Skip to content

Commit

Permalink
[CPU][ArmSME] Update tiling to use all SME accumulators (iree-org#16389)
Browse files Browse the repository at this point in the history
Previously, we only tiled for a single SME accumulator. This patch
updates the lowering_config to make use of all SME accumulators.

This is done by increasing the tile size to [8]x[8] for f32 and to
[4]x[8] for f64. This lowers to four [4]x[4] 32-bit accumulators and
eight [2]x[2] 64-bit accumulators respectively.

Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>

ci-extra: build_test_all_arm64

---------

Signed-off-by: Benjamin Maxwell <benjamin.maxwell@arm.com>
  • Loading branch information
MacDue authored May 17, 2024
1 parent 4132d2e commit a3b74bc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
17 changes: 13 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1273,8 +1273,9 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics(
}

/// Utility to compute the tile sizes for AArch64 SME. Unlike other targets, the
/// tile sizes picked here must exactly match the SME hardware virtual tiles, as
/// there is currently no support for lowering non-standard shapes.
/// tile sizes picked here must exactly match multiples of the SME hardware
/// virtual tiles, as there is currently no support for lowering non-standard
/// shapes.
static void
getMatmulAArch64SMEVectorSizes(linalg::LinalgOp op,
SmallVectorImpl<int64_t> &sizes,
Expand All @@ -1287,13 +1288,21 @@ getMatmulAArch64SMEVectorSizes(linalg::LinalgOp op,
if (failed(elementType))
return;

// TODO(macdue): Come up with some heuristics to pick the appropriate tiling
// for SME, i.e. optimal layout based on static sizes.

if (elementType->isF32()) {
sizes.append({4, 4, 1});
// Tile for [8]x[8], this results in equal loads from both the A and B
// matrices and will use all four [4]x[4] 32-bit SME accumulators.
sizes.append({8, 8, 1});
scalableSizeFlags.append({true, true, false});
}

if (elementType->isF64()) {
sizes.append({2, 2, 1});
// Tile for [4]x[8], this results in loading twice as much from matrix B
// than and will use all eight [2]x[2] 64-bit SME accumulators.
// The B dimension is larger as it is known to be contiguous.
sizes.append({4, 8, 1});
scalableSizeFlags.append({true, true, false});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ module {
// SSVE-WITHOUT-SME: linalg.matmul
// SSVE-WITHOUT-SME-SAME: lowering_config = #[[CONFIG]]

// WITH-SME-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], {{\[}}[4], [4], 0], [0, 0, 1], [0, 0, 0]]>
// WITH-SME-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[64, 64, 0], {{\[}}[8], [8], 0], [0, 0, 1], [0, 0, 0]]>
// WITH-SME-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<CPUDoubleTilingExpert>
// WITH-SME: func.func @matmul_tensors()
// WITH-SME-SAME: translation_info = #[[TRANSLATION]]
Expand Down

0 comments on commit a3b74bc

Please sign in to comment.