diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 5b144503c0ec..ccb0c033c617 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -113,6 +113,25 @@ struct OpBinder { return failure(); } + ParseResult f32FloatAttr(float &value, StringRef nameSuffix, + float defaultValue = 0.0f) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto floatAttr = dyn_cast(attr)) { + FloatType t = cast(floatAttr.getType()); + if (t.getWidth() != 32) + return failure(); + value = floatAttr.getValueAsDouble(); + return success(); + } + return failure(); + } + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, std::string defaultValue = "") { SmallString<64> name("torch.onnx."); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 23af89f329ab..b9fd49bd33f7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -26,4 +26,33 @@ using namespace mlir::torch::onnx_c; // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainQtoZ( - OnnxCustomOpConversionPattern &patterns) {} + OnnxCustomOpConversionPattern &patterns) { + + patterns.onOp( + "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + float alpha, gamma; + Value operand; + if (binder.tensorOperand(operand) || + binder.f32FloatAttr(alpha, "alpha") || + binder.f32FloatAttr(gamma, "gamma") || + binder.tensorResultType(resultType)) + return failure(); + + Value vAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); + + Value vScale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); + + Value vInputScale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); + return success(); + }); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir new file mode 100644 index 000000000000..8b98838dc769 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -0,0 +1,16 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + + +// CHECK-LABEL: func.func @test_selu +func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] + %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} \ No newline at end of file