Skip to content

Commit

Permalink
[TIR] Improve well-formed check's handling of match buffer (apache#16655
Browse files Browse the repository at this point in the history
)

* [TIR] Improve well-formed check's handling of match buffer

- The `T.match_buffer` at the start of a function may contain repeated
  use of the same data var.  For example, a function that must accept
  two `DLTensor` objects with the same backing allocation.

- The `"buffer_bind_scope"` is an older style of match buffer, and may
  be the point of definition for variables.

* Improved comment, added context.pop_back()
  • Loading branch information
Lunderberg authored and thaisacs committed Apr 3, 2024
1 parent 34122bf commit 44828dd
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/tir/analysis/verify_well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
using Verifier::Verifier;

private:
using Verifier::Visit;
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
redefine_allowed_within_function_.clear();
Expand Down
78 changes: 35 additions & 43 deletions src/tir/ir/tir_visitor_with_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
// variable has occurred. Therefore, to ensure that we only avoid
// duplicate calls to VisitVarDef, these semantics need to be
// checked.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_params;
std::vector<std::variant<DefContext<Var>, DefContext<Buffer>>> context;

auto ppath = path->Attr("params");
for (size_t i = 0; i < func->params.size(); i++) {
context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i)));
defined_params.insert(func->params[i]);
}

auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr,
ObjectPath path) {
if (auto opt = expr.as<Var>()) {
auto var = opt.value();
if (!defined_params.count(var)) {
context.push_back(WithDef(var, path));
defined_params.insert(var);
}
}
};
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array<PrimExpr>& arr,
ObjectPath path) {
for (size_t i = 0; i < arr.size(); i++) {
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
}
};

auto buffer_map_path = path->Attr("buffer_map");
for (size_t i = 0; i < func->params.size(); i++) {
if (auto opt = func->buffer_map.Get(func->params[i])) {
auto buf = opt.value();
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));

// A buffer in the buffer_map always defines its data pointer
context.push_back(WithDef(buf->data, buf_path->Attr("data")));

// But other implicit definitions only apply if they weren't
// provided as explicit parameters, and they weren't defined
// implicitly by any previous buffer.
try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape"));
try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides"));
try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset"));
for (auto& def : WithMatchBufferDefs(buf, buf_path)) {
context.push_back(std::move(def));
}
}
}

Expand All @@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
for (size_t i = 0; i < func->params.size(); i++) {
if (auto opt = func->buffer_map.Get(func->params[i])) {
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
EnterDef(opt.value(), buf_path);
context.push_back(WithDef(opt.value(), buf_path));
}
}

Expand Down Expand Up @@ -199,16 +174,40 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) {
void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) {
Visit(op->value, path->Attr("value"));

std::optional<DefContext<IterVar>> context = std::nullopt;
std::vector<std::variant<DefContext<IterVar>, DefContext<Var>>> context;
if (auto iter_var = op->node.as<IterVar>();
iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) {
// Some attributes serve as a source of definition for the
// tir::Var they annotate.
context = WithDef(iter_var.value(), path->Attr("node"));
context.push_back(WithDef(iter_var.value(), path->Attr("node")));

} else if (op->attr_key == attr::buffer_bind_scope) {
// The `attr::buffer_bind_scope` attribute defines a view into an
// existing buffer, similar to the newer
// `BlockNode::match_buffers` field. It requires the buffer being
// viewed to be defined prior to the attribute. The
// `attr::buffer_bind_scope` is the point of definition for the
// `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any
// symbolic shapes used within `buffer_view that are not already
// defined.
Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
ICHECK_EQ(arr.size(), 2U);
Buffer buffer_view = Downcast<Buffer>(arr[0]);
Buffer orig_buffer = Downcast<Buffer>(arr[1]);
Visit(orig_buffer, path->Attr("node")->ArrayIndex(1));

for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) {
context.push_back(std::move(var));
}

} else if (auto expr = op->node.as<PrimExpr>()) {
Visit(expr.value(), path->Attr("node"));
}
Visit(op->body, path->Attr("body"));

while (context.size()) {
context.pop_back();
}
}

void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) {
Expand Down Expand Up @@ -250,7 +249,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path)
void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) {
Visit(op->condition, path->Attr("condition"));
Visit(op->bounds, path->Attr("bounds"));
auto context = WithDef(op->buffer, path->Attr("buffer"));
auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data"));
Visit(op->buffer, path->Attr("buffer"));
Visit(op->body, path->Attr("body"));
}

Expand Down Expand Up @@ -318,18 +318,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) {
for (size_t i = 0; i < op->match_buffers.size(); i++) {
auto buf = op->match_buffers[i]->buffer;
auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer");
auto buffer_strides_path = buffer_path->Attr("strides");
context.push_back(WithDef(buf->data, buffer_path->Attr("data")));
// Define buffer strides and elem_offset if they are vars
if (const auto* v = buf->elem_offset.as<VarNode>()) {
context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset")));
}
for (size_t i = 0; i < buf->strides.size(); ++i) {
if (const auto* v = buf->strides[i].as<VarNode>()) {
context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i)));
}

for (auto& def : WithMatchBufferDefs(buf, buffer_path)) {
context.push_back(std::move(def));
}
context.push_back(WithDef(buf, buffer_path));
}
}

Expand Down
43 changes: 43 additions & 0 deletions src/tir/ir/tir_visitor_with_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
#include <tvm/tir/stmt_functor.h>

#include <exception>
#include <optional>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -173,6 +176,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
// construction of the DefContext and the destruction, we avoid
// this case and allow the first error to propagate upward.
if (self_ && std::uncaught_exceptions() == uncaught_exceptions_) {
self_->in_scope_definitions_.erase(obj_);
self_->ExitDef(obj_, path_);
}
}
Expand All @@ -182,6 +186,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat

DefContext(TIRVisitorWithPath* self, T obj, ObjectPath path)
: self_(self), obj_(obj), path_(path), uncaught_exceptions_(std::uncaught_exceptions()) {
self_->in_scope_definitions_.insert(obj_);
self_->EnterDef(obj_, path_);
}

Expand All @@ -203,6 +208,44 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
DefContext<T> WithDef(T obj, ObjectPath path) {
return DefContext(this, obj, path);
}

/* \brief Utility to track the scope of a node's definition. */
template <typename T>
std::optional<DefContext<T>> WithDefIfUndefined(T obj, ObjectPath path) {
if (in_scope_definitions_.count(obj)) {
return std::nullopt;
} else {
return WithDef(obj, path);
}
}

std::vector<DefContext<Var>> WithMatchBufferDefs(Buffer buf, ObjectPath path) {
std::vector<DefContext<Var>> context;

auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) {
if (auto opt = expr.as<Var>()) {
auto var = opt.value();
if (auto var_def = WithDefIfUndefined(var, path)) {
context.push_back(std::move(var_def).value());
}
}
};
auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](
const Array<PrimExpr>& arr, ObjectPath path) {
for (size_t i = 0; i < arr.size(); i++) {
try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
}
};

try_visit_implicit_var_def(buf->data, path->Attr("data"));
try_visit_implicit_var_def_array(buf->shape, path->Attr("shape"));
try_visit_implicit_var_def_array(buf->strides, path->Attr("strides"));
try_visit_implicit_var_def(buf->elem_offset, path->Attr("elem_offset"));

return context;
}

std::unordered_set<ObjectRef, ObjectPtrHash, ObjectPtrEqual> in_scope_definitions_;
};

} // namespace tir
Expand Down
149 changes: 149 additions & 0 deletions tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,154 @@ def kernel_2(A: T.Buffer([256], "float32")):
tvm.tir.analysis.verify_well_formed(mod)


def test_multiple_buffer_arguments_may_share_allocation():
"""T.match_buffer may re-use a data argument
Like the shape/strides/elem_offset fields in a buffer, the first
occurrence of a `buffer->data` field defines it, and the
occurrences are usages of that definition.
"""

@I.ir_module
class mod:
@T.prim_func
def func(A_handle: T.handle, B_handle: T.handle):
A = T.match_buffer(A_handle, [256], "float32")
B = T.match_buffer(B_handle, [256], "float32", data=A.data)

pass

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_bind_scope_defines_buffer_obj():
"""The "buffer_bind_scope" attribute defines a buffer view"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "float32")):

for tile_i, tile_j in T.grid(16, 16):
B = T.Buffer([16, 16], "float32")
T.attr(
[B, A],
"buffer_bind_scope",
T.tvm_tuple(
tile_i * 16,
16,
tile_j * 16,
16,
dtype="handle",
),
)
for i, j in T.grid(16, 16):
B[i, j] = 0.0

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_bind_scope_defines_symbolic_variables():
"""The "buffer_bind_scope" attribute may define symbolic variables"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "int32")):

for tile_i, tile_j in T.grid(16, 16):
elem_offset = T.int32()
B = T.Buffer([16, 16], "int32", elem_offset=elem_offset)
T.attr(
[B, A],
"buffer_bind_scope",
T.tvm_tuple(
tile_i * 16,
16,
tile_j * 16,
16,
dtype="handle",
),
)
for i, j in T.grid(16, 16):
B[i, j] = elem_offset

tvm.tir.analysis.verify_well_formed(mod)


def test_block_match_buffer_defines_buffer_obj():
"""In a block, T.match_buffer defines a buffer view"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "float32")):
for iters in T.grid(16, 16, 16, 16):
with T.block("compute"):
tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
B = T.match_buffer(
A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
dtype="float32",
)
B[i, j] = 0.0

tvm.tir.analysis.verify_well_formed(mod)


def test_block_match_buffer_defines_symbolic_variables():
"""In a block, T.match_buffer may define symbolic variables"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer([256, 256], "int32")):

for iters in T.grid(16, 16, 16, 16):
with T.block("compute"):
tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)

elem_offset = T.int32()
B = T.match_buffer(
A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
dtype="float32",
elem_offset=elem_offset,
)

B[i, j] = elem_offset

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_realize_on_external_buffer_is_annotation():
"""A T.realize statement on an existing buffer annotates the region used"""

@I.ir_module
class mod:
@T.prim_func
def func(A: T.Buffer(256, "int32")):
T.realize(A[0:16], "global")

for i in range(16):
A[i] = 1

tvm.tir.analysis.verify_well_formed(mod)


def test_buffer_realize_is_allocation():
"""A T.realize statement on an fresh buffer allocates the buffer"""

@I.ir_module
class mod:
@T.prim_func
def func():
A = T.Buffer(256, "int32")
T.realize(A[0:16], "global")

for i in range(16):
A[i] = 1

tvm.tir.analysis.verify_well_formed(mod)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 44828dd

Please sign in to comment.