Skip to content

Commit

Permalink
PR #276: Improve propagation passes docstrings
Browse files Browse the repository at this point in the history
Imported from GitHub PR #276

- Remove some comments to outside the rendered docs;
- Rewrite options using the `Options<>` syntax;
- Corrects typos and suggest some wording.
Copybara import of the project:

--
4cdd396 by Melissa Weber Mendonça <melissawm@gmail.com>:

Improve propagation passes docstrings

--
7b3b534 by Melissa Weber Mendonça <melissawm@gmail.com>:

Improve import and export passes docstrings

--
5c8284d by Melissa Weber Mendonça <melissawm@gmail.com>:

Undo Options syntax

--
f3e1b7a by Melissa Weber Mendonça <melissawm@gmail.com>:

Revert options formatting

Merging this change closes #276

COPYBARA_INTEGRATE_REVIEW=#276 from melissawm:passes-doc f3e1b7a
PiperOrigin-RevId: 708367080
  • Loading branch information
melissawm authored and copybara-github committed Dec 20, 2024
1 parent ad2a7fd commit bed5ac5
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 47 deletions.
16 changes: 7 additions & 9 deletions shardy/dialect/sdy/transforms/export/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> {
let description = [{
Moves the sharding of each `DataFlowEdgeOp` to its input (the root target of
the edge), and replaces the op with its input.

TODO(tomnatan): consider moving the sharding to all targets that can have a
sharding attached.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
//TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached.
}

def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "func::FuncOp"> {
Expand All @@ -54,7 +52,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun

After propagation, some operations may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
Note that when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
Expand All @@ -66,7 +64,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:
Example:

Input:
```mlir
Expand All @@ -89,7 +87,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
`rhs` tensor is resharded before the dot operation explicitly, to be
sharded only on its first dimension and on axis "x". This way, the dot
operation becomes compatible.
}];
Expand All @@ -100,12 +98,12 @@ def ReshardToCollectivesPass : Pass<"sdy-reshard-to-collectives", "func::FuncOp"
let summary = "Converts ReshardOp into various Shardy collective ops.";
let dependentDialects = ["mlir::sdy::SdyDialect"];
let description = [{
Here we match reshard ops and rewrite them into various Shardy collective
Matches reshard ops and rewrites them into various Shardy collective
ops. After this pass, no reshard ops remain in the module. This pass assumes
that xplicit reshards have already been inserted
that explicit reshards have already been inserted
(`sdy-insert-explicit-reshards`).

A clarifying example:
Example:

Input:
```mlir
Expand Down
31 changes: 16 additions & 15 deletions shardy/dialect/sdy/transforms/import/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def LiftInlinedMeshesPass : Pass<"sdy-lift-inlined-meshes", "ModuleOp"> {
let description = [{
Replaces any inlined `MeshAttr` in a `TensorShardingAttr` with a mesh symbol
name, referencing either an existing or new `MeshOp` in the module, such
that no two `MeshOp`s with an identical `MeshAttr` (existing `MeshOp`s are
that no two `MeshOp`s have an identical `MeshAttr` (existing `MeshOp`s are
deduped as well).

The name of each new `MeshOp` will either be:

* `maximal_mesh_{device-id}`, for a maximal mesh (i.e., empty axis list and
a single device ID).
* The first available name in [`mesh`, `mesh_0`, `mesh_1`, ...], otherwise.
a single device ID), or
* The first available name in [`mesh`, `mesh_0`, `mesh_1`, ...].
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand All @@ -41,10 +41,9 @@ def AddDataFlowEdgesPass : Pass<"sdy-add-data-flow-edges", "func::FuncOp"> {

The inserted `DataFlowEdgeOp` will take the existing sharding of the owner
target if it exists.

TODO(b/330339693): update this doc when `getDataFlowEdgeOwners` is removed.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
//TODO(b/330339693): update this doc when `getDataFlowEdgeOwners` is removed.
}

def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func::FuncOp"> {
Expand All @@ -60,8 +59,8 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func:
* The input doesn't have any other users of type `ShardingConstraintOp` or
`ManualComputationOp` with a different sharding.

Which indicates that the `ShardingConstraintOp` dictates the sharding of
its input.
These conditions indicate that the `ShardingConstraintOp` dictates the
sharding of its input.

Note that the sharding of a `ShardingConstraintOp` will propagate to its
input or users during propagation regardless of this pass, but since the
Expand All @@ -72,12 +71,12 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func:
satisfy all of the following:

* The tensor isn't produced by a `ShardingConstraintOp` and doesn't have any
other users of type `ShardingConstraintOp` or `ManualComputationOp`.
other users of type `ShardingConstraintOp` or `ManualComputationOp`;
* None of the `ShardingConstraintOp`s in the chain have more than one use
except the last one.
except the last one;
* The last `ShardingConstraintOp` in the chain doesn't have any users of
type `ShardingConstraintOp` or `ManualComputationOp` (otherwise it's not
the last in the chain).
the last in the chain);

then this pass replaces all other uses of the input of the chain, that are
defined after the last `ShardingConstraintOp` in the chain (and within the
Expand Down Expand Up @@ -112,7 +111,7 @@ def ConstantSplitterPass : Pass<"sdy-constant-splitter", "func::FuncOp"> {
Note that within a constant sub-computation, a value can have multiple uses
within that sub-computation.

NOTE: This pass is the MLIR equivalent of xla::HloConstantSplitter,
NOTE: This pass is the MLIR equivalent of `xla::HloConstantSplitter`,
needed for the purpose of Shardy Propagation.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
Expand All @@ -124,15 +123,17 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> {
Applies canonicalization and validation to sharding groups upon import.
Namely these are:

1) Sharding Group Unification -
1. Sharding Group Unification

Combines sharding groups using the transitive property of group
membership. Any time that a tensor T is in a sharding group G1 *and*
sharding group G2, then we can infer that all members in G1 and G2 should
be sharded in the same way. Thus we can combine G1 and G2 into a single
group. The set of canonical group ids after merging will be 0,1,...N-1
for the minimum set of groups.

2) Sharding Group Validation
2. Sharding Group Validation

Validates that sharding groups are well formed and conform to assumptions
within the implementation. This currently asserts that if a sharding
group contains a `Value` defined inside the block of a
Expand All @@ -145,10 +146,10 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> {
def ManualAxesCleanupPass : Pass<"sdy-manual-axes-cleanup", "ModuleOp"> {
let summary = "Cleans up the use of manual axes in `ManualComputationOp`s";
let description = [{
1) For any in/out sharding that hasn't specified a manual axis, add that
1. For any in/out sharding that hasn't specified a manual axis, add that
manual axis to its replicated_axes. This is to ensure manual axes are
always fully specified.
2) Sorts the manual axes in mesh axis declaration order.
2. Sorts the manual axes in mesh axis declaration order.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
82 changes: 59 additions & 23 deletions shardy/dialect/sdy/transforms/propagation/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,19 @@ def BasicPropagationPass : PassBase<"sdy-basic-propagate", "BasicPropagationPass
hierarchy, that doesn't do any conflict resolution, and instead propagates
axes that are compatible between all operands and results.

Options:
* `-keep-sharding-rules` : whether to keep existing and created op
sharding rules
* `-module-dump-directory` : where to dump any rewritten modules for
debugging
* `-conservative-propagation` : whether to disallow split axes and
non-divisible sharding axes during propagation
* `-debug-sharding-origins` : whether to save information about the origin
of a sharding on the MLIR module. These would be the shardings on the
function inputs, outputs, sharding constraints and manual computations
before propagation.
* `-debug-edge-source-sharding` : whether to save information about the
edge source of a sharding on the MLIR module. These are what
operand/result introduced a sharding on some op result.
**Options:**
- `-keep-sharding-rules`: whether to keep existing and created op sharding
rules.
- `-module-dump-directory`: where to dump any rewritten modules for debugging.
- `-conservative-propagation`: whether to disallow split axes and non-divisible
sharding axes during propagation.
- `-debug-sharding-origins`: whether to save information about the origin of a
sharding on the MLIR module. These would be the shardings on the function
inputs, outputs, sharding constraints and manual computations before
propagation.
- `-debug-edge-source-sharding`: whether to save information about the edge source
of a sharding on the MLIR module. These are what operand/result introduced a
sharding on some op result.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand All @@ -48,9 +47,20 @@ def AggressivePropagationPass : PassBase<"sdy-aggressive-propagate", "Aggressive
aggressive strategy resolves conflicts. Higher aggressiveness can reduce the
memory footprint at the cost of potential communication.

Options:
* All options from `BasicPropagationPass`
* `-propagation-strategy` : which factor propagation strategy to use
**Options:**
- `-keep-sharding-rules`: whether to keep existing and created op sharding
rules.
- `-module-dump-directory`: where to dump any rewritten modules for debugging.
- `-conservative-propagation`: whether to disallow split axes and non-divisible
sharding axes during propagation.
- `-debug-sharding-origins`: whether to save information about the origin of a
sharding on the MLIR module. These would be the shardings on the function
inputs, outputs, sharding constraints and manual computations before
propagation.
- `-debug-edge-source-sharding`: whether to save information about the edge source
of a sharding on the MLIR module. These are what operand/result introduced a
sharding on some op result.
- `-propagation-strategy`: which factor propagation strategy to use.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand All @@ -75,10 +85,22 @@ def OpPriorityPropagationPass : PassBase<"sdy-op-priority-propagate", "OpPriorit
means that at each op-priority iteration, a full aggressive propagation is
applied (see `AggressivePropagationPass`).

Options:
* All options from `AggressivePropagationPass`
* `-run-op-priority-propagation` : whether to run (or skip) op-priority
propagation
**Options:**
- `-keep-sharding-rules`: whether to keep existing and created op sharding
rules.
- `-module-dump-directory`: where to dump any rewritten modules for debugging.
- `-conservative-propagation`: whether to disallow split axes and non-divisible
sharding axes during propagation.
- `-debug-sharding-origins`: whether to save information about the origin of a
sharding on the MLIR module. These would be the shardings on the function
inputs, outputs, sharding constraints and manual computations before
propagation.
- `-debug-edge-source-sharding`: whether to save information about the edge source
of a sharding on the MLIR module. These are what operand/result introduced a
sharding on some op result.
- `-propagation-strategy`: which factor propagation strategy to use.
- `-run-op-priority-propagation`: whether to run (or skip) op-priority
propagation.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand All @@ -95,8 +117,22 @@ def UserPriorityPropagationPass : PassBase<"sdy-user-priority-propagate", "UserP
which means that at each user-priority iteration, a full op-priority
propagation is applied (see `OpPriorityPropagationPass`).

Options:
* All options from `OpPriorityPropagationPass`
**Options:**
- `-keep-sharding-rules`: whether to keep existing and created op sharding
rules.
- `-module-dump-directory`: where to dump any rewritten modules for debugging.
- `-conservative-propagation`: whether to disallow split axes and non-divisible
sharding axes during propagation.
- `-debug-sharding-origins`: whether to save information about the origin of a
sharding on the MLIR module. These would be the shardings on the function
inputs, outputs, sharding constraints and manual computations before
propagation.
- `-debug-edge-source-sharding`: whether to save information about the edge source
of a sharding on the MLIR module. These are what operand/result introduced a
sharding on some op result.
- `-propagation-strategy`: which factor propagation strategy to use.
- `-run-op-priority-propagation`: whether to run (or skip) op-priority
propagation.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand Down

0 comments on commit bed5ac5

Please sign in to comment.