Skip to content

Commit

Permalink
Add regex loading from tokenizer.json and code refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Dec 19, 2024
1 parent 378bbef commit 5eda73c
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 76 deletions.
4 changes: 3 additions & 1 deletion onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def __init__(self, tokenizer_dir):
self.tokenizer = create_tokenizer(tokenizer_dir)

def tokenize(self, text):
if isinstance(text, (list, tuple)):
return batch_tokenize(self.tokenizer, text)
return batch_tokenize(self.tokenizer, [text])[0]

def detokenize(self, tokens):
return batch_detokenize(self.tokenizer, [tokens])[0]
return batch_detokenize(self.tokenizer, [tokens])

def __del__(self):
if delete_object and self.tokenizer:
Expand Down
14 changes: 4 additions & 10 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
bpe::TokenWithRegularExp regcmp;
bpe::PreTokenizerWithRegEx reg_splitter;

for (auto& seg_id : special_token_split_res) {
if (static_cast<int64_t>(res.size()) >= max_length) break;
Expand All @@ -274,7 +274,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,

// Note: keep ptr to make sure the string_view is valid in the following process
std::u32string str(seg_id.first);
regcmp.Set(str.c_str());
reg_splitter.Set(str.c_str());

size_t offset = 0;
OffsetMappingType offset_mapping;
Expand All @@ -287,14 +287,8 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}

while (static_cast<int64_t>(res.size()) < max_length) {
std::string regex_expr = "";
if (ModelName() == kModel_Llama){
regex_expr = regcmp.LLAMA_REGEX_PATTERN;
} else {
// default to GPT2 regex
regex_expr = regcmp.GPT2_REGEX_PATTERN;
}
auto [b, tok] = regcmp.GetNextToken(regex_expr);
std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName());
auto [b, tok] = reg_splitter.GetNextToken(regex_expr);

if (!b) break;

Expand Down
68 changes: 68 additions & 0 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,58 @@ class BpeModel {
}
}

OrtxStatus LoadPreTokenizer(const json& bpe_model) {
auto node_pre_tokenizer = bpe_model.find("pre_tokenizer");
if (node_pre_tokenizer == bpe_model.end() || node_pre_tokenizer->is_null()) {
return {};
}

auto iter_type = node_pre_tokenizer->find("type");
if (iter_type == node_pre_tokenizer->end() || iter_type->is_null()) {
return {};
}

if (iter_type->get<std::string>() != "Sequence") {
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
}

auto iter_node_list = node_pre_tokenizer->find("pretokenizers");

if (iter_node_list == node_pre_tokenizer->end() || iter_node_list->is_null()) {
return {};
}

for (const auto& node : *iter_node_list) {
auto iter_type = node.find("type");
if (iter_type == node.end() || iter_type->is_null()) {
continue; // ignore unknown pre-tokenizer type
}


auto pre_type = iter_type->get<std::string>();
if (pre_type == "Split") {
auto iter_pattern = node.find("pattern");
if (iter_pattern == node.end() || iter_pattern->is_null()) {
continue;
}

auto regex_str = iter_pattern->find("Regex");
if (regex_str == iter_pattern->end() || regex_str->is_null()) {
continue;
}

pre_tokenizer_regex_ = regex_str->get<std::string>();
} else if (pre_type == "ByteLevel") {
; // need to add more flag support here in the future
}
else {
return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"};
}
}

return {};
}

OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token,
const char* special_tokens, bool spm_converted) {
nlohmann::json tok_json;
Expand Down Expand Up @@ -121,6 +173,8 @@ class BpeModel {
}

OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) {
ORTX_RETURN_IF_ERROR(LoadPreTokenizer(bpe_model));

const json& vocab_json = bpe_model["vocab"];
const json& merges_json = bpe_model["merges"];
vocab_json.get_to(vocab_map_);
Expand Down Expand Up @@ -358,6 +412,19 @@ class BpeModel {

const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; }

std::string GetPreTokenizerRegex(const std::string& model_name) const {
if (!pre_tokenizer_regex_.empty()) {
return pre_tokenizer_regex_;
}

if (model_name == "Llama") {
return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN;
}

// by default, use the GPT2 pretokenizer regex
return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN;
}

private:
struct BpeNode {
uint32_t id;
Expand All @@ -379,6 +446,7 @@ class BpeModel {
uint32_t unk_id_ = (std::numeric_limits<uint32_t>::max)();
bpe::SpecialTokenMap special_tokens_;
TrieTree<char32_t> added_tokens_;
std::string pre_tokenizer_regex_;
};

} // namespace ort_extensions
10 changes: 5 additions & 5 deletions operators/tokenizer/bpe_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ class SpecialTokenMap {
std::unordered_map<ustring, int> token_map_;
};

class TokenWithRegularExp {
class PreTokenizerWithRegEx {
public:
static constexpr const char* GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
static constexpr const char* LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

void Set(std::u32string_view val) {
m_text = val;
}
Expand All @@ -115,10 +119,6 @@ class TokenWithRegularExp {
return {false, {}};
}

const std::string GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
const std::string LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";

public:

// Although we have RegexMatchGeneral which performs regex matching given any general regex string
Expand Down
1 change: 1 addition & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"GPT2Tokenizer", TokenType::kBPE},
{"Qwen2Tokenizer", TokenType::kBPE},
{"BaichuanTokenizer", TokenType::kBPE},
{"GPTNeoXTokenizer", TokenType::kBPE},

{"", TokenType::kUnigram},
{"T5Tokenizer", TokenType::kUnigram},
Expand Down
2 changes: 1 addition & 1 deletion pyop/py_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
OrtxTokenizer* tokenizer = nullptr;
auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str());
if (err != kOrtxOK) {
throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage());
throw std::runtime_error(std::string("Failed to create tokenizer\n") + OrtxGetLastErrorMessage());
}
return reinterpret_cast<std::uintptr_t>(tokenizer);
},
Expand Down
77 changes: 28 additions & 49 deletions shared/api/tokenizer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,15 @@

namespace ort_extensions {

std::set<std::string> TokenizerImpl::supported_bpe_models_ = {
"PreTrainedTokenizerFast",
"CLIPTokenizer",
"WhisperTokenizer",
"GemmaTokenizer",
"LlamaTokenizer",
"Phi3Tokenizer",
"CodeLlamaTokenizer",
"CodeGenTokenizer",
"GPT2Tokenizer",
"Qwen2Tokenizer",
"BaichuanTokenizer"
};

std::set<std::string> TokenizerImpl::supported_ugm_models_ = {
"XLMRobertaTokenizer",
"T5Tokenizer",
"ChatGLMTokenizer"
};

TokenizerImpl::TokenizerImpl()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {};
TokenizerImpl::~TokenizerImpl() {};

OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
if (tok_config_->tokenizer_class_.empty() ||
supported_ugm_models_.count(tok_config_->tokenizer_class_)) {

auto type = TokenJsonConfig::GetTokenType(tok_config_->tokenizer_class_);
if (type == TokenType::kUnigram) {
auto tokenizer = std::make_unique<SpmUgmTokenizer>();
auto status = tokenizer->Load(*tok_config_);
if (!status.IsOk()) {
Expand All @@ -53,42 +35,39 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

return status;
}

if (!supported_bpe_models_.count(tok_config_->tokenizer_class_)) {
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
} else if (type == TokenType::kBPE) {
auto tokenizer = std::make_unique<JsonFastTokenizer>();
auto fx_load = &JsonFastTokenizer::Load;
if (blob == nullptr) {
auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile());
// vocab file is checked in TokenJsonConfig::Load
if (vocab_file_path.extension() != ".json") {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;
}
}
} else {
if (blob->raw_model_blob_len > 0) {
fx_load = &JsonFastTokenizer::LoadTikTokenBase64;

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
}

auto status = (tokenizer.get()->*fx_load)(*tok_config_);
if (!status.IsOk()) {
return status;
}
auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);

auto detok = std::make_unique<BpeStreamingDecoder>();
status = detok->Load(tok_config_, *tokenizer);
if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
}

if (status.IsOk()) {
tokenizer_ = std::move(tokenizer);
detokenizer_ = std::move(detok);
return status;
}

return status;
return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class");
}

OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) {
Expand Down
3 changes: 0 additions & 3 deletions shared/api/tokenizer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace ort_extensions {

class TokenizerImpl : public OrtxObjectImpl {
public:
static std::set<std::string> supported_bpe_models_;
static std::set<std::string> supported_ugm_models_;

TokenizerImpl();
virtual ~TokenizerImpl();

Expand Down
8 changes: 4 additions & 4 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST(CApiTest, StreamApiTest) {

TEST(OrtxTokenizerTest, RegexTest) {
std::u32string str = U"CAN'T \r\n 2413m";
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::u32string> res;
std::vector<std::u32string> out_tokens = {U"CAN", U"'T", U" \r\n", U" ", U"241", U"3", U"m"};
Expand All @@ -91,7 +91,7 @@ TEST(OrtxTokenizerTest, RegexMatchSTDTest) {
std::vector<std::u32string> input_strings = {U"not its, or IT'S, but it's",
U" ",
U"AbCd"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::vector<std::u32string>> res_vector;
std::vector<std::vector<std::u32string>> out_tokens = {{U"'s"},
Expand All @@ -118,7 +118,7 @@ TEST(OrtxTokenizerTest, WrapStandaloneCategoriesTest) {
"\\p{rn}\\p{L}\\p{N}\\p{L}",
"\\p{Z}*[\\p{rn}]+",
"\\p{Z}+"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::string> res;
std::vector<std::string> out_regex = {"[^\\p{rn}\\p{L}\\p{N}]?[\\p{L}]+",
Expand Down Expand Up @@ -152,7 +152,7 @@ TEST(OrtxTokenizerTest, RegexMatchGeneralTest) {
U"241356m",
U"Ich liebe München <3 \r\n ",
U"生活的真谛是"};
auto regcmp = std::make_unique<ort_extensions::bpe::TokenWithRegularExp>();
auto regcmp = std::make_unique<ort_extensions::bpe::PreTokenizerWithRegEx>();

std::vector<std::vector<std::u32string>> res_vector;
std::vector<std::vector<std::u32string>> out_tokens = {{U"CAN", U"'T", U"", U""},
Expand Down
17 changes: 14 additions & 3 deletions test/test_pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@


is_pp_api_available = False
hf_token_id = None
try:
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer
from onnxruntime_extensions import pp_api
is_pp_api_available = True
hf_token_id = os.environ.get("HF_TOKEN", None)
except ImportError:
pass

Expand Down Expand Up @@ -46,7 +48,6 @@ def setUpClass(cls):
else:
cls.temp_dir = tempfile.mkdtemp()
print(f"Created temp dir: {cls.temp_dir}")
cls.token_id = os.environ.get("HF_TOKEN", None)

def test_CLIP_image_processing(self):
model_id = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -76,6 +77,7 @@ def test_CLIP_image_processing(self):
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
a_image.save(f"{self.temp_dir}/CLIP_a_{i}.png")

@unittest.skipIf(hf_token_id is None, "HF_TOKEN is not available")
def test_llama3_2_image_processing(self):
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

Expand All @@ -91,7 +93,7 @@ def test_llama3_2_image_processing(self):
"test/data/processor/exceltable.png"]
(image, image2, image3) = [Image.open(f) for f in image_list]

processor = AutoImageProcessor.from_pretrained(model_id, token=TestPPAPI.token_id)
processor = AutoImageProcessor.from_pretrained(model_id, token=hf_token_id)
inputs = processor.preprocess(
[image, image2, image3], return_tensors="np")
print({k: v.shape if k == "pixel_values" else v for k, v in inputs.items()})
Expand All @@ -114,6 +116,15 @@ def test_llama3_2_image_processing(self):
a_image = regen_image(np.transpose(actual, (1, 2, 0)))
a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png")

def test_OLMa_tokenizer(self):
test_sentence = ["I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61" + " |||IP_ADDRESS|||"]
model_id = "amd/AMD-OLMo-1B-SFT-DPO"
hf_enc = AutoTokenizer.from_pretrained(model_id)
inputs = hf_enc(test_sentence)["input_ids"]
tokenizer = pp_api.Tokenizer(model_id)
ortx_inputs = tokenizer.tokenize(test_sentence)
# self.assertEqual(inputs, ortx_inputs)
np.testing.assert_array_equal(ortx_inputs, inputs)

if __name__ == '__main__':
unittest.main()

0 comments on commit 5eda73c

Please sign in to comment.