Skip to content

Commit

Permalink
Simplifications around narrow dimensions in encodings. (iree-org#18607)
Browse files Browse the repository at this point in the history
* Drop the `kNarrowThreshold` constant, relying instead on the default
padding value.
* When reading an `encoding` attribute to tell if a `round_dims_to`
entry should be considered narrow, rely on the fact that we only ever
need one narrowest dimension in a given matmul to be considered narrow,
so the smallest `round_dims_to` entry is the narrow one; if all
`round_dims_to` entries are equal, the matmul is not narrow.
* Introduce a `MatmulNarrowDim` struct to unify helpers and group them
in `EncodingOps.{h,cpp}`.
* This enforces in the type system that at most one of the M or N
dimensions may be narrow, not both. Previously, we had different
structs/tuples, none of which enforced that, so we felt compiled to
write comments about the unenforced contract, and the concerned code was
scattered across different files.
* Remove the `getMatmulNarrow{M,N}` getters on `EncodingAttr`.
* Generally we are over-relying on TableGen class methods, which only
obfuscates things compared to functions declared manually in C++ files,
and the new `MatmulNarrowDim` struct allows replacing both these methods
by a single `getMatmulNarrowDim`, which also simplifies callers.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Sep 28, 2024
1 parent 34641dd commit 9e09115
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ enumerateMatmulTilesVMVX(linalg::ContractionDimensions cDims,
// codegen.query_tile_sizes op, so we disable dynamic tile shapes for
// batch_matmul. Also, they are not set up for narrow M/N matmul, so it is
// disabled when it is the case.
if (!cDims.batch.empty() || encoding.getMatmulNarrowM() ||
encoding.getMatmulNarrowN()) {
if (!cDims.batch.empty() || getMatmulNarrowDim(encoding)) {
hasUkernelSupport = false;
}
if (hasUkernelSupport) {
Expand Down Expand Up @@ -294,19 +293,20 @@ enumerateMatmulTileX86_64(TypeRange elementTypes,
/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such
/// information to host. For now, they are defined by host.
static TileMxNxK
chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles, int64_t matmulNarrowM,
int64_t matmulNarrowN,
chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
IREE::Encoding::MatmulNarrowDim narrowDim,
ArrayRef<int64_t> hostDefinedUpperBound = {}) {
assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) &&
"expected hostDefinedUpperBound is empty or has upper bound for {M, "
"N, K}");
// Handle narrow-N by transposing to reduce to narrow-M. Note: the
// enumeratedTiles currently only enumerate narrow-M cases.
if (matmulNarrowN && (!matmulNarrowM || matmulNarrowN < matmulNarrowM)) {
if (narrowDim.isN()) {
SmallVector<int64_t> newHostDefinedUpperBound(hostDefinedUpperBound);
std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]);
TileMxNxK tile = chooseMatmulTile(enumeratedTiles, matmulNarrowN, 0,
newHostDefinedUpperBound);
narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M;
TileMxNxK tile =
chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound);
std::swap(tile.M, tile.N);
return tile;
}
Expand Down Expand Up @@ -367,9 +367,9 @@ chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles, int64_t matmulNarrowM,
// are OK with the tile that has M==8 even though it requires some padding.
// Otherwise, we would be penalizing the tiles with M==8,4,2 and we would
// end up selecting the vecmat tile (M==1) for that case!
if (matmulNarrowM) {
if (narrowDim) {
ratedTile.paddingPenalty =
std::max<int64_t>(tile.M - llvm::PowerOf2Ceil(matmulNarrowM), 0);
std::max<int64_t>(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0);
}
ratedTile.productMxNxK = tile.M * tile.N * tile.K;
ratedTiles.push_back(ratedTile);
Expand Down Expand Up @@ -438,13 +438,11 @@ materializeEncodingForTarget(RankedTensorType tensorType,
if (enumeratedTileMxNxK.empty()) {
return failure();
}
int64_t matmulNarrowM = encoding.getMatmulNarrowM();
int64_t matmulNarrowN = encoding.getMatmulNarrowN();
auto narrowDim = IREE::Encoding::getMatmulNarrowDim(encoding);
// Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
// taking narrow dimensions into account.
TileMxNxK chosenTileMxNxK =
chooseMatmulTile(enumeratedTileMxNxK, matmulNarrowM, matmulNarrowN,
encoding.getRoundDimsToArray());
TileMxNxK chosenTileMxNxK = chooseMatmulTile(enumeratedTileMxNxK, narrowDim,
encoding.getRoundDimsToArray());

// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
// based on its operand index in the matmul.
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ bool isNarrowNResult(EncodingAttr encoding) {
return false;
}

int64_t narrowM = encoding.getMatmulNarrowM();
int64_t narrowN = encoding.getMatmulNarrowN();
return narrowN && (!narrowM || narrowM > narrowN);
return IREE::Encoding::getMatmulNarrowDim(encoding).isN();
}

SmallVector<int64_t>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,6 @@ def EncodingAttr :

/// Clones an encoding with a new bcast_map
EncodingAttr clone(AffineMap bcastMap);

/// Returns the M size from `round_dims_to` if the value is less than
/// kNarrowThreshold. Otherwise, returns zero.
int64_t getMatmulNarrowM();

/// Returns the N size from `round_dims_to` if the value is less than
/// kNarrowThreshold. Otherwise, returns zero.
int64_t getMatmulNarrowN();
}];

let genVerifyDecl = 0;
Expand Down
54 changes: 39 additions & 15 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,33 @@ std::optional<unsigned> EncodingAttr::mapDimToOperandIndex(int64_t dimPos) {
getAffineDimExpr(dimPos, getContext()));
}

MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp,
int narrowThreshold) {
linalg::ContractionDimensions cDims =
linalg::inferContractionDims(linalgOp).value();
auto map = linalgOp.getIndexingMapsArray().back();
auto outType = llvm::cast<ShapedType>(linalgOp.getDpsInits()[0].getType());
auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t {
return outType.getDimSize(
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext()))
.value());
};
// M or N can be empty instead of having an explicit dim size of 1 for matvec
// and vecmat, so set to 1 if empty.
int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]);
int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]);

MatmulNarrowDim narrowM, narrowN;
if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) {
narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize};
}
if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) {
narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize};
}

return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN;
}

ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() {
auto roundDimsTo = getRoundDimsTo();
if (!roundDimsTo) {
Expand All @@ -151,26 +178,23 @@ EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
AffineMapAttr::get(bcastMap), getRoundDimsTo());
}

int64_t EncodingAttr::getMatmulNarrowM() {
if (getOpType().getValue() != EncodingOpType::matmul) {
return 0;
MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
if (encoding.getOpType().getValue() != EncodingOpType::matmul) {
return {};
}
ArrayRef<int64_t> roundDimsTo = getRoundDimsToArray();
ArrayRef<int64_t> roundDimsTo = encoding.getRoundDimsToArray();
if (roundDimsTo.empty()) {
return 0;
return {};
}
return roundDimsTo[0] < kNarrowThreshold ? roundDimsTo[0] : 0;
}

int64_t EncodingAttr::getMatmulNarrowN() {
if (getOpType().getValue() != EncodingOpType::matmul) {
return 0;
int m = roundDimsTo[0];
int n = roundDimsTo[1];
if (m < n) {
return {MatmulNarrowDim::Dim::M, m};
}
ArrayRef<int64_t> roundDimsTo = getRoundDimsToArray();
if (roundDimsTo.empty()) {
return 0;
if (n < m) {
return {MatmulNarrowDim::Dim::N, n};
}
return roundDimsTo[1] < kNarrowThreshold ? roundDimsTo[1] : 0;
return {};
}

//===---------------------------------------------------------------------===//
Expand Down
42 changes: 38 additions & 4 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@

namespace mlir::iree_compiler::IREE::Encoding {

/// Threadshold that determines if a dimension is considered "narrow" or not.
constexpr int64_t kNarrowThreshold = 32;

/// Returns the encoding attribute from the type if there is an encoding.
/// Otherwise, returns null.
EncodingAttr getEncodingAttr(RankedTensorType type);
Expand All @@ -46,13 +43,50 @@ EncodingAttr getEncodingAttr(RankedTensorType type);
FailureOr<linalg::ContractionDimensions>
getEncodingContractionDims(EncodingAttr encoding);

// Assign a name to operand indices for clarity
/// Assign a name to operand indices for clarity
const int64_t MATMUL_LHS = 0;
const int64_t MATMUL_RHS = 1;
const int64_t MATMUL_RESULT = 2;

/// Convert operand index to strings for printing
std::string stringifyOperandIndex(IntegerAttr);

/// Designates a dimension in a matmul (either the M or the N dimension) as
/// being "narrow", i.e. small enough that we bother lowering the amount of
/// padding along that dimension compared to how padding we apply to
/// sufficiently large dimensions.
struct MatmulNarrowDim {
// Enumerates dimensions of a matmul that may be labelled as narrow.
enum class Dim {
None,
M,
N,
};
Dim dim = Dim::None; // Which dimension is designated by *this.
int64_t size = 0; // Size of the designated dimension, or kDynamic.

explicit operator bool() const { return dim != Dim::None; }
bool isM() const { return dim == Dim::M; }
bool isN() const { return dim == Dim::N; }
};

/// Returns the narrow dim in a given `linalgOp`, with respect to the given
/// `narrowThreshold` below which a dimension is eligible to be considered
/// narrow. If both M and N are narrow, M is returned. If neither M nor N are
/// narrow, this returns a default-constructed falsish value.
MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp,
int narrowThreshold);

/// Returns the narrow dim in a given `encoding`. This works by inspecting
/// the `round_dims_to` array attribute in the `encoding`. If the
/// `round_dims_to` of one dimension (M or N) is smaller than the other, then
/// that's the narrow dimension, because the only way it would have been set
/// to be smaller in the first place, is if we previously flagged that dimension
/// as narrow. If the `round_dims_to` of the M and N dimensions agree, then
/// neither is a narrow dimension and this returns a default-constructed falsish
/// value.
MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding);

} // namespace mlir::iree_compiler::IREE::Encoding

#endif // IREE_COMPILER_DIALECT_ENCODING_IR_ENCODINGOPS_H_
53 changes: 5 additions & 48 deletions compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,47 +45,6 @@ Value setEncoding(OpBuilder &builder, Location loc, Value source,
return builder.create<IREE::Encoding::SetEncodingOp>(loc, resultType, source);
};

struct MatmulNarrowSizes {
std::optional<int64_t> M, N;
};

// Returns the minimum of static sizes of the M/N-dimensions in the types of the
// Ouput.
static MatmulNarrowSizes getMatmulNarrowSizes(ShapedType outType,
linalg::LinalgOp linalgOp) {
linalg::ContractionDimensions cDims =
linalg::inferContractionDims(linalgOp).value();
auto map = linalgOp.getIndexingMapsArray().back();
auto getOutputSizeAtDimPos = [&](unsigned dimPos) -> int64_t {
return outType.getDimSize(
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext()))
.value());
};
// M or N can be empty instead of having an explicit dim size of 1 for matvec
// and vecmat, so set to 1 if empty.
int64_t M = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]);
int64_t N = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]);

MatmulNarrowSizes narrow;
if (!ShapedType::isDynamic(M) && M < IREE::Encoding::kNarrowThreshold) {
narrow.M = M;
}
if (!ShapedType::isDynamic(N) && N < IREE::Encoding::kNarrowThreshold) {
narrow.N = N;
}

// Only pick 1 if both are present
if (narrow.M && narrow.N) {
if (*narrow.M <= *narrow.N) {
narrow.N.reset();
} else {
narrow.M.reset();
}
}

return narrow;
}

static Value unsetEncodingAndExtractSlice(OpBuilder &builder, Location loc,
Value source,
SmallVector<OpFoldResult> sizes) {
Expand Down Expand Up @@ -247,22 +206,20 @@ class setContractionOpEncoding
}
SmallVector<Type> elemTypes = {lhsElemType, rhsElemType, outElemType};

MatmulNarrowSizes narrowSizes =
getMatmulNarrowSizes(cast<ShapedType>(out.getType()), linalgOp);
auto narrowDim = IREE::Encoding::getMatmulNarrowDim(linalgOp, padFactor);

Location loc = linalgOp.getLoc();
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();

auto opType = IREE::Encoding::EncodingOpType::matmul;
auto setEncodingWrapper = [&](Value src, int64_t operandIndex) -> Value {
SmallVector<int64_t> roundDimsTo(3, padFactor);
if (narrowSizes.M) {
roundDimsTo[0] = llvm::PowerOf2Ceil(narrowSizes.M.value());
if (narrowDim.isM()) {
roundDimsTo[0] = llvm::PowerOf2Ceil(narrowDim.size);
}
if (narrowSizes.N) {
roundDimsTo[1] = llvm::PowerOf2Ceil(narrowSizes.N.value());
if (narrowDim.isN()) {
roundDimsTo[1] = llvm::PowerOf2Ceil(narrowDim.size);
}

auto encoding = EncodingAttr::get(linalgOp.getContext(), operandIndex,
opType, elemTypes, maps,
/*bcastMap=*/std::nullopt, roundDimsTo);
Expand Down

0 comments on commit 9e09115

Please sign in to comment.