Skip to content

Commit

Permalink
Enable using kwargs for selecting pad-to-max-length strategy for toke…
Browse files Browse the repository at this point in the history
…nizer in embeddings
  • Loading branch information
kcirred committed Oct 4, 2024
1 parent 1695c3b commit 79e1366
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from torch import nn
from torch.backends import mps
from transformers import BatchEncoding
from transformers.tokenization_utils import PaddingStrategy
import numpy as np
import torch

Expand Down Expand Up @@ -976,6 +977,7 @@ def _tokenize_plus(
truncate_input_tokens: int,
texts: List[str],
implicit_truncation_errors: bool = True,
**kwargs,
) -> TruncatedTokensTuple:
"""Tokenize with support for truncation handling and returning the token count
Args:
Expand Down Expand Up @@ -1015,21 +1017,21 @@ def _tokenize_plus(
texts = [str(s).strip() for s in texts]

# Call tokenizer with the same truncation parameters every time
tokenized = self._get_tokenized(texts)
tokenized = self._get_tokenized(texts, **kwargs)

# Custom truncation and/or error raise if needed
truncation_needed = self._truncation_needed(tokenized, max_length, texts)
if truncation_needed and okay_to_truncate:
# Truncate texts in place
_truncate_texts(texts, tokenized, max_length, truncation_needed)
# Re-tokenize the truncated texts
tokenized = self._get_tokenized(texts)
tokenized = self._get_tokenized(texts, **kwargs)
truncation_needed = [] # truncation accomplished

input_token_count = sum_token_count(tokenized)
return TruncatedTokensTuple(tokenized, input_token_count, truncation_needed)

def _get_tokenized(self, texts):
def _get_tokenized(self, texts, **kwargs):
"""Intentionally always call tokenizer the same way to avoid thread issues.
Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID).
Expand All @@ -1039,6 +1041,8 @@ def _get_tokenized(self, texts):
the fast tokenizer with different truncation settings.
"""

pad_to_max_length = kwargs.pop("pad_to_max_length", None)

# Keep copies of tokenizer per thread (in each wrapped model instance)
thread_id = threading.get_ident()
tokenizer = (
Expand All @@ -1047,18 +1051,32 @@ def _get_tokenized(self, texts):
else self.tokenizers.setdefault(thread_id, deepcopy(self.tokenizer))
)

return tokenizer(
texts,
return_attention_mask=True, # Used for determining token count
return_token_type_ids=False,
return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches
return_offsets_mapping=True, # Used for truncation
return_length=False,
return_tensors="pt",
truncation=True, # DO NOT CHANGE else "Already borrowed" errors
padding=True, # DO NOT CHANGE else "Already borrowed" errors
max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors
)
if pad_to_max_length:
return tokenizer(
texts,
return_attention_mask=True, # Used for determining token count
return_token_type_ids=False,
return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches
return_offsets_mapping=True, # Used for truncation
return_length=False,
return_tensors="pt",
truncation=True, # DO NOT CHANGE else "Already borrowed" errors
padding=PaddingStrategy.MAX_LENGTH, # DO NOT CHANGE else "Already borrowed" errors
max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors
)
else:
return tokenizer(
texts,
return_attention_mask=True, # Used for determining token count
return_token_type_ids=False,
return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches
return_offsets_mapping=True, # Used for truncation
return_length=False,
return_tensors="pt",
truncation=True, # DO NOT CHANGE else "Already borrowed" errors
padding=True, # DO NOT CHANGE else "Already borrowed" errors
max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors
)

def encode(
self,
Expand All @@ -1077,6 +1095,7 @@ def encode(
return_token_count: bool = False,
implicit_truncation_errors: bool = True,
autocast: bool = False,
tokenizer_kwargs: Dict[str, Any] = {},
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
Expand Down Expand Up @@ -1161,6 +1180,7 @@ def encode(
truncate_input_tokens,
sentences_batch,
implicit_truncation_errors=implicit_truncation_errors,
**tokenizer_kwargs
)

if truncation_needed: # truncation was needed and was not done/not allowed
Expand Down

0 comments on commit 79e1366

Please sign in to comment.