From 89cc09c62103d74dce02e03754261b1e205cadab Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 23 Feb 2024 08:41:26 -0600 Subject: [PATCH] [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul (#16591) * [Unity][Transform] Handle dynamic shapes in CombineParallelMatmul Prior to this commit, if the weight of a matmul a dynamic shape, and that matmul is being combined with the `CombineParallelMatmul` pass, it could cause a segfault when `dim.as()` returns a null pointer. This commit adds explicit test cases for these dynamic shapes, and updates `CombineParallelMatmul` to handle the dynamic shapes. * Add Tuple constructor for PR-16589 --- include/tvm/relax/expr.h | 18 ++ .../transform/combine_parallel_matmul.cc | 160 +++++++++++------- .../test_transform_combine_parallel_matmul.py | 123 +++++++++++++- 3 files changed, 240 insertions(+), 61 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index bb1b2c8dd74a..23262ea81794 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -320,6 +320,24 @@ class Tuple : public Expr { */ TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); + /*! + * \brief Utility constructor to handle conversion to relax::Expr + * + * If the calling scope already has an array of a specific type of + * relax expression (e.g. `Array`), it must be converted + * into an array of base type. This constructor handles the + * conversion to the base `Array`. + * + * \tparam RelaxExpr The type of relax expression passed in as an argument. + * + * \param fields The fields of a tuple. + * + * \param span The source span of the expression. + */ + template >> + TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()) + : Tuple(fields.Map([](const RelaxExpr& expr) -> Expr { return expr; }), span) {} + TVM_DEFINE_OBJECT_REF_METHODS(Tuple, Expr, TupleNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); }; diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 3ea17fdd70ea..7e6aa6277b0b 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -71,7 +71,16 @@ struct Patterns { WildcardPattern input; std::vector rhs; std::vector bias; - std::vector matmul, bias_add, activation; + std::vector matmul; + std::vector bias_add; + std::vector activation; +}; + +struct SplitInfo { + Var rhs; + Optional bias; + PrimExpr split_size; + DFPattern pattern_to_replace; }; Patterns CreatePatterns(const BranchInfo& branch_info) { @@ -140,40 +149,68 @@ runtime::TypedPackedFunc(Map, Map)> Ge for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) { if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue; - auto inp = matchings[patterns.input]; + auto lhs = matchings[patterns.input]; + + const auto& patterns_to_replace = [&patterns, &branch_info]() { + if (branch_info.activation) return patterns.activation; + if (branch_info.bias_dim) return patterns.bias_add; + return patterns.matmul; + }(); - Array rhs, bias; - for (auto ind : indices) { - rhs.push_back(matchings[patterns.rhs[ind]]); - if (branch_info.bias_dim) { - ICHECK(matchings.count(patterns.bias[ind])); - bias.push_back(matchings[patterns.bias[ind]]); + std::vector splits; + for (auto index : indices) { + Var rhs = matchings[patterns.rhs[index]]; + Optional bias = NullOpt; + if (branch_info.bias_dim.has_value()) { + bias = matchings[patterns.bias[index]]; } + PrimExpr split_size = GetTensorSInfo(rhs)->GetShape().value()[rhs_dim - 1]; + DFPattern pattern_to_replace = patterns_to_replace[index]; + splits.push_back(SplitInfo{rhs, bias, split_size, pattern_to_replace}); + } + // At most one dynamic output shape can be part of the combined + // matmul, and it must be the last item in the split. Use + // `std::stable_sort` instead of `std::sort` to maintain a + // consistent order for all static shapes, and to consistently + // select the same dynamic weight to participate. + auto is_dynamic_split = [](const SplitInfo& split) -> bool { + return !split.split_size->IsInstance(); + }; + std::stable_sort(splits.begin(), splits.end(), + [&is_dynamic_split](const auto& a, const auto& b) { + return is_dynamic_split(a) < is_dynamic_split(b); + }); + // Remove anything after the first dynamic shape participating + // in the combined matmul. + if (auto it = std::find_if(splits.begin(), splits.end(), is_dynamic_split); + it != splits.end()) { + splits.erase(it + 1, splits.end()); } - if (!check(inp, rhs, bias, bindings)) { + if (splits.size() == 1) { continue; } - auto make_tuple = [](const Array& var_array) { - Array exp_array; - for (auto v : var_array) exp_array.push_back(v); - return Tuple(exp_array); - }; + Array rhs; + Array bias; + for (const auto& split : splits) { + rhs.push_back(split.rhs); + if (split.bias) { + bias.push_back(split.bias.value()); + } + } - auto concat_rhs = concat(make_tuple(rhs), Integer(rhs_dim - 1)); - auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype; - auto matmul_combined = matmul(inp, concat_rhs, out_dtype); + if (!check(lhs, rhs, bias, bindings)) { + continue; + } - const auto& pattern_to_replace = [&patterns, &branch_info]() { - if (branch_info.activation) return patterns.activation; - if (branch_info.bias_dim) return patterns.bias_add; - return patterns.matmul; - }(); + auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1)); + auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype; + auto matmul_combined = matmul(lhs, concat_rhs, out_dtype); if (branch_info.bias_dim) { auto bias_dim = GetTensorSInfo(bias[0])->ndim; - auto concat_bias = concat(make_tuple(bias), Integer(bias_dim - 1)); + auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1)); matmul_combined = add(matmul_combined, concat_bias); } @@ -191,20 +228,23 @@ runtime::TypedPackedFunc(Map, Map)> Ge } } - int ind = 0; + int split_index = 0; Array sections; - for (int i = 0; i < static_cast(indices.size()) - 1; ++i) { - auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1].as(); - ind += width->value; - sections.push_back(IntImm(DataType::Int(64), ind)); + for (size_t i = 0; i + 1 < splits.size(); i++) { + auto width = splits[i].split_size.as(); + ICHECK(width) << "InternalError: " + << "All splits except the last one must have a static shape"; + split_index += width->value; + sections.push_back(IntImm(DataType::Int(64), split_index)); } - int lhs_dim = GetTensorSInfo(inp)->ndim; + int lhs_dim = GetTensorSInfo(lhs)->ndim; int split_axis = std::max(lhs_dim, rhs_dim) - 1; auto chunks = split(matmul_combined, sections, split_axis); - for (size_t i = 0; i < indices.size(); ++i) { - auto bound_var = matchings[pattern_to_replace[indices[i]]]; + for (size_t i = 0; i < splits.size(); i++) { + const auto& split = splits[i]; + auto bound_var = matchings[split.pattern_to_replace]; replacements.Set(bound_var, TupleGetItem(chunks, i)); } } @@ -244,43 +284,43 @@ std::vector GetBranchInfo(Function f) { PostOrderVisit(f, [&](const Expr& e) { if (!e->IsInstance()) return; - if (auto match = ExtractMatchedExpr(pat, e, bindings)) { - auto matmul_call = Downcast(match.value()[matmul_pat]); - auto matmul_lhs = Downcast(matmul_call->args[0]); - auto it = groups.find(matmul_lhs.get()); - BranchInfo* branch = it != groups.end() ? &it->second : nullptr; - std::optional bias_dim = std::nullopt; - std::optional activation = std::nullopt; + auto match = ExtractMatchedExpr(pat, e, bindings); + if (!match) return; - if (match.value().count(bias_pat)) { - bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; - } + auto matmul_call = Downcast(match.value()[matmul_pat]); + auto matmul_lhs = Downcast(matmul_call->args[0]); - for (size_t i = 0; i < activations.size(); ++i) { - if (match.value().count(activation_pat[i]) || - match.value().count(bias_activation_pat[i])) { - activation = activations[i]; - } + std::optional bias_dim = std::nullopt; + std::optional activation = std::nullopt; + + if (match.value().count(bias_pat)) { + bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; + } + + for (size_t i = 0; i < activations.size(); ++i) { + if (match.value().count(activation_pat[i]) || match.value().count(bias_activation_pat[i])) { + activation = activations[i]; } + } - if (!branch) { - // Create a new subgraph with one matmul - groups[matmul_lhs.get()] = {1, bias_dim, activation}; - } else { - // Create a new branch in the existing parallel matmul subtree, and - // invalidate bias and activation information when needed. - branch->num_branches += 1; + if (auto it = groups.find(matmul_lhs.get()); it != groups.end()) { + // Create a new branch in the existing parallel matmul subtree, and + // invalidate bias and activation information when needed. + BranchInfo* branch = &it->second; + + branch->num_branches += 1; - if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) { - branch->bias_dim = std::nullopt; - } + if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) { + branch->bias_dim = std::nullopt; + } - if (!activation || (branch->activation && *branch->activation != *activation)) { - branch->activation = std::nullopt; - } + if (!activation || (branch->activation && *branch->activation != *activation)) { + branch->activation = std::nullopt; } - return; + } else { + // Create a new subgraph with one matmul + groups[matmul_lhs.get()] = {1, bias_dim, activation}; } }); diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index 7e7f2328f3b3..6168d0c58d24 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -525,7 +525,16 @@ def expected( tvm.ir.assert_structural_equal(after, expected) -def test_dynamic_rhs(): +def test_combine_matmul_of_static_and_dynamic_shapes(): + """Combine two matmuls, one with dynamic shape + + The `R.split` operator must have a static list of integer indices + at which to split the matmul output, because these integer indices + are stored as operator attributes. However, the last output can + still have a dynamic shape. + + """ + @R.function(private=True) def before( x: R.Tensor((2, 1024, 640), "float32"), @@ -572,5 +581,117 @@ def expected( tvm.ir.assert_structural_equal(after, expected) +def test_combine_matmul_of_dynamic_and_static_shapes(): + """Combine two matmuls, one with dynamic shape + + Like `test_combine_matmul_of_static_and_dynamic_shapes`, but the + dynamic-shaped matmul is encountered first. Due to the + requirements imposed by `R.split` storing the split indices as + static integers, the static-shaped weights must occur first in the + concatenated weights. + """ + + @R.function(private=True) + def before( + x: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, "M"), "float32"), + w1: R.Tensor((640, 640), "float32"), + ): + M = T.int64() + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + out = (lv0, lv1) + R.output(out) + return out + + @R.function(private=True) + def expected( + x: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, "M"), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tuple( + R.Tensor((2, 1024, "M"), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32") + ): + M = T.int64() + with R.dataflow(): + lv: R.Tensor((640, 640 + M), dtype="float32") = R.concat((w1, w0), axis=1) + lv1: R.Tensor((2, 1024, 640 + M), dtype="float32") = R.matmul( + x, lv, out_dtype="float32" + ) + lv2: R.Tuple( + R.Tensor((2, 1024, 640), dtype="float32"), + R.Tensor((2, 1024, M), dtype="float32"), + ) = R.split(lv1, indices_or_sections=[640], axis=2) + lv0: R.Tensor((2, 1024, M), dtype="float32") = lv2[1] + lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv2[0] + out: R.Tuple( + R.Tensor((2, 1024, M), dtype="float32"), + R.Tensor((2, 1024, 640), dtype="float32"), + ) = (lv0, lv1_1) + R.output(out) + return out + + after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"] + + tvm.ir.assert_structural_equal(after, expected) + + +def test_limit_one_dynamic_shape_in_combined_matmul(): + """Combine two matmuls, one with dynamic shape + + Like `test_combine_matmul_of_static_and_dynamic_shapes`, but with + two dynamic weights that could, in principle, be merged together. + Because `R.split` must have integer indices at which to split, + only one of the dynamic outputs can be part of the combined + matmul. + """ + + @R.function(private=True) + def before( + x: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, "M"), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, "N"), "float32"), + ): + M = T.int64() + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + lv2 = R.matmul(x, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + @R.function(private=True) + def expected( + x: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, "M"), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + w2: R.Tensor((640, "N"), "float32"), + ) -> R.Tuple( + R.Tensor((2, 1024, "M"), dtype="float32"), + R.Tensor((2, 1024, 640), dtype="float32"), + R.Tensor((2, 1024, "N"), dtype="float32"), + ): + M = T.int64() + with R.dataflow(): + concat_weights = R.concat((w1, w0), axis=1) + concat_output = R.matmul(x, concat_weights, out_dtype="float32") + split_output: R.Tuple( + [R.Tensor([2, 1024, 640], dtype="float32"), R.Tensor([2, 1024, M], dtype="float32")] + ) = R.split(concat_output, indices_or_sections=[640], axis=2) + lv0 = split_output[1] + lv1 = split_output[0] + lv2 = R.matmul(x, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + after = CombineParallelMatmul()(tvm.IRModule.from_expr(before))["main"] + + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main()