diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp index 22392cebb8..0c3092d66d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp @@ -4,7 +4,7 @@ //===------------------ ElementwiseBroadcast.cpp - ONNX Operations --------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -299,6 +299,10 @@ LogicalResult ONNXModOp::verify() { // must be set to 1. if (elementType.isa() && (getFmod() != 1)) return emitOpError("fmod must be 1 when the input type is floating point"); + // Verify that when the input type is integer, then `fmod` attribute + // must be set to 0. + if (elementType.isa() && (getFmod() != 0)) + return emitOpError("fmod must be 0 when the input type is an integer"); return success(); } diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index 116a41d653..5c5e1afda4 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -4,7 +4,7 @@ //===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -217,6 +217,28 @@ struct ElementWiseBinaryOpImpl { static T eval(T lhs, T rhs) { return std::max(lhs, rhs); } }; +template <> +struct ElementWiseBinaryOpImpl> { + static int64_t eval(int64_t lhs, int64_t rhs) { + // The original calculation for mod + int64_t mod = lhs % rhs; + // Handle the case when one of the int values are negative + // If both int values are positive or multiples of each other, we can + // calculate as normal + if ((mod != 0) && ((lhs < 0) ^ (rhs < 0))) + return (mod + rhs); + return mod; + } +}; + +template <> +struct ElementWiseBinaryOpImpl> { + static double eval(double lhs, double rhs) { + // Rounding to match the results of the backend tests + return (std::floor(fmod(lhs, rhs) * 1000000000) / 1000000000); + } +}; + template struct ElementWiseBinaryOpImpl { static bool eval(T lhs, T rhs) { return lhs == rhs; } diff --git a/src/Transform/ONNX/ConstProp.td b/src/Transform/ONNX/ConstProp.td index 329ed8905f..46a4fc1d3f 100644 --- a/src/Transform/ONNX/ConstProp.td +++ b/src/Transform/ONNX/ConstProp.td @@ -2,7 +2,7 @@ //===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -166,6 +166,9 @@ def CreateGreaterOrEqualOfTwoConst : def CreatePowOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; +def CreateModOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + def CreateWhereOfThreeConst : NativeCodeCall<"ConstPropWhere($_builder, $0, $1, $2, $3)">; @@ -522,7 +525,7 @@ def DivOnesOnRhs : NamedPat<"DivOnesOnRhs", ]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXEqualOp +// Constant propagation for ONNXEqualOp //===----------------------------------------------------------------------===// def EqualConstProp : NamedPat<"EqualConstProp", @@ -537,7 +540,7 @@ def EqualConstProp : NamedPat<"EqualConstProp", (IsIntOrFloatType:$lhs), (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXLessOp +// Constant propagation for ONNXLessOp //===----------------------------------------------------------------------===// def LessConstPropPattern : NamedPat<"LessConstPropPattern", @@ -549,7 +552,7 @@ def LessConstPropPattern : NamedPat<"LessConstPropPattern", (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXGreaterOp +// Constant propagation for ONNXGreaterOp //===----------------------------------------------------------------------===// def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern", @@ -561,7 +564,7 @@ def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern", (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXLessOrEqualOp +// Constant propagation for ONNXLessOrEqualOp //===----------------------------------------------------------------------===// def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern", @@ -573,7 +576,7 @@ def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern", (SatisfiesExpansionBound:$result)]>; //===----------------------------------------------------------------------===// -// Constant propagate ONNXGreaterOrEqualOp +// Constant propagation for ONNXGreaterOrEqualOp //===----------------------------------------------------------------------===// def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern", @@ -584,6 +587,19 @@ def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern", [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), (SatisfiesExpansionBound:$result)]>; +//===----------------------------------------------------------------------===// +// Constant propagation for ONNXModOp +//===----------------------------------------------------------------------===// + +def ModConstPropPattern : NamedPat<"ModConstPropPattern", + (ONNXModOp:$modOp + (ONNXConstantOp:$A $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$B $_, $_, $_, $_, $_, $_, $_, $_), + $fmod), + (CreateModOfTwoConst $modOp, $A, $B), + [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), + (SatisfiesExpansionBound:$modOp)]>; + //===----------------------------------------------------------------------===// // Patterns for Where. //===----------------------------------------------------------------------===// diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index 724fe54824..3905b447de 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -380,7 +380,7 @@ func.func @test_div_ones(%arg0 : tensor<1x2xui8>) -> tensor<1x2xui8> { } //===----------------------------------------------------------------------===// -/// Equal tests +/// Equal test // ----- @@ -395,7 +395,7 @@ func.func @test_equal() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// Less tests +/// Less test // ----- @@ -410,7 +410,7 @@ func.func @test_less() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// Greater tests +/// Greater test // ----- @@ -425,7 +425,7 @@ func.func @test_greater() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// LessOrEqual tests +/// LessOrEqual test // ----- @@ -440,7 +440,7 @@ func.func @test_lessorequal() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// GreaterOrEqual tests +/// GreaterOrEqual test // ----- @@ -455,7 +455,67 @@ func.func @test_greaterorequal() -> tensor<3xi1> { } //===----------------------------------------------------------------------===// -/// Sqrt tests +/// Modulo tests + +// ----- + +// CHECK-LABEL: @test_modulo_int_both_neg() -> tensor +func.func @test_modulo_int_both_neg() -> tensor { + %0 = onnx.Constant dense<-7> : tensor + %1 = onnx.Constant dense<-5> : tensor + %2 = "onnx.Mod"(%0, %1) : (tensor , tensor) -> tensor + "onnx.Return"(%2) : (tensor) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<-2> : tensor +} + +// ----- + +// CHECK-LABEL: @test_modulo_int_neg() -> tensor +func.func @test_modulo_int_neg() -> tensor { + %0 = onnx.Constant dense<-4> : tensor + %1 = onnx.Constant dense<2> : tensor + %2 = "onnx.Mod"(%0, %1) : (tensor , tensor) -> tensor + "onnx.Return"(%2) : (tensor) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<0> : tensor +} + +// ----- + +// CHECK-LABEL: @test_modulo_int_pos() -> tensor +func.func @test_modulo_int_pos() -> tensor { + %0 = onnx.Constant dense<5> : tensor + %1 = onnx.Constant dense<8> : tensor + %2 = "onnx.Mod"(%0, %1) : (tensor , tensor) -> tensor + "onnx.Return"(%2) : (tensor) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<5> : tensor +} + +// ----- + +// CHECK-LABEL: @test_modulo_float() -> tensor<1xf32> +func.func @test_modulo_float() -> tensor<1xf32> { + %0 = onnx.Constant dense<[2.0]> : tensor<1xf32> + %1 = onnx.Constant dense<[7.0]> : tensor<1xf32> + %2 = "onnx.Mod"(%0, %1) {fmod = 1 : si64} : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32> + "onnx.Return"(%2) : (tensor<1xf32>) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<1xf32> + // CHECK-NOT: {{.*}} = "onnx.Mod"{{.*}} +} + +// ----- + +// CHECK-LABEL: @test_modulo_float_mixed() -> tensor<1xf32> +func.func @test_modulo_float_mixed() -> tensor<1xf32> { + %0 = onnx.Constant dense<[-4.3]> : tensor<1xf32> + %1 = onnx.Constant dense<[2.1]> : tensor<1xf32> + %2 = "onnx.Mod"(%0, %1) {fmod = 1 : si64} : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32> + "onnx.Return"(%2) : (tensor<1xf32>) -> () + // CHECK: [[CONST:%.+]] = onnx.Constant dense<-0.100000381> : tensor<1xf32> + // CHECK-NOT: {{.*}} = "onnx.Mod"{{.*}} +} + +//===----------------------------------------------------------------------===// +/// Sqrt test // ----- @@ -468,7 +528,8 @@ func.func @test_sqrt() -> tensor<1x2xf32> { // CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}} } -/// Relu tests +//===----------------------------------------------------------------------===// +/// Relu test // -----