-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
153 lines (121 loc) · 3.89 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
from typing import Any
import jsonlines
import tqdm
from datasets import Dataset
from datasets import Split
from transformers import AutoTokenizer
def load_data(
path: str,
tokenizer: AutoTokenizer,
split: str = 'train',
return_answers: bool = False,
max_size: int | None = None,
max_token: int = 1024,
) -> Dataset:
"""Load the OpenAssistant dataset.
Args:
path: Path to the dataset.
split: Split to load.
max_size: Maximum number of examples to load.
Defaults to None.
Returns:
Dataset: The dataset.
"""
assert split in ['train', 'test', 'all'], 'split must be either train, test or all.'
path = os.path.join(path, f'{split}.jsonl')
if not os.path.exists(path):
raise FileNotFoundError(f'{path} does not exist.')
with jsonlines.open(path) as reader:
data = [obj for obj in reader]
prompts, input_ids = [], []
if return_answers:
oa_ans, cgpt_ans = [], []
qa_prompt: str = '<|prompter|>{}<|endoftext|><|assistant|>'
for obj in tqdm.tqdm(
data,
total=len(data) if max_size is None else max_size,
desc='Loading data',
):
prompt = qa_prompt.format(obj['prompt'])
prompts.append(prompt)
tokenized_prompt = tokenizer(prompt, truncation=True)
input_ids.append(tokenized_prompt['input_ids'])
if return_answers:
oa_ans.append(obj['openassistant-answer'])
cgpt_ans.append(obj['chatgpt-answer'])
if max_size is not None and len(prompts) >= max_size:
break
if return_answers:
mapping = {
'query': prompts,
'input_ids': input_ids,
'openassistant-answer': oa_ans,
'chatgpt-answer': cgpt_ans,
}
else:
mapping = {
'query': prompts,
'input_ids': input_ids,
}
split = 'train' if split == 'all' else split
ds = Dataset.from_dict(
mapping,
split=Split.TRAIN if split == 'train' else Split.TEST,
)
ds = ds.filter(lambda x: len(x['input_ids']) <= max_token, batched=False)
ds.set_format(type='torch') # , columns=['input_ids'])
return ds
def get_tokenizer(
tokenizer_name: str,
pad_token_as_eos: bool = True,
padding_side: str | None = None,
) -> AutoTokenizer:
"""Get the tokenizer.
Args:
tokenizer_name: Name of the tokenizer.
pad_token_as_eos: Whether to use the pad token as the eos token.
Defaults to True.
Returns:
AutoTokenizer: The tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if pad_token_as_eos and tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if padding_side is not None:
assert padding_side in ['left', 'right'], 'padding_side must be either left or right.'
tokenizer.padding_side = padding_side
return tokenizer
# def collator(
# tokenizer: AutoTokenizer,
# max_token: int = 1024,
# ) -> Callable[..., Any]:
# """Collator function for the dataset.
#
# Args:
# tokenizer: The tokenizer.
# max_token: Maximum number of tokens in the input.
# Defaults to 1024.
#
# Returns:
# callable: The collator function.
#
# """
# def collate_fn(batch: list[dict[str, Any]]) -> list[dict[str, Any]]:
# input_ids = [obj['input_ids'] for obj in batch]
# input_ids = tokenizer.pad(
# input_ids,
# padding=True,
# max_length=max_token,
# return_tensors='pt',
# )
#
# return input_ids
#
# return collate_fn
def collator(data: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
"""Collator function for the dataset.
Args:
data (list[dict[str, Any]]): List of dictionaries.
"""
return {key: [d[key] for d in data] for key in data[0]}