Skip to content

Commit

Permalink
[Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in ConvertSSA (#16682)
Browse files Browse the repository at this point in the history
In some cases, an `AttrStmt` may legally refer to a TIR variable that
hasn't yet been defined.  For example, the
`"pragma_parallel_launch_point"` attribute, which annotates a variable
that is about to occur in a ForNode.  Prior to this commit,
`ConvertSSA` treated the `AttrStmt` as the usage of a variable,
followed by a nested definition to be de-duplicated.  This resulted in
the output `AttrStmt` containing a reference to an undefined variable.

This commit updates `ConvertSSA` to handle this case.  If an
`AttrStmt` refers to a not-yet-defined variable, the body is visited
before marking it as defined.

This implementation may be simplified in the future by
moving "pragma_parallel_launch_point" to be an annotation
on the `ForNode`, rather than an `AttrStmt`.
  • Loading branch information
Lunderberg authored Mar 9, 2024
1 parent 7b7677f commit 898f87f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
34 changes: 30 additions & 4 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class IRConvertSSA final : public StmtExprMutator {
}

Var var = iter_var->var;
bool delayed_define = false;
if (auto it = function_scope_var_remap_.find(var.get());
it != function_scope_var_remap_.end()) {
var = it->second;
Expand All @@ -373,8 +374,23 @@ class IRConvertSSA final : public StmtExprMutator {
function_scope_var_remap_.insert({var.get(), new_var});
var = new_var;
} else {
function_scope_var_remap_.insert({var.get(), var});
defined_.insert(var.get());
// The AttrStmt refers to an undefined variable. This is
// allowed for some attributes, such as
// "pragma_parallel_launch_point", which annotates a variable
// that is about to occur in a ForNode. In these cases, the
// ForNode and the AttrStmt must continue using the same
// variable defintion.
//
// However, other AttrStmt, such as "thread_extent", act as
// points of definition for the variable they annotate. If
// the variable has not been defined after visiting the body,
// we should mark it as defined before exiting. This ensures
// correct de-duplication between multiple functions.
//
// This implementation may be simplified in the future by
// moving "pragma_parallel_launch_point" to be an annotation
// on the `ForNode`, rather than an `AttrStmt`.
delayed_define = true;
}

IterVar new_iter_var;
Expand All @@ -387,12 +403,22 @@ class IRConvertSSA final : public StmtExprMutator {
auto value = VisitExpr(op->value);
auto body = VisitStmt(op->body);

Stmt output;
if (new_iter_var.get() == iter_var && body.same_as(op->body) && value.same_as(op->value)) {
return GetRef<Stmt>(op);
output = GetRef<Stmt>(op);
} else {
return AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
output = AttrStmt(new_iter_var, op->attr_key, value, body, iter_var->span);
}

if (delayed_define) {
if (!defined_.count(var.get())) {
function_scope_var_remap_.insert({var.get(), var});
defined_.insert(var.get());
}
}

return output;

} else if (const VarNode* v = op->node.as<VarNode>()) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
Expand Down
61 changes: 60 additions & 1 deletion tests/python/tir-transform/test_tir_transform_convert_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import tvm
import tvm.testing
from tvm import tir
from tvm import tir, ir
from tvm.script import tir as T, ir as I


Expand Down Expand Up @@ -485,5 +485,64 @@ def kernel_2(A: T.Buffer([256], "float32")):
return mod


class TestTrackForwardDeclarationsInAttrStmt(BaseBeforeAfter):
"""T.attr statements may refer to a about-to-be-defined tir.Var"""

def before(self):
"""Generate the PrimFunc, which is already SSA
This is constructed directly, rather than using TVMScript or
the `tvm.tir.ir_builder`. This test case requires a
`tir.AttrStmt` that references a variable, followed by the
`tir.For` defining that variable. This is not expressible in
either TVMScript or `tvm.tir.ir_builder`, as they only provide
the loop iterator within the body of the loop.
"""
i0_outer_outer = tir.Var("i0_outer_outer", "int32")
i0_outer_inner = tir.Var("i0_outer_inner", "int32")
i0_inner = tir.Var("i0_inner", "int32")

A = tir.decl_buffer(1024, "float32", "A")
B = tir.decl_buffer(1024, "float32", "B")

index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner

stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index])
stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None)
stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt)
stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt)
stmt = tir.AttrStmt(
T.iter_var(i0_outer_inner, None, "DataPar", ""),
"pragma_parallal_barrier_when_finish",
1,
stmt,
)
stmt = tir.AttrStmt(
T.iter_var(i0_outer_inner, None, "DataPar", ""),
"pragma_parallal_stride_pattern",
1,
stmt,
)
stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt)
stmt = tir.AttrStmt(
T.iter_var(i0_outer_outer, None, "DataPar", ""),
"pragma_parallal_launch_point",
1,
stmt,
)

A_handle = tir.Var("A_handle", "handle")
B_handle = tir.Var("B_handle", "handle")

func = tir.PrimFunc(
[A_handle, B_handle],
stmt,
buffer_map={A_handle: A, B_handle: B},
)
return func

expected = before


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 898f87f

Please sign in to comment.