From e05a320906c23313e7d9150cd27c15543591031e Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Mon, 16 Sep 2024 11:21:54 -0700 Subject: [PATCH] Add "PassUtils" for common functions to avoid duplicate definitions (#2547) [back-porting to gh](https://github.com/openxla/xla/commit/5a07f58bb626dec8ad17c6a7998d5410401d6d93) --- BUILD.bazel | 2 + .../chlo/chlo_legalize_to_stablehlo.mlir | 42 +++----------- ...ablehlo_create_compatibility_expander.mlir | 4 +- stablehlo/transforms/CMakeLists.txt | 1 + .../transforms/ChloDecompositionPatterns.td | 21 +------ .../transforms/ChloLegalizeToStablehlo.cpp | 24 +------- stablehlo/transforms/PassUtils.cpp | 34 +++++++++++ stablehlo/transforms/PassUtils.h | 57 +++++++++++++++++++ stablehlo/transforms/Passes.td | 1 + .../StablehloCreateCompatibilityExpander.cpp | 16 +----- ...ehloCreateCompatibilityExpanderPatterns.td | 5 +- 11 files changed, 111 insertions(+), 96 deletions(-) create mode 100644 stablehlo/transforms/PassUtils.cpp create mode 100644 stablehlo/transforms/PassUtils.h diff --git a/BUILD.bazel b/BUILD.bazel index 0551c756fdb..70a171f182b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1096,6 +1096,7 @@ cc_library( srcs = [ "stablehlo/transforms/ChloLegalizeToStablehlo.cpp", "stablehlo/transforms/PassPipelines.cpp", + "stablehlo/transforms/PassUtils.cpp", "stablehlo/transforms/ShapeLegalizeToStablehlo.cpp", "stablehlo/transforms/StablehloAggressiveFolder.cpp", "stablehlo/transforms/StablehloAggressiveSimplification.cpp", @@ -1115,6 +1116,7 @@ cc_library( ], hdrs = [ "stablehlo/transforms/MapStablehloToVhlo.h", + "stablehlo/transforms/PassUtils.h", "stablehlo/transforms/Passes.h", "stablehlo/transforms/StablehloRefineShapes.h", ], diff --git a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir index 3032fe77b57..e55bde8b226 100644 --- a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +++ b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir @@ -3029,28 +3029,11 @@ func.func @next_after_f32(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> // ----- -// CHECK-LABEL: @tan_f16 -// CHECK-SAME: (%[[ARG:.*]]: tensor) -func.func @tan_f16(%arg : tensor) -> tensor { - // %[[TMP_0:.*]] = stablehlo.convert [[ARG]] : (tensor) -> tensor - // %[[TMP_1:.*]] = stablehlo.sine %[[TMP_0]] - // %[[TMP_2:.*]] = stablehlo.cosine %[[TMP_0]] - // %[[TMP_3:.*]] = stablehlo.divide %[[TMP_1]], %[[TMP_2]] - // %[[TMP_4:.*]] = stablehlo.convert %[[TMP_3]] : (tensor) -> tensor - // return %[[TMP_4]] : tensor - %1 = chlo.tan %arg : tensor -> tensor - func.return %1 : tensor -} - -// ----- - // CHECK-LABEL: @tan_f32 // CHECK-SAME: (%[[ARG:.*]]: tensor) func.func @tan_f32(%arg : tensor) -> tensor { - // %[[TMP_0:.*]] = stablehlo.sine %[[ARG]] - // %[[TMP_1:.*]] = stablehlo.cosine %[[ARG]] - // %[[TMP_2:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_1]] - // return %[[TMP_2]] : tensor + // CHECK: %[[TMP_0:.*]] = stablehlo.tan %[[ARG]] : tensor + // CHECK: return %[[TMP_0]] : tensor %1 = chlo.tan %arg : tensor -> tensor func.return %1 : tensor } @@ -3060,22 +3043,11 @@ func.func @tan_f32(%arg : tensor) -> tensor { // CHECK-LABEL: @tan_complexf32 // CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32> func.func @tan_complexf32(%arg0 : tensor<1xf32>, %arg1 : tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { - // CHECK: %[[COMPLEX:.+]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex> - // CHECK: %[[REAL:.+]] = stablehlo.real %[[COMPLEX]] : (tensor<1xcomplex>) -> tensor<1xf32> - // CHECK: %[[SINE:.+]] = stablehlo.sine %[[REAL]] - // CHECK: %[[COS:.+]] = stablehlo.cosine %[[REAL]] - // CHECK: %[[TAN:.+]] = stablehlo.divide %[[SINE]], %[[COS]] - // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[COMPLEX]] : (tensor<1xcomplex>) -> tensor<1xf32> - // CHECK: %[[TANH:.+]] = stablehlo.tanh %[[IMAG]] - // CHECK: %[[NUM:.+]] = stablehlo.complex %[[TAN]], %[[TANH]] - // CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> - // CHECK: %[[MUL:.+]] = stablehlo.multiply %[[TAN]], %[[TANH]] - // CHECK: %[[NEG:.+]] = stablehlo.negate %[[MUL]] - // CHECK: %[[DEN:.+]] = stablehlo.complex %[[ONE]], %[[NEG]] - // CHECK: %[[RES:.+]] = stablehlo.divide %[[NUM]], %[[DEN]] - // CHECK: %[[REAL:.+]] = stablehlo.real %[[RES]] - // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[RES]] - // CHECK: return %[[REAL]], %[[IMAG]] + // CHECK: %[[TMP_0:.*]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex> + // CHECK: %[[TMP_1:.*]] = stablehlo.tan %[[TMP_0]] : tensor<1xcomplex> + // CHECK: %[[TMP_2:.*]] = stablehlo.real %[[TMP_1]] : (tensor<1xcomplex>) -> tensor<1xf32> + // CHECK: %[[TMP_3:.*]] = stablehlo.imag %[[TMP_1]] : (tensor<1xcomplex>) -> tensor<1xf32> + // CHECK: return %[[TMP_2]], %[[TMP_3]] : tensor<1xf32>, tensor<1xf32> %0 = stablehlo.complex %arg0, %arg1 : tensor<1xcomplex> %1 = chlo.tan %0 : tensor<1xcomplex> -> tensor<1xcomplex> %2 = stablehlo.real %1 : (tensor<1xcomplex>) -> tensor<1xf32> diff --git a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir index ae0a376489d..024f1108c92 100644 --- a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +++ b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir @@ -1,5 +1,5 @@ -// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK -// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.0.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE // ----- diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index 1cf6d6a3218..96d51104cad 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -53,6 +53,7 @@ add_mlir_dialect_library(StablehloPasses StablehloRefineShapes.cpp VhloLegalizeToStablehlo.cpp VhloToVersion.cpp + PassUtils.cpp DEPENDS ChloDecompositionPatternsIncGen diff --git a/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/transforms/ChloDecompositionPatterns.td index a92e73d9de5..66170fb5474 100644 --- a/stablehlo/transforms/ChloDecompositionPatterns.td +++ b/stablehlo/transforms/ChloDecompositionPatterns.td @@ -109,26 +109,9 @@ def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input), (STABLEHLO_DEFAULT_COMPARISON_TYPE) )>; -// Express `tan` as -// sine(x) / cosine(x) -def : Pat<(CHLO_TanOp NonComplexElementType:$input), - (StableHLO_DivOp - (StableHLO_SineOp $input), - (StableHLO_CosineOp $input) - )>; - -// Express `tan(a + bi)` as -// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b)) -def : Pat<(CHLO_TanOp ComplexElementType:$input), - (StableHLO_DivOp - (StableHLO_ComplexOp - (CHLO_TanOp:$tan (StableHLO_RealOp $input)), - (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), - (StableHLO_ComplexOp - (StableHLO_ConstantLike<"1.0"> $tan), - (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) - )>; +def : Pat<(CHLO_TanOp $input), + (StableHLO_TanOp $input)>; def : Pat<(CHLO_ConstantOp $v), (StableHLO_ConstantOp $v)>; diff --git a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index b5c07d3894f..b6d75b376b5 100644 --- a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -48,6 +48,7 @@ #include "stablehlo/dialect/BroadcastUtils.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/PassUtils.h" #include "stablehlo/transforms/Passes.h" namespace mlir { @@ -171,29 +172,6 @@ static void populateForBroadcastingBinaryOp(MLIRContext *context, context, args...); } -template -static Value getConstantLike(OpBuilder &b, Location loc, T constant, - Value val) { - Type ty = getElementTypeOrSelf(val.getType()); - auto getAttr = [&]() -> Attribute { - if (isa(ty)) return b.getIntegerAttr(ty, constant); - if (isa(ty)) return b.getFloatAttr(ty, constant); - if (auto complexTy = dyn_cast(ty)) { - return complex::NumberAttr::get(complexTy, constant, 0); - } - llvm_unreachable("unhandled element type"); - }; - return b.create(loc, cast(getAttr()), - val); -} - -static Value getConstantLike(OpBuilder &b, Location loc, - const APFloat &constant, Value val) { - Type ty = getElementTypeOrSelf(val.getType()); - return b.create(loc, b.getFloatAttr(ty, constant), - val); -} - static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, Value val) { auto ty = cast(getElementTypeOrSelf(val.getType())); diff --git a/stablehlo/transforms/PassUtils.cpp b/stablehlo/transforms/PassUtils.cpp new file mode 100644 index 00000000000..3de7faa956c --- /dev/null +++ b/stablehlo/transforms/PassUtils.cpp @@ -0,0 +1,34 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "stablehlo/transforms/PassUtils.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/ChloOps.h" + +namespace mlir { +namespace stablehlo { + +Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + return b.create(loc, b.getFloatAttr(ty, constant), + val); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/PassUtils.h b/stablehlo/transforms/PassUtils.h new file mode 100644 index 00000000000..73da6a2918b --- /dev/null +++ b/stablehlo/transforms/PassUtils.h @@ -0,0 +1,57 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASS_UTILS_H_ +#define THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASS_UTILS_H_ + +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/ChloOps.h" + +namespace mlir { +namespace stablehlo { +// Add utility functions common across passes. + +// Creates a chlo::ConstantLikeOp using a splat `constant` of the same shape +// as `val`. +template +Value getConstantLike(OpBuilder &b, Location loc, T constant, Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + auto getAttr = [&]() -> Attribute { + if (isa(ty)) return b.getIntegerAttr(ty, constant); + if (isa(ty)) return b.getFloatAttr(ty, constant); + if (auto complexTy = dyn_cast(ty)) { + return complex::NumberAttr::get(complexTy, constant, 0); + } + llvm_unreachable("unhandled element type"); + }; + return b.create(loc, cast(getAttr()), + val); +} + +// Creates a chlo::ConstantLikeOp using a APFloat splat `constant` of the +// same shape as `val`. +Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, + Value val); + +} // namespace stablehlo +} // namespace mlir + +#endif // THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASS_UTILS_H_ diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 862c82b9ee4..cf45240fadc 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -338,5 +338,6 @@ def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibil ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", + "mlir::chlo::ChloDialect", ]; } diff --git a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp index b28e9cce9ad..28113bf1a4a 100644 --- a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +++ b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/Version.h" +#include "stablehlo/transforms/PassUtils.h" #include "stablehlo/transforms/Passes.h" namespace mlir { @@ -38,21 +39,6 @@ namespace { // Helpers. //===----------------------------------------------------------------------===// -// Creates a constant with all ones. -static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) { - if (!isa(getElementTypeOrSelf(val))) - llvm_unreachable("Unsupported element type, expecting float"); - - auto shapedTy = dyn_cast(val.getType()); - if (!shapedTy) llvm_unreachable("Unsupported shaped type."); - - mlir::DenseElementsAttr elementsAttr = - mlir::DenseElementsAttr::get(shapedTy, 1.0); - - return b.create(loc, val.getType(), - elementsAttr); -} - // Check user-specified target version. vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { auto failOrVersion = vhlo::Version::fromString(versionRef); diff --git a/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td index 493a9be9c6d..23a5ea4c899 100644 --- a/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td +++ b/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td @@ -24,7 +24,8 @@ def NonComplexElementType : Type< CPred<"!isa(cast($_self).getElementType())">, "Non-complex element type">; -def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">; +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; // Express `tan` as // sine(x) / cosine(x) @@ -42,6 +43,6 @@ def TanOp_ComplexElementType_CompatiblityExpander : Pat<(StableHLO_TanOp Complex (StableHLO_TanOp:$tan (StableHLO_RealOp $input)), (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), (StableHLO_ComplexOp - (createConstantWithAllOnes $tan), + (StableHLO_ConstantLike<"1.0"> $tan), (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) )>;