Skip to content

Commit

Permalink
[Lint] Add check to prevent usage of #include <regex> (#16412)
Browse files Browse the repository at this point in the history
Currently, the pytorch wheels available through `pip install` use the
pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0].  If TVM were to user
the pre-C++11 ABI, this would cause breakages with dynamically-linked
LLVM environments.

This commit adds a lint check to search for use of `#include <regex>`
in any C++ files.  Use of this header should be avoided, as its
implementation is not supported by gcc's dual ABI.  This ABI
incompatibility results in runtime errors either when `std::regex` is
called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first.

This restriction can be removed when a version of pytorch compiled
using `-DUSE_CXX11_ABI=1` is available from PyPI.

[0] pytorch/pytorch#51039
  • Loading branch information
Lunderberg authored Mar 10, 2024
1 parent 9b3621b commit 7ab970d
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 94 deletions.
2 changes: 2 additions & 0 deletions python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@

from . import executor
from . import disco

from .support import _regex_match
69 changes: 69 additions & 0 deletions python/tvm/runtime/support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Runtime support infra of TVM."""

import re

import tvm._ffi


@tvm._ffi.register_func("tvm.runtime.regex_match")
def _regex_match(regex_pattern: str, match_against: str) -> bool:
"""Check if a pattern matches a regular expression
This function should be used instead of `std::regex` within C++
call sites, to avoid ABI incompatibilities with pytorch.
Currently, the pytorch wheels available through pip install use
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
user the pre-C++11 ABI, this would cause breakages with
dynamically-linked LLVM environments.
Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex`
is called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first. This restriction can
be removed when a version of pytorch compiled using
`-DUSE_CXX11_ABI=1` is available from PyPI.
This is exposed as part of `libtvm_runtime.so` as it is used by
the DNNL runtime.
[0] https://github.com/pytorch/pytorch/issues/51039
Parameters
----------
regex_pattern: str
The regular expression
match_against: str
The string against which to match the regular expression
Returns
-------
match_result: bool
True if `match_against` matches the pattern defined by
`regex_pattern`, and False otherwise.
"""
match = re.match(regex_pattern, match_against)
return match is not None
44 changes: 0 additions & 44 deletions python/tvm/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import textwrap
import ctypes
import os
import re
import sys

import tvm
Expand Down Expand Up @@ -88,46 +87,3 @@ def add_function(self, name, func):

def __setitem__(self, key, value):
self.add_function(key, value)


@tvm._ffi.register_func("tvm.support.regex_match")
def _regex_match(regex_pattern: str, match_against: str) -> bool:
"""Check if a pattern matches a regular expression
This function should be used instead of `std::regex` within C++
call sites, to avoid ABI incompatibilities with pytorch.
Currently, the pytorch wheels available through pip install use
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
user the pre-C++11 ABI, this would cause breakages with
dynamically-linked LLVM environments.
Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex`
is called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first. This restriction can
be removed when a version of pytorch compiled using
`-DUSE_CXX11_ABI=1` is available from PyPI.
[0] https://github.com/pytorch/pytorch/issues/51039
Parameters
----------
regex_pattern: str
The regular expression
match_against: str
The string against which to match the regular expression
Returns
-------
match_result: bool
True if `match_against` matches the pattern defined by
`regex_pattern`, and False otherwise.
"""
match = re.match(regex_pattern, match_against)
return match is not None
9 changes: 2 additions & 7 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <unordered_set>

#include "../runtime/object_internal.h"
#include "../runtime/regex.h"

namespace tvm {
namespace transform {
Expand Down Expand Up @@ -538,17 +539,11 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex,
.str();

auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule {
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
CHECK(regex_match_func)
<< "RuntimeError: "
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
<< "This can occur if the TVM Python library has not yet been imported.";

IRModule subset;

for (const auto& [gvar, func] : mod->functions) {
std::string name = gvar->name_hint;
if ((*regex_match_func)(func_name_regex, name)) {
if (tvm::runtime::regex_match(name, func_name_regex)) {
subset->Add(gvar, func);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/update_param_struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
#include <tvm/relax/transform.h>

#include <optional>
#include <regex>
#include <unordered_map>
#include <vector>

#include "../../runtime/regex.h"
#include "utils.h"

namespace tvm {
Expand Down
1 change: 0 additions & 1 deletion src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../utils.h"
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/contrib/dnnl/query_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
#include "../../../../runtime/regex.h"
#include "../../utils.h"
#include "dnnl.hpp"
namespace tvm {
Expand Down Expand Up @@ -173,12 +173,12 @@ dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
}

void check_shapes(const std::vector<std::string> shapes) {
std::regex valid_pat("(\\d*)(,(\\d*))*");
bool checked = std::regex_match(shapes[0], valid_pat);
std::string valid_pat("(\\d*)(,(\\d*))*");
bool checked = tvm::runtime::regex_match(shapes[0], valid_pat);
for (size_t i = 1; i < shapes.size() - 1; i++) {
checked &= std::regex_match(shapes[i], valid_pat);
checked &= tvm::runtime::regex_match(shapes[i], valid_pat);
}
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
checked &= tvm::runtime::regex_match(shapes[shapes.size() - 1], "\\d*");
if (!checked) {
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
}
Expand All @@ -194,8 +194,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
std::string weight_shape, std::string out_shape,
std::string paddings, std::string strides,
std::string dilates, std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true);
check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)OI(D?)(H?)W"), true);
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
Expand Down Expand Up @@ -278,8 +278,8 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
std::string paddings, std::string output_paddings,
std::string strides, std::string dilates,
std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true);
check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)((IO)|(OI))(D?)(H?)W"), true);
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
Expand Down
1 change: 0 additions & 1 deletion src/relay/backend/contrib/mrvl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <iostream>
#include <limits>
#include <memory>
#include <regex>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down
63 changes: 32 additions & 31 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
#include <tvm/runtime/registry.h>

#include <cstddef>
#include <regex>
#include <string>
#include <vector>

#include "../../../runtime/regex.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"

Expand Down Expand Up @@ -194,53 +194,54 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;

// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
std::regex sum_pat(".*_sum.*");
std::regex mish_pat(".*_mish.*");
std::string bias_add_pat(".*_bias.*");
std::string relu_pat(".*_relu.*");
std::string tanh_pat(".*_tanh.*");
std::string sigmoid_pat(".*_sigmoid.*");
std::string clip_pat(".*_clip.*");
std::string gelu_pat(".*_gelu.*");
std::string swish_pat(".*_swish.*");
std::string sum_pat(".*_sum.*");
std::string mish_pat(".*_mish.*");

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();

// Parsing post-ops.
dnnl::post_ops ops;
if (std::regex_match(op_name, sum_pat)) {
if (tvm::runtime::regex_match(op_name, sum_pat)) {
ops.append_sum(1.f);
}
if (std::regex_match(op_name, relu_pat)) {
if (tvm::runtime::regex_match(op_name, relu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
if (std::regex_match(op_name, tanh_pat)) {
if (tvm::runtime::regex_match(op_name, tanh_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
if (std::regex_match(op_name, clip_pat)) {
if (tvm::runtime::regex_match(op_name, clip_pat)) {
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (std::regex_match(op_name, sigmoid_pat)) {
if (tvm::runtime::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
if (std::regex_match(op_name, swish_pat)) {
if (tvm::runtime::regex_match(op_name, swish_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
}
if (std::regex_match(op_name, gelu_pat)) {
if (tvm::runtime::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (std::regex_match(op_name, mish_pat)) {
if (tvm::runtime::regex_match(op_name, mish_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
*bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
*bias_tr =
tvm::runtime::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};

return attr;
}
Expand All @@ -253,31 +254,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
tensor_registry_ = TensorRegistry(engine_, io_eid_set);

std::regex conv_pat(".*conv[1-3]d.*");
std::regex deconv_pat(".*deconv[1-3]d.*");
std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::regex dense_pat(".*dense.*");
std::regex max_pool_pat(".*max_pool[1-3]d");
std::regex avg_pool_pat(".*avg_pool[1-3]d");
std::string conv_pat(".*conv[1-3]d.*");
std::string deconv_pat(".*deconv[1-3]d.*");
std::string conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::string dense_pat(".*dense.*");
std::string max_pool_pat(".*max_pool[1-3]d");
std::string avg_pool_pat(".*avg_pool[1-3]d");

// Build subgraph engine.
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
ICHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if (std::regex_match(op_name, deconv_pat) ||
std::regex_match(op_name, conv_transpose_pat)) {
if (tvm::runtime::regex_match(op_name, deconv_pat) ||
tvm::runtime::regex_match(op_name, conv_transpose_pat)) {
Deconvolution(nid);
} else if (std::regex_match(op_name, conv_pat)) {
} else if (tvm::runtime::regex_match(op_name, conv_pat)) {
Convolution(nid);
} else if (std::regex_match(op_name, dense_pat)) {
} else if (tvm::runtime::regex_match(op_name, dense_pat)) {
Dense(nid);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if (std::regex_match(op_name, max_pool_pat)) {
} else if (tvm::runtime::regex_match(op_name, max_pool_pat)) {
Pooling(nid, dnnl::algorithm::pooling_max);
} else if (std::regex_match(op_name, avg_pool_pat)) {
} else if (tvm::runtime::regex_match(op_name, avg_pool_pat)) {
Pooling(nid, dnnl::algorithm::pooling_avg);
} else if (elt_name2algo.count(op_name)) {
Eltwise(nid);
Expand Down
Loading

0 comments on commit 7ab970d

Please sign in to comment.