Skip to content

Commit

Permalink
[DPL] Support tir_vars field in is_call_tir pattern (#16494)
Browse files Browse the repository at this point in the history
* [DPL] Support tir_vars field in is_call_tir pattern

* add doc

* lint

* lint
  • Loading branch information
vinx13 authored Feb 1, 2024
1 parent 00a2c7d commit f21c5a4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
11 changes: 8 additions & 3 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,19 +871,23 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
def _is_call_tir(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
tir_vars: Optional[DFPattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
if tir_vars is None:
return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False)


# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
def is_call_tir(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
tir_vars: Optional[DFPattern] = None,
) -> CallPattern:
"""
Syntax sugar for creating a CallPattern for call_tir that calls an function through global var.
Expand All @@ -894,14 +898,15 @@ def is_call_tir(
Name of the CPS function to call.
args : Union[List[DFPattern], Tuple[DFPattern]], optional
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments
tir_vars : Optional[DFPattern]
Pattern to match the tuple of integers that are unpacked when calling the tir func.
Returns
-------
CallPattern
The resulting CallPattern
"""
func_pattern = GlobalVarPattern(func_name)
return _is_call_tir(func_pattern, args)
return _is_call_tir(func_pattern, args, tir_vars)


def _is_call_dps_packed(
Expand Down
6 changes: 5 additions & 1 deletion src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,18 @@ ConstantPattern IsConst() { return ConstantPattern(make_object<ConstantPatternNo
WildcardPattern Wildcard() { return WildcardPattern(make_object<WildcardPatternNode>()); }
ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); }
ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); }
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args) {
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args,
Optional<DFPattern> tir_vars) {
DFPattern arg_pattern;
if (!var_args.defined()) {
arg_pattern = Wildcard();
} else {
arg_pattern = var_args.value();
}

if (tir_vars.defined()) {
return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, tir_vars.value());
}
return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern);
}

Expand Down
21 changes: 18 additions & 3 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,27 @@ def tir_relu(x: T.handle, y: T.handle):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.max(A[vi, vj], 0.0)

@T.prim_func
def tir_zeros(x: T.handle, n: T.int64):
T.func_attr({"global_symbol": "tir_zeros"})
A = T.match_buffer(x, [n])
for i in range(n):
with T.block():
vi = T.axis.remap("S", [i])
A[vi] = 1.0

@R.function
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
cls = Module
with R.dataflow():
lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
R.output(lv1)
return lv1
lv2 = R.call_tir(
cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
)
gv = (lv1, lv2)
R.output(gv)
return gv


main_fn = Module["main"]
Expand Down Expand Up @@ -293,10 +306,12 @@ def test_match_call_attr():

def test_is_call_tir():
lv1_val = bindings[1].value
lv2_val = bindings[2].value
var2val = get_var2val(Module["main"])
assert is_call_tir("tir_relu").match(lv1_val)
assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)


@R.function
Expand Down

0 comments on commit f21c5a4

Please sign in to comment.