From 308599ae9c656051837dfe96b322321ad781bf05 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 29 Feb 2024 08:07:10 -0600 Subject: [PATCH] [Transform] Check for zero-param operators in LiftTransformParams (#16595) Prior to this commit, `LiftTransformParams` would extract out all variable binding that have no runtime dependencies. As a result, expressions such as `R.zeros([16], "int32")` would be extracted out into the parameter transformation, even though they do not depend on any parameters. This commit updates `LiftTransformParams` to only output variables that depend on at least one compile-time parameter. The unit test for this functionality also found that `relax::Call` was erroneously calling `MarkGraphNode` in `SEqualReduce` and `SHashReduce`. This should only be called for nodes that have have reference equality, such as `relax::Var`, and not for composite objects. This caused erroneous failures in the unit test when two instances of `R.zeros([16], "int32")` were being compared by reference equality in `StructuralEqual`. These extra calls to `MarkGraphNode` have been removed. --- include/tvm/relax/expr.h | 2 - src/relax/transform/lift_transform_params.cc | 75 ++++++++++++++++--- .../test_transform_lift_transform_params.py | 54 +++++++++++++ tests/python/relax/test_utils.py | 23 ++++++ 4 files changed, 142 insertions(+), 12 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 23262ea817946..fdbd7bd8eb2c6 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -169,13 +169,11 @@ class CallNode : public ExprNode { bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { // skip sinfo_args check for primitive ops. - equal->MarkGraphNode(); return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && equal(sinfo_args, other->sinfo_args) && equal(struct_info_, other->struct_info_); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); hash_reduce(op); hash_reduce(args); hash_reduce(attrs); diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 15b60f5492c22..724ec2f7abc80 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -58,6 +58,15 @@ struct CollectInfo { */ std::vector computable_at_compile_time; + /*! \brief Variables that require a compile-time parameter + * + * Used to distinguish between computed tensors that depend on the + * model weights, and computed tensors that require neither model + * weights nor runtime arguments (e.g. `R.zeros([16], "float16")`). + */ + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + requires_compile_time_param; + /*! \brief Variables that are required at runtime */ std::unordered_set, ObjectPtrHash, ObjectPtrEqual> required_at_runtime; @@ -114,7 +123,8 @@ struct CollectInfo { // Any variable that is computed at compile-time, but is required // at runtime, must be provided as a parameter. for (const auto& binding : computable_at_compile_time) { - if (required_at_runtime.count(binding->var)) { + if (requires_compile_time_param.count(binding->var) && + required_at_runtime.count(binding->var)) { params.push_back(binding->var); } } @@ -182,16 +192,21 @@ struct CollectInfo { // Any binding that is computable at compile-time should be // suppressed at run-time. - struct SuppressCompileTime : ExprMutator { - std::unordered_set to_suppress; - explicit SuppressCompileTime(const std::vector& bindings) { - for (const auto& binding : bindings) { - to_suppress.insert(binding->var); - } + std::unordered_set to_suppress; + for (const auto& binding : computable_at_compile_time) { + if (requires_compile_time_param.count(binding->var)) { + to_suppress.insert(binding->var); } + } + + class SuppressCompileTime : public ExprMutator { + public: + explicit SuppressCompileTime( + const std::unordered_set& to_suppress) + : to_suppress_(to_suppress) {} void VisitBinding(const Binding& binding) override { - if (!to_suppress.count(binding->var)) { + if (!to_suppress_.count(binding->var)) { ExprMutator::VisitBinding(binding); } } @@ -205,8 +220,11 @@ struct CollectInfo { return ExprMutator::VisitExpr_(call); } } + + private: + const std::unordered_set& to_suppress_; }; - Expr body = SuppressCompileTime(computable_at_compile_time)(orig_func->body); + Expr body = SuppressCompileTime(to_suppress)(orig_func->body); body = SeqExpr({DataflowBlock(bindings)}, body); Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs); @@ -300,6 +318,7 @@ class LiftableBindingCollector : ExprVisitor { for (size_t i = num_runtime_params; i < func->params.size(); i++) { liftable_vars_.insert(func->params[i]); + info_.requires_compile_time_param.insert(func->params[i]); for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func->params[i]))) { liftable_vars_.insert(tir_var); } @@ -315,12 +334,48 @@ class LiftableBindingCollector : ExprVisitor { } void VisitBinding(const Binding& binding) override { + auto bound_value = GetBoundValue(binding); + if (CanLiftBinding(binding)) { info_.computable_at_compile_time.push_back(binding); liftable_vars_.insert(binding->var); + + // There are three type of variables we want to distinguish. + // + // 1. Depend on runtime parameters + // + // Must remain within the original function, cannot be + // lifted out into the `transform_params` function. + // + // 2. Depend on model weights, but not runtime parameters. + // + // Legal to lift out into the `transform_params` function. + // Doing so is beneficial, as it reduces the work performed + // in the inference function. + // + // 3. Depend on neither model weights nor runtime parameters + // (e.g. `R.zeros(shape,dtype)`) + // + // Legal to lift out into the `transform_params` function. + // However, doing so would increase the memory footprint of + // the pre-computed parameters, for little to no benefit. + // These may be duplicated between the `transform_params` + // function and the original function, as they typically + // initialize a tensor to an easy-to-compute state. + // + // Tracking whether a variable depends on the model weights, + // either directly or indirectly, allows us to distinguish + // between categories (2) and (3). + auto upstream_vars = FreeVars(bound_value); + bool depends_on_compile_time_param = std::any_of( + upstream_vars.begin(), upstream_vars.end(), + [&](const Var& var) -> bool { return info_.requires_compile_time_param.count(var); }); + if (depends_on_compile_time_param) { + info_.requires_compile_time_param.insert(binding->var); + } + } else { info_.required_at_runtime.insert(binding->var); - auto bound_value = GetBoundValue(binding); for (const auto& upstream_var : FreeVars(bound_value)) { info_.required_at_runtime.insert(upstream_var); } diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 8042765d40517..ce2dffcb51785 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -795,5 +795,59 @@ def main( tvm.ir.assert_structural_equal(Expected, After) +def test_only_lift_when_variable_uses_constants(): + """A variable that has no inputs should not be lifted + + For example, `R.zeros`, or the result of allocation function + calls. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main( + A: R.Tensor([16], "int32"), + B: R.Tensor([16], "int32"), + ): + R.func_attr({"num_input": 1}) + with R.dataflow(): + offset = R.ones([16], "int32") + A_offset = R.add(A, offset) + B_offset = R.add(B, offset) + output = R.multiply(A_offset, B_offset) + R.output(output) + return output + + @tvm.script.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([16], "int32"), + B_offset: R.Tensor([16], "int32"), + ): + R.func_attr({"num_input": 1}) + with R.dataflow(): + offset = R.ones([16], "int32") + A_offset = R.add(A, offset) + output = R.multiply(A_offset, B_offset) + R.output(output) + return output + + @R.function + def main_transform_params(params: R.Tuple([R.Tensor([16], "int32")])): + R.func_attr({"num_input": 0}) + with R.dataflow(): + offset = R.ones([16], "int32") + B = params[0] + B_offset = R.add(B, offset) + output = (B_offset,) + R.output(output) + return output + + mod = Before + after = relax.transform.LiftTransformParams()(mod) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index f0c4ae0bd2a3a..0cae5101a755a 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -122,5 +122,28 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): assert_structural_equal(Actual, Expected) +def test_structural_equal_of_call_nodes(): + """relax.Call must be compared by structural equality, not reference""" + + # Three identical calls to relax.op.zeros + calls_to_op_zero = [relax.op.zeros([16], "int32") for _ in range(3)] + + @R.function(private=True) + def uses_same_object_twice(): + A = calls_to_op_zero[0] + B = calls_to_op_zero[0] + C = R.add(A, B) + return C + + @R.function(private=True) + def uses_two_different_objects(): + A = calls_to_op_zero[1] + B = calls_to_op_zero[2] + C = R.add(A, B) + return C + + tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects) + + if __name__ == "__main__": pytest.main([__file__])