diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 290551ba..ba19a585 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -55,6 +55,7 @@ TRAINING_LOSS_LOG_FILENAME = "training_logs.jsonl" + # pylint: disable=too-many-lines,too-many-instance-attributes @module( id="f9181353-4ccf-4572-bd1e-f12bcda26792", @@ -590,7 +591,11 @@ def run_tokenizer( TokenizationResults The token count """ - raise NotImplementedError("Tokenization not implemented for local") + error.type_check("", str, text=text) + tokenized_output = self.model.tokenizer(text, return_attention_mask=True) + return TokenizationResults( + token_count=len(tokenized_output["input_ids"]), + ) ################################## Private Functions ###################################### diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index f76400f1..5e91bea0 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -1,5 +1,6 @@ """Tests for text-generation module """ + # Standard import os import platform @@ -10,7 +11,7 @@ import torch # First Party -from caikit.interfaces.nlp.data_model import GeneratedTextResult +from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults import caikit # Local @@ -211,7 +212,26 @@ def test_zero_epoch_case(disable_wip): assert isinstance(model.model, HFAutoSeq2SeqLM) -def test_run_tokenizer_not_implemented(): - with pytest.raises(NotImplementedError): - model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) - model.run_tokenizer("This text doesn't matter") +# ############################## Run Tokenizer ################################ + + +def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device): + """Test tokenizer on edge cases like empty strings and long input.""" + model = TextGeneration.bootstrap(CAUSAL_LM_MODEL) + + # Edge case: Empty string + empty_result = model.run_tokenizer("") + assert isinstance(empty_result, TokenizationResults) + assert empty_result.token_count == 0 + + # Normal case: short sentence + short_text = "This is a test sentence." + short_result = model.run_tokenizer(short_text) + assert isinstance(short_result, TokenizationResults) + assert short_result.token_count == len(model.model.tokenizer.encode(short_text)) + + # Edge case: Long input + long_text = "This is a test sentence. " * 1000 + long_result = model.run_tokenizer(long_text) + assert isinstance(long_result, TokenizationResults) + assert long_result.token_count == len(model.model.tokenizer.encode(long_text))