Skip to content

Commit

Permalink
Merge pull request #701 from QData/oct-bug-fixes
Browse files Browse the repository at this point in the history
Fix bugs with t5
  • Loading branch information
jxmorris12 authored Nov 2, 2022
2 parents 5ac125a + 6e60ae6 commit bb56f61
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
13 changes: 10 additions & 3 deletions textattack/datasets/helpers/ted_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,19 @@ def __init__(self, source_lang="en", target_lang="de", split="test", shuffle=Fal
self.source_lang = source_lang
self.target_lang = target_lang
self.shuffled = shuffle
self.label_map = None
self.output_scale_factor = None
self.label_names = None
# self.input_columns = ("Source",)
# self.output_column = "Translation"

if shuffle:
self._dataset.shuffle()

def _format_raw_example(self, raw_example):
translations = np.array(raw_example["translation"])
languages = np.array(raw_example["language"])
def _format_as_dict(self, raw_example):
example = raw_example["translations"]
translations = np.array(example["translation"])
languages = np.array(example["language"])
source = translations[languages == self.source_lang][0]
target = translations[languages == self.target_lang][0]
source_dict = collections.OrderedDict([("Source", source)])
Expand Down
6 changes: 5 additions & 1 deletion textattack/goal_functions/text/text_to_text_goal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
-------------------------------------------------------
"""

import numpy as np

from textattack.goal_function_results import TextToTextGoalFunctionResult
from textattack.goal_functions import GoalFunction
Expand All @@ -22,7 +23,10 @@ def _goal_function_result_type(self):

def _process_model_outputs(self, _, outputs):
"""Processes and validates a list of model outputs."""
return outputs.flatten()
if isinstance(outputs, np.ndarray):
return outputs.flatten()
else:
return outputs

def _get_displayed_output(self, raw_output):
return raw_output
4 changes: 2 additions & 2 deletions textattack/models/tokenizers/t5_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, mode="english_to_german", max_length=64):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
"t5-base", use_fast=True
)
self.max_length = max_length
self.model_max_length = max_length

def __call__(self, text, *args, **kwargs):
"""
Expand All @@ -55,7 +55,7 @@ def __call__(self, text, *args, **kwargs):
else:
for i in range(len(text)):
text[i] = self.tokenization_prefix + text[i]
return self.tokenizer(text, *args, max_length=self.max_length, **kwargs)
return self.tokenizer(text, *args, **kwargs)

def decode(self, ids):
"""Converts IDs (typically generated by the model) back to a string."""
Expand Down

0 comments on commit bb56f61

Please sign in to comment.