From 0f28d441bc07f137fea23fcc516463e461e02d93 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 8 Oct 2024 13:20:59 -0700 Subject: [PATCH] Use upstream dataflow tooling to build an arithmetic opt pass. (#18702) This combines several things into one fixpoint iteration: * Upstream IntRangeOptimizations for taking care of things like constant replacement for unit ranges. * Arith canonicalizations. * Local adaptation of signed->unsigned conversion (upstream's version can't compose since it is based on dialect conversion for some reason). It also has 32bit bugs that have been corrected locally. * Int64/unsigned index conversion. * Common factor elision for integer division. * Making the util.assume ops implement InferIntRangeInterface. I have some additional advanced patterns to the side which simplify a lot of torch cases, but they need some more baking/testing, so I'm just landing the basic pass for now to start. --------- Signed-off-by: Stella Laurenzo --- .../bazel_to_cmake/bazel_to_cmake_targets.py | 1 + .../iree/compiler/Dialect/Util/IR/BUILD.bazel | 1 + .../iree/compiler/Dialect/Util/IR/UtilOps.cpp | 74 ++++ .../iree/compiler/Dialect/Util/IR/UtilOps.h | 1 + .../iree/compiler/Dialect/Util/IR/UtilOps.td | 22 +- .../Dialect/Util/Transforms/BUILD.bazel | 1 + .../Dialect/Util/Transforms/CMakeLists.txt | 1 + .../Util/Transforms/OptimizeIntArithmetic.cpp | 291 ++++++++++++++ .../compiler/Dialect/Util/Transforms/Passes.h | 1 + .../Dialect/Util/Transforms/Passes.td | 7 + .../Dialect/Util/Transforms/test/BUILD.bazel | 1 + .../Util/Transforms/test/CMakeLists.txt | 1 + .../test/optimize_int_arithmetic.mlir | 361 ++++++++++++++++++ 13 files changed, 762 insertions(+), 1 deletion(-) create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index 2f3f057e3a7f..cecc21777f5f 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -58,6 +58,7 @@ def __init__(self, repo_map: Dict[str, str]): ], # MLIR "@llvm-project//mlir:AllPassesAndDialects": ["MLIRAllDialects"], + "@llvm-project//mlir:ArithOpsIncGen": ["MLIRArithDialect"], "@llvm-project//mlir:BufferizationInterfaces": [""], "@llvm-project//mlir:CommonFolders": [""], "@llvm-project//mlir:ConversionPasses": [""], diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel index da0d9ae71189..ce3726e0dac3 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/IR/BUILD.bazel @@ -32,6 +32,7 @@ iree_td_library( "@llvm-project//mlir:CallInterfacesTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferIntRangeInterfaceTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index 73230ac1418f..201ce8ac5a90 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -8,6 +8,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/SMLoc.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" @@ -22,6 +23,8 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include + namespace mlir::iree_compiler { //===----------------------------------------------------------------------===// @@ -1102,6 +1105,77 @@ namespace mlir::iree_compiler::IREE::Util { // util.assume.int //===----------------------------------------------------------------------===// +SmallVector +AssumeIntOp::getOperandAssumptions(unsigned operandIndex) { + assert(operandIndex < getNumOperands() && + "getUnionedUnsignedRange operand out of range"); + auto assumptions = cast(getAssumptions()[operandIndex]); + SmallVector results; + for (auto assumption : assumptions) { + results.push_back(cast(assumption)); + } + return results; +} + +std::pair, std::optional> +AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { + auto assumptions = getOperandAssumptions(operandIndex); + std::optional uminUnion; + std::optional umaxUnion; + + for (auto assumption : assumptions) { + auto umin = assumption.getUmin(); + auto umax = assumption.getUmax(); + if (umin) { + uminUnion = std::min( + *umin, uminUnion ? *uminUnion : std::numeric_limits::max()); + } + if (umax) { + umaxUnion = std::max( + *umax, umaxUnion ? *umaxUnion : std::numeric_limits::min()); + } + } + return std::make_pair(uminUnion, umaxUnion); +} + +// Gets the unioned divisor for an operand. If there are multiple divisor +// assumptions, the gcd of all of them is returned. If there are no +// divisor assumptions, std::nullopt is returned. +std::optional AssumeIntOp::getUnionedDivisor(unsigned operandIndex) { + auto assumptions = getOperandAssumptions(operandIndex); + std::optional divisorUnion; + for (auto assumption : assumptions) { + auto divisor = assumption.getDivisor(); + if (divisor) { + if (divisorUnion) + divisorUnion = std::gcd(*divisor, *divisorUnion); + else + divisorUnion = *divisor; + } + } + return divisorUnion; +} + +void AssumeIntOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + for (auto [index, result] : llvm::enumerate(getResults())) { + Type type = result.getType(); + unsigned bitWidth; + if (isa(type)) + bitWidth = 64; + else if (auto intType = dyn_cast(type)) + bitWidth = intType.getWidth(); + else + continue; + auto [umin, umax] = getUnionedUnsignedRange(index); + if (umin && umax) { + APInt uminAp(bitWidth, *umin); + APInt umaxAp(bitWidth, *umax); + setResultRange(result, ConstantIntRanges::fromUnsigned(uminAp, umaxAp)); + } + } +} + void AssumeIntOp::build(OpBuilder &builder, OperationState &state, Value singleOperand, IntAssumptionAttr singleAssumption) { diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h index 4a58a1eb2252..1623b8e9a5bb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.h @@ -19,6 +19,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 2b62cd2d5455..3a7179b16412 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -18,6 +18,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -458,7 +459,9 @@ def OpGroupCompilerHintOps : OpDocGroup { let opDocGroup = OpGroupCompilerHintOps in { -def Util_AssumeIntOp : Util_PureOp<"assume.int", []> { +def Util_AssumeIntOp : Util_PureOp<"assume.int", [ + DeclareOpInterfaceMethods +]> { let summary = "memorializes assumptions about index/integer values."; let description = [{ This op is used to memorialize the result of some integer analysis or @@ -490,6 +493,23 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", []> { )>, ]; + let extraClassDeclaration = [{ + // Gets the list of assumptions for an operand. + SmallVector getOperandAssumptions(unsigned operandIndex); + + // Gets the unioned unsigned range for an operand. If there are multiple + // assumptions for the operand, this will return the bounding range for + // them all. If there is no umin/umax, then std::nullopt will be returned + // for that position. + std::pair, std::optional> + getUnionedUnsignedRange(unsigned operandIndex); + + // Gets the unioned divisor for an operand. If there are multiple divisor + // assumptions, the gcd of all of them is returned. If there are no + // divisor assumptions, std::nullopt is returned. + std::optional getUnionedDivisor(unsigned operandIndex); + }]; + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index eb10119c2c83..a98135d57bdb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -26,6 +26,7 @@ iree_compiler_cc_library( "HoistIntoGlobals.cpp", "IPO.cpp", "ImportResources.cpp", + "OptimizeIntArithmetic.cpp", "PassDetail.h", "Passes.cpp", "Patterns.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index acb7a4e169d9..d233f11e0278 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_library( "HoistIntoGlobals.cpp" "IPO.cpp" "ImportResources.cpp" + "OptimizeIntArithmetic.cpp" "PassDetail.h" "Passes.cpp" "Patterns.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp new file mode 100644 index 000000000000..04ae8b707baf --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -0,0 +1,291 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h" +#include "iree/compiler/Dialect/Util/Transforms/Passes.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-util-optimize-arithmetic" +using llvm::dbgs; + +using namespace mlir::dataflow; + +namespace mlir::iree_compiler::IREE::Util { + +namespace { + +// An index_cast from i64 to index is a no-op on targets where index is +// 64 bits. But on targets where index is 32bits, it is a truncate. On these +// platforms, demoting to an index is only conservatively correct if all +// operands and all results are within the unsigned 32bit bounds. +// While there is a good chance that such arithmetic that exceeds these +// bounds is simply wrong/overflow-ridden, we opt to do no harm and preseve +// the exact results. This optimization is targeted at "small" sequences +// anyway and this catches everything known to exist. If needed, this rule +// could be dropped if it is ever appropriate to unconditionally assume +// 64bit semantics. +static constexpr uint64_t SAFE_INDEX_UNSIGNED_MAX_VALUE = + std::numeric_limits::max(); + +//===----------------------------------------------------------------------===// +// Signed -> Unsigned patterns +// Note that there is an upstream UnsignedWhenEquivalent pass but it uses +// DialectConversion and legality vs simple patterns, so we cannot use it. +// Some support code has been adapted from that pass, though. +//===----------------------------------------------------------------------===// + +/// Succeeds when a value is statically non-negative in that it has a lower +/// bound on its value (if it is treated as signed) and that bound is +/// non-negative. +static bool staticallyLegalToConvertToUnsigned(DataFlowSolver &solver, + Value v) { + auto *result = solver.lookupState(v); + if (!result || result->getValue().isUninitialized()) { + return false; + } + const ConstantIntRanges &range = result->getValue().getValue(); + bool isNonNegative = range.smin().isNonNegative(); + Type type = v.getType(); + if (isa(type)) { + bool canSafelyTruncate = + range.umin().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE && + range.umax().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE; + return isNonNegative && canSafelyTruncate; + } else { + return isNonNegative; + } +} + +/// Succeeds if an op can be converted to its unsigned equivalent without +/// changing its semantics. This is the case when none of its openands or +/// results can be below 0 when analyzed from a signed perspective. +static LogicalResult +staticallyLegalToConvertToUnsignedOp(DataFlowSolver &solver, Operation *op) { + auto nonNegativePred = [&solver](Value v) -> bool { + bool isNonNegative = staticallyLegalToConvertToUnsigned(solver, v); + return isNonNegative; + }; + return success(llvm::all_of(op->getOperands(), nonNegativePred) && + llvm::all_of(op->getResults(), nonNegativePred)); +} + +template +struct ConvertOpToUnsigned : public OpRewritePattern { + ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &solver) + : OpRewritePattern(context), solver(solver) {} + + LogicalResult matchAndRewrite(Signed op, + PatternRewriter &rewriter) const override { + if (failed(staticallyLegalToConvertToUnsignedOp(solver, op))) + return failure(); + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + op->getOperands(), op->getAttrs()); + return success(); + } + + DataFlowSolver &solver; +}; + +//===----------------------------------------------------------------------===// +// Int64 -> unsigned index demotion +// Torch does a lot of indexy manipulation using scalar i64 ops. We undo these +// here and treat them as index when safe to do so. Since the casts can block +// optimizations, it can be useful to eliminate them when possible. +//===----------------------------------------------------------------------===// + +struct ConvertUnsignedI64IndexCastProducerToIndex + : public OpRewritePattern { + ConvertUnsignedI64IndexCastProducerToIndex(MLIRContext *context, + DataFlowSolver &solver) + : OpRewritePattern(context), solver(solver) {} + + LogicalResult matchAndRewrite(arith::IndexCastUIOp op, + PatternRewriter &rewriter) const override { + Type inType = op.getIn().getType(); + Type outType = op.getOut().getType(); + if (!inType.isSignlessInteger(64) && isa(outType)) + return failure(); + + auto pred = [&](Value v) -> bool { + auto *result = solver.lookupState(v); + if (!result || result->getValue().isUninitialized()) { + return false; + } + const ConstantIntRanges &range = result->getValue().getValue(); + bool isInBounds = + range.umin().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE && + range.umax().getZExtValue() <= SAFE_INDEX_UNSIGNED_MAX_VALUE; + return isInBounds; + }; + auto isOpStaticallyLegal = [&](Operation *op) -> bool { + return llvm::all_of(op->getOperands(), pred) && + llvm::all_of(op->getResults(), pred); + }; + + Operation *producer = op.getIn().getDefiningOp(); + if (!isa_and_present(producer)) + return failure(); + if (!isOpStaticallyLegal(producer)) + return failure(); + + rewriter.modifyOpInPlace(producer, [&]() { + rewriter.setInsertionPoint(producer); + for (auto &operand : producer->getOpOperands()) { + if (operand.get().getType() != inType) + continue; + Value newOperand = rewriter.create( + producer->getLoc(), outType, operand.get()); + operand.set(newOperand); + } + producer->getResult(0).setType(outType); + }); + + return success(); + } + + DataFlowSolver &solver; +}; + +//===----------------------------------------------------------------------===// +// Pass setup +//===----------------------------------------------------------------------===// + +class DataFlowListener : public RewriterBase::Listener { +public: + DataFlowListener(DataFlowSolver &s) : s(s) {} + +protected: + void notifyOperationErased(Operation *op) override { + s.eraseState(op); + for (Value res : op->getResults()) + flushValue(res); + } + void notifyOperationModified(Operation *op) override { + for (Value res : op->getResults()) + flushValue(res); + } + void notifyOperationReplaced(Operation *op, Operation *replacement) override { + for (Value res : op->getResults()) + flushValue(res); + } + + void notifyOperationReplaced(Operation *op, ValueRange replacement) override { + for (Value res : op->getResults()) + flushValue(res); + } + + void flushValue(Value value) { + SmallVector worklist; + SmallVector process; + worklist.push_back(value); + + while (!worklist.empty()) { + process.clear(); + process.swap(worklist); + for (Value childValue : process) { + auto *state = s.lookupState(childValue); + if (!state) { + continue; + } + s.eraseState(childValue); + for (auto user : childValue.getUsers()) { + for (Value result : user->getResults()) { + worklist.push_back(result); + } + } + } + } + } + + DataFlowSolver &s; +}; + +class OptimizeIntArithmeticPass + : public OptimizeIntArithmeticBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + DataFlowListener listener(solver); + RewritePatternSet patterns(ctx); + + // Populate upstream arith patterns. + arith::populateIntRangeOptimizationsPatterns(patterns, solver); + + // Populate canonicalization patterns. + auto arithDialectTypeID = + ctx->getOrLoadDialect()->getTypeID(); + for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) { + if (name.getDialect().getTypeID() == arithDialectTypeID) + name.getCanonicalizationPatterns(patterns, ctx); + } + + // Populate unsigned conversion patterns. + patterns.add, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned>(ctx, + solver); + + GreedyRewriteConfig config; + // Results in fewer recursive data flow flushes/cycles on modification. + config.useTopDownTraversal = false; + config.listener = &listener; + + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + for (int i = 0;; ++i) { + if (failed(solver.initializeAndRun(op))) { + emitError(op->getLoc()) << "failed to perform int range analysis"; + return signalPassFailure(); + } + + bool changed = false; + if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config, + &changed))) { + emitError(op->getLoc()) + << "int arithmetic optimization failed to converge on iteration " + << i; + return signalPassFailure(); + } + + if (!changed) + break; + } + } +}; + +} // namespace + +std::unique_ptr> createOptimizeIntArithmetic() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler::IREE::Util diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h index 0e3f66756ef5..9b5550de0625 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h @@ -30,6 +30,7 @@ createFixedPointIteratorPass(OpPassManager pipeline); std::unique_ptr> createFoldGlobalsPass(); std::unique_ptr> createFuseGlobalsPass(); std::unique_ptr> createIPOPass(); +std::unique_ptr> createOptimizeIntArithmetic(); std::unique_ptr> createPropagateSubrangesPass(); std::unique_ptr> createSimplifyGlobalAccessesPass(); std::unique_ptr> diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index adec1254b4a3..390cdd48c722 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -65,6 +65,13 @@ def IPO : Pass<"iree-util-ipo", "mlir::ModuleOp"> { }]; } +def OptimizeIntArithmetic : Pass<"iree-util-optimize-int-arithmetic", ""> { + let summary = "Optimizes integer arithmetic using a variety of dataflow analysis and patterns."; + let constructor = [{ + mlir::iree_compiler::IREE::Util::createOptimizeIntArithmetic() + }]; +} + def PropagateSubranges : Pass<"iree-util-propagate-subranges", "mlir::ModuleOp"> { let summary = "Propagates resource subranges across the program."; let constructor = [{ diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel index ef098856f558..4a1135868c72 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel @@ -25,6 +25,7 @@ iree_lit_test_suite( "hoist_into_globals_linalg.mlir", "import_resources.mlir", "ipo.mlir", + "optimize_int_arithmetic.mlir", "patterns.mlir", "propagate_subranges.mlir", "simplify_global_accesses.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt index f865da9c5df6..2109f2abdb53 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt @@ -23,6 +23,7 @@ iree_lit_test_suite( "hoist_into_globals_linalg.mlir" "import_resources.mlir" "ipo.mlir" + "optimize_int_arithmetic.mlir" "patterns.mlir" "propagate_subranges.mlir" "simplify_global_accesses.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir new file mode 100644 index 000000000000..39ff74c2cc96 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -0,0 +1,361 @@ +// RUN: iree-opt --split-input-file --iree-util-optimize-int-arithmetic %s | FileCheck %s +// We inherit a number of patterns from upstream for optimizing specific arith +// operations. Those are not the focus of testing, but we may test some of them +// here incidentally as part of verifying that the overall pass and local +// patterns are effective. +// Many of these tests take advantage of the fact that if a value can be +// inferred for arith.cmpi, a constant i1 will be substituted for it. + +// CHECK-LABEL: @index_upper_bound +util.func @index_upper_bound(%arg0 : index) -> i1 { + // CHECK: %[[RESULT:.*]] = arith.constant true + // CHECK: util.return %[[RESULT]] + %cst = arith.constant 101 : index + %0 = util.assume.int %arg0 : index + %1 = arith.cmpi ult, %0, %cst : index + util.return %1 : i1 +} + +// ----- +// CHECK-LABEL: @index_lower_bound +util.func @index_lower_bound(%arg0 : index) -> i1 { + // CHECK: %[[RESULT:.*]] = arith.constant true + // CHECK: util.return %[[RESULT]] + %cst = arith.constant 5 : index + %0 = util.assume.int %arg0 : index + %1 = arith.cmpi ugt, %0, %cst : index + util.return %1 : i1 +} + +// ----- +// CHECK-LABEL: @index_indeterminate +util.func @index_indeterminate(%arg0 : index) -> i1 { + // CHECK: arith.cmpi + %cst = arith.constant 50 : index + %0 = util.assume.int %arg0 : index + %1 = arith.cmpi ugt, %0, %cst : index + util.return %1 : i1 +} + +// ----- +// CHECK-LABEL: @index_multi_assumptions_unioned +util.func @index_multi_assumptions_unioned(%arg0 : index) -> i1, i1, i1 { + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[C51:.*]] = arith.constant 51 + %cst5 = arith.constant 5 : index + %cst51 = arith.constant 51 : index + %cst101 = arith.constant 101 : index + %0 = util.assume.int %arg0[ + , + + ] : index + %1 = arith.cmpi ult, %0, %cst5 : index // Statically false + // CHECK: %[[DYNAMIC:.*]] = arith.cmpi ult, {{.*}}, %[[C51]] + %2 = arith.cmpi ult, %0, %cst51 : index // Cannot be determined + %3 = arith.cmpi ult, %0, %cst101 : index // Statically true + // CHECK: return %[[FALSE]], %[[DYNAMIC]], %[[TRUE]] + util.return %1, %2, %3 : i1, i1, i1 +} + +// ----- +// This checks a corner case that has to line up with how util.assume.int +// signals its range to the int range analsysis. Here, if interpreting the +// umax as a signed value (which is what is used for evaluating an sgt in the +// arith.cmpi op), it poisons the analysis by assuming that the signed range is +// the entire signed range of the data type. This means that any signed +// evaluation will fail, whereas an unsigned will succeed. +// CHECK-LABEL: @index_unsigned_overflow_signed +util.func @index_unsigned_overflow_signed(%arg0 : index) -> i1, i1 { + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + %cst = arith.constant 5 : index + %0 = util.assume.int %arg0 : index + // CHECK: %[[POISON:.*]] = arith.cmpi sgt + %1 = arith.cmpi sgt, %0, %cst : index + %2 = arith.cmpi ugt, %0, %cst : index + // CHECK: util.return %[[POISON]], %[[TRUE]] + util.return %1, %2 : i1, i1 +} + +// ----- +// Minimal testing to ensure that integer data types < 64 bits do the right +// thing. This exercises some APInt bit manipulation in our interface +// implementations. +// CHECK-LABEL: @int_upper_bound +util.func @int_upper_bound(%arg0 : i32) -> i1 { + // CHECK: %[[RESULT:.*]] = arith.constant true + // CHECK: util.return %[[RESULT]] + %cst = arith.constant 101 : i32 + %0 = util.assume.int %arg0 : i32 + %1 = arith.cmpi ult, %0, %cst : i32 + util.return %1 : i1 +} + +// ----- +// Validate the signed/unsigned mismatch corner case on a type narrower than +// 64 bits. +// CHECK-LABEL: @int32_unsigned_overflow_signed +util.func @int32_unsigned_overflow_signed(%arg0 : i32) -> i1, i1 { + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + %cst = arith.constant 5 : i32 + // Max is one greater than the i32 signed positive range. + %0 = util.assume.int %arg0 : i32 + // CHECK: %[[POISON:.*]] = arith.cmpi sgt + %1 = arith.cmpi sgt, %0, %cst : i32 + %2 = arith.cmpi ugt, %0, %cst : i32 + // CHECK: util.return %[[POISON]], %[[TRUE]] + util.return %1, %2 : i1, i1 +} + +// ----- +// CHECK-LABEL: @to_unsigned_ceildivsi +util.func @to_unsigned_ceildivsi(%arg0 : i64, %arg1 : i64) -> i64, i64, i64 { + %0 = util.assume.int %arg0 : i64 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i64 + // CHECK: ceildivui + // CHECK: ceildivsi + // CHECK: ceildivsi + %2 = arith.ceildivsi %0, %0 : i64 + %3 = arith.ceildivsi %0, %1 : i64 + %4 = arith.ceildivsi %1, %0 : i64 + util.return %2, %3, %4 : i64, i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_divsi +util.func @to_unsigned_divsi(%arg0 : i64, %arg1 : i64) -> i64, i64, i64 { + %0 = util.assume.int %arg0 : i64 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i64 + // CHECK: divui + // CHECK: divsi + // CHECK: divsi + %2 = arith.divsi %0, %0 : i64 + %3 = arith.divsi %0, %1 : i64 + %4 = arith.divsi %1, %0 : i64 + util.return %2, %3, %4 : i64, i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_floordivsi +util.func @to_unsigned_floordivsi(%arg0 : i64, %arg1 : i64) -> i64, i64, i64 { + %0 = util.assume.int %arg0 : i64 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i64 + // CHECK: divui + // CHECK: divsi + // CHECK: divsi + %2 = arith.floordivsi %0, %0 : i64 + %3 = arith.floordivsi %0, %1 : i64 + %4 = arith.floordivsi %1, %0 : i64 + util.return %2, %3, %4 : i64, i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_index_cast +util.func @to_unsigned_index_cast(%arg0 : index, %arg1 : index) -> i64, i64 { + %0 = util.assume.int %arg0 : index + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : index + // CHECK: index_castui + %2 = arith.index_cast %0 : index to i64 + // CHECK: index_cast + %3 = arith.index_cast %1 : index to i64 + util.return %2, %3 : i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_remsi +util.func @to_unsigned_remsi(%arg0 : i64, %arg1 : i64) -> i64, i64, i64 { + %0 = util.assume.int %arg0 : i64 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i64 + // CHECK: remui + // CHECK: remsi + // CHECK: remsi + %2 = arith.remsi %0, %0 : i64 + %3 = arith.remsi %0, %1 : i64 + %4 = arith.remsi %1, %0 : i64 + util.return %2, %3, %4 : i64, i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_minsi +util.func @to_unsigned_minsi(%arg0 : i64, %arg1 : i64) -> i64, i64, i64 { + %0 = util.assume.int %arg0 : i64 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i64 + // Note that the first is converted to unsigned and then can be elided + // entirely. + // CHECK-NOT: minui + // CHECK: minsi + // CHECK: minsi + %2 = arith.minsi %0, %0 : i64 + %3 = arith.minsi %0, %1 : i64 + %4 = arith.minsi %1, %0 : i64 + util.return %2, %3, %4 : i64, i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_maxsi +util.func @to_unsigned_maxsi(%arg0 : i64, %arg1 : i64) -> i64, i64, i64 { + %0 = util.assume.int %arg0 : i64 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i64 + // Note that the first is converted to unsigned and then can be elided + // entirely. + // CHECK-NOT: maxui + // CHECK: maxsi + // CHECK: maxsi + %2 = arith.maxsi %0, %0 : i64 + %3 = arith.maxsi %0, %1 : i64 + %4 = arith.maxsi %1, %0 : i64 + util.return %2, %3, %4 : i64, i64, i64 +} + +// ----- +// CHECK-LABEL: @to_unsigned_extsi +util.func @to_unsigned_extsi(%arg0 : i32, %arg1 : i32) -> i64, i64 { + %0 = util.assume.int %arg0 : i32 + // One greater than the signed maximum. + %1 = util.assume.int %arg1 : i32 + // CHECK: extui + %2 = arith.extsi %0 : i32 to i64 + // CHECK: extsi + %3 = arith.extsi %1 : i32 to i64 + util.return %2, %3 : i64, i64 +} + +// ----- +// Tests the ConvertUnsignedI64IndexCastProducerToIndex pattern and the +// composition with other patterns to collapse entire sequences of +// index_cast (signed) -> i64 -> index_cast (signed) -> index. +// This sequence of tests uses signed ops where they exist so as to ensure that +// the cascade of rewrites and additional analysis composes together. This +// specifically tests that the listener properly erases/flushes and triggers +// additional cycles. +// CHECK-LABEL: @index_cast_i64_to_index_addi +util.func @index_cast_i64_to_index_addi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // CHECK: arith.addi %[[ASSUME]], %[[ASSUME]] : index + %2 = arith.addi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_ceildivsi +util.func @index_cast_i64_to_index_ceildivsi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // CHECK: arith.ceildivui %[[ASSUME]], %[[ASSUME]] : index + %2 = arith.ceildivsi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_floordivsi +util.func @index_cast_i64_to_index_floordivsi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // CHECK: arith.divui %[[ASSUME]], %[[ASSUME]] : index + %2 = arith.floordivsi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_maxsi +util.func @index_cast_i64_to_index_maxsi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // Note that the entire sequence is inferred to be removed. + // CHECK: util.return %[[ASSUME]] + %2 = arith.maxsi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_minsi +util.func @index_cast_i64_to_index_minsi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // Note that the entire sequence is inferred to be removed. + // CHECK: util.return %[[ASSUME]] + %2 = arith.minsi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_muli +util.func @index_cast_i64_to_index_muli(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // CHECK: arith.muli %[[ASSUME]], %[[ASSUME]] : index + %2 = arith.muli %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_remsi +util.func @index_cast_i64_to_index_remsi(%arg0 : index, %arg1 : index) -> index { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // CHECK: arith.remui %[[ASSUME]], %[[ASSUME]] : index + %2 = arith.remsi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_subi +util.func @index_cast_i64_to_index_subi(%arg0 : index, %arg1 : index) -> index { + %0 = util.assume.int %arg0 : index + %1 = arith.index_cast %0 : index to i64 + // Note that the subtraction should be inferred as elided. + // CHECK: %[[ZERO:.*]] = arith.constant 0 : index + // CHECK: util.return %[[ZERO]] + %2 = arith.subi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3 : index +} + +// ----- +// CHECK-LABEL: @index_cast_i64_to_index_addi_bad_signed_bounds +util.func @index_cast_i64_to_index_addi_bad_signed_bounds(%arg0 : i64) -> index { + %cst1 = arith.constant 1 : i64 + // Maximum 32bit unsigned value +1 (from the addi should reject). + %0 = util.assume.int %arg0 : i64 + %2 = arith.addi %0, %cst1 : i64 + // Out of bounds of conservative 32bit values so do not convert. + // CHECK: arith.addi + // CHECK: arith.index_castui + %3 = arith.index_castui %2 : i64 to index + util.return %3 : index +} + +// ----- +// Validate the index unsigned 32bit overflow case. +// CHECK-LABEL: @index_unsigned_overflow_signed +util.func @index_unsigned_overflow_signed(%arg0 : index) -> index { + %cst = arith.constant 5 : index + // Max is one greater than the i32 unsigned range. + %0 = util.assume.int %arg0 : index + // Should not convert to unsigned + // CHECK: arith.divsi + %1 = arith.divsi %0, %cst : index + util.return %1 : index +}