From d5839584547ec2a258b12399f56ea3d006668dbb Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Wed, 25 Sep 2024 15:30:01 -0400 Subject: [PATCH] simplify `DataTiledMMAAttr::buildMmaOperation` (#18597) No need to compute `crossIntrinsicCount` and size the result vector accordingly when we can just have the `incrementIndices` helper tell when we have exhausted the index space. --------- Signed-off-by: Benoit Jacob --- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index cd70060b9a6c..d6d9e0e9d0fc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -981,18 +981,18 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( } /// Increment the mutable vector `indices` to traverse the index space below -/// `sizes`, with the last dimension moving fastest. -static void incrementIndices(MutableArrayRef indices, +/// `sizes`, with the last dimension moving fastest, or returns false if that +/// index space was exhausted. +static bool incrementIndices(MutableArrayRef indices, ArrayRef sizes) { - int rank = indices.size(); - for (int i = rank - 1; i >= 0; --i) { - ++indices[i]; - if (indices[i] == sizes[i]) { + for (int i = indices.size() - 1; i >= 0; --i) { + if (++indices[i] == sizes[i]) { indices[i] = 0; } else { - break; + return true; // Found an index that we could increment without wrapping. } } + return false; // All indices wrapped around. } /// Flattens the input vector `value` to 1-D. @@ -1023,17 +1023,13 @@ distributeMmaFragmentToIntrinsics(OpBuilder &builder, Location loc, Value value, }); int rank = internalShape.size(); auto strides = SmallVector(rank, 1); - int64_t crossIntrinsicCount = - std::reduce(crossIntrinsicShape.begin(), crossIntrinsicShape.end(), 1, - std::multiplies()); - SmallVector distributedValues(crossIntrinsicCount); + SmallVector distributedValues; SmallVector indices(rank, 0); - for (Value &distributedValue : distributedValues) { + do { Value extract = builder.create( loc, value, indices, internalShape, strides); - distributedValue = flattenVector(builder, loc, extract); - incrementIndices(indices, crossIntrinsicShape); - } + distributedValues.push_back(flattenVector(builder, loc, extract)); + } while (incrementIndices(indices, crossIntrinsicShape)); return distributedValues; } @@ -1057,11 +1053,11 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, } // Prepare Lhs/Rhs/Acc operand slices to feed the intrinsic. - auto intrinsicsLhs = distributeMmaFragmentToIntrinsics( + SmallVector intrinsicsLhs = distributeMmaFragmentToIntrinsics( builder, loc, lhs, getSwizzle(*this, MMAFragment::Lhs)); - auto intrinsicsRhs = distributeMmaFragmentToIntrinsics( + SmallVector intrinsicsRhs = distributeMmaFragmentToIntrinsics( builder, loc, rhs, getSwizzle(*this, MMAFragment::Rhs)); - auto intrinsicsAcc = distributeMmaFragmentToIntrinsics( + SmallVector intrinsicsAcc = distributeMmaFragmentToIntrinsics( builder, loc, acc, getSwizzle(*this, MMAFragment::Acc)); // Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation @@ -1092,12 +1088,10 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, }); SmallVector strides(intrinsicCType.getRank(), 1); SmallVector indices(accCrossIntrinsicShape.size(), 0); - for (int mu = 0; mu < getUnrollM(); ++mu) { - for (int nu = 0; nu < getUnrollN(); ++nu) { - acc = builder.create( - loc, intrinsicsAcc[mu * getUnrollN() + nu], acc, indices, strides); - incrementIndices(indices, accCrossIntrinsicShape); - } + for (Value intrAcc : intrinsicsAcc) { + acc = builder.create(loc, intrAcc, acc, + indices, strides); + incrementIndices(indices, accCrossIntrinsicShape); } return acc; }