Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the dispatch failure when output C for addmm is not transposed #189

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AlbertYang0112
Copy link

Bug: Addmm dispatch failure when the output C is not transposed
How to reproduce:

  • Build the pytorch docker with OneDNN and ACL
  • Run the following code in the container:
    import torch as th
    import torch.nn as nn
    l = nn.Linear(32, 1)
    l.eval()
    l(th.randn((1, 32)))
    

@nSircombe nSircombe self-assigned this Aug 11, 2023
@nSircombe nSircombe added the bug Something isn't working label Aug 11, 2023
+ } else {
+ mkldnn_matmul(a, b, c, beta.to<float>(), alpha.to<float>());
+ }
+ return;
+ }
+ #endif
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpublas::gemm(
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any of the changes below this point are required.

@nSircombe
Copy link
Contributor

Hi @AlbertYang0112,

Thanks for your PR - I'll do some manual testing with it and get back to you.

@nSircombe
Copy link
Contributor

Hi @AlbertYang0112
The patch worked for me. However, we've not tested the oneDNN+ACL matmul backend for a non-transpose C, while we did quite a bit for the transpose case for this PR - pytorch/pytorch#91763 (merged into PyTorch, but not in a released version yet I believe).

I'd rather that we just pass on the non-transpose case, and let if fall through to the non-oneDNN path.
Updating this patch with an approach similar to the one added to aten/src/ATen/native/LinearAlgebra.cpp in pytorch/pytorch#91763, but updating L1425 to if(transpose_a && !transpose_b && transpose_c && result.scalar_type() == at::ScalarType::Float) { or similar would be preferable.

Would you be able to update the PR along those lines?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants