From 864fd5c706e8a448aa079f1b82c56e12ccc25328 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 23 Feb 2024 08:37:26 -0600 Subject: [PATCH] [Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr (#16599) * [Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr Update the `relax.transform.EliminateCommonSubexpr` pass to handle `R.match_cast` bindings, where the argument of the `R.match_cast` has also been de-duplicated. * Fix unit tests failures * Add unit test for avoiding leak of dataflow var * Track all legal de-duplications, in case the first is a DataflowVar * De-duplicate within an if/else, using bindings before the if/else --- .../transform/eliminate_common_subexpr.cc | 293 +++++++---------- tests/python/relax/test_transform_cse.py | 308 ++++++++++++++++-- 2 files changed, 411 insertions(+), 190 deletions(-) diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 7931d73b7be9..5804b1c5bb67 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -20,223 +20,180 @@ /*! * \file tvm/relax/transform/eliminate_common_subexpr.cc - * \brief Eliminrate common subexpression pass. + * \brief Eliminate common subexpression pass. * * Currently it removes common subexpressions within a Function. */ +#include #include #include #include -#include "utils.h" +#include "../../support/utils.h" namespace tvm { namespace relax { - -// Checks if a given expression contains an impure subexpression -// Caches the results of checks to avoid revisiting subexpressions -class ImpurityDetector : public ExprVisitor { - public: - bool Detect(const Expr& expr) { - impure_found_ = false; - VisitExpr(expr); - return impure_found_; +namespace { +/* \brief Lookup key for subexpression replacements + * + * The lookup key must contain the expression being bound, along with + * the struct info used for a match cast, if applicable. Using + * `MatchCast` with StructuralEqual and StructuralHash would be almost + * correct, but acts as a point of definition for symbolic variables + * within the output struct info. As a result, it would erroneously + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define + * different symbolic variables. + */ +struct ReplacementKey { + tvm::relax::Expr bound_value; + tvm::Optional match_cast = tvm::NullOpt; + + explicit ReplacementKey(const tvm::relax::Binding& binding) + : bound_value(GetBoundValue(binding)) { + if (const auto* ptr = binding.as()) { + match_cast = ptr->struct_info; + } } - void VisitExpr(const Expr& expr) { - // already checked: do not revisit - if (purity_map_.count(expr)) { - impure_found_ = impure_found_ || !purity_map_.at(expr); - return; - } + friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { + tvm::StructuralEqual eq; + return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); + } +}; - // in principle, we could stop checking once we find an impurity, - // but not doing so lets us fully populate the cache +} // namespace +} // namespace relax +} // namespace tvm - // store the previous state so we could assess the purity of this subexpression alone - bool prev_state = impure_found_; - impure_found_ = false; - ExprVisitor::VisitExpr(expr); - // if impure_found_ remains false, then the expression is pure - purity_map_[expr] = !impure_found_; - impure_found_ = prev_state || impure_found_; +/* \brief Definition of std::hash + * + * Specialization of std::hash must occur outside of tvm::relax + * namespace, and before its usage in the constructor of + * `CommonSubexprEliminator`. + */ +template <> +struct std::hash { + std::size_t operator()(const tvm::relax::ReplacementKey& key) const { + tvm::StructuralHash hasher; + return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); } +}; - void VisitExpr_(const CallNode* call) { - // the only possible impurities can come from call nodes - bool is_impure = IsImpureCall(GetRef(call)); - impure_found_ = impure_found_ || is_impure; - ExprVisitor::VisitExpr_(call); - } +namespace tvm { +namespace relax { - private: - bool impure_found_ = false; - std::unordered_map purity_map_; -}; +namespace { -class SubexprCounter : public ExprVisitor { +class CommonSubexprEliminator : public ExprMutator { public: - static std::unordered_map Count(const Expr& expr) { - SubexprCounter visitor; - visitor(expr); - return visitor.count_map_; - } + explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {} - // overriding VisitExpr ensures we do this for every subexpression - void VisitExpr(const Expr& e) override { - // Cases we ignore because we will not substitute them: - // 1. Vars of all kinds - // 2. Op nodes (nothing we can do) - // 3. PrimValue nodes (not much benefit from binding to a var) - // 4. StringImm nodes (not much benefit from binding to a var) - // 5. Scalar constants (not much benefit from binding to a var) - // 6. Shape expressions (exist to hold several PrimValue objects) - // 7. DataType nodes (no need to modify dtype nodes) - if (!(e->IsInstance() || e->IsInstance() || - e->IsInstance() || e->IsInstance() || - e->IsInstance() || e->IsInstance() || - e->IsInstance() || e->IsInstance() || - e->IsInstance() || e->IsInstance())) { - // also if e has an impure subexpression, we will not deduplicate it - if (!impurity_detector_.Detect(e)) { - int count = 0; - if (count_map_.count(e)) { - count = count_map_.at(e); - } - count_map_[e] = count + 1; - } - } + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + auto cache_vars = var_remap_; + auto output = ExprMutator::VisitBindingBlock_(block); - // Only visit the interior of objects that we might still keep - // around. Otherwise, double-counting these would lead to extra - // variable bindings. - // - // Before: - // y = f(a+b) - // z = f(a+b) - // - // Expected: - // y = f(a+b) // De-duped from (y==z) - // z = y - // - // Erroneous output: - // c = a+b // Incorrect, a+b only has a single usage. - // y = f(c) // De-duped from - // z = y - // - if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) { - ExprVisitor::VisitExpr(e); + for (auto& [key, replacements] : expr_replacements_) { + replacements.erase( + std::remove_if(replacements.begin(), replacements.end(), + [](const Var& var) -> bool { return var->IsInstance(); }), + replacements.end()); } + + var_remap_ = cache_vars; + return output; } - // do not visit inner functions: we will do CSE within those - void VisitExpr_(const FunctionNode* func) override {} + void VisitBinding(const Binding& binding) override { + Expr bound_value = VisitExpr(GetBoundValue(binding)); + + Binding output_binding = [&]() -> Binding { + if (binding.as()) { + return VarBinding(binding->var, bound_value); + } else if (auto match_cast = binding.as()) { + return MatchCast(binding->var, bound_value, match_cast->struct_info); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " + << "but was " << binding->GetTypeKey(); + } + }(); - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + ReplacementKey lookup_key(output_binding); - private: - std::unordered_map count_map_; - ImpurityDetector impurity_detector_; -}; + if (call_only_ && !bound_value->IsInstance()) { + VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; -class CommonSubexprEliminator : public ExprMutator { - public: - explicit CommonSubexprEliminator( - std::unordered_map count_map, - bool call_only = false) - : count_map_(std::move(count_map)), call_only_(call_only) {} - - // overriding here ensures we visit every subexpression - Expr VisitExpr(const Expr& e) override { - if (call_only_ && !e->IsInstance()) { - return ExprMutator::VisitExpr(e); - } - if (count_map_.count(e) && count_map_.at(e) > 1) { - // if we already have a mapping for it, get it - if (replacements_.count(e)) { - return replacements_.at(e); - } - // Otherwise, insert a new binding for the current expression. - // Visit before emitting to do inner replacements - Expr new_e = ExprMutator::VisitExpr(e); - Var v = builder_->Emit(new_e); - replacements_[e] = v; - return v; - } - return ExprMutator::VisitExpr(e); - } + } else if (ContainsImpureCall(bound_value)) { + VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value; - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { - return struct_info; - } + } else if (auto it = expr_replacements_.find(lookup_key); + it != expr_replacements_.end() && it->second.size()) { + VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second[0] + << ". The duplicate binding of this value to " << binding->var + << " will be replaced with a trivial binding, " + << "and occurrences of " << binding->var << " will be replaced with " + << it->second[0]; + output_binding = VarBinding(binding->var, it->second[0]); + var_remap_.insert({binding->var->vid, it->second[0]}); + it->second.push_back(binding->var); - Expr VisitExpr_(const FunctionNode* op) override { - Function func = GetRef(op); + } else { + VLOG(1) << "Value " << bound_value << " is bound to " << binding->var + << " and may be de-duplicated if it occurs again."; - auto cache = SubexprCounter::Count(op->body); - std::swap(cache, count_map_); - Expr output = ExprMutator::VisitExpr_(op); - std::swap(cache, count_map_); + expr_replacements_[lookup_key].push_back(binding->var); + } - return output; + builder_->EmitNormalized(output_binding); } - void VisitBinding_(const VarBindingNode* binding) override { - // no need to visit var def because the struct info isn't going to change - Expr new_value = RegisterBoundValue(binding->var, binding->value); - - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + Expr VisitExpr_(const FunctionNode* op) override { + // If we have accumulated any state, visit the function in a fresh + // copy of the mutator, to avoid replacing a child-scope + // expression with a parent-scope binding, or vice versa. + if (expr_replacements_.size() || var_remap_.size()) { + return VisitWithCleanScope(GetRef(op)); } else { - // no need to renormalize new_value because all replacements are with vars - builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + return ExprMutator::VisitExpr_(op); } } - void VisitBinding_(const MatchCastNode* binding) override { - // no need to visit var def because the struct info isn't going to change - Expr new_value = RegisterBoundValue(binding->var, binding->value); - - // re-emit old binding if nothing changes - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef(binding)); + Expr VisitExpr_(const IfNode* op) override { + Expr cond = VisitExpr(op->cond); + Expr true_branch = VisitWithInnerScope(op->true_branch); + Expr false_branch = VisitWithInnerScope(op->false_branch); + if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) && + op->false_branch.same_as(false_branch) && + VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { + return GetRef(op); } else { - // no need to renormalize new_value because all replacements are with vars - builder_->EmitNormalized( - MatchCast(binding->var, new_value, binding->struct_info, binding->span)); + return If(cond, true_branch, false_branch, op->span); } } private: - Expr RegisterBoundValue(Var var, Expr bound_value) { - // special case: if we are processing a binding - // and this is the first time we've encountered it, - // we will use the binding's var for the mapping - bool newly_replaced = false; - if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 && - !replacements_.count(bound_value)) { - replacements_[bound_value] = var; - newly_replaced = true; - } + Expr VisitWithInnerScope(Expr expr) { + auto cached_vars = var_remap_; + auto cached_exprs = expr_replacements_; + auto output = VisitExpr(expr); + var_remap_ = cached_vars; + expr_replacements_ = cached_exprs; + return output; + } - if (newly_replaced) { - // If we've just added the mapping, using the overridden visitor will - // just return the var, which we don't want, so we will use - // the superclass VisitExpr to do inner substitutions - return ExprMutator::VisitExpr(bound_value); - } - return VisitExpr(bound_value); + Expr VisitWithCleanScope(Expr expr) { + CommonSubexprEliminator clean_mutator(call_only_); + return clean_mutator.VisitExpr(expr); } - std::unordered_map count_map_; - std::unordered_map replacements_; bool call_only_{false}; + std::unordered_map> expr_replacements_; }; +} // namespace + Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) { - CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only); + CommonSubexprEliminator mutator(call_only); return mutator(expr); } diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index 2a247c342cdf..b491577314ec 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -45,10 +45,8 @@ class Expected: def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): with R.dataflow(): lv0 = R.add(x, y) - # can combine with canonicalizing bindings - # and getting rid of unused bindings to eliminate this line too lv1 = lv0 - gv = R.multiply(lv0, lv1) + gv = R.multiply(lv0, lv0) R.output(gv) return gv @@ -90,6 +88,12 @@ def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32" def test_repeated_inner_tuples(): + """CSE is only applied at variable bindings + + To remain consistent with the behavior of the normalizer, tuples + are kept as-is, even if they contain repeated sub-tuples. + """ + @I.ir_module class Before: @R.function @@ -101,18 +105,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): R.output(gv) return gv - @I.ir_module - class Expected: - @R.function - def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - with R.dataflow(): - t1 = (x, x) - t2 = (x, t1) - t3 = (t1, t2) - t4 = (t3, t3, t2) - gv = t4[0][0][1] - R.output(gv) - return gv + Expected = Before verify(Before, Expected) @@ -160,7 +153,7 @@ def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): with R.dataflow(): lv0 = R.add(y, y) lv1 = lv0 - lv2 = R.add(lv0, lv1) + lv2 = R.add(lv0, lv0) gv = lv2 R.output(gv) return R.add(gv, gv) @@ -169,11 +162,11 @@ def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): # using canonicalize bindings, eliminate unused bindings, and CSE again lv0 = bar(x) lv1 = lv0 - lv2 = R.add(lv0, lv1) + lv2 = R.add(lv0, lv0) lv3 = lv0 lv4 = lv0 - lv5 = R.add(lv3, lv4) - lv6 = R.add(lv2, lv5) + lv5 = lv2 + lv6 = R.add(lv2, lv2) gv = lv6 R.output(gv) return gv @@ -202,7 +195,7 @@ def foo(x: R.Tensor((160,), dtype="float32")) -> R.Tensor((160,), dtype="float32 lv1 = R.arange(R.prim_value(0), R.prim_value(160), R.prim_value(1), dtype="float32") lv2 = lv1 lv3 = R.add(x, lv1) - out = R.add(lv3, lv2) + out = R.add(lv3, lv1) R.output(out) return out @@ -226,12 +219,112 @@ class Expected: def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): lv0 = R.add(x, y) lv1 = lv0 - gv = R.multiply(lv0, lv1) + gv = R.multiply(lv0, lv0) return gv verify(Before, Expected) +def test_no_cse_across_dataflow(): + # same example as previously but it will work without a dataflow wrapper + @I.ir_module + class Before: + @R.function(pure=False) + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + lv1 = R.add(x, y) + gv1 = R.multiply(lv0, lv1) + R.output(gv1) + + _ = R.print(format="Prevent dataflow block merging") + + with R.dataflow(): + lv2 = R.add(x, y) + lv3 = R.add(x, y) + gv2 = R.multiply(lv2, lv3) + R.output(gv2) + + gv3 = R.add(x, y) + gv4 = R.add(x, y) + gv5 = R.multiply(gv3, gv4) + + output = R.add(R.add(gv1, gv2), gv5) + return output + + @I.ir_module + class Expected: + @R.function(pure=False) + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + # The R.add(x,y) may be de-duplicated within a dataflow block + lv0 = R.add(x, y) + lv1 = lv0 + gv1 = R.multiply(lv0, lv0) + R.output(gv1) + + _ = R.print(format="Prevent dataflow block merging") + + with R.dataflow(): + # However, the later dataflow block may not be + # de-duplicated using variables in the earlier block. + lv2 = R.add(x, y) + lv3 = lv2 + gv2 = R.multiply(lv2, lv2) + R.output(gv2) + + # And while non-dataflow bindings can be de-duplicated, + # they cannot be de-duplicated using bindings that were + # valid in either of the earlier dataflow blocks. + gv3 = R.add(x, y) + gv4 = gv3 + gv5 = R.multiply(gv3, gv3) + + output = R.add(R.add(gv1, gv2), gv5) + return output + + verify(Before, Expected) + + +def test_no_replacement_across_dataflow_boundary(): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + A = R.add(x, y) + # B has the same value as A, and so instances of B can be replaced with A. + B = R.add(x, y) + C = R.multiply(A, B) + + # However, B is exposed for use outside of the + # DataflowBlock, while A is not. Therefore, any + # additional uses of `B` must NOT be replaced with + # A. + R.output(B, C) + + # In addition, because `A` is only valid within the + # dataflow block, the `R.add(x,y)` cannot be de-duplicated + # as another usage of `A`. + D = R.add(x, y) + return (B, C, D) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + A = R.add(x, y) + B = A + C = R.multiply(A, A) + R.output(B, C) + + D = B + return (B, C, B) + + verify(Before, Expected) + + def test_do_not_eliminate_impure(): @I.ir_module class Before: @@ -256,7 +349,7 @@ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32 a1 = R.assert_op(R.const(False), format="Always fails") lv0 = R.add(x, y) lv1 = lv0 - gv = R.multiply(lv0, lv1) + gv = R.multiply(lv0, lv0) a2 = R.assert_op(R.const(False), format="Always fails") return gv @@ -363,5 +456,176 @@ def foo() -> R.Tensor((32, 64), "int32"): verify(Before, Expected) +def test_match_cast(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + A1 = R.add(x, y) + B1 = R.match_cast(A1, R.Tensor([2, 3], "float32")) + + A2 = R.add(x, y) + B2 = R.match_cast(A2, R.Tensor([2, 3], "float32")) + + gv = R.multiply(B1, B2) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + A1 = R.add(x, y) + B1 = R.match_cast(A1, R.Tensor([2, 3], "float32")) + + A2 = A1 + B2 = B1 + gv = R.multiply(B1, B1) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_match_cast_with_symbolic_vars(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): + with R.dataflow(): + A1 = R.add(x, y) + + n = T.int64() + m = T.int64() + B1 = R.match_cast(A1, R.Tensor([n, m], "float32")) + + A2 = R.add(x, y) + p = T.int64() + q = T.int64() + B2 = R.match_cast(A2, R.Tensor([p, q], "float32")) + + gv = R.multiply(B1, B2) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")): + with R.dataflow(): + A1 = R.add(x, y) + n = T.int64() + m = T.int64() + B1 = R.match_cast(A1, R.Tensor([n, m], "float32")) + + A2 = A1 + p = T.int64() + q = T.int64() + B2 = R.match_cast(A1, R.Tensor([p, q], "float32")) + + gv = R.multiply(B1, B2) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_replace_binding_within_branch_with_duplicate_before_branch(): + """Bindings before a branch may be used within the branch""" + + @I.ir_module + class Before: + @R.function + def foo( + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), + condition: R.Prim("bool"), + ): + A = R.add(x, y) + if condition: + B = R.add(x, y) + C = R.multiply(x, B) + D = R.multiply(A, C) + else: + B = R.add(x, y) + C = R.multiply(y, B) + D = R.multiply(A, C) + return D + + @I.ir_module + class Expected: + @R.function + def foo( + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), + condition: R.Prim("bool"), + ): + A = R.add(x, y) + if condition: + B = A + C = R.multiply(x, A) + D = R.multiply(A, C) + else: + B = A + C = R.multiply(y, A) + D = R.multiply(A, C) + return D + + verify(Before, Expected) + + +def test_keep_duplicate_across_if_and_then(): + """Bindings in `if` are not valid within `else`""" + + @I.ir_module + class Before: + @R.function + def foo( + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), + condition: R.Prim("bool"), + ): + if condition: + A = R.add(x, y) + B = R.multiply(x, A) + else: + A = R.add(x, y) + B = R.multiply(y, A) + return B + + Expected = Before + + verify(Before, Expected) + + +def test_keep_duplicate_after_branch(): + """Only the final binding is valid after a if/else branch""" + + @I.ir_module + class Before: + @R.function + def foo( + x: R.Tensor((2, 3), dtype="float32"), + y: R.Tensor((2, 3), dtype="float32"), + condition: R.Prim("bool"), + ): + if condition: + A = R.add(x, y) + B = R.multiply(x, A) + else: + A = R.add(x, y) + B = R.multiply(y, A) + + C = R.add(x, y) + D = R.multiply(B, C) + return D + + Expected = Before + + verify(Before, Expected) + + if __name__ == "__main__": tvm.testing.main()