From 7eeb0efff13e1898ca5d8e9362ec5e1a93d8f329 Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Tue, 1 Nov 2022 11:26:14 -0400 Subject: [PATCH 1/2] update for t5 --- textattack/datasets/helpers/ted_multi.py | 14 +++++++++++--- .../text/text_to_text_goal_function.py | 6 +++++- textattack/models/tokenizers/t5_tokenizer.py | 4 ++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/textattack/datasets/helpers/ted_multi.py b/textattack/datasets/helpers/ted_multi.py index 616a2e805..39574019c 100644 --- a/textattack/datasets/helpers/ted_multi.py +++ b/textattack/datasets/helpers/ted_multi.py @@ -11,6 +11,7 @@ import numpy as np from textattack.datasets import HuggingFaceDataset +from textattack.datasets.huggingface_dataset import get_datasets_dataset_columns class TedMultiTranslationDataset(HuggingFaceDataset): @@ -35,12 +36,19 @@ def __init__(self, source_lang="en", target_lang="de", split="test", shuffle=Fal self.source_lang = source_lang self.target_lang = target_lang self.shuffled = shuffle + self.label_map = None + self.output_scale_factor = None + self.label_names = None + # self.input_columns = ("Source",) + # self.output_column = "Translation" + if shuffle: self._dataset.shuffle() - def _format_raw_example(self, raw_example): - translations = np.array(raw_example["translation"]) - languages = np.array(raw_example["language"]) + def _format_as_dict(self, raw_example): + example = raw_example["translations"] + translations = np.array(example["translation"]) + languages = np.array(example["language"]) source = translations[languages == self.source_lang][0] target = translations[languages == self.target_lang][0] source_dict = collections.OrderedDict([("Source", source)]) diff --git a/textattack/goal_functions/text/text_to_text_goal_function.py b/textattack/goal_functions/text/text_to_text_goal_function.py index 9e4bac3be..341140768 100644 --- a/textattack/goal_functions/text/text_to_text_goal_function.py +++ b/textattack/goal_functions/text/text_to_text_goal_function.py @@ -4,6 +4,7 @@ ------------------------------------------------------- """ +import numpy as np from textattack.goal_function_results import TextToTextGoalFunctionResult from textattack.goal_functions import GoalFunction @@ -22,7 +23,10 @@ def _goal_function_result_type(self): def _process_model_outputs(self, _, outputs): """Processes and validates a list of model outputs.""" - return outputs.flatten() + if isinstance(outputs, np.ndarray): + return outputs.flatten() + else: + return outputs def _get_displayed_output(self, raw_output): return raw_output diff --git a/textattack/models/tokenizers/t5_tokenizer.py b/textattack/models/tokenizers/t5_tokenizer.py index a252e9134..f90aa04c4 100644 --- a/textattack/models/tokenizers/t5_tokenizer.py +++ b/textattack/models/tokenizers/t5_tokenizer.py @@ -38,7 +38,7 @@ def __init__(self, mode="english_to_german", max_length=64): self.tokenizer = transformers.AutoTokenizer.from_pretrained( "t5-base", use_fast=True ) - self.max_length = max_length + self.model_max_length = max_length def __call__(self, text, *args, **kwargs): """ @@ -55,7 +55,7 @@ def __call__(self, text, *args, **kwargs): else: for i in range(len(text)): text[i] = self.tokenization_prefix + text[i] - return self.tokenizer(text, *args, max_length=self.max_length, **kwargs) + return self.tokenizer(text, *args, **kwargs) def decode(self, ids): """Converts IDs (typically generated by the model) back to a string.""" From 6e60ae6a8d462f22a7cc01740b387d930383530d Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Wed, 2 Nov 2022 14:21:03 -0400 Subject: [PATCH 2/2] remove unnecessary import --- textattack/datasets/helpers/ted_multi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/textattack/datasets/helpers/ted_multi.py b/textattack/datasets/helpers/ted_multi.py index 39574019c..9e36c2694 100644 --- a/textattack/datasets/helpers/ted_multi.py +++ b/textattack/datasets/helpers/ted_multi.py @@ -11,7 +11,6 @@ import numpy as np from textattack.datasets import HuggingFaceDataset -from textattack.datasets.huggingface_dataset import get_datasets_dataset_columns class TedMultiTranslationDataset(HuggingFaceDataset):