Skip to content

Commit

Permalink
Unhandle the case where tensor factor shardings have overflow axes ex…
Browse files Browse the repository at this point in the history
…plicitly.

If some tensor factor shardings have non-empty overflow axes, the pass of inserting explicit reshards is a no-op. This case is planned to be handled separately.

PiperOrigin-RevId: 688505548
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Oct 22, 2024
1 parent 78ea64e commit 05815cf
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,34 @@ namespace sdy {

namespace {

// Returns true iff any tensor factor sharding has non-empty overflow axes.
bool hasOverflowAxes(const ShardingProjection& projection) {
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
for (const auto& [_, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
if (!factorSharding.overflowAxes.empty()) {
return true;
}
}
}
return false;
}

// Checks if factor sharding is compatible, that is, it satisfies:
// 1. Factors are sharded the same way across operands and results.
bool hasCompatibleFactorSharding(const ShardingProjection& projection) {
//
// Assumes factor shardings do not have overflow axes.
// TODO(enver): Handle the case when some factor shardings have overflow axes.
bool hasCompatibleFactorShardings(const ShardingProjection& projection) {
FactorIndexToSharding factorIndexToCommonSharding;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
// Detects conflicts within the same factor.
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
// TODO(enver): Handle the case when some factor shardings have overflow
// axes.
if (!factorSharding.overflowAxes.empty()) {
return false;
}
auto commonFactorShardingIt =
factorIndexToCommonSharding.find(factorIndex);
if (commonFactorShardingIt == factorIndexToCommonSharding.end()) {
Expand Down Expand Up @@ -103,11 +116,22 @@ struct InsertExplicitReshardsPass
assert(mesh && "unknown mesh");
ShardingProjection shardingProjection =
ShardingProjection::build(op, shardingRule, mesh);

// Return without inserting reshards if any factor sharding has overflow
// axes. This case is not handled yet.
// TODO(enver): Handle the case when factor shardings have overflow axes.
if (hasOverflowAxes(shardingProjection)) {
return;
}

// Checks if factors are sharded the same way across operands and results.
if (hasCompatibleFactorSharding(shardingProjection)) {
if (hasCompatibleFactorShardings(shardingProjection)) {
return;
}

// TODO(enver): Build a projection where, for each factor, factor
// shardings are the same across all operands and results;

// TODO(enver): Insert the explicit reshard ops.
});
}
Expand Down

0 comments on commit 05815cf

Please sign in to comment.