Skip to content

Commit

Permalink
#sdy copy sharding of ShardingConstraintOp to input even if it has …
Browse files Browse the repository at this point in the history
…other users of type `ShardingConstraintOp` or `ManualComputationOp`, as long as they have the same sharding.

PiperOrigin-RevId: 687233502
  • Loading branch information
tomnatan30 authored and copybara-github committed Oct 18, 2024
1 parent d247923 commit c94100f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
22 changes: 20 additions & 2 deletions shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ namespace sdy {

namespace {

// Returns true if `input` is used by any `ShardingConstraintOp` or
// `ManualComputationOp`, that has a different sharding than `sharding`.
bool isUsedByConstraintWithDifferentSharding(Value input,
TensorShardingAttr sharding) {
return llvm::any_of(input.getUses(), [&](OpOperand& use) {
if (auto otherShardingConstraint =
dyn_cast<ShardingConstraintOp>(use.getOwner())) {
return otherShardingConstraint.getSharding() != sharding;
}
if (auto manualComputation =
dyn_cast<ManualComputationOp>(use.getOwner())) {
return manualComputation.getInSharding(use.getOperandNumber()) !=
sharding;
}
return false;
});
}

// Returns true if `input` is used by any `ShardingConstraintOp` or
// `ManualComputationOp`, that isn't `optionalShardingConstraint` if provided.
bool isUsedByOtherShardingConstraint(
Expand Down Expand Up @@ -62,8 +80,8 @@ bool shouldApply(Value input, TensorShardingAttr sharding, Operation* op) {

// TODO(b/358627707): revisit restricting to a single use if not dangling.
// Return true if `input` has no other uses of type `ShardingConstraintOp` or
// `ManualComputationOp`
return !isUsedByOtherShardingConstraint(input, op);
// `ManualComputationOp` with a different sharding.
return !isUsedByConstraintWithDifferentSharding(input, sharding);
}

// If `curShardingConstraintOp` is the last `ShardingConstraintOp` in a chain
Expand Down
2 changes: 1 addition & 1 deletion shardy/dialect/sdy/transforms/import/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func:
of all targets of the edge.
* The sharding of the `ShardingConstraintOp` is fully closed.
* The input doesn't have any other users of type `ShardingConstraintOp` or
`ManualComputationOp`.
`ManualComputationOp` with a different sharding.

Which indicates that the `ShardingConstraintOp` dictates the sharding of
its input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func.func @no_other_sharding_constraint_users(%arg0: tensor<8x8xf32>)
return %0, %1, %2 : tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>
}

// CHECK-LABEL: func @has_other_sharding_constraint_user
func.func @has_other_sharding_constraint_user(%arg0: tensor<8x8xf32>)
// CHECK-LABEL: func @has_different_sharding_constraint_user
func.func @has_different_sharding_constraint_user(%arg0: tensor<8x8xf32>)
-> (tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>) {
// CHECK-NEXT: stablehlo.add %arg0, %arg0
// CHECK-NOT: sdy.sharding
Expand All @@ -67,8 +67,18 @@ func.func @has_other_sharding_constraint_user(%arg0: tensor<8x8xf32>)
return %0, %1, %2 : tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>
}

// CHECK-LABEL: func @has_other_manual_computation_user
func.func @has_other_manual_computation_user(%arg0: tensor<8x8xf32>)
// CHECK-LABEL: func @has_other_identical_sharding_constraint_user
func.func @has_other_identical_sharding_constraint_user(%arg0: tensor<8x8xf32>)
-> (tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>) {
// CHECK-NEXT: stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>}
%0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh, [{}, {"b"}]> : tensor<8x8xf32>
%2 = sdy.sharding_constraint %0 <@mesh, [{}, {"b"}]> : tensor<8x8xf32>
return %0, %1, %2 : tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>
}

// CHECK-LABEL: func @has_other_manual_computation_user_diff_sharding
func.func @has_other_manual_computation_user_diff_sharding(%arg0: tensor<8x8xf32>)
-> (tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>) {
// CHECK-NEXT: stablehlo.add %arg0, %arg0
// CHECK-NOT: sdy.sharding
Expand All @@ -82,6 +92,19 @@ func.func @has_other_manual_computation_user(%arg0: tensor<8x8xf32>)
return %0, %1, %2 : tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>
}

// CHECK-LABEL: func @has_other_manual_computation_user_same_sharding
func.func @has_other_manual_computation_user_same_sharding(%arg0: tensor<8x8xf32>)
-> (tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>) {
// CHECK-NEXT: stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>}
%0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh, [{}, {"b"}]> : tensor<8x8xf32>
%2 = sdy.manual_computation(%0) in_shardings=[<@mesh, [{}, {"b"}]>] out_shardings=[<@mesh, [{}, {"b"}]>]
manual_axes={"b"} (%arg2: tensor<8x4xf32>) {
sdy.return %arg2 : tensor<8x4xf32>
} : (tensor<8x8xf32>) -> tensor<8x8xf32>
return %0, %1, %2 : tensor<8x8xf32>, tensor<8x8xf32>, tensor<8x8xf32>
}

// CHECK-LABEL: func @dangling_and_no_other_sharding_constraint_users
func.func @dangling_and_no_other_sharding_constraint_users(%arg0: tensor<8x8xf32>)
-> (tensor<8x8xf32>, tensor<8x8xf32>) {
Expand Down

0 comments on commit c94100f

Please sign in to comment.