-
Notifications
You must be signed in to change notification settings - Fork 0
/
collator.py
74 lines (65 loc) · 3.39 KB
/
collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import warnings
from typing import Any, Dict, List, Union
import numpy as np
from transformers.data.data_collator import DataCollatorForLanguageModeling
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensure that the loss is only
calculated on the completion made by the assistant.
Args:
response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
differently if it does not have proper context.
mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
for flexibility and backwards-compatibility.
ignore_index (`int`, *optional*, defaults to `-100`):
The index to use to ignore the initial tokens with
"""
def __init__(
self,
response_template: Union[str, List[int]],
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
self.response_template = response_template
if isinstance(response_template, str):
# The user provides a string, must tokenize
self.response_token_ids = self.tokenizer.encode(
self.response_template, add_special_tokens=False
)
else:
# The user already provides the token ids
self.response_token_ids = response_template
self.ignore_index = ignore_index
def torch_call(self, examples: List[Union[List, Any, Dict]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if (
self.response_token_ids
== batch["labels"][i][
idx : idx + len(self.response_token_ids) # noqa: E203
].tolist()
):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the instance. "
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(
self.response_token_ids
)
# Make pytorch loss function ignore all tokens up through the end of the response key
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
return batch