Skip to content

Commit

Permalink
[LLVMGPU] Combine parallel and reduction padding in LLVMGPUPadAndVect…
Browse files Browse the repository at this point in the history
…orDistribute (iree-org#18771)

Since iree-org#18748 tensor.pad can be
fused in with tiling. This patch combines the parallel and reduction
padding passes into a single pass that pads at once, and the pads are
later fused during tiling.
  • Loading branch information
Groverkss authored Oct 25, 2024
1 parent 1fc6e5b commit 1aa5825
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 253 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,18 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
public:
using impl::LLVMGPUPromoteMatmulToFitMMAPassBase<
LLVMGPUPromoteMatmulToFitMMAPass>::LLVMGPUPromoteMatmulToFitMMAPassBase;
explicit LLVMGPUPromoteMatmulToFitMMAPass(
const LLVMGPUMatmulPadOption &option) {
this->targetDimensions.setValue(option);
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, linalg::LinalgDialect>();
}

void padWithZeroValue(RewriterBase &rewriter, linalg::LinalgOp op,
ArrayRef<int64_t> paddingDims,
ArrayRef<int64_t> padToMultipleOf, bool noFold) const {
assert(paddingDims.size() == padToMultipleOf.size() &&
"invalid pad multiples for padding dimensions");

ArrayRef<int64_t> padToMultipleOf) const {
LLVM_DEBUG(llvm::dbgs() << "candidate: " << op << "\n");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);

SmallVector<bool> nofoldFlags(op.getNumDpsInputs(), noFold);
SmallVector<int64_t> paddingDims =
llvm::to_vector(llvm::seq<int64_t>(padToMultipleOf.size()));

SmallVector<Attribute> paddingValueAttributes;
for (auto &operand : op->getOpOperands()) {
Expand All @@ -58,7 +51,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
.setPaddingDimensions(paddingDims)
.setPaddingValues(paddingValueAttributes)
.setPadToMultipleOf(padToMultipleOf)
.setNofoldFlags(nofoldFlags)
.setCopyBackOp(linalg::LinalgPaddingOptions::CopyBackOp::None);

FailureOr<linalg::LinalgOp> result =
Expand All @@ -72,26 +64,6 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
MLIRContext *ctx = &getContext();
auto funcOp = getOperation();

// Preserve the innermost tensor.pad ops (i.e., pad for reduction dims), so
// we can kick canonicalization patterns to fold outer tensor.pad ops away.
bool noFold = false;
utils::IteratorType targetIterType = utils::IteratorType::parallel;
switch (targetDimensions) {
case LLVMGPUMatmulPadOption::ParallelDims:
LLVM_DEBUG(llvm::dbgs() << "padding parallel dims\n");
targetIterType = utils::IteratorType::parallel;
noFold = false;
break;
case LLVMGPUMatmulPadOption::ReductionDims:
LLVM_DEBUG(llvm::dbgs() << "padding reduction dims\n");
targetIterType = utils::IteratorType::reduction;
noFold = true;
break;
default: // Unreachable.
assert(false);
break;
};

SmallVector<linalg::LinalgOp> candidates;
funcOp->walk([&](linalg::LinalgOp op) {
if (linalg::isaContractionOpInterface(op)) {
Expand All @@ -101,46 +73,27 @@ class LLVMGPUPromoteMatmulToFitMMAPass final

IRRewriter rewriter(ctx);
for (linalg::LinalgOp op : candidates) {
SmallVector<int64_t> padMultiples(op.getNumLoops(), 1);
auto config = dyn_cast_or_null<IREE::GPU::LoweringConfigAttr>(
getLoweringConfig(op));
if (config) {
switch (targetDimensions) {
case LLVMGPUMatmulPadOption::ParallelDims:
padMultiples = config.getStaticTilingLevelSizes(
static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
break;
case LLVMGPUMatmulPadOption::ReductionDims:
padMultiples = config.getStaticTilingLevelSizes(
static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);
break;
default:
assert(false && "Unexpected target dimensions");
break;
}
if (!config) {
continue;
}

// Populate padding dimensions.
SmallVector<int64_t> paddingDimensions;
for (auto [idx, iter] : llvm::enumerate(op.getIteratorTypesArray())) {
if (iter == targetIterType) {
paddingDimensions.push_back(idx);
}
}
SmallVector<int64_t> wgTiles = config.getStaticTilingLevelSizes(
static_cast<unsigned>(IREE::GPU::TilingLevel::Workgroup), op);
SmallVector<int64_t> redTiles = config.getStaticTilingLevelSizes(
static_cast<unsigned>(IREE::GPU::TilingLevel::Reduction), op);

// Populate tile sizes. We pad to multiples of workgroup/reduction
// tile sizes based on the selected target tiling dimensions.
// This pass is ran after the select target tiling is done to pad
// all dimensions to the select tile sizes.
SmallVector<int64_t> padToMultipleOf;
for (int64_t dim : paddingDimensions) {
if (padMultiples[dim] != 0) {
padToMultipleOf.push_back(padMultiples[dim]);
}
// Populate padding dimensions to maximum of possible tile sizes.
SmallVector<int64_t> padToMultipleOf(op.getNumLoops(), 1);
for (auto [wgTile, redTile, padMultiple] :
llvm::zip_equal(wgTiles, redTiles, padToMultipleOf)) {
padMultiple = std::max({wgTile, redTile, padMultiple});
}
SmallVector<int64_t> paddingDimensions =
llvm::to_vector(llvm::seq<int64_t>(op.getNumLoops()));

padWithZeroValue(rewriter, op, paddingDimensions, padToMultipleOf,
noFold);
padWithZeroValue(rewriter, op, padToMultipleOf);
}

{
Expand All @@ -156,58 +109,8 @@ class LLVMGPUPromoteMatmulToFitMMAPass final
return signalPassFailure();
}
}

// XXX(hanchung): This is needed for pad op fusion, which will remove
// outer pad ops. I.e., it mainly wants to remove first pad op in the
// pad->extract_slice->pad chain, while the canonicalization pattern can
// only recognize slice->pad->slice->pad.
{
SmallVector<tensor::PadOp> padOps;
funcOp.walk([&](tensor::PadOp op) { padOps.push_back(op); });
for (auto op : padOps) {
auto srcExtractSliceOp =
op.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!srcExtractSliceOp) {
continue;
}
auto producerPadOp =
srcExtractSliceOp.getSource().getDefiningOp<tensor::PadOp>();
if (!producerPadOp) {
continue;
}
auto src = producerPadOp.getSource()
.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!src) {
continue;
}

rewriter.setInsertionPointAfter(src);
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, op.getLoc(), src);
SmallVector<OpFoldResult> offsets(sizes.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(sizes.size(),
rewriter.getIndexAttr(1));
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
op.getLoc(), src.getResult(), offsets, sizes, strides);
rewriter.startOpModification(op);
producerPadOp.getSourceMutable().assign(extractSliceOp.getResult());
rewriter.finalizeOpModification(op);
}

RewritePatternSet patterns(ctx);
tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option) {
return std::make_unique<LLVMGPUPromoteMatmulToFitMMAPass>(option);
}

} // namespace mlir::iree_compiler
9 changes: 2 additions & 7 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,25 +858,20 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

if (usePadToModelSharedMemcpy) {
LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ParallelDims;
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass());
}

// Tile to reduction loops.
{
GPUApplyTilingLevelPassOptions options;
options.tilingLevel = IREE::GPU::TilingLevel::Reduction;
options.allowZeroSlices = true;
funcPassManager.addPass(createGPUApplyTilingLevelPass(options));
funcPassManager.addPass(affine::createLoopCoalescingPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
}

if (usePadToModelSharedMemcpy) {
LLVMGPUMatmulPadOption option = LLVMGPUMatmulPadOption::ReductionDims;
funcPassManager.addPass(createLLVMGPUPromoteMatmulToFitMMAPass(option));
}

funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
Expand Down
4 changes: 0 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@ verifyGPUMatmulPipeline(Operation *op,
// Wrappers that not use tablegen options.
//------------------------------------------------------------------------------

enum class LLVMGPUMatmulPadOption { ParallelDims, ReductionDims };
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMGPUPromoteMatmulToFitMMAPass(LLVMGPUMatmulPadOption option);

enum class GPUTensorCoreType {
WMMA = 0,
MMA_SYNC = 1,
Expand Down
13 changes: 0 additions & 13 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,6 @@ def LLVMGPUPrefetchSharedMemoryPass :
def LLVMGPUPromoteMatmulToFitMMAPass :
InterfacePass<"iree-llvmgpu-promote-matmul-to-fit-mma", "mlir::FunctionOpInterface"> {
let summary = "Pass to promote contraction ops to fit mma shapes";
let options = [
Option<"targetDimensions", "target-dimensions", "mlir::iree_compiler::LLVMGPUMatmulPadOption",
/*default=*/"mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims",
"Select the strategy to control how multi_reduction is lowered.",
[{::llvm::cl::values(
clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ParallelDims,
"parallel",
"Pad all the parallel dims for contraction ops."),
clEnumValN(mlir::iree_compiler::LLVMGPUMatmulPadOption::ReductionDims,
"reduction",
"Pad all the reduction dims for contraction ops.")
)}]>
];
}

def LLVMGPUSelectLoweringStrategyPass :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
// CHECK: %[[RHS_LOAD:.+]] = vector.transfer_read %[[RHS_GLOBAL_SUB]]{{.+}} {in_bounds = [true, false, false]}
// CHECK: vector.transfer_write %[[LHS_LOAD]], %[[LHS_SHARED]]
// CHECK: vector.transfer_write %[[RHS_LOAD]], %[[RHS_SHARED]]
// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1265 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
// CHECK: %[[RES:.+]] scf.for {{.*}} = %c0 to %c1280 step %c16 iter_args({{.*}}) -> (vector<1x1x1x1x1x1x1x4x1xf16>)
// CHECK-DAG: %[[LHS_GLOBAL_SUB:.+]] = memref.subview %[[LHS_GLOBAL]]
// CHECK-DAG: %[[RHS_GLOBAL_SUB:.+]] = memref.subview %[[RHS_GLOBAL]]
// CHECK: %[[LHS_LOAD:.+]] = vector.transfer_read %[[LHS_GLOBAL_SUB]]
Expand Down Expand Up @@ -581,9 +581,11 @@ hal.executable public @pad_batch_matmul {
// CHECK-SAME: memref<196x16x24xf32
// CHECK-SAME: vector<1x1x1xf32>
// RHS
// The dynamic dimension should be removed after:
// https://github.com/llvm/llvm-project/pull/112236
// CHECK: vector.transfer_read
// CHECK-SAME: in_bounds = [true, true, false]
// CHECK-SAME: memref<1x8x24xf32
// CHECK-SAME: in_bounds = [true, false, false]
// CHECK-SAME: memref<1x?x24xf32
// CHECK-SAME: vector<1x1x2xf32>
// CHECK: scf.yield
// OUTPUT
Expand Down
Loading

0 comments on commit 1aa5825

Please sign in to comment.