diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 3a68c567eef6d..f182cd9bfd2f2 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -41,3 +41,5 @@ from . import executor from . import disco + +from .support import _regex_match diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py new file mode 100644 index 0000000000000..3716460a2709d --- /dev/null +++ b/python/tvm/runtime/support.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Runtime support infra of TVM.""" + +import re + +import tvm._ffi + + +@tvm._ffi.register_func("tvm.runtime.regex_match") +def _regex_match(regex_pattern: str, match_against: str) -> bool: + """Check if a pattern matches a regular expression + + This function should be used instead of `std::regex` within C++ + call sites, to avoid ABI incompatibilities with pytorch. + + Currently, the pytorch wheels available through pip install use + the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to + user the pre-C++11 ABI, this would cause breakages with + dynamically-linked LLVM environments. + + Use of the `` header in TVM should be avoided, as its + implementation is not supported by gcc's dual ABI. This ABI + incompatibility results in runtime errors either when `std::regex` + is called from TVM, or when `std::regex` is called from pytorch, + depending on which library was loaded first. This restriction can + be removed when a version of pytorch compiled using + `-DUSE_CXX11_ABI=1` is available from PyPI. + + This is exposed as part of `libtvm_runtime.so` as it is used by + the DNNL runtime. + + [0] https://github.com/pytorch/pytorch/issues/51039 + + Parameters + ---------- + regex_pattern: str + + The regular expression + + match_against: str + + The string against which to match the regular expression + + Returns + ------- + match_result: bool + + True if `match_against` matches the pattern defined by + `regex_pattern`, and False otherwise. + + """ + match = re.match(regex_pattern, match_against) + return match is not None diff --git a/python/tvm/support.py b/python/tvm/support.py index 4fa95fac89218..a50a5e7b57328 100644 --- a/python/tvm/support.py +++ b/python/tvm/support.py @@ -19,7 +19,6 @@ import textwrap import ctypes import os -import re import sys import tvm @@ -88,46 +87,3 @@ def add_function(self, name, func): def __setitem__(self, key, value): self.add_function(key, value) - - -@tvm._ffi.register_func("tvm.support.regex_match") -def _regex_match(regex_pattern: str, match_against: str) -> bool: - """Check if a pattern matches a regular expression - - This function should be used instead of `std::regex` within C++ - call sites, to avoid ABI incompatibilities with pytorch. - - Currently, the pytorch wheels available through pip install use - the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to - user the pre-C++11 ABI, this would cause breakages with - dynamically-linked LLVM environments. - - Use of the `` header in TVM should be avoided, as its - implementation is not supported by gcc's dual ABI. This ABI - incompatibility results in runtime errors either when `std::regex` - is called from TVM, or when `std::regex` is called from pytorch, - depending on which library was loaded first. This restriction can - be removed when a version of pytorch compiled using - `-DUSE_CXX11_ABI=1` is available from PyPI. - - [0] https://github.com/pytorch/pytorch/issues/51039 - - Parameters - ---------- - regex_pattern: str - - The regular expression - - match_against: str - - The string against which to match the regular expression - - Returns - ------- - match_result: bool - - True if `match_against` matches the pattern defined by - `regex_pattern`, and False otherwise. - """ - match = re.match(regex_pattern, match_against) - return match is not None diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 766bd28875c5d..3eb64fec84fe8 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -35,6 +35,7 @@ #include #include "../runtime/object_internal.h" +#include "../runtime/regex.h" namespace tvm { namespace transform { @@ -538,17 +539,11 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex, .str(); auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule { - const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match"); - CHECK(regex_match_func) - << "RuntimeError: " - << "The PackedFunc 'tvm.support.regex_match' has not been registered. " - << "This can occur if the TVM Python library has not yet been imported."; - IRModule subset; for (const auto& [gvar, func] : mod->functions) { std::string name = gvar->name_hint; - if ((*regex_match_func)(func_name_regex, name)) { + if (tvm::runtime::regex_match(name, func_name_regex)) { subset->Add(gvar, func); } } diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index 327185fd0bc37..b3fa0464beade 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -27,10 +27,10 @@ #include #include -#include #include #include +#include "../../runtime/regex.h" #include "utils.h" namespace tvm { diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 9a9ed5f83d979..3b7bc8f10d501 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -31,7 +31,6 @@ #include #include -#include #include #include "../../utils.h" diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc index 63e0d73ce2297..2660481e00c2f 100755 --- a/src/relay/backend/contrib/dnnl/query_layout.cc +++ b/src/relay/backend/contrib/dnnl/query_layout.cc @@ -31,10 +31,10 @@ #include #include -#include #include #include "../../../../runtime/contrib/dnnl/dnnl_utils.h" +#include "../../../../runtime/regex.h" #include "../../utils.h" #include "dnnl.hpp" namespace tvm { @@ -173,12 +173,12 @@ dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false, } void check_shapes(const std::vector shapes) { - std::regex valid_pat("(\\d*)(,(\\d*))*"); - bool checked = std::regex_match(shapes[0], valid_pat); + std::string valid_pat("(\\d*)(,(\\d*))*"); + bool checked = tvm::runtime::regex_match(shapes[0], valid_pat); for (size_t i = 1; i < shapes.size() - 1; i++) { - checked &= std::regex_match(shapes[i], valid_pat); + checked &= tvm::runtime::regex_match(shapes[i], valid_pat); } - checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*")); + checked &= tvm::runtime::regex_match(shapes[shapes.size() - 1], "\\d*"); if (!checked) { LOG(FATAL) << "Invalid input args for query dnnl optimal layout."; } @@ -194,8 +194,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker std::string weight_shape, std::string out_shape, std::string paddings, std::string strides, std::string dilates, std::string G, std::string dtype) { - check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); - check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true); + check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true); + check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)OI(D?)(H?)W"), true); check_shapes({weight_shape, out_shape, paddings, strides, dilates, G}); dnnl::engine eng(dnnl::engine::kind::cpu, 0); @@ -278,8 +278,8 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string paddings, std::string output_paddings, std::string strides, std::string dilates, std::string G, std::string dtype) { - check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); - check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true); + check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true); + check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)((IO)|(OI))(D?)(H?)W"), true); check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G}); dnnl::engine eng(dnnl::engine::kind::cpu, 0); diff --git a/src/relay/backend/contrib/mrvl/codegen.cc b/src/relay/backend/contrib/mrvl/codegen.cc index 527b53acf4980..d395de6694ff0 100644 --- a/src/relay/backend/contrib/mrvl/codegen.cc +++ b/src/relay/backend/contrib/mrvl/codegen.cc @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 0b674f08f2fd4..f29628d56b800 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -26,10 +26,10 @@ #include #include -#include #include #include +#include "../../../runtime/regex.h" #include "../json/json_node.h" #include "../json/json_runtime.h" @@ -194,45 +194,45 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr; // Define RegExp. - std::regex bias_add_pat(".*_bias.*"); - std::regex relu_pat(".*_relu.*"); - std::regex tanh_pat(".*_tanh.*"); - std::regex sigmoid_pat(".*_sigmoid.*"); - std::regex clip_pat(".*_clip.*"); - std::regex gelu_pat(".*_gelu.*"); - std::regex swish_pat(".*_swish.*"); - std::regex sum_pat(".*_sum.*"); - std::regex mish_pat(".*_mish.*"); + std::string bias_add_pat(".*_bias.*"); + std::string relu_pat(".*_relu.*"); + std::string tanh_pat(".*_tanh.*"); + std::string sigmoid_pat(".*_sigmoid.*"); + std::string clip_pat(".*_clip.*"); + std::string gelu_pat(".*_gelu.*"); + std::string swish_pat(".*_swish.*"); + std::string sum_pat(".*_sum.*"); + std::string mish_pat(".*_mish.*"); // parsing of name to extract attributes auto op_name = nodes_[nid].GetOpName(); // Parsing post-ops. dnnl::post_ops ops; - if (std::regex_match(op_name, sum_pat)) { + if (tvm::runtime::regex_match(op_name, sum_pat)) { ops.append_sum(1.f); } - if (std::regex_match(op_name, relu_pat)) { + if (tvm::runtime::regex_match(op_name, relu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f); } - if (std::regex_match(op_name, tanh_pat)) { + if (tvm::runtime::regex_match(op_name, tanh_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f); } - if (std::regex_match(op_name, clip_pat)) { + if (tvm::runtime::regex_match(op_name, clip_pat)) { float a_min = GetNodeAttr(nodes_[nid], "a_min"); float a_max = GetNodeAttr(nodes_[nid], "a_max"); ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max); } - if (std::regex_match(op_name, sigmoid_pat)) { + if (tvm::runtime::regex_match(op_name, sigmoid_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); } - if (std::regex_match(op_name, swish_pat)) { + if (tvm::runtime::regex_match(op_name, swish_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f); } - if (std::regex_match(op_name, gelu_pat)) { + if (tvm::runtime::regex_match(op_name, gelu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); } - if (std::regex_match(op_name, mish_pat)) { + if (tvm::runtime::regex_match(op_name, mish_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f); } if (ops.len() != 0) { @@ -240,7 +240,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } // Parsing bias_add. - *bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{}; + *bias_tr = + tvm::runtime::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{}; return attr; } @@ -253,12 +254,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::set io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end()); tensor_registry_ = TensorRegistry(engine_, io_eid_set); - std::regex conv_pat(".*conv[1-3]d.*"); - std::regex deconv_pat(".*deconv[1-3]d.*"); - std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); - std::regex dense_pat(".*dense.*"); - std::regex max_pool_pat(".*max_pool[1-3]d"); - std::regex avg_pool_pat(".*avg_pool[1-3]d"); + std::string conv_pat(".*conv[1-3]d.*"); + std::string deconv_pat(".*deconv[1-3]d.*"); + std::string conv_transpose_pat(".*conv[1-3]d_transpose.*"); + std::string dense_pat(".*dense.*"); + std::string max_pool_pat(".*max_pool[1-3]d"); + std::string avg_pool_pat(".*avg_pool[1-3]d"); // Build subgraph engine. for (size_t nid = 0; nid < nodes_.size(); ++nid) { @@ -266,18 +267,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (node.GetOpType() == "kernel") { ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); - if (std::regex_match(op_name, deconv_pat) || - std::regex_match(op_name, conv_transpose_pat)) { + if (tvm::runtime::regex_match(op_name, deconv_pat) || + tvm::runtime::regex_match(op_name, conv_transpose_pat)) { Deconvolution(nid); - } else if (std::regex_match(op_name, conv_pat)) { + } else if (tvm::runtime::regex_match(op_name, conv_pat)) { Convolution(nid); - } else if (std::regex_match(op_name, dense_pat)) { + } else if (tvm::runtime::regex_match(op_name, dense_pat)) { Dense(nid); } else if ("nn.batch_norm" == op_name) { BatchNorm(nid); - } else if (std::regex_match(op_name, max_pool_pat)) { + } else if (tvm::runtime::regex_match(op_name, max_pool_pat)) { Pooling(nid, dnnl::algorithm::pooling_max); - } else if (std::regex_match(op_name, avg_pool_pat)) { + } else if (tvm::runtime::regex_match(op_name, avg_pool_pat)) { Pooling(nid, dnnl::algorithm::pooling_avg); } else if (elt_name2algo.count(op_name)) { Eltwise(nid); diff --git a/src/runtime/regex.cc b/src/runtime/regex.cc new file mode 100644 index 0000000000000..ef6c068edfe0b --- /dev/null +++ b/src/runtime/regex.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/regex.cc + * \brief Exposes calls to python's `re` library. + */ + +#include "./regex.h" + +#include + +namespace tvm { +namespace runtime { + +bool regex_match(const std::string& match_against, const std::string& regex_pattern) { + const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.runtime.regex_match"); + CHECK(regex_match_func) << "RuntimeError: " + << "The PackedFunc 'tvm.runtime.regex_match' has not been registered. " + << "This can occur if the TVM Python library has not yet been imported."; + return (*regex_match_func)(regex_pattern, match_against); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/regex.h b/src/runtime/regex.h new file mode 100644 index 0000000000000..a072700c911af --- /dev/null +++ b/src/runtime/regex.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file regex.h + * \brief Exposes calls to python's `re` library. + */ +#ifndef TVM_RUNTIME_REGEX_H_ +#define TVM_RUNTIME_REGEX_H_ + +#include + +namespace tvm { +namespace runtime { + +/* \brief Check if a pattern matches a regular expression + * + * This function should be used instead of `std::regex` within C++ + * call sites, to avoid ABI incompatibilities with pytorch. + * + * Currently, the pytorch wheels available through pip install use + * the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to + * user the pre-C++11 ABI, this would cause breakages with + * dynamically-linked LLVM environments. + * + * Use of the `` header in TVM should be avoided, as its + * implementation is not supported by gcc's dual ABI. This ABI + * incompatibility results in runtime errors either when `std::regex` + * is called from TVM, or when `std::regex` is called from pytorch, + * depending on which library was loaded first. This restriction can + * be removed when a version of pytorch compiled using + * `-DUSE_CXX11_ABI=1` is available from PyPI. + * + * [0] https://github.com/pytorch/pytorch/issues/51039 + * + * \param match_against The string against which to match the regular expression + * + * \param regex_pattern The regular expression + * + * \returns match_result True if `match_against` matches the pattern + * defined by `regex_pattern`, and False otherwise. + */ + +bool regex_match(const std::string& match_against, const std::string& regex_pattern); + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_REGEX_H_ diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index 38c30b2ed6c69..b948c91c1edd9 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -28,3 +28,13 @@ python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp \ "src/runtime/hexagon/rpc/hexagon_rpc_skel.c" \ "src/runtime/hexagon/rpc/hexagon_rpc_stub.c" \ "src/relay/backend/contrib/libtorch/libtorch_codegen.cc" + + +if find src -name "*.cc" -exec grep -Hn '^#include $' {} +; then + echo "The header file may not be used in TVM," 1>&2 + echo "because it causes ABI incompatibility with most pytorch installations." 1>&2 + echo "Pytorch packages on PyPI currently set `-DUSE_CXX11_ABI=0`," 1>&2 + echo "which causes ABI compatibility when calling functions." 1>&2 + echo "See https://github.com/pytorch/pytorch/issues/51039 for more details." 1>&2 + exit 1 +fi