Skip to content

Commit

Permalink
[Unity][TVMScript] Parse R.Object return type from call_pure_packed (#…
Browse files Browse the repository at this point in the history
…16593)

Prior to this commit, `R.call_packed` and `R.call_pure_packed` had
different normalization for the `sinfo_args` argument.  While
`R.call_packed` checked if the struct info needed to be converted using
`ObjectGeneric.asobject()`, `R.call_pure_packed` did not.

This commit updates the `R.call_pure_packed` to handle `sinfo_args`
in the same manner as `R.call_packed`.
  • Loading branch information
Lunderberg authored Feb 19, 2024
1 parent 1e6482c commit dd70941
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
import tvm.runtime
from tvm.runtime.object import Object
from tvm.runtime import ObjectGeneric

from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
Expand Down Expand Up @@ -709,12 +710,24 @@ def call_pure_packed(
func = func.global_symbol

op = ExternFunc(func)

if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")

if isinstance(sinfo_args, tuple): # type: ignore
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
sinfo_args = [sinfo_args]

sinfo_args = [
sinfo()
if callable(sinfo)
else sinfo.asobject()
if isinstance(sinfo, ObjectGeneric)
else sinfo
for sinfo in sinfo_args
]

# note: if we need attributes, we can also take them here

return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member
Expand Down
16 changes: 9 additions & 7 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,15 @@ def call_packed(
sinfo_args = list(sinfo_args)
elif not isinstance(sinfo_args, list):
sinfo_args = [sinfo_args]
for i, sinfo_arg in enumerate(sinfo_args):
if callable(sinfo_arg):
sinfo_arg = sinfo_arg()
# Convert possible StructInfoProxy to StructInfo
if isinstance(sinfo_arg, ObjectGeneric):
sinfo_arg = sinfo_arg.asobject()
sinfo_args[i] = sinfo_arg

sinfo_args = [
sinfo()
if callable(sinfo)
else sinfo.asobject()
if isinstance(sinfo, ObjectGeneric)
else sinfo
for sinfo in sinfo_args
]

is_default = False
if "attrs_type_key" in kwargs:
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,20 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
_check(foo, bb.get()["foo"])


def test_call_pure_packed_returning_object():
@R.function
def foo() -> R.Object:
z = R.call_pure_packed("dummy_func", sinfo_args=R.Object)
return z

bb = relax.BlockBuilder()
with bb.function("foo", params=[]):
z = bb.emit(R.call_pure_packed("dummy_func", sinfo_args=[relax.ObjectStructInfo()]))
bb.emit_func_output(z)

_check(foo, bb.get()["foo"])


def test_private_function():
@I.ir_module
class Addition:
Expand Down

0 comments on commit dd70941

Please sign in to comment.