From 8987c7c2787e2e779c29d34ea56a691c380fe966 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 9 Mar 2024 07:13:45 -0600 Subject: [PATCH] [Unity][TIR] Clear struct info when specializing PrimFunc (#16584) In rare cases, a `PrimFunc` may be annotated with `StructInfo`, to indicate that it is an impure function with specific shapes for the parameters. If struct info is present, it is invalidated when specializing a `PrimFunc`, and should be cleared. --- src/tir/ir/specialize.cc | 2 ++ tests/python/tir-base/test_tir_specialize.py | 38 ++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 5964f02932997..8095b3141fbf5 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -109,6 +109,8 @@ class PrimFuncSpecializer : public StmtExprMutator { f_ptr->params = std::move(params); f_ptr->buffer_map = std::move(buffer_map); f_ptr->body = std::move(body); + f_ptr->struct_info_ = NullOpt; + f_ptr->checked_type_ = Type(nullptr); } return f; } diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index f695b8522594f..fd2843f743be3 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=missing-function-docstring, missing-module-docstring +import pytest + import tvm from tvm.script import tir as T from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol @@ -324,5 +326,41 @@ def expected(A_data: T.handle("float32")): tvm.ir.assert_structural_equal(expected, after) +def test_specialization_removes_struct_info(): + """Reset struct info in specialization + + While a PrimFunc usually doesn't have a `relax.StructInfo`, the + field can be populated in some edge cases. If that PrimFunc is + specialized, the struct info should be reset. + """ + + @T.prim_func(private=True) + def before(n: T.int32) -> T.int32: + T.ret(n * 10) + + @T.prim_func(private=True) + def expected() -> T.int32: + T.ret(50) + + sinfo = tvm.relax.FuncStructInfo( + [tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32") + ) + tvm.relax.expr._update_struct_info(before, sinfo) + + n = before.params[0] + param_map = {n: 5} + after = before.specialize(param_map) + + tvm.ir.assert_structural_equal(expected, after) + assert before.struct_info is not None + + # PrimFuncs do not expose the `struct_info_` field. Checking the + # `struct_info` field when it isn't set raises an exception. This + # is the desired behavior, since the struct info before + # specialization is no longer valid. + with pytest.raises(tvm.TVMError): + after.struct_info + + if __name__ == "__main__": tvm.testing.main()