diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index fe6e4fb94fd9..c25434d19b29 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -1273,8 +1273,9 @@ static void getMatmulVectorSizesUsingFullVectorHeuristics( } /// Utility to compute the tile sizes for AArch64 SME. Unlike other targets, the -/// tile sizes picked here must exactly match the SME hardware virtual tiles, as -/// there is currently no support for lowering non-standard shapes. +/// tile sizes picked here must exactly match multiples of the SME hardware +/// virtual tiles, as there is currently no support for lowering non-standard +/// shapes. static void getMatmulAArch64SMEVectorSizes(linalg::LinalgOp op, SmallVectorImpl &sizes, @@ -1287,13 +1288,21 @@ getMatmulAArch64SMEVectorSizes(linalg::LinalgOp op, if (failed(elementType)) return; + // TODO(macdue): Come up with some heuristics to pick the appropriate tiling + // for SME, i.e. optimal layout based on static sizes. + if (elementType->isF32()) { - sizes.append({4, 4, 1}); + // Tile for [8]x[8], this results in equal loads from both the A and B + // matrices and will use all four [4]x[4] 32-bit SME accumulators. + sizes.append({8, 8, 1}); scalableSizeFlags.append({true, true, false}); } if (elementType->isF64()) { - sizes.append({2, 2, 1}); + // Tile for [4]x[8], this results in loading twice as much from matrix B + // than and will use all eight [2]x[2] 64-bit SME accumulators. + // The B dimension is larger as it is known to be contiguous. + sizes.append({4, 8, 1}); scalableSizeFlags.append({true, true, false}); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir index 688ce7145f90..e32f260bd5fd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_sve_lowering_strategy.mlir @@ -108,7 +108,7 @@ module { // SSVE-WITHOUT-SME: linalg.matmul // SSVE-WITHOUT-SME-SAME: lowering_config = #[[CONFIG]] -// WITH-SME-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config +// WITH-SME-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config // WITH-SME-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info // WITH-SME: func.func @matmul_tensors() // WITH-SME-SAME: translation_info = #[[TRANSLATION]]