Skip to content

Commit

Permalink
Add onnx op LRN lowering
Browse files Browse the repository at this point in the history
This commit adds support for lowering
Onnx LRN op to aten.
  • Loading branch information
manupak committed Jun 7, 2024
1 parent d59d0b6 commit 14eabd5
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace mlir {
namespace torch {
namespace Torch {

class PrimListConstructOp;

int64_t toPositiveDim(int64_t dim, int64_t inputRank);
bool isValidDim(int64_t dim, int64_t inputRank);
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
Expand Down Expand Up @@ -150,6 +152,13 @@ Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);
LogicalResult getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType);
// A helper function to get a constant value vector from an int64_t ArrayRef
SmallVector<Value> getValueVector(OpBuilder &builder, Location loc,
ArrayRef<int64_t> intVector);

// A helper function to get a list Type from an int64_t ArrayRef
PrimListConstructOp getList(OpBuilder &builder, Location loc,
ArrayRef<int64_t> intVector);

} // namespace Torch
} // namespace torch
Expand Down
121 changes: 121 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,127 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand, constAlpha);
return success();
});
patterns.onOp(
"LRN", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t size;
float alpha;
float beta;
float bias;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(size, "size", 2) ||
binder.f32FloatAttr(alpha, "alpha", 0.0001f) ||
binder.f32FloatAttr(beta, "beta", 0.75f) ||
binder.f32FloatAttr(bias, "bias", 1.0f))
return failure();
Type dtype = resultType.getOptionalDtype();
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(alpha));
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(beta));
Value constBias = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(bias));
// Please refer to the operator description
// for more info on the lowering
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#LRN

// squared = operand^2
Location loc = binder.getLoc();
Torch::ValueTensorType inTy =
cast<Torch::ValueTensorType>(operand.getType());
Value sqOperand = rewriter.create<Torch::AtenMulTensorOp>(
loc, inTy, operand, operand);
// view it as n x 1 x c x d0 x d..
if (!inTy.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expected input to have sizes");
}
ArrayRef<int64_t> inTyShape = inTy.getSizes();
if (inTyShape.size() < 3) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported: the input dimensions should be >= 3");
}
SmallVector<int64_t, 5> viewShapeInt{inTyShape[0], 1, inTyShape[1],
inTyShape[2]};
if (inTyShape.size() > 3) {
int64_t tailFlatDimSize = 1;
for (int64_t dimSize : inTyShape.slice(3)) {
tailFlatDimSize *= dimSize;
}
viewShapeInt.push_back(tailFlatDimSize);
}
Torch::ValueTensorType reshapeType =
rewriter.getType<Torch::ValueTensorType>(viewShapeInt, dtype);
Torch::PrimListConstructOp viewShapeListVal =
Torch::getList(rewriter, loc, viewShapeInt);
auto view = rewriter.create<Torch::AtenViewOp>(
loc, reshapeType, sqOperand, viewShapeListVal);
// padding
int64_t highPad = (size - 1) / 2;
int64_t lowPad = (size - 1) - highPad;
SmallVector<int64_t> paddingInt{0, 0, 0, 0, lowPad, highPad};
auto constPadVal = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(0.0));
Torch::PrimListConstructOp paddingListVal =
Torch::getList(rewriter, loc, paddingInt);
SmallVector<int64_t, 5> paddedShapeInt = viewShapeInt;
paddedShapeInt[2] += size - 1;
Torch::ValueTensorType paddedType =
rewriter.getType<Torch::ValueTensorType>(paddedShapeInt, dtype);
auto padded = rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddedType, view, paddingListVal, constPadVal);
// avg_pool3d
SmallVector<int64_t, 3> kernelSize{size, 1, 1};
Torch::PrimListConstructOp kernelSizeList =
Torch::getList(rewriter, loc, kernelSize);
SmallVector<int64_t, 3> strides{1, 1, 1};
Torch::PrimListConstructOp stridesList =
Torch::getList(rewriter, loc, strides);
SmallVector<int64_t, 3> padding{0, 0, 0};
Torch::PrimListConstructOp paddingList =
Torch::getList(rewriter, loc, padding);
auto cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto cstCountIncludeMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
// Output of pooling is same reshape(view) type because
// of the padding done on the dimensions being pooled.
auto pool = rewriter.create<Torch::AtenAvgPool3dOp>(
loc, reshapeType, padded, kernelSizeList, stridesList, paddingList,
cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone);
// squeeze
auto one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<int64_t, 5> squeezeShapeInt{
viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]};
Torch::ValueTensorType squeezeType =
rewriter.getType<Torch::ValueTensorType>(squeezeShapeInt, dtype);
auto squeeze = rewriter.create<Torch::AtenSqueezeDimOp>(
loc, squeezeType, pool, one);
// view as input Type
Torch::PrimListConstructOp intTyShapeList =
Torch::getList(rewriter, loc, inTyShape);
auto viewAsInput = rewriter.create<Torch::AtenViewOp>(
loc, inTy, squeeze, intTyShapeList);
// mul + add + pow + div
auto mul = rewriter.create<Torch::AtenMulScalarOp>(
loc, resultType, viewAsInput, constAlpha);
auto add = rewriter.create<Torch::AtenAddScalarOp>(loc, resultType, mul,
constBias, one);
auto pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
loc, resultType, add, constBeta);

rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
binder.op, resultType, operand, pow);
return success();
});
patterns.onOp(
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,19 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getI64Type();
return inputType;
}

SmallVector<Value> Torch::getValueVector(OpBuilder &builder, Location loc,
ArrayRef<int64_t> intVector) {
SmallVector<Value> ret;
llvm::transform(intVector, std::back_inserter(ret), [&](int64_t n) {
return builder.create<ConstantIntOp>(loc, builder.getI64IntegerAttr(n));
});
return ret;
}

PrimListConstructOp Torch::getList(OpBuilder &builder, Location loc,
ArrayRef<int64_t> intVector) {
SmallVector<Value> vals = getValueVector(builder, loc, intVector);
return builder.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(builder.getType<Torch::IntType>()), vals);
}
131 changes: 131 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor

// -----

// CHECK-LABEL: func.func @test_lrn_default
func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0

// CHECK-DAG: %[[I20:.*]] = torch.constant.int 20
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I10:.*]] = torch.constant.int 10
// CHECK-DAG: %[[I3:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I50:.+]] = torch.constant.int 50
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[I50]]

// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]

// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]]

// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]

// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]]

// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]

// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]

// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]

// CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20
// CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]]

// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
// CHECK: return %[[OUTPUT]]
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32>
return %0 : !torch.vtensor<[20,10,3,50],f32>
}

// -----

// CHECK-LABEL: func.func @test_lrn_with_optionals
func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0

// CHECK-DAG: %[[I13:.*]] = torch.constant.int 13
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I19:.*]] = torch.constant.int 19
// CHECK-DAG: %[[I100:.+]] = torch.constant.int 100
// CHECK-DAG: %[[I200:.+]] = torch.constant.int 200
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[I200]]

// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]

// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]]

// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]

// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]]

// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]

// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]

// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]

// CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13
// CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19
// CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100
// CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]]

// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
// CHECK: return %[[OUTPUT]]
%none = torch.constant.none
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32>
return %0 : !torch.vtensor<[13,19,100,200],f32>
}

// -----

// CHECK-LABEL: @test_matmul_2d
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
Expand Down

0 comments on commit 14eabd5

Please sign in to comment.