Skip to content

Commit

Permalink
[Arith] Fix handling of overlapping predicates (#15555)
Browse files Browse the repository at this point in the history
When the lower bound predicate is present, `TryFuseIters` will rewrite
an expression to `IterMarkWithOffset`, where the iter mark (structured
form) has minimum value of zero. During rewriting, when the subexpression
matches a previous rewritten result (`IterMarkWithOffset`), previously
it doesn't correctly add the offset to the result `IterMarkWithOffset`,
and causes inconsistency between the flattened form and the structured
(normal) form.
  • Loading branch information
vinx13 authored Aug 16, 2023
1 parent 0c6fbb8 commit 760c030
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
34 changes: 24 additions & 10 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ class IterMapRewriter : public ExprMutator {
bool requires_padding_{false};

// The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
// an extra offset)
// an extra offset). The normal form always has minimum value of zero.
// Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=9),
Expand All @@ -497,6 +497,7 @@ class IterMapRewriter : public ExprMutator {
// IterSplit(k, scale=1)),
// extent=9)
// scale=1))
// offset = 0
// Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: 1 <= j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=8),
Expand All @@ -507,9 +508,15 @@ class IterMapRewriter : public ExprMutator {
// IterSplit(k, scale=1), base=-1),
// extent=9-1)
// scale=1),
// base=1)
// base=0)
// offset = 1
std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_;
// The map for sum that maps normal form to flattened form
// For sum_fuse_map_ and flattened_map_ the following invariants hold:
// for any IterSumExpr e in the flattened_form, we have
// iter_mark, mark_offset = sum_fuse_map_[e]
// flattened_map_[normal_form] = e where normal_form = iter_mark->args[0] and
// iter_mark->args.size() = 1
std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_;
// The flattened forms of constrained iters
std::vector<IterSumExpr> constrained_iters_flattened_;
Expand Down Expand Up @@ -685,7 +692,10 @@ class IterMapRewriter : public ExprMutator {
PrimExpr mark_offset = it_mark->second.offset;
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
// the delta of iter_min when it is updated when the lower bound predicate is present
PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0);
if (predicate_induced_min.defined()) {
iter_min_delta = predicate_induced_min.value() - iter_min;
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
Expand All @@ -704,10 +714,12 @@ class IterMapRewriter : public ExprMutator {
iter_max = min(predicate_induced_max.value(), iter_max);
}
}
if (!is_zero(iter_min)) {
// When iter_min_delta is present, we need to normalize the structured form to have minimum of
// 0, and add the delta to the mark_offset
if (!is_zero(iter_min_delta)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
structured_form.CopyOnWrite()->base -= iter_min_delta;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
}
Expand All @@ -716,8 +728,9 @@ class IterMapRewriter : public ExprMutator {
// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
IterSumExprNode* normalized_expr = expr.CopyOnWrite();
normalized_expr->args = Array<IterSplitExpr>({split});
normalized_expr->base = base;
return expr;
}
ErrorLogger(this) << "Could not normalize iterators using the constraints given.";
Expand Down Expand Up @@ -1089,8 +1102,8 @@ class IterMapRewriter : public ExprMutator {
std::vector<IterSplitExpr> flattened_iters, grouped_iters;

// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_extra_base = make_const(expr.dtype(), 0);
PrimExpr tail_extent = make_const(expr.dtype(), 0);
PrimExpr expected_scale = base_scale;
int first_possible_unit_extent_pos = FindFirstPossibleUnitExtentIndex(expr);

Expand Down Expand Up @@ -1143,8 +1156,9 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale)) {
break;
}
}
}
if (k == expr->args.size()) {
Expand Down Expand Up @@ -1201,7 +1215,7 @@ class IterMapRewriter : public ExprMutator {
} else {
// new iter, form a new mark
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale) + tail_extent);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, expected_extra_base);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale)}, expr->base + expected_extra_base);
}
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_compound_floormod_two_regression():
def test_predicate():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")

# available contraints
# upper bound only
Expand Down Expand Up @@ -269,6 +270,12 @@ def test_predicate():
predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127),
)

assert_iter_sum_pattern(
{x * 64 + y * 4 + z: (16, 16)},
var_dom([(x, 16), (y, 16), (z, 4)]),
predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, 4 <= x * 16 + y),
)

# constraint on one fused iter
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
Expand Down

0 comments on commit 760c030

Please sign in to comment.