From c8c435b87f21f9a2bd898cca5421970ed6c16557 Mon Sep 17 00:00:00 2001 From: Tom Natan Date: Thu, 18 Jul 2024 12:17:51 -0700 Subject: [PATCH] #sdy Make `-sharding-constraint-to-reshard` pass simple by always replacing an `sdy.sharding_constraint` with an `sdy.reshard`. Reshards will be optimized using canonicalization patterns. PiperOrigin-RevId: 653713049 --- .../export/sharding_constraint_to_reshard.cc | 15 ++------------ .../test/sharding_constraint_to_reshard.mlir | 20 +++---------------- 2 files changed, 5 insertions(+), 30 deletions(-) diff --git a/shardy/dialect/sdy/transforms/export/sharding_constraint_to_reshard.cc b/shardy/dialect/sdy/transforms/export/sharding_constraint_to_reshard.cc index ff31a71..c895cef 100644 --- a/shardy/dialect/sdy/transforms/export/sharding_constraint_to_reshard.cc +++ b/shardy/dialect/sdy/transforms/export/sharding_constraint_to_reshard.cc @@ -24,7 +24,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "shardy/dialect/sdy/ir/utils.h" namespace mlir { namespace sdy { @@ -43,18 +42,8 @@ class ShardingConstraintPattern LogicalResult matchAndRewrite( ShardingConstraintOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - Value input = adaptor.getInput(); - if (TensorShardingAttr sharding = getSharding(input)) { - if (sharding != op.getShardingAttr()) { - rewriter.replaceOpWithNewOp(op, input) - .setShardingAttr(op.getShardingAttr()); - } else { - rewriter.replaceOp(op, input); - } - } else { - setSharding(input, op.getShardingAttr()); - rewriter.replaceOp(op, input); - } + rewriter.replaceOpWithNewOp(op, adaptor.getInput()) + .setShardingAttr(adaptor.getSharding()); return success(); } }; diff --git a/shardy/dialect/sdy/transforms/export/test/sharding_constraint_to_reshard.mlir b/shardy/dialect/sdy/transforms/export/test/sharding_constraint_to_reshard.mlir index 3c442d6..4f15c33 100644 --- a/shardy/dialect/sdy/transforms/export/test/sharding_constraint_to_reshard.mlir +++ b/shardy/dialect/sdy/transforms/export/test/sharding_constraint_to_reshard.mlir @@ -2,23 +2,9 @@ sdy.mesh @mesh = <"a"=2, "b"=2> -// CHECK-LABEL: func @remove_redundant_sharding_constraint -func.func @remove_redundant_sharding_constraint(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}) -> tensor<8x8xf32> { - // CHECK: return %arg0 : tensor<8x8xf32> - %0 = sdy.sharding_constraint %arg0 <@mesh, [{"a", ?}, {?}]> : tensor<8x8xf32> - return %0 : tensor<8x8xf32> -} - // CHECK-LABEL: func @sharding_constraint_to_reshard -func.func @sharding_constraint_to_reshard(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}) -> tensor<8x8xf32> { - // CHECK: %0 = sdy.reshard %arg0 <@mesh, [{?}, {?}]> : tensor<8x8xf32> - %0 = sdy.sharding_constraint %arg0 <@mesh, [{?}, {?}]> : tensor<8x8xf32> - return %0 : tensor<8x8xf32> -} - -// CHECK: func.func @get_sharding_from_sharding_constraint(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {?}]>}) -> tensor<8x8xf32> -func.func @get_sharding_from_sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: return %arg0 : tensor<8x8xf32> - %0 = sdy.sharding_constraint %arg0 <@mesh, [{"a", ?}, {?}]> : tensor<8x8xf32> +func.func @sharding_constraint_to_reshard(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK: %0 = sdy.reshard %arg0 <@mesh, [{"a"}, {?}]> : tensor<8x8xf32> + %0 = sdy.sharding_constraint %arg0 <@mesh, [{"a"}, {?}]> : tensor<8x8xf32> return %0 : tensor<8x8xf32> }