From f21c5a424518ab516209931e0208e8dc1835aa41 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 1 Feb 2024 10:10:48 -0800 Subject: [PATCH] [DPL] Support tir_vars field in is_call_tir pattern (#16494) * [DPL] Support tir_vars field in is_call_tir pattern * add doc * lint * lint --- python/tvm/relax/dpl/pattern.py | 11 ++++++++--- src/relax/ir/dataflow_pattern.cc | 6 +++++- tests/python/relax/test_dataflow_pattern.py | 21 ++++++++++++++++++--- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index e5670dee4b7e..5594dea3ad74 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -871,19 +871,23 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern": def _is_call_tir( func_pattern: DFPattern, args: Union[List, Tuple, TuplePattern] = None, + tir_vars: Optional[DFPattern] = None, ) -> CallPattern: if args is None: args = wildcard() elif isinstance(args, (list, tuple)): args = TuplePattern(args) - return is_op("relax.call_tir")(func_pattern, args, add_constraint=False) + if tir_vars is None: + return is_op("relax.call_tir")(func_pattern, args, add_constraint=False) + return is_op("relax.call_tir")(func_pattern, args, tir_vars, add_constraint=False) # Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo def is_call_tir( func_name: str, args: Union[List, Tuple, TuplePattern] = None, + tir_vars: Optional[DFPattern] = None, ) -> CallPattern: """ Syntax sugar for creating a CallPattern for call_tir that calls an function through global var. @@ -894,14 +898,15 @@ def is_call_tir( Name of the CPS function to call. args : Union[List[DFPattern], Tuple[DFPattern]], optional Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments - + tir_vars : Optional[DFPattern] + Pattern to match the tuple of integers that are unpacked when calling the tir func. Returns ------- CallPattern The resulting CallPattern """ func_pattern = GlobalVarPattern(func_name) - return _is_call_tir(func_pattern, args) + return _is_call_tir(func_pattern, args, tir_vars) def _is_call_dps_packed( diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 1286a32e4cb8..ca81b910126a 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -574,7 +574,8 @@ ConstantPattern IsConst() { return ConstantPattern(make_object()); } ExprPattern IsExpr(const Expr& expr) { return ExprPattern(expr); } ExprPattern IsOp(const String& op_name) { return IsExpr(Op::Get(op_name)); } -CallPattern IsCallTIR(const String& name, Optional var_args) { +CallPattern IsCallTIR(const String& name, Optional var_args, + Optional tir_vars) { DFPattern arg_pattern; if (!var_args.defined()) { arg_pattern = Wildcard(); @@ -582,6 +583,9 @@ CallPattern IsCallTIR(const String& name, Optional var_args) { arg_pattern = var_args.value(); } + if (tir_vars.defined()) { + return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern, tir_vars.value()); + } return IsOp("relax.call_tir")(GlobalVarPattern(name), arg_pattern); } diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 7f2cb241bb1d..39b3c40c2676 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -56,14 +56,27 @@ def tir_relu(x: T.handle, y: T.handle): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.max(A[vi, vj], 0.0) + @T.prim_func + def tir_zeros(x: T.handle, n: T.int64): + T.func_attr({"global_symbol": "tir_zeros"}) + A = T.match_buffer(x, [n]) + for i in range(n): + with T.block(): + vi = T.axis.remap("S", [i]) + A[vi] = 1.0 + @R.function - def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple: cls = Module with R.dataflow(): lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) - R.output(lv1) - return lv1 + lv2 = R.call_tir( + cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) + ) + gv = (lv1, lv2) + R.output(gv) + return gv main_fn = Module["main"] @@ -293,10 +306,12 @@ def test_match_call_attr(): def test_is_call_tir(): lv1_val = bindings[1].value + lv2_val = bindings[2].value var2val = get_var2val(Module["main"]) assert is_call_tir("tir_relu").match(lv1_val) assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val) assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val) + assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val) @R.function