diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index 943a11971115..c001d35054f3 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -228,6 +228,7 @@ class UndefinedVarVerifier : public Verifier { using Verifier::Verifier; private: + using Verifier::Visit; void Visit(const PrimFunc& prim_func, ObjectPath path) override { Verifier::Visit(prim_func, path); redefine_allowed_within_function_.clear(); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index a80f2300e2c8..37b3ce55a2ca 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { // variable has occurred. Therefore, to ensure that we only avoid // duplicate calls to VisitVarDef, these semantics need to be // checked. - std::unordered_set defined_params; std::vector, DefContext>> context; auto ppath = path->Attr("params"); for (size_t i = 0; i < func->params.size(); i++) { context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i))); - defined_params.insert(func->params[i]); } - auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr, - ObjectPath path) { - if (auto opt = expr.as()) { - auto var = opt.value(); - if (!defined_params.count(var)) { - context.push_back(WithDef(var, path)); - defined_params.insert(var); - } - } - }; - auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array& arr, - ObjectPath path) { - for (size_t i = 0; i < arr.size(); i++) { - try_visit_implicit_var_def(arr[i], path->ArrayIndex(i)); - } - }; - auto buffer_map_path = path->Attr("buffer_map"); for (size_t i = 0; i < func->params.size(); i++) { if (auto opt = func->buffer_map.Get(func->params[i])) { auto buf = opt.value(); auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); - // A buffer in the buffer_map always defines its data pointer - context.push_back(WithDef(buf->data, buf_path->Attr("data"))); - - // But other implicit definitions only apply if they weren't - // provided as explicit parameters, and they weren't defined - // implicitly by any previous buffer. - try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape")); - try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides")); - try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset")); + for (auto& def : WithMatchBufferDefs(buf, buf_path)) { + context.push_back(std::move(def)); + } } } @@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) { for (size_t i = 0; i < func->params.size(); i++) { if (auto opt = func->buffer_map.Get(func->params[i])) { auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i)); - EnterDef(opt.value(), buf_path); + context.push_back(WithDef(opt.value(), buf_path)); } } @@ -199,16 +174,40 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) { void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) { Visit(op->value, path->Attr("value")); - std::optional> context = std::nullopt; + std::vector, DefContext>> context; if (auto iter_var = op->node.as(); iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) { // Some attributes serve as a source of definition for the // tir::Var they annotate. - context = WithDef(iter_var.value(), path->Attr("node")); + context.push_back(WithDef(iter_var.value(), path->Attr("node"))); + + } else if (op->attr_key == attr::buffer_bind_scope) { + // The `attr::buffer_bind_scope` attribute defines a view into an + // existing buffer, similar to the newer + // `BlockNode::match_buffers` field. It requires the buffer being + // viewed to be defined prior to the attribute. The + // `attr::buffer_bind_scope` is the point of definition for the + // `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any + // symbolic shapes used within `buffer_view that are not already + // defined. + Array arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2U); + Buffer buffer_view = Downcast(arr[0]); + Buffer orig_buffer = Downcast(arr[1]); + Visit(orig_buffer, path->Attr("node")->ArrayIndex(1)); + + for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) { + context.push_back(std::move(var)); + } + } else if (auto expr = op->node.as()) { Visit(expr.value(), path->Attr("node")); } Visit(op->body, path->Attr("body")); + + while (context.size()) { + context.pop_back(); + } } void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) { @@ -250,7 +249,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path) void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) { Visit(op->condition, path->Attr("condition")); Visit(op->bounds, path->Attr("bounds")); - auto context = WithDef(op->buffer, path->Attr("buffer")); + auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data")); + Visit(op->buffer, path->Attr("buffer")); Visit(op->body, path->Attr("body")); } @@ -318,18 +318,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) { for (size_t i = 0; i < op->match_buffers.size(); i++) { auto buf = op->match_buffers[i]->buffer; auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer"); - auto buffer_strides_path = buffer_path->Attr("strides"); - context.push_back(WithDef(buf->data, buffer_path->Attr("data"))); - // Define buffer strides and elem_offset if they are vars - if (const auto* v = buf->elem_offset.as()) { - context.push_back(WithDef(GetRef(v), buffer_path->Attr("elem_offset"))); - } - for (size_t i = 0; i < buf->strides.size(); ++i) { - if (const auto* v = buf->strides[i].as()) { - context.push_back(WithDef(GetRef(v), buffer_strides_path->ArrayIndex(i))); - } + + for (auto& def : WithMatchBufferDefs(buf, buffer_path)) { + context.push_back(std::move(def)); } - context.push_back(WithDef(buf, buffer_path)); } } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index dd0da1fe77a9..1ae6df58f760 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -29,7 +29,10 @@ #include #include +#include +#include #include +#include namespace tvm { namespace tir { @@ -173,6 +176,7 @@ class TIRVisitorWithPath : protected ExprFunctorin_scope_definitions_.erase(obj_); self_->ExitDef(obj_, path_); } } @@ -182,6 +186,7 @@ class TIRVisitorWithPath : protected ExprFunctorin_scope_definitions_.insert(obj_); self_->EnterDef(obj_, path_); } @@ -203,6 +208,44 @@ class TIRVisitorWithPath : protected ExprFunctor WithDef(T obj, ObjectPath path) { return DefContext(this, obj, path); } + + /* \brief Utility to track the scope of a node's definition. */ + template + std::optional> WithDefIfUndefined(T obj, ObjectPath path) { + if (in_scope_definitions_.count(obj)) { + return std::nullopt; + } else { + return WithDef(obj, path); + } + } + + std::vector> WithMatchBufferDefs(Buffer buf, ObjectPath path) { + std::vector> context; + + auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) { + if (auto opt = expr.as()) { + auto var = opt.value(); + if (auto var_def = WithDefIfUndefined(var, path)) { + context.push_back(std::move(var_def).value()); + } + } + }; + auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def]( + const Array& arr, ObjectPath path) { + for (size_t i = 0; i < arr.size(); i++) { + try_visit_implicit_var_def(arr[i], path->ArrayIndex(i)); + } + }; + + try_visit_implicit_var_def(buf->data, path->Attr("data")); + try_visit_implicit_var_def_array(buf->shape, path->Attr("shape")); + try_visit_implicit_var_def_array(buf->strides, path->Attr("strides")); + try_visit_implicit_var_def(buf->elem_offset, path->Attr("elem_offset")); + + return context; + } + + std::unordered_set in_scope_definitions_; }; } // namespace tir diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index 8c153afc9de9..a1b3bee1b282 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -199,5 +199,154 @@ def kernel_2(A: T.Buffer([256], "float32")): tvm.tir.analysis.verify_well_formed(mod) +def test_multiple_buffer_arguments_may_share_allocation(): + """T.match_buffer may re-use a data argument + + Like the shape/strides/elem_offset fields in a buffer, the first + occurrence of a `buffer->data` field defines it, and the + occurrences are usages of that definition. + """ + + @I.ir_module + class mod: + @T.prim_func + def func(A_handle: T.handle, B_handle: T.handle): + A = T.match_buffer(A_handle, [256], "float32") + B = T.match_buffer(B_handle, [256], "float32", data=A.data) + + pass + + tvm.tir.analysis.verify_well_formed(mod) + + +def test_buffer_bind_scope_defines_buffer_obj(): + """The "buffer_bind_scope" attribute defines a buffer view""" + + @I.ir_module + class mod: + @T.prim_func + def func(A: T.Buffer([256, 256], "float32")): + + for tile_i, tile_j in T.grid(16, 16): + B = T.Buffer([16, 16], "float32") + T.attr( + [B, A], + "buffer_bind_scope", + T.tvm_tuple( + tile_i * 16, + 16, + tile_j * 16, + 16, + dtype="handle", + ), + ) + for i, j in T.grid(16, 16): + B[i, j] = 0.0 + + tvm.tir.analysis.verify_well_formed(mod) + + +def test_buffer_bind_scope_defines_symbolic_variables(): + """The "buffer_bind_scope" attribute may define symbolic variables""" + + @I.ir_module + class mod: + @T.prim_func + def func(A: T.Buffer([256, 256], "int32")): + + for tile_i, tile_j in T.grid(16, 16): + elem_offset = T.int32() + B = T.Buffer([16, 16], "int32", elem_offset=elem_offset) + T.attr( + [B, A], + "buffer_bind_scope", + T.tvm_tuple( + tile_i * 16, + 16, + tile_j * 16, + 16, + dtype="handle", + ), + ) + for i, j in T.grid(16, 16): + B[i, j] = elem_offset + + tvm.tir.analysis.verify_well_formed(mod) + + +def test_block_match_buffer_defines_buffer_obj(): + """In a block, T.match_buffer defines a buffer view""" + + @I.ir_module + class mod: + @T.prim_func + def func(A: T.Buffer([256, 256], "float32")): + for iters in T.grid(16, 16, 16, 16): + with T.block("compute"): + tile_i, tile_j, i, j = T.axis.remap("SSSS", iters) + B = T.match_buffer( + A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16], + dtype="float32", + ) + B[i, j] = 0.0 + + tvm.tir.analysis.verify_well_formed(mod) + + +def test_block_match_buffer_defines_symbolic_variables(): + """In a block, T.match_buffer may define symbolic variables""" + + @I.ir_module + class mod: + @T.prim_func + def func(A: T.Buffer([256, 256], "int32")): + + for iters in T.grid(16, 16, 16, 16): + with T.block("compute"): + tile_i, tile_j, i, j = T.axis.remap("SSSS", iters) + + elem_offset = T.int32() + B = T.match_buffer( + A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16], + dtype="float32", + elem_offset=elem_offset, + ) + + B[i, j] = elem_offset + + tvm.tir.analysis.verify_well_formed(mod) + + +def test_buffer_realize_on_external_buffer_is_annotation(): + """A T.realize statement on an existing buffer annotates the region used""" + + @I.ir_module + class mod: + @T.prim_func + def func(A: T.Buffer(256, "int32")): + T.realize(A[0:16], "global") + + for i in range(16): + A[i] = 1 + + tvm.tir.analysis.verify_well_formed(mod) + + +def test_buffer_realize_is_allocation(): + """A T.realize statement on an fresh buffer allocates the buffer""" + + @I.ir_module + class mod: + @T.prim_func + def func(): + A = T.Buffer(256, "int32") + T.realize(A[0:16], "global") + + for i in range(16): + A[i] = 1 + + tvm.tir.analysis.verify_well_formed(mod) + + if __name__ == "__main__": tvm.testing.main()