Skip to content

Commit

Permalink
[Unity][VM] Recursively visit match bindings in VMShapeLowerMutator (#…
Browse files Browse the repository at this point in the history
…16583)

Prior to this commit, the `MatchBinding` visitor in
`VMShapeLowerMutator`.  If the RHS of the `MatchBinding` is a
`ShapeExpr` that uses symbolic variables, that RHS must be visited in
order to have the symbolic variable updated.
  • Loading branch information
Lunderberg authored Feb 17, 2024
1 parent 9ed9f7a commit 94e83f2
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ class VMShapeLowerMutator

// These checks are emitted as extra, in codegen
// match-cast is simply ignored and treated as a normal binding.
builder_->EmitNormalized(GetRef<MatchCast>(binding));
ExprMutator::VisitBinding_(binding);
}

// Do not override shape in struct info fields
Expand Down
78 changes: 78 additions & 0 deletions tests/python/relax/test_backend_transform_shape_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,5 +731,83 @@ def main(
assert_structural_equal(after, expected)


def test_update_symbolic_vars_in_match_cast_rhs():
"""Symbolic variables may be used on the RHS of match_cast"""

@I.ir_module
class Before:
@R.function
def main(
arg_prim_value: R.Prim(value="n"),
):
R.func_attr({"relax.force_pure": 1})
n = T.int64()
shape = R.shape([n])
m = T.int64()
_ = R.match_cast(shape, R.Shape([m]))
return R.prim_value(m)

@I.ir_module
class Expected:
@R.function
def main(arg_prim_value: R.Prim(value="n")) -> R.Prim("int64"):
R.func_attr({"relax.force_pure": 1})
n = T.int64()

shape_heap = R.call_builtin_with_ctx(
"vm.builtin.alloc_shape_heap",
[2],
sinfo_args=(R.Tensor(dtype="int64", ndim=1),),
)
_ = R.call_packed(
"vm.builtin.check_prim_value_info",
arg_prim_value,
R.dtype("int64"),
"",
sinfo_args=[R.Tuple],
)
_ = R.call_packed(
"vm.builtin.match_prim_value",
arg_prim_value,
shape_heap,
MatchShapeCode.STORE_TO_HEAP,
0,
"",
sinfo_args=[R.Tuple],
)
shape = R.call_packed(
"vm.builtin.make_shape",
shape_heap,
1,
MakeShapeCode.LOAD_SHAPE,
0,
sinfo_args=[R.Shape(ndim=1)],
)
_ = R.call_packed(
"vm.builtin.match_shape",
shape,
shape_heap,
1,
MatchShapeCode.STORE_TO_HEAP,
1,
"",
sinfo_args=[R.Tuple],
)

m = T.int64()
_ = R.match_cast(shape, R.Shape([m]))
gv = R.call_packed(
"vm.builtin.make_prim_value",
shape_heap,
MakeShapeCode.LOAD_SHAPE,
1,
sinfo_args=[R.Prim(value=m)],
)
return gv

After = relax.transform.VMShapeLower(emit_err_ctx=False)(Before)
assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 94e83f2

Please sign in to comment.