Skip to content

Commit

Permalink
Append Tensor to FactorSharding struct name.
Browse files Browse the repository at this point in the history
FactorSharding has isMinorMost field which depends on the tensor, while the other fields, axisRef, overflowAxes, and isClosed can be shared among the tensors.

Hence TensorFactorSharding defines a tensor-depending factor sharding. The target, by a subsequent change, is to get the following structured:

TensorFactorSharding {
  FactorSharding factorSharding,
  bool isMinorMost
}

FactorSharding {
  SmallVector<AxisRef> axisRefs,
  bool isClosed,
  SmallVector<AxisRef> overflowAxes
}

One motivation is ShardingProjection::build from common factor shardings, where the isMinorMost needs to be tensor specific and should be determined by the sharding rule, not from the common factor sharding.

PiperOrigin-RevId: 688056929
  • Loading branch information
Google-ML-Automation authored and copybara-github committed Oct 21, 2024
1 parent ed6ba35 commit b2306a7
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ namespace {
// 1. Factors are sharded the same way across operands and results.
bool hasCompatibleFactorSharding(const ShardingProjection& projection) {
FactorIndexToSharding factorIndexToCommonSharding;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
for (const TensorFactorShardingMap& tensorFactorSharding :
llvm::concat<const TensorFactorShardingMap>(projection.getOperands(),
projection.getResults())) {
// Detects conflicts within the same factor.
for (const auto& [factorIndex, factorSharding] :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
// different shardings to different tensors along the same factor. Examples
// are provided in the docstring of this class.
for (const auto& [tensorIndex, tensorFactorShardings] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
llvm::enumerate(llvm::concat<const TensorFactorShardingMap>(
projection.getOperands(), projection.getResults()))) {
// Propagate the axes got in Step 1, and resolve conflicts within a factor.
FactorIndexToSharding newSharding =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ BasicFactorPropagation::compatiblePrefixNoConflictsAcrossFactors(
std::optional<AxisRefAttr>
BasicFactorPropagation::compatiblePrefixNoConflictsWithinFactor(
AxisRefAttr axisRef, ArrayRef<AxisRefAttr> replicatedAxes,
const FactorSharding& factorSharding, int64_t prevShardedSize,
const TensorFactorSharding& factorSharding, int64_t prevShardedSize,
int64_t factorSize, MeshAttr mesh) const {
AxisRefAttr result = axisRef;

Expand Down Expand Up @@ -158,7 +158,7 @@ void BasicFactorPropagation::truncateAxesByRemovingConflicts(
namespace {

using DirectionBasedTensorShardings =
std::pair<ArrayRef<TensorFactorShardings>, ArrayRef<TensorFactorShardings>>;
std::pair<ArrayRef<TensorFactorShardingMap>, ArrayRef<TensorFactorShardingMap>>;

// Gets the tensor shardings that should be processed first and then second.
//
Expand All @@ -175,8 +175,8 @@ using DirectionBasedTensorShardings =
// on the result factor shardings but not the operands.
std::optional<DirectionBasedTensorShardings> getDirectionBasedTensorShardings(
PropagationDirection direction, Operation* op,
ArrayRef<TensorFactorShardings> operands,
ArrayRef<TensorFactorShardings> results) {
ArrayRef<TensorFactorShardingMap> operands,
ArrayRef<TensorFactorShardingMap> results) {
static const char* errMsg =
"since Shardy is propagating {0} for this op, Shardy may not "
"fully propagate to each of the multiple {1}s; {0} "
Expand Down Expand Up @@ -313,8 +313,8 @@ SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorAxes(
bool canExpand = true;

auto updateCompatibleMajorAxesWithTensors =
[&](ArrayRef<TensorFactorShardings> tensors) {
for (const TensorFactorShardings& tensor : tensors) {
[&](ArrayRef<TensorFactorShardingMap> tensors) {
for (const TensorFactorShardingMap& tensor : tensors) {
if (auto factorShardingIt =
tensor.factorIndexToSharding.find(factorIndex);
factorShardingIt != tensor.factorIndexToSharding.end()) {
Expand All @@ -336,7 +336,7 @@ SmallVector<AxisRefAttr> BasicFactorPropagation::getCompatibleMajorAxes(
}

std::optional<AxisRefAttr> BasicFactorPropagation::compatiblePrefix(
AxisRefAttr axisRef, const TensorFactorShardings& tensorFactorSharding,
AxisRefAttr axisRef, const TensorFactorShardingMap& tensorFactorSharding,
int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize,
MeshAttr mesh) const {
const FactorIndexToSharding& factorIndexToSharding =
Expand Down Expand Up @@ -368,8 +368,8 @@ std::optional<AxisRefAttr> BasicFactorPropagation::compatiblePrefix(
int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize,
MeshAttr mesh) const {
AxisRefAttr result = axisRef;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
for (const TensorFactorShardingMap& tensorFactorSharding :
llvm::concat<const TensorFactorShardingMap>(projection.getOperands(),
projection.getResults())) {
SDY_ASSIGN_OR_RETURN_IF_NULLOPT(
result, compatiblePrefix(result, tensorFactorSharding, factorIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class BasicFactorPropagation : public FactorPropagation {
// Returns std::nullopt if the compatible prefix does not exist.
std::optional<AxisRefAttr> compatiblePrefixNoConflictsWithinFactor(
AxisRefAttr axisRef, ArrayRef<AxisRefAttr> replicatedAxes,
const FactorSharding& factorSharding, int64_t prevShardedSize,
const TensorFactorSharding& factorSharding, int64_t prevShardedSize,
int64_t factorSize, MeshAttr mesh) const;

// For each axis in `axes`, call `removeConflicts` to get the compatible
Expand Down Expand Up @@ -156,12 +156,12 @@ class BasicFactorPropagation : public FactorPropagation {
// If this tensor is mapped to `factorIndex`, returns the prefix of `axisRef`
// by removing conflicts with other factors and within the factor itself.
std::optional<AxisRefAttr> compatiblePrefix(
AxisRefAttr axisRef, const TensorFactorShardings& tensorFactorSharding,
AxisRefAttr axisRef, const TensorFactorShardingMap& tensorFactorSharding,
int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize,
MeshAttr mesh) const;

// Returns the largest compatible prefix of `axisRef` by removing conflicts
// with every `TensorFactorShardings` in `projection`.
// with every `TensorFactorShardingMap` in `projection`.
std::optional<AxisRefAttr> compatiblePrefix(
AxisRefAttr axisRef, const ShardingProjection& projection,
int64_t factorIndex, int64_t prevShardedSize, int64_t factorSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,14 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) {
}
MeshAttr mesh = MeshAttr::get(&context, meshAxisAttrs);

TensorFactorShardings operand = {
TensorFactorShardingMap operand = {
.factorIndexToSharding = {
{0,
{.axisRefs = {createAxis("a"), createAxis("b")},
.isMinorMost = true}},
{1, {.axisRefs = {createAxis("c")}, .isMinorMost = true}},
}};
TensorFactorShardings resultBefore = {
TensorFactorShardingMap resultBefore = {
.factorIndexToSharding = {
{0, {.axisRefs = {}, .isMinorMost = false}},
{1, {.axisRefs = {}, .isMinorMost = true}},
Expand All @@ -433,7 +433,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) {
// factor size (4) isn't divisible by the size of ["c"] (3).

auto test = [&](ArrayRef<int64_t> factorSizes,
const TensorFactorShardings& resultAfter) {
const TensorFactorShardingMap& resultAfter) {
ShardingProjection projectionBefore({operand}, {resultBefore});
ShardingProjection projectionAfter({operand}, {resultAfter});
auto [updateOperands, updateResults] = propagateFactorShardings(
Expand All @@ -446,7 +446,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) {
{
// The factor size (9) is divisible by the size of ["a", "b"] (9).
SmallVector<int64_t> factorSizes = {9, 4};
TensorFactorShardings resultAfter = {
TensorFactorShardingMap resultAfter = {
.factorIndexToSharding = {
{0,
{.axisRefs = {createAxis("a"), createAxis("b")},
Expand All @@ -460,7 +460,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) {
// The factor size (6) is divisible by the size of ["a"] (3), but not by the
// size of ["a", "b"] (9).
SmallVector<int64_t> factorSizes = {6, 19};
TensorFactorShardings resultAfter = {
TensorFactorShardingMap resultAfter = {
.factorIndexToSharding = {
{0, {.axisRefs = {createAxis("a")}, .isMinorMost = false}},
{1, {.axisRefs = {createAxis("c")}, .isMinorMost = true}},
Expand All @@ -471,7 +471,7 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) {
{
// The factor size (4) isn't divisible by the size of ["a"] (3).
SmallVector<int64_t> factorSizes = {4, 1};
TensorFactorShardings resultAfter = {
TensorFactorShardingMap resultAfter = {
.factorIndexToSharding = {
{0, {.axisRefs = {}, .isMinorMost = false}},
{1, {.axisRefs = {createAxis("c")}, .isMinorMost = true}},
Expand All @@ -481,30 +481,30 @@ TEST_F(BasicFactorPropagationTest, MinorMostFactorNotDivisible) {
}

TEST_F(BasicFactorPropagationTest, UniDirectionalPropagation) {
TensorFactorShardings operandBefore0 = {
TensorFactorShardingMap operandBefore0 = {
.factorIndexToSharding = {
{0, {.axisRefs = {createAxis("a"), createAxis("b")}}},
{1, {.axisRefs = {createAxis("d"), createAxis("e")}}},
}};
TensorFactorShardings operandBefore1 = {
TensorFactorShardingMap operandBefore1 = {
.factorIndexToSharding = {
{0, {.axisRefs = {createAxis("a")}}},
{1, {.axisRefs = {createAxis("d")}}},
}};
TensorFactorShardings result0 = {
TensorFactorShardingMap result0 = {
.factorIndexToSharding = {
{0,
{.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}},
{1, {.axisRefs = {createAxis("d")}}},
}};

TensorFactorShardings operandAfter0 = {
TensorFactorShardingMap operandAfter0 = {
.factorIndexToSharding = {
{0,
{.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}},
{1, {.axisRefs = {createAxis("d"), createAxis("e")}}},
}};
TensorFactorShardings operandAfter1 = {
TensorFactorShardingMap operandAfter1 = {
.factorIndexToSharding = {
{0,
{.axisRefs = {createAxis("a"), createAxis("b"), createAxis("c")}}},
Expand Down Expand Up @@ -546,14 +546,14 @@ TEST_F(BasicFactorPropagationTest, UniDirectionalPropagation) {
}

TEST_F(BasicFactorPropagationTest, UniDirectionalPropagationWithConflict) {
TensorFactorShardings operand0 = {
TensorFactorShardingMap operand0 = {
.factorIndexToSharding = {
{0, {.axisRefs = {createAxis("a"), createAxis("b")}}},
}};
TensorFactorShardings operand1 = {.factorIndexToSharding = {
TensorFactorShardingMap operand1 = {.factorIndexToSharding = {
{0, {.axisRefs = {createAxis("a")}}},
}};
TensorFactorShardings result = {
TensorFactorShardingMap result = {
.factorIndexToSharding = {
{0,
{.axisRefs = {createAxis("z"), createAxis("a"), createAxis("b")}}},
Expand Down
10 changes: 5 additions & 5 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ void notifyShardingModified(Value value,
notifyUsersModified(value, notifyOpModified);
}

// Update the sharding of `value` to the sharding in `tensorFactorShardings`.
// Update the sharding of `value` to the sharding in `TensorFactorShardingMap`.
//
// Returns true if it's possible to update the sharding, i.e., if strided view
// isn't needed and all non-minor-most factors are divisible by sharding axes.
bool updateTensorSharding(
TensorShardingAttr oldTensorSharding,
SetTensorShardingCallback setTensorShardingCallback,
const TensorFactorShardings& tensorFactorShardings,
const TensorFactorShardingMap& tensorFactorShardings,
TensorMappingAttr tensorMapping, ArrayRef<int64_t> factorSizes,
StringRef meshName, MeshAttr mesh, Value modifiedValue,
const ShardingGroupMap& shardingGroupMap,
Expand Down Expand Up @@ -173,17 +173,17 @@ bool updateTensorSharding(
return true;
}

// Updates the sharding of all tensors according to `tensorFactorShardings`.
// Updates the sharding of all tensors according to `TensorFactorShardingMap`.
//
// Skips tensors for which `updateTensor` is set to false.
//
// If an operand or result couldn't be updated to the corresponding sharding in
// `tensorFactorShardings`, e.g., if strided view is required, sets the
// `TensorFactorShardingMap`, e.g., if strided view is required, sets the
// respective bit in `updateTensor` or `updateResult` to false.
void updateTensorShardings(
ValueRange tensors, ArrayRef<TensorShardingAttr> tensorShardings,
SetShardingPerTensorCallback setTensorShardingCallback,
ArrayRef<TensorFactorShardings> tensorFactorShardings,
ArrayRef<TensorFactorShardingMap> tensorFactorShardings,
ArrayRef<TensorMappingAttr> tensorMappings, ArrayRef<int64_t> factorSizes,
BitVector& updateTensor, StringRef meshName, MeshAttr mesh,
const ShardingGroupMap& shardingGroupMap,
Expand Down
26 changes: 13 additions & 13 deletions shardy/dialect/sdy/transforms/propagation/sharding_projection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ bool shouldUpdate(ArrayRef<AxisRefAttr> oldAxes,
return oldAxes.back().strictPrefixOf(newAxes.back());
}

bool TensorFactorShardings::updateShardingAxes(int64_t factorIndex,
bool TensorFactorShardingMap::updateShardingAxes(int64_t factorIndex,
ArrayRef<AxisRefAttr> newAxes) {
auto factorShardingIt = factorIndexToSharding.find(factorIndex);
if (factorShardingIt == factorIndexToSharding.end()) {
Expand Down Expand Up @@ -103,7 +103,7 @@ int64_t addAxesToDimSharding(SmallVector<AxisRefAttr>& dimSharding,

} // namespace

TensorShardingAttr TensorFactorShardings::createTensorShardingAttr(
TensorShardingAttr TensorFactorShardingMap::createTensorShardingAttr(
MLIRContext* ctx, TensorMappingAttr tensorMapping,
ArrayRef<int64_t> factorSizes, StringRef meshName, MeshAttr mesh) const {
SmallVector<DimensionShardingAttr> newDimShardings;
Expand All @@ -114,7 +114,7 @@ TensorShardingAttr TensorFactorShardings::createTensorShardingAttr(
SmallVector<AxisRefAttr> dimSharding;
for (int64_t factorIndex : dimMapping.getFactorIndices()) {
int64_t factorSize = factorSizes[factorIndex];
const FactorSharding& factorSharding =
const TensorFactorSharding& factorSharding =
factorIndexToSharding.at(factorIndex);
isClosed |= factorSharding.isClosed;

Expand Down Expand Up @@ -207,18 +207,18 @@ void addRemainingAxes(SmallVector<AxisRefAttr>& currentAxes,
}
}

// Builds a `TensorFactorShardings` for a tensor with the specified
// Builds a `TensorFactorShardingMap` for a tensor with the specified
// `optionalSharding` and `tensorMapping`.
//
// The high level algorithm for projecting a dimension sharding into factor
// shardings is to add axes (or sub-axes) from the dimension sharding to the
// current factor sharding (starting from the major-most factor and axis) until
// the factor is fully sharded, which might require further splitting an axis,
// or this is the minor-most factor, then moving to the next factor.
TensorFactorShardings buildTensorFactorShardings(
TensorFactorShardingMap buildTensorFactorShardings(
TensorMappingAttr tensorMapping, TensorShardingAttr optionalSharding,
ArrayRef<int64_t> factorSizes, MeshAttr mesh) {
TensorFactorShardings result;
TensorFactorShardingMap result;
auto& [factorIndexToSharding, replicatedAxes] = result;
factorIndexToSharding.reserve(factorSizes.size());

Expand All @@ -238,7 +238,7 @@ TensorFactorShardings buildTensorFactorShardings(

bool hasOverflowAxes = false;
for (int64_t factorIndex : dimMapping.getFactorIndices()) {
FactorSharding& factorSharding = factorIndexToSharding[factorIndex];
TensorFactorSharding& factorSharding = factorIndexToSharding[factorIndex];
factorSharding.isMinorMost = dimMapping.isMinorMost(factorIndex);

if (hasOverflowAxes) {
Expand Down Expand Up @@ -310,9 +310,9 @@ TensorFactorShardings buildTensorFactorShardings(
return result;
}

TensorFactorShardings buildTensorFactorShardings(
TensorMappingAttr tensorMapping, ArrayRef<FactorSharding> factorShardings) {
TensorFactorShardings result;
TensorFactorShardingMap buildTensorFactorShardings(
TensorMappingAttr tensorMapping, ArrayRef<TensorFactorSharding> factorShardings) {
TensorFactorShardingMap result;
// TODO(enver): Drop replicatedAxes after propagation, perhaps isMinorMost as
// well.
result.factorIndexToSharding.reserve(factorShardings.size());
Expand All @@ -327,8 +327,8 @@ TensorFactorShardings buildTensorFactorShardings(
} // namespace

ShardingProjection::ShardingProjection(
SmallVector<TensorFactorShardings> operands,
SmallVector<TensorFactorShardings> results)
SmallVector<TensorFactorShardingMap> operands,
SmallVector<TensorFactorShardingMap> results)
: operands(std::move(operands)), results(std::move(results)) {}

ShardingProjection ShardingProjection::build(
Expand Down Expand Up @@ -360,7 +360,7 @@ ShardingProjection ShardingProjection::build(Operation* op,
}

ShardingProjection ShardingProjection::build(
ArrayRef<FactorSharding> factorShardings, OpShardingRuleAttr shardingRule) {
ArrayRef<TensorFactorSharding> factorShardings, OpShardingRuleAttr shardingRule) {
ShardingProjection projection;
for (const auto& operandMapping : shardingRule.getOperandMappings()) {
projection.operands.push_back(
Expand Down
Loading

0 comments on commit b2306a7

Please sign in to comment.