diff --git a/BUILD.bazel b/BUILD.bazel index 70a171f182b..58df9f183b6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -345,11 +345,11 @@ gentbl_cc_library( tbl_outs = [ ( ["--gen-rewriters"], - "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc", + "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td", + td_file = "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.td", deps = [ ":stablehlo_ops_td_files", ], @@ -1101,8 +1101,8 @@ cc_library( "stablehlo/transforms/StablehloAggressiveFolder.cpp", "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", + "stablehlo/transforms/StablehloCompatibilityExpander.cpp", "stablehlo/transforms/StablehloConvertToSignless.cpp", - "stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp", "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", "stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cpp", diff --git a/docs/_toc.yaml b/docs/_toc.yaml index 9314989eeae..577d15f51b5 100644 --- a/docs/_toc.yaml +++ b/docs/_toc.yaml @@ -21,6 +21,8 @@ toc: path: /stablehlo/spec - title: Compatibility guarantees path: /stablehlo/compatibility + - title: Dynamism support + path: /stablehlo/dynamism - title: Reference interpreter path: /stablehlo/reference - title: Roadmap diff --git a/docs/dynamism.md b/docs/dynamism.md new file mode 100644 index 00000000000..294325faffa --- /dev/null +++ b/docs/dynamism.md @@ -0,0 +1,220 @@ +# Dynamism in StableHLO + +The current state of dynamism is more formally spelled out in the +[Dynamism RFC][dynamism-rfc], this page will provide a high level overview of +the RFC and discuss important APIs and tooling for interacting with dynamic +programs. + +[dynamism-rfc]:https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md + +## Dynamism Terminology & Support Overview + +First, to cover a few terms that will appear in this doc, as well as a brief +intro to their support in StableHLO: + +### Dynamic dimensions + +Dynamic dimensions refers to any dimension whose dimension size is unknown. +In StableHLO we represent dynamic dimensions using `?`, i.e. `tensor<16x?xf32>`. + +### Bounded dynamism + +Bounded dynamism refers to a dynamic dimension whose value has a known upper +bound. Generally this is useful for padding the tensor during execution. +In StableHLO we represent bounded dynamism using `#stablehlo.bounds` as a +tensor encoding, i.e. a rank-2 tensor with one dynamic dimension bounded at 16 +and the other without a bound can be represented as +`tensor>`. + +StableHLO is able to represent bounded dynamism, but there is limited framework +support, originating in TensorFlow, and with some support in PyTorch/XLA. + +### Unbounded dynamism + +Unbounded dynamism as the name implies refers to a dynamic dimension with +no known bound on the size. This type of dynamism is very common in StableHLO, +with JAX, PyTorch/XLA, and TF support, often used for exporting models with +dynamic batch size or sequence length. + +In StableHLO we simply elide the bounds encoding for this form of dynamism, i.e. +`tensor`. + +### Shape polymorphism + +Shape polymorphism is a [term we've inherited from JAX][shape-poly]. + +There are two key implications to shape polymorphism: + +1. All dynamism in the program traces back to its input arguments. +2. All dynamism pertains to tensor _shapes_ only, i.e. not data-dependent. + +With these two rules, once the static shapes of a program are known, we are able +to take a dynamic program and fully refine it into a static program for +compilation (see ["Compiler passes for refining dynamic programs"](#compiler-passes-for-refining-dynamic-programs)). + +Generally shape polymorphism uses unbounded dynamism, if known argument shapes +can lead to a fully static program, there isn't a need to guess on how to bound +the values. + +### Data-dependent dynamism + +Data-dependent dynamism refers to dynamic dimensions sizes that pertain to +the _data_ inside a tensor. The canonical example is a `nonzeros` function which +returns the indices of all elements that are `0` in a tensor value. The shape +cannot be known without evaluating the data, but it can often be compiled using +bounded dynamism, spending extra memory on the potential output tensor size. + +Many data-dependent dynamic ops can be modeled using bounded dynamism, where an +upper bound on a tensor size is specified, and hardware generally will implement +this via tensor padding. Today there is some support for data-dependent dynamism +in PyTorch/XLA and TensorFlow, but JAX does not currently trace operations which +lead to data dependent dynamism. + +[shape-poly]:https://jax.readthedocs.io/en/latest/export/shape_poly.html + +## Exporting programs with dynamic dimensions + +See our StableHLO tutorials for information on how to export programs with +dynamic batch sizes or sequence lengths: + +- [JAX Tutorial > Export with Dynamic Batch Size][jax-export-dynamic] +- [PyTorch/XLA Tutorial > Export with Dynamic Batch Size][pytorch-export-dynamic] + +[jax-export-dynamic]:https://openxla.org/stablehlo/tutorials/jax-export#export_with_dynamic_batch_size +[pytorch-export-dynamic]:https://openxla.org/stablehlo/tutorials/pytorch-export#export_with_dynamic_batch_dimension + +## Compiler passes for refining dynamic programs + +### Remove dynamism pass pipeline + +There are a few useful passes for refining shapes, conveniently they are all +bundled in a pass pipeline [`createStablehloRemoveDynamismPipeline`][remove-dynamism]: + +```c++ +void createStablehloRemoveDynamismPipeline(OpPassManager &pm, + TypeRange refinedTypes); +``` + +### Individual passes for refining dynamism + +Individually, the passes that tend to be useful for shape refinement are: + +- [`stablehlo-refine-arguments`][refine-arguments] to replace input arguments + with concrete tensor types. +- [`stablehlo-refine-shapes`][refine-shapes] to propagate the new input argument + shape information throughout the entire program. +- [`stablehlo-canonicalize-dynamism`][canonicalize-dynamism] to replace dynamic + ops with their static variants. + +See linked documentation for up-to-date information and examples. + +[remove-dynamism]:https://github.com/openxla/stablehlo/blob/ff13c96e56b73c62dcbb5b34b69f5ece9e71322f/stablehlo/transforms/Passes.h#L134 +[canonicalize-dynamism]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-canonicalize-dynamism +[refine-arguments]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-refine-arguments +[refine-shapes]:https://openxla.org/stablehlo/generated/stablehlo_passes#-stablehlo-refine-shapes + +## Example: How is dynamism useful, and how can I use it? + +Dynamism has lots of uses, here we'll mainly focus on the common use case for +Shape Polymorphism - creating a flexible exported model representation, +generally used to represent dynamic batch size or sequence length. + +### Static add_one model + +We'll use the following simple `add_one` model to demonstrate this: + +```py +def add_one(x): + return x + 1 +``` + +When traced using a `tensor<4xf32>` we'll get the following StableHLO program: + +```mlir +// File: add_one.mlir +func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32> + %0 = stablehlo.add %arg0, %cst : tensor<4xf32> + return %0 : tensor<4xf32> +} +``` + +This model will work _only_ for input arguments that have a `tensor<4xf32>` +shape. If we ever changed our batch size or sequence length, we would need to +re-trace the source code and re-lower to StableHLO, and there's no guarantee +that we even have access to the source code still! + +### Dynamic add_one model + +This is where shape polymorphic dynamism comes into play. Instead JAX and +PyTorch/XLA can emit the `add_one` model with dynamically valid IR which +will broadcast the constant to match the dynamic input shape as follows: + +```mlir +// File: add_one_dynamic.mlir +func.func public @main(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<1.0> : tensor + %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor) -> tensor + %1 = stablehlo.reshape %0 : (tensor) -> tensor<1xi32> + %2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor, tensor<1xi32>) -> tensor + %3 = stablehlo.add %arg0, %2 : tensor + return %3 : tensor +} +``` + +This model representation is much more flexible, and allows deferred +specification of values like batch size or sequence length. This model can be +deployed on platforms with dynamic shape support (like [AI Edge][ai-edge]), or +it can be refined using the dynamism passes mentioned in this documentation. + +[ai-edge]:https://github.com/google-ai-edge/ai-edge-torch + +### Refining the dynamic model + +For example the following pass ordering can fully refine this program: + +```sh +stablehlo-opt add_one_dynamic.mlir \ + --stablehlo-refine-arguments='types=tensor<16xf32>' \ + --stablehlo-refine-shapes \ + --stablehlo-canonicalize-dynamism +``` + +Incrementally, this is how the program gets transformed: + +```mlir +// After stablehlo-refine-arguments: Inputs updated, shapes not propagated +func.func public @main(%arg0: tensor<16xf32>) -> tensor { + %c = stablehlo.constant dense<16> : tensor<1xi64> + %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor + ... + %3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor, tensor<1xi32>) -> tensor + %4 = stablehlo.add %0, %3 : tensor + return %4 : tensor +} + +// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist +func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> { + %cst = stablehlo.constant dense<1.000000e+00> : tensor + %c = stablehlo.constant dense<16> : tensor<1xi32> + %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor, tensor<1xi32>) -> tensor<16xf32> + %1 = stablehlo.add %arg0, %0 : tensor<16xf32> + return %1 : tensor<16xf32> +} + +// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops +func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> { + %cst = stablehlo.constant dense<1.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16xf32> + %1 = stablehlo.add %arg0, %0 : tensor<16xf32> + return %1 : tensor<16xf32> +} + +// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the +// constant broadcast, leaving us with the original static program in this case. +func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> { + %cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32> + %0 = stablehlo.add %arg0, %cst : tensor<16xf32> + return %0 : tensor<16xf32> +} +``` diff --git a/docs/generated/stablehlo_passes.md b/docs/generated/stablehlo_passes.md index 80fc0867edb..ed591fab242 100755 --- a/docs/generated/stablehlo_passes.md +++ b/docs/generated/stablehlo_passes.md @@ -30,18 +30,21 @@ _Canonicalizes StableHLO operations_ _Canonicalizes dynamic StableHLO ops into static ops._ Replaces dynamic StableHLO ops like DynamicReshapeOp with the corresponding -static counterparts like ReshapeOp if all the dynamic elements of these ops -are actually constant. +static counterparts like `DynamicReshapeOp` to `ReshapeOp` or +`DynamicBroadcastInDim` to `BroadcastInDim` if all the dynamic elements of = +these ops are actually constants. -For example, if the output_shape operand of DynamicReshapeOp is a constant -value, then the operation can be transformed to ReshapeOp. -### `-stablehlo-convert-to-signless` +``` + %c = stablehlo.constant dense<16> : tensor<1xi32> + %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor, tensor<1xi32>) -> tensor<16xf32> -_Pass to transform the IR to be on signless integers._ + ==> -### `-stablehlo-create-compatibility-expander` + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16xf32> +``` +### `-stablehlo-compatibility-expander` -_Create compatibility expander for StableHLO operations._ +_Compatibility expander for StableHLO operations._ StableHLO ops gets updates or new op is introduced in the latest versions. This opt-in pass expands backward compatibility with older StableHLO @@ -65,11 +68,9 @@ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { %1 = stablehlo.tan %arg0 : tensor<4xf64> func.return %1 : tensor<4xf64> } -``` -will become: +==> -```mlir func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { %0 = stablehlo.sine %arg0 : tensor<4xf64> %1 = stablehlo.cosine %arg0 : tensor<4xf64> @@ -82,9 +83,13 @@ func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { ``` -target : The target version. Must be a version of the form #.#.#. ``` +### `-stablehlo-convert-to-signless` + +_Pass to transform the IR to be on signless integers._ + ### `-stablehlo-legalize-composite-to-call` -_Replaces composite ops with a call to their decomposition_ +_Replaces composite ops with a call to their decomposition._ Replaces composite ops with a call to their decomposition, e.g. the below: @@ -232,9 +237,23 @@ _Legalize StableHLO to VHLO._ _Refines the argument shapes of the main function._ Modifies the arguments of the main function using the input type signature. -Wraps arguments in custom_call @stablehlo.shape_refinement_operand_wrapper +Wraps arguments in `custom_call @stablehlo.shape_refinement_operand_wrapper` to keep the IR valid before shape refinement is run. +``` +func.func public @main(%arg0: tensor) -> tensor { + ... +} + +==> + +func.func public @main(%arg0: tensor<16xf32>) -> tensor { + %c = stablehlo.constant dense<16> : tensor<1xi64> + %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...} + : (tensor<16xf32>, tensor<1xi64>) -> tensor + ... +``` + The `refinedTypesOption` can be used to specify a list of refined types. This can be specified in MLIR with `--types='tensor<...>,tensor<...>'`, or passed to the pass create method. The refinement type list must specify the @@ -255,6 +274,21 @@ programs to static shapes. If a dynamically-shaped StableHLO program has the right structure, then updating its argument types from dynamic shapes to static shapes and running this pass will propagate static shapes across the program. + +This pass removes `custom_call @shape_refinement_operand_wrapper` by +replacing uses of the result with the operand directly, and propagates +static shapes throughout the program. + +``` +%c = stablehlo.constant dense<16> : tensor<1xi64> +%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...} + : (tensor<16xf32>, tensor<1xi64>) -> tensor +%1 = stablehlo.add %0, %0 : tensor + +==> + +%1 = stablehlo.add %arg0, %arg0 : tensor<16xf32> +``` ### `-vhlo-legalize-to-stablehlo` _Legalize VHLO to StableHLO._ diff --git a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir index 569db7cb1d1..76aa10e488f 100644 --- a/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +++ b/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir @@ -1,5 +1,5 @@ -// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.0.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK -// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-compatibility-expander='target=1.0.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK +// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-compatibility-expander='target=1.6.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE // ----- diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index 96d51104cad..ccdcc26ed03 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -20,9 +20,9 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td) mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) add_public_tablegen_target(ChloDecompositionPatternsIncGen) -set(LLVM_TARGET_DEFINITIONS StablehloCreateCompatibilityExpanderPatterns.td) -mlir_tablegen(StablehloCreateCompatibilityExpanderPatterns.h.inc --gen-rewriters) -add_public_tablegen_target(StablehloCreateCompatibilityExpanderPatternsIncGen) +set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) +mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) @@ -42,7 +42,7 @@ add_mlir_dialect_library(StablehloPasses StablehloAggressiveSimplification.cpp StablehloCanonicalizeDynamism.cpp StablehloConvertToSignless.cpp - StablehloCreateCompatibilityExpander.cpp + StablehloCompatibilityExpander.cpp StablehloLegalizeCompositeToCall.cpp StablehloLegalizeDeprecatedOps.cpp StablehloLegalizeQuantToMath.cpp @@ -60,7 +60,7 @@ add_mlir_dialect_library(StablehloPasses StablehloLegalizeDeprecatedOpsPatternsIncGen PassesIncGen VhloToVersionPatterns - StablehloCreateCompatibilityExpanderPatternsIncGen + StablehloCompatibilityExpanderPatternsIncGen LINK_LIBS PUBLIC ChloOps diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 0bbb48cfe37..5fc81c801aa 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -99,7 +99,7 @@ void populateShapeToStablehloPatterns(MLIRContext *context, /// Collection of patterns to create compatibility expander for StableHLO /// operations. -void populateStablehloCreateCompatibilityExpanderPatterns( +void populateStablehloCompatibilityExpanderPatterns( RewritePatternSet *patterns, MLIRContext *context, vhlo::Version targetVersion); diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index cf45240fadc..e5559075b50 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -15,73 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism", "func::FuncOp"> { - let summary = "Canonicalizes dynamic StableHLO ops into static ops."; - let description = [{ - Replaces dynamic StableHLO ops like DynamicReshapeOp with the corresponding - static counterparts like ReshapeOp if all the dynamic elements of these ops - are actually constant. - - For example, if the output_shape operand of DynamicReshapeOp is a constant - value, then the operation can be transformed to ReshapeOp. - }]; -} - -def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { - let summary = "Legalize StableHLO to VHLO."; - let dependentDialects = ["mlir::vhlo::VhloDialect"]; -} - -def StablehloRefineShapesPass : Pass<"stablehlo-refine-shapes", "ModuleOp"> { - let summary = "Refines shapes across a StableHLO program."; - let description = [{ - Walks through a StableHLO program refining shapes within ops. - - The flagship use case for this pass is specializing dynamically-shaped - programs to static shapes. If a dynamically-shaped StableHLO program has the - right structure, then updating its argument types from dynamic shapes to - static shapes and running this pass will propagate static shapes across - the program. - }]; -} - -def StablehloRefineArgumentsPass : Pass<"stablehlo-refine-arguments", "ModuleOp"> { - let summary = "Refines the argument shapes of the main function."; - let description = [{ - Modifies the arguments of the main function using the input type signature. - Wraps arguments in custom_call @stablehlo.shape_refinement_operand_wrapper - to keep the IR valid before shape refinement is run. - - The `refinedTypesOption` can be used to specify a list of refined types. - This can be specified in MLIR with `--types='tensor<...>,tensor<...>'`, or - passed to the pass create method. The refinement type list must specify the - type of every argument to the `main` method being refined. - }]; - let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; - let options = [ - ListOption<"refinedTypesOption", "types", "std::string", - "The new types to be used for the main function's arguments, specified as an MLIR TypeRange 'tensor<1x2xf32>, ...'">, - ]; -} - -def VhloLegalizeToStablehloPass : Pass<"vhlo-legalize-to-stablehlo", "ModuleOp"> { - let summary = "Legalize VHLO to StableHLO."; - let dependentDialects = [ - "mlir::func::FuncDialect", - "mlir::quant::QuantizationDialect", - "mlir::shape::ShapeDialect", - "mlir::stablehlo::StablehloDialect", - ]; -} - -def VhloToVersionPass : Pass<"vhlo-to-version"> { - let summary = "Convert between versions of VHLO."; - let options = [ - Option<"targetVersionOption", "target", "std::string", "", - "The target version. Must be a version of the form #.#.# .">, - ]; -} - def ChloLegalizeToStablehloPass : Pass<"chlo-legalize-to-stablehlo", "func::FuncOp"> { let summary = "Legalizes from CHLO ops flow to StableHLO and Shape ops"; let dependentDialects = [ @@ -91,6 +24,18 @@ def ChloLegalizeToStablehloPass : Pass<"chlo-legalize-to-stablehlo", "func::Func ]; } +def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::FuncOp"> { + let summary = "Legalize shape-related ops to StableHLO."; + let description = [{ + An experimental pass that legalizes shape-related ops to StableHLO ops. + + Bringing shape and data computations together via an optional pass will + make it possible for the StableHLO ecosystem to potentially leverage the + compilation pipelines that use StableHLO operations to model dynamism. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + def StablehloAggressiveFolderPass : Pass<"stablehlo-aggressive-folder", "func::FuncOp"> { let summary = "Folds StableHLO operations"; @@ -111,40 +56,79 @@ def StablehloAggressiveSimplificationPass ]; } -def StablehloConvertToSignlessPass : Pass<"stablehlo-convert-to-signless", "ModuleOp"> { - let summary = "Pass to transform the IR to be on signless integers."; -} - -def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::FuncOp"> { - let summary = "Legalize shape-related ops to StableHLO."; +def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism", "func::FuncOp"> { + let summary = "Canonicalizes dynamic StableHLO ops into static ops."; let description = [{ - An experimental pass that legalizes shape-related ops to StableHLO ops. + Replaces dynamic StableHLO ops like DynamicReshapeOp with the corresponding + static counterparts like `DynamicReshapeOp` to `ReshapeOp` or + `DynamicBroadcastInDim` to `BroadcastInDim` if all the dynamic elements of = + these ops are actually constants. - Bringing shape and data computations together via an optional pass will - make it possible for the StableHLO ecosystem to potentially leverage the - compilation pipelines that use StableHLO operations to model dynamism. + ``` + %c = stablehlo.constant dense<16> : tensor<1xi32> + %0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor, tensor<1xi32>) -> tensor<16xf32> + + ==> + + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16xf32> + ``` }]; - let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; } -def StablehloLegalizeDeprecatedOpsPass : Pass<"stablehlo-legalize-deprecated-ops", "func::FuncOp"> { - let summary = "Legalize deprecated ops to well-supported ops."; +def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander", "mlir::func::FuncOp"> { + let summary = "Compatibility expander for StableHLO operations."; + let description = [{ - The StableHLO v1.0 Opset Deprecations RFC (#2283) proposes to remove - several redundant ops. This pass helps to evaluate the impact of these op - removals in various compilation pipelines by legalizing them to their - long-term supported counterparts. + StableHLO ops gets updates or new op is introduced in the latest versions. + This opt-in pass expands backward compatibility with older StableHLO + versions by decomposing newer StableHLO operations into equivalent + operations supported by those older versions. + + Why is this an opt-in pass? + + Occasionally, StableHLO op enhancements are used to greatly simplify the + handling of certain common patterns in the OpenXLA ecosystem. This + includes things like TanOp, which has high framework and compiler support, + as well as gather/scatter batching dimensions, which can be represented + using slices, but makes sharding much more difficult. For this category of + new features, we do not offer automatic downgrade, since it may throw away + important information used in subsequent optimizations. This pass can be + used to expand these ops based on a target version to maximize compatibility + at the expense of potentially less optimal compilation. + + ```mlir + func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { + %1 = stablehlo.tan %arg0 : tensor<4xf64> + func.return %1 : tensor<4xf64> + } + + ==> + + func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { + %0 = stablehlo.sine %arg0 : tensor<4xf64> + %1 = stablehlo.cosine %arg0 : tensor<4xf64> + %2 = stablehlo.divide %0, %1 : tensor<4xf64> + return %2 : tensor<4xf64> + } + ``` }]; - let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; let options = [ - Option<"failOnUnusedOps", "fail-on-unused", "bool", /*default=*/"true", - "Fail on (mostly) unused ops that are deprecated without any fallback.">, + Option<"targetVersionOption", "target", "std::string", "", + "The target version. Must be a version of the form #.#.#.">, ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::chlo::ChloDialect", + ]; +} + +def StablehloConvertToSignlessPass : Pass<"stablehlo-convert-to-signless", "ModuleOp"> { + let summary = "Pass to transform the IR to be on signless integers."; } def StablehloLegalizeCompositeToCallPass : Pass<"stablehlo-legalize-composite-to-call", "func::FuncOp"> { - let summary = "Replaces composite ops with a call to their decomposition"; + let summary = "Replaces composite ops with a call to their decomposition."; let description = [{ Replaces composite ops with a call to their decomposition, e.g. the below: @@ -179,51 +163,51 @@ def StablehloLegalizeCompositeToCallPass : ]; } -def StablehloLegalizeQuantToMathPass : Pass<"stablehlo-legalize-quant-to-math", "mlir::func::FuncOp"> { - let summary = "Convert from StableHLO quantized ops to StableHLO primitive math ops."; +def StablehloLegalizeDeprecatedOpsPass : Pass<"stablehlo-legalize-deprecated-ops", "func::FuncOp"> { + let summary = "Legalize deprecated ops to well-supported ops."; + let description = [{ + The StableHLO v1.0 Opset Deprecations RFC (#2283) proposes to remove + several redundant ops. This pass helps to evaluate the impact of these op + removals in various compilation pipelines by legalizing them to their + long-term supported counterparts. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; + let options = [ + Option<"failOnUnusedOps", "fail-on-unused", "bool", /*default=*/"true", + "Fail on (mostly) unused ops that are deprecated without any fallback.">, + ]; +} + +def StablehloLegalizeQDQToQuantizedOpPass : Pass<"stablehlo-legalize-qdq-to-quantized-op", "mlir::func::FuncOp"> { + let summary = "Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation"; let description = [{ - Convert StableHLO programs using UniformQuantized types to semantically - equivalent integer math operations. + Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation + Note: The pass does not delete any preexisting op. + For example, the following program ```mlir - func.func @add(%arg0: tensor>, %arg1: tensor>) -> tensor> { - %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> - func.return %0 : tensor> + func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + func.return %2 : tensor<16x16x!quant.uniform> } ``` Will become: ```mlir - func.func @add(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = stablehlo.convert %arg0 : (tensor) -> tensor - %cst = stablehlo.constant dense<0.333333343> : tensor - %1 = chlo.broadcast_multiply %0, %cst : (tensor, tensor) -> tensor - %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor - %2 = chlo.broadcast_add %1, %cst_0 : (tensor, tensor) -> tensor - %3 = stablehlo.round_nearest_even %2 : tensor - %4 = stablehlo.convert %3 : (tensor) -> tensor - %5 = stablehlo.convert %arg1 : (tensor) -> tensor - %cst_1 = stablehlo.constant dense<0.666666686> : tensor - %6 = chlo.broadcast_multiply %5, %cst_1 : (tensor, tensor) -> tensor - %cst_2 = stablehlo.constant dense<1.33333337> : tensor - %7 = chlo.broadcast_add %6, %cst_2 : (tensor, tensor) -> tensor - %8 = stablehlo.round_nearest_even %7 : tensor - %9 = stablehlo.convert %8 : (tensor) -> tensor - %c = stablehlo.constant dense<2> : tensor - %10 = chlo.broadcast_add %4, %9 : (tensor, tensor) -> tensor - %11 = chlo.broadcast_subtract %10, %c : (tensor, tensor) -> tensor - %c_3 = stablehlo.constant dense<-128> : tensor - %c_4 = stablehlo.constant dense<127> : tensor - %12 = stablehlo.clamp %c_3, %11, %c_4 : tensor - %13 = stablehlo.convert %12 : (tensor) -> tensor - return %13 : tensor + func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { + %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %1 = stablehlo.abs %0 : tensor<16x16xf32> + %2 = stablehlo.abs %arg0 : tensor<16x16x!quant.uniform> + %3 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + return %2 : tensor<16x16x!quant.uniform> } ``` }]; let dependentDialects = [ - "mlir::chlo::ChloDialect", "mlir::stablehlo::StablehloDialect", ]; } @@ -259,85 +243,135 @@ def StablehloLegalizeQuantizedOpToQDQPass : Pass<"stablehlo-legalize-quantized-o ]; } -def StablehloLegalizeQDQToQuantizedOpPass : Pass<"stablehlo-legalize-qdq-to-quantized-op", "mlir::func::FuncOp"> { - let summary = "Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation"; +def StablehloLegalizeQuantToMathPass : Pass<"stablehlo-legalize-quant-to-math", "mlir::func::FuncOp"> { + let summary = "Convert from StableHLO quantized ops to StableHLO primitive math ops."; let description = [{ - Fuse (de-quantize, floating-point operation and quantize) pattern into StableHLO quantized operation - Note: The pass does not delete any preexisting op. - For example, the following program + Convert StableHLO programs using UniformQuantized types to semantically + equivalent integer math operations. ```mlir - func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { - %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> - %1 = stablehlo.abs %0 : tensor<16x16xf32> - %2 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> - func.return %2 : tensor<16x16x!quant.uniform> + func.func @add(%arg0: tensor>, %arg1: tensor>) -> tensor> { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> } ``` Will become: ```mlir - func.func @add(%arg0: tensor<16x16x!quant.uniform>) -> tensor<16x16x!quant.uniform> { - %0 = stablehlo.uniform_dequantize %arg0 : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> - %1 = stablehlo.abs %0 : tensor<16x16xf32> - %2 = stablehlo.abs %arg0 : tensor<16x16x!quant.uniform> - %3 = stablehlo.uniform_quantize %1 : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> - return %2 : tensor<16x16x!quant.uniform> + func.func @add(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.convert %arg0 : (tensor) -> tensor + %cst = stablehlo.constant dense<0.333333343> : tensor + %1 = chlo.broadcast_multiply %0, %cst : (tensor, tensor) -> tensor + %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor + %2 = chlo.broadcast_add %1, %cst_0 : (tensor, tensor) -> tensor + %3 = stablehlo.round_nearest_even %2 : tensor + %4 = stablehlo.convert %3 : (tensor) -> tensor + %5 = stablehlo.convert %arg1 : (tensor) -> tensor + %cst_1 = stablehlo.constant dense<0.666666686> : tensor + %6 = chlo.broadcast_multiply %5, %cst_1 : (tensor, tensor) -> tensor + %cst_2 = stablehlo.constant dense<1.33333337> : tensor + %7 = chlo.broadcast_add %6, %cst_2 : (tensor, tensor) -> tensor + %8 = stablehlo.round_nearest_even %7 : tensor + %9 = stablehlo.convert %8 : (tensor) -> tensor + %c = stablehlo.constant dense<2> : tensor + %10 = chlo.broadcast_add %4, %9 : (tensor, tensor) -> tensor + %11 = chlo.broadcast_subtract %10, %c : (tensor, tensor) -> tensor + %c_3 = stablehlo.constant dense<-128> : tensor + %c_4 = stablehlo.constant dense<127> : tensor + %12 = stablehlo.clamp %c_3, %11, %c_4 : tensor + %13 = stablehlo.convert %12 : (tensor) -> tensor + return %13 : tensor } ``` }]; let dependentDialects = [ + "mlir::chlo::ChloDialect", "mlir::stablehlo::StablehloDialect", ]; } -def StablehloCreateCompatibilityExpanderPass : Pass<"stablehlo-create-compatibility-expander", "mlir::func::FuncOp"> { - let summary = "Create compatibility expander for StableHLO operations."; +def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { + let summary = "Legalize StableHLO to VHLO."; + let dependentDialects = ["mlir::vhlo::VhloDialect"]; +} +def StablehloRefineArgumentsPass : Pass<"stablehlo-refine-arguments", "ModuleOp"> { + let summary = "Refines the argument shapes of the main function."; let description = [{ - StableHLO ops gets updates or new op is introduced in the latest versions. - This opt-in pass expands backward compatibility with older StableHLO - versions by decomposing newer StableHLO operations into equivalent - operations supported by those older versions. - - Why is this an opt-in pass? - - Occasionally, StableHLO op enhancements are used to greatly simplify the - handling of certain common patterns in the OpenXLA ecosystem. This - includes things like TanOp, which has high framework and compiler support, - as well as gather/scatter batching dimensions, which can be represented - using slices, but makes sharding much more difficult. For this category of - new features, we do not offer automatic downgrade, since it may throw away - important information used in subsequent optimizations. This pass can be - used to expand these ops based on a target version to maximize compatibility - at the expense of potentially less optimal compilation. + Modifies the arguments of the main function using the input type signature. + Wraps arguments in `custom_call @stablehlo.shape_refinement_operand_wrapper` + to keep the IR valid before shape refinement is run. - ```mlir - func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { - %1 = stablehlo.tan %arg0 : tensor<4xf64> - func.return %1 : tensor<4xf64> - } ``` + func.func public @main(%arg0: tensor) -> tensor { + ... + } - will become: + ==> - ```mlir - func.func @tan_op_non_complex(%arg0: tensor<4xf64>) -> tensor<4xf64> { - %0 = stablehlo.sine %arg0 : tensor<4xf64> - %1 = stablehlo.cosine %arg0 : tensor<4xf64> - %2 = stablehlo.divide %0, %1 : tensor<4xf64> - return %2 : tensor<4xf64> - } + func.func public @main(%arg0: tensor<16xf32>) -> tensor { + %c = stablehlo.constant dense<16> : tensor<1xi64> + %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...} + : (tensor<16xf32>, tensor<1xi64>) -> tensor + ... ``` + + The `refinedTypesOption` can be used to specify a list of refined types. + This can be specified in MLIR with `--types='tensor<...>,tensor<...>'`, or + passed to the pass create method. The refinement type list must specify the + type of every argument to the `main` method being refined. }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; let options = [ - Option<"targetVersionOption", "target", "std::string", "", - "The target version. Must be a version of the form #.#.#.">, + ListOption<"refinedTypesOption", "types", "std::string", + "The new types to be used for the main function's arguments, specified as an MLIR TypeRange 'tensor<1x2xf32>, ...'">, ]; +} + +def StablehloRefineShapesPass : Pass<"stablehlo-refine-shapes", "ModuleOp"> { + let summary = "Refines shapes across a StableHLO program."; + let description = [{ + Walks through a StableHLO program refining shapes within ops. + + The flagship use case for this pass is specializing dynamically-shaped + programs to static shapes. If a dynamically-shaped StableHLO program has the + right structure, then updating its argument types from dynamic shapes to + static shapes and running this pass will propagate static shapes across + the program. + + This pass removes `custom_call @shape_refinement_operand_wrapper` by + replacing uses of the result with the operand directly, and propagates + static shapes throughout the program. + + ``` + %c = stablehlo.constant dense<16> : tensor<1xi64> + %0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {...} + : (tensor<16xf32>, tensor<1xi64>) -> tensor + %1 = stablehlo.add %0, %0 : tensor + + ==> + + %1 = stablehlo.add %arg0, %arg0 : tensor<16xf32> + ``` + }]; +} + +def VhloLegalizeToStablehloPass : Pass<"vhlo-legalize-to-stablehlo", "ModuleOp"> { + let summary = "Legalize VHLO to StableHLO."; let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::quant::QuantizationDialect", + "mlir::shape::ShapeDialect", "mlir::stablehlo::StablehloDialect", - "mlir::chlo::ChloDialect", + ]; +} + +def VhloToVersionPass : Pass<"vhlo-to-version"> { + let summary = "Convert between versions of VHLO."; + let options = [ + Option<"targetVersionOption", "target", "std::string", "", + "The target version. Must be a version of the form #.#.# .">, ]; } diff --git a/stablehlo/transforms/StablehloCompatibilityExpander.cpp b/stablehlo/transforms/StablehloCompatibilityExpander.cpp new file mode 100644 index 00000000000..03bb810067d --- /dev/null +++ b/stablehlo/transforms/StablehloCompatibilityExpander.cpp @@ -0,0 +1,255 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/Version.h" +#include "stablehlo/transforms/PassUtils.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DEF_STABLEHLOCOMPATIBILITYEXPANDERPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Helpers. +//===----------------------------------------------------------------------===// + +// Check user-specified target version. +vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { + auto failOrVersion = vhlo::Version::fromString(versionRef); + if (failed(failOrVersion)) { + assert(!versionRef.empty() && + "No target version specified. Target version must be of the form " + "`#.#.#`."); + assert(versionRef.empty() && + "Invalid target version argument. Target version must be of the " + "form `#.#.#`."); + } + vhlo::Version targetVersion = *failOrVersion; + assert((vhlo::Version::getMinimumVersion() <= targetVersion) && + "target version is less than minimum supported."); + assert((targetVersion <= vhlo::Version::getCurrentVersion()) && + "target version is greater than current version."); + return targetVersion; +} + +SmallVector mergeSortedDims(ArrayRef dims1, + ArrayRef dims2) { + SmallVector result; + result.reserve(dims1.size() + dims2.size()); + std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), + std::back_inserter(result)); + return result; +} + +// Returns an updated indices tensor such that an `IotaOp` is prepended for each +// dim in `indicesBatchingDims` with a `ConcatenateOp`. +// +// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have +// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. +Value createConcatIndices(Value indices, int64_t indexVectorDim, + ArrayRef indicesBatchingDims, + PatternRewriter &rewriter) { + Location loc = indices.getLoc(); + auto indicesType = cast(indices.getType()); + bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); + + SmallVector iotaShape(indicesType.getShape()); + if (indexVectorDimOnLastDim) { + iotaShape.push_back(1); + } else { + iotaShape[indexVectorDim] = 1; + } + auto iotaType = + RankedTensorType::get(iotaShape, indicesType.getElementType()); + + SmallVector indicesToConcat; + indicesToConcat.reserve(indicesBatchingDims.size() + 1); + for (int64_t batchingDim : indicesBatchingDims) { + indicesToConcat.push_back( + rewriter.create(loc, iotaType, batchingDim)); + } + if (indexVectorDimOnLastDim) { + indicesToConcat.push_back( + rewriter.create(loc, iotaType, indices)); + } else { + indicesToConcat.push_back(indices); + } + return rewriter.create(loc, indicesToConcat, indexVectorDim); +} + +//===----------------------------------------------------------------------===// +// Patterns (non DRR) +//===----------------------------------------------------------------------===// + +// Converts a `GatherOp` with batching dims to a `GatherOp` without batching +// dims, such that each batching dim becomes a collapsed slice dim with a +// corresponding `IotaOp` concatenated to the start indices. +class GatherWithBatchingDimsExpander : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter &rewriter) const override { + GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); + ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); + if (operandBatchingDims.empty()) + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has no batching dims"; + }); + + SmallVector newCollapsedSliceDims = mergeSortedDims( + operandBatchingDims, dimNumbers.getCollapsedSliceDims()); + SmallVector newStartIndexMap = + llvm::to_vector(llvm::concat( + operandBatchingDims, dimNumbers.getStartIndexMap())); + Value newIndices = createConcatIndices( + op.getStartIndices(), dimNumbers.getIndexVectorDim(), + dimNumbers.getStartIndicesBatchingDims(), rewriter); + rewriter.replaceOpWithNewOp( + op, op.getOperand(), newIndices, + GatherDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, + /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, + newStartIndexMap, dimNumbers.getIndexVectorDim()), + op.getSliceSizes(), /*indicesAreSorted=*/false); + + return success(); + } +}; + +// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching +// dims, such that each batching dim becomes an inserted window dim with a +// corresponding `IotaOp` concatenated to the scatter indices. +class ScatterWithBatchingDimsExpander : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ScatterOp op, + PatternRewriter &rewriter) const override { + ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); + ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); + if (inputBatchingDims.empty()) + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "scatter op has no batching dims"; + }); + + SmallVector newInsertedWindowDims = + mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); + SmallVector newScatterDimsToOperandDims = + llvm::to_vector(llvm::concat( + inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); + Value newIndices = createConcatIndices( + op.getScatterIndices(), dimNumbers.getIndexVectorDim(), + dimNumbers.getScatterIndicesBatchingDims(), rewriter); + auto newScatterOp = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, + op.getUpdates(), + ScatterDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getUpdateWindowDims(), + newInsertedWindowDims, + /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, + newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), + /*indicesAreSorted=*/false, op.getUniqueIndices()); + + newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); + rewriter.replaceOp(op, newScatterOp.getResults()); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct StablehloCompatibilityExpanderPass + : public impl::StablehloCompatibilityExpanderPassBase< + StablehloCompatibilityExpanderPass> { + StablehloCompatibilityExpanderPass() + : StablehloCompatibilityExpanderPassBase< + StablehloCompatibilityExpanderPass>() {} + StablehloCompatibilityExpanderPass( + const StablehloCompatibilityExpanderPassOptions &opts) + : StablehloCompatibilityExpanderPassBase< + StablehloCompatibilityExpanderPass>(opts) {} + + public: + LogicalResult initialize(MLIRContext *context) override { + auto targetVersion = validateTargetVersion(targetVersionOption); + + config.useTopDownTraversal = true; + RewritePatternSet patterns_(context); + populateStablehloCompatibilityExpanderPatterns(&patterns_, context, + targetVersion); + patterns = std::move(patterns_); + return success(); + } + + void runOnOperation() override { + auto func = getOperation(); + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { + func.emitError( + "Failed to converge StableHLOCompatibilityExpanderPass in ") + << config.maxIterations << " iterations"; + signalPassFailure(); + } + } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; +}; + +#include "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc" + +} // namespace + +void populateStablehloCompatibilityExpanderPatterns( + RewritePatternSet *patterns, MLIRContext *context, + vhlo::Version targetVersion) { + // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. + if (targetVersion < vhlo::Version(1, 1, 0)) + patterns + ->add( + context); + + // StableHLO TanOp is introduced in v1.4.0. + if (targetVersion < vhlo::Version(1, 4, 0)) + patterns->add(context); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/transforms/StablehloCompatibilityExpanderPatterns.td similarity index 100% rename from stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td rename to stablehlo/transforms/StablehloCompatibilityExpanderPatterns.td diff --git a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp index 0761a8d80cb..6d5d19eb56b 100644 --- a/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +++ b/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp @@ -38,7 +38,7 @@ limitations under the License. namespace mlir { namespace stablehlo { -#define GEN_PASS_DEF_STABLEHLOCREATECOMPATIBILITYEXPANDERPASS +#define GEN_PASS_DEF_STABLEHLOCOMPATIBILITYEXPANDERPASS #include "stablehlo/transforms/Passes.h.inc" namespace { @@ -196,16 +196,16 @@ class ScatterWithBatchingDimsExpander : public OpRewritePattern { // Pass //===----------------------------------------------------------------------===// -struct StablehloCreateCompatibilityExpanderPass - : public impl::StablehloCreateCompatibilityExpanderPassBase< - StablehloCreateCompatibilityExpanderPass> { - StablehloCreateCompatibilityExpanderPass() - : StablehloCreateCompatibilityExpanderPassBase< - StablehloCreateCompatibilityExpanderPass>() {} - StablehloCreateCompatibilityExpanderPass( - const StablehloCreateCompatibilityExpanderPassOptions &opts) - : StablehloCreateCompatibilityExpanderPassBase< - StablehloCreateCompatibilityExpanderPass>(opts) {} +struct StablehloCompatibilityExpanderPass + : public impl::StablehloCompatibilityExpanderPassBase< + StablehloCompatibilityExpanderPass> { + StablehloCompatibilityExpanderPass() + : StablehloCompatibilityExpanderPassBase< + StablehloCompatibilityExpanderPass>() {} + StablehloCompatibilityExpanderPass( + const StablehloCompatibilityExpanderPassOptions &opts) + : StablehloCompatibilityExpanderPassBase< + StablehloCompatibilityExpanderPass>(opts) {} public: LogicalResult initialize(MLIRContext *context) override { @@ -213,8 +213,8 @@ struct StablehloCreateCompatibilityExpanderPass config.useTopDownTraversal = true; RewritePatternSet patterns_(context); - populateStablehloCreateCompatibilityExpanderPatterns(&patterns_, context, - targetVersion); + populateStablehloCompatibilityExpanderPatterns(&patterns_, context, + targetVersion); patterns = std::move(patterns_); return success(); } @@ -223,7 +223,7 @@ struct StablehloCreateCompatibilityExpanderPass auto func = getOperation(); if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { func.emitError( - "Failed to converge StableHLOCreateCompatibilityExpanderPass in ") + "Failed to converge StableHLOCompatibilityExpanderPass in ") << config.maxIterations << " iterations"; signalPassFailure(); } @@ -234,11 +234,11 @@ struct StablehloCreateCompatibilityExpanderPass GreedyRewriteConfig config; }; -#include "stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.h.inc" +#include "stablehlo/transforms/StablehloCompatibilityExpanderPatterns.h.inc" } // namespace -void populateStablehloCreateCompatibilityExpanderPatterns( +void populateStablehloCompatibilityExpanderPatterns( RewritePatternSet *patterns, MLIRContext *context, vhlo::Version targetVersion) { // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0.