Skip to content

Commit

Permalink
[TVMScript][Bugfix] Check for StructInfoProxy in R.match_cast
Browse files Browse the repository at this point in the history
Prior to this commit, bare `StructInfoProxy` annotations could be used
to annotate variables (e.g. `var: R.Tensor`).  However, they could
not be used as the argument of a match cast (e.g. `R.match_cast(obj,
R.Tensor)`).  This breaks round-trips, as the `R.match_cast` printing
generates base `StructInfoProxy` objects.

This commit updates TVMScript parsing to handle bare `StructInfoProxy`
annotations as an argument to `R.match_cast`.
  • Loading branch information
Lunderberg committed Mar 12, 2024
1 parent 69995d1 commit 94ce183
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 13 deletions.
24 changes: 24 additions & 0 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

############################## R.function ##############################


# this formulation allows us to support having @R.function
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
Expand Down Expand Up @@ -488,8 +489,31 @@ def __init__(self, value: Expr, struct_info: StructInfo) -> None:


def match_cast(value: Expr, struct_info: StructInfo):
struct_info = _normalize_struct_info(struct_info)

if value is None:
raise ValueError("value of match_cast cannot be None")
if struct_info is None:
raise ValueError("struct_info of match_cast cannot be None")
return MatchCastPair(value, struct_info)


def _normalize_struct_info_proxy(annotation) -> StructInfoProxy:
if annotation is None:
return TupleProxy([])
elif callable(annotation):
return annotation()
elif isinstance(annotation, StructInfoProxy):
return annotation
else:
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")


def _normalize_struct_info(
struct_info, dict_globals: Optional[Dict[str, Any]] = None
) -> StructInfo:
if isinstance(struct_info, StructInfo):
return struct_info
else:
proxy = _normalize_struct_info_proxy(struct_info)
return proxy.as_struct_info(dict_globals)
19 changes: 9 additions & 10 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
from ...ir_builder import relax as R
from ...ir_builder.base import IRBuilder
from .._core import Parser, dispatch, doc
from .entry import MatchCastPair, StructInfoProxy, TupleProxy
from .entry import (
MatchCastPair,
StructInfoProxy,
_normalize_struct_info_proxy,
_normalize_struct_info,
)


def bind_assign_value(
Expand Down Expand Up @@ -91,13 +96,7 @@ def bind_assign_value(
def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
try:
annotation = self.eval_expr(node)
if annotation is None:
return TupleProxy([])
if callable(annotation):
annotation = annotation()
if isinstance(annotation, StructInfoProxy):
return annotation
raise TypeError(f"Expected StructInfoProxy but got {type(annotation)}.")
return _normalize_struct_info_proxy(annotation)
except Exception as err:
self.report_error(node, str(err))
raise err
Expand All @@ -106,7 +105,8 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo:
var_table = self.var_table.get() if eval_str else None
try:
return eval_struct_info_proxy(self, node).as_struct_info(var_table)
struct_info = self.eval_expr(node)
return _normalize_struct_info(struct_info, var_table)
except Exception as err:
self.report_error(node, str(err))
raise err
Expand Down Expand Up @@ -367,7 +367,6 @@ def visit_if(self: Parser, node: doc.If) -> None:
@dispatch.register(token="relax", type_name="enter_token")
def enter_token(self: Parser) -> Dict[str, Any]:
def relax_call(self, *args) -> Expr:

args = [convert_to_expr(arg) if isinstance(arg, tuple) else arg for arg in args]

if all(isinstance(x, Expr) for x in args):
Expand Down
46 changes: 43 additions & 3 deletions tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
T.ramp((x_c * 32), 1, 32)
] + (
T.broadcast(
A_1[
(((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)),
],
A_1[(((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4))],
32,
)
* packedB[T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32)]
Expand Down Expand Up @@ -4023,6 +4021,47 @@ def func(A: R.Tensor([10, 20], "float32")):
return func


def relax_match_cast_struct_info_proxy():
"""StructInfoProxy subclasses may be used as expressions
This is a regression test. The TVMScript parser allows StructInfo
to be specified using a default-constructible class
(e.g. `R.Tensor` or `R.Shape`) rather than an instance of that
class (e.g. `R.Tensor()` or `R.Shape()`). In previous
implementations, this was only handled when the `StructInfo` was
used in an annotation context. However, a `StructInfo` may also
appear as an argument, which is passed to `R.match_cast`. Use of
a default-constructible class must be handled in this context as
well.
"""

def make_ir_generator(proxy_subclass):
def inner():
@R.function
def func(A: R.Object):
B = R.match_cast(A, proxy_subclass)
return B

return func

inner.__name__ = subclass.__name__
return inner

# Not all subclasses of StructInfoProxy are default-constructible.
# This list is a subset of `StructInfoProxy.__subclasses__()`,
# excluding `PrimProxy` and `DTensorProxy`.
subclasses = [
tvm.script.parser.relax.entry.ObjectProxy,
tvm.script.parser.relax.entry.TensorProxy,
tvm.script.parser.relax.entry.CallableProxy,
tvm.script.parser.relax.entry.TupleProxy,
tvm.script.parser.relax.entry.ShapeProxy,
]

for subclass in subclasses:
yield make_ir_generator(subclass)


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
Expand Down Expand Up @@ -4106,6 +4145,7 @@ def func(A: R.Tensor([10, 20], "float32")):
return_zero_private,
return_zero_private_with_attr,
*op_of_literal(),
*relax_match_cast_struct_info_proxy(),
)

relax_ir_generator = tvm.testing.parameter(
Expand Down

0 comments on commit 94ce183

Please sign in to comment.