diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index f9f4837..89fc302 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -37,9 +37,27 @@ namespace sdy { namespace { +// Returns true iff any tensor factor sharding has non-empty overflow axes. +bool hasOverflowAxes(const ShardingProjection& projection) { + for (const TensorFactorShardings& tensorFactorSharding : + llvm::concat(projection.getOperands(), + projection.getResults())) { + for (const auto& [_, factorSharding] : + tensorFactorSharding.factorIndexToSharding) { + if (!factorSharding.overflowAxes.empty()) { + return true; + } + } + } + return false; +} + // Checks if factor sharding is compatible, that is, it satisfies: // 1. Factors are sharded the same way across operands and results. -bool hasCompatibleFactorSharding(const ShardingProjection& projection) { +// +// Assumes factor shardings do not have overflow axes. +// TODO(enver): Handle the case when some factor shardings have overflow axes. +bool hasCompatibleFactorShardings(const ShardingProjection& projection) { FactorIndexToSharding factorIndexToCommonSharding; for (const TensorFactorShardings& tensorFactorSharding : llvm::concat(projection.getOperands(), @@ -47,11 +65,6 @@ bool hasCompatibleFactorSharding(const ShardingProjection& projection) { // Detects conflicts within the same factor. for (const auto& [factorIndex, factorSharding] : tensorFactorSharding.factorIndexToSharding) { - // TODO(enver): Handle the case when some factor shardings have overflow - // axes. - if (!factorSharding.overflowAxes.empty()) { - return false; - } auto commonFactorShardingIt = factorIndexToCommonSharding.find(factorIndex); if (commonFactorShardingIt == factorIndexToCommonSharding.end()) { @@ -103,11 +116,22 @@ struct InsertExplicitReshardsPass assert(mesh && "unknown mesh"); ShardingProjection shardingProjection = ShardingProjection::build(op, shardingRule, mesh); + + // Return without inserting reshards if any factor sharding has overflow + // axes. This case is not handled yet. + // TODO(enver): Handle the case when factor shardings have overflow axes. + if (hasOverflowAxes(shardingProjection)) { + return; + } + // Checks if factors are sharded the same way across operands and results. - if (hasCompatibleFactorSharding(shardingProjection)) { + if (hasCompatibleFactorShardings(shardingProjection)) { return; } + // TODO(enver): Build a projection where, for each factor, factor + // shardings are the same across all operands and results; + // TODO(enver): Insert the explicit reshard ops. }); }