Skip to content

Commit

Permalink
[Relax][Frontend][Onnx] support MaxPool1/2/3D and AveragePool1/2/3D (a…
Browse files Browse the repository at this point in the history
…pache#16681)

support MaxPool1/2/3D and AveragePool1/2/3D

Co-authored-by: cheng wen <chengven027-intellif>
  • Loading branch information
chengven027 authored and thaisacs committed Apr 3, 2024
1 parent 8987c7c commit 50c8bbf
Show file tree
Hide file tree
Showing 17 changed files with 1,123 additions and 124 deletions.
80 changes: 80 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,51 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
}
}; // struct Conv2DTransposeAttrs

/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
struct Pool1DAttrs : public tvm::AttrsNode<Pool1DAttrs> {
Array<IntImm> pool_size;
Array<IntImm> strides;
Array<IntImm> padding;
Array<IntImm> dilation;
bool ceil_mode;
bool count_include_pad;
String layout;
String out_layout;

TVM_DECLARE_ATTRS(Pool1DAttrs, "relax.attrs.Pool1DAttrs") {
TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the convolution.");
TVM_ATTR_FIELD(padding).describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : padding width in the order of (left, right)");
TVM_ATTR_FIELD(ceil_mode).describe(
"A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
"every element in the input tensor will be covered by a sliding window.");
TVM_ATTR_FIELD(count_include_pad)
.describe("When true, will include padding to compute the average");
TVM_ATTR_FIELD(layout).set_default("NCW").describe(
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.describe(
"Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
}
}; // struct Pool1dAttrs

/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
Array<IntImm> pool_size;
Array<IntImm> strides;
Array<IntImm> padding;
Array<IntImm> dilation;
bool ceil_mode;
bool count_include_pad;
String layout;
String out_layout;

Expand All @@ -277,6 +315,8 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
TVM_ATTR_FIELD(ceil_mode).describe(
"A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
"every element in the input tensor will be covered by a sliding window.");
TVM_ATTR_FIELD(count_include_pad)
.describe("When true, will include padding to compute the average");
TVM_ATTR_FIELD(layout).describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand All @@ -291,6 +331,46 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
}
}; // struct Pool2dAttrs

/*! \brief Attributes used in max_pool3d and avg_pool3d operator */
struct Pool3DAttrs : public tvm::AttrsNode<Pool3DAttrs> {
Array<IntImm> pool_size;
Array<IntImm> strides;
Array<IntImm> padding;
Array<IntImm> dilation;
bool ceil_mode;
bool count_include_pad;
String layout;
String out_layout;

TVM_DECLARE_ATTRS(Pool3DAttrs, "relax.attrs.Pool3DAttrs") {
TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the convolution.");
TVM_ATTR_FIELD(padding).describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"four int : padding width in the order of (front, top, left, back, bottom, right)");
TVM_ATTR_FIELD(ceil_mode).describe(
"A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
"every element in the input tensor will be covered by a sliding window.");
TVM_ATTR_FIELD(count_include_pad)
.describe("When true, will include padding to compute the average");
TVM_ATTR_FIELD(layout).describe(
"Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(out_layout)
.describe(
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
"'W' dimensions.");
}
}; // struct Pool3dAttrs

/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
Optional<Array<IntImm>> output_size;
Expand Down
102 changes: 70 additions & 32 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,21 +1438,40 @@ def _impl_v15(cls, bb, inputs, attr, params):
)


class MaxPool(OnnxOpConverter):
"""Converts an onnx MaxPool node into an equivalent Relax expression."""
class Pool(OnnxOpConverter):
"""A helper class for pool op converters."""

name = ""

@classmethod
def _impl_v12(cls, bb, inputs, attr, params):
def get_pad_pair(cls, input1d, kernel1d, stride1d, mode):
"""infer pad size"""
if input1d % stride1d == 0:
pad = max(kernel1d - stride1d, 0)
else:
pad = max(kernel1d - (input1d % stride1d), 0)
pad_before = pad // 2
pad_after = pad - pad_before
if "LOWER" in mode:
return [pad_after, pad_before]
return [pad_before, pad_after]

@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
# Unpack inputs and attributes.
data = inputs[0]
input_shape = data.struct_info.shape
ndim = len(input_shape)

auto_pad = attr.get("auto_pad", b"NOTSET").decode("utf-8")
ceil_mode = attr.get("ceil_mode", 0)
dilations = attr.get("dilations", [1, 1])
dilations = attr.get("dilations", [1] * (ndim - 2))
kernel_shape = attr.get("kernel_shape")
pads = attr.get("pads", 0)
strides = attr.get("strides", [1, 1])
strides = attr.get("strides", [1] * (ndim - 2))

assert len(kernel_shape) in [1, 2, 3], "Currently only 1D/2D/3D/ pooling is supported."

assert len(kernel_shape) == 2, "Currently only 2D pooling is supported."
assert auto_pad in [
"NOTSET",
"SAME_UPPER",
Expand All @@ -1461,41 +1480,59 @@ def _impl_v12(cls, bb, inputs, attr, params):
], f"Value {auto_pad} in attribute auto_pad is invalid."

if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
input_spatial_shape = cls._get_input_spatial_shape(data)
output_spatial_shape = [0 for _ in input_spatial_shape]

pads = _np.array([(0, 0) for _ in range(len(kernel_shape))])
pads = []
if cls.name == "avg_pool":
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = strides[axis]
kernel = kernel_shape[axis]
pad = cls.get_pad_pair(axis_shape, kernel, stride, auto_pad)
pads.append(pad)
else:
input_spatial_shape = cls._get_input_spatial_shape(data)
output_spatial_shape = [0 for _ in input_spatial_shape]

for i, _ in enumerate(input_spatial_shape):
if auto_pad == "SAME_UPPER":
output_spatial_shape[i] = int(_np.ceil(input_spatial_shape[i] / strides[i]))
else:
output_spatial_shape[i] = int(
_np.floor(input_spatial_shape[i] / strides[i])
)
pad_i = (
(output_spatial_shape[i] - 1) * strides[i]
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
- input_spatial_shape[i]
)

for i, _ in enumerate(input_spatial_shape):
if auto_pad == "SAME_UPPER":
output_spatial_shape[i] = int(_np.ceil(input_spatial_shape[i] / strides[i]))
else:
output_spatial_shape[i] = int(_np.floor(input_spatial_shape[i] / strides[i]))
pad_i = (
(output_spatial_shape[i] - 1) * strides[i]
+ ((kernel_shape[i] - 1) * dilations[i] + 1)
- input_spatial_shape[i]
)
if auto_pad == "SAME_UPPER":
pads[i, 0] = pad_i // 2
pads[i, 1] = pad_i - pads[i, 0]
else:
pads[i, 1] = pad_i // 2
pads[i, 0] = pad_i - pads[i, 1]
if auto_pad == "SAME_UPPER":
pads.append([pad_i // 2, pad_i - pad_i // 2])
else:
pads.append([pad_i - pad_i // 2, pad_i // 2])

# TODO(agladyshev): for now we support only 2D kernel
# (top, left, bottom, right)
flatten_pads = [pads[0][0], pads[1][0], pads[0][1], pads[1][1]]
pads = tuple(flatten_pads)
pads = tuple([val for pair in zip(*pads) for val in pair])

return relax.op.nn.max_pool2d(data, kernel_shape, strides, pads, dilations, ceil_mode)
op = getattr(relax.op.nn, cls.name + str(len(kernel_shape)) + "d")
return op(data, kernel_shape, strides, pads, dilations, ceil_mode)

@classmethod
def _get_input_spatial_shape(cls, tensor):
# shape is (N x C x D1 x D2 ... Dn)
return _np.array([int(d) for d in tensor.struct_info.shape], dtype="int64")[2:]


class MaxPool(Pool):
"""Converts an onnx MaxPool node into an equivalent Relax expression."""

name = "max_pool"


class AveragePool(Pool):
"""Converts an onnx MaxPool node into an equivalent Relax expression."""

name = "avg_pool"


class GlobalAveragePool(OnnxOpConverter):
"""Converts an onnx GlobalAveragePool node into an equivalent Relax expression."""

Expand Down Expand Up @@ -1922,9 +1959,10 @@ def _get_convert_map():
"Split": Split,
"Tile": Tile,
"BatchNormalization": BatchNormalization,
"MaxPool": MaxPool,
"AveragePool": AveragePool,
"GlobalAveragePool": GlobalAveragePool,
"Flatten": Flatten,
"MaxPool": MaxPool,
"Identity": Identity,
"Resize": Resize,
"Einsum": Einsum,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/op/_op_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,7 @@ def max_pool2d_grad(
orig_call.attrs.padding,
orig_call.attrs.dilation,
orig_call.attrs.ceil_mode,
orig_call.attrs.count_include_pad,
orig_call.attrs.layout,
orig_call.attrs.out_layout,
)
Expand Down Expand Up @@ -1310,6 +1311,7 @@ def avg_pool2d_grad(
orig_call.attrs.padding,
orig_call.attrs.dilation,
orig_call.attrs.ceil_mode,
orig_call.attrs.count_include_pad,
orig_call.attrs.layout,
orig_call.attrs.out_layout,
)
Expand Down
24 changes: 22 additions & 2 deletions python/tvm/relax/op/grad/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def max_pool2d_backward(
padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
dilation: Tuple[int, int] = (1, 1),
ceil_mode: bool = False,
count_include_pad: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
Expand All @@ -147,7 +148,16 @@ def max_pool2d_backward(
The gradient w.r.t. data.
"""
return _ffi_api.max_pool2d_backward( # type: ignore
output_grad, data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout
output_grad,
data,
pool_size,
strides,
padding,
dilation,
ceil_mode,
count_include_pad,
layout,
out_layout,
)


Expand All @@ -159,6 +169,7 @@ def avg_pool2d_backward(
padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
dilation: Tuple[int, int] = (1, 1),
ceil_mode: bool = False,
count_include_pad: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
Expand All @@ -176,7 +187,16 @@ def avg_pool2d_backward(
The gradient w.r.t. data.
"""
return _ffi_api.avg_pool2d_backward( # type: ignore
output_grad, data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout
output_grad,
data,
pool_size,
strides,
padding,
dilation,
ceil_mode,
count_include_pad,
layout,
out_layout,
)


Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
adaptive_avg_pool2d,
attention,
attention_var_len,
avg_pool1d,
avg_pool2d,
avg_pool3d,
batch_norm,
conv1d,
conv1d_transpose,
Expand All @@ -34,7 +36,9 @@
layer_norm,
leakyrelu,
log_softmax,
max_pool1d,
max_pool2d,
max_pool3d,
nll_loss,
pad,
relu,
Expand Down
Loading

0 comments on commit 50c8bbf

Please sign in to comment.