Skip to content

Commit

Permalink
✨ added run_tokenizer in text_generation_local.py
Browse files Browse the repository at this point in the history
- 🎨 tox formatting
- 🚧 added a test to assert the length of the run_tokenizer output
- 🚧 made a more comprehensive test for the run tokenizer method
  • Loading branch information
m-misiura committed Dec 3, 2024
1 parent 56b7e18 commit cd36806
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
7 changes: 6 additions & 1 deletion caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

TRAINING_LOSS_LOG_FILENAME = "training_logs.jsonl"


# pylint: disable=too-many-lines,too-many-instance-attributes
@module(
id="f9181353-4ccf-4572-bd1e-f12bcda26792",
Expand Down Expand Up @@ -590,7 +591,11 @@ def run_tokenizer(
TokenizationResults
The token count
"""
raise NotImplementedError("Tokenization not implemented for local")
error.type_check("<NLP48137045E>", str, text=text)
tokenized_output = self.model.tokenizer(text)
return TokenizationResults(
token_count=len(tokenized_output["input_ids"]),
)

################################## Private Functions ######################################

Expand Down
30 changes: 25 additions & 5 deletions tests/modules/text_generation/test_text_generation_local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for text-generation module
"""

# Standard
import os
import platform
Expand All @@ -10,7 +11,7 @@
import torch

# First Party
from caikit.interfaces.nlp.data_model import GeneratedTextResult
from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults
import caikit

# Local
Expand Down Expand Up @@ -211,7 +212,26 @@ def test_zero_epoch_case(disable_wip):
assert isinstance(model.model, HFAutoSeq2SeqLM)


def test_run_tokenizer_not_implemented():
with pytest.raises(NotImplementedError):
model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL)
model.run_tokenizer("This text doesn't matter")
# ############################## Run Tokenizer ################################


def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device):
"""Test tokenizer on edge cases like empty strings and long input."""
model = TextGeneration.bootstrap(CAUSAL_LM_MODEL)

# Edge case: Empty string
empty_result = model.run_tokenizer("")
assert isinstance(empty_result, TokenizationResults)
assert empty_result.token_count == 0

# Normal case: short sentence
short_text = "This is a test sentence."
short_result = model.run_tokenizer(short_text)
assert isinstance(short_result, TokenizationResults)
assert short_result.token_count > 0

# Edge case: Long input
long_text = "This is a test sentence. " * 1000
long_result = model.run_tokenizer(long_text)
assert isinstance(long_result, TokenizationResults)
assert long_result.token_count > 0

0 comments on commit cd36806

Please sign in to comment.