diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 900d220d7720..c82ce6892ba2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -53,8 +53,11 @@ iree_compiler_cc_library( "KernelDispatch.cpp", "LLVMCPUAssignConstantOrdinals.cpp", "LLVMCPUAssignImportOrdinals.cpp", + "LLVMCPUBreakDownSubbyteExtend.cpp", "LLVMCPUCheckIRBeforeLLVMConversion.cpp", "LLVMCPUEmitVectorizationRemarks.cpp", + "LLVMCPUFoldMemRefAliasOps.cpp", + "LLVMCPUFoldVectorContractUnitDims.cpp", "LLVMCPULinkExecutables.cpp", "LLVMCPULowerExecutableTarget.cpp", "LLVMCPULowerToUKernels.cpp", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 8d5feafc38db..bd0378980295 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -54,8 +54,11 @@ iree_cc_library( "KernelDispatch.cpp" "LLVMCPUAssignConstantOrdinals.cpp" "LLVMCPUAssignImportOrdinals.cpp" + "LLVMCPUBreakDownSubbyteExtend.cpp" "LLVMCPUCheckIRBeforeLLVMConversion.cpp" "LLVMCPUEmitVectorizationRemarks.cpp" + "LLVMCPUFoldMemRefAliasOps.cpp" + "LLVMCPUFoldVectorContractUnitDims.cpp" "LLVMCPULinkExecutables.cpp" "LLVMCPULowerExecutableTarget.cpp" "LLVMCPULowerToUKernels.cpp" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp new file mode 100644 index 000000000000..6d9b90384145 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUBreakDownSubbyteExtend.cpp @@ -0,0 +1,387 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-breakdown-subbyte-extend" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { +namespace { + +template +static Value shuffleMaskShift(PatternRewriter &rewriter, Location loc, + SmallVector shuffleInputs, + int64_t srcBitWidth, int64_t vectorSize) { + auto shuffleInType = llvm::cast(shuffleInputs[0].getType()); + auto shuffleResultType = + VectorType::get({vectorSize}, shuffleInType.getElementType()); + int64_t dstBitWidth = shuffleInType.getElementTypeBitWidth(); + T maskBase = (1u << srcBitWidth) - 1; + + SmallVector maskArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + maskArray[elemNum] = maskBase << (elemNum * srcBitWidth % dstBitWidth); + } + auto maskVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, maskArray)); + LDBG("maskVals: " << maskVals); + SmallVector shruiArray(shuffleResultType.getNumElements()); + for (T elemNum = 0; elemNum < shuffleResultType.getNumElements(); elemNum++) { + shruiArray[elemNum] = elemNum * srcBitWidth % dstBitWidth; + } + auto shruiVals = rewriter.create( + loc, shuffleResultType, + DenseIntElementsAttr::get(shuffleResultType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + + int64_t dstSize = vectorSize * shuffleInputs.size(); + auto newVectorType = + VectorType::get({dstSize}, shuffleResultType.getElementType()); + Value newVector = rewriter.create( + loc, newVectorType, rewriter.getZeroAttr(newVectorType)); + + for (auto shuffleIn : llvm::enumerate(shuffleInputs)) { + SmallVector shuffleArray(vectorSize); + for (int64_t elemNum = 0; elemNum < vectorSize; elemNum++) { + shuffleArray[elemNum] = + elemNum / (vectorSize / shuffleInType.getNumElements()); + } + Value shuffleResult = rewriter.create( + loc, shuffleIn.value(), shuffleIn.value(), shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value andResult = + rewriter.create(loc, shuffleResult, maskVals); + LDBG("andResult: " << andResult); + + Value shruiResult = + rewriter.create(loc, andResult, shruiVals); + LDBG("shruiResult: " << shruiResult); + + int64_t offset = shuffleIn.index() * vectorSize; + newVector = rewriter.create( + loc, shruiResult, newVector, offset, 1); + } + return newVector; +} + +static std::optional> +getLoadsForExtend(arith::ExtUIOp extOp) { + Value extSource = extOp.getIn(); + auto shapeCastOp = extSource.getDefiningOp(); + if (!shapeCastOp) { + return std::nullopt; + } + Value shapeCastSource = shapeCastOp.getSource(); + auto insertOp = shapeCastSource.getDefiningOp(); + if (!insertOp) { + return std::nullopt; + } + SmallVector loads; + while (insertOp) { + Value insert = insertOp.getSource(); + auto insertShapeCastOp = insert.getDefiningOp(); + if (!insertShapeCastOp) { + return std::nullopt; + } + auto loadOp = insertShapeCastOp.getSource().getDefiningOp(); + if (!loadOp) { + return std::nullopt; + } + loads.push_back(loadOp.getResult()); + insertOp = insertOp.getDest().getDefiningOp(); + } + return loads; +} + +struct BreakDownSubbyteExtend final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + + SmallVector sources{extOp.getIn()}; + if (auto loads = getLoadsForExtend(extOp)) { + sources = *loads; + } + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(dstElemBitwidth) || srcElemBitwidth != 4) + return failure(); + + if (dstElemBitwidth != 32 && dstElemBitwidth != 16) { + return failure(); + } + + int64_t vectorSizeBits = 512; + int64_t vectorSize = vectorSizeBits / dstElemBitwidth; + int64_t shuffleInputSizeBits = vectorSize * srcElemBitwidth; + int64_t shuffleInputSize = shuffleInputSizeBits / dstElemBitwidth; + auto shuffleInputType = + VectorType::get({shuffleInputSize}, extuiDstType.getElementType()); + Value shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + SmallVector shuffleInputs; + + for (int sourceIdx = 0; sourceIdx < sources.size(); sourceIdx++) { + Value source = sources[sourceIdx]; + VectorType sourceType = llvm::cast(source.getType()); + SmallVector sourceShape(sourceType.getShape()); + int64_t innerSize = sourceShape.back(); + if (!llvm::isPowerOf2_64(innerSize)) { + return failure(); + } + for (int64_t i = 0; i < sourceType.getNumElements() / innerSize; i++) { + SmallVector indices; + int64_t numElems = i; + SmallVector sourceOuterShape(sourceShape.begin(), + sourceShape.end() - 1); + for (int64_t size : llvm::reverse(sourceOuterShape)) { + indices.push_back(numElems % size); + numElems /= size; + } + std::reverse(indices.begin(), indices.end()); + + Value innerSlice; + if (indices.size()) { + innerSlice = rewriter.create(extOp.getLoc(), + source, indices); + } else { + innerSlice = source; + } + VectorType innerSliceType = + llvm::cast(innerSlice.getType()); + int64_t numExtractedBits = + innerSliceType.getNumElements() * srcElemBitwidth; + if (numExtractedBits / dstElemBitwidth < 1) { + LDBG("extract not big enough: " << numExtractedBits / + dstElemBitwidth); + return failure(); + } + auto bitCastType = VectorType::get({numExtractedBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, innerSlice); + LDBG("innerSlice: " << innerSlice); + // LDBG("bitCastResult: " << bitCastResult); + + if (numExtractedBits >= shuffleInputSizeBits) { + for (int64_t extractOffset = 0; + extractOffset < numExtractedBits / dstElemBitwidth; + extractOffset += shuffleInputSize) { + Value extractedSlice = + rewriter.create( + extOp.getLoc(), bitCastResult, extractOffset, + shuffleInputSize, 1); + shuffleInputs.push_back(extractedSlice); + LDBG("extractedSlice: " << extractedSlice); + // vector = + // rewriter.create(extOp.getLoc(), + // extractedSlice, vector, SmallVector{offset}, + // SmallVector{1}); + } + } else { + int64_t offset = + i * numExtractedBits / dstElemBitwidth % shuffleInputSize; + shuffleInput = rewriter.create( + extOp.getLoc(), bitCastResult, shuffleInput, + SmallVector{offset}, SmallVector{1}); + if (offset + numExtractedBits / dstElemBitwidth == shuffleInputSize) { + shuffleInputs.push_back(shuffleInput); + shuffleInput = rewriter.create( + extOp.getLoc(), shuffleInputType, + rewriter.getZeroAttr(shuffleInputType)); + } + } + } + } + + Value newVector; + if (dstElemBitwidth == 32) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } else if (dstElemBitwidth == 16) { + newVector = shuffleMaskShift( + rewriter, extOp.getLoc(), shuffleInputs, srcElemBitwidth, vectorSize); + } + rewriter.replaceOpWithNewOp(extOp, extuiDstType, + newVector); + + return success(); + } +}; + +struct BreakDownSubbyteExtendFlatten final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtUIOp extOp, + PatternRewriter &rewriter) const override { + VectorType extuiSrcType = + llvm::dyn_cast(extOp.getIn().getType()); + VectorType extuiDstType = llvm::dyn_cast(extOp.getType()); + if (!extuiSrcType || !extuiDstType) { + return failure(); + } + LDBG("extuiSrcType: " << extuiSrcType); + LDBG("extuiDstType: " << extuiDstType); + + // We only have power-of-two bitwidth cases for now. + if (!llvm::isPowerOf2_64(extuiSrcType.getNumElements())) + return failure(); + + int64_t srcElemBitwidth = extuiSrcType.getElementTypeBitWidth(); + int64_t dstElemBitwidth = extuiDstType.getElementTypeBitWidth(); + LDBG("srcElemBitwidth: " << srcElemBitwidth); + LDBG("dstElemBitwidth: " << dstElemBitwidth); + + int64_t numBits = srcElemBitwidth * extuiSrcType.getNumElements(); + if (numBits / dstElemBitwidth < 1) { + return failure(); + } + + VectorType flattenedType = VectorType::get({extuiSrcType.getNumElements()}, + extuiSrcType.getElementType()); + Value shapeCastFlatten = rewriter.create( + extOp.getLoc(), flattenedType, extOp.getIn()); + + auto bitCastType = VectorType::get({numBits / dstElemBitwidth}, + extuiDstType.getElementType()); + Value bitCastResult = rewriter.create( + extOp.getLoc(), bitCastType, shapeCastFlatten); + LDBG("bitCastResult: " << bitCastResult); + + SmallVector shuffleArray(extuiDstType.getNumElements()); + for (int64_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shuffleArray[elemNum] = elemNum / (extuiDstType.getNumElements() / + bitCastType.getNumElements()); + } + + Value shuffleResult = rewriter.create( + extOp.getLoc(), bitCastResult, bitCastResult, shuffleArray); + LDBG("shuffleResult: " << shuffleResult); + + Value shapeCastUnflatten = rewriter.create( + extOp.getLoc(), extuiDstType, shuffleResult); + Value maskVals, shruiVals; + if (dstElemBitwidth == 32) { + int32_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int32_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else if (dstElemBitwidth == 16) { + int16_t maskBase = (1u << srcElemBitwidth) - 1; + SmallVector maskArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + maskArray[elemNum] = maskBase + << (elemNum * srcElemBitwidth % dstElemBitwidth); + } + maskVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, maskArray)); + LDBG("maskVals: " << maskVals); + + SmallVector shruiArray(extuiDstType.getNumElements()); + for (int16_t elemNum = 0; elemNum < extuiDstType.getNumElements(); + elemNum++) { + shruiArray[elemNum] = elemNum * srcElemBitwidth % dstElemBitwidth; + } + shruiVals = rewriter.create( + extOp.getLoc(), extuiDstType, + DenseIntElementsAttr::get(extuiDstType, shruiArray)); + LDBG("shruiVals: " << shruiVals); + } else { + return failure(); + } + + Value andResult = rewriter.create( + extOp.getLoc(), shapeCastUnflatten, maskVals); + LDBG("andResult: " << andResult); + + rewriter.replaceOpWithNewOp(extOp, andResult, shruiVals); + + return success(); + } +}; + +struct LLVMCPUBreakDownSubbyteExtendPass final + : public LLVMCPUBreakDownSubbyteExtendBase< + LLVMCPUBreakDownSubbyteExtendPass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + { + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // For the case when the innermost dimension of the src type is too small to + // fill a single element of the dst type. + // { + // RewritePatternSet patterns(context); + // patterns.add(context); + // vector::populateVectorShapeCastLoweringPatterns(patterns); + // if (failed(applyPatternsAndFoldGreedily(getOperation(), + // std::move(patterns)))) { + // return signalPassFailure(); + // } + // } + } +}; + +} // namespace + +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass() { + return std::make_unique(); +} + +void populateLLVMCPUBreakDownSubbyteExtendPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp new file mode 100644 index 000000000000..fc8c40dfb2e2 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldMemRefAliasOps.cpp @@ -0,0 +1,283 @@ +//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass folds loading/storing from/to subview ops into +// loading/storing from/to the original memref. +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-memref-alias-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Merges expand_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfExpandShapeOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; + +/// Merges collapse_shape operation with load/transferRead operation. +template +class LLVMCPULoadOpOfCollapseShapeOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy loadOp, + PatternRewriter &rewriter) const override; +}; +} // namespace + +static SmallVector +calculateExpandedAccessIndices(AffineMap affineMap, + const SmallVector &indices, Location loc, + PatternRewriter &rewriter) { + SmallVector indicesOfr(llvm::to_vector( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }))); + SmallVector expandedIndices; + for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, affineMap.getSubMap({i}), indicesOfr); + expandedIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return expandedIndices; +} + +static LogicalResult +resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, + memref::ExpandShapeOp expandShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + // The below implementation uses computeSuffixProduct method, which only + // allows int64_t values (i.e., static shape). Bail out if it has dynamic + // shapes. + if (!expandShapeOp.getResultType().hasStaticShape()) + return failure(); + + MLIRContext *ctx = rewriter.getContext(); + for (ArrayRef groups : expandShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + int64_t groupSize = groups.size(); + + // Construct the expression for the index value w.r.t to expand shape op + // source corresponding the indices wrt to expand shape op result. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + SmallVector dims(groupSize); + bindDimsList(ctx, MutableArrayRef{dims}); + AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + + /// Apply permutation and create AffineApplyOp. + SmallVector dynamicIndices(groupSize); + for (int64_t i = 0; i < groupSize; i++) + dynamicIndices[i] = indices[groups[i]]; + + // Creating maximally folded and composd affine.apply composes better with + // other transformations without interleaving canonicalization passes. + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/groupSize, + /*numSymbols=*/0, srcIndexExpr), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + return success(); +} + +static LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + memref::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl &sourceIndices) { + int64_t cnt = 0; + SmallVector tmp(indices.size()); + SmallVector dynamicIndices; + for (ArrayRef groups : collapseShapeOp.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + dynamicIndices.push_back(indices[cnt++]); + int64_t groupSize = groups.size(); + + // Calculate suffix product for all collapse op source dimension sizes. + SmallVector sizes(groupSize); + for (int64_t i = 0; i < groupSize; ++i) + sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); + SmallVector suffixProduct = computeSuffixProduct(sizes); + + // Derive the index values along all dimensions of the source corresponding + // to the index wrt to collapsed shape op output. + auto d0 = rewriter.getAffineDimExpr(0); + SmallVector delinearizingExprs = delinearize(d0, suffixProduct); + + // Construct the AffineApplyOp for each delinearizingExpr. + for (int64_t i = 0; i < groupSize; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, + AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, + delinearizingExprs[i]), + dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + dynamicIndices.clear(); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast(collapseShapeOp.getViewSource().getType()).getRank(); + for (int64_t i = 0; i < srcRank; i++) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, zeroAffineMap, dynamicIndices); + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + return success(); +} + +/// Helpers to access the memref operand for each op. +template +static Value getMemRefOperand(LoadOrStoreOpTy op) { + return op.getMemref(); +} + +static Value getMemRefOperand(vector::TransferReadOp op) { + return op.getSource(); +} + +static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } + +template +LogicalResult LLVMCPULoadOpOfExpandShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto expandShapeOp = + getMemRefOperand(loadOp).template getDefiningOp(); + + if (!expandShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesExpandShape( + loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), expandShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +template +LogicalResult LLVMCPULoadOpOfCollapseShapeOpFolder::matchAndRewrite( + OpTy loadOp, PatternRewriter &rewriter) const { + auto collapseShapeOp = getMemRefOperand(loadOp) + .template getDefiningOp(); + + if (!collapseShapeOp) + return failure(); + + SmallVector indices(loadOp.getIndices().begin(), + loadOp.getIndices().end()); + // For affine ops, we need to apply the map to get the operands to get the + // "actual" indices. + if (auto affineLoadOp = + dyn_cast(loadOp.getOperation())) { + AffineMap affineMap = affineLoadOp.getAffineMap(); + auto expandedIndices = calculateExpandedAccessIndices( + affineMap, indices, loadOp.getLoc(), rewriter); + indices.assign(expandedIndices.begin(), expandedIndices.end()); + } + SmallVector sourceIndices; + if (failed(resolveSourceIndicesCollapseShape( + loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) + return failure(); + llvm::TypeSwitch(loadOp) + .Case([&](vector::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), collapseShapeOp.getViewSource(), + sourceIndices); + }) + .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); + return success(); +} + +void populateLLVMCPUFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { + patterns.add, + LLVMCPULoadOpOfCollapseShapeOpFolder>( + patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct LLVMCPUFoldMemRefAliasOpsPass final + : public LLVMCPUFoldMemRefAliasOpsBase { + void runOnOperation() override; +}; + +} // namespace + +void LLVMCPUFoldMemRefAliasOpsPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(patterns); + populateLLVMCPUFoldMemRefAliasOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass() { + return std::make_unique(); +} + +} // namespace iree_compiler +} // namespace mlir \ No newline at end of file diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp new file mode 100644 index 000000000000..ac01d81ce392 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUFoldVectorContractUnitDims.cpp @@ -0,0 +1,347 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- LLVMCPUFoldVectorContractUnitDims.cpp - Pass to fold unit dims of +// vector.contract ops -===// +// +// Patterns to fold away unit dimensions on `vector.contract` ops +// +//===----------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" +#include "iree/compiler/Codegen/LLVMCPU/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-llvmcpu-fold-unit-reduction-dims" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace iree_compiler { + +// Given a `vector.contract` op and a set of indices to fold, this op rewrites +// the `vector.contract` op with surrounding `vector.shape_cast` ops to fold +// away the indicated indices. +static FailureOr +dropFoldableUnitIndices(PatternRewriter &rewriter, + vector::ContractionOp contractOp, + SmallVector foldIndices) { + SmallVector contractShape = *contractOp.getShapeForUnroll(); + SmallVector iteratorTypes = + contractOp.getIteratorTypesArray(); + auto indexingMaps = contractOp.getIndexingMapsArray(); + SmallVector> dstShapes; + SmallVector> dstExprs; + SmallVector inputs( + {contractOp.getLhs(), contractOp.getRhs(), contractOp.getAcc()}); + llvm::SetVector foldableDims; + for (int64_t dim : foldIndices) + foldableDims.insert(dim); + + for (AffineMap map : indexingMaps) { + SmallVector dstShape; + SmallVector dstExpr; + for (const auto &expr : enumerate(map.getResults())) { + if (auto dimExpr = expr.value().dyn_cast()) { + if (!foldableDims.contains(dimExpr.getPosition())) { + dstShape.push_back(contractShape[dimExpr.getPosition()]); + unsigned numSkipped = 0; + for (int64_t ind : foldIndices) { + if (dimExpr.getPosition() > ind) { + numSkipped++; + } + } + dstExpr.push_back( + rewriter.getAffineDimExpr(dimExpr.getPosition() - numSkipped)); + } + } else { + return failure(); + } + } + dstShapes.push_back(dstShape); + dstExprs.push_back(dstExpr); + } + + SmallVector newInputs; + SmallVector newIndexingMaps; + SmallVector newIteratorTypes; + for (auto iter : enumerate(iteratorTypes)) { + if (!foldableDims.contains(iter.index())) { + newIteratorTypes.push_back(iter.value()); + } + } + + for (int i = 0; i < 3; i++) { + // Shape unchanged + if (dstShapes[i].size() == indexingMaps[i].getResults().size()) { + newInputs.push_back(inputs[i]); + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newIndexingMaps.push_back(newIndexingMap); + continue; + } + if (dstShapes[i].size() == 0) { + return failure(); + } + VectorType inputVecType = llvm::cast(inputs[i].getType()); + VectorType dstType = + VectorType::get(dstShapes[i], inputVecType.getElementType()); + + Value result; + auto extsiop = inputs[i].getDefiningOp(); + auto extuiop = inputs[i].getDefiningOp(); + if (!extsiop && !extuiop) { + result = rewriter.create(contractOp.getLoc(), + dstType, inputs[i]); + } else { + Value extIn = extsiop ? extsiop.getIn() : extuiop.getIn(); + VectorType extInType = llvm::dyn_cast(extIn.getType()); + VectorType shapeCastOutType = + VectorType::get(dstType.getShape(), extInType.getElementType()); + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), shapeCastOutType, extIn); + result = extsiop ? rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult() + : rewriter + .create(contractOp.getLoc(), + dstType, shapeCastResult) + .getResult(); + } + AffineMap newIndexingMap = + AffineMap::get(/*dimCount=*/contractShape.size() - foldIndices.size(), + /*symCount=*/0, dstExprs[i], contractOp.getContext()); + newInputs.push_back(result); + newIndexingMaps.push_back(newIndexingMap); + } + auto newContract = + rewriter + .create( + contractOp.getLoc(), newInputs[0], newInputs[1], newInputs[2], + rewriter.getAffineMapArrayAttr(newIndexingMaps), + rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( + newIteratorTypes, + [&](vector::IteratorType t) -> mlir::Attribute { + return vector::IteratorTypeAttr::get(rewriter.getContext(), + t); + })))) + .getResult(); + return newContract; +} + +// This pattern matches on a `vector.contract` op with unit size dimensions, and +// folds these dimensions away +class DropVectorContractUnitDims final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + LDBG("vector.contract op:\n" << contractOp); + VectorType outputType = + llvm::dyn_cast(contractOp.getAcc().getType()); + if (!outputType) { + return failure(); + } + + auto iteratorTypes = contractOp.getIteratorTypesArray(); + SmallVector contractDims = *contractOp.getShapeForUnroll(); + unsigned numParallel = 0; + unsigned numReduction = 0; + SmallVector unitParallelDims; + SmallVector unitReductionDims; + SmallVector foldableDims; + for (auto size : enumerate(contractDims)) { + if (iteratorTypes[size.index()] == vector::IteratorType::parallel) { + numParallel++; + if (size.value() == 1) { + unitParallelDims.push_back(size.index()); + } + } else { + numReduction++; + if (size.value() == 1) { + unitReductionDims.push_back(size.index()); + } + } + } + if (numReduction && numReduction == unitReductionDims.size()) { + foldableDims.append(unitReductionDims.begin(), + unitReductionDims.end() - 1); + } else { + foldableDims.append(unitReductionDims.begin(), unitReductionDims.end()); + } + if (numParallel && numParallel == unitParallelDims.size()) { + foldableDims.append(unitParallelDims.begin() + 1, unitParallelDims.end()); + } else { + foldableDims.append(unitParallelDims.begin(), unitParallelDims.end()); + } + if (!foldableDims.size()) { + return failure(); + } + + FailureOr maybeNewContract = + dropFoldableUnitIndices(rewriter, contractOp, foldableDims); + if (failed(maybeNewContract)) { + return failure(); + } + Value newContract = maybeNewContract.value(); + LDBG("Replaced vector.contract:\n" << newContract); + + VectorType newOutputType = + llvm::dyn_cast(newContract.getType()); + if (outputType != newOutputType) { + // Reshape output of new vector.contract if needed + Value shapeCastResult = rewriter.create( + contractOp.getLoc(), outputType, newContract); + rewriter.replaceOp(contractOp, shapeCastResult); + } else { + rewriter.replaceOp(contractOp, newContract); + } + + return success(); + } +}; + +// This pattern matches on a sequence of +// `vector.shape_cast->vector.contract->vector.shape_cast` within an `scf.for` +// op, where the shape cast ops are casting an argument of the `scf.for` op and +// the yielded result of the `scf.for` op. Once matched, the `vector.shape_cast` +// ops are hoisted out of the `scf.for` op. +class HoistShapeCastOutOfSCFFor final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override { + LDBG("forOp:\n" << forOp); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + std::optional> + hoistableShapeCast = std::nullopt; + int initArgIdx; + for (Value result : yieldOp.getOperation()->getOperands()) { + auto outputShapeCastOp = result.getDefiningOp(); + if (!outputShapeCastOp) { + continue; + } + LDBG("outputShapeCastOp:\n" << outputShapeCastOp); + auto contractOp = + outputShapeCastOp.getSource().getDefiningOp(); + if (!contractOp) { + continue; + } + LDBG("contractOp:\n" << contractOp); + Value acc = contractOp.getAcc(); + auto inputShapeCastOp = acc.getDefiningOp(); + if (!inputShapeCastOp) { + continue; + } + LDBG("inputShapeCastOp:\n" << inputShapeCastOp); + Value input = inputShapeCastOp.getSource(); + auto blockArg = dyn_cast(input); + if (!blockArg) { + continue; + } + LDBG("blockArg:\n" << blockArg); + hoistableShapeCast = std::make_pair(inputShapeCastOp, outputShapeCastOp); + initArgIdx = blockArg.getArgNumber() - 1; + } + + if (!hoistableShapeCast) { + return failure(); + } + vector::ShapeCastOp inSC = hoistableShapeCast->first; + vector::ShapeCastOp outSC = hoistableShapeCast->second; + SmallVector forOpInitArgs = forOp.getInitArgs(); + Value source = forOpInitArgs[initArgIdx]; + Value sourceSC = + rewriter + .create(forOp.getLoc(), inSC.getType(), source) + .getResult(); + forOpInitArgs[initArgIdx] = sourceSC; + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), forOpInitArgs); + LDBG("newForOp:\n" << newForOp); + rewriter.mergeBlocks(forOp.getBody(), newForOp.getBody(), + newForOp.getBody()->getArguments()); + auto newYieldOp = cast(newForOp.getBody()->getTerminator()); + LDBG("newYieldOp:\n" << newYieldOp); + SmallVector newForOpResults = + newYieldOp.getOperation()->getOperands(); + int contractResultIndex; + for (auto result : llvm::enumerate(newForOpResults)) { + if (result.value() == outSC.getResult()) { + newForOpResults[result.index()] = outSC.getSource(); + contractResultIndex = result.index(); + } + } + rewriter.updateRootInPlace(newYieldOp, [&]() { + newYieldOp.getOperation()->setOperands(newForOpResults); + }); + LDBG("newForOp with body:\n" << newForOp); + SmallVector newResults = newForOp.getResults(); + Value hoistedOutputShapeCast = + rewriter + .create(forOp.getLoc(), outSC.getType(), + newResults[contractResultIndex]) + .getResult(); + LDBG("hoistedOutputShapeCast:\n" << hoistedOutputShapeCast); + newResults[contractResultIndex] = hoistedOutputShapeCast; + rewriter.replaceOp(forOp, newResults); + + return success(); + } +}; + +namespace { +struct LLVMCPUFoldVectorContractUnitDimsPass + : public LLVMCPUFoldVectorContractUnitDimsBase< + LLVMCPUFoldVectorContractUnitDimsPass> { + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; +} // namespace + +void LLVMCPUFoldVectorContractUnitDimsPass::runOnOperation() { + Operation *funcOp = getOperation(); + MLIRContext *context = &getContext(); + RewritePatternSet foldUnitDimsPatterns(context); + foldUnitDimsPatterns + .add(context); + if (failed(applyPatternsAndFoldGreedily(funcOp, + std::move(foldUnitDimsPatterns)))) { + return signalPassFailure(); + } +} + +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass() { + return std::make_unique(); +} + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp index 722db723c594..6ece8dcdf89f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSplitReduction.cpp @@ -38,8 +38,9 @@ namespace { /// TODO: support named ops, numInputs > 1, and modify lastDim check below /// accordingly. If fpReductionReordering is not enabled by default, it must /// be an integer or index type to proceed to allow associative reordering. -LogicalResult splitReductionPrecondition(Operation *op, - bool fpReductionReordering) { +LogicalResult +splitReductionPrecondition(Operation *op, bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { linalg::LinalgOp linalgOp = cast(op); if (!linalgOp.hasTensorSemantics()) { @@ -63,7 +64,11 @@ LogicalResult splitReductionPrecondition(Operation *op, LLVM_DEBUG(llvm::dbgs() << "is not a generic op\n"); return failure(); } - if (linalgOp.getNumDpsInputs() != 1) { + if (enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() > 2) { + LLVM_DEBUG(llvm::dbgs() << "doesn't have at most 2 inputs\n"); + return failure(); + } + if (!enableQuantizedMatmulReassociation && linalgOp.getNumDpsInputs() != 1) { LLVM_DEBUG(llvm::dbgs() << "doesn't have exactly 1 input\n"); return failure(); } @@ -102,8 +107,10 @@ LogicalResult splitReductionPrecondition(Operation *op, /// Converts an inner-reduction into outer reduction + inner-parallel dimension, /// followed by simple inner reduction. -LogicalResult splitReductionImpl(Operation *op, int64_t size, +LogicalResult splitReductionImpl(Operation *op, SmallVector tileSizes, + bool enableQuantizedMatmulReassociation, RewriterBase &rewriter) { + int64_t size = tileSizes.back(); IRRewriter::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(op); linalg::LinalgOp linalgOp = cast(op); @@ -119,8 +126,19 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, auto numLoops = linalgOp.getNumLoops(); // 1) Tile to extract a single vector-length array. - SmallVector tileSizesSVFirst(numLoops, - rewriter.getIndexAttr(1)); + SmallVector tileSizesSVFirst; + if (enableQuantizedMatmulReassociation) { + for (auto &s : tileSizes) { + if (!s) { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(1)); + } else { + tileSizesSVFirst.push_back(rewriter.getIndexAttr(s)); + } + } + } else { + tileSizesSVFirst = + SmallVector(numLoops, rewriter.getIndexAttr(1)); + } tileSizesSVFirst[numLoops - 1] = rewriter.getIndexAttr(0); auto options = scf::SCFTilingOptions().setTileSizes(tileSizesSVFirst); FailureOr tileResFirst = scf::tileUsingSCFForOp( @@ -147,7 +165,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, rewriter.getIndexAttr(0)); // The reduction happens only in the penultimate dimension, which we now // tile. - tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + if (enableQuantizedMatmulReassociation) { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(2); + } else { + tileSizesSV[numLoops - 1] = rewriter.getIndexAttr(1); + } options = scf::SCFTilingOptions().setTileSizes(tileSizesSV); FailureOr tileRes = scf::tileUsingSCFForOp( rewriter, cast(splitRes->splitLinalgOp.getOperation()), @@ -164,8 +186,11 @@ LogicalResult splitReductionImpl(Operation *op, int64_t size, class LLVMCPUSplitReductionPass : public LLVMCPUSplitReductionBase { public: - LLVMCPUSplitReductionPass(bool fpReductionReordering) { + LLVMCPUSplitReductionPass(bool fpReductionReordering, + bool enableQuantizedMatmulReassociation) { this->enableFpReductionReordering = fpReductionReordering; + this->enableQuantizedMatmulReassociation = + enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -183,8 +208,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); }); for (auto genericOp : candidates) { LLVM_DEBUG(llvm::dbgs() << "candidate: " << genericOp << "\n"); - if (failed(splitReductionPrecondition(genericOp, - enableFpReductionReordering))) { + if (failed( + splitReductionPrecondition(genericOp, enableFpReductionReordering, + enableQuantizedMatmulReassociation))) { continue; } @@ -208,8 +234,9 @@ void LLVMCPUSplitReductionPass::runOnOperation() { "skip SplitReduction"); continue; } - int64_t size = reductionSizes.back(); - if (failed(splitReductionImpl(genericOp, size, rewriter))) { + if (failed(splitReductionImpl(genericOp, reductionSizes, + enableQuantizedMatmulReassociation, + rewriter))) { return signalPassFailure(); } } @@ -217,9 +244,10 @@ void LLVMCPUSplitReductionPass::runOnOperation() { } // namespace std::unique_ptr> -createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering) { +createLLVMCPUSplitReductionPass(const bool enableFpReductionReordering, + const bool enableQuantizedMatmulReassociation) { return std::make_unique( - enableFpReductionReordering); + enableFpReductionReordering, enableQuantizedMatmulReassociation); } } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp index 42b46756c25d..ff4af02a3e26 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUVectorLowering.cpp @@ -46,6 +46,8 @@ class LLVMCPUVectorLoweringPass LLVMCPUVectorLoweringPass(const LLVMCPUVectorLoweringPassOptions &options) { this->splitVectorTransfersTo = options.splitVectorTransfersTo; this->lowerVectorTransposeToAVX2 = options.lowerVectorTransposeToAVX2; + this->enableQuantizedMatmulReassociation = + options.enableQuantizedMatmulReassociation; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -77,6 +79,27 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { .setVectorTransformsOptions(vectorContractLowering) .setVectorMultiReductionLowering(vectorMultiReductionLowering) .setVectorTransferSplit(vectorTransferSplit); + + { + if (enableQuantizedMatmulReassociation) { + // Special-case vector.contract codegen paths. This needs to happen + // just before the generic vector ops lowerings. + RewritePatternSet patterns(ctx); + auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + populateVectorContractCustomKernelsPatterns(target, patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After custom kernel lowering for " + "vector.contract ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + // Lower high level vector operations like contract or multidim reduce ops // to lower level vector ops. { @@ -158,6 +181,23 @@ void LLVMCPUVectorLoweringPass::runOnOperation() { llvm::dbgs() << "\n\n"; }); + // Break down subbyte `arith.extui` ops + { + if (enableQuantizedMatmulReassociation) { + RewritePatternSet patterns(&getContext()); + populateLLVMCPUBreakDownSubbyteExtendPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + + LLVM_DEBUG({ + llvm::dbgs() << "\n--- After breaking down subbyte extend ops ---\n"; + funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n\n"; + }); + } + } + // 'vector.shape_cast' are very expensive operations that are even generated // by some of the lowerings above (e.g., transpose lowering). There are // chances to cancel them out if they are not lowered too early so we lower diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index b563dbac0cd4..09467bc80e8a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -86,6 +86,13 @@ static llvm::cl::opt clInstrumentMemoryAccesses{ "instrumentation is enabled."), llvm::cl::init(false)}; +static llvm::cl::opt clEnableQuantizedMatmulReassociation( + "iree-llvmcpu-enable-quantized-matmul-reassociation", + llvm::cl::desc( + "Enables LLVMCPU codegen optimizations specific to reassociated " + "quantized matmuls (experimental)."), + llvm::cl::init(false)); + // MLIR file containing a top-level module that specifies the transformations to // apply to form dispatch regions. // Defined externally in KernelDispatch.cpp to control the codegen pass @@ -226,6 +233,19 @@ LogicalResult verifyDoubleTilingExpertPassPipelineConfig( << index << "-th tile size set"; } } + // if (!clEnableQuantizedMatmulReassociation) { + // SmallVector thirdLevelTileSizes; + // std::tie(thirdLevelTileSizes, std::ignore) = + // tilingConfig.getVectorReductionSizes(); + // for (auto [index, tileSize] : llvm::enumerate(thirdLevelTileSizes)) { + // if (tileSize != 0 && pLoopsSet.contains(index)) { + // return op->emitOpError("expected only reduction dims to be set in " + // "the third tiling " + // "level, got ") + // << index << "-th tile size set"; + // } + // } + // } } // Verify interchange @@ -471,7 +491,9 @@ void addMultiTilingExpertPassPipeline( // Run SplitReductionPass before the final reduction Fuse pass, because // SplitReductionPass takes care of banked-tiling. nestedModulePM.addNestedPass( - createLLVMCPUSplitReductionPass(clEnableReassociateFpReductions)); + createLLVMCPUSplitReductionPass( + clEnableReassociateFpReductions, + clEnableQuantizedMatmulReassociation)); nestedModulePM.addNestedPass(createLLVMCPUTilePass(i)); continue; } @@ -508,11 +530,17 @@ void addMultiTilingExpertPassPipeline( // Run IREE specific passes before vector lowering expert. nestedModulePM.addNestedPass( createRemoveSingleIterationLoopPass()); + if (clEnableQuantizedMatmulReassociation) { + nestedModulePM.addNestedPass( + createLLVMCPUFoldVectorContractUnitDimsPass()); + } { LLVMCPUVectorLoweringPassOptions options; options.lowerVectorTransposeToAVX2 = lowerToAVX2; options.splitVectorTransfersTo = "linalg-copy"; + options.enableQuantizedMatmulReassociation = + clEnableQuantizedMatmulReassociation; nestedModulePM.addNestedPass( createLLVMCPUVectorLoweringPass(options)); } @@ -743,6 +771,9 @@ static void addLowerToLLVMPasses(OpPassManager &passManager) { passManager.addNestedPass(arith::createArithExpandOpsPass()); passManager.addNestedPass(memref::createExpandOpsPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); + if (clEnableQuantizedMatmulReassociation) { + passManager.addPass(createLLVMCPUFoldMemRefAliasOpsPass()); + } passManager.addPass(createEmulateNarrowTypePass()); passManager.addPass(createCanonicalizerPass()); passManager.addPass(createCSEPass()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h index 47dad29749e1..4527ec5a07a8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h @@ -20,6 +20,10 @@ namespace iree_compiler { class TilingConfig; +// Pass to breakdown subbyte extui +std::unique_ptr> +createLLVMCPUBreakDownSubbyteExtendPass(); + /// Performs the final conversion to LLVM dialect. std::unique_ptr> createConvertToLLVMPass(bool reassociateFpReordering = false); @@ -55,8 +59,9 @@ createLLVMCPUMmt4dVectorLoweringPass(); std::unique_ptr> createLLVMCPUPeelPass(); /// Pass to perform SplitReduction transformations of `LinalgOp`s. -std::unique_ptr> -createLLVMCPUSplitReductionPass(bool enableReassociateFpReductions = false); +std::unique_ptr> createLLVMCPUSplitReductionPass( + bool enableReassociateFpReductions = false, + bool enableQuantizedMatmulReassociation = false); /// Synchronizes LLVM linkage with MLIR symbol visibility. std::unique_ptr> @@ -82,6 +87,7 @@ std::unique_ptr> createLLVMCPUUnfuseFMAOpsPass(); struct LLVMCPUVectorLoweringPassOptions { std::string splitVectorTransfersTo = ""; bool lowerVectorTransposeToAVX2 = false; + bool enableQuantizedMatmulReassociation = false; }; std::unique_ptr> createLLVMCPUVectorLoweringPass(); std::unique_ptr> createLLVMCPUVectorLoweringPass( @@ -96,6 +102,11 @@ createVectorContractCustomKernelsPass(); std::unique_ptr> createVerifyLinalgTransformLegalityPass(); +std::unique_ptr> +createLLVMCPUFoldVectorContractUnitDimsPass(); + +std::unique_ptr createLLVMCPUFoldMemRefAliasOpsPass(); + //------------------------------------------------------------------------------ // LLVMCPU Codegen specific patterns. //------------------------------------------------------------------------------ @@ -108,6 +119,11 @@ void populateUnfusedFMAOpsPassPatterns(MLIRContext *context, void populateVectorContractCustomKernelsPatterns( IREE::HAL::ExecutableTargetAttr target, RewritePatternSet &patterns); +void populateLLVMCPUBreakDownSubbyteExtendPatterns(RewritePatternSet &patterns); + +void populateFoldVectorContractUnitDimsPass(RewritePatternSet &patterns, + MLIRContext *context); + //----------------------------------------------------------------------------// // LLVMCPU backend Pass Pipelines. //----------------------------------------------------------------------------// diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td index 69fc43ffc07d..0d4c95f70c1a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td @@ -45,6 +45,11 @@ def LLVMCPUAssignImportOrdinals : let constructor = "mlir::iree_compiler::createLLVMCPUAssignImportOrdinalsPass()"; } +def LLVMCPUBreakDownSubbyteExtend : Pass<"iree-llvmcpu-breakdown-subbyte-extend", "func::FuncOp"> { + let summary = "Pass to break down subbyte extui ops."; + let constructor = "mlir::iree_compiler::createLLVMCPUBreakDownSubbyteExtendPass()"; +} + def LLVMCPUCheckIRBeforeLLVMConversion : Pass<"iree-llvmcpu-check-ir-before-llvm-conversion", "ModuleOp"> { let summary = "Checks CPU backend specific IR constraints (like no allocas)"; @@ -58,6 +63,20 @@ def LLVMCPUEmitVectorizationRemarks : "mlir::iree_compiler::createLLVMCPUEmitVectorizationRemarksPass()"; } +def LLVMCPUFoldVectorContractUnitDims : + Pass<"iree-llvmcpu-fold-vector-contract-unit-dims", "func::FuncOp"> { + let summary = "Fold unit dims on vector.contract ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldVectorContractUnitDimsPass()"; +} + +def LLVMCPUFoldMemRefAliasOps : + Pass<"iree-llvmcpu-fold-memref-alias-ops", ""> { + let summary = "Fold combinations of memref ops"; + let constructor = + "mlir::iree_compiler::createLLVMCPUFoldMemRefAliasOpsPass()"; +} + def LLVMCPULinkExecutables : Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> { let summary = "Links LLVMCPU HAL executables within the top-level program module."; @@ -107,6 +126,9 @@ def LLVMCPUSplitReduction : Pass<"iree-llvmcpu-split-reduction", "func::FuncOp"> Option<"enableFpReductionReordering", "enable-fp-reduction-reordering", "bool", /*default=*/"false", "Flag to enable reduction reordering on floating points.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", + "bool", /*default=*/"false", + "Flag to enable optimizations for reassociated quantized matmuls.">, ]; } @@ -166,6 +188,9 @@ def LLVMCPUVectorLowering : Option<"lowerVectorTransposeToAVX2", "lower-vector-transpose-to-avx2", "bool", /*default=*/"false", "Add specific transpose to avx2 lowering patterns.">, + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", + /*default=*/"false", + "Add specific patterns for optimizing reassociated quantized matmuls.">, ]; let constructor = "mlir::iree_compiler::createLLVMCPUVectorLoweringPass()"; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp index a4730ee1483a..fe6fd081e5ef 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp @@ -25,6 +25,10 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#define DEBUG_TYPE "iree-vector-contract-custom-kernels" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace mlir { namespace iree_compiler { @@ -86,6 +90,73 @@ static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) { return true; } +static bool isVectorTimesMatrixTransposed(vector::ContractionOp contractionOp, + int64_t splitSize) { + // Check that the reduction is additive. + if (contractionOp.getKind() != vector::CombiningKind::ADD) { + return false; + } + // Check that there are 1 parallel and 1 reduction iterators. + unsigned numIters = splitSize ? 3 : 2; + auto iteratorTypes = contractionOp.getIteratorTypes().getValue(); + if (iteratorTypes.size() != numIters) { + return false; + } + SmallVector parallelIterators; + SmallVector reductionIterators; + for (int i = 0; i < numIters; i++) { + if (vector::isParallelIterator(iteratorTypes[i])) { + parallelIterators.push_back(i); + } else if (vector::isReductionIterator(iteratorTypes[i])) { + reductionIterators.push_back(i); + } else { + return false; + } + } + if (parallelIterators.size() != numIters - 1 || + reductionIterators.size() != 1) { + return false; + } + // Give the found iterators some idiomatic names. + const int NIter = parallelIterators[0]; + const int KIter = reductionIterators[0]; + const int SplitIter = splitSize ? parallelIterators[1] : 0; + // Check that there are 3 indexing maps. + auto indexingMaps = contractionOp.getIndexingMapsArray(); + if (indexingMaps.size() != 3) { + return false; + } + // Check that the indexing maps have the expected form. + SmallVector> expectedMapResults; + if (splitSize) { + SmallVector> res = { + {KIter, SplitIter}, {NIter, KIter, SplitIter}, {NIter, SplitIter}}; + expectedMapResults = res; + numIters = 3; + } else { + SmallVector> res = {{KIter}, {NIter, KIter}, {NIter}}; + expectedMapResults = res; + numIters = 2; + } + for (int m = 0; m < 3; ++m) { + auto map = indexingMaps[m]; + auto expectedResults = expectedMapResults[m]; + if (map.getNumDims() != numIters || + map.getNumResults() != expectedResults.size()) { + return false; + } + for (int r = 0; r < expectedResults.size(); ++r) { + int actualMapResult = + map.getResults()[r].cast().getPosition(); + if (actualMapResult != expectedMapResults[m][r]) { + return false; + } + } + } + LDBG("passed isVectorTimesMatrixTransposed"); + return true; +} + // Returns true if `contractionOp` is of the form // matrix * transposed_matrix // where matrix is a vector<{mSize}x{kSize}xType>, and @@ -132,6 +203,31 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, return false; } +static bool matchVMT(vector::ContractionOp contractionOp, int64_t mSize, + int64_t kSize, int64_t nSize, int splitSize, + bool *transpose = nullptr) { + if (mSize != 1) { + return false; + } + if (!isVectorTimesMatrixTransposed(contractionOp, splitSize)) { + return false; + } + VectorType lhsType = llvm::cast(contractionOp.getLhs().getType()); + VectorType rhsType = llvm::cast(contractionOp.getRhs().getType()); + auto lhsShape = lhsType.getShape(); + auto rhsShape = rhsType.getShape(); + if (splitSize && (lhsShape[1] != splitSize || rhsShape[2] != splitSize)) { + return false; + } + if (lhsShape[0] != kSize || rhsShape[1] != kSize) { + return false; + } + if (rhsShape[0] == nSize) { + return true; + } + return false; +} + // `promotedResult` is required to be a Vector. // If its VectorType does not have `promotedType` as its element type, or // the operand to the type-promotion op is not `unpromotedType` returns a null @@ -143,8 +239,9 @@ static bool matchMMT(vector::ContractionOp contractionOp, int64_t mSize, // Note that this only looks at the immediately defining operation, so we likely // want to have earlier passes that sink widening operations as far down as // possible, which is probably just good regardless. -static Value getUnpromotedInput(Type unpromotedType, Type promotedType, - Value promotedResult) { +static Value getUnpromotedInput(PatternRewriter &rewriter, Type unpromotedType, + Type promotedType, Value promotedResult, + bool promoteSmallTypes = false) { VectorType promotedResultVectorType = llvm::cast(promotedResult.getType()); if (promotedResultVectorType.getElementType() != promotedType) { @@ -156,13 +253,29 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // TODO: handle promotion of floating point types. Not doing it for now as // it wouldn't be exercised. auto extSIOp = promotedResult.getDefiningOp(); - if (!extSIOp) { + auto extUIOp = promotedResult.getDefiningOp(); + if (!extSIOp && !extUIOp) { return nullptr; } - Value extInput = extSIOp.getIn(); + Value extInput = extSIOp ? extSIOp.getIn() : extUIOp.getIn(); if (llvm::cast(extInput.getType()).getElementType() != unpromotedType) { - return nullptr; + if (promoteSmallTypes) { + VectorType unpromotedVectorType = + VectorType::get(llvm::cast(extInput.getType()).getShape(), + unpromotedType); + return extSIOp + ? rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult() + : rewriter + .create(extInput.getLoc(), + unpromotedVectorType, extInput) + .getResult(); + } else { + return nullptr; + } } return extInput; } @@ -170,12 +283,28 @@ static Value getUnpromotedInput(Type unpromotedType, Type promotedType, // Helper to create a 1D, contiguous slice of a 1D vector. static Value extract1DSlice(PatternRewriter &rewriter, Location loc, VectorType dstVecType, Value input, int position) { - assert(input.getType().cast().getRank() == 1); assert(dstVecType.getRank() == 1); - std::array offsets{position}; - std::array strides{1}; - return rewriter.create( - loc, input, offsets, dstVecType.getShape(), strides); + if (input.getType().cast().getRank() == 1) { + SmallVector offsets({position}); + SmallVector strides({1}); + SmallVector sizes(dstVecType.getShape()); + return rewriter.create(loc, input, offsets, + sizes, strides); + } else { + SmallVector inputShape( + llvm::cast(input.getType()).getShape()); + assert(inputShape.back() == dstVecType.getNumElements()); + std::reverse(inputShape.begin(), inputShape.end()); + int currentPos = position; + SmallVector indices; + for (auto size : inputShape) { + indices.push_back(currentPos % size); + currentPos = currentPos / size; + } + std::reverse(indices.begin(), indices.end()); + return rewriter.create( + loc, input, SmallVector(indices.begin(), indices.end() - 1)); + } } // Helper to extract an element of a 1D vector. @@ -189,8 +318,12 @@ static Value extract(PatternRewriter &rewriter, Location loc, Value input, } // Helper to flatten a N-dimensional vector to a 1D vector. -static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { +static Value flattenImperfectSize(PatternRewriter &rewriter, Location loc, + Value vector, VectorType regVectorType) { VectorType inputVecType = llvm::cast(vector.getType()); + if (regVectorType.getNumElements() == inputVecType.getShape().back()) { + return vector; + } VectorType dstType = VectorType::get(inputVecType.getNumElements(), inputVecType.getElementType()); return rewriter.create(loc, dstType, vector); @@ -207,20 +340,31 @@ static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) { // (2) Be explicit about the size of the vectors involved in the kernel's // "calling convention". struct MMTKernel { - enum class ScalarType : int8_t { None, I8, I32, F32 }; + enum class ScalarType : int8_t { None, I4, I8, I16, I32, F32 }; // Element type of the LHS vectors. ScalarType lhsType = ScalarType::None; // Element type of the RHS vectors. ScalarType rhsType = ScalarType::None; // Element type of the Accumulator and output vectors. ScalarType accType = ScalarType::None; + // Optional user defined constrained codes for input and output registers. + // This is useful when the constraint code is not the same for all operands. + std::optional> lhsCode = std::nullopt; + std::optional> rhsCode = std::nullopt; + std::optional> accCode = std::nullopt; + // This flag indicates whether or not to promote inputs that have a smaller + // bitwidth than lhsType, rhsType, or accType, to the appropriate bitwidth + bool promoteSmallTypes = false; // Number of rows of the LHS and Accumulator tile. - int8_t m0 = 0; + int16_t m0 = 0; // Reduction dimension, i.e. number of columns of the LHS. - int8_t k0 = 0; + int16_t k0 = 0; // Number of rows of the RHS (note that the operation being targeted, MMT, // is matrix multiplication with a *transposed* RHS) - int8_t n0 = 0; + int16_t n0 = 0; + // Size of the added parallel dimension when the vector.contract op has been + // split with splitReduction + int16_t split0 = 0; // Number of LHS elements in the type of register to be used for the LHS. // This is > 1 if SIMD registers are to be used. // Note: LHS/RHS/Accumulator may use registers of different sizes. @@ -236,6 +380,8 @@ struct MMTKernel { int8_t rhsRegs = 0; // Number of registers needed to hold the Accumulator. int8_t accRegs = 0; + // Indicates whether to use Intel or AT&T syntax + bool useIntel = false; // If not null, points to the inline asm code template for this kernel. // Register operands for the LHS, RHS and Accumulator are to be referenced as // $(lhs:), $(rhs:), $(acc:) respectively, where i is a decimal @@ -250,9 +396,15 @@ struct MMTKernel { const char *asmClobbers = nullptr; void validate() const { - assert(m0 * k0 == lhsRegSize * lhsRegs); // number of elements of LHS - assert(n0 * k0 == rhsRegSize * rhsRegs); // number of elements of RHS - assert(m0 * n0 == accRegSize * accRegs); // number of elements of Accum + assert(m0 * k0 == lhsRegSize * lhsRegs || + m0 * k0 * split0 == + lhsRegSize * lhsRegs); // number of elements of LHS + assert(n0 * k0 == rhsRegSize * rhsRegs || + n0 * k0 * split0 == + rhsRegSize * rhsRegs); // number of elements of RHS + assert(m0 * n0 == accRegSize * accRegs || + m0 * n0 * split0 == + accRegSize * accRegs); // number of elements of Accum assert(lhsType != ScalarType::None); assert(rhsType != ScalarType::None); assert(accType != ScalarType::None); @@ -674,13 +826,75 @@ MMTKernel MMTKernel_8x1x1_f32f32f32_Aarch64_Baseline_InlineAsm() { return kernel; } +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpdpwssd $(acc:0), $(rhs:0), $(lhs:0) + vpdpwssd $(acc:1), $(rhs:1), $(lhs:0) + vpdpwssd $(acc:2), $(rhs:2), $(lhs:0) + vpdpwssd $(acc:3), $(rhs:3), $(lhs:0) + )ASM"; + kernel.asmClobbers = ""; + return kernel; +} + +MMTKernel MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm() { + MMTKernel kernel; + kernel.lhsType = MMTKernel::ScalarType::I16; + kernel.rhsType = MMTKernel::ScalarType::I16; + kernel.accType = MMTKernel::ScalarType::I32; + kernel.promoteSmallTypes = true; + kernel.useIntel = true; + kernel.m0 = 1; + kernel.k0 = 2; + kernel.n0 = 4; + kernel.split0 = 16; + kernel.lhsRegSize = 32; + kernel.rhsRegSize = 32; + kernel.accRegSize = 16; + kernel.lhsRegs = 1; + kernel.rhsRegs = 4; + kernel.accRegs = 4; + kernel.asmImpl = R"ASM( + vpmaddwd zmm17, $(rhs:0), $(lhs:0) + vpmaddwd zmm18, $(rhs:1), $(lhs:0) + vpmaddwd zmm19, $(rhs:2), $(lhs:0) + vpmaddwd zmm20, $(rhs:3), $(lhs:0) + vpaddw $(acc:0), $(acc:0), zmm17 + vpaddw $(acc:1), $(acc:1), zmm18 + vpaddw $(acc:2), $(acc:2), zmm19 + vpaddw $(acc:3), $(acc:3), zmm20 + )ASM"; + kernel.asmClobbers = "zmm17,zmm18,zmm19,zmm20"; + return kernel; +} + // Constructs the mlir::Type corresponding to a scalar type. Type mlirType(MLIRContext *context, MMTKernel::ScalarType t) { switch (t) { case MMTKernel::ScalarType::None: break; + case MMTKernel::ScalarType::I4: + return IntegerType::get(context, 4, IntegerType::Signless); case MMTKernel::ScalarType::I8: return IntegerType::get(context, 8, IntegerType::Signless); + case MMTKernel::ScalarType::I16: + return IntegerType::get(context, 16, IntegerType::Signless); case MMTKernel::ScalarType::I32: return IntegerType::get(context, 32, IntegerType::Signless); case MMTKernel::ScalarType::F32: @@ -705,7 +919,7 @@ class MMTKernelGenerator { ArrayRef acc) { validateOperands(lhs, rhs, acc); if (kernel.asmImpl) { - return generateAsm(rewriter, loc, lhs, rhs, acc); + return generateAsm(rewriter, loc, lhs, rhs, acc, kernel.useIntel); } // In the future we may have alternate generator paths, e.g. 1D intrinsics // or other asm paths with a different interface, e.g. handling also @@ -755,10 +969,17 @@ class MMTKernelGenerator { validate(acc, kernel.accRegs, getAccRegVectorType()); } // Helper for generateAsmCodeAndConstraints - std::string getConstraintCode() const { + std::string + getConstraintCode(std::optional kernelConstraintCode) const { + if (kernelConstraintCode) { + return std::string(*kernelConstraintCode); + } if (isAArch64(target)) { return "w"; } + if (isX86(target)) { + return "v"; + } assert(false && "what constraint code to use on this arch?"); return {}; } @@ -820,31 +1041,39 @@ class MMTKernelGenerator { // processedIdx is the index of a register in the processed asm. // Example: $5 => processedIdx == 5 int processedIdx = 0; - auto processOperands = [&](Constraints::Kind constraintKind, - const char *name, int count) { - const std::string &constraintCode = getConstraintCode(); - // unprocessedIdx is the index of a register in the unprocessed asm. - // Example: $(lhs:1) => unprocessedIdx == 1 - for (int unprocessedIdx = 0; unprocessedIdx < count; - ++unprocessedIdx, ++processedIdx) { - constraints.add(constraintKind, constraintCode); - // Perform the code replacement for the operand. - // Example: $(lhs:1) => $5 - replaceAllSubstrsInPlace( - code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), - llvm::formatv("${0}", processedIdx)); - } - }; - processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs); - processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs); - processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs); + auto processOperands = + [&](Constraints::Kind constraintKind, const char *name, int count, + std::optional> kernelCodes) { + const std::string &constraintCode = getConstraintCode(std::nullopt); + // unprocessedIdx is the index of a register in the unprocessed asm. + // Example: $(lhs:1) => unprocessedIdx == 1 + for (int unprocessedIdx = 0; unprocessedIdx < count; + ++unprocessedIdx, ++processedIdx) { + if (kernelCodes) { + constraints.add(constraintKind, (*kernelCodes)[unprocessedIdx]); + } else { + constraints.add(constraintKind, constraintCode); + } + // Perform the code replacement for the operand. + // Example: $(lhs:1) => $5 + replaceAllSubstrsInPlace( + code, llvm::formatv("$({0}:{1})", name, unprocessedIdx), + llvm::formatv("${0}", processedIdx)); + } + }; + processOperands(Constraints::Kind::InputOutput, "acc", kernel.accRegs, + kernel.accCode); + processOperands(Constraints::Kind::Input, "lhs", kernel.lhsRegs, + kernel.lhsCode); + processOperands(Constraints::Kind::Input, "rhs", kernel.rhsRegs, + kernel.rhsCode); constraints.setClobbers(kernel.asmClobbers); constraintsString = constraints.toString(); } // Helper for generate(). Implements the asm path. SmallVector generateAsm(PatternRewriter &rewriter, Location loc, ArrayRef lhs, ArrayRef rhs, - ArrayRef acc) { + ArrayRef acc, bool useIntel) { SmallVector inputs; // First the input operands. Then the input-output operands, which, as far // as input constraints are concerned, are *tied* inputs, i.e. refer to @@ -864,9 +1093,13 @@ class MMTKernelGenerator { SmallVector outputOperandTypes( llvm::map_range(acc, [](Value v) { return v.getType(); })); auto returnType = - LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); + outputOperandTypes.size() == 1 + ? outputOperandTypes[0] + : LLVM::LLVMStructType::getLiteral(context, outputOperandTypes); auto dialectAttr = - LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); + useIntel + ? LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_Intel) + : LLVM::AsmDialectAttr::get(context, LLVM::AsmDialect::AD_ATT); std::string code; std::string constraints; generateAsmCodeAndConstraints(code, constraints); @@ -876,10 +1109,14 @@ class MMTKernelGenerator { /*operand_attrs=*/ArrayAttr()); // Extract result vectors from the asm op. SmallVector resVec; - for (int i = 0; i < kernel.accRegs; ++i) { - SmallVector position = {i}; - resVec.push_back( - rewriter.create(loc, asmOp.getRes(), position)); + if (outputOperandTypes.size() == 1) { + resVec.push_back(asmOp.getRes()); + } else { + for (int i = 0; i < kernel.accRegs; ++i) { + SmallVector position = {i}; + resVec.push_back(rewriter.create( + loc, asmOp.getRes(), position)); + } } return resVec; } @@ -914,7 +1151,9 @@ class MMTCustomKernelPattern : public OpRewritePattern { // Check if `contractionOp` matches, and obtain the (un-promoted) input // LHS and RHS vectors. bool transposeKernel = false; - if (!matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, + if (!matchVMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, kernel.split0, + &transposeKernel) && + !matchMMT(contractionOp, kernel.m0, kernel.k0, kernel.n0, &transposeKernel)) { return failure(); } @@ -929,9 +1168,11 @@ class MMTCustomKernelPattern : public OpRewritePattern { return failure(); } Value unpromotedLhs = - getUnpromotedInput(lhsElemType, accElemType, contractionOp.getLhs()); + getUnpromotedInput(rewriter, lhsElemType, accElemType, + contractionOp.getLhs(), kernel.promoteSmallTypes); Value unpromotedRhs = - getUnpromotedInput(rhsElemType, accElemType, contractionOp.getRhs()); + getUnpromotedInput(rewriter, rhsElemType, accElemType, + contractionOp.getRhs(), kernel.promoteSmallTypes); if (!unpromotedLhs || !unpromotedRhs) { return failure(); } @@ -953,9 +1194,23 @@ class MMTCustomKernelPattern : public OpRewritePattern { // `contractionOp` matches, start rewriting it. Location loc = contractionOp.getLoc(); // Flatten the inputs to 1D vectors. - Value flatLhs = flatten(rewriter, loc, unpromotedLhs); - Value flatRhs = flatten(rewriter, loc, unpromotedRhs); - Value flatAcc = flatten(rewriter, loc, contractionOp.getAcc()); + VectorType lhsRegVectorType = generator.getLhsRegVectorType(); + VectorType rhsRegVectorType = generator.getRhsRegVectorType(); + VectorType accRegVectorType = generator.getAccRegVectorType(); + Value lhs, rhs; + if (transposeKernel) { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, rhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, lhsRegVectorType); + } else { + lhs = + flattenImperfectSize(rewriter, loc, unpromotedLhs, lhsRegVectorType); + rhs = + flattenImperfectSize(rewriter, loc, unpromotedRhs, rhsRegVectorType); + } + Value acc = flattenImperfectSize(rewriter, loc, contractionOp.getAcc(), + accRegVectorType); // Slice into SIMD-register-sized 1D input vectors ready to feed to the // target SIMD instructions. auto sliceIntoRegVectors = [&](int regsCount, VectorType regVectorType, @@ -968,17 +1223,14 @@ class MMTCustomKernelPattern : public OpRewritePattern { } return regVectors; }; - VectorType lhsRegVectorType = generator.getLhsRegVectorType(); - VectorType rhsRegVectorType = generator.getRhsRegVectorType(); - VectorType accRegVectorType = generator.getAccRegVectorType(); - Value flatLhsForKernel = transposeKernel ? flatRhs : flatLhs; - Value flatRhsForKernel = transposeKernel ? flatLhs : flatRhs; + Value lhsForKernel = transposeKernel ? rhs : lhs; + Value rhsForKernel = transposeKernel ? lhs : rhs; SmallVector lhsRegVectors = - sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, flatLhsForKernel); + sliceIntoRegVectors(kernel.lhsRegs, lhsRegVectorType, lhsForKernel); SmallVector rhsRegVectors = - sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, flatRhsForKernel); + sliceIntoRegVectors(kernel.rhsRegs, rhsRegVectorType, rhsForKernel); SmallVector accRegVectors = - sliceIntoRegVectors(kernel.accRegs, accRegVectorType, flatAcc); + sliceIntoRegVectors(kernel.accRegs, accRegVectorType, acc); // Generate the kernel! SmallVector resRegVectors = generator.generate( rewriter, loc, lhsRegVectors, rhsRegVectors, accRegVectors); @@ -1037,8 +1289,8 @@ struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_Intrinsics return failure(); } - Value inLhs = getUnpromotedInput(I8Type, I32Type, lhs); - Value inRhs = getUnpromotedInput(I8Type, I32Type, rhs); + Value inLhs = getUnpromotedInput(rewriter, I8Type, I32Type, lhs); + Value inRhs = getUnpromotedInput(rewriter, I8Type, I32Type, rhs); if (!inLhs || !inRhs) return failure(); @@ -1171,6 +1423,15 @@ void populateVectorContractCustomKernelsPatterns( patterns.add( context, MMTKernel_8x8x8_i8i8i32_Aarch64I8mm_InlineAsm()); } + } else if (isX86(target)) { + if (hasFeature(target, "+avx512vnni")) { + patterns.add( + context, + MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512VNNI_InlineAsm()); + } else if (hasFeature(target, "+avx512bw")) { + patterns.add( + context, MMTKernel_1x2x4_split16_i16i16i32_x86_AVX512_InlineAsm()); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td index 5ee617cddac1..50eacb872ecb 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td @@ -102,9 +102,13 @@ def FormScalarDispatches : } def FuseDequantizationMatmul: - Pass<"iree-flow-fuse-dequantization-matmul", ""> { + InterfacePass<"iree-flow-fuse-dequantization-matmul", "mlir::FunctionOpInterface"> { let summary = "Fuse dequantization and matmul linalg.generic ops"; let constructor = "mlir::iree_compiler::IREE::Flow::createFuseDequantizationMatmulPass()"; + let options = [ + Option<"enableQuantizedMatmulReassociation", "enable-quantized-matmul-reassociation", "bool", + /*default=*/"false", "Allow reassociation of quantized matmuls (experimental)">, + ]; } def CollapseDimensions :