Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a generic tokenizer operator to support Hugging Face Tokenizer with JSON data files. #859

Merged
merged 7 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/ext_ortlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ else()
if (OCOS_ONNXRUNTIME_VERSION)
set(ONNXRUNTIME_VER ${OCOS_ONNXRUNTIME_VERSION})
else()
set(ONNXRUNTIME_VER "1.17.1")
set(ONNXRUNTIME_VER "1.19.2") # need to check if android package of this version is available too.
endif()

if (ANDROID)
Expand Down
2 changes: 1 addition & 1 deletion include/ort_c_to_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size)
return false;
}

#define ORTX_RETURN_IF_ERROR(expr) \
#define ORTW_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (_status != nullptr) { \
Expand Down
4 changes: 2 additions & 2 deletions include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct OrtxTokenizerBlob {
#ifdef __cplusplus
OrtxTokenizerBlob(const std::string_view& config_json_blob,
const std::string_view& vocab_json_blob,
const std::string_view& token_module_blob,
const std::string_view& raw_model_blob)
const std::string_view& token_module_blob = {},
const std::string_view& raw_model_blob = {})
: config_json_blob(config_json_blob.data()), vocab_json_blob(vocab_json_blob.data()),
token_module_blob(token_module_blob.data()), raw_model_blob(raw_model_blob.data()),
config_blob_len(config_json_blob.size()),
Expand Down
8 changes: 8 additions & 0 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ typedef OrtxObject OrtxTensorResult;
// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
#define ORTX_DISPOSE(obj) OrtxDispose((OrtxObject**)&obj)
#define ORTX_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if (!_status.IsOk()) { \
return _status; \
} \
} while (0)


typedef uint32_t extTokenId_t;

Expand Down
16 changes: 8 additions & 8 deletions operators/tokenizer/bpe_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,26 @@ struct KernelBpeDecoder {
}

std::string added_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_tokens", added_tokens));
if (!added_tokens.empty()) {
auto um = ParseId2String(added_tokens);
added_tokens_ = std::map<int64_t, std::string>(um.begin(), um.end());
}

std::string all_special_ids;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "all_special_ids", all_special_ids));
if (!all_special_ids.empty()) {
auto um = ParseId2String(all_special_ids);
std::transform(um.begin(), um.end(), std::inserter(all_special_ids_, all_special_ids_.end()),
[](const auto& p) { return p.first; });
}

ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_));
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "en_normalization", en_normalization_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "skip_special_tokens", skip_special_tokens_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "whitespace_token", whitespace_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "bos_token", bos_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "eos_token", eos_token_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "unk_token", unk_token_));

return status;
}
Expand Down
10 changes: 5 additions & 5 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,24 @@ KernelBpeTokenizer::KernelBpeTokenizer(const BpeModelConf& conf)
OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
// note: if the attribute doesn't exist in op node, GetOpAttribute doesn't return a failed status;
std::string vocab;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", vocab));
if (vocab.empty()) {
return OrtW::CreateStatus("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
}

std::string merges;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "merges", merges));
if (merges.empty()) {
return OrtW::CreateStatus("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
}

ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "padding_length", padding_length_));
if (padding_length_ != -1 && padding_length_ <= 0) {
return OrtW::CreateStatus("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
}

std::string model_name;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model_name", model_name));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model_name", model_name));
if (!model_name.empty()) {
model_name_ = model_name;
}
Expand All @@ -159,7 +159,7 @@ OrtStatusPtr KernelBpeTokenizer::OnModelAttach(const OrtApi& api, const OrtKerne
}

std::string added_token;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "added_token", added_token));
status = bbpe_tokenizer_->LoadAddedTokens(added_token.c_str());
if (!status.IsOk()) {
return (OrtStatusPtr)status;
Expand Down
2 changes: 1 addition & 1 deletion operators/tokenizer/sentencepiece_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
struct KernelSentencepieceDecoder {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string model_blob;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_blob));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_blob));

sentencepiece::ModelProto model_proto;
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
Expand Down
2 changes: 1 addition & 1 deletion operators/tokenizer/sentencepiece_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

OrtStatusPtr KernelSentencepieceTokenizer::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string model_as_string;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_as_string));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "model", model_as_string));

sentencepiece::ModelProto model_proto;
std::vector<uint8_t> model_as_bytes;
Expand Down
32 changes: 32 additions & 0 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,30 @@

namespace ort_extensions {

enum class TokenType {
kUnknown, kUnigram, kBPE
};

constexpr std::pair<const char*, TokenType> kTokenizerDict[] = {
{"PreTrainedTokenizerFast", TokenType::kBPE},
{"CLIPTokenizer", TokenType::kBPE},
{"WhisperTokenizer", TokenType::kBPE},
{"GemmaTokenizer", TokenType::kBPE},
{"LlamaTokenizer", TokenType::kBPE},
{"Phi3Tokenizer", TokenType::kBPE},
{"CodeLlamaTokenizer", TokenType::kBPE},
{"CodeGenTokenizer", TokenType::kBPE},
{"GPT2Tokenizer", TokenType::kBPE},
{"Qwen2Tokenizer", TokenType::kBPE},
{"BaichuanTokenizer", TokenType::kBPE},

{"", TokenType::kUnigram},
{"T5Tokenizer", TokenType::kUnigram},
{"ChatGLMTokenizer", TokenType::kUnigram},
{"XLMRobertaTokenizer", TokenType::kUnigram}
};


// TokenJsonConfig: Handles loading and parsing of JSON configuration files for tokenizers
class TokenJsonConfig final {
public:
Expand Down Expand Up @@ -230,6 +254,14 @@ class TokenJsonConfig final {
return added_token;
}

static TokenType GetTokenType(const std::string& tok) {
static const std::unordered_map<std::string, TokenType> dict {
std::begin(kTokenizerDict), std::end(kTokenizerDict) };

auto iter = dict.find(tok);
return iter == dict.end() ? TokenType::kUnknown : iter->second;
}

private:
void LoadAddedTokens(const json& tok_json) {
auto added_tokens = tok_json.find("added_tokens");
Expand Down
56 changes: 56 additions & 0 deletions operators/tokenizer/tokenizer_op_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <variant>

#include "bpe_kernels.h"
#include "ugm_kernels.hpp"

#include "ext_status.h"
#include "op_def_struct.h"
#include "ort_c_to_cpp.h"

namespace ort_extensions {

class JsonTokenizerOpKernel {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string config_json;
ORTW_RETURN_IF_ERROR(OrtW::API::GetOpAttributeString(api, info, "tokenizer_config", config_json));

std::string vocab_json;
ORTW_RETURN_IF_ERROR(OrtW::API::GetOpAttributeString(api, info, "tokenizer_vocab", vocab_json));

TokenJsonConfig cfg;
OrtxTokenizerBlob blob({config_json.c_str(), config_json.length()},
{vocab_json.c_str(), vocab_json.length()});

ORTX_RETURN_IF_ERROR(cfg.LoadFromBlob(blob));

auto type = TokenJsonConfig::GetTokenType(cfg.tokenizer_class_);
if (type == TokenType::kUnigram) {
tokenizer_ = std::make_unique<SpmUgmTokenizer>();
} else if (type == TokenType::kBPE) {
tokenizer_ = std::make_unique<JsonFastTokenizer>();
} else {
return OrtxStatus(kOrtxErrorCorruptData, "Unknown tokenizer type");
}

return std::visit([&](auto& ptr) { return ptr->Load(cfg); }, tokenizer_);
}

OrtxStatus Compute(const ortc::Tensor<std::string>& input,
ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
std::optional<ortc::Tensor<int64_t>*> offset_mapping = std::nullopt) const {

return std::visit([&](auto& ptr) {
return ptr->Compute(input, tokenize_output, attention_mask, offset_mapping);
}, tokenizer_);
}

private:
std::variant<std::unique_ptr<JsonFastTokenizer>, std::unique_ptr<SpmUgmTokenizer>> tokenizer_;
};

} // namespace ort_extensions
3 changes: 3 additions & 0 deletions operators/tokenizer/tokenizers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "bpe_kernels.h"
#include "bpe_tokenizer_model.hpp"
#include "bpe_decoder.hpp"
#include "tokenizer_op_impl.hpp"
using namespace ort_extensions;
#endif

#ifdef ENABLE_SPM_TOKENIZER
Expand Down Expand Up @@ -40,6 +42,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer = []() -> CustomOpArray& {
CustomCpuStructV2("RobertaTokenizer", RobertaTokenizer),
CustomCpuStructV2("BpeDecoder", KernelBpeDecoder),
CustomCpuStructV2("SpmTokenizer", SpmTokenizer),
CustomCpuStructV2("HfJsonTokenizer", JsonTokenizerOpKernel),
#endif

#ifdef ENABLE_SPM_TOKENIZER
Expand Down
4 changes: 2 additions & 2 deletions operators/tokenizer/trie_tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct KernelTrieTokenizer {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string text_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
return nullptr;
};
Expand Down Expand Up @@ -156,7 +156,7 @@ struct KernelTrieDetokenizer {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string text_tokens;
ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
ORTW_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "vocab", text_tokens));
tokenizer = std::make_shared<TrieTokenizer>(text_tokens);
return nullptr;
};
Expand Down
8 changes: 7 additions & 1 deletion operators/tokenizer/ugm_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ struct SpmUgmTokenizer {
return std::get<0>(iter->second);
}

OrtxStatus Compute(const ortc::Tensor<std::string>& input, ortc::Tensor<int64_t>& tokenize_output) const {
OrtxStatus Compute(const ortc::Tensor<std::string>& input, ortc::Tensor<int64_t>& tokenize_output,
std::optional<ortc::Tensor<int64_t>*> attention_mask = std::nullopt,
std::optional<ortc::Tensor<int64_t>*> offset_mapping = std::nullopt) const {
if (attention_mask.has_value() || offset_mapping.has_value()) {
return {kOrtxErrorInvalidArgument, "attention-mask or offset-mapping was supported in unigram tokenizer"};
}

if (input.Shape().size() != 1) {
return OrtxStatus(extError_t::kOrtxErrorInvalidArgument, "Input tensor must have rank 1.");
}
Expand Down
Loading
Loading