Skip to content

Commit

Permalink
simplify DataTiledMMAAttr::buildMmaOperation (iree-org#18597)
Browse files Browse the repository at this point in the history
No need to compute `crossIntrinsicCount` and size the result vector
accordingly when we can just have the `incrementIndices` helper tell
when we have exhausted the index space.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Sep 25, 2024
1 parent 672ae82 commit d583958
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,18 +981,18 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
}

/// Increment the mutable vector `indices` to traverse the index space below
/// `sizes`, with the last dimension moving fastest.
static void incrementIndices(MutableArrayRef<int64_t> indices,
/// `sizes`, with the last dimension moving fastest, or returns false if that
/// index space was exhausted.
static bool incrementIndices(MutableArrayRef<int64_t> indices,
ArrayRef<int64_t> sizes) {
int rank = indices.size();
for (int i = rank - 1; i >= 0; --i) {
++indices[i];
if (indices[i] == sizes[i]) {
for (int i = indices.size() - 1; i >= 0; --i) {
if (++indices[i] == sizes[i]) {
indices[i] = 0;
} else {
break;
return true; // Found an index that we could increment without wrapping.
}
}
return false; // All indices wrapped around.
}

/// Flattens the input vector `value` to 1-D.
Expand Down Expand Up @@ -1023,17 +1023,13 @@ distributeMmaFragmentToIntrinsics(OpBuilder &builder, Location loc, Value value,
});
int rank = internalShape.size();
auto strides = SmallVector<int64_t>(rank, 1);
int64_t crossIntrinsicCount =
std::reduce(crossIntrinsicShape.begin(), crossIntrinsicShape.end(), 1,
std::multiplies<int64_t>());
SmallVector<Value> distributedValues(crossIntrinsicCount);
SmallVector<Value> distributedValues;
SmallVector<int64_t> indices(rank, 0);
for (Value &distributedValue : distributedValues) {
do {
Value extract = builder.create<vector::ExtractStridedSliceOp>(
loc, value, indices, internalShape, strides);
distributedValue = flattenVector(builder, loc, extract);
incrementIndices(indices, crossIntrinsicShape);
}
distributedValues.push_back(flattenVector(builder, loc, extract));
} while (incrementIndices(indices, crossIntrinsicShape));
return distributedValues;
}

Expand All @@ -1057,11 +1053,11 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
}

// Prepare Lhs/Rhs/Acc operand slices to feed the intrinsic.
auto intrinsicsLhs = distributeMmaFragmentToIntrinsics(
SmallVector<Value> intrinsicsLhs = distributeMmaFragmentToIntrinsics(
builder, loc, lhs, getSwizzle(*this, MMAFragment::Lhs));
auto intrinsicsRhs = distributeMmaFragmentToIntrinsics(
SmallVector<Value> intrinsicsRhs = distributeMmaFragmentToIntrinsics(
builder, loc, rhs, getSwizzle(*this, MMAFragment::Rhs));
auto intrinsicsAcc = distributeMmaFragmentToIntrinsics(
SmallVector<Value> intrinsicsAcc = distributeMmaFragmentToIntrinsics(
builder, loc, acc, getSwizzle(*this, MMAFragment::Acc));

// Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
Expand Down Expand Up @@ -1092,12 +1088,10 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
});
SmallVector<int64_t> strides(intrinsicCType.getRank(), 1);
SmallVector<int64_t> indices(accCrossIntrinsicShape.size(), 0);
for (int mu = 0; mu < getUnrollM(); ++mu) {
for (int nu = 0; nu < getUnrollN(); ++nu) {
acc = builder.create<vector::InsertStridedSliceOp>(
loc, intrinsicsAcc[mu * getUnrollN() + nu], acc, indices, strides);
incrementIndices(indices, accCrossIntrinsicShape);
}
for (Value intrAcc : intrinsicsAcc) {
acc = builder.create<vector::InsertStridedSliceOp>(loc, intrAcc, acc,
indices, strides);
incrementIndices(indices, accCrossIntrinsicShape);
}
return acc;
}
Expand Down

0 comments on commit d583958

Please sign in to comment.