Skip to content

Commit

Permalink
[LLVMGPU][NFC] Create LLVMGPU pass for IGEMM (iree-org#18871)
Browse files Browse the repository at this point in the history
This PR refactors the ConvolutionToIGEMM pass to a shared transform
function, and creates a new pass for LLVMGPU. This keeps the lowering
config details in LLVMGPU separate from the common pass, and removes the
need for passing a control function or config function in the pass
constructor. This is also a precursor to adding some more complex logic
in the control function for LLVMGPU, which will be added in a later PR.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Oct 25, 2024
1 parent c6b3592 commit 55c5562
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 136 deletions.
162 changes: 82 additions & 80 deletions compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
Expand All @@ -26,10 +28,14 @@ namespace {

using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;

/// Pattern to set a lowering configuration on an IGEMM convolution. Searches
/// for a contraction with a linalg_ext.im2col producer, and calls the configFn
/// to set the configuration.
/// TODO(Max191): Use a funcOp walk instead of a pattern for this.
struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;

SetIGEMMConfiguration(MLIRContext *context, ConfigFn configFn)
SetIGEMMConfiguration(MLIRContext *context, IGEMMConfigFn configFn)
: OpRewritePattern(context), configFn(configFn) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
Expand Down Expand Up @@ -67,99 +73,95 @@ struct SetIGEMMConfiguration final : OpRewritePattern<linalg::GenericOp> {
}

private:
ConfigFn configFn;
IGEMMConfigFn configFn;
};

class ConvolutionToIGEMMPass final
: public impl::ConvolutionToIGEMMPassBase<ConvolutionToIGEMMPass> {
public:
using ConvolutionToIGEMMPassBase::ConvolutionToIGEMMPassBase;

explicit ConvolutionToIGEMMPass(ConfigFn configFn) : configFn(configFn) {}
ConvolutionToIGEMMPass(std::optional<IGEMMConfigFn> configFn,
std::optional<IGEMMControlFn> controlFn)
: configFn(configFn), controlFn(controlFn) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();

// Rewrite convolutions into a im2col and GEMM.
{
auto conv2dToIm2colControlFn = [](Operation *conv) {
// Don't transform convolutions that have a preset lowering config.
if (getLoweringConfig(conv)) {
return false;
}
return true;
};
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
patterns, conv2dToIm2colControlFn);
patterns.add<SetIGEMMConfiguration>(context, configFn);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}

// The im2col transformation collapses some of the dimensions of the
// convolution operands. Try to push the reshape ops towards the boundaries
// of the function and fold with interface tensor ops.
//
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
// generate a multi-M dim contraction instead of collapsing and
// propagating reshapes. It should ultimately become a pass option to
// decide whether to collapse the contraction dimensions into a single
// M/N/K dimension.
{
RewritePatternSet bubbleCollapseShapePatterns(context);
linalg::ControlFusionFn bubbleUpExpansionControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

// Block only if one of the operations has a lowering configuration
// which means it likely expects tiling specific to its original
// shape.
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
return false;
}
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(bubbleCollapseShapePatterns)))) {
return signalPassFailure();
}
}
}
void runOnOperation() override;

private:
ConfigFn configFn = [](linalg::GenericOp genericOp,
IREE::LinalgExt::Im2colOp im2colOp) {
return failure();
};
std::optional<IGEMMConfigFn> configFn;
std::optional<IGEMMControlFn> controlFn;
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvolutionToIGEMMPass(ConfigFn configFn) {
return std::make_unique<ConvolutionToIGEMMPass>(configFn);
LogicalResult
convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
std::optional<IGEMMConfigFn> configFn,
std::optional<IGEMMControlFn> controlFn) {
// Rewrite convolutions into a im2col and GEMM.
MLIRContext *context = funcOp->getContext();
{
RewritePatternSet patterns(context);
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(patterns,
controlFn);
if (configFn.has_value()) {
patterns.add<SetIGEMMConfiguration>(context, configFn.value());
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return failure();
}
}

// The im2col transformation collapses some of the dimensions of the
// convolution operands. Try to push the reshape ops towards the boundaries
// of the function and fold with interface tensor ops.
//
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
// generate a multi-M dim contraction instead of collapsing and
// propagating reshapes. It should ultimately become a pass option to
// decide whether to collapse the contraction dimensions into a single
// M/N/K dimension.
{
RewritePatternSet bubbleCollapseShapePatterns(context);
linalg::ControlFusionFn bubbleUpExpansionControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

// Block only if one of the operations has a lowering configuration
// which means it likely expects tiling specific to its original
// shape.
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
return false;
}
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(bubbleCollapseShapePatterns)))) {
return failure();
}
}
return success();
}

void ConvolutionToIGEMMPass::runOnOperation() {
if (failed(convertToIGEMMAndSetConfig(getOperation()))) {
return signalPassFailure();
}
}

} // namespace mlir::iree_compiler
7 changes: 0 additions & 7 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,6 @@ std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvertToDestinationPassingStylePass(
bool useWARForCooperativeMatrixCodegen);

using ConfigFn =
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
/// Pass to convert Conv2D ops into IGEMM (Im2colOp + matmul). `configFn` is
/// used to set lowering configurations on the resulting ops, if necessary.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass(ConfigFn configFn);

std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);

/// Pass to perform linalg on tensor bufferization. The function passed into
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def ConvolutionToIGEMMPass :
InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
let summary =
"Transforms convolution operations into an implicit GEMM format.";
let dependentDialects = [
"tensor::TensorDialect",
"iree_compiler::IREE::LinalgExt::IREELinalgExtDialect"
];
}

def DecomposeAffineOpsPass: Pass<"iree-codegen-decompose-affine-ops"> {
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ struct OneShotBufferizationOptions;

namespace mlir::iree_compiler {

using IGEMMConfigFn =
std::function<LogicalResult(linalg::GenericOp, IREE::LinalgExt::Im2colOp)>;
using IGEMMControlFn = std::function<bool(Operation *)>;

/// Converts conv_2d ops into linalg_ext.im2col + matmul, and sets a lowering
/// configuration on the matmul.
LogicalResult convertToIGEMMAndSetConfig(
FunctionOpInterface funcOp,
std::optional<IGEMMConfigFn> configFn = std::nullopt,
std::optional<IGEMMControlFn> controlFn = std::nullopt);

/// Eliminates tensor.empty ops to avoid buffer allocations.
LogicalResult eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,6 @@ module {

// -----

#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<1x14x14x16xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
// CHECK: func.func public @conv_with_lowering_config
// CHECK-NOT: iree_linalg_ext.im2col
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: lowering_config

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ iree_compiler_cc_library(
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPUConvolutionToIGEMM.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPrefetching.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ iree_cc_library(
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPUConvolutionToIGEMM.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPrefetching.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

#define DEBUG_TYPE "iree-llvmgpu-convolution-to-igemm"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_LLVMGPUCONVOLUTIONTOIGEMMPASS
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"

namespace {

/// Function for setting lowering configurations on contractions resulting from
/// the IGEMM transformation. This currently uses the TileAndFuse pipeline, and
/// tries to target MMA intrinsics.
static LogicalResult llvmgpuConfigFn(linalg::GenericOp genericOp,
IREE::LinalgExt::Im2colOp im2colOp) {
auto funcOp = genericOp->getParentOfType<FunctionOpInterface>();
if (!funcOp) {
return genericOp.emitError("cannot find parent funcOp");
}
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
if (!target) {
return funcOp.emitError("missing GPU target in parent funcOp");
}
if (failed(IREE::GPU::setMatmulLoweringConfig(target, funcOp, genericOp))) {
return IREE::GPU::setTileAndFuseLoweringConfig(target, funcOp, genericOp);
}
return success();
}

static bool llvmgpuControlFn(Operation *op) {
// Do not convert anything that already has a lowering configuration.
if (getLoweringConfig(op)) {
return false;
}
return true;
}

struct LLVMGPUConvolutionToIGEMMPass final
: impl::LLVMGPUConvolutionToIGEMMPassBase<LLVMGPUConvolutionToIGEMMPass> {
using impl::LLVMGPUConvolutionToIGEMMPassBase<
LLVMGPUConvolutionToIGEMMPass>::LLVMGPUConvolutionToIGEMMPassBase;

void runOnOperation() override;
};

void LLVMGPUConvolutionToIGEMMPass::runOnOperation() {
if (failed(convertToIGEMMAndSetConfig(getOperation(), llvmgpuConfigFn,
llvmgpuControlFn))) {
return signalPassFailure();
}
}

} // namespace
} // namespace mlir::iree_compiler
Loading

0 comments on commit 55c5562

Please sign in to comment.