Skip to content

Commit

Permalink
New support for resize op (#844)
Browse files Browse the repository at this point in the history
* shape inference for size

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* to Krnl

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* change constant

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* add backend test

Signed-off-by: Tong Chen <chentong@us.ibm.com>

* change check lit

Signed-off-by: Tong Chen <chentong@us.ibm.com>

Co-authored-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
chentong319 and tungld authored Sep 7, 2021
1 parent 3deff97 commit 977c765
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 57 deletions.
82 changes: 60 additions & 22 deletions src/Conversion/ONNXToKrnl/Tensor/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct ONNXResizeOpLowering : public ConversionPattern {
ONNXResizeOpAdaptor operandAdaptor(operands);
Value data = operandAdaptor.X();
Value scales = operandAdaptor.scales();
Value sizes = operandAdaptor.sizes();
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
int64_t rank = memRefType.getShape().size();

Expand All @@ -43,33 +44,52 @@ struct ONNXResizeOpLowering : public ConversionPattern {
resizeOp.nearest_mode() != "round_prefer_floor"))
return emitError(loc, "not implemented yet");

// Get the scales
// SymbolIndexExpr was tried but got runtime error
// Attribute::cast() const [with U = mlir::IntegerAttr]
// The reason seems to be that IntegerAttr is assumed
//
SmallVector<Value, 4> scaleValues;
DenseElementsAttr scalesAttrs =
getDenseElementAttributeFromONNXValue(resizeOp.scales());
SmallVector<float, 4> scalesConstant;
if (scalesAttrs) {
for (auto scaleAttr : scalesAttrs.getValues<FloatAttr>()) {
Value scaleConstant = emitConstantOp(
rewriter, loc, rewriter.getF32Type(), scaleAttr.getValueAsDouble());
scaleValues.emplace_back(scaleConstant);
bool fromScale = !isFromNone(resizeOp.scales());
IndexExprScope outerloopContex(rewriter, loc);
DimsExpr outputDims(rank);
MemRefBoundsIndexCapture dataBounds(data);
if (fromScale) {
// Get the scales
// SymbolIndexExpr was tried but got runtime error
// Attribute::cast() const [with U = mlir::IntegerAttr]
// The reason seems to be that IntegerAttr is assumed
//
DenseElementsAttr scalesAttrs =
getDenseElementAttributeFromONNXValue(resizeOp.scales());
SmallVector<float, 4> scalesConstant;
if (scalesAttrs) {
for (auto scaleAttr : scalesAttrs.getValues<FloatAttr>()) {
Value scaleConstant = emitConstantOp(rewriter, loc,
rewriter.getF32Type(), scaleAttr.getValueAsDouble());
scaleValues.emplace_back(scaleConstant);
}
} else {
for (decltype(rank) i = 0; i < rank; i++) {
Value indexValue =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
Value scaleVal = rewriter.create<KrnlLoadOp>(loc, scales, indexValue);
scaleValues.emplace_back(scaleVal);
}
}
} else {
for (decltype(rank) i = 0; i < rank; i++) {
Value indexValue =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
Value scaleVal = rewriter.create<KrnlLoadOp>(loc, scales, indexValue);
Value resizedVal = rewriter.create<KrnlLoadOp>(loc, sizes, indexValue);
Value resizedFVal =
rewriter.create<SIToFPOp>(loc, rewriter.getF32Type(), resizedVal);
Value inputDim = dataBounds.getDim(i).getValue();
Value inputDimInteger = rewriter.create<IndexCastOp>(
loc, inputDim, rewriter.getIntegerType(64));
Value inputDimFloat = rewriter.create<SIToFPOp>(
loc, rewriter.getF32Type(), inputDimInteger);
Value scaleVal =
rewriter.create<DivFOp>(loc, resizedFVal, inputDimFloat);
scaleValues.emplace_back(scaleVal);
}
}

IndexExprScope outerloopContex(rewriter, loc);
DimsExpr outputDims(rank);
MemRefBoundsIndexCapture dataBounds(data);
// Keep the code using IndexExpr for bug fixing
// ArrayValueIndexCapture scaleIEs(op, scales,
// getDenseElementAttributeFromKrnlValue,
Expand All @@ -80,7 +100,7 @@ struct ONNXResizeOpLowering : public ConversionPattern {

if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else {
else if (fromScale) {
for (decltype(rank) i = 0; i < rank; i++) {
if (memRefType.getShape()[i] != -1) {
outputDims[i] = LiteralIndexExpr(memRefType.getShape()[i]);
Expand All @@ -102,6 +122,24 @@ struct ONNXResizeOpLowering : public ConversionPattern {
}
alloc = insertAllocAndDeallocSimple(
rewriter, op, memRefType, loc, outputDims, insertDealloc);
} else {
// Output is determined by sizes()
for (decltype(rank) i = 0; i < rank; i++) {
if (memRefType.getShape()[i] != -1) {
outputDims[i] = LiteralIndexExpr(memRefType.getShape()[i]);
} else {
Value indexValue =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
Value resizedVal =
rewriter.create<KrnlLoadOp>(loc, sizes, indexValue);
Value outDim = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(), resizedVal);
SymbolIndexExpr outDimIE(outDim);
outputDims[i] = SymbolIndexExpr(outDimIE);
}
}
alloc = insertAllocAndDeallocSimple(
rewriter, op, memRefType, loc, outputDims, insertDealloc);
}

// Create loops
Expand All @@ -124,8 +162,8 @@ struct ONNXResizeOpLowering : public ConversionPattern {
// FPToSIOp is round-to-zero, same as floor for positive
// round_prefer_floor will round 2.5 to 2, not 3
if (resizeOp.nearest_mode() == "round_prefer_floor") {
Value deltaConstant = emitConstantOp(
rewriter, loc, rewriter.getF32Type(), 0.4999999999);
Value deltaConstant =
emitConstantOp(rewriter, loc, rewriter.getF32Type(), 0.499999);
inIndexFloat =
rewriter.create<AddFOp>(loc, inIndexFloat, deltaConstant);
} else if (resizeOp.nearest_mode() == "floor") {
Expand Down Expand Up @@ -153,8 +191,8 @@ struct ONNXResizeOpLowering : public ConversionPattern {
scaleValues[i]),
halfPixelConstant);
if (resizeOp.nearest_mode() == "round_prefer_floor") {
Value deltaConstant = emitConstantOp(
rewriter, loc, rewriter.getF32Type(), 0.4999999999);
Value deltaConstant =
emitConstantOp(rewriter, loc, rewriter.getF32Type(), 0.499999);
inIndexFloat =
rewriter.create<AddFOp>(loc, inIndexFloat, deltaConstant);
} else if (resizeOp.nearest_mode() == "floor") {
Expand Down
64 changes: 41 additions & 23 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2592,12 +2592,14 @@ LogicalResult ONNXResizeOp::inferShapes(
}
auto inputTy = X().getType().cast<RankedTensorType>();

if (isFromNone(scales()) == isFromNone(sizes())) {
return emitError("scales() and sizes() can not both None/not None");
// Output should at least has the same rank as X input
if (!getResult().getType().isa<RankedTensorType>()) {
SmallVector<int64_t, 4> dims(inputTy.getRank(), -1);
getResult().setType(RankedTensorType::get(dims, inputTy.getElementType()));
}

if (isFromNone(scales())) {
return emitError("using sizes() not implemented yet");
if (isFromNone(scales()) == isFromNone(sizes())) {
return emitError("scales() and sizes() can not both None/not None");
}

if (!(mode() == "nearest" &&
Expand All @@ -2608,28 +2610,44 @@ LogicalResult ONNXResizeOp::inferShapes(
}

// Current implementation handles constant scales only
DenseElementsAttr scalesAttrs =
getDenseElementAttributeFromONNXValue(scales());
if (!scalesAttrs) {
return success();
}
if (!isFromNone(scales())) {
DenseElementsAttr scalesAttrs =
getDenseElementAttributeFromONNXValue(scales());
if (!scalesAttrs) {
return success();
}

SmallVector<float, 4> scalesConstant;
for (auto scaleAttr : scalesAttrs.getValues<FloatAttr>()) {
scalesConstant.emplace_back(scaleAttr.getValueAsDouble());
}
SmallVector<float, 4> scalesConstant;
for (auto scaleAttr : scalesAttrs.getValues<FloatAttr>()) {
scalesConstant.emplace_back(scaleAttr.getValueAsDouble());
}

SmallVector<int64_t, 4> dims;
for (int i = 0; i < inputTy.getRank(); i++) {
int newDim;
if (inputTy.getShape()[i] == -1)
newDim = -1;
else
newDim = inputTy.getShape()[i] * scalesConstant[i];
dims.emplace_back(newDim);
}
SmallVector<int64_t, 4> dims;
for (int i = 0; i < inputTy.getRank(); i++) {
int newDim;
if (inputTy.getShape()[i] == -1)
newDim = -1;
else
newDim = inputTy.getShape()[i] * scalesConstant[i];
dims.emplace_back(newDim);
}

getResult().setType(RankedTensorType::get(dims, inputTy.getElementType()));
getResult().setType(RankedTensorType::get(dims, inputTy.getElementType()));
} else {
DenseElementsAttr sizesAttrs =
getDenseElementAttributeFromONNXValue(sizes());
if (!sizesAttrs) {
return success();
}

SmallVector<int64_t, 4> sizesConstant;
for (auto sizeAttr : sizesAttrs.getValues<IntegerAttr>()) {
sizesConstant.emplace_back(sizeAttr.getInt());
}

getResult().setType(
RankedTensorType::get(sizesConstant, inputTy.getElementType()));
}
return success();
}

Expand Down
2 changes: 2 additions & 0 deletions test/backend/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@
# Resize
"test_resize_upsample_scales_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}},
"test_resize_downsample_scales_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}},
"test_resize_upsample_sizes_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}},
"test_resize_downsample_sizes_nearest_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE: {0:{-1}}, CONSTANT_INPUT:{-1}},

# Reverse Sequence

Expand Down
12 changes: 0 additions & 12 deletions test/mlir/onnx/onnx_shape_inference_error.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,6 @@ func @unsupport_pad_unknown_pad_values(%arg0 : tensor<16x13xf32>, %arg1 : tensor
/// Unsupported configurations for ONNXResizeOp.
//===----------------------------------------------------------------------===//

func @unsupport_resize_using_sizes(%arg0 : tensor<3x4x5x6xf32>, %arg1 : tensor<1xi64>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Constant"() {value = dense<[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00]> : tensor<8xf32>} : () -> tensor<8xf32>

// expected-error @+2 {{using sizes() not implemented yet}}
// expected-error @+1 {{shape inference failed}}
%1 = "onnx.Resize"(%arg0, %0, %cst, %arg1) {coordinate_transformation_mode = "asymmetric", mode = "nearest", nearest_mode = "floor", onnx_node_name = "Resize1"} : (tensor<3x4x5x6xf32>, tensor<8xf32>, none, tensor<1xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
}

// -----

func @unsupport_resize_linear_mode(%arg0 : tensor<3x4x5x6xf32>) -> tensor<*xf32> {
%cst = constant unit
%0 = "onnx.Constant"() {value = dense<[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00]> : tensor<8xf32>} : () -> tensor<8xf32>
Expand Down

0 comments on commit 977c765

Please sign in to comment.