-
Notifications
You must be signed in to change notification settings - Fork 109
/
TaskForPretraining.py
304 lines (283 loc) · 15.5 KB
/
TaskForPretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import os
import logging
import sys
sys.path.append('../')
from utils import logger_init
from model import BertConfig
from model import BertForPretrainingModel
from utils import LoadBertPretrainingDataset
from transformers import BertTokenizer
from transformers import AdamW
from transformers import get_polynomial_decay_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy
import torch
import time
class ModelConfig:
def __init__(self):
self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# ========== wike2 数据集相关配置
# self.dataset_dir = os.path.join(self.project_dir, 'data', 'WikiText')
# self.pretrained_model_dir = os.path.join(self.project_dir, "bert_base_uncased_english")
# self.train_file_path = os.path.join(self.dataset_dir, 'wiki.train.tokens')
# self.val_file_path = os.path.join(self.dataset_dir, 'wiki.valid.tokens')
# self.test_file_path = os.path.join(self.dataset_dir, 'wiki.test.tokens')
# self.data_name = 'wiki2'
# ========== songci 数据集相关配置
self.dataset_dir = os.path.join(self.project_dir, 'data', 'SongCi')
self.pretrained_model_dir = os.path.join(self.project_dir, "bert_base_chinese")
self.train_file_path = os.path.join(self.dataset_dir, 'songci.train.txt')
self.val_file_path = os.path.join(self.dataset_dir, 'songci.valid.txt')
self.test_file_path = os.path.join(self.dataset_dir, 'songci.test.txt')
self.data_name = 'songci'
# 如果需要切换数据集,只需要更改上面的配置即可
self.vocab_path = os.path.join(self.pretrained_model_dir, 'vocab.txt')
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.model_save_dir = os.path.join(self.project_dir, 'cache')
self.logs_save_dir = os.path.join(self.project_dir, 'logs')
self.model_save_path = os.path.join(self.model_save_dir, f'model_{self.data_name}.bin')
self.writer = SummaryWriter(f"runs/{self.data_name}")
self.is_sample_shuffle = True
self.use_embedding_weight = True
self.batch_size = 16
self.max_sen_len = None # 为None时则采用每个batch中最长的样本对该batch中的样本进行padding
self.pad_index = 0
self.random_state = 2022
self.learning_rate = 4e-5
self.weight_decay = 0.1
self.masked_rate = 0.15
self.masked_token_rate = 0.8
self.masked_token_unchanged_rate = 0.5
self.log_level = logging.DEBUG
self.use_torch_multi_head = False # False表示使用model/BasicBert/MyTransformer中的多头实现
self.epochs = 200
self.model_val_per_epoch = 1
logger_init(log_file_name=self.data_name, log_level=self.log_level,
log_dir=self.logs_save_dir)
if not os.path.exists(self.model_save_dir):
os.makedirs(self.model_save_dir)
bert_config_path = os.path.join(self.pretrained_model_dir, "config.json")
bert_config = BertConfig.from_json_file(bert_config_path)
for key, value in bert_config.__dict__.items():
self.__dict__[key] = value
# 将当前配置打印到日志文件中
logging.info(" ### 将当前配置打印到日志文件中 ")
for key, value in self.__dict__.items():
logging.info(f"### {key} = {value}")
def train(config):
model = BertForPretrainingModel(config,
config.pretrained_model_dir)
last_epoch = -1
if os.path.exists(config.model_save_path):
checkpoint = torch.load(config.model_save_path)
last_epoch = checkpoint['last_epoch']
loaded_paras = checkpoint['model_state_dict']
model.load_state_dict(loaded_paras)
logging.info("## 成功载入已有模型,进行追加训练......")
model = model.to(config.device)
model.train()
bert_tokenize = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize
data_loader = LoadBertPretrainingDataset(vocab_path=config.vocab_path,
tokenizer=bert_tokenize,
batch_size=config.batch_size,
max_sen_len=config.max_sen_len,
max_position_embeddings=config.max_position_embeddings,
pad_index=config.pad_index,
is_sample_shuffle=config.is_sample_shuffle,
random_state=config.random_state,
data_name=config.data_name,
masked_rate=config.masked_rate,
masked_token_rate=config.masked_token_rate,
masked_token_unchanged_rate=config.masked_token_unchanged_rate)
train_iter, test_iter, val_iter = \
data_loader.load_train_val_test_data(test_file_path=config.test_file_path,
train_file_path=config.train_file_path,
val_file_path=config.val_file_path)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.weight_decay,
"initial_lr": config.learning_rate
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
"initial_lr": config.learning_rate
},
]
# optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
optimizer = AdamW(optimizer_grouped_parameters)
scheduler = get_polynomial_decay_schedule_with_warmup(optimizer,
int(len(train_iter) * 0),
int(config.epochs * len(train_iter)),
last_epoch=last_epoch)
max_acc = 0
state_dict = None
for epoch in range(config.epochs):
losses = 0
start_time = time.time()
for idx, (b_token_ids, b_segs, b_mask, b_mlm_label, b_nsp_label) in enumerate(train_iter):
b_token_ids = b_token_ids.to(config.device) # [src_len, batch_size]
b_segs = b_segs.to(config.device)
b_mask = b_mask.to(config.device)
b_mlm_label = b_mlm_label.to(config.device)
b_nsp_label = b_nsp_label.to(config.device)
loss, mlm_logits, nsp_logits = model(input_ids=b_token_ids,
attention_mask=b_mask,
token_type_ids=b_segs,
masked_lm_labels=b_mlm_label,
next_sentence_labels=b_nsp_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
losses += loss.item()
mlm_acc, _, _, nsp_acc, _, _ = accuracy(mlm_logits, nsp_logits, b_mlm_label,
b_nsp_label, data_loader.PAD_IDX)
if idx % 20 == 0:
logging.info(f"Epoch: [{epoch + 1}/{config.epochs}], Batch[{idx}/{len(train_iter)}], "
f"Train loss :{loss.item():.3f}, Train mlm acc: {mlm_acc:.3f},"
f"nsp acc: {nsp_acc:.3f}")
config.writer.add_scalar('Training/Loss', loss.item(), scheduler.last_epoch)
config.writer.add_scalar('Training/Learning Rate', scheduler.get_last_lr()[0], scheduler.last_epoch)
config.writer.add_scalars(main_tag='Training/Accuracy',
tag_scalar_dict={'NSP': nsp_acc,
'MLM': mlm_acc},
global_step=scheduler.last_epoch)
end_time = time.time()
train_loss = losses / len(train_iter)
logging.info(f"Epoch: [{epoch + 1}/{config.epochs}], Train loss: "
f"{train_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s")
if (epoch + 1) % config.model_val_per_epoch == 0:
mlm_acc, nsp_acc = evaluate(config, val_iter, model, data_loader.PAD_IDX)
logging.info(f" ### MLM Accuracy on val: {round(mlm_acc, 4)}, "
f"NSP Accuracy on val: {round(nsp_acc, 4)}")
config.writer.add_scalars(main_tag='Testing/Accuracy',
tag_scalar_dict={'NSP': nsp_acc,
'MLM': mlm_acc},
global_step=scheduler.last_epoch)
# mlm_acc, nsp_acc = evaluate(config, train_iter, model, data_loader.PAD_IDX)
if mlm_acc > max_acc:
max_acc = mlm_acc
state_dict = deepcopy(model.state_dict())
torch.save({'last_epoch': scheduler.last_epoch,
'model_state_dict': state_dict},
config.model_save_path)
def accuracy(mlm_logits, nsp_logits, mlm_labels, nsp_label, PAD_IDX):
"""
:param mlm_logits: [src_len,batch_size,src_vocab_size]
:param mlm_labels: [src_len,batch_size]
:param nsp_logits: [batch_size,2]
:param nsp_label: [batch_size]
:param PAD_IDX:
:return:
"""
mlm_pred = mlm_logits.transpose(0, 1).argmax(axis=2).reshape(-1)
# 将 [src_len,batch_size,src_vocab_size] 转成 [batch_size, src_len,src_vocab_size]
mlm_true = mlm_labels.transpose(0, 1).reshape(-1)
# 将 [src_len,batch_size] 转成 [batch_size, src_len]
mlm_acc = mlm_pred.eq(mlm_true) # 计算预测值与正确值比较的情况,得到预测正确的个数(此时还包括有mask位置)
mask = torch.logical_not(mlm_true.eq(PAD_IDX)) # 找到真实标签中,mask位置的信息。 mask位置为FALSE,非mask位置为TRUE
mlm_acc = mlm_acc.logical_and(mask) # 去掉mlm_acc中mask的部分
mlm_correct = mlm_acc.sum().item()
mlm_total = mask.sum().item()
mlm_acc = float(mlm_correct) / mlm_total
nsp_correct = (nsp_logits.argmax(1) == nsp_label).float().sum()
nsp_total = len(nsp_label)
nsp_acc = float(nsp_correct) / nsp_total
return [mlm_acc, mlm_correct, mlm_total, nsp_acc, nsp_correct, nsp_total]
def evaluate(config, data_iter, model, PAD_IDX):
model.eval()
mlm_corrects, mlm_totals, nsp_corrects, nsp_totals = 0, 0, 0, 0
with torch.no_grad():
for idx, (b_token_ids, b_segs, b_mask, b_mlm_label, b_nsp_label) in enumerate(data_iter):
b_token_ids = b_token_ids.to(config.device) # [src_len, batch_size]
b_segs = b_segs.to(config.device)
b_mask = b_mask.to(config.device)
b_mlm_label = b_mlm_label.to(config.device)
b_nsp_label = b_nsp_label.to(config.device)
mlm_logits, nsp_logits = model(input_ids=b_token_ids,
attention_mask=b_mask,
token_type_ids=b_segs)
result = accuracy(mlm_logits, nsp_logits, b_mlm_label, b_nsp_label, PAD_IDX)
_, mlm_cor, mlm_tot, _, nsp_cor, nsp_tot = result
mlm_corrects += mlm_cor
mlm_totals += mlm_tot
nsp_corrects += nsp_cor
nsp_totals += nsp_tot
model.train()
return [float(mlm_corrects) / mlm_totals, float(nsp_corrects) / nsp_totals]
def inference(config, sentences=None, masked=False, language='en', random_state=None):
"""
:param config:
:param sentences:
:param masked: 推理时的句子是否Mask
:param language: 语种
:param random_state: 控制mask字符时的随机状态
:return:
"""
bert_tokenize = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize
data_loader = LoadBertPretrainingDataset(vocab_path=config.vocab_path,
tokenizer=bert_tokenize,
pad_index=config.pad_index,
random_state=config.random_state,
masked_rate=0.15) # 15% Mask掉
token_ids, pred_idx, mask = data_loader.make_inference_samples(sentences,
masked=masked,
language=language,
random_state=random_state)
model = BertForPretrainingModel(config,
config.pretrained_model_dir)
if os.path.exists(config.model_save_path):
checkpoint = torch.load(config.model_save_path)
loaded_paras = checkpoint['model_state_dict']
model.load_state_dict(loaded_paras)
logging.info("## 成功载入已有模型进行推理......")
else:
raise ValueError(f"模型 {config.model_save_path} 不存在!")
model = model.to(config.device)
model.eval()
with torch.no_grad():
token_ids = token_ids.to(config.device) # [src_len, batch_size]
mask = mask.to(config.device)
mlm_logits, _ = model(input_ids=token_ids,
attention_mask=mask)
pretty_print(token_ids, mlm_logits, pred_idx,
data_loader.vocab.itos, sentences, language)
def pretty_print(token_ids, logits, pred_idx, itos, sentences, language):
"""
格式化输出结果
:param token_ids: [src_len, batch_size]
:param logits: [src_len, batch_size, vocab_size]
:param pred_idx: 二维列表,每个内层列表记录了原始句子中被mask的位置
:param itos:
:param sentences: 原始句子
:return:
"""
token_ids = token_ids.transpose(0, 1) # [batch_size,src_len]
logits = logits.transpose(0, 1) # [batch_size, src_len,vocab_size]
y_pred = logits.argmax(axis=2) # [batch_size, src_len]
sep = " " if language == 'en' else ""
for token_id, sentence, y, y_idx in zip(token_ids, sentences, y_pred, pred_idx):
sen = [itos[id] for id in token_id]
sen_mask = sep.join(sen).replace(" ##", "").replace("[PAD]", "").replace(" ,", ",")
sen_mask = sen_mask.replace(" .", ".").replace("[SEP]", "").replace("[CLS]", "").lstrip()
logging.info(f" ### 原始: {sentence}")
logging.info(f" ## 掩盖: {sen_mask}")
for idx in y_idx:
sen[idx] = itos[y[idx]].replace("##", "")
sen = sep.join(sen).replace("[PAD]", "").replace(" ,", ",")
sen = sen.replace(" .", ".").replace("[SEP]", "").replace("[CLS]", "").lstrip()
logging.info(f" ## 预测: {sen}")
logging.info("===============")
if __name__ == '__main__':
config = ModelConfig()
train(config)
sentences_1 = ["I no longer love her, true, but perhaps I love her.",
"Love is so short and oblivion so long."]
sentences_2 = ["十年生死两茫茫。不思量。自难忘。千里孤坟,无处话凄凉。",
"红酥手。黄藤酒。满园春色宫墙柳。"]
inference(config, sentences_2, masked=False, language='zh',random_state=2022)