Skip to content

Commit

Permalink
[LLVMGPU] Use forall workgroup distribution in TileAndFuse pipeline (i…
Browse files Browse the repository at this point in the history
…ree-org#18565)

This switches the TileAndFuse pipeline to use scf.forall distribution.
Using scf.forall distribution also requires some changes to the pass
ordering in the TileAndFuse pipeline, which is also handled by this PR:
1. The main difference is that PackToIntrinsics happens before workgroup
distribution. Otherwise, collapse_shape ops can end up at the end of the
workgroup forall, and an extra buffer is created.
2. Pack decomposition is now staged, with packs/unpacks at the function
boundaries being decomposed early before workgroup decomposition, and
the rest being decomposed after reduction tiling as before. This
prevents unpacks being fused into the workgroup forall and causing the
same problem as in (1).
3. `ConcretizeMmaShapes` now runs before workgroup tiling as well,
so the resulting collapse_shape on the multi_mma op result can be
propagated to the function boundary before any tiling. This is also to
avoid the same problem as in (1).

The lowering configs on the MMA path have also changed, since they now
need to account for inner tile sizes of packing.

depends on iree-org#18852

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Oct 24, 2024
1 parent 4d20b82 commit c3fae2f
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
}

// Compute the M/N dimension tile size by multiplying subgroup information.
workgroupTileSizes[mDim] =
schedule->mWarpCount * schedule->mTileCount * schedule->mSize;
workgroupTileSizes[nDim] =
schedule->nWarpCount * schedule->nTileCount * schedule->nSize;
workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount;
workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount;

// Specify the subgroup tile sizes from the mma schedule. This is applied
subgroupTileSizes[mDim] = schedule->mTileCount;
Expand Down
73 changes: 52 additions & 21 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
Expand Down Expand Up @@ -190,18 +191,23 @@ static void addBufferizePasses(OpPassManager &funcPassManager) {
}

static void tileAndDistributeToWorkgroup(
OpPassManager &funcPassManager,
OpPassManager &funcPassManager, bool useForall,
std::optional<ConvertToDestinationPassingStylePassOptions>
convertToDpsOptions = ConvertToDestinationPassingStylePassOptions{}) {
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass(
kNumMaxParallelDims,
linalg::DistributionMethod::CyclicNumProcsEqNumIters));
funcPassManager.addPass(createCSEPass());

if (convertToDpsOptions) {
if (useForall) {
funcPassManager.addPass(
createConvertToDestinationPassingStylePass(*convertToDpsOptions));
createTileAndDistributeToWorkgroupsUsingForallOpPass());
} else {
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass(
kNumMaxParallelDims,
linalg::DistributionMethod::CyclicNumProcsEqNumIters));
funcPassManager.addPass(createCSEPass());
if (convertToDpsOptions) {
funcPassManager.addPass(
createConvertToDestinationPassingStylePass(*convertToDpsOptions));
}
}

// TODO(#16421): Disable decomposition due to failure in bufferization.
// funcPassManager.addPass(
// IREE::LinalgExt::createTileAndDecomposeAttentionPass());
Expand All @@ -212,7 +218,8 @@ static void tileAndDistributeToWorkgroup(
static void tileAndBufferize(OpPassManager &funcPassManager) {
ConvertToDestinationPassingStylePassOptions options;
options.useWARForCooperativeMatrixCodegen = true;
tileAndDistributeToWorkgroup(funcPassManager, options);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false,
/*convertToDpsOptions=*/options);
addBufferizePasses(funcPassManager);
}

Expand Down Expand Up @@ -243,7 +250,7 @@ static void addGPUVectorizationPasses(OpPassManager &funcPassManager,
//===---------------------------------------------------------------------===//

void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -323,22 +330,45 @@ static void addGPUBufferizePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCSEPass());
}

/// Control function for decomposing pack and unpack ops. Returns true if the
/// op is a PackOp with a DispatchTensorLoadOp producer, or an UnPackOp with
/// only DispatchTensorStoreOp consumers.
LogicalResult isAtBoundary(Operation *op) {
if (isa<tensor::PackOp>(op)) {
if (isa_and_nonnull<IREE::Flow::DispatchTensorLoadOp>(
op->getOperand(0).getDefiningOp())) {
return success();
}
} else if (isa<tensor::UnPackOp>(op)) {
if (llvm::all_of(op->getUsers(), [](Operation *user) {
return isa<IREE::Flow::DispatchTensorStoreOp>(user);
})) {
return success();
}
}
return failure();
}

void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &pipelineOptions) {
tileAndDistributeToWorkgroup(funcPassManager,
/*convertToDpsOptions=*/std::nullopt);

// Step 1. Promote matmul operands and pack to intrinsic shapes.
funcPassManager.addPass(createGPUPromoteMatmulOperandsPass());
funcPassManager.addPass(IREE::GPU::createPackToIntrinsicsPass());
// Decompose packs and unpacks that are at the function boundary.
funcPassManager.addPass(createDecomposeBoundaryPackUnPackOpsPass());

// Step 1.5. Expand result shapes of MultiMmaOps before reduction tiling.
// Step 1.5. Expand result shapes of MultiMmaOps before tiling, and
// propagate reshapes to the function boundary.
{
IREE::GPU::ConcretizeMmaShapesPassOptions options;
options.concretizeInputs = false;
options.concretizeResult = true;
funcPassManager.addPass(IREE::GPU::createConcretizeMmaShapesPass());
}
funcPassManager.addPass(createPropagateReshapesByExpansionPass());

tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
/*convertToDpsOptions=*/std::nullopt);

// Step 2. Tile and fuse tileable ops to reduction loops.
{
Expand Down Expand Up @@ -468,7 +498,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
//===---------------------------------------------------------------------===//

void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -505,7 +535,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {

void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -709,7 +739,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(

void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -814,7 +844,7 @@ static void addVectorBufferizePasses(OpPassManager &funcPassManager) {
void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options,
bool usePadToModelSharedMemcpy) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);

ReorderWorkgroupsStrategy reorderStrategy =
getReorderWorkgroupsStrategy(options.reorderStrategy);
Expand Down Expand Up @@ -914,7 +944,7 @@ void addGPUVectorDistributePassPipeline(OpPassManager &funcPassManager,
}

void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
funcPassManager.addPass(createRematerializeParallelOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createGPUTileReductionPass());
Expand Down Expand Up @@ -958,7 +988,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
}

void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

Expand Down Expand Up @@ -994,7 +1024,8 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
ConvertToDestinationPassingStylePassOptions dpsOptions;
dpsOptions.useWARForCooperativeMatrixCodegen = true;
tileAndDistributeToWorkgroup(funcPassManager, dpsOptions);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false,
/*convertToDpsOptions=*/dpsOptions);
if (options.enableUkernels) {
funcPassManager.addPass(createGPULowerToUKernelsPass());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
// CHECK-SAME: subgroup = [0, 0, 4, 1, 0]
// CHECK-SAME: workgroup = [1, 1, 64, 64, 0]
// CHECK-SAME: workgroup = [1, 1, 4, 4, 0]

// -----

Expand All @@ -63,7 +63,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 2]
// CHECK-SAME: subgroup = [4, 4, 0]
// CHECK-SAME: workgroup = [128, 128, 0]
// CHECK-SAME: workgroup = [8, 8, 0]

// -----

Expand Down
Loading

0 comments on commit c3fae2f

Please sign in to comment.