Skip to content

Commit

Permalink
#sdy use applyPatternsGreedily with config.fold=false to avoid co…
Browse files Browse the repository at this point in the history
…nstant folding which is expensive.

PiperOrigin-RevId: 708547565
  • Loading branch information
tomnatan30 authored and copybara-github committed Dec 21, 2024
1 parent 786b95d commit 7e99965
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
22 changes: 17 additions & 5 deletions shardy/dialect/sdy/transforms/import/import_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,28 @@ limitations under the License.
namespace mlir {
namespace sdy {

namespace {

GreedyRewriteConfig getCanonicalizerConfig() {
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Disabled;
config.fold = false;
return config;
}

} // namespace

void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory,
"sdy_module_before_sdy_import"));
// We need to apply the inliner pass so we have a single main function,
// otherwise we would need to propagate shardings between call ops and callee
// functions.
pm.addPass(createInlinerPass());
GreedyRewriteConfig canonicalizerConfig = getCanonicalizerConfig();
pm.addPass(createInlinerPass({}, [&](OpPassManager& pm) {
pm.addPass(createCanonicalizerPass(canonicalizerConfig));
}));
pm.addPass(createSymbolDCEPass());
pm.addPass(createLiftInlinedMeshesPass());
pm.addNestedPass<func::FuncOp>(createConstantSplitterPass());
Expand All @@ -43,11 +58,8 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
// members which have pre-propagation shardings due to sharding constraints.
pm.addPass(createShardingGroupImportPass());

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Disabled;
pm.addPass(createCanonicalizerPass(
/*config=*/config, /*disabledPatterns=*/{},
canonicalizerConfig, /*disabledPatterns=*/{},
/*enabledPatterns=*/{"DedupShardingGroupPattern"}));
pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory,
"sdy_module_after_sdy_import"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,13 @@ LogicalResult BasicPropagationPassImpl::propagate(
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns),
config))) {
config.fold = false;
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns), config))) {
// We should always converge in 2 iterations, if we don't, something is
// wrong.
moduleOp->emitError("Failed to converge after ")
<< config.maxIterations
<< " iterations. please contact the Shardy team.";
return failure();
}

Expand Down

0 comments on commit 7e99965

Please sign in to comment.