Skip to content

Commit

Permalink
[Codegen][GPU] Add tiling cleanup pattern to fuse pad without zero ga…
Browse files Browse the repository at this point in the history
…urd (iree-org#18748)

This PR adds a way to fuse tensor.pad in ApplyGPUTilingLevel when we
know the pad will not ever recieve an empty slice. This is useful, when
the tensor.pad is padding to the tiling size that we are tiling with,
and will never generate an empty slice.
  • Loading branch information
Groverkss authored Oct 15, 2024
1 parent 7622770 commit afe18d2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/ADT/STLForwardCompat.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -65,10 +66,9 @@ collectTiledAndFusedOps(Operation *op,

/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult
applyTileAndFuseToEachRoot(RewriterBase &rewriter,
llvm::SmallDenseSet<TilingInterface> &payloadOps,
IREE::GPU::TilingLevel tilingLevel) {
static LogicalResult applyTileAndFuseToEachRoot(
RewriterBase &rewriter, llvm::SmallDenseSet<TilingInterface> &payloadOps,
IREE::GPU::TilingLevel tilingLevel, bool allowZeroSlices) {
MLIRContext *context = rewriter.getContext();
for (TilingInterface tilingInterfaceOp : payloadOps) {
mlir::DominanceInfo dominanceInfo(tilingInterfaceOp);
Expand Down Expand Up @@ -137,7 +137,8 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter,
Operation *owner = originalProducer.getOwner();
if (tilingLevel == IREE::GPU::TilingLevel::Reduction ||
tilingLevel == IREE::GPU::TilingLevel::Subgroup) {
// Do not fuse pad in reduction and subgroup tiling.
// Do not fuse pad in reduction and subgroup tiling. We instead fuse
// pad without zero slice guard as a cleanup pattern.
if (isa<tensor::PadOp>(owner)) {
return std::nullopt;
}
Expand All @@ -161,6 +162,22 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter,
};
tileAndFuseOptions.setFusionControlFn(controlFn);

RewritePatternSet cleanupPatterns(context);

if (allowZeroSlices) {
// Add pattern to fuse pad operations without zero slice gaurd, if we
// know we have no zero slices.
auto zeroSliceGuard = [](tensor::ExtractSliceOp) -> std::optional<bool> {
// Do not use zero slice gaurd.
return false;
};
cleanupPatterns.add<linalg::ExtractSliceOfPadTensorSwapPattern>(
context, zeroSliceGuard);
}

tileAndFuseOptions.cleanupPatterns =
FrozenRewritePatternSet(std::move(cleanupPatterns));

FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
tileAndFuseOptions);
Expand Down Expand Up @@ -221,7 +238,8 @@ void GPUApplyTilingLevelPass::runOnOperation() {
getTiledOps(funcOp, tilingLevel);

IRRewriter rewriter(funcOp);
if (failed(applyTileAndFuseToEachRoot(rewriter, targetOps, tilingLevel))) {
if (failed(applyTileAndFuseToEachRoot(rewriter, targetOps, tilingLevel,
allowZeroSlices))) {
funcOp.emitError() << "tiling of level "
<< IREE::GPU::stringifyEnum(tilingLevel) << " failed\n";
return signalPassFailure();
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def GPUApplyTilingLevelPass :
clEnumValN(IREE::GPU::TilingLevel::Subgroup, "subgroup",
"Tile and fuse all annotated ops to threads")
)}]>,
Option<"allowZeroSlices", "allow-zero-slices", "bool",
/*default=*/"false",
"Allow pad fusion to generate zero size slices">
];
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level, canonicalize, cse))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{allow-zero-slices=true}, canonicalize, cse))" %s | FileCheck %s --check-prefix=NOZERO
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=thread}, canonicalize, cse))" %s | FileCheck %s --check-prefix=THREAD
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-apply-tiling-level{tiling-level=subgroup}, canonicalize, cse))" %s | FileCheck %s --check-prefix=SUBGROUP

Expand Down Expand Up @@ -474,3 +475,51 @@ module {
// SUBGROUP: scf.forall.in_parallel
// SUBGROUP: tensor.parallel_insert_slice %[[MMA]] into %[[INIT]]
// SUBGROUP: return

// -----

// This test only checks when a tensor.pad gets fused when tiling. We disable
// tensor.pad fusion by default, because it generates a gaurd to prevent
// empty slices, which is hard to vectorize.
//
// However, if we already know no zero slices will be generated, we can fuse
// the pad directly.

#map = affine_map<()[s0] -> (s0 * -16 + 19, 16)>
#map1 = affine_map<()[s0] -> (-s0 + 16)>
module {
func.func @fuse_pad_no_zero_slice(%arg0: tensor<?x17xf32>, %arg1: tensor<17x17xf32>, %arg2: index, %arg3: index) -> tensor<?x17xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = affine.min #map()[%arg2]
%1 = tensor.empty() : tensor<16x32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<16x32xf32>) -> tensor<16x32xf32>
%3 = affine.apply #map1()[%0]
%padded = tensor.pad %arg0 low[0, 0] high[%3, 7] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %cst : f32
} : tensor<?x17xf32> to tensor<16x24xf32>
%padded_0 = tensor.pad %arg1 low[0, 0] high[7, 15] {
^bb0(%arg4: index, %arg5: index):
tensor.yield %cst : f32
} : tensor<17x17xf32> to tensor<24x32xf32>
%4 = linalg.matmul {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 8]}>} ins(%padded, %padded_0 : tensor<16x24xf32>, tensor<24x32xf32>) outs(%2 : tensor<16x32xf32>) -> tensor<16x32xf32>
%extracted_slice = tensor.extract_slice %4[0, 0] [%0, 17] [1, 1] : tensor<16x32xf32> to tensor<?x17xf32>
return %extracted_slice : tensor<?x17xf32>
}
}

// Only fuse pad when no-zero-slices is true.

// CHECK-LABEL: @fuse_pad_no_zero_slice
// CHECK: tensor.pad
// CHECK: tensor.pad
// CHECK: scf.for
// CHECK-NOT: tensor.pad
// CHECK: linalg.matmul

// NOZERO-LABEL: @fuse_pad_no_zero_slice
// NOZERO-NOT: tensor.pad
// NOZERO: scf.for
// NOZERO: tensor.pad
// NOZERO: tensor.pad
// NOZERO: linalg.matmul

0 comments on commit afe18d2

Please sign in to comment.