From 79e13660331f89d0e6464d484477605bd454590f Mon Sep 17 00:00:00 2001 From: kcirred Date: Fri, 4 Oct 2024 15:56:34 -0400 Subject: [PATCH] Enable using kwargs for selecting pad-to-max-length strategy for tokenizer in embeddings --- .../modules/text_embedding/embedding.py | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 06779d0d3..97d479768 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -37,6 +37,7 @@ from torch import nn from torch.backends import mps from transformers import BatchEncoding +from transformers.tokenization_utils import PaddingStrategy import numpy as np import torch @@ -976,6 +977,7 @@ def _tokenize_plus( truncate_input_tokens: int, texts: List[str], implicit_truncation_errors: bool = True, + **kwargs, ) -> TruncatedTokensTuple: """Tokenize with support for truncation handling and returning the token count Args: @@ -1015,7 +1017,7 @@ def _tokenize_plus( texts = [str(s).strip() for s in texts] # Call tokenizer with the same truncation parameters every time - tokenized = self._get_tokenized(texts) + tokenized = self._get_tokenized(texts, **kwargs) # Custom truncation and/or error raise if needed truncation_needed = self._truncation_needed(tokenized, max_length, texts) @@ -1023,13 +1025,13 @@ def _tokenize_plus( # Truncate texts in place _truncate_texts(texts, tokenized, max_length, truncation_needed) # Re-tokenize the truncated texts - tokenized = self._get_tokenized(texts) + tokenized = self._get_tokenized(texts, **kwargs) truncation_needed = [] # truncation accomplished input_token_count = sum_token_count(tokenized) return TruncatedTokensTuple(tokenized, input_token_count, truncation_needed) - def _get_tokenized(self, texts): + def _get_tokenized(self, texts, **kwargs): """Intentionally always call tokenizer the same way to avoid thread issues. Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID). @@ -1039,6 +1041,8 @@ def _get_tokenized(self, texts): the fast tokenizer with different truncation settings. """ + pad_to_max_length = kwargs.pop("pad_to_max_length", None) + # Keep copies of tokenizer per thread (in each wrapped model instance) thread_id = threading.get_ident() tokenizer = ( @@ -1047,18 +1051,32 @@ def _get_tokenized(self, texts): else self.tokenizers.setdefault(thread_id, deepcopy(self.tokenizer)) ) - return tokenizer( - texts, - return_attention_mask=True, # Used for determining token count - return_token_type_ids=False, - return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches - return_offsets_mapping=True, # Used for truncation - return_length=False, - return_tensors="pt", - truncation=True, # DO NOT CHANGE else "Already borrowed" errors - padding=True, # DO NOT CHANGE else "Already borrowed" errors - max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors - ) + if pad_to_max_length: + return tokenizer( + texts, + return_attention_mask=True, # Used for determining token count + return_token_type_ids=False, + return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches + return_offsets_mapping=True, # Used for truncation + return_length=False, + return_tensors="pt", + truncation=True, # DO NOT CHANGE else "Already borrowed" errors + padding=PaddingStrategy.MAX_LENGTH, # DO NOT CHANGE else "Already borrowed" errors + max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors + ) + else: + return tokenizer( + texts, + return_attention_mask=True, # Used for determining token count + return_token_type_ids=False, + return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches + return_offsets_mapping=True, # Used for truncation + return_length=False, + return_tensors="pt", + truncation=True, # DO NOT CHANGE else "Already borrowed" errors + padding=True, # DO NOT CHANGE else "Already borrowed" errors + max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors + ) def encode( self, @@ -1077,6 +1095,7 @@ def encode( return_token_count: bool = False, implicit_truncation_errors: bool = True, autocast: bool = False, + tokenizer_kwargs: Dict[str, Any] = {}, ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings @@ -1161,6 +1180,7 @@ def encode( truncate_input_tokens, sentences_batch, implicit_truncation_errors=implicit_truncation_errors, + **tokenizer_kwargs ) if truncation_needed: # truncation was needed and was not done/not allowed