Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Improve well-formed check's handling of match buffer #16655

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
using Verifier::Verifier;

private:
using Verifier::Visit;
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
redefine_allowed_within_function_.clear();
Expand Down
78 changes: 35 additions & 43 deletions src/tir/ir/tir_visitor_with_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
// variable has occurred. Therefore, to ensure that we only avoid
// duplicate calls to VisitVarDef, these semantics need to be
// checked.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_params;
std::vector<std::variant<DefContext<Var>, DefContext<Buffer>>> context;

auto ppath = path->Attr("params");
for (size_t i = 0; i < func->params.size(); i++) {
context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i)));
defined_params.insert(func->params[i]);
}

auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr,
ObjectPath path) {
if (auto opt = expr.as<Var>()) {
auto var = opt.value();
if (!defined_params.count(var)) {
context.push_back(WithDef(var, path));
defined_params.insert(var);
}
}
};
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array<PrimExpr>& arr,
ObjectPath path) {
for (size_t i = 0; i < arr.size(); i++) {
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
}
};

auto buffer_map_path = path->Attr("buffer_map");
for (size_t i = 0; i < func->params.size(); i++) {
if (auto opt = func->buffer_map.Get(func->params[i])) {
auto buf = opt.value();
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));

// A buffer in the buffer_map always defines its data pointer
context.push_back(WithDef(buf->data, buf_path->Attr("data")));

// But other implicit definitions only apply if they weren't
// provided as explicit parameters, and they weren't defined
// implicitly by any previous buffer.
try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape"));
try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides"));
try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset"));
for (auto& def : WithMatchBufferDefs(buf, buf_path)) {
context.push_back(std::move(def));
}
}
}

Expand All @@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
for (size_t i = 0; i < func->params.size(); i++) {
if (auto opt = func->buffer_map.Get(func->params[i])) {
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
EnterDef(opt.value(), buf_path);
context.push_back(WithDef(opt.value(), buf_path));
}
}

Expand Down Expand Up @@ -199,16 +174,40 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) {
void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) {
Visit(op->value, path->Attr("value"));

std::optional<DefContext<IterVar>> context = std::nullopt;
std::vector<std::variant<DefContext<IterVar>, DefContext<Var>>> context;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ever used? I don't see any reads from it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There aren't any reads from it, as it holds a scoped context manager. On destruction, the DefContext<T> object removes items from TIRVisitorWithPath::in_scope_definitions_, and calls the ExitDef handler of the child class.

Also, thank you for pointing this one out. When switching from std::optional to std::vector, I forgot to add a while(context.size()) context.pop_back(); loop in case child classes rely on ExitDef being called in the reverse order from EnterDef.

if (auto iter_var = op->node.as<IterVar>();
iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) {
// Some attributes serve as a source of definition for the
// tir::Var they annotate.
context = WithDef(iter_var.value(), path->Attr("node"));
context.push_back(WithDef(iter_var.value(), path->Attr("node")));

} else if (op->attr_key == attr::buffer_bind_scope) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth commenting that this acts as an older form of MatchBuffer, per the PR description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and added a comment with description.

// The `attr::buffer_bind_scope` attribute defines a view into an
// existing buffer, similar to the newer
// `BlockNode::match_buffers` field. It requires the buffer being
// viewed to be defined prior to the attribute. The
// `attr::buffer_bind_scope` is the point of definition for the
// `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any
// symbolic shapes used within `buffer_view that are not already
// defined.
Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
ICHECK_EQ(arr.size(), 2U);
Buffer buffer_view = Downcast<Buffer>(arr[0]);
Buffer orig_buffer = Downcast<Buffer>(arr[1]);
Visit(orig_buffer, path->Attr("node")->ArrayIndex(1));

for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) {
context.push_back(std::move(var));
}

} else if (auto expr = op->node.as<PrimExpr>()) {
Visit(expr.value(), path->Attr("node"));
}
Visit(op->body, path->Attr("body"));

while (context.size()) {
Copy link
Contributor

@slyubomirsky slyubomirsky Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. It's probably worth also putting in a comment that this is to ensure that the defs expire, otherwise it seems like spooky action at a distance. Not 100% sure, as the name "DefContext" does imply that it's an RAII sort of thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried to follow the existing FooContext naming structure (e.g. arith::ConstraintContext).

context.pop_back();
}
}

void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) {
Expand Down Expand Up @@ -250,7 +249,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path)
void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) {
Visit(op->condition, path->Attr("condition"));
Visit(op->bounds, path->Attr("bounds"));
auto context = WithDef(op->buffer, path->Attr("buffer"));
auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this accounts for the case where a BufferRealize can act as a point of definition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. In cases where the buffer's backing allocation is defined externally, the BufferRealize is an annotation of the bounds where the external buffer is accessed. Otherwise, BufferRealize is an allocation. Prior to this commit, only the external backing allocation was handled.

Visit(op->buffer, path->Attr("buffer"));
Visit(op->body, path->Attr("body"));
}

Expand Down Expand Up @@ -318,18 +318,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) {
for (size_t i = 0; i < op->match_buffers.size(); i++) {
auto buf = op->match_buffers[i]->buffer;
auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer");
auto buffer_strides_path = buffer_path->Attr("strides");
context.push_back(WithDef(buf->data, buffer_path->Attr("data")));
// Define buffer strides and elem_offset if they are vars
if (const auto* v = buf->elem_offset.as<VarNode>()) {
context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset")));
}
for (size_t i = 0; i < buf->strides.size(); ++i) {
if (const auto* v = buf->strides[i].as<VarNode>()) {
context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i)));
}

for (auto& def : WithMatchBufferDefs(buf, buffer_path)) {
context.push_back(std::move(def));
}
context.push_back(WithDef(buf, buffer_path));
}
}

Expand Down
43 changes: 43 additions & 0 deletions src/tir/ir/tir_visitor_with_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
#include <tvm/tir/stmt_functor.h>

#include <exception>
#include <optional>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -173,6 +176,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
// construction of the DefContext and the destruction, we avoid
// this case and allow the first error to propagate upward.
if (self_ && std::uncaught_exceptions() == uncaught_exceptions_) {
self_->in_scope_definitions_.erase(obj_);
self_->ExitDef(obj_, path_);
}
}
Expand All @@ -182,6 +186,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat

DefContext(TIRVisitorWithPath* self, T obj, ObjectPath path)
: self_(self), obj_(obj), path_(path), uncaught_exceptions_(std::uncaught_exceptions()) {
self_->in_scope_definitions_.insert(obj_);
self_->EnterDef(obj_, path_);
}

Expand All @@ -203,6 +208,44 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
DefContext<T> WithDef(T obj, ObjectPath path) {
return DefContext(this, obj, path);
}

/* \brief Utility to track the scope of a node's definition. */
template <typename T>
std::optional<DefContext<T>> WithDefIfUndefined(T obj, ObjectPath path) {
if (in_scope_definitions_.count(obj)) {
return std::nullopt;
} else {
return WithDef(obj, path);
}
}

std::vector<DefContext<Var>> WithMatchBufferDefs(Buffer buf, ObjectPath path) {
std::vector<DefContext<Var>> context;

auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) {
if (auto opt = expr.as<Var>()) {
auto var = opt.value();
if (auto var_def = WithDefIfUndefined(var, path)) {
context.push_back(std::move(var_def).value());
}
}
};
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](
const Array<PrimExpr>& arr, ObjectPath path) {
for (size_t i = 0; i < arr.size(); i++) {
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
}
};

try_visit_implicit_var_def(buf->data, path->Attr("data"));
try_visit_implicit_var_def_array(buf->shape, path->Attr("shape"));
try_visit_implicit_var_def_array(buf->strides, path->Attr("strides"));
try_visit_implicit_var_def(buf->elem_offset, path->Attr("elem_offset"));

return context;
}

std::unordered_set<ObjectRef, ObjectPtrHash, ObjectPtrEqual> in_scope_definitions_;
};

} // namespace tir
Expand Down
149 changes: 149 additions & 0 deletions tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,154 @@ def kernel_2(A: T.Buffer([256], "float32")):
tvm.tir.analysis.verify_well_formed(mod)


def test_multiple_buffer_arguments_may_share_allocation():
"""T.match_buffer may re-use a data argument

Like the shape/strides/elem_offset fields in a buffer, the first
occurrence of a `buffer->data` field defines it, and the
occurrences are usages of that definition.
"""

@I.ir_module
class mod:
@T.prim_func
def func(A_handle: T.handle, B_handle: T.handle):
A = T.match_buffer(A_handle, [256], "float32")
B = T.match_buffer(B_handle, [256], "float32", data=A.data)

pass

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_bind_scope_defines_buffer_obj():
"""The "buffer_bind_scope" attribute defines a buffer view"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "float32")):

for tile_i, tile_j in T.grid(16, 16):
B = T.Buffer([16, 16], "float32")
T.attr(
[B, A],
"buffer_bind_scope",
T.tvm_tuple(
tile_i * 16,
16,
tile_j * 16,
16,
dtype="handle",
),
)
for i, j in T.grid(16, 16):
B[i, j] = 0.0

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_bind_scope_defines_symbolic_variables():
"""The "buffer_bind_scope" attribute may define symbolic variables"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "int32")):

for tile_i, tile_j in T.grid(16, 16):
elem_offset = T.int32()
B = T.Buffer([16, 16], "int32", elem_offset=elem_offset)
T.attr(
[B, A],
"buffer_bind_scope",
T.tvm_tuple(
tile_i * 16,
16,
tile_j * 16,
16,
dtype="handle",
),
)
for i, j in T.grid(16, 16):
B[i, j] = elem_offset

tvm.tir.analysis.verify_well_formed(mod)


def test_block_match_buffer_defines_buffer_obj():
"""In a block, T.match_buffer defines a buffer view"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "float32")):
for iters in T.grid(16, 16, 16, 16):
with T.block("compute"):
tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
B = T.match_buffer(
A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
dtype="float32",
)
B[i, j] = 0.0

tvm.tir.analysis.verify_well_formed(mod)


def test_block_match_buffer_defines_symbolic_variables():
"""In a block, T.match_buffer may define symbolic variables"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "int32")):

for iters in T.grid(16, 16, 16, 16):
with T.block("compute"):
tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)

elem_offset = T.int32()
B = T.match_buffer(
A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
dtype="float32",
elem_offset=elem_offset,
)

B[i, j] = elem_offset

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_realize_on_external_buffer_is_annotation():
"""A T.realize statement on an existing buffer annotates the region used"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer(256, "int32")):
T.realize(A[0:16], "global")

for i in range(16):
A[i] = 1

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_realize_is_allocation():
"""A T.realize statement on an fresh buffer allocates the buffer"""

@I.ir_module
class mod:
@T.prim_func
def func():
A = T.Buffer(256, "int32")
T.realize(A[0:16], "global")

for i in range(16):
A[i] = 1

tvm.tir.analysis.verify_well_formed(mod)


if __name__ == "__main__":
tvm.testing.main()
Loading