From 5eda73c8a62b624e328ec371856af8c777431b34 Mon Sep 17 00:00:00 2001 From: Wenbing Li Date: Thu, 19 Dec 2024 18:40:24 +0000 Subject: [PATCH] Add regex loading from tokenizer.json and code refinement --- onnxruntime_extensions/pp_api.py | 4 +- operators/tokenizer/bpe_kernels.cc | 14 ++-- operators/tokenizer/bpe_tokenizer_model.hpp | 68 ++++++++++++++++++ operators/tokenizer/bpe_utils.hpp | 10 +-- operators/tokenizer/tokenizer_jsconfig.hpp | 1 + pyop/py_c_api.cc | 2 +- shared/api/tokenizer_impl.cc | 77 ++++++++------------- shared/api/tokenizer_impl.h | 3 - test/pp_api_test/test_tokenizer.cc | 8 +-- test/test_pp_api.py | 17 ++++- 10 files changed, 128 insertions(+), 76 deletions(-) diff --git a/onnxruntime_extensions/pp_api.py b/onnxruntime_extensions/pp_api.py index e827f78a4..08ec07678 100644 --- a/onnxruntime_extensions/pp_api.py +++ b/onnxruntime_extensions/pp_api.py @@ -49,10 +49,12 @@ def __init__(self, tokenizer_dir): self.tokenizer = create_tokenizer(tokenizer_dir) def tokenize(self, text): + if isinstance(text, (list, tuple)): + return batch_tokenize(self.tokenizer, text) return batch_tokenize(self.tokenizer, [text])[0] def detokenize(self, tokens): - return batch_detokenize(self.tokenizer, [tokens])[0] + return batch_detokenize(self.tokenizer, [tokens]) def __del__(self): if delete_object and self.tokenizer: diff --git a/operators/tokenizer/bpe_kernels.cc b/operators/tokenizer/bpe_kernels.cc index 4ca160c9f..7f8d4f64b 100644 --- a/operators/tokenizer/bpe_kernels.cc +++ b/operators/tokenizer/bpe_kernels.cc @@ -262,7 +262,7 @@ std::vector KernelBpeTokenizer::Tokenize(ustring& input, // Parse input auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input); - bpe::TokenWithRegularExp regcmp; + bpe::PreTokenizerWithRegEx reg_splitter; for (auto& seg_id : special_token_split_res) { if (static_cast(res.size()) >= max_length) break; @@ -274,7 +274,7 @@ std::vector KernelBpeTokenizer::Tokenize(ustring& input, // Note: keep ptr to make sure the string_view is valid in the following process std::u32string str(seg_id.first); - regcmp.Set(str.c_str()); + reg_splitter.Set(str.c_str()); size_t offset = 0; OffsetMappingType offset_mapping; @@ -287,14 +287,8 @@ std::vector KernelBpeTokenizer::Tokenize(ustring& input, } while (static_cast(res.size()) < max_length) { - std::string regex_expr = ""; - if (ModelName() == kModel_Llama){ - regex_expr = regcmp.LLAMA_REGEX_PATTERN; - } else { - // default to GPT2 regex - regex_expr = regcmp.GPT2_REGEX_PATTERN; - } - auto [b, tok] = regcmp.GetNextToken(regex_expr); + std::string regex_expr = bbpe_tokenizer_->GetPreTokenizerRegex(ModelName()); + auto [b, tok] = reg_splitter.GetNextToken(regex_expr); if (!b) break; diff --git a/operators/tokenizer/bpe_tokenizer_model.hpp b/operators/tokenizer/bpe_tokenizer_model.hpp index 14e3c1547..8ad3aa7fd 100644 --- a/operators/tokenizer/bpe_tokenizer_model.hpp +++ b/operators/tokenizer/bpe_tokenizer_model.hpp @@ -44,6 +44,58 @@ class BpeModel { } } + OrtxStatus LoadPreTokenizer(const json& bpe_model) { + auto node_pre_tokenizer = bpe_model.find("pre_tokenizer"); + if (node_pre_tokenizer == bpe_model.end() || node_pre_tokenizer->is_null()) { + return {}; + } + + auto iter_type = node_pre_tokenizer->find("type"); + if (iter_type == node_pre_tokenizer->end() || iter_type->is_null()) { + return {}; + } + + if (iter_type->get() != "Sequence") { + return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"}; + } + + auto iter_node_list = node_pre_tokenizer->find("pretokenizers"); + + if (iter_node_list == node_pre_tokenizer->end() || iter_node_list->is_null()) { + return {}; + } + + for (const auto& node : *iter_node_list) { + auto iter_type = node.find("type"); + if (iter_type == node.end() || iter_type->is_null()) { + continue; // ignore unknown pre-tokenizer type + } + + + auto pre_type = iter_type->get(); + if (pre_type == "Split") { + auto iter_pattern = node.find("pattern"); + if (iter_pattern == node.end() || iter_pattern->is_null()) { + continue; + } + + auto regex_str = iter_pattern->find("Regex"); + if (regex_str == iter_pattern->end() || regex_str->is_null()) { + continue; + } + + pre_tokenizer_regex_ = regex_str->get(); + } else if (pre_type == "ByteLevel") { + ; // need to add more flag support here in the future + } + else { + return {kOrtxErrorNotImplemented, "Unsupported pretokenizer type!"}; + } + } + + return {}; + } + OrtxStatus Load(std::istream& vocab_stream, std::istream& merges_stream, const char* unk_token, const char* special_tokens, bool spm_converted) { nlohmann::json tok_json; @@ -121,6 +173,8 @@ class BpeModel { } OrtxStatus Load(const json& bpe_model, const char* /* special_tokens */, bool spm_converted) { + ORTX_RETURN_IF_ERROR(LoadPreTokenizer(bpe_model)); + const json& vocab_json = bpe_model["vocab"]; const json& merges_json = bpe_model["merges"]; vocab_json.get_to(vocab_map_); @@ -358,6 +412,19 @@ class BpeModel { const std::string& GetEndOfWordSuffix() const { return end_of_word_suffix_; } + std::string GetPreTokenizerRegex(const std::string& model_name) const { + if (!pre_tokenizer_regex_.empty()) { + return pre_tokenizer_regex_; + } + + if (model_name == "Llama") { + return bpe::PreTokenizerWithRegEx::LLAMA_REGEX_PATTERN; + } + + // by default, use the GPT2 pretokenizer regex + return bpe::PreTokenizerWithRegEx::GPT2_REGEX_PATTERN; + } + private: struct BpeNode { uint32_t id; @@ -379,6 +446,7 @@ class BpeModel { uint32_t unk_id_ = (std::numeric_limits::max)(); bpe::SpecialTokenMap special_tokens_; TrieTree added_tokens_; + std::string pre_tokenizer_regex_; }; } // namespace ort_extensions diff --git a/operators/tokenizer/bpe_utils.hpp b/operators/tokenizer/bpe_utils.hpp index 7e9d415d7..ce4bdfc01 100644 --- a/operators/tokenizer/bpe_utils.hpp +++ b/operators/tokenizer/bpe_utils.hpp @@ -97,8 +97,12 @@ class SpecialTokenMap { std::unordered_map token_map_; }; -class TokenWithRegularExp { +class PreTokenizerWithRegEx { public: + static constexpr const char* GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; + static constexpr const char* LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + static constexpr const char* LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + void Set(std::u32string_view val) { m_text = val; } @@ -115,10 +119,6 @@ class TokenWithRegularExp { return {false, {}}; } - const std::string GPT2_REGEX_PATTERN = "'s|'t|'re|'ve|'m|'ll|'d|?\\p{L}+|?\\p{N}+|?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; - const std::string LLAMA_REGEX_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}|?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - const std::string LLAMA_REGEX_PATTERN_2 = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - public: // Although we have RegexMatchGeneral which performs regex matching given any general regex string diff --git a/operators/tokenizer/tokenizer_jsconfig.hpp b/operators/tokenizer/tokenizer_jsconfig.hpp index cd1a67b59..75fac5f46 100644 --- a/operators/tokenizer/tokenizer_jsconfig.hpp +++ b/operators/tokenizer/tokenizer_jsconfig.hpp @@ -26,6 +26,7 @@ constexpr std::pair kTokenizerDict[] = { {"GPT2Tokenizer", TokenType::kBPE}, {"Qwen2Tokenizer", TokenType::kBPE}, {"BaichuanTokenizer", TokenType::kBPE}, + {"GPTNeoXTokenizer", TokenType::kBPE}, {"", TokenType::kUnigram}, {"T5Tokenizer", TokenType::kUnigram}, diff --git a/pyop/py_c_api.cc b/pyop/py_c_api.cc index c2f57b561..46a36d799 100644 --- a/pyop/py_c_api.cc +++ b/pyop/py_c_api.cc @@ -121,7 +121,7 @@ void AddGlobalMethodsCApi(pybind11::module& m) { OrtxTokenizer* tokenizer = nullptr; auto err = OrtxCreateTokenizer(&tokenizer, tokenizer_def_json.c_str()); if (err != kOrtxOK) { - throw std::runtime_error(std::string("Failed to create tokenizer") + OrtxGetLastErrorMessage()); + throw std::runtime_error(std::string("Failed to create tokenizer\n") + OrtxGetLastErrorMessage()); } return reinterpret_cast(tokenizer); }, diff --git a/shared/api/tokenizer_impl.cc b/shared/api/tokenizer_impl.cc index e8aca8e32..fe5ad5277 100644 --- a/shared/api/tokenizer_impl.cc +++ b/shared/api/tokenizer_impl.cc @@ -11,33 +11,15 @@ namespace ort_extensions { -std::set TokenizerImpl::supported_bpe_models_ = { - "PreTrainedTokenizerFast", - "CLIPTokenizer", - "WhisperTokenizer", - "GemmaTokenizer", - "LlamaTokenizer", - "Phi3Tokenizer", - "CodeLlamaTokenizer", - "CodeGenTokenizer", - "GPT2Tokenizer", - "Qwen2Tokenizer", - "BaichuanTokenizer" -}; - -std::set TokenizerImpl::supported_ugm_models_ = { - "XLMRobertaTokenizer", - "T5Tokenizer", - "ChatGLMTokenizer" -}; TokenizerImpl::TokenizerImpl() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTokenizer) {}; TokenizerImpl::~TokenizerImpl() {}; OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) { - if (tok_config_->tokenizer_class_.empty() || - supported_ugm_models_.count(tok_config_->tokenizer_class_)) { + + auto type = TokenJsonConfig::GetTokenType(tok_config_->tokenizer_class_); + if (type == TokenType::kUnigram) { auto tokenizer = std::make_unique(); auto status = tokenizer->Load(*tok_config_); if (!status.IsOk()) { @@ -53,42 +35,39 @@ OrtxStatus TokenizerImpl::LoadTokenizer(const OrtxTokenizerBlob* blob) { tokenizer_ = std::move(tokenizer); detokenizer_ = std::move(detok); } - return status; - } - - if (!supported_bpe_models_.count(tok_config_->tokenizer_class_)) { - return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class"); - } - - auto tokenizer = std::make_unique(); - auto fx_load = &JsonFastTokenizer::Load; - if (blob == nullptr) { - auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile()); - // vocab file is checked in TokenJsonConfig::Load - if (vocab_file_path.extension() != ".json") { - fx_load = &JsonFastTokenizer::LoadTikTokenBase64; + } else if (type == TokenType::kBPE) { + auto tokenizer = std::make_unique(); + auto fx_load = &JsonFastTokenizer::Load; + if (blob == nullptr) { + auto vocab_file_path = ortx::path(tok_config_->GetVocabDataFile()); + // vocab file is checked in TokenJsonConfig::Load + if (vocab_file_path.extension() != ".json") { + fx_load = &JsonFastTokenizer::LoadTikTokenBase64; + } + } else { + if (blob->raw_model_blob_len > 0) { + fx_load = &JsonFastTokenizer::LoadTikTokenBase64; + } } - } else { - if (blob->raw_model_blob_len > 0) { - fx_load = &JsonFastTokenizer::LoadTikTokenBase64; + + auto status = (tokenizer.get()->*fx_load)(*tok_config_); + if (!status.IsOk()) { + return status; } - } - auto status = (tokenizer.get()->*fx_load)(*tok_config_); - if (!status.IsOk()) { - return status; - } + auto detok = std::make_unique(); + status = detok->Load(tok_config_, *tokenizer); - auto detok = std::make_unique(); - status = detok->Load(tok_config_, *tokenizer); + if (status.IsOk()) { + tokenizer_ = std::move(tokenizer); + detokenizer_ = std::move(detok); + } - if (status.IsOk()) { - tokenizer_ = std::move(tokenizer); - detokenizer_ = std::move(detok); + return status; } - return status; + return OrtxStatus(kOrtxErrorNotImplemented, "Unsupported tokenizer class"); } OrtxStatus TokenizerImpl::Load(const OrtxTokenizerBlob& blob) { diff --git a/shared/api/tokenizer_impl.h b/shared/api/tokenizer_impl.h index f1372c8f1..395d2cb7c 100644 --- a/shared/api/tokenizer_impl.h +++ b/shared/api/tokenizer_impl.h @@ -15,9 +15,6 @@ namespace ort_extensions { class TokenizerImpl : public OrtxObjectImpl { public: - static std::set supported_bpe_models_; - static std::set supported_ugm_models_; - TokenizerImpl(); virtual ~TokenizerImpl(); diff --git a/test/pp_api_test/test_tokenizer.cc b/test/pp_api_test/test_tokenizer.cc index 527d832cc..0a5ea815f 100644 --- a/test/pp_api_test/test_tokenizer.cc +++ b/test/pp_api_test/test_tokenizer.cc @@ -67,7 +67,7 @@ TEST(CApiTest, StreamApiTest) { TEST(OrtxTokenizerTest, RegexTest) { std::u32string str = U"CAN'T \r\n 2413m"; - auto regcmp = std::make_unique(); + auto regcmp = std::make_unique(); std::vector res; std::vector out_tokens = {U"CAN", U"'T", U" \r\n", U" ", U"241", U"3", U"m"}; @@ -91,7 +91,7 @@ TEST(OrtxTokenizerTest, RegexMatchSTDTest) { std::vector input_strings = {U"not its, or IT'S, but it's", U" ", U"AbCd"}; - auto regcmp = std::make_unique(); + auto regcmp = std::make_unique(); std::vector> res_vector; std::vector> out_tokens = {{U"'s"}, @@ -118,7 +118,7 @@ TEST(OrtxTokenizerTest, WrapStandaloneCategoriesTest) { "\\p{rn}\\p{L}\\p{N}\\p{L}", "\\p{Z}*[\\p{rn}]+", "\\p{Z}+"}; - auto regcmp = std::make_unique(); + auto regcmp = std::make_unique(); std::vector res; std::vector out_regex = {"[^\\p{rn}\\p{L}\\p{N}]?[\\p{L}]+", @@ -152,7 +152,7 @@ TEST(OrtxTokenizerTest, RegexMatchGeneralTest) { U"241356m", U"Ich liebe München <3 \r\n ", U"生活的真谛是"}; - auto regcmp = std::make_unique(); + auto regcmp = std::make_unique(); std::vector> res_vector; std::vector> out_tokens = {{U"CAN", U"'T", U"", U""}, diff --git a/test/test_pp_api.py b/test/test_pp_api.py index 4965f13f4..36f77affc 100644 --- a/test/test_pp_api.py +++ b/test/test_pp_api.py @@ -8,10 +8,12 @@ is_pp_api_available = False +hf_token_id = None try: - from transformers import AutoImageProcessor + from transformers import AutoImageProcessor, AutoTokenizer from onnxruntime_extensions import pp_api is_pp_api_available = True + hf_token_id = os.environ.get("HF_TOKEN", None) except ImportError: pass @@ -46,7 +48,6 @@ def setUpClass(cls): else: cls.temp_dir = tempfile.mkdtemp() print(f"Created temp dir: {cls.temp_dir}") - cls.token_id = os.environ.get("HF_TOKEN", None) def test_CLIP_image_processing(self): model_id = "openai/clip-vit-large-patch14" @@ -76,6 +77,7 @@ def test_CLIP_image_processing(self): a_image = regen_image(np.transpose(actual, (1, 2, 0))) a_image.save(f"{self.temp_dir}/CLIP_a_{i}.png") + @unittest.skipIf(hf_token_id is None, "HF_TOKEN is not available") def test_llama3_2_image_processing(self): model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" @@ -91,7 +93,7 @@ def test_llama3_2_image_processing(self): "test/data/processor/exceltable.png"] (image, image2, image3) = [Image.open(f) for f in image_list] - processor = AutoImageProcessor.from_pretrained(model_id, token=TestPPAPI.token_id) + processor = AutoImageProcessor.from_pretrained(model_id, token=hf_token_id) inputs = processor.preprocess( [image, image2, image3], return_tensors="np") print({k: v.shape if k == "pixel_values" else v for k, v in inputs.items()}) @@ -114,6 +116,15 @@ def test_llama3_2_image_processing(self): a_image = regen_image(np.transpose(actual, (1, 2, 0))) a_image.save(f"{self.temp_dir}/a_{idx}_{i}.png") + def test_OLMa_tokenizer(self): + test_sentence = ["I like walking my cute dog\n and\x17 then 生活的真谛是 \t\t\t\t \n\n61" + " |||IP_ADDRESS|||"] + model_id = "amd/AMD-OLMo-1B-SFT-DPO" + hf_enc = AutoTokenizer.from_pretrained(model_id) + inputs = hf_enc(test_sentence)["input_ids"] + tokenizer = pp_api.Tokenizer(model_id) + ortx_inputs = tokenizer.tokenize(test_sentence) + # self.assertEqual(inputs, ortx_inputs) + np.testing.assert_array_equal(ortx_inputs, inputs) if __name__ == '__main__': unittest.main()