Skip to content

Commit

Permalink
Various tweaks to numeric optimizations found while looking at progra…
Browse files Browse the repository at this point in the history
…ms. (iree-org#18765)

* Expands affine.apply ops at the program level. These get introduced
from various places that should possibly be eliminated at this level.
For now expanding them gets the job done, making the program
transformation friendly again.
* Fixes a multi-use issue in i64->index promotion.
* Adds a pattern to fold trunc of an index cast.
* Uses a conservative limit to bound all dynamic dims at the torch
level, even when coming to us as unbounded.
* Implements analysis interfaces on util.align.

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
  • Loading branch information
stellaraccident authored Oct 17, 2024
1 parent 8e54ed5 commit 1500641
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 38 deletions.
24 changes: 13 additions & 11 deletions compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ namespace mlir::iree_compiler::TorchInput {

namespace {

// We aribtrarily say that unbounded dimensions in a torch program cannot
// exceed 53bits, making the maximum safe dimension 9007199254740991. The
// astute reader will note that this is also the maximum safe value in
// JavaScript, which also "happens" to be the largest mantissa value in a
// 64bit double. We need a maximum and in the absence of a better choice,
// with this one we are at least in good company.
static constexpr uint64_t MAX_DIM_VALUE = (static_cast<uint64_t>(1) << 53) - 1;

// Torch "binds" symbolic shape information to all tensors in the program
// which are not static. It does this by emitting side-effecting
// torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops
Expand Down Expand Up @@ -95,15 +103,9 @@ class BindSymbolicShapesPass final
auto maxVal = symbolDefOp.getMaxValAttr();
if (minVal && maxVal) {
uint64_t minValInt = minVal.getValue().getZExtValue();
uint64_t maxValInt = maxVal.getValue().getZExtValue();
// Note that torch represents open ranges in strange ways with various
// magic numbers in the high range of the uint64_t type. We somewhat
// arbitrarily say that anything over a fourth of the uint64_t
// range (which is half of the positive int64_t range, should these have
// originated as signed quantities), is a ridiculously large number not
// suitable as a shape dimension, and we drop the hint.
if (maxValInt >= minValInt &&
maxValInt < std::numeric_limits<uint64_t>::max() / 4) {
uint64_t maxValInt =
std::min(maxVal.getValue().getZExtValue(), MAX_DIM_VALUE);
if (maxValInt >= minValInt) {
// Note that in Torch, min values are "weird" because they encode
// some special cases about broadcast behavior. Here we just discard
// them, but in the future, there may be more to derive here.
Expand Down Expand Up @@ -220,8 +222,8 @@ class BindSymbolicShapesPass final
for (auto [pos, symbolValue] : llvm::enumerate(symbols)) {
const SymbolInfo &symbolInfo = symbolInfos.at(symbolValue);
if (!symbolInfo.minMaxBounds) {
lowerBounds.push_back({});
upperBounds.push_back({});
lowerBounds.push_back(1);
upperBounds.push_back(MAX_DIM_VALUE);
} else {
lowerBounds.push_back(symbolInfo.minMaxBounds->first);
upperBounds.push_back(symbolInfo.minMaxBounds->second);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,16 @@ module @unsupported_non_symbolic {

// -----
// Torch uses high values to signal unbounded ranges. Ensure they are
// suppressed.
// clamped.
// CHECK-LABEL: @torch_unbounded_max_range
module @torch_unbounded_max_range {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK-NOT: util.assume.int<umin
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 4611686018427387903} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775806} : !torch.int
// CHECK: util.assume.int {{.*}}<umin = 1, umax = 9007199254740991>
// CHECK: util.assume.int {{.*}}<umin = 1, umax = 9007199254740991>
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
// CHECK: util.assume.int {{.*}}<umin = 10, umax = 90071992547409910, udiv = 10>
torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32>
return
}
Expand Down
68 changes: 54 additions & 14 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,40 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op,

namespace mlir::iree_compiler::IREE::Util {

//===----------------------------------------------------------------------===//
// util.align
//===----------------------------------------------------------------------===//

void AlignOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
auto constantAlignment = argRanges[1].getConstantValue();
// Note that for non constant alignment, there may still be something we
// want to infer, but this is left for the future.
if (constantAlignment) {
// We can align the range directly.
// (value + (alignment - 1)) & ~(alignment - 1)
// https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
APInt umin = argRanges[0].umin();
APInt umax = argRanges[0].umax();
APInt one(constantAlignment->getBitWidth(), 1);
APInt alignmentM1 = *constantAlignment - one;
APInt alignmentM1Inv = ~alignmentM1;
auto align = [&](APInt value) -> APInt {
return (value + alignmentM1) & alignmentM1Inv;
};
setResultRange(getResult(),
ConstantIntRanges::fromUnsigned(align(umin), align(umax)));
}
}

void AlignOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
SetIntDivisibilityFn setResultDivs) {
auto alignmentDiv = argDivs[1];
if (alignmentDiv.isUninitialized())
return;
setResultDivs(getResult(), alignmentDiv.getValue());
}

//===----------------------------------------------------------------------===//
// util.assume.int
//===----------------------------------------------------------------------===//
Expand All @@ -1120,39 +1154,45 @@ AssumeIntOp::getOperandAssumptions(unsigned operandIndex) {
std::pair<std::optional<uint64_t>, std::optional<uint64_t>>
AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) {
auto assumptions = getOperandAssumptions(operandIndex);
std::optional<uint64_t> uminUnion;
std::optional<uint64_t> umaxUnion;
uint64_t uminUnion = std::numeric_limits<uint64_t>::max();
int uminCount = 0;
uint64_t umaxUnion = std::numeric_limits<uint64_t>::min();
int umaxCount = 0;

for (auto assumption : assumptions) {
auto umin = assumption.getUmin();
auto umax = assumption.getUmax();
if (umin) {
uminUnion = std::min(
*umin, uminUnion ? *uminUnion : std::numeric_limits<uint64_t>::max());
*umin, uminUnion ? uminUnion : std::numeric_limits<uint64_t>::max());
uminCount += 1;
}
if (umax) {
umaxUnion = std::max(
*umax, umaxUnion ? *umaxUnion : std::numeric_limits<uint64_t>::min());
*umax, umaxUnion ? umaxUnion : std::numeric_limits<uint64_t>::min());
umaxCount += 1;
}
}
return std::make_pair(uminUnion, umaxUnion);
return std::make_pair(uminCount > 0 && uminCount == assumptions.size()
? std::optional<uint64_t>(uminUnion)
: std::nullopt,
umaxCount > 0 && umaxCount == assumptions.size()
? std::optional<uint64_t>(umaxUnion)
: std::nullopt);
}

// Gets the unioned divisor for an operand. If there are multiple divisor
// assumptions, the gcd of all of them is returned. If there are no
// divisor assumptions, std::nullopt is returned.
std::optional<uint64_t>
AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) {
auto assumptions = getOperandAssumptions(operandIndex);
std::optional<uint64_t> divisorUnion;
for (auto assumption : assumptions) {
auto divisor = assumption.getUdiv();
if (divisor) {
if (divisorUnion)
divisorUnion = std::gcd(*divisor, *divisorUnion);
else
divisorUnion = *divisor;
}
if (!divisor)
return std::nullopt;
if (divisorUnion)
divisorUnion = std::gcd(*divisor, *divisorUnion);
else
divisorUnion = *divisor;
}
return divisorUnion;
}
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def OpGroupAddressOffsetArithmeticOps : OpDocGroup {
let opDocGroup = OpGroupAddressOffsetArithmeticOps in {

def Util_AlignOp : Util_PureOp<"align", [
SameOperandsAndResultType
SameOperandsAndResultType,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<InferIntDivisibilityOpInterface, ["inferResultRanges"]>
]> {
let summary = "Aligns up to a power-of-two alignment if required";
let description = [{
Expand Down Expand Up @@ -504,14 +506,15 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [

// Gets the unioned unsigned range for an operand. If there are multiple
// assumptions for the operand, this will return the bounding range for
// them all. If there is no umin/umax, then std::nullopt will be returned
// for that position.
// them all. If there is no umin/umax for any row in the set, then
// std::nullopt will be returned for that position.
std::pair<std::optional<uint64_t>, std::optional<uint64_t>>
getUnionedUnsignedRange(unsigned operandIndex);

// Gets the unioned divisor for an operand. If there are multiple divisor
// assumptions, the gcd of all of them is returned. If there are no
// divisor assumptions, std::nullopt is returned.
// divisor assumptions or if there is not a udiv for any row, std::nullopt
// is returned.
std::optional<uint64_t> getUnionedUnsignedDivisor(unsigned operandIndex);
}];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithTransforms",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ iree_cc_library(
::PassesIncGen
LLVMSupport
MLIRAffineDialect
MLIRAffineTransforms
MLIRAffineUtils
MLIRAnalysis
MLIRArithDialect
MLIRArithTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -108,19 +111,35 @@ struct ConvertOpToUnsigned : public OpRewritePattern<Signed> {
// optimizations, it can be useful to eliminate them when possible.
//===----------------------------------------------------------------------===//

// Matches IR like:
// %5 = arith.addi %0, %1 : int64
// %6 = arith.index_castui %5 : int64 to index
//
// And moves the index_castui to the producer's operands:
// %3 = arith.index_castui %0 : int64 to index
// %4 = arith.index_castui %1 : int64 to index
// %5 = arith.addi %3, %4 : index
//
struct ConvertUnsignedI64IndexCastProducerToIndex
: public OpRewritePattern<arith::IndexCastUIOp> {
ConvertUnsignedI64IndexCastProducerToIndex(MLIRContext *context,
DataFlowSolver &solver)
: OpRewritePattern(context), solver(solver) {}

LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
LogicalResult matchAndRewrite(arith::IndexCastUIOp origIndexOp,
PatternRewriter &rewriter) const override {
Type inType = op.getIn().getType();
Type outType = op.getOut().getType();
Type inType = origIndexOp.getIn().getType();
Type outType = origIndexOp.getOut().getType();
if (!inType.isSignlessInteger(64) && isa<IndexType>(outType))
return failure();

Operation *producer = origIndexOp.getIn().getDefiningOp();
if (!producer)
return failure();
auto producerResult = producer->getResult(0);
if (!producerResult.hasOneUse())
return failure();

auto pred = [&](Value v) -> bool {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result || result->getValue().isUninitialized()) {
Expand All @@ -137,14 +156,14 @@ struct ConvertUnsignedI64IndexCastProducerToIndex
llvm::all_of(op->getResults(), pred);
};

Operation *producer = op.getIn().getDefiningOp();
if (!isa_and_present<arith::AddIOp, arith::CeilDivUIOp, arith::DivUIOp,
arith::MaxUIOp, arith::MinUIOp, arith::MulIOp,
arith::RemUIOp, arith::SubIOp>(producer))
return failure();
if (!isOpStaticallyLegal(producer))
return failure();

// Make modifications.
rewriter.modifyOpInPlace(producer, [&]() {
rewriter.setInsertionPoint(producer);
for (auto &operand : producer->getOpOperands()) {
Expand All @@ -156,6 +175,8 @@ struct ConvertUnsignedI64IndexCastProducerToIndex
}
producer->getResult(0).setType(outType);
});
origIndexOp.getOut().replaceAllUsesWith(producer->getResult(0));
rewriter.eraseOp(origIndexOp);

return success();
}
Expand Down Expand Up @@ -206,6 +227,52 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern<arith::RemUIOp> {
DataFlowSolver &solver;
};

//===----------------------------------------------------------------------===//
// Affine expansion
// affine.apply expansion can fail after producing a lot of IR. Since this is
// a bad thing to be doing as part of our overall iteration, we do it as a
// preprocessing walk. This also lets it be well behaved with respect to
// error messaging, etc. We will likely replace this with a more integrated
// version at some point which can use the bounds analysis to avoid corners
// of the original.
//===----------------------------------------------------------------------===//

void expandAffineOps(Operation *rootOp) {
IRRewriter rewriter(rootOp->getContext());
rootOp->walk([&](affine::AffineApplyOp op) {
LLVM_DEBUG(dbgs() << "** Expand affine.apply: " << op << "\n");
rewriter.setInsertionPoint(op);
auto maybeExpanded =
mlir::affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<4>(op.getOperands()));
if (!maybeExpanded) {
LLVM_DEBUG(dbgs() << "** ERROR: Failed to expand affine.apply\n");
return;
}
rewriter.replaceOp(op, *maybeExpanded);
});
}

//===----------------------------------------------------------------------===//
// General optimization patterns
//===----------------------------------------------------------------------===//

struct ElideTruncOfIndexCast : public OpRewritePattern<arith::TruncIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
PatternRewriter &rewriter) const override {
Operation *producer = truncOp.getOperand().getDefiningOp();
if (!producer)
return failure();
if (!isa<arith::IndexCastOp, arith::IndexCastUIOp>(producer))
return failure();
rewriter.replaceOpWithNewOp<arith::IndexCastUIOp>(
truncOp, truncOp.getResult().getType(), producer->getOperand(0));
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass setup
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -270,6 +337,9 @@ class OptimizeIntArithmeticPass
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();

expandAffineOps(op);

DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
Expand All @@ -281,13 +351,15 @@ class OptimizeIntArithmeticPass
arith::populateIntRangeOptimizationsPatterns(patterns, solver);

// Populate canonicalization patterns.
auto arithDialectTypeID =
ctx->getOrLoadDialect<arith::ArithDialect>()->getTypeID();
auto arithDialect = ctx->getOrLoadDialect<arith::ArithDialect>();
for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) {
if (name.getDialect().getTypeID() == arithDialectTypeID)
if (&name.getDialect() == arithDialect)
name.getCanonicalizationPatterns(patterns, ctx);
}

// General optimization patterns.
patterns.add<ElideTruncOfIndexCast>(ctx);

// Populate unsigned conversion patterns.
patterns.add<ConvertUnsignedI64IndexCastProducerToIndex,
ConvertOpToUnsigned<arith::CeilDivSIOp, arith::CeilDivUIOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,14 @@ util.func @remui_div_by_unrelated(%arg0 : index) -> index {
%1 = arith.remui %0, %cst : index
util.return %1 : index
}

// -----
// A missing udiv in a multi-row assumption is treated as an unknown.
// CHECK-LABEL: @missing_udiv_skipped
util.func @missing_udiv_skipped(%arg0 : index) -> index {
// CHECK: arith.remui
%cst = arith.constant 16 : index
%0 = util.assume.int %arg0[<udiv = 16>, <>] : index
%1 = arith.remui %0, %cst : index
util.return %1 : index
}
Loading

0 comments on commit 1500641

Please sign in to comment.