Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch-mlir-opt doesn't support a case for onnx.QLinearConv conversion #18124

Open
uazizTT opened this issue Aug 6, 2024 · 0 comments
Open

torch-mlir-opt doesn't support a case for onnx.QLinearConv conversion #18124

uazizTT opened this issue Aug 6, 2024 · 0 comments
Labels
bug 🐞 Something isn't working integrations/onnx ONNX integration work integrations/pytorch PyTorch integration work

Comments

@uazizTT
Copy link

uazizTT commented Aug 6, 2024

What happened?

I converted an ONNX model resnet50-v1-12-int8.onnx to mlir using import_onnx tool that completed successfully.

python -m torch_mlir.tools.import_onnx resnet50-v1-12-int8.onnx -o restnet50-v1-12.int8.onnx.mlir

However when converting the torch-onnx to torch using the torch-mlir-opt tool, it fails to convert onnx.QLinearConv operator.

restnet50-v1-12.int8.onnx.mlir:371:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %367 = torch.operator "onnx.QLinearConv"(%366, %1, %0, %2, %3, %4, %7, %6, %5) {torch.onnx.auto_pad = "NOTSET", torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,3,224,224],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64,3,7,7],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[64],si32>) -> !torch.vtensor<[?,64,112,112],ui8>

This operator is currently marked as completed on the op tracker nod-ai/SHARK-ModelDev#215

I can reproduce the same issue with iree-import-onnx tool followed by iree-compile.

Steps to reproduce your issue

python -m torch_mlir.tools.import_onnx resnet50-v1-12-int8.onnx -o restnet50-v1-12.int8.onnx.mlir

./build/bin/torch-mlir-opt --convert-torch-onnx-to-torch restnet50-v1-12.int8.onnx.mlir -o restnet50-v1-12.int8.mlir

The onnx model can be downloaded from https://github.com/onnx/models

What component(s) does this issue relate to?

Frontends, MLIR

Version information

No response

Additional context

No response

@uazizTT uazizTT added the bug 🐞 Something isn't working label Aug 6, 2024
@ScottTodd ScottTodd added the integrations/pytorch PyTorch integration work label Aug 6, 2024
@ScottTodd ScottTodd added the integrations/onnx ONNX integration work label Aug 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working integrations/onnx ONNX integration work integrations/pytorch PyTorch integration work
Projects
None yet
Development

No branches or pull requests

2 participants