From d59d0b6e5a88252d1d7e9b380e5488f49fadf87f Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Wed, 5 Jun 2024 07:05:39 +0800 Subject: [PATCH] [Linalg] Promote type for compare tensor op (#3416) --- .../TorchToLinalg/Uncategorized.cpp | 103 +++++------------- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../test_suite/elementwise_comparison.py | 45 ++++++++ 3 files changed, 76 insertions(+), 75 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 12b2264bc244..d11fd987482e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } -template -static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); - - Type lhsDtype = lhs.getType(); - Type rhsDtype = rhs.getType(); - - // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs - // to be handled. - if (lhsDtype != rhsDtype) { - op.emitError("unimplemented: lhs and rhs dtype must be same"); - return nullptr; - } - - Type elementalType = cast(op.getSelf().getType()).getDtype(); - if constexpr (std::is_same()) { - return createLessThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createNotEqual(b, loc, elementalType, lhs, rhs); - } - llvm_unreachable("unimplemented: op type not supported"); -} +template +struct is_any_same : std::disjunction...> {}; template -static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); +static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs, + Value rhs) { + static_assert( + is_any_same(), + "unimplemented: op type not supported"); Type lhsDtype = lhs.getType(); Type rhsDtype = rhs.getType(); @@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, return nullptr; } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createGreaterThan(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createNotEqual(b, loc, elementalType, lhs, rhs); } llvm_unreachable("unimplemented: op type not supported"); @@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, rhs); } if (auto ltTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]); } if (auto leTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, leTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]); } if (auto gtTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]); } if (auto geTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]); } if (auto neTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, neTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); @@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto gtScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]); } if (auto geScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]); } if (auto eqScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]); } if (auto neScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]); } if (auto ltScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]); } if (auto leScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]); } if (auto whereSelf = dyn_cast(op)) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7dc557b44ee2..65153e4f5ba3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -27,6 +27,7 @@ "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateDynamicModule_scales_recompute_bilinear", + "ElementwiseFloatTensorGtIntTensorModule_basic", } LINALG_CRASHING_SET = { @@ -2707,6 +2708,7 @@ "ElementwiseTanIntModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", @@ -3786,6 +3788,7 @@ "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", "ElementwiseFlattenBroadcastModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFmodTensor_Float_basic", "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 7fdfb454d362..304bc422e4d2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) +class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.float64, True), + ] + ) + def forward(self, x, y): + return torch.lt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule()) +def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64)) + + +class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.gt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule()) +def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(3, 5, high=10).to(torch.float32), + tu.randint(5, high=10, dtype=torch.int32), + ) + + # ==============================================================================