Skip to content

Commit

Permalink
Lower data-tiled multi_mma to intrinsics. (iree-org#18547)
Browse files Browse the repository at this point in the history
This lowers `multi_mma` with `DataTiledMMAAttr` down to `amdgpu`
intrinsics.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Sep 25, 2024
1 parent 129ad45 commit 3773a48
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 25 deletions.
10 changes: 8 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ iree_compiler_cc_library(
"TileDispatchUsingForall.cpp",
"TileDispatchUsingInterface.cpp",
"TileSizeSelection.cpp",
"TileSwizzle.cpp",
"TypePropagationPass.cpp",
"UserConfig.cpp",
"VectorizeMemrefCopy.cpp",
Expand All @@ -155,14 +154,14 @@ iree_compiler_cc_library(
"PassUtils.h",
"Passes.h",
"TileSizeSelection.h",
"TileSwizzle.h",
"Transforms.h",
"UserConfig.h",
],
deps = [
":PassHeaders",
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen",
"//compiler/src/iree/compiler/Codegen/Common:TileSwizzle",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"//compiler/src/iree/compiler/Codegen/Interfaces:BufferizationInterfaces",
Expand Down Expand Up @@ -229,6 +228,13 @@ iree_compiler_cc_library(
],
)

iree_compiler_cc_library(
name = "TileSwizzle",
srcs = ["TileSwizzle.cpp"],
hdrs = ["TileSwizzle.h"],
deps = ["@llvm-project//llvm:Support"],
)

# TODO: If the layering causes concerns then the transform dialect interpreter
# should be one level above everything: it is a mechanism by which
# transformations are applied to any IR and needs to register all the dialects
Expand Down
15 changes: 13 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ iree_cc_library(
"PassUtils.h"
"Passes.h"
"TileSizeSelection.h"
"TileSwizzle.h"
"Transforms.h"
"UserConfig.h"
SRCS
Expand Down Expand Up @@ -135,7 +134,6 @@ iree_cc_library(
"TileDispatchUsingForall.cpp"
"TileDispatchUsingInterface.cpp"
"TileSizeSelection.cpp"
"TileSwizzle.cpp"
"TypePropagationPass.cpp"
"UserConfig.cpp"
"VectorizeMemrefCopy.cpp"
Expand Down Expand Up @@ -190,6 +188,7 @@ iree_cc_library(
MLIRVectorTransforms
MLIRViewLikeInterface
iree::compiler::Codegen::Common::FoldTensorExtractOpIncGen
iree::compiler::Codegen::Common::TileSwizzle
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
iree::compiler::Codegen::Interfaces::BufferizationInterfaces
Expand All @@ -207,6 +206,18 @@ iree_cc_library(
PUBLIC
)

iree_cc_library(
NAME
TileSwizzle
HDRS
"TileSwizzle.h"
SRCS
"TileSwizzle.cpp"
DEPS
LLVMSupport
PUBLIC
)

iree_cc_library(
NAME
TransformDialectInterpreterPass
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ iree_compiler_cc_library(
"GPUTensorTileToSerialLoops.cpp",
"GPUTile.cpp",
"GPUTileReduction.cpp",
"GPUTileSwizzleUtils.cpp",
"GPUVectorAlloc.cpp",
"GPUVectorDistribution.cpp",
"GPUVerifyDistribution.cpp",
Expand All @@ -86,7 +85,6 @@ iree_compiler_cc_library(
],
hdrs = [
"GPUPatterns.h",
"GPUTileSwizzleUtils.h",
"GPUVectorDistribution.h",
"Passes.h",
],
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ iree_cc_library(
CommonGPUPasses
HDRS
"GPUPatterns.h"
"GPUTileSwizzleUtils.h"
"GPUVectorDistribution.h"
"Passes.h"
SRCS
Expand Down Expand Up @@ -75,7 +74,6 @@ iree_cc_library(
"GPUTensorTileToSerialLoops.cpp"
"GPUTile.cpp"
"GPUTileReduction.cpp"
"GPUTileSwizzleUtils.cpp"
"GPUVectorAlloc.cpp"
"GPUVectorDistribution.cpp"
"GPUVerifyDistribution.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
Expand Down Expand Up @@ -55,16 +55,16 @@ chooseDataTiledMMAAttr(TypeRange elementTypes, IREE::GPU::TargetAttr target) {
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];
auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollMToThreads,
int unrollN, int unrollNToThreads,
auto match = [=](MMAIntrinsic intrinsic, int unrollM, int unrollMToSubgroups,
int unrollN, int unrollNToSubgroups,
int unrollK) -> std::optional<DataTiledMMAAttr> {
if (!hasIntrinsic(target, intrinsic)) {
return std::nullopt;
}
auto candidate = DataTiledMMAAttr::get(
ctx, MMAIntrinsicAttr::get(ctx, intrinsic), /*unroll_m=*/unrollM,
/*unroll_m_to_subgroups=*/unrollMToThreads, /*unroll_n=*/unrollN,
/*unroll_n_to_subgroups=*/unrollNToThreads, /*unroll_k=*/unrollK);
/*unroll_m_to_subgroups=*/unrollMToSubgroups, /*unroll_n=*/unrollN,
/*unroll_n_to_subgroups=*/unrollNToSubgroups, /*unroll_k=*/unrollK);
auto [candidateLhs, candidateRhs, candidateOut] =
candidate.getABCElementTypes();
if (candidateLhs != lhs || candidateRhs != rhs || candidateOut != out) {
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/TileSwizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
return os << "CrossThread";
case TileSwizzle::Dim::Kind::CrossIntrinsic:
return os << "CrossIntrinsic";
default:
// Required by GCC.
assert(false);
return os;
}
}

Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ iree_compiler_cc_library(
name = "IREEGPUDialect",
srcs = [
"DerivedConfigUtils.cpp",
"GPUTileSwizzleUtils.cpp",
"IREEGPUAttrs.cpp",
"IREEGPUDialect.cpp",
"IREEGPUInterfaces.cpp",
Expand All @@ -56,6 +57,7 @@ iree_compiler_cc_library(
],
hdrs = [
"DerivedConfigUtils.h",
"GPUTileSwizzleUtils.h",
"IREEGPUAttrs.h",
"IREEGPUDialect.h",
"IREEGPUEnums.h",
Expand All @@ -80,6 +82,7 @@ iree_compiler_cc_library(
":IREEGPUEnums",
":IREEGPUInterfaces",
":IREEGPUOpsGen",
"//compiler/src/iree/compiler/Codegen/Common:TileSwizzle",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_cc_library(
IREEGPUDialect
HDRS
"DerivedConfigUtils.h"
"GPUTileSwizzleUtils.h"
"IREEGPUAttrs.h"
"IREEGPUDialect.h"
"IREEGPUEnums.h"
Expand All @@ -33,6 +34,7 @@ iree_cc_library(
"IREEGPUOps.h.inc"
SRCS
"DerivedConfigUtils.cpp"
"GPUTileSwizzleUtils.cpp"
"IREEGPUAttrs.cpp"
"IREEGPUDialect.cpp"
"IREEGPUInterfaces.cpp"
Expand All @@ -59,6 +61,7 @@ iree_cc_library(
MLIRTilingInterface
MLIRVectorDialect
MLIRVectorInterfaces
iree::compiler::Codegen::Common::TileSwizzle
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
iree::compiler::Codegen::Utils::VectorOpUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUTileSwizzleUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"

namespace mlir::iree_compiler {

Expand Down Expand Up @@ -106,8 +106,8 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
}
// The layout strides decide the initial swizzle.permutation.
// Some WMMA intrinsics have tstrides=0 values, assert on that as that
// would defeat this algorithm. We'll need to solve that if and when we want
// to support data tiling on WMMA intrinsics.
// would defeat this algorithm.
// TODO(bjacob): Resolve that to support WMMA intrinsics.
for (auto s : layout.tstrides) {
(void)s;
assert(s != 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_GPU_GPUTILESWIZZLEUTILS_H_
#define IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_GPU_GPUTILESWIZZLEUTILS_H_
#ifndef IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_
#define IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_

#include "iree/compiler/Codegen/Common/TileSwizzle.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
Expand Down Expand Up @@ -55,4 +55,4 @@ void interleave(TileSwizzle &swizzle, int srcIndex,

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_GPU_GPUTILESWIZZLEUTILS_H_
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_
Loading

0 comments on commit 3773a48

Please sign in to comment.