Skip to content

Commit

Permalink
remove dependency of shape asssertions
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jun 10, 2024
1 parent 64e763c commit 762f8b0
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 122 deletions.
5 changes: 0 additions & 5 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ An experimental pass that legalizes shape-related ops to StableHLO ops.
Bringing shape and data computations together via an optional pass will
make it possible for the StableHLO ecosystem to potentially leverage the
compilation pipelines that use StableHLO operations to model dynamism.

#### Options
```
-legalize-constraints : Whether to legalize Cstr Ops to shape_assertion custom_call
```
### `-stablehlo-aggressive-folder`

_Folds StableHLO operations_
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/reference/Api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ FailureOr<func::FuncOp> getMainFunction(ModuleOp module, StringRef mainName) {
class DefaultInterpreterFallback : public InterpreterFallback {
public:
DefaultInterpreterFallback(const InterpreterConfiguration &config)
: config(config){};
: config(config) {};

virtual llvm::Error operator()(Operation &op, Scope &scope,
Process *process) final {
Expand Down Expand Up @@ -229,6 +229,7 @@ LogicalResult lowerQuantization(ModuleOp module, func::FuncOp func) {
return func.emitError("Failed to lower quantized types/ops in function: ")
<< func.getName();
}
module.dump();
return success();
}

Expand Down
3 changes: 2 additions & 1 deletion stablehlo/transforms/PassPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ void createStablehloLowerQuantPipeline(OpPassManager &pm) {
stablehlo::createStablehloLegalizeQuantToIntPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<mlir::func::FuncOp>(
stablehlo::createShapeLegalizeToStablehloPass(true));
stablehlo::createShapeLegalizeToStablehloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}

Expand Down
7 changes: 1 addition & 6 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,10 @@ void populateStablehloLegalizeDeprecatedOpsPatterns(

/// Collection of shape dialect to StableHLO patterns.
void populateShapeToStablehloPatterns(MLIRContext *context,
RewritePatternSet *patterns,
bool legalizeConstraints);
RewritePatternSet *patterns);

//// Additional pass constructors ////

// Legalizes from the Shape dialect to the StableHLO dialect.
std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createShapeLegalizeToStablehloPass(bool legalizeConstraints);

std::unique_ptr<OperationPass<ModuleOp>> createStablehloRefineArgumentsPass(
TypeRange refinedTypes);

Expand Down
4 changes: 0 additions & 4 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::Fu
compilation pipelines that use StableHLO operations to model dynamism.
}];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
let options = [
Option<"legalize_constraints_", "legalize-constraints", "bool",
/*default=*/"false", "Whether to legalize Cstr Ops to shape_assertion custom_call">
];
}

def StablehloLegalizeDeprecatedOpsPass : Pass<"stablehlo-legalize-deprecated-ops", "func::FuncOp"> {
Expand Down
107 changes: 2 additions & 105 deletions stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,6 @@ Value castToIndex(PatternRewriter& rewriter, Location loc, Value value) {
return cast.getResult(0);
}

void insertShapeAssertionCustomCall(OpBuilder builder, Location loc,
Value assert) {
auto customCall = builder.create<stablehlo::CustomCallOp>(loc, TypeRange{},
ValueRange{assert});
customCall.setCallTargetName("shape_assertion");
customCall.setHasSideEffect(true);
customCall->setAttr("error_message",
builder.getStringAttr("Shape assertion failed"));
}

Value maybeCastToIndex(Value result, Value value, PatternRewriter& rewriter) {
if (isShapedOfI32(result)) return value;
return castToIndex(rewriter, value.getLoc(), value);
Expand Down Expand Up @@ -501,75 +491,6 @@ struct ConvertTensorFromElementsPattern
}
};

struct ConvertCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter& rewriter) const override {
// The way CstrBroadcastableOp is defined, its inputs inputs must be 1D
// tensor or !shape.shape. We only support inputs of two 1D tensors.
if (op.getShapes().size() != 2) return failure();
auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front());
auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back());
if (!shape1 || !shape2) return failure();
auto tensorType1 = dyn_cast<RankedTensorType>(shape1.getType());
auto tensorType2 = dyn_cast<RankedTensorType>(shape2.getType());
if (!tensorType1 || !tensorType2) return failure();

// If the two operand shapes are of different sizes, the smaller one is
// padded with 1's from the left.
if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) {
shape1 =
padFromLeft(rewriter, op.getLoc(), shape1,
tensorType2.getDimSize(0) - tensorType1.getDimSize(0));
} else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) {
shape2 =
padFromLeft(rewriter, op.getLoc(), shape2,
tensorType1.getDimSize(0) - tensorType2.getDimSize(0));
}

// Compute if each dim is broadcastable. A dim is broadcastable iff
// dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1
int32_t rank =
std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0));
auto allOne = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), DenseIntElementsAttr::get<int32_t>(
RankedTensorType::get({rank}, rewriter.getI32Type()),
static_cast<int32_t>(1)));
Value dimSize1Is1 = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape1, allOne, ComparisonDirection::EQ);
Value dimSize2Is1 = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape2, allOne, ComparisonDirection::EQ);
Value eitherDimSizeIs1 =
rewriter.create<stablehlo::OrOp>(op.getLoc(), dimSize1Is1, dimSize2Is1);
Value dimSizeEq = rewriter.create<stablehlo::CompareOp>(
op.getLoc(), shape1, shape2, ComparisonDirection::EQ);
Value dimBroadcastable = rewriter.create<stablehlo::OrOp>(
op.getLoc(), eitherDimSizeIs1, dimSizeEq);

// Iterate over each dim to check that all dims are broadcastable.
auto boolType = RankedTensorType::get({1}, rewriter.getI1Type());
Value allBroadcastable = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(), DenseIntElementsAttr::get<bool>(boolType, true));
for (auto i = 0; i < rank; ++i) {
Value broadcastable =
rewriter.create<SliceOp>(op.getLoc(), dimBroadcastable, i, i + 1, 1);
allBroadcastable =
rewriter.create<AndOp>(op.getLoc(), allBroadcastable, broadcastable);
}
Value allBroadcastableScalar = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()),
allBroadcastable);

// Add CustomCallOp and replace Cstr op with const witness, which is useful
// for canonicalizer to remove the shape.assuming region.
insertShapeAssertionCustomCall(rewriter, op->getLoc(),
allBroadcastableScalar);
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
return success();
}
};

template <typename OpType>
struct CastOperandsPattern : public OpRewritePattern<OpType> {
using OpRewritePattern<OpType>::OpRewritePattern;
Expand Down Expand Up @@ -602,12 +523,6 @@ struct ShapeLegalizeToStablehloPass
ShapeLegalizeToStablehloPass> {
using ShapeLegalizeToStablehloPassBase::ShapeLegalizeToStablehloPassBase;

explicit ShapeLegalizeToStablehloPass(bool legalizeConstraints)
: impl::ShapeLegalizeToStablehloPassBase<
ShapeLegalizeToStablehloPass>::ShapeLegalizeToStablehloPassBase() {
this->legalize_constraints_ = legalizeConstraints;
}

LogicalResult initialize(MLIRContext* context) override {
// In order to make dynamic StableHLO programs compatible with HLO, we need
// to get rid of all non-StableHLO ops.
Expand All @@ -633,10 +548,6 @@ struct ShapeLegalizeToStablehloPass
// is able to remove unnecessary cruft. At the moment, this pass is a
// work in progress, so not all of these ops are supported.
//
// When legalize_constraints_ is set true, cstr* ops are also legalized.
// A shape_assertion custom_call is used to check the constraint. And the
// shape.assuming region will consume a shape.const_witness that evaluate to
// true, so that it can be removed later in a canonicalizer pass.
target = std::make_shared<ConversionTarget>(*context);
target->addIllegalDialect<shape::ShapeDialect>();
target->addIllegalDialect<tensor::TensorDialect>();
Expand All @@ -648,10 +559,6 @@ struct ShapeLegalizeToStablehloPass
});
target->addLegalOp<tensor::CastOp>();
target->addLegalOp<UnrealizedConversionCastOp>();
if (this->legalize_constraints_) {
target->addLegalOp<shape::ConstWitnessOp, shape::AssumingOp,
shape::AssumingYieldOp>();
}

// The patterns do what one might expect, converting between MLIR-style
// and HLO-style shape computations.
Expand All @@ -662,8 +569,7 @@ struct ShapeLegalizeToStablehloPass
// to ultimately annihilate with each other upon canonicalization if
// everything went right.
RewritePatternSet patterns_(context);
populateShapeToStablehloPatterns(context, &patterns_,
this->legalize_constraints_);
populateShapeToStablehloPatterns(context, &patterns_);
patterns = std::move(patterns_);

return success();
Expand All @@ -682,8 +588,7 @@ struct ShapeLegalizeToStablehloPass
} // namespace

void populateShapeToStablehloPatterns(MLIRContext* context,
RewritePatternSet* patterns,
bool legalizeConstraints) {
RewritePatternSet* patterns) {
patterns->add<ConvertConstShapeOpPattern>(context);
patterns->add<ConvertMulIOpPattern>(context);
patterns->add<ConvertIndexCastOpPattern>(context);
Expand All @@ -695,14 +600,6 @@ void populateShapeToStablehloPatterns(MLIRContext* context,
patterns->add<ConvertTensorDimPattern>(context);
patterns->add<ConvertTensorExtractPattern>(context);
patterns->add<ConvertTensorFromElementsPattern>(context);
if (legalizeConstraints) {
patterns->add<ConvertCstrBroadcastableOp>(context);
}
}

std::unique_ptr<mlir::OperationPass<func::FuncOp>>
createShapeLegalizeToStablehloPass(bool legalizeConstraints) {
return std::make_unique<ShapeLegalizeToStablehloPass>(legalizeConstraints);
}

} // namespace stablehlo
Expand Down

0 comments on commit 762f8b0

Please sign in to comment.