Skip to content

Commit

Permalink
[Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr (#…
Browse files Browse the repository at this point in the history
…16599)

* [Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr

Update the `relax.transform.EliminateCommonSubexpr` pass to handle
`R.match_cast` bindings, where the argument of the `R.match_cast` has
also been de-duplicated.

* Fix unit tests failures

* Add unit test for avoiding leak of dataflow var

* Track all legal de-duplications, in case the first is a DataflowVar

* De-duplicate within an if/else, using bindings before the if/else
  • Loading branch information
Lunderberg authored Feb 23, 2024
1 parent e715814 commit 864fd5c
Show file tree
Hide file tree
Showing 2 changed files with 411 additions and 190 deletions.
293 changes: 125 additions & 168 deletions src/relax/transform/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,223 +20,180 @@

/*!
* \file tvm/relax/transform/eliminate_common_subexpr.cc
* \brief Eliminrate common subexpression pass.
* \brief Eliminate common subexpression pass.
*
* Currently it removes common subexpressions within a Function.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>

#include "utils.h"
#include "../../support/utils.h"

namespace tvm {
namespace relax {

// Checks if a given expression contains an impure subexpression
// Caches the results of checks to avoid revisiting subexpressions
class ImpurityDetector : public ExprVisitor {
public:
bool Detect(const Expr& expr) {
impure_found_ = false;
VisitExpr(expr);
return impure_found_;
namespace {
/* \brief Lookup key for subexpression replacements
*
* The lookup key must contain the expression being bound, along with
* the struct info used for a match cast, if applicable. Using
* `MatchCast` with StructuralEqual and StructuralHash would be almost
* correct, but acts as a point of definition for symbolic variables
* within the output struct info. As a result, it would erroneously
* de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and
* `R.match_cast(A, R.Tensor([p,q]))`, even though they define
* different symbolic variables.
*/
struct ReplacementKey {
tvm::relax::Expr bound_value;
tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt;

explicit ReplacementKey(const tvm::relax::Binding& binding)
: bound_value(GetBoundValue(binding)) {
if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) {
match_cast = ptr->struct_info;
}
}

void VisitExpr(const Expr& expr) {
// already checked: do not revisit
if (purity_map_.count(expr)) {
impure_found_ = impure_found_ || !purity_map_.at(expr);
return;
}
friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) {
tvm::StructuralEqual eq;
return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast);
}
};

// in principle, we could stop checking once we find an impurity,
// but not doing so lets us fully populate the cache
} // namespace
} // namespace relax
} // namespace tvm

// store the previous state so we could assess the purity of this subexpression alone
bool prev_state = impure_found_;
impure_found_ = false;
ExprVisitor::VisitExpr(expr);
// if impure_found_ remains false, then the expression is pure
purity_map_[expr] = !impure_found_;
impure_found_ = prev_state || impure_found_;
/* \brief Definition of std::hash<ReplacementKey>
*
* Specialization of std::hash must occur outside of tvm::relax
* namespace, and before its usage in the constructor of
* `CommonSubexprEliminator`.
*/
template <>
struct std::hash<tvm::relax::ReplacementKey> {
std::size_t operator()(const tvm::relax::ReplacementKey& key) const {
tvm::StructuralHash hasher;
return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast));
}
};

void VisitExpr_(const CallNode* call) {
// the only possible impurities can come from call nodes
bool is_impure = IsImpureCall(GetRef<Call>(call));
impure_found_ = impure_found_ || is_impure;
ExprVisitor::VisitExpr_(call);
}
namespace tvm {
namespace relax {

private:
bool impure_found_ = false;
std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
};
namespace {

class SubexprCounter : public ExprVisitor {
class CommonSubexprEliminator : public ExprMutator {
public:
static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) {
SubexprCounter visitor;
visitor(expr);
return visitor.count_map_;
}
explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {}

// overriding VisitExpr ensures we do this for every subexpression
void VisitExpr(const Expr& e) override {
// Cases we ignore because we will not substitute them:
// 1. Vars of all kinds
// 2. Op nodes (nothing we can do)
// 3. PrimValue nodes (not much benefit from binding to a var)
// 4. StringImm nodes (not much benefit from binding to a var)
// 5. Scalar constants (not much benefit from binding to a var)
// 6. Shape expressions (exist to hold several PrimValue objects)
// 7. DataType nodes (no need to modify dtype nodes)
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
int count = 0;
if (count_map_.count(e)) {
count = count_map_.at(e);
}
count_map_[e] = count + 1;
}
}
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
auto cache_vars = var_remap_;
auto output = ExprMutator::VisitBindingBlock_(block);

// Only visit the interior of objects that we might still keep
// around. Otherwise, double-counting these would lead to extra
// variable bindings.
//
// Before:
// y = f(a+b)
// z = f(a+b)
//
// Expected:
// y = f(a+b) // De-duped from (y==z)
// z = y
//
// Erroneous output:
// c = a+b // Incorrect, a+b only has a single usage.
// y = f(c) // De-duped from
// z = y
//
if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) {
ExprVisitor::VisitExpr(e);
for (auto& [key, replacements] : expr_replacements_) {
replacements.erase(
std::remove_if(replacements.begin(), replacements.end(),
[](const Var& var) -> bool { return var->IsInstance<DataflowVarNode>(); }),
replacements.end());
}

var_remap_ = cache_vars;
return output;
}

// do not visit inner functions: we will do CSE within those
void VisitExpr_(const FunctionNode* func) override {}
void VisitBinding(const Binding& binding) override {
Expr bound_value = VisitExpr(GetBoundValue(binding));

Binding output_binding = [&]() -> Binding {
if (binding.as<VarBindingNode>()) {
return VarBinding(binding->var, bound_value);
} else if (auto match_cast = binding.as<MatchCastNode>()) {
return MatchCast(binding->var, bound_value, match_cast->struct_info);
} else {
LOG(FATAL) << "Binding must be either VarBinding or MatchCast, "
<< "but was " << binding->GetTypeKey();
}
}();

// we are not going to do replacements inside struct info to avoid binding lots of reused shapes
void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
ReplacementKey lookup_key(output_binding);

private:
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
ImpurityDetector impurity_detector_;
};
if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) {
VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value;

class CommonSubexprEliminator : public ExprMutator {
public:
explicit CommonSubexprEliminator(
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
bool call_only = false)
: count_map_(std::move(count_map)), call_only_(call_only) {}

// overriding here ensures we visit every subexpression
Expr VisitExpr(const Expr& e) override {
if (call_only_ && !e->IsInstance<CallNode>()) {
return ExprMutator::VisitExpr(e);
}
if (count_map_.count(e) && count_map_.at(e) > 1) {
// if we already have a mapping for it, get it
if (replacements_.count(e)) {
return replacements_.at(e);
}
// Otherwise, insert a new binding for the current expression.
// Visit before emitting to do inner replacements
Expr new_e = ExprMutator::VisitExpr(e);
Var v = builder_->Emit(new_e);
replacements_[e] = v;
return v;
}
return ExprMutator::VisitExpr(e);
}
} else if (ContainsImpureCall(bound_value)) {
VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value;

// we are not going to do replacements inside struct info to avoid binding lots of reused shapes
StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
return struct_info;
}
} else if (auto it = expr_replacements_.find(lookup_key);
it != expr_replacements_.end() && it->second.size()) {
VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second[0]
<< ". The duplicate binding of this value to " << binding->var
<< " will be replaced with a trivial binding, "
<< "and occurrences of " << binding->var << " will be replaced with "
<< it->second[0];
output_binding = VarBinding(binding->var, it->second[0]);
var_remap_.insert({binding->var->vid, it->second[0]});
it->second.push_back(binding->var);

Expr VisitExpr_(const FunctionNode* op) override {
Function func = GetRef<Function>(op);
} else {
VLOG(1) << "Value " << bound_value << " is bound to " << binding->var
<< " and may be de-duplicated if it occurs again.";

auto cache = SubexprCounter::Count(op->body);
std::swap(cache, count_map_);
Expr output = ExprMutator::VisitExpr_(op);
std::swap(cache, count_map_);
expr_replacements_[lookup_key].push_back(binding->var);
}

return output;
builder_->EmitNormalized(output_binding);
}

void VisitBinding_(const VarBindingNode* binding) override {
// no need to visit var def because the struct info isn't going to change
Expr new_value = RegisterBoundValue(binding->var, binding->value);

if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<VarBinding>(binding));
Expr VisitExpr_(const FunctionNode* op) override {
// If we have accumulated any state, visit the function in a fresh
// copy of the mutator, to avoid replacing a child-scope
// expression with a parent-scope binding, or vice versa.
if (expr_replacements_.size() || var_remap_.size()) {
return VisitWithCleanScope(GetRef<Expr>(op));
} else {
// no need to renormalize new_value because all replacements are with vars
builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span));
return ExprMutator::VisitExpr_(op);
}
}

void VisitBinding_(const MatchCastNode* binding) override {
// no need to visit var def because the struct info isn't going to change
Expr new_value = RegisterBoundValue(binding->var, binding->value);

// re-emit old binding if nothing changes
if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<MatchCast>(binding));
Expr VisitExpr_(const IfNode* op) override {
Expr cond = VisitExpr(op->cond);
Expr true_branch = VisitWithInnerScope(op->true_branch);
Expr false_branch = VisitWithInnerScope(op->false_branch);
if (op->cond.same_as(cond) && op->true_branch.same_as(true_branch) &&
op->false_branch.same_as(false_branch) &&
VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) {
return GetRef<Expr>(op);
} else {
// no need to renormalize new_value because all replacements are with vars
builder_->EmitNormalized(
MatchCast(binding->var, new_value, binding->struct_info, binding->span));
return If(cond, true_branch, false_branch, op->span);
}
}

private:
Expr RegisterBoundValue(Var var, Expr bound_value) {
// special case: if we are processing a binding
// and this is the first time we've encountered it,
// we will use the binding's var for the mapping
bool newly_replaced = false;
if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 &&
!replacements_.count(bound_value)) {
replacements_[bound_value] = var;
newly_replaced = true;
}
Expr VisitWithInnerScope(Expr expr) {
auto cached_vars = var_remap_;
auto cached_exprs = expr_replacements_;
auto output = VisitExpr(expr);
var_remap_ = cached_vars;
expr_replacements_ = cached_exprs;
return output;
}

if (newly_replaced) {
// If we've just added the mapping, using the overridden visitor will
// just return the var, which we don't want, so we will use
// the superclass VisitExpr to do inner substitutions
return ExprMutator::VisitExpr(bound_value);
}
return VisitExpr(bound_value);
Expr VisitWithCleanScope(Expr expr) {
CommonSubexprEliminator clean_mutator(call_only_);
return clean_mutator.VisitExpr(expr);
}

std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
bool call_only_{false};
std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_;
};

} // namespace

Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) {
CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only);
CommonSubexprEliminator mutator(call_only);
return mutator(expr);
}

Expand Down
Loading

0 comments on commit 864fd5c

Please sign in to comment.