From b2306a78b7884e936005eb46f8a852fae9a930f0 Mon Sep 17 00:00:00 2001 From: shardy authors Date: Mon, 21 Oct 2024 02:48:27 -0700 Subject: [PATCH] Append Tensor to FactorSharding struct name. FactorSharding has isMinorMost field which depends on the tensor, while the other fields, axisRef, overflowAxes, and isClosed can be shared among the tensors. Hence TensorFactorSharding defines a tensor-depending factor sharding. The target, by a subsequent change, is to get the following structured: TensorFactorSharding { FactorSharding factorSharding, bool isMinorMost } FactorSharding { SmallVector axisRefs, bool isClosed, SmallVector overflowAxes } One motivation is ShardingProjection::build from common factor shardings, where the isMinorMost needs to be tensor specific and should be determined by the sharding rule, not from the common factor sharding. PiperOrigin-RevId: 688056929 --- .../export/insert_explicit_reshards.cc | 4 +- .../aggressive_factor_propagation.cc | 2 +- .../propagation/basic_factor_propagation.cc | 18 +-- .../propagation/basic_factor_propagation.h | 6 +- .../basic_factor_propagation_test.cc | 28 ++-- .../propagation/basic_propagation.cc | 10 +- .../propagation/sharding_projection.cc | 26 ++-- .../propagation/sharding_projection.h | 36 ++--- .../propagation/sharding_projection_test.cc | 146 +++++++++--------- .../propagation/test/basic_propagation.mlir | 6 +- .../transforms/propagation/testing_utils.h | 10 +- 11 files changed, 146 insertions(+), 146 deletions(-) diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index f9f4837..b8b4d93 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -41,8 +41,8 @@ namespace { // 1. Factors are sharded the same way across operands and results. bool hasCompatibleFactorSharding(const ShardingProjection& projection) { FactorIndexToSharding factorIndexToCommonSharding; - for (const TensorFactorShardings& tensorFactorSharding : - llvm::concat(projection.getOperands(), + for (const TensorFactorShardingMap& tensorFactorSharding : + llvm::concat(projection.getOperands(), projection.getResults())) { // Detects conflicts within the same factor. for (const auto& [factorIndex, factorSharding] : diff --git a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc index 1d6e591..1b2aebc 100644 --- a/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.cc @@ -70,7 +70,7 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings( // different shardings to different tensors along the same factor. Examples // are provided in the docstring of this class. for (const auto& [tensorIndex, tensorFactorShardings] : - llvm::enumerate(llvm::concat( + llvm::enumerate(llvm::concat( projection.getOperands(), projection.getResults()))) { // Propagate the axes got in Step 1, and resolve conflicts within a factor. FactorIndexToSharding newSharding = diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc index a58ef5d..e05af79 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc @@ -75,7 +75,7 @@ BasicFactorPropagation::compatiblePrefixNoConflictsAcrossFactors( std::optional BasicFactorPropagation::compatiblePrefixNoConflictsWithinFactor( AxisRefAttr axisRef, ArrayRef replicatedAxes, - const FactorSharding& factorSharding, int64_t prevShardedSize, + const TensorFactorSharding& factorSharding, int64_t prevShardedSize, int64_t factorSize, MeshAttr mesh) const { AxisRefAttr result = axisRef; @@ -158,7 +158,7 @@ void BasicFactorPropagation::truncateAxesByRemovingConflicts( namespace { using DirectionBasedTensorShardings = - std::pair, ArrayRef>; + std::pair, ArrayRef>; // Gets the tensor shardings that should be processed first and then second. // @@ -175,8 +175,8 @@ using DirectionBasedTensorShardings = // on the result factor shardings but not the operands. std::optional getDirectionBasedTensorShardings( PropagationDirection direction, Operation* op, - ArrayRef operands, - ArrayRef results) { + ArrayRef operands, + ArrayRef results) { static const char* errMsg = "since Shardy is propagating {0} for this op, Shardy may not " "fully propagate to each of the multiple {1}s; {0} " @@ -313,8 +313,8 @@ SmallVector BasicFactorPropagation::getCompatibleMajorAxes( bool canExpand = true; auto updateCompatibleMajorAxesWithTensors = - [&](ArrayRef tensors) { - for (const TensorFactorShardings& tensor : tensors) { + [&](ArrayRef tensors) { + for (const TensorFactorShardingMap& tensor : tensors) { if (auto factorShardingIt = tensor.factorIndexToSharding.find(factorIndex); factorShardingIt != tensor.factorIndexToSharding.end()) { @@ -336,7 +336,7 @@ SmallVector BasicFactorPropagation::getCompatibleMajorAxes( } std::optional BasicFactorPropagation::compatiblePrefix( - AxisRefAttr axisRef, const TensorFactorShardings& tensorFactorSharding, + AxisRefAttr axisRef, const TensorFactorShardingMap& tensorFactorSharding, int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize, MeshAttr mesh) const { const FactorIndexToSharding& factorIndexToSharding = @@ -368,8 +368,8 @@ std::optional BasicFactorPropagation::compatiblePrefix( int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize, MeshAttr mesh) const { AxisRefAttr result = axisRef; - for (const TensorFactorShardings& tensorFactorSharding : - llvm::concat(projection.getOperands(), + for (const TensorFactorShardingMap& tensorFactorSharding : + llvm::concat(projection.getOperands(), projection.getResults())) { SDY_ASSIGN_OR_RETURN_IF_NULLOPT( result, compatiblePrefix(result, tensorFactorSharding, factorIndex, diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h index 6ef1d47..21ed510 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h @@ -127,7 +127,7 @@ class BasicFactorPropagation : public FactorPropagation { // Returns std::nullopt if the compatible prefix does not exist. std::optional compatiblePrefixNoConflictsWithinFactor( AxisRefAttr axisRef, ArrayRef replicatedAxes, - const FactorSharding& factorSharding, int64_t prevShardedSize, + const TensorFactorSharding& factorSharding, int64_t prevShardedSize, int64_t factorSize, MeshAttr mesh) const; // For each axis in `axes`, call `removeConflicts` to get the compatible @@ -156,12 +156,12 @@ class BasicFactorPropagation : public FactorPropagation { // If this tensor is mapped to `factorIndex`, returns the prefix of `axisRef` // by removing conflicts with other factors and within the factor itself. std::optional compatiblePrefix( - AxisRefAttr axisRef, const TensorFactorShardings& tensorFactorSharding, + AxisRefAttr axisRef, const TensorFactorShardingMap& tensorFactorSharding, int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize, MeshAttr mesh) const; // Returns the largest compatible prefix of `axisRef` by removing conflicts - // with every `TensorFactorShardings` in `projection`. + // with every `TensorFactorShardingMap` in `projection`. std::optional compatiblePrefix( AxisRefAttr axisRef, const ShardingProjection& projection, int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize, diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc index 90449ce..08c23c0 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation_test.cc @@ -412,14 +412,14 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { } MeshAttr mesh = MeshAttr::get(&context, meshAxisAttrs); - TensorFactorShardings operand = { + TensorFactorShardingMap operand = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b")}, .isMinorMost = true}}, {1, {.axisRefs = {createAxis("c")}, .isMinorMost = true}}, }}; - TensorFactorShardings resultBefore = { + TensorFactorShardingMap resultBefore = { .factorIndexToSharding = { {0, {.axisRefs = {}, .isMinorMost = false}}, {1, {.axisRefs = {}, .isMinorMost = true}}, @@ -433,7 +433,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { // factor size (4) isn't divisible by the size of ["c"] (3). auto test = [&](ArrayRef factorSizes, - const TensorFactorShardings& resultAfter) { + const TensorFactorShardingMap& resultAfter) { ShardingProjection projectionBefore({operand}, {resultBefore}); ShardingProjection projectionAfter({operand}, {resultAfter}); auto [updateOperands, updateResults] = propagateFactorShardings( @@ -446,7 +446,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { { // The factor size (9) is divisible by the size of ["a", "b"] (9). SmallVector factorSizes = {9, 4}; - TensorFactorShardings resultAfter = { + TensorFactorShardingMap resultAfter = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b")}, @@ -460,7 +460,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { // The factor size (6) is divisible by the size of ["a"] (3), but not by the // size of ["a", "b"] (9). SmallVector factorSizes = {6, 19}; - TensorFactorShardings resultAfter = { + TensorFactorShardingMap resultAfter = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a")}, .isMinorMost = false}}, {1, {.axisRefs = {createAxis("c")}, .isMinorMost = true}}, @@ -471,7 +471,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { { // The factor size (4) isn't divisible by the size of ["a"] (3). SmallVector factorSizes = {4, 1}; - TensorFactorShardings resultAfter = { + TensorFactorShardingMap resultAfter = { .factorIndexToSharding = { {0, {.axisRefs = {}, .isMinorMost = false}}, {1, {.axisRefs = {createAxis("c")}, .isMinorMost = true}}, @@ -481,30 +481,30 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) { } TEST_F(BasicFactorPropagationTest, UniDirectionalPropagation) { - TensorFactorShardings operandBefore0 = { + TensorFactorShardingMap operandBefore0 = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, {1, {.axisRefs = {createAxis("d"), createAxis("e")}}}, }}; - TensorFactorShardings operandBefore1 = { + TensorFactorShardingMap operandBefore1 = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a")}}}, {1, {.axisRefs = {createAxis("d")}}}, }}; - TensorFactorShardings result0 = { + TensorFactorShardingMap result0 = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}}, {1, {.axisRefs = {createAxis("d")}}}, }}; - TensorFactorShardings operandAfter0 = { + TensorFactorShardingMap operandAfter0 = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}}, {1, {.axisRefs = {createAxis("d"), createAxis("e")}}}, }}; - TensorFactorShardings operandAfter1 = { + TensorFactorShardingMap operandAfter1 = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}}, @@ -546,14 +546,14 @@ TEST_F(BasicFactorPropagationTest, UniDirectionalPropagation) { } TEST_F(BasicFactorPropagationTest, UniDirectionalPropagationWithConflict) { - TensorFactorShardings operand0 = { + TensorFactorShardingMap operand0 = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, }}; - TensorFactorShardings operand1 = {.factorIndexToSharding = { + TensorFactorShardingMap operand1 = {.factorIndexToSharding = { {0, {.axisRefs = {createAxis("a")}}}, }}; - TensorFactorShardings result = { + TensorFactorShardingMap result = { .factorIndexToSharding = { {0, {.axisRefs = {createAxis("z"), createAxis("a"), createAxis("b")}}}, diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 2f220bd..29e6143 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -120,14 +120,14 @@ void notifyShardingModified(Value value, notifyUsersModified(value, notifyOpModified); } -// Update the sharding of `value` to the sharding in `tensorFactorShardings`. +// Update the sharding of `value` to the sharding in `TensorFactorShardingMap`. // // Returns true if it's possible to update the sharding, i.e., if strided view // isn't needed and all non-minor-most factors are divisible by sharding axes. bool updateTensorSharding( TensorShardingAttr oldTensorSharding, SetTensorShardingCallback setTensorShardingCallback, - const TensorFactorShardings& tensorFactorShardings, + const TensorFactorShardingMap& tensorFactorShardings, TensorMappingAttr tensorMapping, ArrayRef factorSizes, StringRef meshName, MeshAttr mesh, Value modifiedValue, const ShardingGroupMap& shardingGroupMap, @@ -173,17 +173,17 @@ bool updateTensorSharding( return true; } -// Updates the sharding of all tensors according to `tensorFactorShardings`. +// Updates the sharding of all tensors according to `TensorFactorShardingMap`. // // Skips tensors for which `updateTensor` is set to false. // // If an operand or result couldn't be updated to the corresponding sharding in -// `tensorFactorShardings`, e.g., if strided view is required, sets the +// `TensorFactorShardingMap`, e.g., if strided view is required, sets the // respective bit in `updateTensor` or `updateResult` to false. void updateTensorShardings( ValueRange tensors, ArrayRef tensorShardings, SetShardingPerTensorCallback setTensorShardingCallback, - ArrayRef tensorFactorShardings, + ArrayRef tensorFactorShardings, ArrayRef tensorMappings, ArrayRef factorSizes, BitVector& updateTensor, StringRef meshName, MeshAttr mesh, const ShardingGroupMap& shardingGroupMap, diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc index f63aeba..c1bf125 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection.cc @@ -61,7 +61,7 @@ bool shouldUpdate(ArrayRef oldAxes, return oldAxes.back().strictPrefixOf(newAxes.back()); } -bool TensorFactorShardings::updateShardingAxes(int64_t factorIndex, +bool TensorFactorShardingMap::updateShardingAxes(int64_t factorIndex, ArrayRef newAxes) { auto factorShardingIt = factorIndexToSharding.find(factorIndex); if (factorShardingIt == factorIndexToSharding.end()) { @@ -103,7 +103,7 @@ int64_t addAxesToDimSharding(SmallVector& dimSharding, } // namespace -TensorShardingAttr TensorFactorShardings::createTensorShardingAttr( +TensorShardingAttr TensorFactorShardingMap::createTensorShardingAttr( MLIRContext* ctx, TensorMappingAttr tensorMapping, ArrayRef factorSizes, StringRef meshName, MeshAttr mesh) const { SmallVector newDimShardings; @@ -114,7 +114,7 @@ TensorShardingAttr TensorFactorShardings::createTensorShardingAttr( SmallVector dimSharding; for (int64_t factorIndex : dimMapping.getFactorIndices()) { int64_t factorSize = factorSizes[factorIndex]; - const FactorSharding& factorSharding = + const TensorFactorSharding& factorSharding = factorIndexToSharding.at(factorIndex); isClosed |= factorSharding.isClosed; @@ -207,7 +207,7 @@ void addRemainingAxes(SmallVector& currentAxes, } } -// Builds a `TensorFactorShardings` for a tensor with the specified +// Builds a `TensorFactorShardingMap` for a tensor with the specified // `optionalSharding` and `tensorMapping`. // // The high level algorithm for projecting a dimension sharding into factor @@ -215,10 +215,10 @@ void addRemainingAxes(SmallVector& currentAxes, // current factor sharding (starting from the major-most factor and axis) until // the factor is fully sharded, which might require further splitting an axis, // or this is the minor-most factor, then moving to the next factor. -TensorFactorShardings buildTensorFactorShardings( +TensorFactorShardingMap buildTensorFactorShardings( TensorMappingAttr tensorMapping, TensorShardingAttr optionalSharding, ArrayRef factorSizes, MeshAttr mesh) { - TensorFactorShardings result; + TensorFactorShardingMap result; auto& [factorIndexToSharding, replicatedAxes] = result; factorIndexToSharding.reserve(factorSizes.size()); @@ -238,7 +238,7 @@ TensorFactorShardings buildTensorFactorShardings( bool hasOverflowAxes = false; for (int64_t factorIndex : dimMapping.getFactorIndices()) { - FactorSharding& factorSharding = factorIndexToSharding[factorIndex]; + TensorFactorSharding& factorSharding = factorIndexToSharding[factorIndex]; factorSharding.isMinorMost = dimMapping.isMinorMost(factorIndex); if (hasOverflowAxes) { @@ -310,9 +310,9 @@ TensorFactorShardings buildTensorFactorShardings( return result; } -TensorFactorShardings buildTensorFactorShardings( - TensorMappingAttr tensorMapping, ArrayRef factorShardings) { - TensorFactorShardings result; +TensorFactorShardingMap buildTensorFactorShardings( + TensorMappingAttr tensorMapping, ArrayRef factorShardings) { + TensorFactorShardingMap result; // TODO(enver): Drop replicatedAxes after propagation, perhaps isMinorMost as // well. result.factorIndexToSharding.reserve(factorShardings.size()); @@ -327,8 +327,8 @@ TensorFactorShardings buildTensorFactorShardings( } // namespace ShardingProjection::ShardingProjection( - SmallVector operands, - SmallVector results) + SmallVector operands, + SmallVector results) : operands(std::move(operands)), results(std::move(results)) {} ShardingProjection ShardingProjection::build( @@ -360,7 +360,7 @@ ShardingProjection ShardingProjection::build(Operation* op, } ShardingProjection ShardingProjection::build( - ArrayRef factorShardings, OpShardingRuleAttr shardingRule) { + ArrayRef factorShardings, OpShardingRuleAttr shardingRule) { ShardingProjection projection; for (const auto& operandMapping : shardingRule.getOperandMappings()) { projection.operands.push_back( diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection.h b/shardy/dialect/sdy/transforms/propagation/sharding_projection.h index 10e5620..5f783a7 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection.h +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection.h @@ -32,44 +32,44 @@ bool shouldUpdate(ArrayRef oldAxes, ArrayRef newAxes); // The axes along which a factor is sharded, and whether the factor can be // further sharded (unless it's fully sharded already). -struct FactorSharding { +struct TensorFactorSharding { SmallVector axisRefs; bool isClosed = false; bool isMinorMost = false; // Additional axes in the dimension sharding that was projected to this - // `FactorSharding`, such that the size of the first overflow axis doesn't + // `TensorFactorSharding`, such that the size of the first overflow axis doesn't // divide the factor size, and the factor is non-minor-most. // // We need to store these axes so that we can add them when projecting back to // dimension shardings. SmallVector overflowAxes; - bool operator==(const FactorSharding& other) const { + bool operator==(const TensorFactorSharding& other) const { return axisRefs == other.axisRefs && isClosed == other.isClosed && isMinorMost == other.isMinorMost && overflowAxes == other.overflowAxes; } - bool operator!=(const FactorSharding& other) const { + bool operator!=(const TensorFactorSharding& other) const { return !(*this == other); } }; -using FactorIndexToSharding = llvm::DenseMap; +using FactorIndexToSharding = llvm::DenseMap; // Holds the factor shardings and replicated axes of a tensor. -struct TensorFactorShardings { +struct TensorFactorShardingMap { // A mapping between factor index to the sharding of that factor. // TODO(tomnatan): consider using a vector with null for unmapped factors. FactorIndexToSharding factorIndexToSharding; SmallVector replicatedAxes; - bool operator==(const TensorFactorShardings& other) const { + bool operator==(const TensorFactorShardingMap& other) const { return factorIndexToSharding == other.factorIndexToSharding && replicatedAxes == other.replicatedAxes; } - bool operator!=(const TensorFactorShardings& other) const { + bool operator!=(const TensorFactorShardingMap& other) const { return !(*this == other); } @@ -83,7 +83,7 @@ struct TensorFactorShardings { bool updateShardingAxes(int64_t factorIndex, ArrayRef newAxes); // Creates a `TensorShardingAttr` by projecting the factor shardings in - // this `TensorFactorShardings` to dimension shardings w.r.t. to + // this `TensorFactorShardingMap` to dimension shardings w.r.t. to // `tensorMapping`. // // Ignores sharding of any factor that needs strided view. @@ -140,20 +140,20 @@ class ShardingProjection { public: ShardingProjection() = default; - ShardingProjection(SmallVector operands, - SmallVector results); + ShardingProjection(SmallVector operands, + SmallVector results); int64_t getNumOperands() const { return operands.size(); } int64_t getNumResults() const { return results.size(); } int64_t getNumTensors() const { return getNumOperands() + getNumResults(); } - ArrayRef getOperands() const { return operands; } - ArrayRef getResults() const { return results; } + ArrayRef getOperands() const { return operands; } + ArrayRef getResults() const { return results; } - const TensorFactorShardings& getOperand(int64_t operandNum) const { + const TensorFactorShardingMap& getOperand(int64_t operandNum) const { return operands[operandNum]; } - const TensorFactorShardings& getResult(int64_t resultNum) const { + const TensorFactorShardingMap& getResult(int64_t resultNum) const { return results[resultNum]; } @@ -188,7 +188,7 @@ class ShardingProjection { // Builds a `ShardingProjection` w.r.t. the given `shardingRule` where factor // shardings are the same across all operands and results, and specified by // `factorShardings`. - static ShardingProjection build(ArrayRef factorShardings, + static ShardingProjection build(ArrayRef factorShardings, OpShardingRuleAttr shardingRule); bool operator==(const ShardingProjection& other) const { @@ -200,8 +200,8 @@ class ShardingProjection { } private: - SmallVector operands; - SmallVector results; + SmallVector operands; + SmallVector results; }; } // namespace sdy diff --git a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc index 4db8e56..d84d46e 100644 --- a/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc +++ b/shardy/dialect/sdy/transforms/propagation/sharding_projection_test.cc @@ -66,7 +66,7 @@ void verifyShardingAttrsMatch(TensorShardingAttr resultSharding, } void verifyReconstructedShardings( - ValueRange tensors, ArrayRef tensorFactorShardings, + ValueRange tensors, ArrayRef tensorFactorShardings, ArrayRef tensorMappings, ArrayRef factorSizes, StringRef meshName, MeshAttr mesh) { for (auto [tensor, factorShardings, tensorMapping] : @@ -82,7 +82,7 @@ void verifyReconstructedShardings( // Builds a `ShardingProjection` for the first OpTy in the main function. // // In addition, verifies that reconstructing the `TensorShardingAttr` for each -// tensor (using `TensorFactorShardings::createTensorShardingAttr`) from the +// tensor (using `TensorFactorShardingMap::createTensorShardingAttr`) from the // created projection matches the original sharding. template ShardingProjection getShardingProjection(ModuleOp module) { @@ -103,7 +103,7 @@ ShardingProjection getShardingProjection(ModuleOp module) { template ShardingProjection getShardingProjection( - ModuleOp module, ArrayRef factorShardings) { + ModuleOp module, ArrayRef factorShardings) { OpTy op = getFirstOp(module); OpShardingRuleAttr shardingRule = getOrCreateShardingRule(op); assert(shardingRule); @@ -115,7 +115,7 @@ ShardingProjection getShardingProjection( //===----------------------------------------------------------------------===// // Tests for ShardingProjection::build // -// TensorFactorShardings::createTensorShardingAttr is also tested indirectly +// TensorFactorShardingMap::createTensorShardingAttr is also tested indirectly // by calling it using the created `ShardingProjection` and verifying that the // reconstructed `TensorShardingAttr` for each tensor matches the original one. //===----------------------------------------------------------------------===// @@ -141,32 +141,32 @@ TEST_F(ShardingProjectionBuildTest, DotGeneralSimple) { getShardingProjection(module.get()); EXPECT_THAT(projection.getOperand(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("b"))), - FactorShardingIs(/*index*/ 2, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 2, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a")))), /*replicatedAxes*/ IsEmpty())); EXPECT_THAT(projection.getOperand(1), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 1, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("d"))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"), AxisRefIs("c")))), /*replicatedAxes*/ ElementsAre(AxisRefIs("b")))); EXPECT_THAT( projection.getResult(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ true, IsEmpty()), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("d"), AxisRefIs("e")))), /*replicatedAxes*/ ElementsAre(AxisRefIs("a"), AxisRefIs("c")))); @@ -189,21 +189,21 @@ TEST_F(ShardingProjectionBuildTest, ReshapeSplitDim) { EXPECT_THAT( projection.getOperand(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(SubAxisRefIs("a", 1, 2))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(SubAxisRefIs("a", 2, 2), AxisRefIs("b")))), /*replicatedAxes*/ ElementsAre(AxisRefIs("c")))); EXPECT_THAT(projection.getResult(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ true, IsEmpty()), - FactorShardingIs(/*index*/ 1, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ false, /*isMinorMost*/ true, IsEmpty())), /*replicatedAxes*/ IsEmpty())); } @@ -226,15 +226,15 @@ TEST_F(ShardingProjectionBuildTest, ReshapeSplitDimAxisAlreadySplit) { getShardingProjection(module.get()); EXPECT_THAT(projection.getOperand(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(SubAxisRefIs("a", 1, 2), AxisRefIs("b"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(SubAxisRefIs("a", 2, 2))), - FactorShardingIs(/*index*/ 2, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(SubAxisRefIs("a", 4, 4)))), /*replicatedAxes*/ ElementsAre(SubAxisRefIs("c", 2, 2)))); @@ -259,15 +259,15 @@ TEST_F(ShardingProjectionBuildTest, ReshapeMergeDim) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("b"))))); EXPECT_THAT( projection.getResult(0).factorIndexToSharding, - UnorderedElementsAre(FactorShardingIs(/*index*/ 0, /*isClosed*/ false, + UnorderedElementsAre(TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ false, IsEmpty()), - FactorShardingIs(/*index*/ 1, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ false, /*isMinorMost*/ true, IsEmpty()))); } @@ -289,11 +289,11 @@ TEST_F(ShardingProjectionBuildTest, ReshapeWithSizeOneDims) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(AxisRefIs("a"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()), - FactorShardingIs(/*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()))); } @@ -317,9 +317,9 @@ TEST_F(ShardingProjectionBuildTest, AddSingleFactorNonDivisible) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"))))); } @@ -343,9 +343,9 @@ TEST_F(ShardingProjectionBuildTest, SingleFactorOverflows) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"), AxisRefIs("b"))))); } @@ -368,12 +368,12 @@ TEST_F(ShardingProjectionBuildTest, FactorWithSmallerSizeThanDimOverflows) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 1, /*isClosed*/ false, /*isMinorMost*/ true, /*axisRefs*/ ElementsAre(AxisRefIs("c"))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, /*axisRefs*/ ElementsAre(AxisRefIs("b"), SubAxisRefIs("d", 2, 2), @@ -399,10 +399,10 @@ TEST_F(ShardingProjectionBuildTest, ReshapeMinorMostFactorNonDivisible) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(SubAxisRefIs("a", 1, 2))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(SubAxisRefIs("a", 2, 3))))); } @@ -425,10 +425,10 @@ TEST_F(ShardingProjectionBuildTest, ReshapeMinorMostFactorOverflows) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(SubAxisRefIs("a", 1, 2))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(SubAxisRefIs("a", 2, 8))))); } @@ -452,10 +452,10 @@ TEST_F(ShardingProjectionBuildTest, EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(SubAxisRefIs("a", 1, 2))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(SubAxisRefIs("a", 2, 8), AxisRefIs("b"), AxisRefIs("c"))))); } @@ -478,11 +478,11 @@ TEST_F(ShardingProjectionBuildTest, ReshapeNonMinorMostFactorNonDivisible) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, - UnorderedElementsAre(FactorShardingWithOverflowIs( + UnorderedElementsAre(TensorFactorShardingWithOverflowIs( /*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ false, /*axisRefs*/ IsEmpty(), /*overflowAxes*/ ElementsAre(AxisRefIs("a"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()))); } @@ -505,11 +505,11 @@ TEST_F(ShardingProjectionBuildTest, EXPECT_THAT(projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingWithOverflowIs( + TensorFactorShardingWithOverflowIs( /*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, /*axisRefs*/ ElementsAre(SubAxisRefIs("a", 1, 2)), /*overflowAxes*/ ElementsAre(SubAxisRefIs("a", 2, 3))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()))); } @@ -532,11 +532,11 @@ TEST_F(ShardingProjectionBuildTest, EXPECT_THAT(projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingWithOverflowIs( + TensorFactorShardingWithOverflowIs( /*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, /*axisRefs*/ ElementsAre(AxisRefIs("a")), /*overflowAxes*/ ElementsAre(AxisRefIs("b"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ false, IsEmpty()))); } @@ -559,14 +559,14 @@ TEST_F(ShardingProjectionBuildTest, ReshapeMinorMostFactorSizeOneAxes) { EXPECT_THAT( projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ false, ElementsAre(AxisRefIs("a"), AxisRefIs("b"), AxisRefIs("c"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, IsEmpty()))); } -TEST_F(ShardingProjectionBuildTest, DotGeneralSimpleFromFactorShardings) { +TEST_F(ShardingProjectionBuildTest, DotGeneralSimpleFromTensorFactorShardings) { const std::string program = R"mlir( sdy.mesh @mesh = <["a"=4, "b"=2, "c"=2, "d"=2]> @@ -598,33 +598,33 @@ TEST_F(ShardingProjectionBuildTest, DotGeneralSimpleFromFactorShardings) { EXPECT_THAT( projection.getOperand(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"))), - FactorShardingIs(/*index*/ 2, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("c"), AxisRefIs("d")))), /*replicatedAxes*/ IsEmpty())); EXPECT_THAT( projection.getOperand(1), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 2, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("c"), AxisRefIs("d"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("b")))), /*replicatedAxes*/ IsEmpty())); EXPECT_THAT(projection.getResult(0), - TensorFactorShardingsIs( + TensorFactorShardingMapIs( /*factorIndexToSharding*/ UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"))), - FactorShardingIs(/*index*/ 1, /*isClosed*/ true, + TensorFactorShardingIs(/*index*/ 1, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("b")))), /*replicatedAxes*/ IsEmpty())); @@ -685,28 +685,28 @@ TEST_F(ShardingProjectionUpdateShardingTest, DotGeneralSimple) { // other members (isClosed, isMinorMost, overflowAxes). EXPECT_THAT(projection.getOperand(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 0, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"), AxisRefIs("b"))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 2, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("c"))))); EXPECT_THAT(projection.getOperand(1).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 1, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("d"), SubAxisRefIs("f", 1, 2))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 2, /*isClosed*/ true, /*isMinorMost*/ true, ElementsAre(AxisRefIs("e"))))); EXPECT_THAT(projection.getResult(0).factorIndexToSharding, UnorderedElementsAre( - FactorShardingIs(/*index*/ 0, /*isClosed*/ false, + TensorFactorShardingIs(/*index*/ 0, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("a"), AxisRefIs("b"))), - FactorShardingIs( + TensorFactorShardingIs( /*index*/ 1, /*isClosed*/ false, /*isMinorMost*/ true, ElementsAre(AxisRefIs("d"), AxisRefIs("f"))))); } @@ -750,7 +750,7 @@ TEST_F(ShouldUpdateTest, ShouldUpdateTest) { } //===----------------------------------------------------------------------===// -// Tests for TensorFactorShardings::createTensorShardingAttr +// Tests for TensorFactorShardingMap::createTensorShardingAttr // // Since ShardingProjectionBuildTest also tests this method indirectly in each // test case, here we only test the special cases that aren't tested above. @@ -788,7 +788,7 @@ TEST_F(CreateTensorShardingAttrTest, ConsecutiveSubAxesMerged) { ASSERT_TRUE(module); auto op = getFirstOp(module.get()); OpShardingRuleAttr shardingRule = getOrCreateShardingRule(op); - TensorFactorShardings factorShardings{ + TensorFactorShardingMap factorShardings{ .factorIndexToSharding = {{0, {.axisRefs = {createAxis("b"), createSubAxis("a", 2, 2)}}}, {1, {.axisRefs = {createSubAxis("a", 4, 2)}}}}, @@ -819,7 +819,7 @@ TEST_F(CreateTensorShardingAttrTest, OverflowSubAxisMerged) { ASSERT_TRUE(module); auto op = getFirstOp(module.get()); OpShardingRuleAttr shardingRule = getOrCreateShardingRule(op); - TensorFactorShardings factorShardings{ + TensorFactorShardingMap factorShardings{ .factorIndexToSharding = {{0, {.axisRefs = {createSubAxis("a", 1, 2)}, .overflowAxes = {createSubAxis("a", 2, 3)}}}, @@ -850,7 +850,7 @@ TEST_F(CreateTensorShardingAttrTest, NonMinorMostFactorFullySharded) { ASSERT_TRUE(module); auto op = getFirstOp(module.get()); OpShardingRuleAttr shardingRule = getOrCreateShardingRule(op); - TensorFactorShardings factorShardings{ + TensorFactorShardingMap factorShardings{ .factorIndexToSharding = {{0, {.axisRefs = {createAxis("a"), createAxis("b")}}}, {1, {.axisRefs = {createAxis("c")}}}}, @@ -882,7 +882,7 @@ TEST_F(CreateTensorShardingAttrTest, NonMinorMostFactorPartiallySharded) { ASSERT_TRUE(module); auto op = getFirstOp(module.get()); OpShardingRuleAttr shardingRule = getOrCreateShardingRule(op); - TensorFactorShardings factorShardings{ + TensorFactorShardingMap factorShardings{ .factorIndexToSharding = {{0, {.axisRefs = {createAxis("a")}}}, {1, {.axisRefs = {createAxis("b")}}}}}; @@ -910,7 +910,7 @@ TEST_F(CreateTensorShardingAttrTest, MinorMostFactorNotDivisible) { ASSERT_TRUE(module); auto op = getFirstOp(module.get()); OpShardingRuleAttr shardingRule = getOrCreateShardingRule(op); - TensorFactorShardings factorShardings{ + TensorFactorShardingMap factorShardings{ .factorIndexToSharding = {{0, {.axisRefs = {createAxis("b")}}}, {1, {.axisRefs = {createAxis("a")}}}}}; diff --git a/shardy/dialect/sdy/transforms/propagation/test/basic_propagation.mlir b/shardy/dialect/sdy/transforms/propagation/test/basic_propagation.mlir index d260bbb..54d7f12 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/basic_propagation.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/basic_propagation.mlir @@ -458,7 +458,7 @@ func.func @non_minor_most_factor_non_divisible(%arg0: tensor<8xf32> {sdy.shardin // NOTE: it's important to make sure that the sharding of %arg0 doesn't change, // because "b" is added to the ShardingProjection as an overflow axis (see -// `FactorSharding`), that gets added back when creating the updated +// `TensorFactorSharding`), that gets added back when creating the updated // `TensorShardingAttr`. // CHECK-LABEL: func @non_minor_most_factor_non_divisible_multiple_axes( // CHECK-SAME: %arg0: tensor<2x2x32xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_3_c_2_d_2, [{"c"}, {"d", ?}, {"a", "b"}]>}) @@ -472,7 +472,7 @@ func.func @non_minor_most_factor_non_divisible_multiple_axes( // NOTE: it's important to make sure that the sharding of %arg0 doesn't change, // because "a":(2)3 is added to the ShardingProjection as an overflow axis (see -// `FactorSharding`), that gets added back when creating the updated +// `TensorFactorSharding`), that gets added back when creating the updated // `TensorShardingAttr`. // CHECK-LABEL: func @non_minor_most_factor_non_divisible_sub_axis( // CHECK-SAME: %arg0: tensor<2x32xf32> {sdy.sharding = #sdy.sharding<@mesh_a_6_b_2, [{"b", ?}, {"a"}]>}) @@ -485,7 +485,7 @@ func.func @non_minor_most_factor_non_divisible_sub_axis( } // This test verifies that "b" isn't propagated from the `stablehlo.reshape` to -// %arg0, even though "b" in %arg0 is an overflow axis (see `FactorSharding`). +// %arg0, even though "b" in %arg0 is an overflow axis (see `TensorFactorSharding`). // CHECK-LABEL: func @non_minor_most_factor_non_divisible_other_open_dim_unchanged( // CHECK-SAME: %arg0: tensor<3x32xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_3, [{?}, {"a", "b", ?}]>}) func.func @non_minor_most_factor_non_divisible_other_open_dim_unchanged( diff --git a/shardy/dialect/sdy/transforms/propagation/testing_utils.h b/shardy/dialect/sdy/transforms/propagation/testing_utils.h index 12661e4..85b6d00 100644 --- a/shardy/dialect/sdy/transforms/propagation/testing_utils.h +++ b/shardy/dialect/sdy/transforms/propagation/testing_utils.h @@ -49,7 +49,7 @@ MATCHER_P3(SubAxisRefIs, axisName, preSize, size, arg.getSubAxisInfo().getSize() == size; } -MATCHER_P5(FactorShardingWithOverflowIs, index, isClosed, isMinorMost, +MATCHER_P5(TensorFactorShardingWithOverflowIs, index, isClosed, isMinorMost, axisRefsMatcher, overflowAxesMatcher, "factor " + PrintToString(index) + " sharding that is " + (isClosed || negation ? "closed" : "open") + @@ -73,18 +73,18 @@ MATCHER_P5(FactorShardingWithOverflowIs, index, isClosed, isMinorMost, result_listener); } -MATCHER_P4(FactorShardingIs, index, isClosed, isMinorMost, axisRefsMatcher, +MATCHER_P4(TensorFactorShardingIs, index, isClosed, isMinorMost, axisRefsMatcher, DescribeMatcher( - FactorShardingWithOverflowIs(index, isClosed, isMinorMost, + TensorFactorShardingWithOverflowIs(index, isClosed, isMinorMost, axisRefsMatcher, IsEmpty()), negation)) { return ExplainMatchResult( - FactorShardingWithOverflowIs(index, isClosed, isMinorMost, + TensorFactorShardingWithOverflowIs(index, isClosed, isMinorMost, axisRefsMatcher, IsEmpty()), arg, result_listener); } -MATCHER_P2(TensorFactorShardingsIs, factorIndexToShardingMatcher, +MATCHER_P2(TensorFactorShardingMapIs, factorIndexToShardingMatcher, replicatedAxesMatcher, "tensor factor shardings that:\n" + DescribeMatcher(