Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve propagation passes docstrings #276

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions shardy/dialect/sdy/transforms/export/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ def SinkDataFlowEdgesPass : Pass<"sdy-sink-data-flow-edges", "func::FuncOp"> {
let description = [{
Moves the sharding of each `DataFlowEdgeOp` to its input (the root target of
the edge), and replaces the op with its input.

TODO(tomnatan): consider moving the sharding to all targets that can have a
sharding attached.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
//TODO(tomnatan): consider moving the sharding to all targets that can have a sharding attached.
}

def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible-input-output-shardings", "func::FuncOp"> {
Expand All @@ -54,7 +52,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun

After propagation, some operations may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
Note that when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
Expand All @@ -66,7 +64,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:
Example:

Input:
```mlir
Expand All @@ -89,7 +87,7 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
`rhs` tensor is resharded before the dot operation explicitly, to be
sharded only on its first dimension and on axis "x". This way, the dot
operation becomes compatible.
}];
Expand All @@ -100,12 +98,12 @@ def ReshardToCollectivesPass : Pass<"sdy-reshard-to-collectives", "func::FuncOp"
let summary = "Converts ReshardOp into various Shardy collective ops.";
let dependentDialects = ["mlir::sdy::SdyDialect"];
let description = [{
Here we match reshard ops and rewrite them into various Shardy collective
Matches reshard ops and rewrites them into various Shardy collective
ops. After this pass, no reshard ops remain in the module. This pass assumes
that xplicit reshards have already been inserted
that explicit reshards have already been inserted
(`sdy-insert-explicit-reshards`).

A clarifying example:
Example:

Input:
```mlir
Expand Down
31 changes: 16 additions & 15 deletions shardy/dialect/sdy/transforms/import/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def LiftInlinedMeshesPass : Pass<"sdy-lift-inlined-meshes", "ModuleOp"> {
let description = [{
Replaces any inlined `MeshAttr` in a `TensorShardingAttr` with a mesh symbol
name, referencing either an existing or new `MeshOp` in the module, such
that no two `MeshOp`s with an identical `MeshAttr` (existing `MeshOp`s are
that no two `MeshOp`s have an identical `MeshAttr` (existing `MeshOp`s are
deduped as well).

The name of each new `MeshOp` will either be:

* `maximal_mesh_{device-id}`, for a maximal mesh (i.e., empty axis list and
a single device ID).
* The first available name in [`mesh`, `mesh_0`, `mesh_1`, ...], otherwise.
a single device ID), or
* The first available name in [`mesh`, `mesh_0`, `mesh_1`, ...].
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand All @@ -41,10 +41,9 @@ def AddDataFlowEdgesPass : Pass<"sdy-add-data-flow-edges", "func::FuncOp"> {

The inserted `DataFlowEdgeOp` will take the existing sharding of the owner
target if it exists.

TODO(b/330339693): update this doc when `getDataFlowEdgeOwners` is removed.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
//TODO(b/330339693): update this doc when `getDataFlowEdgeOwners` is removed.
}

def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func::FuncOp"> {
Expand All @@ -60,8 +59,8 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func:
* The input doesn't have any other users of type `ShardingConstraintOp` or
`ManualComputationOp` with a different sharding.

Which indicates that the `ShardingConstraintOp` dictates the sharding of
its input.
These conditions indicate that the `ShardingConstraintOp` dictates the
sharding of its input.

Note that the sharding of a `ShardingConstraintOp` will propagate to its
input or users during propagation regardless of this pass, but since the
Expand All @@ -72,12 +71,12 @@ def ApplyShardingConstraintsPass : Pass<"sdy-apply-sharding-constraints", "func:
satisfy all of the following:

* The tensor isn't produced by a `ShardingConstraintOp` and doesn't have any
other users of type `ShardingConstraintOp` or `ManualComputationOp`.
other users of type `ShardingConstraintOp` or `ManualComputationOp`;
* None of the `ShardingConstraintOp`s in the chain have more than one use
except the last one.
except the last one;
* The last `ShardingConstraintOp` in the chain doesn't have any users of
type `ShardingConstraintOp` or `ManualComputationOp` (otherwise it's not
the last in the chain).
the last in the chain);

then this pass replaces all other uses of the input of the chain, that are
defined after the last `ShardingConstraintOp` in the chain (and within the
Expand Down Expand Up @@ -112,7 +111,7 @@ def ConstantSplitterPass : Pass<"sdy-constant-splitter", "func::FuncOp"> {
Note that within a constant sub-computation, a value can have multiple uses
within that sub-computation.

NOTE: This pass is the MLIR equivalent of xla::HloConstantSplitter,
NOTE: This pass is the MLIR equivalent of `xla::HloConstantSplitter`,
needed for the purpose of Shardy Propagation.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
Expand All @@ -124,15 +123,17 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> {
Applies canonicalization and validation to sharding groups upon import.
Namely these are:

1) Sharding Group Unification -
1. Sharding Group Unification

Combines sharding groups using the transitive property of group
membership. Any time that a tensor T is in a sharding group G1 *and*
sharding group G2, then we can infer that all members in G1 and G2 should
be sharded in the same way. Thus we can combine G1 and G2 into a single
group. The set of canonical group ids after merging will be 0,1,...N-1
for the minimum set of groups.

2) Sharding Group Validation
2. Sharding Group Validation

Validates that sharding groups are well formed and conform to assumptions
within the implementation. This currently asserts that if a sharding
group contains a `Value` defined inside the block of a
Expand All @@ -145,10 +146,10 @@ def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> {
def ManualAxesCleanupPass : Pass<"sdy-manual-axes-cleanup", "ModuleOp"> {
let summary = "Cleans up the use of manual axes in `ManualComputationOp`s";
let description = [{
1) For any in/out sharding that hasn't specified a manual axis, add that
1. For any in/out sharding that hasn't specified a manual axis, add that
manual axis to its replicated_axes. This is to ensure manual axes are
always fully specified.
2) Sorts the manual axes in mesh axis declaration order.
2. Sorts the manual axes in mesh axis declaration order.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
132 changes: 104 additions & 28 deletions shardy/dialect/sdy/transforms/propagation/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,30 @@ def BasicPropagationPass : PassBase<"sdy-basic-propagate", "BasicPropagationPass
The basic propagation algorithm is the lowest strategy of propagation in the
hierarchy, that doesn't do any conflict resolution, and instead propagates
axes that are compatible between all operands and results.

Options:
* `-keep-sharding-rules` : whether to keep existing and created op
sharding rules
* `-module-dump-directory` : where to dump any rewritten modules for
debugging
* `-conservative-propagation` : whether to disallow split axes and
non-divisible sharding axes during propagation
* `-debug-sharding-origins` : whether to save information about the origin
of a sharding on the MLIR module. These would be the shardings on the
function inputs, outputs, sharding constraints and manual computations
before propagation.
* `-debug-edge-source-sharding` : whether to save information about the
edge source of a sharding on the MLIR module. These are what
operand/result introduced a sharding on some op result.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
let options = [
Option<"keepShardingRules", "keep-sharding-rules", "bool",
/*default=*/"false",
"whether to keep existing and created op sharding rules.">,
Option<"moduleDumpDirectory", "module-dump-directory", "string",
/*default=*/"",
"where to dump any rewritten modules for debugging.">,
Option<"conservativePropagation", "conservative-propagation", "bool",
/*default=*/"false",
"whether to disallow split axes and non-divisible sharding axes during "
"propagation.">,
Option<"debugShardingOrigins", "debug-sharding-origins", "bool",
/*default=*/"false",
"whether to save information about the origin of a sharding on the MLIR "
"module. These would be the shardings on the function inputs, outputs, "
"sharding constraints and manual computations before propagation.">,
Option<"debugEdgeSourceSharding", "debug-edge-source-sharding", "bool",
/*default=*/"false",
"whether to save information about the edge source of a sharding on the "
"MLIR module. These are what operand/result introduced a sharding on "
"some op result.">,
];
}

def AggressivePropagationPass : PassBase<"sdy-aggressive-propagate", "AggressivePropagationPassImpl"> {
Expand All @@ -47,12 +54,33 @@ def AggressivePropagationPass : PassBase<"sdy-aggressive-propagate", "Aggressive
basic strategy only propagates shardings without conflicts, while the
aggressive strategy resolves conflicts. Higher aggressiveness can reduce the
memory footprint at the cost of potential communication.

Options:
* All options from `BasicPropagationPass`
* `-propagation-strategy` : which factor propagation strategy to use
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
let options = [
Option<"keepShardingRules", "keep-sharding-rules", "bool",
/*default=*/"false",
"whether to keep existing and created op sharding rules.">,
Option<"moduleDumpDirectory", "module-dump-directory", "string",
/*default=*/"",
"where to dump any rewritten modules for debugging.">,
Option<"conservativePropagation", "conservative-propagation", "bool",
/*default=*/"false",
"whether to disallow split axes and non-divisible sharding axes during "
"propagation.">,
Option<"debugShardingOrigins", "debug-sharding-origins", "bool",
/*default=*/"false",
"whether to save information about the origin of a sharding on the MLIR "
"module. These would be the shardings on the function inputs, outputs, "
"sharding constraints and manual computations before propagation.">,
Option<"debugEdgeSourceSharding", "debug-edge-source-sharding", "bool",
/*default=*/"false",
"whether to save information about the edge source of a sharding on the "
"MLIR module. These are what operand/result introduced a sharding on "
"some op result.">,
Option<"propagationStrategy", "propagation-strategy", "string",
/*default=*/"",
"which factor propagation strategy to use.">,
];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'm surprised this works! I remember due to the subclassing we couldn't make these explicit options. Does this compile and all tests pass? I'm guessing it does since the flags are duplicated?

I guess an annoying issue is that we can't create top level constants which save these options in a list, and have to duplicate them. Hmm I'm unsure what is better to do here....I'm leaning to keeping it as is, @tomnatan30 any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this doesn't work, see failing build.

We can't declare the options at the moment because we define them in the c++ files so they can be inherited.

@melissawm feel free to update the description of the options and change their format so they will look better in the auto generated doc, but we can't switch to declarative options just yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try to make it look exactly as it does with the declarative options you have now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a pity! Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think repeating is fine - in terms of usability, having all options on hand when looking up one of these items is more helpful than having to click through other previously defined items. In addition, the items on the page are now shown in the same order as the source file so we can't rely on what was "previously" defined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I understand the concern. It's up to you really - I'm happy to move that back and refer to the original options instead of repeating. Just let me know and I'll change it 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would want to try declaring the options like you did before, and this was require some refactoring, so better to keep it simple until we do (not repeating) if that's ok.

btw, when we use the let options, the auto generated doc would have:

Options

-some-option1 : whether to do A
-some-option2 : whether to do B

Should we reuse this format or keep as is? Maybe better to avoid the #### Options, given it looks out of format in the td file. wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the difference - the #### are added automatically when the let options syntax is used, so I'm not sure that could be tweaked. I could use that in the description field but I think it would look out of place, as the markdown section hierarchy represented by the # characters would be lost. I would vote for just using bold face for now (as in **Options**)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also referring to the code block all options were in when using the let options syntax, but lets keep as is :)

The PR got auto imported with the previous version (with duplication), and this PR was automatically closed. This is fine, we can fix it on out side.

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops sorry about that! Let me know if there's anything I can do, otherwise thanks for all the feedback!!

}

def OpPriorityPropagationPass : PassBase<"sdy-op-priority-propagate", "OpPriorityPropagationPassImpl"> {
Expand All @@ -74,13 +102,36 @@ def OpPriorityPropagationPass : PassBase<"sdy-op-priority-propagate", "OpPriorit
This propagation strategy extends the aggressive propagation strategy, which
means that at each op-priority iteration, a full aggressive propagation is
applied (see `AggressivePropagationPass`).

Options:
* All options from `AggressivePropagationPass`
* `-run-op-priority-propagation` : whether to run (or skip) op-priority
propagation
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
let options = [
Option<"keepShardingRules", "keep-sharding-rules", "bool",
/*default=*/"false",
"whether to keep existing and created op sharding rules.">,
Option<"moduleDumpDirectory", "module-dump-directory", "string",
/*default=*/"",
"where to dump any rewritten modules for debugging.">,
Option<"conservativePropagation", "conservative-propagation", "bool",
/*default=*/"false",
"whether to disallow split axes and non-divisible sharding axes during "
"propagation.">,
Option<"debugShardingOrigins", "debug-sharding-origins", "bool",
/*default=*/"false",
"whether to save information about the origin of a sharding on the MLIR "
"module. These would be the shardings on the function inputs, outputs, "
"sharding constraints and manual computations before propagation.">,
Option<"debugEdgeSourceSharding", "debug-edge-source-sharding", "bool",
/*default=*/"false",
"whether to save information about the edge source of a sharding on the "
"MLIR module. These are what operand/result introduced a sharding on "
"some op result.">,
Option<"propagationStrategy", "propagation-strategy", "string",
/*default=*/"",
"which factor propagation strategy to use.">,
Option<"runOpPriorityPropagation", "run-op-priority-propagation", "bool",
/*default=*/"false",
"whether to run (or skip) op-priority propagation.">,
];
}

def UserPriorityPropagationPass : PassBase<"sdy-user-priority-propagate", "UserPriorityPropagationPassImpl"> {
Expand All @@ -94,11 +145,36 @@ def UserPriorityPropagationPass : PassBase<"sdy-user-priority-propagate", "UserP
This propagation strategy extends the op-priority propagation strategy,
which means that at each user-priority iteration, a full op-priority
propagation is applied (see `OpPriorityPropagationPass`).

Options:
* All options from `OpPriorityPropagationPass`
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
let options = [
Option<"keepShardingRules", "keep-sharding-rules", "bool",
/*default=*/"false",
"whether to keep existing and created op sharding rules.">,
Option<"moduleDumpDirectory", "module-dump-directory", "string",
/*default=*/"",
"where to dump any rewritten modules for debugging.">,
Option<"conservativePropagation", "conservative-propagation", "bool",
/*default=*/"false",
"whether to disallow split axes and non-divisible sharding axes during "
"propagation.">,
Option<"debugShardingOrigins", "debug-sharding-origins", "bool",
/*default=*/"false",
"whether to save information about the origin of a sharding on the MLIR "
"module. These would be the shardings on the function inputs, outputs, "
"sharding constraints and manual computations before propagation.">,
Option<"debugEdgeSourceSharding", "debug-edge-source-sharding", "bool",
/*default=*/"false",
"whether to save information about the edge source of a sharding on the "
"MLIR module. These are what operand/result introduced a sharding on "
"some op result.">,
Option<"propagationStrategy", "propagation-strategy", "string",
/*default=*/"",
"which factor propagation strategy to use.">,
Option<"runOpPriorityPropagation", "run-op-priority-propagation", "bool",
/*default=*/"false",
"whether to run (or skip) op-priority propagation.">,
];
}

def PopulateOpShardingRulesPass : Pass<"sdy-populate-op-sharding-rules", "func::FuncOp"> {
Expand All @@ -116,6 +192,6 @@ def PopulateOpShardingRulesPass : Pass<"sdy-populate-op-sharding-rules", "func::
Option<"conservativePropagation", "conservative-propagation", "bool",
/*default=*/"false",
"whether to disllow rules that can propagate non-divisible sharding "
"axes">
"axes.">
];
}
Loading