diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index c1eb71736..f2db1f10f 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -42,6 +42,8 @@ SentenceSimilarityResult, SentenceSimilarityResults, SentenceSimilarityScores, + Token, + TokenizationResults, ) from caikit.interfaces.nlp.tasks import ( EmbeddingTask, @@ -50,6 +52,7 @@ RerankTasks, SentenceSimilarityTask, SentenceSimilarityTasks, + TokenizationTask, ) import alog @@ -120,6 +123,7 @@ class TruncatedTokensTuple(NamedTuple): SentenceSimilarityTasks, RerankTask, RerankTasks, + TokenizationTask, ], ) class EmbeddingModule(ModuleBase): @@ -192,6 +196,29 @@ def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argumen "sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(), } + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + result = self.model.tokenizer(text, return_offsets_mapping=True) + + mapping = [ + interv for interv in result.offset_mapping if (interv[1] - interv[0]) > 0 + ] + tokens = [Token(start=i[0], end=i[1], text=text[i[0] : i[1]]) for i in mapping] + + return TokenizationResults(token_count=len(result.input_ids), results=tokens) + @classmethod def _get_ipex(cls, ipex_flag): """Get IPEX optimization library if enabled and available, else return False