From d50ec2367bf2124f2958e561a7ac8d39931023f7 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 10 Oct 2024 10:36:11 +0800 Subject: [PATCH 01/20] [Relax] Add NonZero op (#17453) this PR adds the NonZero op to Relax, together with ONNX frontend support --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 10 ++++- python/tvm/relax/op/__init__.py | 2 +- python/tvm/relax/op/set.py | 37 +++++++++++++++++++ src/relax/op/tensor/set.cc | 23 ++++++++++++ src/relax/op/tensor/set.h | 28 ++++++++++++++ tests/python/relax/test_frontend_onnx.py | 5 +++ tests/python/relax/test_op_set.py | 34 +++++++++++++++++ 7 files changed, 137 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index aa156a025fef..b9eb141bd14e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -2482,6 +2482,14 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.unique(data, sorted=sorted, axis=axis) +class NonZero(OnnxOpConverter): + """Converts an onnx NonZero node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + return relax.op.nonzero(inputs[0]) + + class HardSigmoid(OnnxOpConverter): """Converts an onnx HardSigmoid node into an equivalent Relax expression.""" @@ -2867,7 +2875,7 @@ def _get_convert_map(): "Range": Range, "OneHot": OneHot, "Unique": Unique, - # "NonZero": NonZero, + "NonZero": NonZero, # "If": If, # "LRN": LRN, # "MaxRoiPool": MaxRoiPool, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index c99201e969b5..efd9997698ee 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -101,7 +101,7 @@ from .qdq import dequantize, quantize from .sampling import multinomial_from_uniform from .search import argmax, argmin, where -from .set import unique +from .set import nonzero, unique from .sorting import argsort, sort, topk from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance from .ternary import ewise_fma diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py index 0b86e19ce53f..c5db852ddd5d 100644 --- a/python/tvm/relax/op/set.py +++ b/python/tvm/relax/op/set.py @@ -110,3 +110,40 @@ def numpy_unique( return tvm.nd.array(output_sorted_numpy) output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis) return tvm.nd.array(output_numpy) + + +def nonzero(x: Expr) -> Expr: + """Find the indices of elements of a tensor that are non-zero. + + Parameters + ---------- + x : relax.Expr + The input data tensor. + + Returns + ------- + result : relax.Expr + A (n+1)-D tensor containing indices of non-zero elements. + + Note + ---- + This function is equivalent to `onnx.nonzero`. + + Examples + -------- + + .. code-block:: python + + x = [[0, 1], + [2, 0]] + nonzero(x) = [[0, 1], + [1, 0]] + + """ + return _ffi_api.nonzero(x) # type: ignore + + +@tvm.register_func("relax.run.nonzero") +def numpy_nonzero(x: tvm.nd.array) -> tvm.nd.array: + np_result = np.atleast_1d(x.numpy()).nonzero() + return tvm.nd.array(np.stack(np_result, axis=0)) diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index 29d9d52c6077..c659a49afd12 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -24,6 +24,7 @@ #include "set.h" +#include #include #include @@ -137,5 +138,27 @@ TVM_REGISTER_OP("relax.unique") .set_attr("FCallPacked", "relax.run.unique") .set_attr("FPurity", Bool(true)); +/* relax.nonzero */ +Expr nonzero(Expr x) { + static const Op& op = Op::Get("relax.nonzero"); + return Call(op, {std::move(x)}); +} + +TVM_REGISTER_GLOBAL("relax.op.nonzero").set_body_typed(nonzero); + +StructInfo InferStructInfoNonzero(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + // Cheat zero dim scalar as 1-dim. + int dim = data_sinfo->IsUnknownNdim() ? kUnknownNDim : std::max(1, data_sinfo->ndim) + 1; + return TensorStructInfo(DataType::Int(64), dim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.nonzero") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoNonzero) + .set_attr("FCallPacked", "relax.run.nonzero") + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/set.h b/src/relax/op/tensor/set.h index a5c7ee85bfb2..251dd1975e9f 100644 --- a/src/relax/op/tensor/set.h +++ b/src/relax/op/tensor/set.h @@ -29,8 +29,36 @@ namespace tvm { namespace relax { +/*! + * \brief Find the unique elements in a given tensor. + * In addition, it optionally returns + * - the indices of the input tensor that give the unique values; + * - the indices of the unique tensor that reconstruct the input tensor; + * - the number of times each unique value comes up in the input tensor. + * \param x The input tensor. + * \param sorted Whether to sort the unique elements in ascending order before + * returning as output. + * \param return_index Whether to return an additional tensor with indices for where elements in + * the unique tensor come from the original input. + * \param return_inverse Whether to return an additional tensor with indices for where elements in + * the original input ended up in the returned unique list. + * \param return_counts Whether to return an additional tensor with counts of each unique elements. + * \param axis The dimension to apply unique. + * If not specified, the unique values of the flattened input are returned. + * \return The unique elements of the array. The returned array will be sorted if `sorted` is True. + * Additional return values depend on `return_index`, `return_inverse`, and `return_counts`. + */ Expr unique(Expr x, PrimValue sorted, PrimValue return_index, PrimValue return_inverse, PrimValue return_counts, Optional axis); + +/*! + * \brief Returns the indices of the non-zero elements of the input tensor. + * \param x The input tensor. + * \return a list of 1-D tensors containing indices of non-zero elements for each dimension. + * \note This function behaves similarly to numpy.nonzero(), but return a multi-dimensional array + * instead of a tuple of 1-D arrays. + */ +Expr nonzero(Expr x); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index e3ed3a3a9d4d..57f94c8442f7 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2162,6 +2162,11 @@ def test_unique(axis: Optional[int], sorted: int): check_correctness(model) +@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (4, 5, 6)]) +def test_nonzero(shape): + verify_unary("NonZero", shape, input_dtype=TensorProto.BOOL, output_dtype=TensorProto.INT64) + + @pytest.mark.parametrize("mode", ["DCR", "CRD"]) def test_depth_to_space(mode: Literal["DCR", "CRD"]): in_shape = [1, 8, 2, 3] diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index 741d7869d52f..e9070f99fc3f 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -867,5 +867,39 @@ def test_unique_infer_struct_info_wrong_input_dtype(): bb.normalize(relax.op.unique(x1)) +@pytest.mark.parametrize("shape", [(1,), (2, 3), (4, 5, 6)]) +def test_nonzero_infer_struct_info(shape): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor(shape, "bool")) + + _check_inference( + bb, + relax.op.nonzero(x0), + relax.TensorStructInfo(ndim=len(shape) + 1, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_ndim_zero(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((), "bool")) + + _check_inference( + bb, + relax.op.nonzero(x), + relax.TensorStructInfo(ndim=2, dtype="int64"), + ) + + +def test_nonzero_infer_struct_info_wrong_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nonzero(x1)) + + if __name__ == "__main__": tvm.testing.main() From 910ee0e852e32dd9a6e7c495229aa37847a7e473 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 10 Oct 2024 10:36:30 +0800 Subject: [PATCH 02/20] [Relax] Add scatter_nd op support (#17449) Add relax scatter_nd op support and ONNX frontend support. --- include/tvm/relax/attrs/manipulate.h | 12 ++ .../tvm/relax/frontend/onnx/onnx_frontend.py | 32 ++++- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/manipulate.py | 39 +++++ .../transform/legalize_ops/manipulate.py | 17 +++ python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/op/tensor/manipulate.cc | 134 ++++++++++++++++++ src/relax/op/tensor/manipulate.h | 33 +++++ tests/python/relax/test_frontend_onnx.py | 33 ++++- tests/python/relax/test_op_manipulate.py | 25 ++++ .../test_transform_legalize_ops_manipulate.py | 62 +++++++- 11 files changed, 387 insertions(+), 3 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index ef4265d73b4b..e53ba3c36e7f 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -164,6 +164,18 @@ struct ScatterElementsAttrs : public tvm::AttrsNode { "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\"."); } }; // struct ScatterElementsAttrs + +/*! \brief Attributes used in scatter_nd operators */ +struct ScatterNDAttrs : public tvm::AttrsNode { + String reduction; + + TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") { + TVM_ATTR_FIELD(reduction).set_default("update").describe( + "Accumulation mode of the ScatterND, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); + } +}; // struct ScatterNDAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b9eb141bd14e..f1fa67546c2a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -692,6 +692,36 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) +class ScatterND(OnnxOpConverter): + """Convert an onnx ScatterND node into an equivalent Relax expression.""" + + @staticmethod + def _reduction_check(attr, valid_reductions: List[str]): + reduction = attr.get("reduction", None) + reduction = reduction or b"update" + reduction = reduction.decode("utf-8") + reduction = "update" if reduction == "none" else reduction + assert ( + reduction in valid_reductions + ), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten" + + return reduction + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2]) + + @classmethod + def _impl_v16(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + class Size(OnnxOpConverter): """Convert an onnx Size node into an equivalent Relax expression.""" @@ -2827,7 +2857,7 @@ def _get_convert_map(): # "GatherND": GatherND, "Scatter": Scatter, "ScatterElements": ScatterElements, - # "ScatterND": ScatterND, + "ScatterND": ScatterND, # "Compress": Compress, "Size": Size, # "EyeLike": EyeLike, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index efd9997698ee..84b31ccec01e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -93,6 +93,7 @@ repeat, reshape, scatter_elements, + scatter_nd, split, squeeze, tile, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index da0a09cc7b51..1673a79b08c2 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -511,3 +511,42 @@ def scatter_elements( """ return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore + + +def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr: + """Scatter updates into an array according to indices. + + Parameters + ---------- + data: relax.Expr + The input data to be updated. + + indices: relax.Expr + The index positions to update in `data`. + + updates: relax.Expr + Values to replace to. + + reduction: str + Type of reduction to apply: update, add, mul, max, min. + It is "update" by default. + + Returns + ------- + result : relax.Expr + The result has the same shape as data. + + Examples + -------- + .. code-block:: python + + # inputs + data = [1, 2, 3, 4, 5, 6, 7, 8] + indices = [[4], [3], [1], [7]] + updates = [9, 10, 11, 12] + + # output + output = [1, 11, 3, 10, 9, 6, 7, 12] + + """ + return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 1efa78c069ad..105d763403af 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -168,6 +168,23 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.scatter_nd") +def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr: + # TODO(relax-team): Support native scatter_nd without te extern + def scatter_nd(data, indices, updates, reduction): + axes = list(range(len(indices.shape))) + indices = topi.transpose(indices, axes[-1:] + axes[:-1]) + return topi.scatter_nd(data, indices, updates, reduction) + + return bb.call_te( + scatter_nd, + call.args[0], + call.args[1], + call.args[2], + call.attrs.reduction, + ) + + @register_legalize("relax.layout_transform") def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: def te_layout_transform(data, name): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e6ff35ebe56b..f7847e2af8ed 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -138,6 +138,7 @@ round, rsqrt, scatter_elements, + scatter_nd, shape_of, shape_to_tensor, sigmoid, @@ -738,6 +739,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "cumsum", "einsum", "scatter_elements", + "scatter_nd", "dataflow", "device", "divide", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2b1c6eafb652..ca7d0a0945bc 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1531,5 +1531,139 @@ TVM_REGISTER_OP("relax.scatter_elements") .set_attr("FInferStructInfo", InferStructInfoScatterElements) .set_attr("FPurity", Bool(true)); +/* relax.scatter_nd */ +TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); + +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { + auto attrs = make_object(); + attrs->reduction = std::move(reduction); + static const Op& op = Op::Get("relax.scatter_nd"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); + +StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { + // `call->args` contains: [data, indices, updates] + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + ICHECK_EQ(call->args.size(), 3); + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* indices_sinfo = GetStructInfoAs(call->args[1]); + const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input data to be a tensor. However, the given type is " + << call->args[0]->GetTypeKey()); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input indices to be a tensor. However, the given type is " + << call->args[1]->GetTypeKey()); + } + if (updates_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input updates to be a tensor. However, the given type is " + << call->args[2]->GetTypeKey()); + } + + if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data and updates to have known dtype. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (data_sinfo->dtype != updates_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data to have same type with updates. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input indices to have integer dtype. However, " + "the given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + const auto* updates_shape = updates_sinfo->shape.as(); + + if (data_shape && indices_shape && updates_shape) { + const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as(); + if (!k_dim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND needs a static shape for the last axis of indices, got " + << indices_shape->values); + } + const size_t data_ndim = data_sinfo->ndim; + const size_t indices_ndim = indices_sinfo->ndim; + const size_t updates_ndim = updates_sinfo->ndim; + if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the rank of " + "`data tensor + indices tensor - last axis of indices tensor - 1`. " + "However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values) + << ", updates: " << ShapeExpr(updates_shape->values)); + } + if (k_dim->value > static_cast(data_ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the last axis of indices tensor to be less than " + "or equal to the rank of data tensor. However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values)); + } + Array expected_updates_shape; + for (size_t i = 0; i < indices_ndim - 1; i++) { + expected_updates_shape.push_back(indices_shape->values[i]); + } + for (size_t i = k_dim->value; i < data_ndim; i++) { + expected_updates_shape.push_back(data_shape->values[i]); + } + auto check_shape = [&](const Array& expected, const Array& actual) { + if (expected.size() != actual.size()) { + return false; + } + for (size_t i = 0; i < expected.size(); i++) { + if (!analyzer->CanProve(expected[i] == actual[i])) { + return false; + } + } + return true; + }; + if (!check_shape(expected_updates_shape, updates_shape->values)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the shape with constraint: " + << "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got " + << "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: " + << ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values)); + } + } + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.scatter_nd") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") + .set_attr("FInferStructInfo", InferStructInfoScatterND) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 68622f1359e0..e9fa1131e803 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -173,6 +173,39 @@ Expr tile(Expr data, Array repeats); */ Expr flip(Expr data, Integer axis); +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param axis The axis along which to scatter the elements. + * \param reduction The reduction mode of the scatter elements, + * either "update", "add", "mul", "mean", "max" or "min". + * \return The computed result. + */ +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction); + +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor to be updated. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param reduction The reduction mode of the scatter operation. + * Supported modes are: + * - "update": Replace the values at the indices with the update values. + * - "add": Add the update values to the existing values at the indices. + * - "mul": Multiply the existing values at the indices by the update values. + * - "max": Take the maximum of the existing value and the update value at each index. + * - "min": Take the minimum of the existing value and the update value at each index. + * \return The computed result tensor with the same shape as `data`. + * + * \note The shape of `indices` defines the shape of the scattered tensor. + * The last dimension of `indices` corresponds to the depth of each index vector. + * The shape of `updates` must match the shape of `indices` except for the last dimension, + * which must match the slice shape at each index. + */ +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 57f94c8442f7..9ac520c58e14 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -118,7 +118,6 @@ def check_correctness( tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) - print(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) @@ -523,6 +522,38 @@ def test_scatter(axis: int, name: str, opset: int): check_correctness(model, inputs={"indices": indices}, opset=opset) +@pytest.mark.parametrize("reduction", ["none", "add", "mul"]) +def test_scatter_nd(reduction): + def verify_scatter_nd(data_shape, indices_shape, updates_shape): + scatter_nd_node = helper.make_node( + "ScatterND", + ["data", "indices", "updates"], + ["output"], + reduction=reduction, + ) + + graph = helper.make_graph( + [scatter_nd_node], + "scatter_nd_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)], + ) + + model = helper.make_model(graph, producer_name="scatter_nd_test") + + indices = np.random.choice(data_shape[0], indices_shape) + check_correctness(model, inputs={"indices": indices}, opset=16) + + verify_scatter_nd([8], [4, 1], [4]) + verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4]) + verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6]) + verify_scatter_nd([10], [5, 1], [5]) + + def test_size(): test_node = helper.make_node("Size", ["x"], ["y"]) graph = helper.make_graph( diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index ddb92725d438..e958b03e4ce6 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -45,6 +45,7 @@ def test_op_correctness(): assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum") assert relax.op.flip(x, axis=1).op == Op.get("relax.flip") assert relax.op.scatter_elements(x, x, x).op == Op.get("relax.scatter_elements") + assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -3352,5 +3353,29 @@ def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): bb.normalize(relax.op.scatter_elements(d0, i0, u4)) +def test_scatter_nd_infer_struct_info(): + bb = relax.BlockBuilder() + + d0 = relax.Var("data", R.Tensor((8,), "float32")) + i0 = relax.Var("indices", R.Tensor((4, 1), "int64")) + u0 = relax.Var("updates", R.Tensor((4,), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d0, i0, u0, "update"), + relax.TensorStructInfo((8,), dtype="float32"), + ) + + d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32")) + i1 = relax.Var("indices", R.Tensor((2, 1), "int64")) + u1 = relax.Var("updates", R.Tensor((2, 4, 4), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d1, i1, u1, "update"), + relax.TensorStructInfo((4, 4, 4), dtype="float32"), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index a0ecd3c73dc9..0565b7a5790a 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import pytest import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -1739,5 +1738,66 @@ def te_layout_transform( tvm.ir.assert_structural_equal(Expected, After) +def test_scatter_nd(): + + # fmt: off + @I.ir_module + class Before: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv: R.Tensor((8,), "float32") = R.scatter_nd(data, indices, updates, reduction="update") + return gv + + After = relax.transform.LegalizeOps()(Before) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv = R.call_tir( + Expected.scatter_nd, (data, indices, updates), R.Tensor((8,), dtype="float32") + ) + return gv + + @T.prim_func(private=True) + def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, var_scatter_nd_generic: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + data = T.match_buffer(var_data, (T.int64(8),), offset_factor=1) + indices = T.match_buffer(var_indices, (T.int64(4), T.int64(1)), "int64") + updates = T.match_buffer(var_updates, (T.int64(4),), offset_factor=1) + out_buf = T.match_buffer(var_scatter_nd_generic, (T.int64(8),)) + with T.block("root"): + T.reads() + T.writes() + T_transpose = T.alloc_buffer((T.int64(1), T.int64(4)), "int64") + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(4)): + with T.block("T_transpose"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(4), ax1) + T.reads(indices[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = indices[v_ax1, v_ax0] + with T.block("scatter_nd_generic"): + T.reads() + T.writes() + for i in range(T.int64(8)): + out_buf[i] = data[i] + for j in range(T.int64(4)): + for k in T.parallel(T.int64(1)): + out_buf[k + T_transpose[j // T.int64(4), j % T.int64(4)]] = updates[j + k] + + # fmt: on + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main() From 74ed86b5df128dffeedac1eb6bbd345b1a756327 Mon Sep 17 00:00:00 2001 From: Honglin Zhu Date: Thu, 10 Oct 2024 10:37:02 +0800 Subject: [PATCH 03/20] [Relax][Frontend][Onnx] Add support for pad-2 (#17431) * fix params name bug * add support for onnx pad_v2 * Update test_frontend_onnx.py * Update onnx_frontend.py --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 29 ++++++++++ tests/python/relax/test_frontend_onnx.py | 57 +++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f1fa67546c2a..4770b7ce5cc5 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1582,6 +1582,35 @@ def _impl_v13(cls, bb, inputs, attr, params): class Pad(OnnxOpConverter): """Converts an onnx Pad node into an equivalent Relax expression.""" + @classmethod + def _impl_v2(cls, bb, inputs, attr, params): + pads = attr.get("pads") + pads = relax.const(_np.array(pads), inputs[0].struct_info.shape[0].dtype) + constant_value = attr.get("value") + if constant_value is None: + constant_value = 0.0 + + if isinstance(pads, relax.Constant): + pad_before, pad_after = _np.split(pads.data.numpy(), 2) + pad_before = _np.ndarray.tolist(pad_before) + pad_after = _np.ndarray.tolist(pad_after) + else: + raise ValueError("Dynamic pads are not supported yet.") + + pad_mode = attr.get("mode", b"constant").decode("utf-8") + if not pad_mode in ["constant", "edge", "reflect"]: + raise tvm.error.OpAttributeInvalid( + "Value " + pad_mode + ' in attribute "mode" is invalid for operator Pad.' + ) + + if pad_mode == "constant": + return bb.emit_te(topi.nn.pad, inputs[0], pad_before, pad_after, constant_value) + elif pad_mode == "reflect": + return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT") + else: + # TODO(gigiblender) Support edge mode. + raise NotImplementedError("Pad mode {} not implemented".format(pad_mode)) + @classmethod def _impl_v11(cls, bb, inputs, attr, params): pads = get_constant(inputs[1], params) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 9ac520c58e14..1b4c5d281abb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1696,6 +1696,63 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0): verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") +@pytest.mark.parametrize("dynamic", [True, False]) +def test_pad_v2(dynamic): + + if dynamic: + pytest.skip("Dynamic pad not supported") + + def verify_pad(input_shape, pads, mode="constant", value=0.0): + indata = np.random.normal(size=input_shape).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i + len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ["edge", "reflect"]: + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + "Pad", inputs=["input"], outputs=["output"], mode=mode, pads=pads + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)) + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + else: + outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) + node = helper.make_node( + "Pad", + inputs=["input"], + outputs=["output"], + mode="constant", + pads=pads, + value=value, + ) + graph = helper.make_graph( + [node], + "pad_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(indata.shape)) + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(outdata.shape)) + ], + ) + model = helper.make_model(graph, producer_name="pad_test") + check_correctness(model=model, opset=10) + + verify_pad((2, 2), [0, 1, 0, 0], "constant", 0.0) + verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0) + verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0) + verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") + + @pytest.mark.parametrize("fp_arith", [np.float16, np.float32]) @pytest.mark.parametrize("dynamic", [True, False]) def test_split(fp_arith, dynamic): From 7d2fa11bd16972368bfbaab0a872541fa76745a7 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 10 Oct 2024 23:02:51 +0800 Subject: [PATCH 04/20] Try to fix windows CI conda build issue (#17457) try fix ci --- conda/build-environment.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 8eb25ce01ac7..de4e6f4234d7 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -26,7 +26,8 @@ channels: # The packages to install to the environment dependencies: - python=3.9 - - conda-build + - conda < 24.9.0 + - conda-build < 24.9.0 - git - llvmdev >=11 - numpy From 22a9d388d441dbfd917d032564e2a1bccacd5f8c Mon Sep 17 00:00:00 2001 From: ysh329 Date: Fri, 11 Oct 2024 09:17:59 +0000 Subject: [PATCH 05/20] [release] Update version to 0.18.0 on main branch --- conda/recipe/meta.yaml | 2 +- include/tvm/runtime/c_runtime_api.h | 2 +- python/tvm/_ffi/libinfo.py | 2 +- version.py | 2 +- web/package-lock.json | 4 ++-- web/package.json | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/conda/recipe/meta.yaml b/conda/recipe/meta.yaml index d4477468c79d..c5e3840ff613 100644 --- a/conda/recipe/meta.yaml +++ b/conda/recipe/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = '0.18.dev0' %} +{% set version = '0.18.0' %} {% set pkg_name = 'tvm' %} {% set cuda_tag = cuda_version | replace('.', '') %} # [cuda] {% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda] diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index d26c95e4f53c..8071020cef28 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -73,7 +73,7 @@ #endif // TVM version -#define TVM_VERSION "0.18.dev0" +#define TVM_VERSION "0.18.0" // TVM Runtime is DLPack compatible. #include diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 2ec4ba8e31be..6e39d5b33a99 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -247,4 +247,4 @@ def find_include_path(name=None, search_path=None, optional=False): # We use the version of the incoming release for code # that is under development. # The following line is set by tvm/python/update_version.py -__version__ = "0.18.dev0" +__version__ = "0.18.0" diff --git a/version.py b/version.py index a827571c6cdf..cea1ba306c57 100644 --- a/version.py +++ b/version.py @@ -44,7 +44,7 @@ # Two tag formats are supported: # - vMAJ.MIN.PATCH (e.g. v0.8.0) or # - vMAJ.MIN.devN (e.g. v0.8.dev0) -__version__ = "0.18.dev0" +__version__ = "0.18.0" # --------------------------------------------------- diff --git a/web/package-lock.json b/web/package-lock.json index 751aaf2ef442..6c7e024f2236 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.18.0-dev2", + "version": "0.18.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.18.0-dev2", + "version": "0.18.0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/package.json b/web/package.json index a63997bb2f1c..c8d33be0b5e9 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,7 @@ "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", "homepage": "https://github.com/apache/tvm/tree/main/web", - "version": "0.18.0-dev2", + "version": "0.18.0", "files": [ "lib" ], From ab648358178a1c8a8a5116fc975f4618b3ede8aa Mon Sep 17 00:00:00 2001 From: ysh329 Date: Fri, 11 Oct 2024 10:14:24 +0000 Subject: [PATCH 06/20] [release] Update version to 0.19.dev0 on main branch --- conda/recipe/meta.yaml | 2 +- include/tvm/runtime/c_runtime_api.h | 2 +- python/tvm/_ffi/libinfo.py | 2 +- version.py | 2 +- web/package-lock.json | 4 ++-- web/package.json | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/conda/recipe/meta.yaml b/conda/recipe/meta.yaml index c5e3840ff613..e340b25e5ba1 100644 --- a/conda/recipe/meta.yaml +++ b/conda/recipe/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = '0.18.0' %} +{% set version = '0.19.dev0' %} {% set pkg_name = 'tvm' %} {% set cuda_tag = cuda_version | replace('.', '') %} # [cuda] {% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda] diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 8071020cef28..438d049ed4a1 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -73,7 +73,7 @@ #endif // TVM version -#define TVM_VERSION "0.18.0" +#define TVM_VERSION "0.19.dev0" // TVM Runtime is DLPack compatible. #include diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 6e39d5b33a99..f29ddaab72a9 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -247,4 +247,4 @@ def find_include_path(name=None, search_path=None, optional=False): # We use the version of the incoming release for code # that is under development. # The following line is set by tvm/python/update_version.py -__version__ = "0.18.0" +__version__ = "0.19.dev0" diff --git a/version.py b/version.py index cea1ba306c57..c8151769ba68 100644 --- a/version.py +++ b/version.py @@ -44,7 +44,7 @@ # Two tag formats are supported: # - vMAJ.MIN.PATCH (e.g. v0.8.0) or # - vMAJ.MIN.devN (e.g. v0.8.dev0) -__version__ = "0.18.0" +__version__ = "0.19.dev0" # --------------------------------------------------- diff --git a/web/package-lock.json b/web/package-lock.json index 6c7e024f2236..ddc14c7f134d 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "tvmjs", - "version": "0.18.0", + "version": "0.19.0-dev0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tvmjs", - "version": "0.18.0", + "version": "0.19.0-dev0", "license": "Apache-2.0", "devDependencies": { "@rollup/plugin-commonjs": "^20.0.0", diff --git a/web/package.json b/web/package.json index c8d33be0b5e9..a89b078cd776 100644 --- a/web/package.json +++ b/web/package.json @@ -3,7 +3,7 @@ "description": "TVM WASM/WebGPU runtime for JS/TS", "license": "Apache-2.0", "homepage": "https://github.com/apache/tvm/tree/main/web", - "version": "0.18.0", + "version": "0.19.0-dev0", "files": [ "lib" ], From 43f6c08f9db04adc73a17d3d99efdc6135ff0d3d Mon Sep 17 00:00:00 2001 From: sunzj Date: Mon, 14 Oct 2024 21:04:06 +0800 Subject: [PATCH 07/20] Show the record if the escape sequence is unsupported (#17458) * Show the record if the escape sequence is unsupported Show the record if the escape sequence is unspported. so we can find and check it. * Show the record if the escape sequence is unsupported Show the record if the escape sequence is unspported. so we can find and check it. --- src/meta_schedule/database/database_utils.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index ce025540e496..22b0933db4b4 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -236,7 +236,8 @@ class JSONTokenizer { str.push_back('\t'); break; default: - LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_; + LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_ + << ". record:" << std::string(cur_, end_); } } if (cur_ == end_) { From e3faa55573977300ccc4530331700eac65560b2e Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 15 Oct 2024 06:26:42 +0800 Subject: [PATCH 08/20] [JVM] Align Java GraphModule Initialization with Python API (#17464) [JVM] Align Java GraphModule initialization with Python API Java API is still using the outdated initialization method for `GraphModule`, which has led to issues where the old API no longer works as expected. This PR updates the Java API for `GraphModule` initialization to match the simplified method used in the Python API. --- .../main/java/org/apache/tvm/Function.java | 12 +++++++++++ .../src/main/java/org/apache/tvm/LibInfo.java | 2 ++ .../org/apache/tvm/contrib/GraphModule.java | 2 +- jvm/native/src/main/native/jni_helper_func.h | 21 +++++++++++++++++++ .../native/org_apache_tvm_native_c_api.cc | 15 +++++++++++++ 5 files changed, 51 insertions(+), 1 deletion(-) diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index df535a87aa85..594b35b0af68 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -222,6 +222,16 @@ public Function pushArg(byte[] arg) { return this; } + /** + * Push argument to the function. + * @param arg Device. + * @return this + */ + public Function pushArg(Device arg) { + Base._LIB.tvmFuncPushArgDevice(arg); + return this; + } + /** * Invoke function with arguments. * @param args Can be Integer, Long, Float, Double, String, NDArray. @@ -255,6 +265,8 @@ private static void pushArgToStack(Object arg) { Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); + } else if (arg instanceof Device) { + Base._LIB.tvmFuncPushArgDevice((Device) arg); } else if (arg instanceof TVMValue) { TVMValue tvmArg = (TVMValue) arg; switch (tvmArg.typeCode) { diff --git a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java index 62b8c901bd71..aede9be334c8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/LibInfo.java +++ b/jvm/core/src/main/java/org/apache/tvm/LibInfo.java @@ -37,6 +37,8 @@ class LibInfo { native void tvmFuncPushArgHandle(long arg, int argType); + native void tvmFuncPushArgDevice(Device device); + native int tvmFuncListGlobalNames(List funcNames); native int tvmFuncFree(long handle); diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java index 737fdef24ae8..0a0bc7efc46d 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphModule.java @@ -41,7 +41,7 @@ public class GraphModule { private Function fdebugGetOutput; private Function floadParams; - GraphModule(Module module, Device dev) { + public GraphModule(Module module, Device dev) { this.module = module; this.device = dev; fsetInput = module.getFunction("set_input"); diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index d60a1a4230b7..3e44f757392d 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -214,4 +214,25 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { return NULL; } +// Helper function to pack two int32_t values into an int64_t +inline int64_t deviceToInt64(const int32_t device_type, const int32_t device_id) { + int64_t result; + int32_t* parts = reinterpret_cast(&result); + + // Lambda function to check endianness + const auto isLittleEndian = []() -> bool { + uint32_t i = 1; + return *reinterpret_cast(&i) == 1; + }; + + if (isLittleEndian()) { + parts[0] = device_type; + parts[1] = device_id; + } else { + parts[1] = device_type; + parts[0] = device_id; + } + return result; +} + #endif // TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index 09522381f181..c039508b4b7f 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -112,6 +112,21 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* e->tvmFuncArgTypes.push_back(static_cast(argType)); } +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDevice(JNIEnv* env, jobject obj, + jobject arg) { + jclass deviceClass = env->FindClass("org/apache/tvm/Device"); + jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I"); + jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I"); + jint deviceType = env->GetIntField(arg, deviceTypeField); + jint deviceId = env->GetIntField(arg, deviceIdField); + + TVMValue value; + value.v_int64 = deviceToInt64(deviceType, deviceId); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); + e->tvmFuncArgValues.push_back(value); + e->tvmFuncArgTypes.push_back(kDLDevice); +} + JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); From 0c67cd8d294bbe683ef8cfbd50adefe9b2573b3a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 15 Oct 2024 12:49:20 -0400 Subject: [PATCH 09/20] Revert "[KVCACHE] Improved schedule for prefill attention" (#17466) Revert "[KVCACHE] Improved schedule for prefill attention (#17432)" This reverts commit 79abc0356ee66f3dbdd8bde3cbfcbf88a2ed746e. --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 60 ++++---------------- 1 file changed, 11 insertions(+), 49 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index fd866ae06c16..9b16fc2fbfee 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -925,12 +925,8 @@ def _attention_decode( THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - # Keeping lower thread limit for this kernel on adreno target - # to avoid register spill - THREAD_LIMIT = 256 + if target.kind.name == "opencl" and "android" in str(target.host): + THREAD_LIMIT = 256 if H_kv < 8 else 512 TILE_SIZE_PER_BDX = 1 max_num_threads_per_block = get_max_num_threads_per_block(target) thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) @@ -1574,11 +1570,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = ( - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - d, - 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), - ) + tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1588,12 +1580,6 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], tile_z = 8 num_warps = 2 - if target.kind.name == "opencl" and ( - ("android" in str(target.host)) or ("adreno" in str(target.attrs)) - ): - LOAD_VEC = 16 // ((DataType(dtype).bits + 7) // 8) # 16 bytes - NUM_BLKS = group_size * 8 - # fmt: off @T.prim_func def batch_prefill_ragged_kv( # pylint: disable=too-many-branches @@ -1722,6 +1708,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches for lz, ly in T.grid(tile_z, tile_y): with T.block("K_load"): i, j = T.axis.remap("SS", [lz, ly]) + T.reads() + T.writes() cur_L = L_kv_start + i if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( @@ -1836,14 +1824,6 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # fmt: on # pylint: enable=line-too-long,too-many-branches sch = tir.Schedule(batch_prefill_ragged_kv) - get_extent = lambda *lps: [int(sch.get(lp).extent) for lp in lps] - - def get_vecsize(extent): - return min(LOAD_VEC, (extent & ~(extent - 1))) - - def getxy_vecsize(x, y, t): - assert (x * y) % t == 0 - return min(get_vecsize(y), get_vecsize(x * y // t)) def get_tile_size(x, y, t): cnt = (x * y) // t @@ -1857,37 +1837,26 @@ def get_tile_size(x, y, t): def apply_to_qkv_load(sch: tir.Schedule, block): loop_x, loop_y = sch.get_loops(block)[-2:] - x_extent, y_extent = get_extent(loop_x, loop_y) - vec_size = getxy_vecsize(x_extent, y_extent, bdx * num_warps) - yo, yv = sch.split(loop_y, [None, vec_size]) - yo_extent = y_extent // vec_size - tile_x, tile_y = get_tile_size(x_extent, yo_extent, (bdx * num_warps)) - xo, xi = sch.split(loop_x, [tile_x, None]) - yo, yi = sch.split(yo, [tile_y, None]) - sch.reorder(xi, yi, xo, yo) - t = sch.fuse(xi, yi) - ty, tx = sch.split(t, [num_warps, bdx]) + loop = sch.fuse(loop_x, loop_y) + _, ty, tx, vec = sch.split( + loop, factors=[None, num_warps, bdx, LOAD_VEC], preserve_unit_iters=True + ) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.vectorize(yv) + sch.vectorize(vec) def apply_to_so_ewise(sch: tir.Schedule, block, tile): loop_x, loop_y = sch.get_loops(block)[-2:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) yo, yi = sch.split(loop_y, factors=[None, tile[1]]) sch.reorder(xo, yo, xi, yi) - sch.unroll(xi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) t = sch.fuse(xo, yo) ty, tx = sch.split(t, factors=[None, bdx]) sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") def apply_to_gemm( # pylint: disable=unused-argument - sch: tir.Schedule, block, tile, read_0, read_1, r_len=16, k_major=False + sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False ): loop_x, loop_y, loop_z = sch.get_loops(block)[-3:] xo, xi = sch.split(loop_x, factors=[None, tile[0]]) @@ -1903,12 +1872,6 @@ def apply_to_gemm( # pylint: disable=unused-argument sch.reorder(ko, xi, yi, ki) else: sch.reorder(ko, ki, xi, yi) - yiv_extent = get_vecsize(tile[1]) - yio, yiv = sch.split(yi, [None, yiv_extent]) - sch.unroll(yio) - sch.vectorize(yiv) - sch.unroll(xi) - sch.unroll(ki) sch.decompose_reduction(block, ty) def apply_to_md(sch, block): @@ -1917,7 +1880,6 @@ def apply_to_md(sch, block): sch.bind(ty, "threadIdx.y") sch.bind(tx, "threadIdx.x") - sch.transform_layout("K_load", ("write", 0), lambda i, j: (j, i)) tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps) tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps) apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True) From 02172c3a5e36433257f83cf8bd0c7f48c993363d Mon Sep 17 00:00:00 2001 From: Hussein Taher <6496177+Husenap@users.noreply.github.com> Date: Wed, 16 Oct 2024 04:48:57 +0200 Subject: [PATCH 10/20] [FIX][RELAX][ONNX] Fix typo in onnx frontend (#17467) Fixed typo in onnx_frontend.py --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4770b7ce5cc5..43c1ec681a2f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -260,7 +260,7 @@ def base_impl(cls, bb, inputs, attr, params): else inputs[0].data.numpy() ) y = ( - _np.array(inputs[0].value) + _np.array(inputs[1].value) if isinstance(inputs[1], relax.PrimValue) else inputs[1].data.numpy() ) From 35d6a1b9d27f1128bd00edef541be0d1f9f61dd9 Mon Sep 17 00:00:00 2001 From: albert qing <2628869@qq.com> Date: Wed, 16 Oct 2024 10:50:32 +0800 Subject: [PATCH 11/20] [TIR][Schedule] Add annotate_buffer_access primitive (#17423) Co-authored-by: qsqqsqqsq-intellif --- include/tvm/tir/schedule/schedule.h | 11 + include/tvm/tir/stmt.h | 10 + python/tvm/tir/schedule/schedule.py | 136 +++++++ src/tir/schedule/concrete_schedule.cc | 10 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 10 + .../primitive/annotate_buffer_access.cc | 167 +++++++++ src/tir/schedule/schedule.cc | 7 + src/tir/schedule/traced_schedule.cc | 12 + src/tir/schedule/traced_schedule.h | 2 + src/tir/transforms/compact_buffer_region.cc | 43 ++- ...est_tir_schedule_annotate_buffer_access.py | 332 ++++++++++++++++++ 12 files changed, 736 insertions(+), 6 deletions(-) create mode 100644 src/tir/schedule/primitive/annotate_buffer_access.cc create mode 100644 tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 092bd52d5634..e4b13888f948 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object { */ virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; + /*! + * \brief Annotate the buffer access of a block + * \param block_rv The block to be annotated + * \param buffer_index The index of the buffer in block's read or write region + * \param buffer_index_type The type of the buffer index, kRead or kWrite. + * \param index_map The index map that defines the new read or write region + */ + virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index c77254ed34cb..38289af463d5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution"; /*! \brief Mark that a block is disallowed in auto inline. */ constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule"; +/*! \brief Mark that a block has an explicitly specified read region. + * This is used to override the default read region inference in TIR. + */ +constexpr const char* explicit_read_region = "explicit_read_region"; + +/*! \brief Mark that a block has an explicitly specified write region. + * This is used to override the default write region inference in TIR. + */ +constexpr const char* explicit_write_region = "explicit_write_region"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index be88e234634f..17c256be3538 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3907,3 +3907,139 @@ def unsafe_hide_buffer_access( buf_type, buf_index_array, ) + + @type_checked + def annotate_buffer_access( + self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable + ) -> None: + """Annotate the read or write region of a block + + Parameters + ---------- + block : BlockRV + The block to be annotated + buffer_index : int + The index of the buffer in block's read or write region + buf_type : str + The buffer type: "read" or "write" + gen_new_ranges : Callable + A function that takes the block's iter_vars and returns a + Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...] + which defines the new read or write region for the buffer. + Each element in the tuple can be: + - A single PrimExpr representing the iter_var itself + - A tuple of two PrimExprs representing the range (begin, end) + + Examples + -------- + Annotate a 2D read region for a buffer. + Before annotate_buffer_access, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_annotate_buffer_access( + A: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do annotate_buffer_access: + + .. code-block:: python + + sch = tir.Schedule(before_annotate_buffer_access) + block = sch.get_block("B") + sch.annotate_buffer_access(block, 0, "read", + lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1))) + print(sch.mod["main"].script()) + + After applying annotate_buffer_access, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_annotate_buffer_access( + A: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32") + ) -> None: + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": 0}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + This annotates the read region for buffer A (index 0) in block "B" to be + [vi-1:vi+1, vj-1:vj+1] for each (vi, vj) in the block's iteration domain. + + Note + ---- + This function allows manual specification of read or write regions, which + can be useful in cases where the compiler cannot accurately infer the + access pattern, such as complex data-dependent accesses. + It overrides the automatically inferred region for the specified buffer. + The function adds an annotation to the block, indicating that an explicit + region has been provided for the buffer at the given index. This annotation + is used in the CompactBufferAllocation pass to respect the manually specified + region instead of relying on automatic inference. + + Caution should be exercised when using this function, as incorrect annotations + may lead to incorrect code generation or runtime errors. It's crucial to + ensure that the specified region covers all actual reads or writes performed + by the block for the given buffer. + + """ + block_obj = self.get(block) + iter_vars = [x.var for x in block_obj.iter_vars] + new_ranges_spec = gen_new_ranges(*iter_vars) + if len(iter_vars) != len(new_ranges_spec): + raise ValueError( + f"Number of iter_vars ({len(iter_vars)}) must match " + f"number of new_ranges_spec ({len(new_ranges_spec)})" + ) + + result = [] + for rng in new_ranges_spec: + if isinstance(rng, (tuple, list)): + if len(rng) != 2: + raise ValueError( + "Tuple must have exactly 2 elements to represent (begin, end)." + ) + result.extend(rng) + elif isinstance(rng, PrimExpr): + result.extend([rng, rng + 1]) # Single point represented as (rng, rng + 1) + else: + raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}") + + # Create index_map using IndexMap constructor + index_map = IndexMap( + initial_indices=iter_vars, + final_indices=result, + inverse_index_map=None, + ) + + if buf_type == "read": + buffer_index_type = 0 + elif buf_type == "write": + buffer_index_type = 1 + else: + raise ValueError(f"Invalid buf_type: {buf_type}. Expected 'read' or 'write'.") + + return _ffi_api.ScheduleAnnotateBufferAccess( + self, block, buffer_index, buffer_index_type, index_map + ) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 73b5ff3fafd4..f6cb1f05ef6e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -1059,5 +1059,15 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const this->state_->DebugVerify(); } +void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, + index_map); + TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_); + this->state_->DebugVerify(); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 092bcf0c79f9..b8ad56d2ab56 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -183,6 +183,8 @@ class ConcreteScheduleNode : public ScheduleNode { void EnterPostproc() override {} void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, const Array& buf_index_array) override; + void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) override; protected: /******** Utility functions ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fd1349e4a3ec..cf1ac957c89f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -718,6 +718,16 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, const String& buf_type, const Array& buf_index_array); +/*! + * \brief Annotate the read or write region of a specific buffer in a block + * \param self The state of the schedule + * \param block_sref The sref of the block to be annotated + * \param buffer_index The index of the buffer in block's read or write region + * \param buffer_index_type The type of the buffer index, kRead or kWrite + * \param index_map The IndexMap that defines the new read or write region for the buffer + */ +TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/annotate_buffer_access.cc b/src/tir/schedule/primitive/annotate_buffer_access.cc new file mode 100644 index 000000000000..2c5976b035dd --- /dev/null +++ b/src/tir/schedule/primitive/annotate_buffer_access.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class AnnotateRegionRewriter : public StmtExprMutator { + public: + AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion new_region, + BufferIndexType buffer_index_type) + : buffer_(buffer), + buffer_index_(buffer_index), + new_region_(new_region), + buffer_index_type_(buffer_index_type) {} + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + Array regions = + buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; + ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; + ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; + regions.Set(buffer_index_, new_region_); + + ObjectPtr n = CopyOnWrite(block.get()); + if (buffer_index_type_ == BufferIndexType::kWrite) { + n->writes = std::move(regions); + } else { + n->reads = std::move(regions); + } + + // Annotate the block with explicit_read_region or explicit_write_region + Map new_annotations = n->annotations; + String annotation_key = buffer_index_type_ == BufferIndexType::kWrite + ? attr::explicit_write_region + : attr::explicit_read_region; + if (new_annotations.count(annotation_key)) { + Array buffer_indices = Downcast>(new_annotations[annotation_key]); + bool found = false; + for (const Integer& index : buffer_indices) { + if (index->value == buffer_index_) { + found = true; + break; + } + } + if (!found) { + buffer_indices.push_back(Integer(buffer_index_)); + new_annotations.Set(annotation_key, buffer_indices); + } + } else { + new_annotations.Set(annotation_key, Array{Integer(buffer_index_)}); + } + n->annotations = std::move(new_annotations); + + return Block(n); + } + + private: + Buffer buffer_; + int buffer_index_; + BufferRegion new_region_; + BufferIndexType buffer_index_type_; +}; + +void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, buffer_index_type); + + arith::Analyzer analyzer; + Array block_iter_vars; + for (const IterVar& iter_var : block->iter_vars) { + block_iter_vars.push_back(iter_var->var); + } + Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; + Array new_ranges; + for (size_t i = 0; i < new_indices.size(); i += 2) { + // (begin, end) represents a region + new_ranges.push_back(Range::FromMinExtent( + new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i]))); + } + + BufferRegion new_region(buffer, new_ranges); + + AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type); + Stmt new_stmt = mutator(GetRef(block_sref->stmt)); + + self->Replace(block_sref, new_stmt, {{GetRef(block), Downcast(new_stmt)}}); +} + +struct AnnotateBufferAccessTraits : public UnpackedInstTraits { + static constexpr const char* kName = "AnnotateBufferAccess"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 4; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, + Integer buffer_index_type, IndexMap index_map) { + return sch->AnnotateBufferAccess(block, buffer_index->value, + static_cast(buffer_index_type->value), + index_map); + } + + static String IndexMap2GenNewRangesLambda(const IndexMap& index_map) { + std::ostringstream oss; + oss << "lambda "; + for (size_t i = 0; i < index_map->initial_indices.size(); ++i) { + if (i != 0) oss << ", "; + oss << index_map->initial_indices[i]; + } + oss << ": ["; + for (size_t i = 0; i < index_map->final_indices.size(); i += 2) { + if (i != 0) oss << ", "; + if (index_map->final_indices[i].same_as(index_map->final_indices[i + 1])) { + oss << index_map->final_indices[i]; + } else { + oss << "(" << index_map->final_indices[i] << ", " << index_map->final_indices[i + 1] << ")"; + } + } + oss << "]"; + return String(oss.str()); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index, + Integer buffer_index_type, IndexMap index_map) { + PythonAPICall py("annotate_buffer_access"); + py.Input("block", block); + py.Input("buffer_index", buffer_index->value); + + std::ostringstream os; + os << "\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) + << "\""; + py.Input("buf_type", os.str()); + + py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map)); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 44f9b8f42c68..2c3661d17ecc 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -310,6 +310,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess") .set_body_method(&ScheduleNode::UnsafeHideBufferAccess); +/******** (FFI) Annotate buffer access ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess") + .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, + int buffer_index_type, const IndexMap& index_map) { + return self->AnnotateBufferAccess(block_rv, buffer_index, + static_cast(buffer_index_type), index_map); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 1611109d7735..d790f21e671a 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -769,5 +769,17 @@ void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const S /*outputs=*/{})); } +void TracedScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, + const IndexMap& index_map) { + ConcreteScheduleNode::AnnotateBufferAccess(block_rv, buffer_index, buffer_index_type, index_map); + static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, Integer(buffer_index), Integer(buffer_index_type), index_map}, + /*attrs=*/{}, + /*outputs=*/{})); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 78629e84f039..1c21c3e2c894 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -142,6 +142,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { void EnterPostproc() final; void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type, const Array& buf_index_array) final; + void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index, + BufferIndexType buffer_index_type, const IndexMap& index_map) final; }; } // namespace tir diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index f562a057e595..7385af49528b 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -136,7 +136,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + auto explicit_it = explicit_access_annotations_.find(op->buffer); + if (explicit_it != explicit_access_annotations_.end()) { + VisitBufferAccess(explicit_it->second); + } else { + VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + } StmtExprVisitor::VisitExpr_(op); } @@ -235,17 +240,38 @@ class BufferAccessRegionCollector : public StmtExprVisitor { auto& regions = access_annotations_[p.first]; p.second.swap(regions); } - // Step 2. Record relax position of ancestor_loops_ + + // Step 2. Record explicit read/write region annotations + auto record_explicit_region = [&](const String& attr_key, BufferIndexType index_type) { + auto it = op->annotations.find(attr_key); + if (it != op->annotations.end()) { + Array buffer_indices = Downcast>((*it).second); + for (const auto& index : buffer_indices) { + int buffer_index = index->value; + if (buffer_index >= 0 && buffer_index < static_cast(op->reads.size())) { + const BufferRegion& explicit_region = index_type == BufferIndexType::kRead + ? op->reads[buffer_index] + : op->writes[buffer_index]; + explicit_access_annotations_[explicit_region->buffer] = explicit_region; + } + } + } + }; + + record_explicit_region(attr::explicit_read_region, BufferIndexType::kRead); + record_explicit_region(attr::explicit_write_region, BufferIndexType::kWrite); + + // Step 3. Record relax position of ancestor_loops_ for (const Buffer& buffer : op->alloc_buffers) { VisitBufferDef(buffer->data); } - // Step 3. Visit match buffers + // Step 4. Visit match buffers for (const MatchBufferRegion& region : op->match_buffers) { VisitBufferAccess(region->source); } - // Step 4. Visit block body recursively + // Step 5. Visit block body recursively StmtExprVisitor::VisitStmt_(op); - // Step 5. Recover read/write region annotations + // Step 6. Recover read/write region annotations for (auto& p : cur_access_annotations) { auto& regions = access_annotations_[p.first]; if (p.second.empty()) { @@ -254,7 +280,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { regions.swap(p.second); } } - // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. + // Step 7. Clear explicit access annotations + explicit_access_annotations_.clear(); + // Step 8. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. for (const Buffer& buffer : op->alloc_buffers) { ICHECK_EQ(var2buffer_[buffer->data].size(), 1) << "Block allocation buffer shoud not be alised"; @@ -489,6 +517,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief The map from Buffer to it's access regions annotated by current block. */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> access_annotations_; + /*! \brief The map from Buffer to its explicit access region annotated by the block. */ + std::unordered_map + explicit_access_annotations_; }; /*! \brief The storage alignment for a dimension */ diff --git a/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py new file mode 100644 index 000000000000..cc09a807dcac --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_annotate_buffer_access.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) + + +def test_annotate_read_buffer_access(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1 : vi - 1 + 2, vj - 1 : vj - 1 + 2]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)) + ) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_write_buffer_access(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi : vi + 2, vj : vj + 2]) + T.block_attr({"explicit_write_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_buffer_access_for_resize(): + # fmt: off + @T.prim_func + def resize_before(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): + for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, 0:32, 0:32]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) + + @T.prim_func + def resize_expected(x: T.Buffer((1, 1, 32, 32), "float16"), resize: T.Buffer((1, 1, 16, 16), "float16")): + for i0, i1, i2, i3 in T.grid(1, 1, 16, 16): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 + 3, v_i3 * 2 - 3:v_i3 * 2 + 3]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"explicit_read_region": [0]}) + resize[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", T.Cast("float32", x[v_i0, v_i1, T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i2) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0), T.max(T.min(T.Cast("int32", T.floor((T.Cast("float32", v_i3) + T.float32(0.5)) * T.float32(2) - T.float32(0.5) + T.float32(1.0000000000000001e-05))), 31), 0)])) + # fmt: on + sch = tir.Schedule(resize_before, debug_mask="all") + block = sch.get_block("resize") + sch.annotate_buffer_access( + block, + 0, + "read", + gen_new_ranges=lambda v_i0, v_i1, v_i2, v_i3: [ + v_i0, + v_i1, + (v_i2 * 2 - 3, v_i2 * 2 + 3), + (v_i3 * 2 - 3, v_i3 * 2 + 3), + ], + ) + assert_structural_equal_ignore_global_symbol(sch.mod["main"], resize_expected) + verify_trace_roundtrip(sch=sch, mod=resize_before) + + +def test_annotate_buffer_access_read_and_write(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 1 : vi + 2, vj - 1 : vj + 2]) + T.writes(B[vi : vi + 2, vj : vj + 2]) + T.block_attr({"explicit_read_region": [0], "explicit_write_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) + ) + + sch.annotate_buffer_access(block, 0, "write", lambda vi, vj: ((vi, vi + 2), (vj, vj + 2))) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_double_annotate_buffer_access_read(): + @T.prim_func + def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi - 2 : vi + 3, vj - 2 : vj + 3]) + T.writes(B[vi, vj]) + T.block_attr({"explicit_read_region": [0]}) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("B") + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 1, vi + 2), (vj - 1, vj + 2)) + ) + + sch.annotate_buffer_access( + block, 0, "read", lambda vi, vj: ((vi - 2, vi + 3), (vj - 2, vj + 3)) + ) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], expected) + verify_trace_roundtrip(sch=sch, mod=before) + + +def test_annotate_buffer_access_with_compute_at_for_resize(): + # fmt: off + @T.prim_func + def before(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer([1, 3, 200, 200], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(1, 3, 200, 200): + with T.block("cache"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 3, 100, 100): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(v_i2 * 2 + 0.5)), T.Cast("int32", T.floor(v_i3 * 2 + 0.5))] + + @T.prim_func + def after(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer((1, 3, 200, 200)) + for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): + for ax0, ax1 in T.grid(24, 24): + with T.block("cache"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(3, i1) + v2 = T.axis.spatial(200, i2_0 * 20 - 3 + ax0) + v3 = T.axis.spatial(200, i3_0 * 20 - 3 + ax1) + T.where(3 <= i2_0 * 20 + ax0 and i2_0 * 20 + ax0 < 203 and 3 <= i3_0 * 20 + ax1 and i3_0 * 20 + ax1 < 203) + T.reads(x[v0, v1, v2, v3]) + T.writes(x_global[v0, v1, v2, v3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i2_1, i3_1 in T.grid(10, 10): + with T.block("resize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) + v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) + T.reads(x_global[v_i0, v_i1, v_i2 * 2 - 3:v_i2 * 2 - 3 + 6, v_i3 * 2 - 3:v_i3 * 2 - 3 + 6]) + T.writes(y[v_i0, v_i1, v_i2, v_i3]) + T.block_attr({"explicit_read_region": [0]}) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] + + @T.prim_func + def after_without_annotate_buffer_access(x: T.Buffer((1, 3, 200, 200), "float32"), y: T.Buffer((1, 3, 100, 100), "float32")): + x_global = T.alloc_buffer((1, 3, 200, 200)) + for i0, i1, i2_0, i3_0 in T.grid(1, 3, 10, 10): + for ax0, ax1 in T.grid(200, 200): + with T.block("cache"): + v0 = T.axis.spatial(1, 0) + v1, v2, v3 = T.axis.remap("SSS", [i1, ax0, ax1]) + T.reads(x[v0, v1, v2, v3]) + T.writes(x_global[v0, v1, v2, v3]) + x_global[v0, v1, v2, v3] = x[v0, v1, v2, v3] + for i2_1, i3_1 in T.grid(10, 10): + with T.block("resize"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + v_i2 = T.axis.spatial(100, i2_0 * 10 + i2_1) + v_i3 = T.axis.spatial(100, i3_0 * 10 + i3_1) + T.reads(x_global[v_i0, v_i1, 0:200, 0:200]) + T.writes(y[v_i0, v_i1, v_i2, v_i3]) + y[v_i0, v_i1, v_i2, v_i3] = x_global[v_i0, v_i1, T.Cast("int32", T.floor(T.Cast("float32", v_i2 * 2) + T.float32(0.5))), T.Cast("int32", T.floor(T.Cast("float32", v_i3 * 2) + T.float32(0.5)))] + # fmt: on + + # Schedule with annotate_buffer_access + sch = tir.Schedule(before, debug_mask="all") + block = sch.get_block("resize") + cache_block = sch.get_block("cache") + + # Annotate buffer access + sch.annotate_buffer_access( + block, + 0, + "read", + lambda vn, vc, vh, vw: (vn, vc, (vh * 2 - 3, vh * 2 + 3), (vw * 2 - 3, vw * 2 + 3)), + ) + + h, w = sch.get_loops(block)[-2:] + ho, hi = sch.split(h, factors=[10, 10]) + wo, wi = sch.split(w, factors=[10, 10]) + sch.reorder(ho, wo, hi, wi) + sch.compute_at(cache_block, wo) + + assert_structural_equal_ignore_global_symbol(sch.mod["main"], after) + verify_trace_roundtrip(sch=sch, mod=before) + + # Schedule without annotate_buffer_access + sch_without_annotate = tir.Schedule(before, debug_mask="all") + block_without_annotate = sch_without_annotate.get_block("resize") + cache_block_without_annotate = sch_without_annotate.get_block("cache") + + h, w = sch_without_annotate.get_loops(block_without_annotate)[-2:] + ho, hi = sch_without_annotate.split(h, factors=[10, 10]) + wo, wi = sch_without_annotate.split(w, factors=[10, 10]) + sch_without_annotate.reorder(ho, wo, hi, wi) + sch_without_annotate.compute_at(cache_block_without_annotate, wo) + + assert_structural_equal_ignore_global_symbol( + sch_without_annotate.mod["main"], after_without_annotate_buffer_access + ) + + +if __name__ == "__main__": + tvm.testing.main() From 58a43c87245e58ee09f2cdbde26fb2cc5167df9d Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 16 Oct 2024 11:04:37 +0800 Subject: [PATCH 12/20] [MetaSchedule] Fix a multilevel tiling error on dynamic relax workload (#17465) fix meta-schedule tiling primitive segfault on dynamic workload Co-authored-by: wrongtest --- src/tir/schedule/analysis/analysis.cc | 4 +-- src/tir/schedule/concrete_schedule.cc | 4 ++- src/tir/schedule/concrete_schedule.h | 12 +++++-- src/tir/schedule/trace.cc | 4 ++- src/tir/schedule/traced_schedule.cc | 8 +++-- .../test_tir_schedule_sampling.py | 28 +++++++++++++++ .../test_tir_schedule_split_fuse.py | 35 +++++++++++++++++++ 7 files changed, 86 insertions(+), 9 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b60e60c3cfc9..6195313fddae 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1581,14 +1581,14 @@ std::pair GetCumulativeSpaceAndReductionLength(const tir::Sche tir::IterVarType type = GetLoopIterType(loop_sref); if (type == tir::kDataPar) { const int64_t* extent = GetLoopIntExtent(loop_sref); - if (*extent != -1) { + if (extent && *extent != -1) { cum_space_len *= *extent; } else { return std::make_pair(-1, -1); } } else if (type == tir::kCommReduce) { const int64_t* extent = GetLoopIntExtent(loop_sref); - if (*extent != -1) { + if (extent && *extent != -1) { cum_reduce_len *= *extent; } else { return std::make_pair(-1, -1); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index f6cb1f05ef6e..dd1a376deaf8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -246,8 +246,10 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int int max_innermost_factor, Optional> decision) { TVM_TIR_SCHEDULE_BEGIN(); + // use None RV object to denotes auto-infer tile factors. return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, - max_innermost_factor, &decision)); + max_innermost_factor, &decision), + /*convert_negone_to_none=*/true); TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); throw; } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b8ad56d2ab56..4aebe3036cf2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -219,9 +219,12 @@ class ConcreteScheduleNode : public ScheduleNode { /*! * \brief Add a list of integers as random variables into the symbol table * \param value The list of integers to be added to the symbol table + * \param convert_negone_to_none Convert negative one to none RV. + * Which is convention of certain primitives. * \return The new random variables created */ - inline Array CreateRV(const std::vector& value); + inline Array CreateRV(const std::vector& value, + bool convert_negone_to_none = false); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); /*! @@ -362,10 +365,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return std::move(rv); } -inline Array ConcreteScheduleNode::CreateRV(const std::vector& value) { +inline Array ConcreteScheduleNode::CreateRV(const std::vector& value, + bool convert_negone_to_none) { Array results; results.reserve(value.size()); for (int64_t v : value) { + if (convert_negone_to_none && v == -1) { + results.push_back(ExprRV(nullptr)); + continue; + } results.push_back(CreateRV(v)); } return results; diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 6e243bf19198..7421cbbf32df 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -227,7 +227,9 @@ Array TranslateAddOutputRVs( ICHECK(!rv_names->count(output)) << "ValueError: The random variable has been produced once: " << rv_names->at(output); String result{ObjectPtr{nullptr}}; - if (output->IsInstance()) { + if (!output.defined()) { + result = "_"; + } else if (output->IsInstance()) { result = "b" + std::to_string(i); } else if (output->IsInstance()) { result = "l" + std::to_string(i); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d790f21e671a..784ecdeb32cb 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -70,9 +70,11 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidat Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision) { - Array results = CreateRV(tir::SamplePerfectTile( - &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); - + // use None RV object to denotes auto-infer tile factors. + Array results = + CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision), + /*convert_negone_to_none=*/true); static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{loop_rv}, diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index 8ae576e9b922..f37c818e7992 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -212,5 +212,33 @@ def test_sample_perfect_tile_after_copy(): sch_copy.sample_perfect_tile(i, n=4) +def test_sample_perfect_tile_on_dynamic_loops(): + """Currently dynamic loop is trivially tiled""" + + @T.prim_func + def workload(a: T.handle) -> None: + n = T.int32() + A = T.match_buffer(a, (n, 1024)) + for i, j in T.grid(n, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + A[vi, vj] = 1.0 + + sch = tir.Schedule(workload, debug_mask="all") + di, si = sch.get_loops(sch.get_block("B")) + + factors = sch.sample_perfect_tile(si, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 1024 + + factors = sch.sample_perfect_tile(di, n=4) + assert factors[0] is None + factors = [sch.get(i) for i in factors[1:]] + prod = factors[0] * factors[1] * factors[2] + assert prod == 1 + verify_trace_roundtrip(sch, mod=workload) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py index f5e5b3b54e76..22344acfe1d4 100644 --- a/tests/python/tir-schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/tir-schedule/test_tir_schedule_split_fuse.py @@ -389,6 +389,41 @@ def test_split_with_inferred_factor(): verify_trace_roundtrip(sch=sch, mod=elementwise) +def test_split_with_dynamic_inferred_factor(): + @T.prim_func + def before(a: T.handle, b: T.handle) -> None: + N = T.int32() + M = T.int32() + A = T.match_buffer(a, (N, 128, M)) + B = T.match_buffer(b, (N, 128, M)) + for i, j, k in T.grid(N, 128, M): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + @T.prim_func + def expected(a: T.handle, b: T.handle) -> None: + N, M = T.int32(), T.int32() + A = T.match_buffer(a, (N, 128, M)) + B = T.match_buffer(b, (N, 128, M)) + for i_0, i_1, j_0, j_1, k_0, k_1 in T.grid((N + 15) // 16, 16, 4, 32, 16, (M + 15) // 16): + with T.block("B"): + vi = T.axis.spatial(N, i_0 * 16 + i_1) + vj = T.axis.spatial(128, j_0 * 32 + j_1) + vk = T.axis.spatial(M, k_0 * ((M + 15) // 16) + k_1) + T.where(i_0 * 16 + i_1 < N and k_0 * ((M + 15) // 16) + k_1 < M) + B[vi, vj, vk] = A[vi, vj, vk] * T.float32(2.0) + + sch = tir.Schedule(before, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 16]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, None]) + assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=before) + + def test_split_with_predicate(): sch = tir.Schedule(elementwise, debug_mask="all") block_b = sch.get_block("B") From c6a5b7869023f7fd7b2926be847d39d363c13def Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 17 Oct 2024 01:05:34 +0800 Subject: [PATCH 13/20] [Relax] Enhance Relax op and ONNX frontend (#17462) --- include/tvm/relax/attrs/manipulate.h | 11 +++ .../tvm/relax/frontend/onnx/onnx_frontend.py | 66 +++++++++++++-- python/tvm/relax/op/__init__.py | 5 ++ python/tvm/relax/op/binary.py | 26 ++++++ python/tvm/relax/op/create.py | 68 +++++++++++++++ python/tvm/relax/op/manipulate.py | 44 ++++++++++ .../relax/transform/legalize_ops/binary.py | 3 +- .../relax/transform/legalize_ops/create.py | 30 +++++++ .../transform/legalize_ops/manipulate.py | 19 +++++ python/tvm/script/ir_builder/relax/ir.py | 10 +++ python/tvm/topi/tensor.py | 35 +++++++- src/relax/op/distributed/binary.cc | 2 + src/relax/op/tensor/binary.cc | 2 + src/relax/op/tensor/binary.h | 6 ++ src/relax/op/tensor/create.cc | 84 +++++++++++++++++++ src/relax/op/tensor/create.h | 40 ++++++++- src/relax/op/tensor/manipulate.cc | 75 +++++++++++++++++ src/relax/op/tensor/manipulate.h | 12 +++ tests/python/relax/test_frontend_onnx.py | 26 +++++- tests/python/relax/test_op_create.py | 58 +++++++++++++ tests/python/relax/test_op_manipulate.py | 52 ++++++++++++ 21 files changed, 657 insertions(+), 17 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index e53ba3c36e7f..ea41488354d8 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -176,6 +176,17 @@ struct ScatterNDAttrs : public tvm::AttrsNode { } }; // struct ScatterNDAttrs +/*! \brief Attributes used in one_hot operator */ +struct OneHotAttrs : public tvm::AttrsNode { + int depth; + int axis; + + TVM_DECLARE_ATTRS(OneHotAttrs, "relax.attrs.OneHotAttrs") { + TVM_ATTR_FIELD(depth).describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill."); + } +}; // struct OneHotAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 43c1ec681a2f..6c9225070d3f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -287,7 +287,7 @@ class Sub(BinaryBase): relax_op = relax.op.subtract @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -298,7 +298,7 @@ class Mul(BinaryBase): relax_op = relax.op.multiply @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -309,7 +309,7 @@ class Div(BinaryBase): relax_op = relax.op.divide @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) @@ -320,7 +320,24 @@ class Pow(BinaryBase): relax_op = relax.op.power @classmethod - def _impl_v1(cls, bb, inputs, attr, params): + def _impl_v7(cls, bb, inputs, attr, params): + return cls.base_impl(bb, inputs, attr, params) + + +class Mod(BinaryBase): + """Converts an onnx Mod node into an equivalent Relax expression.""" + + numpy_op = _np.mod + relax_op = relax.op.mod + + @classmethod + def _impl_v10(cls, bb, inputs, attr, params): + if attr.get("fmod", 0) == 0: + cls.numpy_op = _np.fmod + cls.relax_op = relax.op.floor_mod + else: + cls.numpy_op = _np.mod + cls.relax_op = relax.op.mod return cls.base_impl(bb, inputs, attr, params) @@ -523,6 +540,23 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.nn.log_softmax(inputs[0], axis=axis) +class Hardmax(OnnxOpConverter): + """Converts an onnx Hardmax node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr, params): + axis = attr.get("axis", -1) + indices = inputs[0] + dtype = indices.struct_info.dtype + axis_len = int(inputs[0].struct_info.shape[axis]) + argmax = relax.op.argmax(indices, axis=axis) + on_value = relax.PrimValue(tvm.tir.const(1.0, dtype)) + off_value = relax.PrimValue(tvm.tir.const(0.0, dtype)) + + one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis) + return one_hot + + class Transpose(OnnxOpConverter): """Converts an onnx Transpose node into an equivalent Relax expression.""" @@ -731,6 +765,20 @@ def _impl_v1(cls, bb, inputs, attr, params): return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0]))) +class EyeLike(OnnxOpConverter): + """Convert an onnx EyeLike node into an equivalent Relax expression.""" + + @classmethod + def _impl_v9(cls, bb, inputs, attr, params): + k = attr.get("k", 0) + input_dtype = inputs[0].struct_info.dtype + if "dtype" in attr and get_type(attr["dtype"]) != input_dtype: + raise ValueError( + f"dtype mismatch between input ({input_dtype}) and attribute ({attr['dtype']})" + ) + return relax.op.eye_like(inputs[0], k, input_dtype) + + class Gemm(OnnxOpConverter): """Convert an onnx Gemm node into an equivalent Relax expression.""" @@ -2520,13 +2568,13 @@ def _impl_v11(cls, bb, inputs, attr, params): depth = get_constant(inputs[1], params) values = get_constant(inputs[2], params) axis = attr.get("axis", -1) - dtype = values.struct_info.dtype assert isinstance(depth, relax.Constant), "Only constant depth currently supported." depth = depth.data.numpy().tolist() assert isinstance(values, relax.Constant), "Only constant values currently supported." values = values.data.numpy().tolist() off_value, on_value = values - return bb.emit_te(topi.one_hot, indices, on_value, off_value, depth, axis, dtype) + off_value, on_value = relax.PrimValue(off_value), relax.PrimValue(on_value) + return relax.op.one_hot(indices, on_value, off_value, depth, axis) class Unique(OnnxOpConverter): @@ -2800,7 +2848,7 @@ def _get_convert_map(): "Sub": Sub, "Mul": Mul, "Div": Div, - # "Mod": Mod, + "Mod": Mod, "Less": Less, "LessOrEqual": LessOrEqual, "Greater": Greater, @@ -2870,7 +2918,7 @@ def _get_convert_map(): "Sigmoid": Sigmoid, "Softmax": Softmax, "LogSoftmax": LogSoftmax, - # "Hardmax": Hardmax, + "Hardmax": Hardmax, "Transpose": Transpose, "Unsqueeze": Unsqueeze, "Where": Where, @@ -2889,7 +2937,7 @@ def _get_convert_map(): "ScatterND": ScatterND, # "Compress": Compress, "Size": Size, - # "EyeLike": EyeLike, + "EyeLike": EyeLike, # Normalization "BatchNormalization": BatchNormalization, "LayerNormalization": LayerNormalization, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 84b31ccec01e..1603ea2f0f7e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -50,6 +50,7 @@ divide, equal, floor_divide, + floor_mod, greater, greater_equal, left_shift, @@ -60,6 +61,7 @@ logical_xor, maximum, minimum, + mod, multiply, not_equal, power, @@ -72,6 +74,8 @@ full_like, ones, ones_like, + eye, + eye_like, tril, triu, zeros, @@ -89,6 +93,7 @@ flatten, flip, layout_transform, + one_hot, permute_dims, repeat, reshape, diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 7632235cb32c..7a41c8b0953c 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -139,6 +139,32 @@ def subtract(x1: Expr, x2: Expr) -> Expr: return _ffi_api.subtract(x1, x2) # type: ignore +def mod(x1: Expr, x2: Expr) -> Expr: + """Modulo with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + """ + return _ffi_api.mod(x1, x2) # type: ignore + + +def floor_mod(x1: Expr, x2: Expr) -> Expr: + """Floor modulo with numpy-style broadcasting. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + """ + return _ffi_api.floor_mod(x1, x2) # type: ignore + + ###################### Comparison operators ###################### diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index 092d79a74dc4..c61d9521a41d 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -163,6 +163,74 @@ def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: return _ffi_api.zeros_like(x, dtype) # type: ignore +def eye( + n: Union[PrimExprLike, PrimValue], + m: Optional[Union[PrimExprLike, PrimValue]] = None, + k: Union[PrimExprLike, PrimValue] = 0, + dtype: Union[str, DataType] = "float32", +) -> Expr: + """Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. + + Parameters + ---------- + n : Union[PrimExprLike, PrimValue] + Number of rows in the output. + + m : Optional[Union[PrimExprLike, PrimValue]] + Number of columns in the output. If None, defaults to n. + + k : Union[PrimExprLike, PrimValue] + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + m = n if m is None else m + n = n if isinstance(n, PrimValue) else PrimValue(n) + m = m if isinstance(m, PrimValue) else PrimValue(m) + k = k if isinstance(k, PrimValue) else PrimValue(k) + return _ffi_api.eye(n, m, k, dtype) # type: ignore + + +def eye_like( + x: Expr, + k: Union[PrimExprLike, PrimValue] = 0, + dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Return a 2-D tensor with ones on the diagonal and zeros elsewhere, + with the same shape as the input tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + k : Union[PrimExprLike, PrimValue] + Index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + k = k if isinstance(k, PrimValue) else PrimValue(k) + return _ffi_api.eye_like(x, k, dtype) # type: ignore + + def arange( start: Union[PrimExprLike, PrimValue], end: Optional[Union[PrimExprLike, PrimValue]] = None, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 1673a79b08c2..3210cc821689 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -550,3 +550,47 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat """ return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore + + +def one_hot( + indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1 +) -> Expr: + """Returns a one-hot tensor. + + Parameters + ---------- + indices : relax.Expr + The indices to set to `on_value`. + + on_value : relax.PrimValue + The value to fill at `indices`. + + off_value : relax.PrimValue + The value to fill at other locations. + + depth : int + The depth of the one-hot dimension. + + axis : int, optional + The axis to fill. Default is -1 which adds a new dimension at the end. + + Returns + ------- + result : relax.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + indices = [0, 1, 2] + depth = 3 + on_value = 1 + off_value = 0 + + one_hot(indices, on_value, off_value, depth) = + [[1, 0, 0], + [0, 1, 0], + [0, 0, 1]] + """ + return _ffi_api.one_hot(indices, on_value, off_value, depth, axis) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index d28e100edb9f..41e317f1e0ef 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -48,7 +48,8 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.power", _binary(topi.power)) register_legalize("relax.subtract", _binary(topi.subtract)) register_legalize("relax.equal", _binary(topi.equal)) - +register_legalize("relax.mod", _binary(topi.mod)) +register_legalize("relax.floor_mod", _binary(topi.floor_mod)) register_legalize("relax.greater", _binary(topi.greater)) register_legalize("relax.greater_equal", _binary(topi.greater_equal)) register_legalize("relax.less", _binary(topi.less)) diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 1b022672d0bd..8bf85e34dee8 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -70,6 +70,36 @@ def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.triu", _tril_triu(is_upper=True, primfunc_name="triu")) +def _eye(is_like: bool, primfunc_name: str) -> LegalizeFunc: + def eye_call_te(bb: BlockBuilder, call: Call) -> Expr: + _convert_to_scalar_const = lambda x: _try_convert_to_scalar_const(x, python_native=True) + if is_like: + x = call.args[0] + k = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else 0 + n, m = x.struct_info.shape + dtype = x.struct_info.dtype + else: + n = _convert_to_scalar_const(call.args[0]) + m = _convert_to_scalar_const(call.args[1]) if len(call.args) > 1 else n + k = _convert_to_scalar_const(call.args[2]) if len(call.args) > 2 else 0 + dtype = call.attrs.dtype + + return bb.call_te( + topi.eye, + n, + m, + k, + dtype, + primfunc_name_hint=primfunc_name, + ) + + return eye_call_te + + +register_legalize("relax.eye", _eye(is_like=False, primfunc_name="eye")) +register_legalize("relax.eye_like", _eye(is_like=True, primfunc_name="eye_like")) + + @register_legalize("relax.arange") def _arange(bb: BlockBuilder, call: Call) -> Expr: assert len(call.args) == 3 diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 105d763403af..163085a07c34 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -185,6 +185,25 @@ def scatter_nd(data, indices, updates, reduction): ) +@register_legalize("relax.one_hot") +def _one_hot(bb: BlockBuilder, call: Call) -> Expr: + indices, on_value, off_value = call.args + if not (isinstance(on_value, relax.PrimValue) and isinstance(off_value, relax.PrimValue)): + raise ValueError("on_value and off_value must be PrimValue") + on_value, off_value = on_value.value, off_value.value + if on_value.dtype != off_value.dtype: + raise ValueError("on_value and off_value must have the same dtype") + return bb.call_te( + topi.one_hot, + indices, + on_value, + off_value, + call.attrs.depth, + call.attrs.axis, + on_value.dtype, + ) + + @register_legalize("relax.layout_transform") def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: def te_layout_transform(data, name): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index f7847e2af8ed..049345fcb10d 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -85,10 +85,13 @@ ewise_fma, exp, expand_dims, + eye, + eye_like, flatten, flip, floor, floor_divide, + floor_mod, full, full_like, grad, @@ -119,6 +122,7 @@ memory, min, minimum, + mod, multinomial_from_uniform, multiply, negative, @@ -127,6 +131,7 @@ null_value, ones, ones_like, + one_hot, permute_dims, power, print, @@ -753,10 +758,13 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "exp", "expand_dims", "ext_dev", + "eye", + "eye_like", "flatten", "flip", "floor", "floor_divide", + "floor_mod", "full", "full_like", "func_attr", @@ -795,6 +803,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "metal", "min", "minimum", + "mod", "multinomial_from_uniform", "multiply", "negative", @@ -802,6 +811,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "null_value", "ones", "ones_like", + "one_hot", "opencl", "output", "permute_dims", diff --git a/python/tvm/topi/tensor.py b/python/tvm/topi/tensor.py index 31ebe86760cb..449c599deaf3 100644 --- a/python/tvm/topi/tensor.py +++ b/python/tvm/topi/tensor.py @@ -16,7 +16,11 @@ # under the License. # pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition """Elementwise operators""" -from __future__ import absolute_import as _abs + +from typing import Optional + +from tvm import te + from . import cpp @@ -73,3 +77,32 @@ def full_like(x, fill_value): The result. """ return cpp.full_like(x, fill_value) + + +def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: str = "float32") -> te.Tensor: + """Generate an identity matrix or a matrix with ones on the k-th diagonal. + + Parameters + ---------- + n : int + Number of rows + m : int, optional + Number of columns. If None, defaults to n. + k : int, optional + Index of the diagonal. 0 (default) refers to the main diagonal. + A positive value refers to an upper diagonal, and a negative value + to a lower diagonal. + dtype : str, optional + Data type of the returned array. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + m = m if m is not None else n + return te.compute( + (n, m), + lambda i, j: te.if_then_else(i == j - k, te.const(1, dtype), te.const(0, dtype)), + name="eye", + ) diff --git a/src/relax/op/distributed/binary.cc b/src/relax/op/distributed/binary.cc index 6ad71e0f85bf..1e7fa8172718 100644 --- a/src/relax/op/distributed/binary.cc +++ b/src/relax/op/distributed/binary.cc @@ -42,6 +42,8 @@ RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(multiply); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(power); RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(subtract); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(mod); +RELAX_REGISTER_BINARY_BROADCAST_DIST_INFER_STRUCT_INFO(floor_mod); /***************** Comparison operators *****************/ diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index f1dc3d4904c8..bd4c681c7925 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -181,6 +181,8 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod); /***************** Comparison operators *****************/ diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 003bcb7e27cf..b66eb96f8452 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -79,6 +79,12 @@ Expr power(Expr x1, Expr x2); /*! \brief Subtraction with numpy-style broadcasting. */ Expr subtract(Expr x1, Expr x2); +/*! \brief Modulo with numpy-style broadcasting. */ +Expr mod(Expr x1, Expr x2); + +/*! \brief Floor modulo with numpy-style broadcasting. */ +Expr floor_mod(Expr x1, Expr x2); + /***************** Comparison operators *****************/ /*! \brief Broadcasted element-wise test for (lhs == rhs). */ diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 7aca1470aee4..8696d85f7756 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -228,6 +228,90 @@ TVM_REGISTER_OP("relax.zeros_like") .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike) .set_attr("FPurity", Bool(true)); +/* relax.eye & relax.eye_like */ +Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.eye"); + return Call(op, {std::move(n), std::move(m), std::move(k)}, Attrs(attrs), {}); +} + +Expr eye_like(Expr x, PrimValue k, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.eye_like"); + return Call(op, {std::move(x), std::move(k)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.eye").set_body_typed(eye); +TVM_REGISTER_GLOBAL("relax.op.eye_like").set_body_typed(eye_like); + +StructInfo InferStructInfoEye(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 3) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye op should have 3 arguments: n, m, and k, but got " << call->args.size() + << " arguments"); + } + + auto get_prim_value = [&ctx](const Expr& expr, std::string key) { + if (!expr->IsInstance()) { + ctx->ReportFatal(Diagnostic::Error(expr) + << "Eye expects the `" << key << "` to be a PrimValue, but got " + << expr->GetTypeKey()); + } + return expr.as()->value; + }; + + PrimExpr n = get_prim_value(call->args[0], "n"); + PrimExpr m = get_prim_value(call->args[1], "m"); + + DataType dtype = call->attrs.as()->dtype; + return TensorStructInfo(ShapeExpr({n, m}), dtype); +} + +StructInfo InferStructInfoEyeLike(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like op should have 2 arguments: x and k, but got " + << call->args.size() << " arguments"); + } + + const auto* x_sinfo = GetStructInfoAs(call->args[0]); + if (x_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like expects the input `x` to be a Tensor, but got " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (x_sinfo->ndim != 2 && x_sinfo->ndim != kUnknownNDim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Eye_like expects the input tensor to be 2-dimensional, but got " + << x_sinfo->ndim << " dimensions"); + } + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->dtype.is_void() ? x_sinfo->dtype : attrs->dtype; + + return TensorStructInfo(x_sinfo->shape.value(), out_dtype, x_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.eye") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("n", "PrimValue", "Number of rows in the output.") + .add_argument("m", "PrimValue", "Number of columns in the output.") + .add_argument("k", "PrimValue", "Index of the diagonal.") + .set_attr("FInferStructInfo", InferStructInfoEye) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + +TVM_REGISTER_OP("relax.eye_like") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("k", "PrimValue", "Index of the diagonal.") + .set_attr("FInferStructInfo", InferStructInfoEyeLike) + .set_attr("FPurity", Bool(true)); + /* relax.arange */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype) { ObjectPtr attrs = make_object(); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 6e7c8255238a..d88336146d44 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -72,12 +72,48 @@ Expr ones(Expr shape, DataType dtype); */ Expr ones_like(Expr x, DataType dtype); -/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */ +/*! + * \brief Construct a tensor of all zeros, with the input shape and dtype. + * \param shape The shape of the created tensor. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ Expr zeros(Expr shape, DataType dtype); -/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */ +/*! + * \brief Construct a tensor with all zeros, with shape of the input tensor shape. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ Expr zeros_like(Expr x, DataType dtype); +/*! + * \brief Construct a 2-D tensor with ones on the diagonal and zeros elsewhere. + * \param n The number of rows and columns in the output. + * \param m The number of columns in the output. If None, defaults to n. + * \param k The index of the diagonal. A positive value refers to an upper diagonal, + * a negative value to a lower diagonal, and 0 to the main diagonal. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr eye(PrimValue n, PrimValue m, PrimValue k, DataType dtype); + +/*! + * \brief Construct a tensor with ones on the diagonal and zeros elsewhere, + * with shape and dtype similar to the input tensor. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param k The index of the diagonal. A positive value refers to an upper diagonal, + * a negative value to a lower diagonal, and 0 to the main diagonal. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr eye_like(Expr x, PrimValue k, DataType dtype); + /*! \brief Construct a tensor with evenly spaced elements. */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index ca7d0a0945bc..ba443413025a 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -30,6 +30,8 @@ #include #include +#include "tvm/runtime/data_type.h" + namespace tvm { namespace relax { @@ -1665,5 +1667,78 @@ TVM_REGISTER_OP("relax.scatter_nd") .set_attr("FInferStructInfo", InferStructInfoScatterND) .set_attr("FPurity", Bool(true)); +/* relax.one_hot */ +TVM_REGISTER_NODE_TYPE(OneHotAttrs); +Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { + ObjectPtr attrs = make_object(); + attrs->depth = depth; + attrs->axis = axis; + + // Check if on_value and off_value have the same dtype + DataType on_dtype = on_value->value->dtype; + DataType off_dtype = off_value->value->dtype; + ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_dtype << " and " << off_dtype; + + ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth; + + static const Op& op = Op::Get("relax.one_hot"); + return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); +} // namespace relax + +TVM_REGISTER_GLOBAL("relax.op.one_hot").set_body_typed(one_hot); + +StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo indices_sinfo = GetInputTensorStructInfo(call, 0, ctx); + const auto* attrs = call->attrs.as(); + PrimValue on_value = Downcast(call->args[1]); + PrimValue off_value = Downcast(call->args[2]); + // Check if on_value and off_value have the same dtype + ICHECK(on_value->value->dtype == off_value->value->dtype) + << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; + DataType dtype = on_value->value->dtype; + + // Check if indices has an integer dtype + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "one_hot op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + // Check if indices has unknown dimension + if (indices_sinfo->IsUnknownNdim()) { + return TensorStructInfo(dtype, kUnknownNDim, indices_sinfo->vdevice); + } + // Get the shape of indices + const auto* indices_shape = indices_sinfo->shape.as(); + if (indices_shape == nullptr) { + return TensorStructInfo(dtype, indices_sinfo->ndim + 1, indices_sinfo->vdevice); + } + + Array output_shape = indices_shape->values; + int axis = attrs->axis; + if (axis < 0) { + axis += output_shape.size() + 1; + } + ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) + << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " + << "but got " << axis; + output_shape.insert(output_shape.begin() + axis, attrs->depth); + + return TensorStructInfo(ShapeExpr(output_shape), dtype, indices_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.one_hot") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("on_value", "PrimValue", "The value to fill at specified indices.") + .add_argument("off_value", "PrimValue", "The value to fill at other indices.") + .set_attr("FInferStructInfo", InferStructInfoOneHot) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index e9fa1131e803..010ceb663ef3 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -27,6 +27,7 @@ #include #include "../op_common.h" +#include "tvm/relax/expr.h" namespace tvm { namespace relax { @@ -206,6 +207,17 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re */ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); +/*! + * \brief Returns a one-hot tensor. + * \param indices The indices to set to `on_value`. + * \param on_value The value to fill at `indices`. + * \param off_value The value to fill at other locations. + * \param depth The depth of the one hot dimension. + * \param axis The axis to fill. + * \return The computed result. + */ +Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 1b4c5d281abb..46373510b101 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -63,8 +63,11 @@ def generate_random_inputs( if dtype == "bool": # random_value = np.random.choice(a=[False, True], size=shape) random_value = rg.choice(a=[False, True], size=shape) + elif dtype.startswith("int"): + # Keep non-zero values + random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype) + random_value[random_value <= 0] -= 1 else: - # random_value = np.random.normal(size=shape).astype(dtype) random_value = rg.standard_normal(size=shape).astype(dtype) input_values[i.name] = random_value @@ -246,7 +249,6 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, dtype=TensorProto.INT32 ) model = helper.make_model(graph, producer_name="binary_test") - # NOTE: explicitly pass inputs to avoid numerical error check_correctness(model, opset=opset) @@ -327,6 +329,16 @@ def test_binary(op_name: str): verify_binary_scalar(op_name) +@pytest.mark.parametrize("int_mode", [True, False]) +def test_mod(int_mode: bool): + if int_mode: + dtype, fmod = TensorProto.INT32, 0 + else: + dtype, fmod = TensorProto.FLOAT, 1 + verify_binary("Mod", [1, 32], [1, 32], [1, 32], attrs={"fmod": fmod}, dtype=dtype) + verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype) + + @pytest.mark.parametrize("num_inputs", [1, 2, 4]) @pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"]) def test_multi_input(op_name: str, num_inputs: int): @@ -430,6 +442,7 @@ def test_bitwise_shift(direction: str): "Sigmoid", "Softmax", "LogSoftmax", + "Hardmax", "Identity", ], ) @@ -445,7 +458,7 @@ def test_unary(op_name: str): output_dtype = TensorProto.BOOL else: output_dtype = TensorProto.FLOAT - verify_unary(op_name, [32, 32], input_dtype=input_dtype, output_dtype=output_dtype) + verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, output_dtype=output_dtype) @pytest.mark.parametrize("from_type", [TensorProto.INT32, TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -567,6 +580,11 @@ def test_size(): check_correctness(model) +@pytest.mark.parametrize("k", [-1, 0, 1]) +def test_eye_like(k: int): + verify_unary("EyeLike", [32, 32], attrs={"k": k}) + + @pytest.mark.parametrize("alpha", [None, 0.25, 1.0]) @pytest.mark.parametrize("beta", [None, 0.35, 1.0]) @pytest.mark.parametrize("useC", [False, True]) @@ -966,7 +984,7 @@ def test_cumsum1(): ) model = helper.make_model(graph, producer_name="cumsum_graph") - check_correctness(model) + check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)}) @pytest.mark.parametrize("axis", [[0, 2], None]) diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index 1e895169f620..67f347019163 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -545,6 +545,64 @@ def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.zeros_like(x1)) +def test_eye_infer_struct_info(): + bb = relax.BlockBuilder() + + _check_inference(bb, relax.op.eye(3), relax.TensorStructInfo((3, 3), "float32")) + _check_inference(bb, relax.op.eye(2, 4), relax.TensorStructInfo((2, 4), "float32")) + _check_inference(bb, relax.op.eye(3, dtype="int64"), relax.TensorStructInfo((3, 3), "int64")) + _check_inference(bb, relax.op.eye(3, 5, k=1), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.eye(3, 5, k=-2), relax.TensorStructInfo((3, 5), "float32")) + + +def test_eye_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + k = tir.Var("k", "int64") + + _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n), "float32")) + _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye(n, k=k), relax.TensorStructInfo((n, n), "float32")) + + +def test_eye_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((2, 5), "int64")) + x2 = relax.Var("x", R.Tensor((3, 3))) + + _check_inference(bb, relax.op.eye_like(x0), relax.TensorStructInfo((3, 4), "float32")) + _check_inference(bb, relax.op.eye_like(x1), relax.TensorStructInfo((2, 5), "int64")) + _check_inference(bb, relax.op.eye_like(x2), relax.TensorStructInfo((3, 3), dtype="")) + _check_inference(bb, relax.op.eye_like(x0, k=1), relax.TensorStructInfo((3, 4), "float32")) + _check_inference( + bb, relax.op.eye_like(x1, dtype="float32"), relax.TensorStructInfo((2, 5), "float32") + ) + + +def test_eye_like_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + x = relax.Var("x", R.Tensor((n, m), "float32")) + k = tir.Var("k", "int64") + + _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m), "float32")) + _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n, m), "float32")) + + +def test_eye_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.eye_like(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.eye_like(x1)) + + def test_arange_infer_struct_info(): bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index e958b03e4ce6..f6aefc859114 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -3377,5 +3377,57 @@ def test_scatter_nd_infer_struct_info(): ) +def test_one_hot_infer_struct_info(): + bb = relax.BlockBuilder() + + # Test case 1: Basic usage + i0 = relax.Var("indices", R.Tensor((3,), "int32")) + _check_inference( + bb, + relax.op.one_hot(i0, relax.PrimValue(1.0), relax.PrimValue(0.0), 5), + relax.TensorStructInfo((3, 5), "float32"), + ) + + # Test case 2: With specified axis + i1 = relax.Var("indices", R.Tensor((2, 2), "int32")) + _check_inference( + bb, + relax.op.one_hot(i1, relax.PrimValue(1), relax.PrimValue(0), 3, axis=1), + relax.TensorStructInfo((2, 3, 2), "int64"), + ) + + # Test case 3: With symbolic shape + n = tir.Var("n", "int64") + i2 = relax.Var("indices", R.Tensor((n,), "int32")) + _check_inference( + bb, + relax.op.one_hot(i2, relax.PrimValue(1.0), relax.PrimValue(0.0), 4), + relax.TensorStructInfo((n, 4), "float32"), + ) + + # Test case 4: With unknown shape + i3 = relax.Var("indices", R.Tensor("int32")) + _check_inference( + bb, + relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0.0), 6), + relax.TensorStructInfo(dtype="float32"), + ) + + # Test case 5: With different on_value and off_value dtypes + i3 = relax.Var("indices", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i3, relax.PrimValue(1.0), relax.PrimValue(0), 5)) + + # Test case 6: With invalid indices dtype + i4 = relax.Var("indices", R.Tensor((2, 3), "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i4, relax.PrimValue(1.0), relax.PrimValue(0.0), 5)) + + # Test case 7: With invalid depth + i5 = relax.Var("indices", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.one_hot(i5, relax.PrimValue(1.0), relax.PrimValue(0.0), -1)) + + if __name__ == "__main__": tvm.testing.main() From 80250411e706509fef499e0defe0e625bf6fab28 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 17 Oct 2024 04:36:41 +0800 Subject: [PATCH 14/20] [Relax][MetaSchedule] Support CPU weight prepack (#17445) This PR adds support for CPU weight prepacking. To be specific, this PR adds a new pass `AttachAttrLayoutFreeBuffers` to attach layout free buffers to the weight parameters, so that we can leverage MetaSchedule to optimize the prepacking process. After the pass and tuning, we introduce a new pass `SplitLayoutRewritePreproc` to split the layout rewrite pass into multiple functions, so that we can lift the parameters transform pass function with existing pass. --- include/tvm/relax/transform.h | 21 ++ python/tvm/relax/frontend/nn/__init__.py | 2 + python/tvm/relax/pipeline.py | 50 ++- python/tvm/relax/transform/__init__.py | 2 + python/tvm/relax/transform/transform.py | 29 ++ src/meta_schedule/postproc/rewrite_layout.cc | 8 +- .../attach_attr_layout_free_buffers.cc | 113 ++++++ .../transform/split_layout_rewrite_preproc.cc | 327 ++++++++++++++++++ ...t_meta_schedule_postproc_rewrite_layout.py | 3 +- ...ansform_attach_attr_layout_free_buffers.py | 311 +++++++++++++++++ ..._transform_split_layout_rewrite_preproc.py | 220 ++++++++++++ 11 files changed, 1083 insertions(+), 3 deletions(-) create mode 100644 src/relax/transform/attach_attr_layout_free_buffers.cc create mode 100644 src/relax/transform/split_layout_rewrite_preproc.cc create mode 100644 tests/python/relax/test_transform_attach_attr_layout_free_buffers.py create mode 100644 tests/python/relax/test_transform_split_layout_rewrite_preproc.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a7b85ac1376..eaad44a93ace 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_war */ TVM_DLL Pass RealizeVDevice(); +/*! + * \brief Attach layout free buffers to the tir::PrimFunc. + * + * This pass is used to attach layout free buffers to the tir::PrimFunc according to + * the function usage in the relax function. Currently, the layout free buffers are the model + * weights and relax constants. + * + * \note We recommend applying CanonicalizeBindings before this pass. + * \return The Pass. + */ +TVM_DLL Pass AttachAttrLayoutFreeBuffers(); + +/*! + * \brief Split the layout rewrite preproc block to a separate tir::PrimFunc. + * + * This pass is used in the prepack weight after meta_schedule tuning. + * + * \return The Pass. + */ +TVM_DLL Pass SplitLayoutRewritePreproc(); + /*! * \brief Lift transformation of the parameters of a function. * diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index a8200d8dd627..f490af7062b0 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -23,6 +23,8 @@ from .modules import ( GELU, Conv1D, + Conv2D, + Conv3D, ConvTranspose1D, Embedding, GroupNorm, diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 582f5111aaf5..fe3dbc99fc15 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -109,6 +109,7 @@ def static_shape_tuning_pipeline( total_trials: int, target: Union[str, tvm.target.Target], work_dir: str = "tuning_logs", + cpu_weight_prepack: bool = False, ): """Tune the static shape model and store the log to database. @@ -122,18 +123,65 @@ def static_shape_tuning_pipeline( work_dir : str The directory to store the tuning logs. + + cpu_weight_prepack : bool + Whether to enable the cpu weight prepack feature. + + Note + ---- + `cpu_weight_prepack` is expected to be `True` when running on CPU for + better performance. However, it requires an explicit layout transformation + step by calling the corresponding vm function, which changes the interface + of deployment. So we disable it by default. Here is an example to enable it: + + .. code-block:: python + + mod = relax.pipeline.static_shape_tuning_pipeline( + total_trials=1000, + target="llvm -num-cores 16", + work_dir="tuning_logs", + cpu_weight_prepack=True, + )(mod) + + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device=tvm.cpu()) + + # Transform the params using the vm function + # the name should be f"{func_name}_transform_params" + params = vm["main_transform_params"](params["main"]) + + input_data = tvm.nd.array(np.random.randn(1, 3, 224, 224).astype("float32")) + out = vm["main"](input_data, *params).numpy() """ @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + if cpu_weight_prepack: + pre_tuning_layout_rewrite = [transform.AttachAttrLayoutFreeBuffers()] + post_tuning_layout_rewrite = [ + transform.SplitLayoutRewritePreproc(), + transform.LiftTransformParams(), + transform.FoldConstant(), + ] + else: + pre_tuning_layout_rewrite = [] + post_tuning_layout_rewrite = [] + with tvm.target.Target(target): mod = tvm.transform.Sequential( [ transform.DecomposeOpsForInference(), transform.CanonicalizeBindings(), zero_pipeline(), - transform.MetaScheduleTuneIRMod({}, work_dir, total_trials), + *pre_tuning_layout_rewrite, + # Skip tuning if total_trials is 0 + ( + transform.MetaScheduleTuneIRMod({}, work_dir, total_trials) + if total_trials > 0 + else tvm.transform.Sequential([]) + ), transform.MetaScheduleApplyDatabase(work_dir), + *post_tuning_layout_rewrite, ] )(mod) diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 1ce864651cd9..16e4800ca33d 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -21,6 +21,7 @@ AllocateWorkspace, AlterOpImpl, AnnotateTIROpPattern, + AttachAttrLayoutFreeBuffers, AttachGlobalSymbol, BindParams, BindSymbolicVars, @@ -73,6 +74,7 @@ RewriteDataflowReshape, RunCodegen, SplitCallTIRByPattern, + SplitLayoutRewritePreproc, StaticPlanBlockMemory, ToMixedPrecision, ToNonDataflow, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3330d4098734..603211b59ebc 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -970,6 +970,35 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass: return _ffi_api.MergeCompositeFunctions() # type: ignore +def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass: + """Attach layout free buffers to the tir::PrimFunc. + + This pass is used to attach layout free buffers to the tir::PrimFunc according to + the function usage in the relax function. Currently, the layout free buffers are the model + weights and relax constants. + + Note that we recommend applying CanonicalizeBindings before this pass. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for attaching layout free buffers. + """ + return _ffi_api.AttachAttrLayoutFreeBuffers() # type: ignore + + +def SplitLayoutRewritePreproc() -> tvm.ir.transform.Pass: + """Split the TIR layout rewrite into multiple TIR functions. + This pass is used in the prepack weight after meta_schedule tuning. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for splitting TIR layout rewrite. + """ + return _ffi_api.SplitLayoutRewritePreproc() # type: ignore + + def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) -> tvm.ir.transform.Pass: """Lift transformation of the parameters of a function. diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 71ae43387112..87fa96f67ceb 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -249,7 +249,13 @@ class RewriteLayoutNode : public PostprocNode { void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode - bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); } + bool Apply(const tir::Schedule& sch) final { + try { + return tir::RewriteLayout(sch); + } catch (const std::runtime_error& e) { + return false; + } + } Postproc Clone() const { ObjectPtr n = make_object(*this); diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc new file mode 100644 index 000000000000..64062e224372 --- /dev/null +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/attach_attr_layout_free_buffers.cc + * \brief Attach layout_free_buffers for layout-free buffers. + */ + +#include +#include +#include + +namespace tvm { +namespace relax { + +class AttrAttacher : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + AttrAttacher mutator(mod); + for (auto [gvar, func] : mod->functions) { + if (func->IsInstance()) { + // clear the layout_free_exprs_ for each function + mutator.layout_free_exprs_.clear(); + mutator.builder_->UpdateFunction(gvar, Downcast(mutator.VisitExpr(func))); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit AttrAttacher(IRModule mod) : ExprMutator(mod), mod_(mod) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* op) final { + if (auto opt_num_input = op->attrs.GetAttr(attr::kNumInput)) { + ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with num_input attr"; + size_t num_input = opt_num_input.value()->value; + for (size_t i = num_input; i < op->params.size(); i++) { + layout_free_exprs_.insert(op->params[i].get()); + } + } + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const ConstantNode* op) final { + layout_free_exprs_.insert(op); + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + if (call->op != call_tir_op_) { + return call; + } + GlobalVar gv = Downcast(call->args[0]); + Array call_tir_args = Downcast(call->args[1])->fields; + // Compute the layout free buffers + Array layout_free_buffers; + for (size_t i = 0; i < call_tir_args.size(); i++) { + if (layout_free_exprs_.count(call_tir_args[i].get())) { + layout_free_buffers.push_back(Integer(i)); + } + } + // Attach the layout free buffers to the tir::PrimFunc + tir::PrimFunc func = WithAttr(Downcast(mod_->Lookup(gv)), "layout_free_buffers", + layout_free_buffers); + // Renew defs + func = tir::RenewDefs(func); + // Add the updated tir::PrimFunc in the IRModule + // Note the blockbuilder would automatically combine the same tir function + // So we don't need to worry about the duplicate insertion + GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint); + // Create a new call node with the updated tir::PrimFunc + auto n = make_object(*op); + n->args = {new_gv, Tuple(call_tir_args)}; + return Call(n); + } + + private: + IRModule mod_; + std::unordered_set layout_free_exprs_; +}; +namespace transform { + +Pass AttachAttrLayoutFreeBuffers() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod); }; + auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers", {}); + // Apply DeadCodeElimination to remove unused tir::PrimFunc + return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); +} + +TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers") + .set_body_typed(AttachAttrLayoutFreeBuffers); +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc new file mode 100644 index 000000000000..5fee946c26dd --- /dev/null +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/split_tir_layout_rewrite.cc + * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post process. + */ +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +class SplitPrimFuncLayoutRewrite : public StmtMutator { + public: + explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} + std::tuple, PrimFunc> Transform(const PrimFunc& func) { + ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; + const auto& block = func->body.as()->block; + visit_root_block(block.get()); + if (layout_rewrite_preproc_stmts_.size() > 0) { + return std::make_tuple(create_layout_rewrite_preproc_func(), create_compute_func()); + } else { + return std::make_tuple(NullOpt, func); + } + } + + private: + void sort_rewrite_infos() { + std::sort( + rewrite_infos_.begin(), rewrite_infos_.end(), + [](const RewriteInfo& a, const RewriteInfo& b) { return a.buffer_index < b.buffer_index; }); + } + + PrimFunc create_layout_rewrite_preproc_func() const { + // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers + ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; + + // Step 2: Create the params for the new PrimFunc + Array params; + Map buffer_map; + + for (const auto& info : rewrite_infos_) { + params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle())); + buffer_map.Set(params.back(), info.pre_rewrite_buffer); + } + for (const auto& info : rewrite_infos_) { + params.push_back(Var(info.post_rewrite_buffer->name, DataType::Handle())); + buffer_map.Set(params.back(), info.post_rewrite_buffer); + } + + // Step 3: Create the body for the new PrimFunc + ICHECK(layout_rewrite_preproc_stmts_.size() > 0) + << "There should be at least one layout rewrite preproc stmt."; + Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] + : SeqStmt(layout_rewrite_preproc_stmts_); + body = BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body)); + + PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map); + + return RenewDefs(func); + } + + PrimFunc create_compute_func() const { + // Step 1: Create the params for the new PrimFunc + Array params = original_func_->params; + Map buffer_map = original_func_->buffer_map; + for (const auto& info : rewrite_infos_) { + const Var& param = params[info.buffer_index]; + ICHECK(buffer_map[param] == info.pre_rewrite_buffer); + buffer_map.Set(param, info.post_rewrite_buffer); + } + + // Step 2: Create the body for the new PrimFunc + Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : SeqStmt(compute_stmts_); + Block original_block = original_func_->body.as()->block; + Array alloc_buffers; + for (const auto& buffer : original_block->alloc_buffers) { + auto it = + std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(), + [&](const RewriteInfo& info) { return info.post_rewrite_buffer == buffer; }); + if (it == rewrite_infos_.end()) { + alloc_buffers.push_back(buffer); + } + } + + body = BlockRealize( + /*iter_values=*/Array(), + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, + /*name_hint=*/"root", body, + /*init=*/NullOpt, + /*alloc_buffers=*/alloc_buffers)); + + PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), buffer_map); + return RenewDefs(func); + } + + void visit_root_block(const BlockNode* op) { + Stmt body = op->body; + if (const auto* seq_stmt = body.as()) { + for (const auto& stmt : seq_stmt->seq) { + current_subtree_ = 0; + Stmt new_stmt = this->VisitStmt(stmt); + ICHECK(current_subtree_ != 0) << "There should be at least a block in the subtree."; + if (current_subtree_ == 1) { + layout_rewrite_preproc_stmts_.push_back(new_stmt); + } else { + compute_stmts_.push_back(new_stmt); + } + } + } else { + current_subtree_ = 0; + this->VisitStmt(body); + ICHECK(current_subtree_ == -1) + << "There should be a compute block if there is only one subtree under the root."; + } + } + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc); + bool is_layout_rewrite_preproc = + it != op->annotations.end() && is_one(Downcast((*it).second)); + + if (current_subtree_ == 0) { + current_subtree_ = is_layout_rewrite_preproc ? 1 : -1; + } else if (current_subtree_ == 1) { + CHECK(is_layout_rewrite_preproc) + << "There is a layout rewrite block in the subtree, but meet a non-layout rewrite block."; + } else { + CHECK(!is_layout_rewrite_preproc) + << "There is a non-layout rewrite block in the subtree, but meet a layout rewrite block."; + } + + if (is_layout_rewrite_preproc) { + ICHECK(op->reads.size() == 1) << "There should be only one read buffer in the layout rewrite"; + ICHECK(op->writes.size() == 1) + << "There should be only one write buffer in the layout rewrite"; + ICHECK(op->alloc_buffers.empty()) << "There should be no alloc buffer in the layout rewrite"; + ICHECK(op->match_buffers.empty()) << "There should be no match buffer in the layout rewrite"; + const Buffer& preproc_buffer = op->reads[0]->buffer; + int buffer_index = -1; + for (size_t i = 0; i < original_func_->params.size(); ++i) { + const Buffer& buffer = original_func_->buffer_map[original_func_->params[i]]; + if (buffer == preproc_buffer) { + buffer_index = i; + break; + } + } + ICHECK(buffer_index != -1) << "The preproc buffer is not found in the original primfunc."; + rewrite_infos_.push_back( + RewriteInfo{buffer_index, op->reads[0]->buffer, op->writes[0]->buffer}); + + auto new_annotations = op->annotations; + new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc); + auto n = make_object(*block.get()); + n->annotations = new_annotations; + return Block(n); + } + return block; + } + + public: + struct RewriteInfo { + int buffer_index; + Buffer pre_rewrite_buffer; + Buffer post_rewrite_buffer; + }; + std::vector rewrite_infos_; + + private: + /*! \brief The stmts that are used for layout rewrite preproc*/ + Array layout_rewrite_preproc_stmts_; + /*! \brief The stmts that are other than layout rewrite preproc*/ + Array compute_stmts_; + /*! + \brief Whether the current subtree is a layout rewrite preproc subtree. + -1: visited a non-layout rewrite preproc block + 0: unsure, not visited any block + 1: visited a layout rewrite preproc block + */ + int current_subtree_; + /*! \brief The original primfunc*/ + PrimFunc original_func_; +}; +} // namespace tir + +namespace relax { +class SplitLayoutRewritePreproc : public ExprMutator { + public: + static IRModule Transform(const IRModule& mod) { + SplitLayoutRewritePreproc mutator(mod); + + // Step 1: Split the primfunc into preproc and compute + for (auto [gv, func] : mod->functions) { + if (func->IsInstance()) { + tir::SplitPrimFuncLayoutRewrite tir_rewriter(Downcast(func)); + auto [preproc_func, compute_func] = tir_rewriter.Transform(Downcast(func)); + if (preproc_func.defined()) { + mutator.split_funcs_.emplace(gv.get(), + std::make_tuple(preproc_func.value(), compute_func)); + mutator.rewrite_infos_.emplace(gv.get(), tir_rewriter.rewrite_infos_); + } + } + } + + for (auto [gv, func] : mod->functions) { + if (func->IsInstance()) { + auto relax_func = Downcast(func); + mutator.builder_->UpdateFunction(gv, Downcast(mutator(relax_func))); + } + } + return mutator.builder_->GetContextIRModule(); + } + + private: + explicit SplitLayoutRewritePreproc(const IRModule& mod) : ExprMutator(mod) {} + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* op) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + Call call = Downcast(ExprMutator::VisitExpr_(op)); + + // Step 1: Skip call to other than `tir.call_tir` + if (!call->op.same_as(call_tir_op)) { + return call; + } + + // Step 2: Skip if there is no preproc stage + const GlobalVar gv = Downcast(call->args[0]); + auto it = split_funcs_.find(gv.get()); + if (it == split_funcs_.end()) { + return call; + } + + // Step 3: Get the preproc and compute functions and update the module + const auto& [preproc_func, compute_func] = it->second; + GlobalVar preproc_gv = builder_->AddFunction(preproc_func, gv->name_hint + "_weight_prepack"); + GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint + "_prepacked"); + // Step 4. Get rewrite infos + auto rewrite_infos_it = rewrite_infos_.find(gv.get()); + ICHECK(rewrite_infos_it != rewrite_infos_.end()) + << "Rewrite infos are not found for " << gv->name_hint; + const auto& rewrite_infos = rewrite_infos_it->second; + + // Step 5: Emit the preproc call + Array call_tir_args = Downcast(call->args[1])->fields; + Array preproc_args; + Array preproc_sinfo_list; + for (const auto& info : rewrite_infos) { + preproc_args.push_back(call_tir_args[info.buffer_index]); + tir::Buffer rewritten_buffer = info.post_rewrite_buffer; + for (const auto& shape_expr : rewritten_buffer->shape) { + CHECK(shape_expr.as()) << "Currently does not support rewrite buffer with " + "dynamic shape."; + } + preproc_sinfo_list.push_back( + TensorStructInfo(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); + } + StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1 // + ? TupleStructInfo(preproc_sinfo_list) // + : preproc_sinfo_list[0]; + + // Step 6: Call the preproc function + Expr preproc_call = + builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, {}, {preproc_sinfo})); + if (rewrite_infos.size() == 1) { + call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call); + } else { + for (size_t i = 0; i < rewrite_infos.size(); ++i) { + call_tir_args.Set(rewrite_infos[i].buffer_index, TupleGetItem(preproc_call, i)); + } + } + Expr main_call = + builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, {}, call->sinfo_args)); + + return main_call; + } + + private: + std::unordered_map> split_funcs_; + std::unordered_map> + rewrite_infos_; +}; + +} // namespace relax + +namespace transform { +Pass SplitLayoutRewritePreproc() { + auto pass_func = [](IRModule mod, PassContext pc) { + return relax::SplitLayoutRewritePreproc::Transform(mod); + }; + auto pass = CreateModulePass(pass_func, 0, "SplitLayoutRewritePreproc", {}); + return tvm::transform::Sequential({pass, relax::transform::DeadCodeElimination()}, + "SplitLayoutRewritePreproc"); +} +TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc") + .set_body_typed(SplitLayoutRewritePreproc); +} // namespace transform +} // namespace tvm diff --git a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py index e2305de2afaf..8348c57c1949 100644 --- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py @@ -61,7 +61,8 @@ def inner(mod): ) sch = tvm.tir.Schedule(mod, debug_mask="all") sch.enter_postproc() - assert ctx.space_generator.postprocs[0].apply(sch) + if not ctx.space_generator.postprocs[0].apply(sch): + raise tvm.TVMError("RewriteLayout postproc failed") return sch.mod return inner diff --git a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py new file mode 100644 index 000000000000..46f7c8aa87be --- /dev/null +++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm.testing + +from tvm import relax, tir +from tvm.script import relax as R, tir as T, ir as I +from tvm.relax.transform import CombineParallelMatmul +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def test_param(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv = R.call_tir(cls.matmul1, (x, y), out_sinfo=R.Tensor((32, 32), "float32")) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_const(): + const_value = np.ones((32, 32), dtype="float32") + + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.matmul, + (x, relax.const(const_value)), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main(x: R.Tensor((32, 32), "float32")): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + gv = R.call_tir( + cls.matmul1, + (x, relax.const(const_value)), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_same_func(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul, + (lv1, w2), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul1, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul1, + (lv1, w2), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +def test_multiple_same_func_with_different_free_buffers(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def matmul( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul, + (w2, lv1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def matmul1( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @T.prim_func(private=True) + def matmul2( + A: T.Buffer((T.int64(32), T.int64(32)), "float32"), + B: T.Buffer((T.int64(32), T.int64(32)), "float32"), + C: T.Buffer((T.int64(32), T.int64(32)), "float32"), + ): + T.func_attr({"layout_free_buffers": [0]}) + for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)): + with T.block("C"): + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + @R.function + def main( + x: R.Tensor((32, 32), "float32"), + w1: R.Tensor((32, 32), "float32"), + w2: R.Tensor((32, 32), "float32"), + ): + R.func_attr({"num_input": 1}) + cls = Expected + with R.dataflow(): + lv1 = R.call_tir( + cls.matmul1, + (x, w1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + gv = R.call_tir( + cls.matmul2, + (w2, lv1), + out_sinfo=R.Tensor((32, 32), "float32"), + ) + R.output(gv) + return gv + + after = relax.transform.AttachAttrLayoutFreeBuffers()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py new file mode 100644 index 000000000000..e6b4c8ec4e2a --- /dev/null +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_single_buffer(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + T.func_attr({"layout_free_buffers": [1]}) + W_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W: T.Buffer((224, 224), "float32"), + W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + for i, j in T.grid(224, 224): + with T.block("W_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv = R.call_tir( + cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 4, 56, 56), "float32") + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 224), "float32") + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +def test_multiple_buffers(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def tir_func( + X: T.Buffer((224, 224), "float32"), + W1: T.Buffer((224, 224), "float32"), + W2: T.Buffer((224, 224), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + W1_rewrite = T.alloc_buffer((4, 4, 56, 56)) + W2_rewrite = T.alloc_buffer((4, 4, 56, 56)) + for i, j in T.grid(224, 224): + with T.block("W1_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] + for i, j in T.grid(224, 224): + with T.block("W2_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": T.bool(True)}) + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = ( + X[vi, vj] + + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + ) + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w1: R.Tensor((224, 224), dtype="float32"), + w2: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = Before + with R.dataflow(): + gv = R.call_tir( + cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224), dtype="float32") + ) + R.output(gv) + return gv + + @I.ir_module + class After: + @T.prim_func(private=True) + def tir_func_prepacked( + X: T.Buffer((224, 224), "float32"), + W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + Out: T.Buffer((224, 224), "float32"), + ): + for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): + with T.block("Out"): + vi = T.axis.spatial(224, i0 * 56 + i1) + vj = T.axis.spatial(224, j0 * 56 + j1) + Out[vi, vj] = ( + X[vi, vj] + + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] + ) + + @T.prim_func(private=True) + def tir_func_weight_prepack( + W1: T.Buffer((224, 224), "float32"), + W2: T.Buffer((224, 224), "float32"), + W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"), + ): + for i, j in T.grid(224, 224): + with T.block("W1_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, vj] + for i, j in T.grid(224, 224): + with T.block("W2_rewrite"): + vi, vj = T.axis.remap("SS", [i, j]) + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, vj] + + @R.function + def forward( + x: R.Tensor((224, 224), dtype="float32"), + w1: R.Tensor((224, 224), dtype="float32"), + w2: R.Tensor((224, 224), dtype="float32"), + ) -> R.Tensor((224, 224), dtype="float32"): + R.func_attr({"num_input": 1}) + cls = After + with R.dataflow(): + lv0 = R.call_tir( + cls.tir_func_weight_prepack, + (w1, w2), + out_sinfo=[ + R.Tensor((4, 4, 56, 56), "float32"), + R.Tensor((4, 4, 56, 56), "float32"), + ], + ) + lv1 = R.call_tir( + cls.tir_func_prepacked, + (x, lv0[0], lv0[1]), + out_sinfo=R.Tensor((224, 224), "float32"), + ) + gv: R.Tensor((224, 224), dtype="float32") = lv1 + R.output(gv) + return gv + + mod = relax.transform.SplitLayoutRewritePreproc()(Before) + tvm.ir.assert_structural_equal(mod, After) + + +if __name__ == "__main__": + tvm.testing.main() From f75b563e19d9652b57a6be7286fbb1b28df09ed4 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Thu, 17 Oct 2024 21:17:17 +0300 Subject: [PATCH 15/20] [LLVM][Arith] Presburger compile fix for MLIR/LLVM 19.x (#17469) --- src/arith/presburger_set.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 3798ba190446..4f4d7e18578f 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -215,7 +215,9 @@ PresburgerSet Intersect(const Array& sets) { IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { Array tvm_coeffs = DetectLinearEquation(e, set->GetVars()); -#if TVM_MLIR_VERSION >= 160 +#if TVM_MLIR_VERSION >= 190 + SmallVector coeffs; +#elif TVM_MLIR_VERSION >= 160 SmallVector coeffs; #else SmallVector coeffs; @@ -223,7 +225,9 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { coeffs.reserve(tvm_coeffs.size()); for (const PrimExpr& it : tvm_coeffs) { -#if TVM_MLIR_VERSION >= 160 +#if TVM_MLIR_VERSION >= 190 + coeffs.push_back(llvm::DynamicAPInt(*as_const_int(it))); +#elif TVM_MLIR_VERSION >= 160 coeffs.push_back(mlir::presburger::MPInt(*as_const_int(it))); #else coeffs.push_back(*as_const_int(it)); From 031508394802a96090ada8314e9ef698a359a42d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 18 Oct 2024 06:53:23 +0900 Subject: [PATCH 16/20] [CI] Pin cpplint==1.6.1 (#17470) use cpplint==1.6.1 --- docker/Dockerfile.ci_lint | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index bab0cd0ebf9c..89749b75bca8 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -38,7 +38,7 @@ ENV PYTHONNOUSERSITE 1 # Disable .local directory from affecting CI. RUN apt-get update && apt-install-and-clear -y doxygen graphviz curl shellcheck -RUN pip3 install cpplint pylint==2.17.2 mypy==0.902 black==22.12.0 flake8==3.9.2 blocklint==0.2.3 jinja2==3.0.3 +RUN pip3 install cpplint==1.6.1 pylint==2.17.2 mypy==0.902 black==22.12.0 flake8==3.9.2 blocklint==0.2.3 jinja2==3.0.3 # Rust env (build early; takes a while) COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh From 72f5d98e19c2d2cf2203441ca2f665109b290fbd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 21 Oct 2024 21:49:06 +0900 Subject: [PATCH 17/20] Pin pytest-profiling==1.7.0 (#17476) --- docker/install/ubuntu2004_install_python_package.sh | 2 +- docker/install/ubuntu_install_python_package.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install/ubuntu2004_install_python_package.sh b/docker/install/ubuntu2004_install_python_package.sh index f1c03cf1c0e2..c72ea5d4fa66 100644 --- a/docker/install/ubuntu2004_install_python_package.sh +++ b/docker/install/ubuntu2004_install_python_package.sh @@ -35,7 +35,7 @@ pip3 install --upgrade \ psutil \ pytest \ git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@768ec1dce349fe4708f6ad68be1ebb3f3dabafa1 \ - pytest-profiling \ + pytest-profiling==1.7.0 \ pytest-xdist \ pytest-rerunfailures==10.2 \ requests \ diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 593ba15f5947..7fe82a1db414 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -35,7 +35,7 @@ pip3 install --upgrade \ psutil \ pytest \ git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@768ec1dce349fe4708f6ad68be1ebb3f3dabafa1 \ - pytest-profiling \ + pytest-profiling!=1.8.0 \ pytest-xdist \ pytest-rerunfailures==10.2 \ requests \ From b38417cd0047dc27d562b63bfac9f93227db3491 Mon Sep 17 00:00:00 2001 From: Piotr eF Date: Mon, 21 Oct 2024 14:49:17 +0200 Subject: [PATCH 18/20] =?UTF-8?q?[Device][OpenCL]=20add=20CL=5FEXEC=5FSTAT?= =?UTF-8?q?US=5FERROR=5FFOR=5FEVENTS=5FIN=5FWAIT=5FLIST=20to=20=E2=80=A6?= =?UTF-8?q?=20(#17472)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [Device][OpenCL] add CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST to check function Co-authored-by: pfk-beta --- src/runtime/opencl/opencl_common.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 8c1607c4e56f..f752a487ea7e 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -171,6 +171,8 @@ inline const char* CLGetErrorString(cl_int error) { return "CL_INVALID_BUFFER_SIZE"; case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; + case CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST: + return "CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST"; default: return "Unknown OpenCL error code"; } From 3219b49c2f985440d5b35868f37a2f141ebc5359 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 23 Oct 2024 08:31:58 +0900 Subject: [PATCH 19/20] [CI] Revert jax, keras, tensorflow, and tflite upgrades introduced #17425 (#17485) Revert part of "[CI] Upgrade CI (#17425)" change the versions of jax, tensorflow, tflite back to what we've been using before --- docker/install/ubuntu_install_jax.sh | 18 ++++----- docker/install/ubuntu_install_tensorflow.sh | 4 +- .../ubuntu_install_tensorflow_aarch64.sh | 4 +- docker/install/ubuntu_install_tflite.sh | 40 +++++++++---------- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/docker/install/ubuntu_install_jax.sh b/docker/install/ubuntu_install_jax.sh index 17114e0efce8..19149909161e 100644 --- a/docker/install/ubuntu_install_jax.sh +++ b/docker/install/ubuntu_install_jax.sh @@ -20,18 +20,16 @@ set -e set -u set -o pipefail -JAX_VERSION=0.4.30 - -# Install jaxlib +# Install jax and jaxlib if [ "$1" == "cuda" ]; then - pip install -U \ - "jax[cuda12]~=${JAX_VERSION}" \ - jaxlib~=${JAX_VERSION} + pip3 install --upgrade \ + jaxlib~=0.4.9 \ + "jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else - pip3 install -U \ - jax~=${JAX_VERSION} \ - jaxlib~=${JAX_VERSION} + pip3 install --upgrade \ + jaxlib~=0.4.9 \ + "jax[cpu]~=0.4.9" fi # Install flax -pip3 install flax~=0.8.5 +pip3 install flax~=0.6.9 diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 012b678916b3..2225b7aef3b8 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -21,5 +21,5 @@ set -u set -o pipefail pip3 install \ - keras==3.5 \ - tensorflow==2.17.0 + keras==2.9 \ + tensorflow==2.9.1 diff --git a/docker/install/ubuntu_install_tensorflow_aarch64.sh b/docker/install/ubuntu_install_tensorflow_aarch64.sh index 4b158948387b..fcd912a4478a 100755 --- a/docker/install/ubuntu_install_tensorflow_aarch64.sh +++ b/docker/install/ubuntu_install_tensorflow_aarch64.sh @@ -25,5 +25,5 @@ apt-install-and-clear -y --no-install-recommends libhdf5-dev # h5py wheel tries to use the wrong .so file pip3 install \ numpy==1.23.5 \ - keras==3.5 \ - tensorflow-aarch64~=2.16.1 + keras==2.9 \ + tensorflow-aarch64~=2.9.3 diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index 8faabc022640..36e6dfc42794 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -26,11 +26,11 @@ set -o pipefail TENSORFLOW_VERSION=$(python3 -c "import tensorflow; print(tensorflow.__version__)" 2> /dev/null) # Download, build and install flatbuffers -git clone --branch=v24.3.25 --depth=1 --recursive https://github.com/google/flatbuffers.git -pushd flatbuffers - cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess" - ninja install -j8 -popd +git clone --branch=v1.12.0 --depth=1 --recursive https://github.com/google/flatbuffers.git +cd flatbuffers +cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-Wno-class-memaccess" +make install -j8 +cd .. # Install flatbuffers python packages. pip3 install flatbuffers @@ -41,22 +41,22 @@ pip3 install flatbuffers git clone https://github.com/tensorflow/tensorflow --branch=v${TENSORFLOW_VERSION} --depth 1 mkdir -p /opt/tflite -pushd /opt/tflite - cmake -G Ninja \ - -DTFLITE_ENABLE_XNNPACK=OFF \ - /tensorflow/tensorflow/lite +cd /opt/tflite +cmake \ + -DTFLITE_ENABLE_XNNPACK=OFF \ + /tensorflow/tensorflow/lite + +cmake --build . +cd - - cmake --build . -popd # Setup tflite from schema mkdir tflite -find / -name "schema.fbs" -cp /tensorflow/tensorflow/lite/stablehlo/schema/schema.fbs tflite -pushd tflite - flatc --python schema.fbs +cp tensorflow/tensorflow/lite/schema/schema.fbs tflite +cd tflite +flatc --python schema.fbs - cat <setup.py +cat <setup.py import setuptools setuptools.setup( @@ -77,12 +77,12 @@ setuptools.setup( ) EOM - cat <__init__.py +cat <__init__.py name = "tflite" EOM - # Install tflite over python3 - python3 setup.py install +# Install tflite over python3 +python3 setup.py install -popd +cd .. rm -rf tflite From d973b33f7f1b5a0244593260ee807b7dc64a1333 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 23 Oct 2024 08:32:16 +0900 Subject: [PATCH 20/20] Replace `np.int` with `np.int32` (#17484) --- tests/python/topi/test_topi_depthwise_conv2d_back_input.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/test_topi_depthwise_conv2d_back_input.py b/tests/python/topi/test_topi_depthwise_conv2d_back_input.py index b0a263172010..5087b0047315 100644 --- a/tests/python/topi/test_topi_depthwise_conv2d_back_input.py +++ b/tests/python/topi/test_topi_depthwise_conv2d_back_input.py @@ -36,8 +36,8 @@ def verify_depthwise_conv2d_back_input( stride_w = stride_h padding_w = padding_h - out_h = np.int((in_h + 2 * padding_h - filter_h) / stride_h + 1) - out_w = np.int((in_w + 2 * padding_w - filter_w) / stride_w + 1) + out_h = np.int32((in_h + 2 * padding_h - filter_h) / stride_h + 1) + out_w = np.int32((in_w + 2 * padding_w - filter_w) / stride_w + 1) out_channel = in_channel * channel_multiplier ishape = [batch, in_h, in_w, in_channel]