Skip to content

Commit

Permalink
Merge branch 'main' into opset-19-doc
Browse files Browse the repository at this point in the history
  • Loading branch information
cjvolzka authored Feb 23, 2024
2 parents a694268 + 3b94f16 commit 4a969d7
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 15 deletions.
6 changes: 5 additions & 1 deletion src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===------------------ ElementwiseBroadcast.cpp - ONNX Operations --------===//
//
// Copyright 2019-2023 The IBM Research Authors.
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -299,6 +299,10 @@ LogicalResult ONNXModOp::verify() {
// must be set to 1.
if (elementType.isa<FloatType>() && (getFmod() != 1))
return emitOpError("fmod must be 1 when the input type is floating point");
// Verify that when the input type is integer, then `fmod` attribute
// must be set to 0.
if (elementType.isa<IntegerType>() && (getFmod() != 0))
return emitOpError("fmod must be 0 when the input type is an integer");

return success();
}
Expand Down
24 changes: 23 additions & 1 deletion src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===//
//
// Copyright 2019-2023 The IBM Research Authors.
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -217,6 +217,28 @@ struct ElementWiseBinaryOpImpl<ONNXMaxOp, T> {
static T eval(T lhs, T rhs) { return std::max<T>(lhs, rhs); }
};

template <>
struct ElementWiseBinaryOpImpl<ONNXModOp, int64_t, EnableNotBool<int64_t>> {
static int64_t eval(int64_t lhs, int64_t rhs) {
// The original calculation for mod
int64_t mod = lhs % rhs;
// Handle the case when one of the int values are negative
// If both int values are positive or multiples of each other, we can
// calculate as normal
if ((mod != 0) && ((lhs < 0) ^ (rhs < 0)))
return (mod + rhs);
return mod;
}
};

template <>
struct ElementWiseBinaryOpImpl<ONNXModOp, double, EnableNotBool<double>> {
static double eval(double lhs, double rhs) {
// Rounding to match the results of the backend tests
return (std::floor(fmod(lhs, rhs) * 1000000000) / 1000000000);
}
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXEqualOp, T> {
static bool eval(T lhs, T rhs) { return lhs == rhs; }
Expand Down
28 changes: 22 additions & 6 deletions src/Transform/ONNX/ConstProp.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

//===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -166,6 +166,9 @@ def CreateGreaterOrEqualOfTwoConst :
def CreatePowOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXPowOp>($_builder, $0, $1, $2)">;

def CreateModOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXModOp>($_builder, $0, $1, $2)">;

def CreateWhereOfThreeConst :
NativeCodeCall<"ConstPropWhere($_builder, $0, $1, $2, $3)">;

Expand Down Expand Up @@ -522,7 +525,7 @@ def DivOnesOnRhs : NamedPat<"DivOnesOnRhs",
]>;

//===----------------------------------------------------------------------===//
// Constant propagate ONNXEqualOp
// Constant propagation for ONNXEqualOp
//===----------------------------------------------------------------------===//

def EqualConstProp : NamedPat<"EqualConstProp",
Expand All @@ -537,7 +540,7 @@ def EqualConstProp : NamedPat<"EqualConstProp",
(IsIntOrFloatType:$lhs), (SatisfiesExpansionBound:$result)]>;

//===----------------------------------------------------------------------===//
// Constant propagate ONNXLessOp
// Constant propagation for ONNXLessOp
//===----------------------------------------------------------------------===//

def LessConstPropPattern : NamedPat<"LessConstPropPattern",
Expand All @@ -549,7 +552,7 @@ def LessConstPropPattern : NamedPat<"LessConstPropPattern",
(SatisfiesExpansionBound:$result)]>;

//===----------------------------------------------------------------------===//
// Constant propagate ONNXGreaterOp
// Constant propagation for ONNXGreaterOp
//===----------------------------------------------------------------------===//

def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern",
Expand All @@ -561,7 +564,7 @@ def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern",
(SatisfiesExpansionBound:$result)]>;

//===----------------------------------------------------------------------===//
// Constant propagate ONNXLessOrEqualOp
// Constant propagation for ONNXLessOrEqualOp
//===----------------------------------------------------------------------===//

def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern",
Expand All @@ -573,7 +576,7 @@ def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern",
(SatisfiesExpansionBound:$result)]>;

//===----------------------------------------------------------------------===//
// Constant propagate ONNXGreaterOrEqualOp
// Constant propagation for ONNXGreaterOrEqualOp
//===----------------------------------------------------------------------===//

def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern",
Expand All @@ -584,6 +587,19 @@ def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern",
[(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs),
(SatisfiesExpansionBound:$result)]>;

//===----------------------------------------------------------------------===//
// Constant propagation for ONNXModOp
//===----------------------------------------------------------------------===//

def ModConstPropPattern : NamedPat<"ModConstPropPattern",
(ONNXModOp:$modOp
(ONNXConstantOp:$A $_, $_, $_, $_, $_, $_, $_, $_),
(ONNXConstantOp:$B $_, $_, $_, $_, $_, $_, $_, $_),
$fmod),
(CreateModOfTwoConst $modOp, $A, $B),
[(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B),
(SatisfiesExpansionBound:$modOp)]>;

//===----------------------------------------------------------------------===//
// Patterns for Where.
//===----------------------------------------------------------------------===//
Expand Down
75 changes: 68 additions & 7 deletions test/mlir/onnx/onnx_constprop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ func.func @test_div_ones(%arg0 : tensor<1x2xui8>) -> tensor<1x2xui8> {
}

//===----------------------------------------------------------------------===//
/// Equal tests
/// Equal test

// -----

Expand All @@ -395,7 +395,7 @@ func.func @test_equal() -> tensor<3xi1> {
}

//===----------------------------------------------------------------------===//
/// Less tests
/// Less test

// -----

Expand All @@ -410,7 +410,7 @@ func.func @test_less() -> tensor<3xi1> {
}

//===----------------------------------------------------------------------===//
/// Greater tests
/// Greater test

// -----

Expand All @@ -425,7 +425,7 @@ func.func @test_greater() -> tensor<3xi1> {
}

//===----------------------------------------------------------------------===//
/// LessOrEqual tests
/// LessOrEqual test

// -----

Expand All @@ -440,7 +440,7 @@ func.func @test_lessorequal() -> tensor<3xi1> {
}

//===----------------------------------------------------------------------===//
/// GreaterOrEqual tests
/// GreaterOrEqual test

// -----

Expand All @@ -455,7 +455,67 @@ func.func @test_greaterorequal() -> tensor<3xi1> {
}

//===----------------------------------------------------------------------===//
/// Sqrt tests
/// Modulo tests

// -----

// CHECK-LABEL: @test_modulo_int_both_neg() -> tensor<i64>
func.func @test_modulo_int_both_neg() -> tensor<i64> {
%0 = onnx.Constant dense<-7> : tensor<i64>
%1 = onnx.Constant dense<-5> : tensor<i64>
%2 = "onnx.Mod"(%0, %1) : (tensor<i64> , tensor<i64>) -> tensor<i64>
"onnx.Return"(%2) : (tensor<i64>) -> ()
// CHECK: [[CONST:%.+]] = onnx.Constant dense<-2> : tensor<i64>
}

// -----

// CHECK-LABEL: @test_modulo_int_neg() -> tensor<i64>
func.func @test_modulo_int_neg() -> tensor<i64> {
%0 = onnx.Constant dense<-4> : tensor<i64>
%1 = onnx.Constant dense<2> : tensor<i64>
%2 = "onnx.Mod"(%0, %1) : (tensor<i64> , tensor<i64>) -> tensor<i64>
"onnx.Return"(%2) : (tensor<i64>) -> ()
// CHECK: [[CONST:%.+]] = onnx.Constant dense<0> : tensor<i64>
}

// -----

// CHECK-LABEL: @test_modulo_int_pos() -> tensor<i64>
func.func @test_modulo_int_pos() -> tensor<i64> {
%0 = onnx.Constant dense<5> : tensor<i64>
%1 = onnx.Constant dense<8> : tensor<i64>
%2 = "onnx.Mod"(%0, %1) : (tensor<i64> , tensor<i64>) -> tensor<i64>
"onnx.Return"(%2) : (tensor<i64>) -> ()
// CHECK: [[CONST:%.+]] = onnx.Constant dense<5> : tensor<i64>
}

// -----

// CHECK-LABEL: @test_modulo_float() -> tensor<1xf32>
func.func @test_modulo_float() -> tensor<1xf32> {
%0 = onnx.Constant dense<[2.0]> : tensor<1xf32>
%1 = onnx.Constant dense<[7.0]> : tensor<1xf32>
%2 = "onnx.Mod"(%0, %1) {fmod = 1 : si64} : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"onnx.Return"(%2) : (tensor<1xf32>) -> ()
// CHECK: [[CONST:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<1xf32>
// CHECK-NOT: {{.*}} = "onnx.Mod"{{.*}}
}

// -----

// CHECK-LABEL: @test_modulo_float_mixed() -> tensor<1xf32>
func.func @test_modulo_float_mixed() -> tensor<1xf32> {
%0 = onnx.Constant dense<[-4.3]> : tensor<1xf32>
%1 = onnx.Constant dense<[2.1]> : tensor<1xf32>
%2 = "onnx.Mod"(%0, %1) {fmod = 1 : si64} : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"onnx.Return"(%2) : (tensor<1xf32>) -> ()
// CHECK: [[CONST:%.+]] = onnx.Constant dense<-0.100000381> : tensor<1xf32>
// CHECK-NOT: {{.*}} = "onnx.Mod"{{.*}}
}

//===----------------------------------------------------------------------===//
/// Sqrt test

// -----

Expand All @@ -468,7 +528,8 @@ func.func @test_sqrt() -> tensor<1x2xf32> {
// CHECK-NOT: {{.*}} = "onnx.Sqrt"{{.*}}
}

/// Relu tests
//===----------------------------------------------------------------------===//
/// Relu test

// -----

Expand Down

0 comments on commit 4a969d7

Please sign in to comment.