diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td index 250add55..e454dd7e 100644 --- a/shardy/dialect/sdy/transforms/export/passes.td +++ b/shardy/dialect/sdy/transforms/export/passes.td @@ -25,11 +25,9 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> { let description = [{ Moves the sharding of each `DataFlowEdgeOp` to its input (the root target of the edge), and replaces the op with its input. - - TODO(tomnatan): consider moving the sharding to all targets that can have a - sharding attached. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; + //TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached. } def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "func::FuncOp"> { @@ -54,7 +52,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun After propagation, some operations may still have incompatible shardings. - Please note, when an axis (or sub-axis) is used to shard non-corresponding + Note that when an axis (or sub-axis) is used to shard non-corresponding dimensions (e.g. non-contracting dimensions in matmul) across multiple tensors, or when an axis shards a dimension in one tensor but not the corresponding dimension in the other tensor, it is said that the operation @@ -66,7 +64,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun and results, and every axis (or sub-axis) can only be used to shard a single dimension type. - A clarifying example: + Example: Input: ```mlir @@ -89,7 +87,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun In the example above, there is a conflict since `lhs` and `rhs` tensors are both sharded on axis "x" on their non-contracting dimensions. Here, - `rhs` tensor is resharded, before the dot operation, explicitly to be + `rhs` tensor is resharded before the dot operation explicitly, to be sharded only on its first dimension and on axis "x". This way, the dot operation becomes compatible. }]; @@ -100,12 +98,12 @@ def ReshardToCollectivesPass : Pass<"sdy-reshard-to-collectives", "func::FuncOp" let summary = "Converts ReshardOp into various Shardy collective ops."; let dependentDialects = ["mlir::sdy::SdyDialect"]; let description = [{ - Here we match reshard ops and rewrite them into various Shardy collective + Matches reshard ops and rewrites them into various Shardy collective ops. After this pass, no reshard ops remain in the module. This pass assumes - that xplicit reshards have already been inserted + that explicit reshards have already been inserted (`sdy-insert-explicit-reshards`). - A clarifying example: + Example: Input: ```mlir diff --git a/shardy/dialect/sdy/transforms/import/passes.td b/shardy/dialect/sdy/transforms/import/passes.td index d114d6a7..c2ed7f5e 100644 --- a/shardy/dialect/sdy/transforms/import/passes.td +++ b/shardy/dialect/sdy/transforms/import/passes.td @@ -20,14 +20,14 @@ def LiftInlinedMeshesPass : Pass<"sdy-lift-inlined-meshes", "ModuleOp"> { let description = [{ Replaces any inlined `MeshAttr` in a `TensorShardingAttr` with a mesh symbol name, referencing either an existing or new `MeshOp` in the module, such - that no two `MeshOp`s with an identical `MeshAttr` (existing `MeshOp`s are + that no two `MeshOp`s have an identical `MeshAttr` (existing `MeshOp`s are deduped as well). The name of each new `MeshOp` will either be: * `maximal_mesh_{device-id}`, for a maximal mesh (i.e., empty axis list and - a single device ID). - * The first available name in [`mesh`, `mesh_0`, `mesh_1`, ...], otherwise. + a single device ID), or + * The first available name in [`mesh`, `mesh_0`, `mesh_1`, ...]. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } @@ -41,10 +41,9 @@ def AddDataFlowEdgesPass : Pass<"sdy-add-data-flow-edges", "func::FuncOp"> { The inserted `DataFlowEdgeOp` will take the existing sharding of the owner target if it exists. - - TODO(b/330339693): update this doc when `getDataFlowEdgeOwners` is removed. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; + //TODO(b/330339693): update this doc when `getDataFlowEdgeOwners` is removed. } def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func::FuncOp"> { @@ -60,8 +59,8 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func: * The input doesn't have any other users of type `ShardingConstraintOp` or `ManualComputationOp` with a different sharding. - Which indicates that the `ShardingConstraintOp` dictates the sharding of - its input. + These conditions indicate that the `ShardingConstraintOp` dictates the + sharding of its input. Note that the sharding of a `ShardingConstraintOp` will propagate to its input or users during propagation regardless of this pass, but since the @@ -72,12 +71,12 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func: satisfy all of the following: * The tensor isn't produced by a `ShardingConstraintOp` and doesn't have any - other users of type `ShardingConstraintOp` or `ManualComputationOp`. + other users of type `ShardingConstraintOp` or `ManualComputationOp`; * None of the `ShardingConstraintOp`s in the chain have more than one use - except the last one. + except the last one; * The last `ShardingConstraintOp` in the chain doesn't have any users of type `ShardingConstraintOp` or `ManualComputationOp` (otherwise it's not - the last in the chain). + the last in the chain); then this pass replaces all other uses of the input of the chain, that are defined after the last `ShardingConstraintOp` in the chain (and within the @@ -112,7 +111,7 @@ def ConstantSplitterPass : Pass<"sdy-constant-splitter", "func::FuncOp"> { Note that within a constant sub-computation, a value can have multiple uses within that sub-computation. - NOTE: This pass is the MLIR equivalent of xla::HloConstantSplitter, + NOTE: This pass is the MLIR equivalent of `xla::HloConstantSplitter`, needed for the purpose of Shardy Propagation. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; @@ -124,7 +123,8 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> { Applies canonicalization and validation to sharding groups upon import. Namely these are: - 1) Sharding Group Unification - + 1. Sharding Group Unification + Combines sharding groups using the transitive property of group membership. Any time that a tensor T is in a sharding group G1 *and* sharding group G2, then we can infer that all members in G1 and G2 should @@ -132,7 +132,8 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> { group. The set of canonical group ids after merging will be 0,1,...N-1 for the minimum set of groups. - 2) Sharding Group Validation + 2. Sharding Group Validation + Validates that sharding groups are well formed and conform to assumptions within the implementation. This currently asserts that if a sharding group contains a `Value` defined inside the block of a @@ -145,10 +146,10 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> { def ManualAxesCleanupPass : Pass<"sdy-manual-axes-cleanup", "ModuleOp"> { let summary = "Cleans up the use of manual axes in `ManualComputationOp`s"; let description = [{ - 1) For any in/out sharding that hasn't specified a manual axis, add that + 1. For any in/out sharding that hasn't specified a manual axis, add that manual axis to its replicated_axes. This is to ensure manual axes are always fully specified. - 2) Sorts the manual axes in mesh axis declaration order. + 2. Sorts the manual axes in mesh axis declaration order. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } diff --git a/shardy/dialect/sdy/transforms/propagation/passes.td b/shardy/dialect/sdy/transforms/propagation/passes.td index 4d78c1b2..e37a8115 100644 --- a/shardy/dialect/sdy/transforms/propagation/passes.td +++ b/shardy/dialect/sdy/transforms/propagation/passes.td @@ -22,20 +22,19 @@ def BasicPropagationPass : PassBase<"sdy-basic-propagate", "BasicPropagationPass hierarchy, that doesn't do any conflict resolution, and instead propagates axes that are compatible between all operands and results. - Options: - * `-keep-sharding-rules` : whether to keep existing and created op - sharding rules - * `-module-dump-directory` : where to dump any rewritten modules for - debugging - * `-conservative-propagation` : whether to disallow split axes and - non-divisible sharding axes during propagation - * `-debug-sharding-origins` : whether to save information about the origin - of a sharding on the MLIR module. These would be the shardings on the - function inputs, outputs, sharding constraints and manual computations - before propagation. - * `-debug-edge-source-sharding` : whether to save information about the - edge source of a sharding on the MLIR module. These are what - operand/result introduced a sharding on some op result. + **Options:** + - `-keep-sharding-rules`: whether to keep existing and created op sharding + rules. + - `-module-dump-directory`: where to dump any rewritten modules for debugging. + - `-conservative-propagation`: whether to disallow split axes and non-divisible + sharding axes during propagation. + - `-debug-sharding-origins`: whether to save information about the origin of a + sharding on the MLIR module. These would be the shardings on the function + inputs, outputs, sharding constraints and manual computations before + propagation. + - `-debug-edge-source-sharding`: whether to save information about the edge source + of a sharding on the MLIR module. These are what operand/result introduced a + sharding on some op result. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } @@ -48,9 +47,20 @@ def AggressivePropagationPass : PassBase<"sdy-aggressive-propagate", "Aggressive aggressive strategy resolves conflicts. Higher aggressiveness can reduce the memory footprint at the cost of potential communication. - Options: - * All options from `BasicPropagationPass` - * `-propagation-strategy` : which factor propagation strategy to use + **Options:** + - `-keep-sharding-rules`: whether to keep existing and created op sharding + rules. + - `-module-dump-directory`: where to dump any rewritten modules for debugging. + - `-conservative-propagation`: whether to disallow split axes and non-divisible + sharding axes during propagation. + - `-debug-sharding-origins`: whether to save information about the origin of a + sharding on the MLIR module. These would be the shardings on the function + inputs, outputs, sharding constraints and manual computations before + propagation. + - `-debug-edge-source-sharding`: whether to save information about the edge source + of a sharding on the MLIR module. These are what operand/result introduced a + sharding on some op result. + - `-propagation-strategy`: which factor propagation strategy to use. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } @@ -75,10 +85,22 @@ def OpPriorityPropagationPass : PassBase<"sdy-op-priority-propagate", "OpPriorit means that at each op-priority iteration, a full aggressive propagation is applied (see `AggressivePropagationPass`). - Options: - * All options from `AggressivePropagationPass` - * `-run-op-priority-propagation` : whether to run (or skip) op-priority - propagation + **Options:** + - `-keep-sharding-rules`: whether to keep existing and created op sharding + rules. + - `-module-dump-directory`: where to dump any rewritten modules for debugging. + - `-conservative-propagation`: whether to disallow split axes and non-divisible + sharding axes during propagation. + - `-debug-sharding-origins`: whether to save information about the origin of a + sharding on the MLIR module. These would be the shardings on the function + inputs, outputs, sharding constraints and manual computations before + propagation. + - `-debug-edge-source-sharding`: whether to save information about the edge source + of a sharding on the MLIR module. These are what operand/result introduced a + sharding on some op result. + - `-propagation-strategy`: which factor propagation strategy to use. + - `-run-op-priority-propagation`: whether to run (or skip) op-priority + propagation. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } @@ -95,8 +117,22 @@ def UserPriorityPropagationPass : PassBase<"sdy-user-priority-propagate", "UserP which means that at each user-priority iteration, a full op-priority propagation is applied (see `OpPriorityPropagationPass`). - Options: - * All options from `OpPriorityPropagationPass` + **Options:** + - `-keep-sharding-rules`: whether to keep existing and created op sharding + rules. + - `-module-dump-directory`: where to dump any rewritten modules for debugging. + - `-conservative-propagation`: whether to disallow split axes and non-divisible + sharding axes during propagation. + - `-debug-sharding-origins`: whether to save information about the origin of a + sharding on the MLIR module. These would be the shardings on the function + inputs, outputs, sharding constraints and manual computations before + propagation. + - `-debug-edge-source-sharding`: whether to save information about the edge source + of a sharding on the MLIR module. These are what operand/result introduced a + sharding on some op result. + - `-propagation-strategy`: which factor propagation strategy to use. + - `-run-op-priority-propagation`: whether to run (or skip) op-priority + propagation. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; }