Skip to content

Commit

Permalink
#sdy Apply input shardings of ManualComputationOp in `-apply-shardi…
Browse files Browse the repository at this point in the history
…ng-constraints` pass, this is because they are in essense sharding constraints given by the user.

PiperOrigin-RevId: 653578847
  • Loading branch information
tomnatan30 authored and copybara-github committed Jul 18, 2024
1 parent 4cd73aa commit 4753ca3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
33 changes: 27 additions & 6 deletions shardy/dialect/sdy/transforms/import/apply_sharding_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "third_party/openxla/shardy/src/shardy/dialect/sdy/ir/dialect.h"

namespace mlir {
namespace sdy {
Expand All @@ -32,7 +33,7 @@ namespace sdy {

namespace {

bool shouldApply(Value input, ShardingConstraintOp shardingConstraintOp) {
bool shouldApply(Value input, Operation* op) {
if (getSharding(input)) {
// `input` already has a sharding.
return false;
Expand All @@ -43,6 +44,12 @@ bool shouldApply(Value input, ShardingConstraintOp shardingConstraintOp) {
return true;
}

auto shardingConstraintOp = dyn_cast<ShardingConstraintOp>(op);
if (!shardingConstraintOp) {
// `op` is a `ManualComputationOp` and `input` has other uses.
return false;
}

// `shardingConstraintOp` is dangling and `input` has no other uses of type
// `ShardingConstraintOp`.
return shardingConstraintOp.use_empty() &&
Expand All @@ -58,11 +65,25 @@ struct ApplyShardingConstraintsPass
using ApplyShardingConstraintsPassBase::ApplyShardingConstraintsPassBase;

void runOnOperation() final {
getOperation().walk([](ShardingConstraintOp shardingConstraintOp) {
Value input = shardingConstraintOp.getInput();
if (shouldApply(input, shardingConstraintOp)) {
setSharding(input, shardingConstraintOp.getSharding());
}
getOperation().walk([](Operation* op) {
TypeSwitch<Operation*>(op)
.Case<ShardingConstraintOp>(
[](ShardingConstraintOp shardingConstraintOp) {
Value input = shardingConstraintOp.getInput();
if (shouldApply(input, shardingConstraintOp)) {
setSharding(input, shardingConstraintOp.getSharding());
}
})
.Case<ManualComputationOp>(
[](ManualComputationOp manualComputationOp) {
for (auto [operand, sharding] : llvm::zip_equal(
manualComputationOp.getOperands(),
manualComputationOp.getInShardings().getShardings())) {
if (shouldApply(operand, manualComputationOp)) {
setSharding(operand, sharding);
}
}
});
});
}
};
Expand Down
5 changes: 5 additions & 0 deletions shardy/dialect/sdy/transforms/import/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func:
input or users during propagation regardless of this pass, but since the
closed property of a dimension doesn't propagate, it's important to copy the
sharding to fully respect the constraint in the above cases.

A `in_shardings` of a `ManualComputationOp` are in essense sharding
constraints on the corresponding operands, so this pass will also apply
their sharding if the above conditions are satisfied (expect for the
dangling case).
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ func.func @dangling_and_has_other_sharding_constraint_users(%arg0: tensor<8x8xf3
return %0 : tensor<8x8xf32>
}

// CHECK-LABEL: func @manual_computation
func.func @manual_computation(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) -> (tensor<8x8xf32>, tensor<8x8xf32>) {
// CHECK-NEXT: stablehlo.add %arg0, %arg0
// CHECK-NOT: sdy.sharding
// CHECK-NEXT: stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b", "a"}, {}]>]>}
// CHECK-NEXT: sdy.manual_computation
%0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32>
%1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32>
%2 = sdy.manual_computation(%0, %1) in_shardings=[<@mesh, [{"a", "b"}, {}]>, <@mesh, [{"b", "a"}, {}]>] out_shardings=[<@mesh, [{"a", "b"}, {}]>]
manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<2x8xf32>) {
%3 = stablehlo.add %arg2, %arg3 : tensor<2x8xf32>
sdy.return %3 : tensor<2x8xf32>
} : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32>
func.return %0, %2: tensor<8x8xf32>, tensor<8x8xf32>
}

0 comments on commit 4753ca3

Please sign in to comment.