diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index cb4bb23f88981..b9894e427af81 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -197,14 +197,14 @@ struct CollectInfo { } } - struct SuppressCompileTime : ExprMutator { - std::unordered_set to_suppress; + class SuppressCompileTime : public ExprMutator { + public: explicit SuppressCompileTime( - std::unordered_set to_suppress) - : to_suppress(to_suppress) {} + const std::unordered_set& to_suppress) + : to_suppress_(to_suppress) {} void VisitBinding(const Binding& binding) override { - if (!to_suppress.count(binding->var)) { + if (!to_suppress_.count(binding->var)) { ExprMutator::VisitBinding(binding); } } @@ -218,8 +218,11 @@ struct CollectInfo { return ExprMutator::VisitExpr_(call); } } + + private: + const std::unordered_set& to_suppress_; }; - Expr body = SuppressCompileTime(std::move(to_suppress))(orig_func->body); + Expr body = SuppressCompileTime{to_suppress}(orig_func->body); body = SeqExpr({DataflowBlock(bindings)}, body); Function func(params, body, orig_func->ret_struct_info, orig_func->is_pure, orig_func->attrs); @@ -335,6 +338,32 @@ class LiftableBindingCollector : ExprVisitor { info_.computable_at_compile_time.push_back(binding); liftable_vars_.insert(binding->var); + // There are three type of variables we want to distinguish. + // + // 1. Depend on runtime parameters + // + // Must remain within the original function, cannot be + // lifted out into the `transform_params` function. + // + // 2. Depend on model weights, but not runtime parameters. + // + // Legal to lift out into the `transform_params` function. + // Doing so is beneficial, as it reduces the work performed + // in the inference function. + // + // 3. Depend on neither model weights nor runtime parameters + // (e.g. `R.zeros(shape,dtype)`) + // + // Legal to lift out into the `transform_params` function. + // However, doing so would increase the memory footprint of + // the pre-computed parameters, for little to no benefit. + // These may be duplicated between the `transform_params` + // function and the original function, as they typically + // initialize a tensor to an easy-to-compute state. + // + // Tracking whether a variable depends on the model weights, + // either directly or indirectly, allows us to distinguish + // between categories (2) and (3). auto upstream_vars = FreeVars(bound_value); bool depends_on_compile_time_param = std::any_of( upstream_vars.begin(), upstream_vars.end(), diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index f0c4ae0bd2a3a..0cae5101a755a 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -122,5 +122,28 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): assert_structural_equal(Actual, Expected) +def test_structural_equal_of_call_nodes(): + """relax.Call must be compared by structural equality, not reference""" + + # Three identical calls to relax.op.zeros + calls_to_op_zero = [relax.op.zeros([16], "int32") for _ in range(3)] + + @R.function(private=True) + def uses_same_object_twice(): + A = calls_to_op_zero[0] + B = calls_to_op_zero[0] + C = R.add(A, B) + return C + + @R.function(private=True) + def uses_two_different_objects(): + A = calls_to_op_zero[1] + B = calls_to_op_zero[2] + C = R.add(A, B) + return C + + tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects) + + if __name__ == "__main__": pytest.main([__file__])