Skip to content

Commit

Permalink
add python api hf embdedded json tokenizer support
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayan Shaw committed Dec 12, 2024
1 parent 1a21d45 commit 64a0350
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 26 deletions.
12 changes: 11 additions & 1 deletion onnxruntime_extensions/_cuops.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,16 @@ def get_outputs(cls):
]


class HfJsonTokenizer(CustomOp):
@classmethod
def get_inputs(cls):
return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]

@classmethod
def get_outputs(cls):
return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]


# TODO: have a C++ impl.
def _argsort_op(x, dim):
d = numpy.argsort(x, dim)
Expand Down Expand Up @@ -544,4 +554,4 @@ def build_graph(cls, op_class, *args, **kwargs):

@staticmethod
def get_op_class(op_type):
return globals()[op_type]
return globals()[op_type]
260 changes: 235 additions & 25 deletions onnxruntime_extensions/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@
from ._hf_cvt import HFTokenizerConverter, HFTokenizerOnnxGraph # noqa
from ._ortapi2 import make_onnx_model, SingleOpGraph

import os
import numpy as np
import tempfile
import shutil

# edit environment variables to avoid protobuf version mismatch
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

from transformers.convert_slow_tokenizer import SpmConverter # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from tokenizers import decoders, normalizers, pre_tokenizers, Regex # noqa: E402


OrtxTokenizer = None
try:
from onnxruntime_extensions.pp_api import Tokenizer as OrtxTokenizer
except ImportError:
pass

_is_torch_available = False
try:
Expand All @@ -24,11 +42,150 @@

_PRE_POST_PAIR = {'TrieTokenizer': "TrieDetokenizer"}

def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
if add_prefix_space:
prepend_scheme = "always"
if not getattr(original_tokenizer, "legacy", True):
prepend_scheme = "first"
else:
prepend_scheme = "never"
return prepend_scheme


class Baichuan2Converter(SpmConverter):
handle_byte_fallback = True

def __init__(self, original_tokenizer):
super().__init__(original_tokenizer)
original_tokenizer.add_prefix_space = False

def vocab(self, proto):
vocab = [
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab

def unk_id(self, proto):
unk_id = 0
return unk_id

def decoder(self, replacement, add_prefix_space):
sequence = [
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
]
if add_prefix_space:
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)

def normalizer(self, proto):
if getattr(self.original_tokenizer, "legacy", True):
sequence = []
if getattr(self.original_tokenizer, "add_prefix_space", True):
sequence += [normalizers.Prepend(prepend="▁")]
sequence += [normalizers.Replace(pattern=" ", content="▁")]
return normalizers.Sequence(sequence)
return None # non-legacy, no normalizer

def pre_tokenizer(self, replacement, add_prefix_space):
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
else:
return super().pre_tokenizer(replacement, add_prefix_space)


class ChatGlmConverter(SpmConverter):
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
_normalizers = [
normalizers.Strip(left=False, right=True), # stripping is important
normalizers.Replace(Regex(" {2,}"), "▁"),
]
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)

def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = "always"
if hasattr(self.original_tokenizer, "legacy") and not self.original_tokenizer.legacy:
prepend_scheme = "first"
return pre_tokenizers.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space, prepend_scheme=prepend_scheme
)


JSON_TOKEN_CONVERTERS = {
"BaichuanTokenizer": Baichuan2Converter,
"ChatGLMTokenizer": ChatGlmConverter,
}

# Save tokenizer JSON files using HuggingFace AutoTokenizer
def convert_tokenizer(model_path, output_dir):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if output_dir is None:
if os.path.isdir(model_path):
output_dir = model_path
else:
# create a temporary directory
output_dir = tempfile.mkdtemp()
tokenizer.save_pretrained(output_dir)
json_path = os.path.join(output_dir, "tokenizer.json")

if type(tokenizer).__name__ in JSON_TOKEN_CONVERTERS:
GenericSpmConverter = JSON_TOKEN_CONVERTERS[type(tokenizer).__name__]

converted = GenericSpmConverter(tokenizer).converted()
converted.save(json_path)
print(f"**Tokenizer saved to {json_path}")
return output_dir

# Validate tokenizer files downloaded from memory
def validate_tokenizer(model_path, output_dir):
test_sentence = "I like walking my cute dog\n and\x17 then, 生活的真谛是 \t\t\t\t \n\n61"
if OrtxTokenizer is None:
print("onnxruntime_extensions package was built with C API enabled, skipping tokenization test")
ortx_tokenizer = OrtxTokenizer(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
expected_ids = tokenizer(test_sentence, return_tensors="np")["input_ids"]
ortx_ids = np.asarray(ortx_tokenizer.tokenize(test_sentence))
assert np.array_equal(expected_ids[0], ortx_ids), f"Tokenization mismatch: {expected_ids[0]} != {ortx_ids}"
print("Tokenization test passed")

# Download tokenizer JSON files from memory
def download_tokenizer(tokenizer_dir, output_dir):
try:
from transformers.utils import cached_file

resolved_full_file = cached_file(tokenizer_dir, "tokenizer.json")
resolved_config_file = cached_file(tokenizer_dir, "tokenizer_config.json")
except ImportError:
raise ValueError(f"Directory '{tokenizer_dir}' not found and transformers is not available")
if not os.path.exists(resolved_full_file):
raise FileNotFoundError(f"Downloaded HF file '{resolved_full_file}' cannot be found")
if os.path.dirname(resolved_full_file) != os.path.dirname(resolved_config_file):
raise FileNotFoundError(
f"Downloaded HF files '{resolved_full_file}' " f"and '{resolved_config_file}' are not in the same directory"
)

if output_dir is None or len(output_dir) == 0:
output_dir = os.path.dirname(resolved_full_file)
print(f"Using {output_dir} as output directory")
return output_dir
else:
# copy the files to the output directory
shutil.copy(resolved_full_file, output_dir)
shutil.copy(resolved_config_file, output_dir)
return output_dir


def gen_processing_models(processor: Union[str, object],
pre_kwargs: dict = None,
post_kwargs: dict = None,
opset: int = None,
schema_v2: bool = False,
**kwargs):
"""
Generate the pre- and post-processing ONNX model, basing on the name or HF class.
Expand All @@ -47,6 +204,9 @@ def gen_processing_models(processor: Union[str, object],
Keyword arguments for generating the post-processing model
opset: int
the target opset version of the model
schema_v2: bool
the flag for using embedded tokenizer files; this option leverages the blob-loading functionality
which loads HF tokenizers from memory rather than using the tokenizer files in HF JSON format.
kwargs:
The additional arguments for generating models
Expand All @@ -58,39 +218,89 @@ def gen_processing_models(processor: Union[str, object],
if pre_kwargs is None and post_kwargs is None:
raise ValueError(
"Either pre_kwargs or post_kwargs should be provided. None means no processing graph output.")
if isinstance(processor, str):

# If true, we get the tokenizer JSON files by either downloading from cache or using HuggingFace AutoTokenizer
# to convert them, and then create an ONNX model with the JSON files as strings in the model attributes (attrs).
if schema_v2:
model_name = processor if isinstance(processor, str) else type(processor).__name__

converted_tokenizer = {"Baichuan2", "chatglm"}
need_convert = False
for token in converted_tokenizer:
if model_name.find(token) != -1:
need_convert = True
break

if need_convert:
model_dir = convert_tokenizer(model_name)
validate_tokenizer(model_name, None)
else:
model_dir = download_tokenizer(model_name, None)

# Load the content of tokenizer.json into a string
with open(f"{model_dir}/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer_vocab = f.read()

# Load the content of tokenizer_config.json into a string
with open(f"{model_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = f.read()

# Create an ONNX model with these JSON file strings in attrs
g_pre, g_post = (None, None)
if pre_kwargs:
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
if post_kwargs:
if pre_kwargs is not None:
# Add tokenizer_vocab and tokenizer_config to the kwargs
# so they are added to attrs in build_graph
pre_kwargs['tokenizer_vocab'] = tokenizer_vocab
pre_kwargs['tokenizer_config'] = tokenizer_config
g_pre = SingleOpGraph.build_graph("HfJsonTokenizer", **pre_kwargs)
if post_kwargs is not None:
if pre_kwargs is None:
cls_name = processor
else:
if processor not in _PRE_POST_PAIR:
raise RuntimeError(
f"Cannot locate the post processing operator name from {processor}")
cls_name = _PRE_POST_PAIR[processor]
# Add tokenizer_vocab and tokenizer_config to the kwargs
# so they are added to attrs in build_graph
post_kwargs['tokenizer_vocab'] = tokenizer_vocab
post_kwargs['tokenizer_config'] = tokenizer_config
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None

cls_name = type(processor).__name__
if cls_name == "WhisperProcessor":
if WhisperDataProcGraph is None:
raise ValueError(
"The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
pre_m = _converter.pre_processing(
**pre_kwargs) if pre_kwargs is not None else None
post_m = _converter.post_processing(
**post_kwargs) if post_kwargs is not None else None
return pre_m, post_m
elif HFTokenizerOnnxGraph.is_supported(processor):
_converter = HFTokenizerOnnxGraph(processor)
pre_g = _converter.pre_processing(
**pre_kwargs) if pre_kwargs is not None else None
post_g = _converter.post_processing(
**post_kwargs) if post_kwargs is not None else None
return make_onnx_model(pre_g) if pre_g else None, \
make_onnx_model(post_g) if post_g else None
else:
raise ValueError(f"Unsupported processor/tokenizer: {cls_name}")
if isinstance(processor, str):
g_pre, g_post = (None, None)
if pre_kwargs:
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
if post_kwargs:
if pre_kwargs is None:
cls_name = processor
else:
if processor not in _PRE_POST_PAIR:
raise RuntimeError(
f"Cannot locate the post processing operator name from {processor}")
cls_name = _PRE_POST_PAIR[processor]
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None

cls_name = type(processor).__name__
if cls_name == "WhisperProcessor":
if WhisperDataProcGraph is None:
raise ValueError(
"The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
pre_m = _converter.pre_processing(
**pre_kwargs) if pre_kwargs is not None else None
post_m = _converter.post_processing(
**post_kwargs) if post_kwargs is not None else None
return pre_m, post_m
elif HFTokenizerOnnxGraph.is_supported(processor):
_converter = HFTokenizerOnnxGraph(processor)
pre_g = _converter.pre_processing(
**pre_kwargs) if pre_kwargs is not None else None
post_g = _converter.post_processing(
**post_kwargs) if post_kwargs is not None else None
return make_onnx_model(pre_g) if pre_g else None, \
make_onnx_model(post_g) if post_g else None
else:
raise ValueError(f"Unsupported processor/tokenizer: {cls_name}")
59 changes: 59 additions & 0 deletions test/test_embedded_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import numpy as np
from transformers import AutoTokenizer, GPT2Tokenizer
from onnxruntime_extensions import OrtPyFunction, gen_processing_models, ort_inference


class TestEmbeddedTokenizer(unittest.TestCase):
def test_clip_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
"openai/clip-vit-base-patch32", use_fast=False)
text = """
1. Testing long text with multiple lines to check newline handling
2. As well as words with apostrophes such as you're, i'm, don't, etc.
3. And weird characters such as . , ~ ? ( ) " [ ] ! : - .
"""
ids = tokenizer.encode(text, return_tensors="np")

ort_tok = OrtPyFunction.from_model(gen_processing_models(
tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0],
schema_v2=True)
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids, actual_ids)

def test_gpt2_tokenizer(self):
tokenizer = GPT2Tokenizer.from_pretrained(
"Xenova/gpt-4", use_fast=False)
text = "Testing words with apostrophes such as you're, i'm, don't, etc."
ids = tokenizer.encode(text, return_tensors="np")

ort_tok = OrtPyFunction.from_model(gen_processing_models(
tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True})[0],
schema_v2=True)
actual_ids = ort_tok([text])[0]
np.testing.assert_array_equal(ids, actual_ids)

def test_xlm_roberta_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
"xlm-roberta-base", use_fast=False)
# TODO: if there is <unk> in text, the result is not matched.
text = (
'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will'
" add words that should not exist and be tokenized to , such as saoneuhaoesuth")
ids = tokenizer.encode(text, return_tensors="np")

ort_tok, _ = gen_processing_models(
tokenizer,
pre_kwargs={"WITH_DEFAULT_INPUTS": True},
schema_v2=True)
actual_ids, *_ = ort_inference(ort_tok, [text])
np.testing.assert_array_equal(ids[0], actual_ids)


if __name__ == '__main__':
unittest.main()

0 comments on commit 64a0350

Please sign in to comment.