Skip to content

Commit

Permalink
Pass pipeline for lowering quantized type/ops for evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jun 8, 2024
1 parent 0ead75b commit 88c639f
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 6 deletions.
5 changes: 5 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ An experimental pass that legalizes shape-related ops to StableHLO ops.
Bringing shape and data computations together via an optional pass will
make it possible for the StableHLO ecosystem to potentially leverage the
compilation pipelines that use StableHLO operations to model dynamism.

#### Options
```
-legalize-constraints : Whether to legalize Cstr Ops to shape_assertion custom_call
```
### `-stablehlo-aggressive-folder`

_Folds StableHLO operations_
Expand Down
69 changes: 66 additions & 3 deletions stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ FailureOr<func::FuncOp> getMainFunction(ModuleOp module, StringRef mainName) {
class DefaultInterpreterFallback : public InterpreterFallback {
public:
DefaultInterpreterFallback(const InterpreterConfiguration &config)
: config(config){};
: config(config) {};

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
Expand Down Expand Up @@ -170,19 +170,82 @@ LogicalResult removeDynamism(ModuleOp module, func::FuncOp func,
return success();
}

bool isQuantizedTypes(TypeRange types) {
return llvm::any_of(types, [](Type type) {
return isa<quant::QuantizedType>(getElementTypeOrSelf(type));
});
}

// Recursively checks if an operation or any of its nested operations use
// quantized types.
//
// Args:
// op: The operation to check for quantized type usage.
//
// Returns:
// True if the operation or any nested operation uses quantized types,
// false otherwise.
bool operationUsesQuantType(Operation &op) {
for (Region &region : op.getRegions()) {
for (Block &block : region) {
for (Operation &op : block) {
TypeRange operandTypes = op.getOperandTypes();
TypeRange resultTypes = op.getResultTypes();
if (isQuantizedTypes(operandTypes) || isQuantizedTypes(resultTypes) ||
operationUsesQuantType(op)) {
return true;
}
}
}
}
return false;
}

// Lowers quantization-related operations and types within a function to
// primitive math operations.
//
// This function checks if a function uses quantized types in its inputs,
// outputs, or internal operations. If so, it creates and runs a StableHLO
// quantization lowering pipeline to transform those quantized constructs into
// primitive math operations. If the lowering process fails, an error is
// emitted.
//
// Args:
// module: The module containing the function `func`.
// func The function to lower quantized types/operations in.
//
// Returns:
// A `LogicalResult` indicating success or failure of the lowering process.
LogicalResult lowerQuantization(ModuleOp module, func::FuncOp func) {
if (!(isQuantizedTypes(func.getFunctionType().getInputs()) ||
operationUsesQuantType(*func.getOperation()) ||
isQuantizedTypes(func.getFunctionType().getResults()))) {
return success();
}

PassManager pm(func.getContext());
stablehlo::createStablehloLowerQuantPipeline(pm);
if (failed(pm.run(module))) {
return func.emitError("Failed to lower quantized types/ops in function: ")
<< func.getName();
}
return success();
}

} // namespace

FailureOr<SmallVector<InterpreterValue>> evalModule(
ModuleOp module, ArrayRef<InterpreterValue> inputs,
const InterpreterConfiguration &config) {
// Additional error checking at main function boundary.
// This is most likely user error, where future errors during interpreting are
// more likely invalid IR or interpreter bugs.
// This is most likely user error, where future errors during interpreting
// are more likely invalid IR or interpreter bugs.
if (module.getOps<func::FuncOp>().empty())
return SmallVector<InterpreterValue>();

auto mainFunc = getMainFunction(module, config.mainFunction);
if (failed(mainFunc) || failed(removeDynamism(module, *mainFunc, inputs)) ||
failed(lowerQuantization(module, *mainFunc)) ||
failed(validateEntrySignature(*mainFunc, inputs))) {
return failure();
}
Expand Down
54 changes: 54 additions & 0 deletions stablehlo/tests/interpret/quantized_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @uniform_quantize() {
%operand = stablehlo.constant dense<[4.0, 15.0]> : tensor<2xf32>
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%bitcast_result = "stablehlo.bitcast_convert"(%result) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xi8>
check.expect_eq_const %bitcast_result, dense<[10, 10]> : tensor<2xi8>
func.return
}

// -----

func.func @uniform_quantize() {
%operand = stablehlo.constant dense<[10, 10]> : tensor<2xi8>
%bitcast_operand = "stablehlo.bitcast_convert"(%operand) : (tensor<2xi8>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%result = "stablehlo.uniform_quantize"(%bitcast_operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
%bitcast_result = "stablehlo.bitcast_convert"(%result) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>) -> tensor<2xi8>
check.expect_eq_const %bitcast_result, dense<[20, 45]> : tensor<2xi8>
func.return
}

// -----

func.func @uniform_dequantize() {
%operand = stablehlo.constant dense<[10, 10]> : tensor<2xi8>
%bitcast_operand = "stablehlo.bitcast_convert"(%operand) : (tensor<2xi8>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%result = "stablehlo.uniform_dequantize"(%bitcast_operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
check.expect_almost_eq_const %result, dense<[4.0, 15.0]> : tensor<2xf32>
func.return
}


// -----

func.func @uniform_qdq() {
%operand = stablehlo.constant dense<[4.0, 15.0]> : tensor<2xf32>
%quantize = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
%result = "stablehlo.uniform_dequantize"(%quantize) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
check.expect_almost_eq_const %result, dense<[4.0, 15.0]> : tensor<2xf32>
func.return
}

// -----

func.func @quantized_add() {
%operand1 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
%operand2 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
%q_operand1 = "stablehlo.uniform_quantize"(%operand1) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:-30>>
%q_operand2 = "stablehlo.uniform_quantize"(%operand2) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
%result = "stablehlo.add"(%q_operand1, %q_operand2) : (tensor<2x!quant.uniform<i8:f32, 0.1:-30>>, tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
%bitcast_result = "stablehlo.bitcast_convert"(%result) : (tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2xi8>
check.expect_eq_const %bitcast_result, dense<[-12, -8]> : tensor<2xi8>
func.return
}
1 change: 1 addition & 0 deletions stablehlo/tools/StablehloTranslateMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ TranslateFromMLIRRegistration interpretRegistration(
},
[](DialectRegistry &registry) {
registry.insert<func::FuncDialect>();
registry.insert<quant::QuantizationDialect>();
registry.insert<stablehlo::check::CheckDialect>();
registry.insert<stablehlo::interpreter::InterpreterDialect>();
registry.insert<stablehlo::StablehloDialect>();
Expand Down
11 changes: 11 additions & 0 deletions stablehlo/transforms/PassPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ limitations under the License.
==============================================================================*/

#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"

Expand All @@ -36,6 +37,16 @@ void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
stablehlo::createStablehloCanonicalizeDynamismPass());
}

void createStablehloLowerQuantPipeline(OpPassManager &pm) {
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createStablehloLegalizeQuantToIntPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createShapeLegalizeToStablehloPass(true));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}

void registerPassPipelines() {
PassPipelineRegistration<>("stablehlo-deserialize",
"Run an example pipeline.",
Expand Down
12 changes: 11 additions & 1 deletion stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ void populateStablehloLegalizeDeprecatedOpsPatterns(

/// Collection of shape dialect to StableHLO patterns.
void populateShapeToStablehloPatterns(MLIRContext *context,
RewritePatternSet *patterns);
RewritePatternSet *patterns,
bool legalizeConstraints);

//// Additional pass constructors ////

// Legalizes from the Shape dialect to the StableHLO dialect.
std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createShapeLegalizeToStablehloPass(bool legalizeConstraints);

std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
TypeRange refinedTypes);

Expand Down Expand Up @@ -116,6 +121,11 @@ void createStablehloDeserializePipeline(OpPassManager &pm);
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);

// Decomposes quantized operations within a StableHLO module by
// applying a series of MLIR passes essentially breaking down the quantized
// operations into a primitive math operations.
void createStablehloLowerQuantPipeline(OpPassManager &pm);

// Adds `stablehlo-deserialize` pipeline as a registered pass pipeline
// for opt tools.
void registerPassPipelines();
Expand Down
4 changes: 4 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::Fu
compilation pipelines that use StableHLO operations to model dynamism.
}];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
let options = [
Option<"legalize_constraints_", "legalize-constraints", "bool",
/*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call">
];
}

def StablehloLegalizeDeprecatedOpsPass : Pass<"stablehlo-legalize-deprecated-ops", "func::FuncOp"> {
Expand Down
107 changes: 105 additions & 2 deletions stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,16 @@ Value castToIndex(PatternRewriter& rewriter, Location loc, Value value) {
return cast.getResult(0);
}

void insertShapeAssertionCustomCall(OpBuilder builder, Location loc,
Value assert) {
auto customCall = builder.create<stablehlo::CustomCallOp>(loc, TypeRange{},
ValueRange{assert});
customCall.setCallTargetName("shape_assertion");
customCall.setHasSideEffect(true);
customCall->setAttr("error_message",
builder.getStringAttr("Shape assertion failed"));
}

Value maybeCastToIndex(Value result, Value value, PatternRewriter& rewriter) {
if (isShapedOfI32(result)) return value;
return castToIndex(rewriter, value.getLoc(), value);
Expand Down Expand Up @@ -491,6 +501,75 @@ struct ConvertTensorFromElementsPattern
}
};

struct ConvertCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter& rewriter) const override {
// The way CstrBroadcastableOp is defined, its inputs inputs must be 1D
// tensor or !shape.shape. We only support inputs of two 1D tensors.
if (op.getShapes().size() != 2) return failure();
auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front());
auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back());
if (!shape1 || !shape2) return failure();
auto tensorType1 = dyn_cast<RankedTensorType>(shape1.getType());
auto tensorType2 = dyn_cast<RankedTensorType>(shape2.getType());
if (!tensorType1 || !tensorType2) return failure();

// If the two operand shapes are of different sizes, the smaller one is
// padded with 1's from the left.
if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) {
shape1 =
padFromLeft(rewriter, op.getLoc(), shape1,
tensorType2.getDimSize(0) - tensorType1.getDimSize(0));
} else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) {
shape2 =
padFromLeft(rewriter, op.getLoc(), shape2,
tensorType1.getDimSize(0) - tensorType2.getDimSize(0));
}

// Compute if each dim is broadcastable. A dim is broadcastable iff
// dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1
int32_t rank =
std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0));
auto allOne = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), DenseIntElementsAttr::get<int32_t>(
RankedTensorType::get({rank}, rewriter.getI32Type()),
static_cast<int32_t>(1)));
Value dimSize1Is1 = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape1, allOne, ComparisonDirection::EQ);
Value dimSize2Is1 = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape2, allOne, ComparisonDirection::EQ);
Value eitherDimSizeIs1 =
rewriter.create<stablehlo::OrOp>(op.getLoc(), dimSize1Is1, dimSize2Is1);
Value dimSizeEq = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape1, shape2, ComparisonDirection::EQ);
Value dimBroadcastable = rewriter.create<stablehlo::OrOp>(
op.getLoc(), eitherDimSizeIs1, dimSizeEq);

// Iterate over each dim to check that all dims are broadcastable.
auto boolType = RankedTensorType::get({1}, rewriter.getI1Type());
Value allBroadcastable = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), DenseIntElementsAttr::get<bool>(boolType, true));
for (auto i = 0; i < rank; ++i) {
Value broadcastable =
rewriter.create<SliceOp>(op.getLoc(), dimBroadcastable, i, i + 1, 1);
allBroadcastable =
rewriter.create<AndOp>(op.getLoc(), allBroadcastable, broadcastable);
}
Value allBroadcastableScalar = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()),
allBroadcastable);

// Add CustomCallOp and replace Cstr op with const witness, which is useful
// for canonicalizer to remove the shape.assuming region.
insertShapeAssertionCustomCall(rewriter, op->getLoc(),
allBroadcastableScalar);
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
return success();
}
};

template <typename OpType>
struct CastOperandsPattern : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;
Expand Down Expand Up @@ -523,6 +602,12 @@ struct ShapeLegalizeToStablehloPass
ShapeLegalizeToStablehloPass> {
using ShapeLegalizeToStablehloPassBase::ShapeLegalizeToStablehloPassBase;

explicit ShapeLegalizeToStablehloPass(bool legalizeConstraints)
: impl::ShapeLegalizeToStablehloPassBase<
ShapeLegalizeToStablehloPass>::ShapeLegalizeToStablehloPassBase() {
this->legalize_constraints_ = legalizeConstraints;
}

LogicalResult initialize(MLIRContext* context) override {
// In order to make dynamic StableHLO programs compatible with HLO, we need
// to get rid of all non-StableHLO ops.
Expand All @@ -548,6 +633,10 @@ struct ShapeLegalizeToStablehloPass
// is able to remove unnecessary cruft. At the moment, this pass is a
// work in progress, so not all of these ops are supported.
//
// When legalize_constraints_ is set true, cstr* ops are also legalized.
// A shape_assertion custom_call is used to check the constraint. And the
// shape.assuming region will consume a shape.const_witness that evaluate to
// true, so that it can be removed later in a canonicalizer pass.
target = std::make_shared<ConversionTarget>(*context);
target->addIllegalDialect<shape::ShapeDialect>();
target->addIllegalDialect<tensor::TensorDialect>();
Expand All @@ -559,6 +648,10 @@ struct ShapeLegalizeToStablehloPass
});
target->addLegalOp<tensor::CastOp>();
target->addLegalOp<UnrealizedConversionCastOp>();
if (this->legalize_constraints_) {
target->addLegalOp<shape::ConstWitnessOp, shape::AssumingOp,
shape::AssumingYieldOp>();
}

// The patterns do what one might expect, converting between MLIR-style
// and HLO-style shape computations.
Expand All @@ -569,7 +662,8 @@ struct ShapeLegalizeToStablehloPass
// to ultimately annihilate with each other upon canonicalization if
// everything went right.
RewritePatternSet patterns_(context);
populateShapeToStablehloPatterns(context, &patterns_);
populateShapeToStablehloPatterns(context, &patterns_,
this->legalize_constraints_);
patterns = std::move(patterns_);

return success();
Expand All @@ -588,7 +682,8 @@ struct ShapeLegalizeToStablehloPass
} // namespace

void populateShapeToStablehloPatterns(MLIRContext* context,
RewritePatternSet* patterns) {
RewritePatternSet* patterns,
bool legalizeConstraints) {
patterns->add<ConvertConstShapeOpPattern>(context);
patterns->add<ConvertMulIOpPattern>(context);
patterns->add<ConvertIndexCastOpPattern>(context);
Expand All @@ -600,6 +695,14 @@ void populateShapeToStablehloPatterns(MLIRContext* context,
patterns->add<ConvertTensorDimPattern>(context);
patterns->add<ConvertTensorExtractPattern>(context);
patterns->add<ConvertTensorFromElementsPattern>(context);
if (legalizeConstraints) {
patterns->add<ConvertCstrBroadcastableOp>(context);
}
}

std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createShapeLegalizeToStablehloPass(bool legalizeConstraints) {
return std::make_unique<ShapeLegalizeToStablehloPass>(legalizeConstraints);
}

} // namespace stablehlo
Expand Down

0 comments on commit 88c639f

Please sign in to comment.