Skip to content

Commit

Permalink
[Unity][TIR] Clear struct info when specializing PrimFunc (apache#16584)
Browse files Browse the repository at this point in the history
In rare cases, a `PrimFunc` may be annotated with `StructInfo`, to
indicate that it is an impure function with specific shapes for the
parameters.  If struct info is present, it is invalidated when
specializing a `PrimFunc`, and should be cleared.
  • Loading branch information
Lunderberg authored and thaisacs committed Apr 3, 2024
1 parent fb506fd commit 8987c7c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class PrimFuncSpecializer : public StmtExprMutator {
f_ptr->params = std::move(params);
f_ptr->buffer_map = std::move(buffer_map);
f_ptr->body = std::move(body);
f_ptr->struct_info_ = NullOpt;
f_ptr->checked_type_ = Type(nullptr);
}
return f;
}
Expand Down
38 changes: 38 additions & 0 deletions tests/python/tir-base/test_tir_specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=missing-function-docstring, missing-module-docstring

import pytest

import tvm
from tvm.script import tir as T
from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol
Expand Down Expand Up @@ -324,5 +326,41 @@ def expected(A_data: T.handle("float32")):
tvm.ir.assert_structural_equal(expected, after)


def test_specialization_removes_struct_info():
"""Reset struct info in specialization
While a PrimFunc usually doesn't have a `relax.StructInfo`, the
field can be populated in some edge cases. If that PrimFunc is
specialized, the struct info should be reset.
"""

@T.prim_func(private=True)
def before(n: T.int32) -> T.int32:
T.ret(n * 10)

@T.prim_func(private=True)
def expected() -> T.int32:
T.ret(50)

sinfo = tvm.relax.FuncStructInfo(
[tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32")
)
tvm.relax.expr._update_struct_info(before, sinfo)

n = before.params[0]
param_map = {n: 5}
after = before.specialize(param_map)

tvm.ir.assert_structural_equal(expected, after)
assert before.struct_info is not None

# PrimFuncs do not expose the `struct_info_` field. Checking the
# `struct_info` field when it isn't set raises an exception. This
# is the desired behavior, since the struct info before
# specialization is no longer valid.
with pytest.raises(tvm.TVMError):
after.struct_info


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 8987c7c

Please sign in to comment.