Skip to content

Commit

Permalink
Update run_text2text.py (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric8932 authored Aug 22, 2023
1 parent 3b2ba28 commit 461373d
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions finetune/run_text2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def __init__(self, args):
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab))
self.tgt_embedding.update(tmp_emb, embedding_name)
self.decoder = str2decoder[args.decoder](args)
self.target = LmTarget(args, len(args.tokenizer.vocab))
self.target = Target()
self.target.update(LmTarget(args, len(args.tokenizer.vocab)), "lm")
if args.tie_weights:
self.target.output_layer.weight = self.embedding.word.embedding.weight
self.target.lm.output_layer.weight = self.embedding.word.embedding.weight
if args.share_embedding:
self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight

Expand All @@ -40,28 +41,28 @@ def encode(self, src, seg):
memory_bank = self.encoder(emb, seg)
return memory_bank

def decode(self, src, memory_bank, tgt):
def decode(self, src, memory_bank, tgt, tgt_seg):
tgt_in, tgt_out, _ = tgt
decoder_emb = self.tgt_embedding(tgt_in, None)
decoder_emb = self.tgt_embedding(tgt_in, tgt_seg)
hidden = self.decoder(memory_bank, decoder_emb, (src,))
output = self.target.output_layer(hidden)
output = self.target.lm.output_layer(hidden)
return output

def forward(self, src, tgt, seg, memory_bank=None, only_use_encoder=False):
def forward(self, src, tgt, seg, tgt_seg, memory_bank=None, only_use_encoder=False):
if only_use_encoder:
return self.encode(src, seg)
if memory_bank is not None:
return self.decode(src, memory_bank, tgt)
return self.decode(src, memory_bank, tgt, tgt_seg)
tgt_in, tgt_out, _ = tgt
memory_bank = self.encode(src, seg)
output = self.decode(src, memory_bank, tgt)
if tgt_out is None:
output = self.decode(src, memory_bank, tgt)
return None, output
else:
decoder_emb = self.tgt_embedding(tgt_in, None)
decoder_emb = self.tgt_embedding(tgt_in, tgt_seg)
hidden = self.decoder(memory_bank, decoder_emb, (seg,))
loss = self.target(hidden, tgt_out, seg)[0]
return loss, output
loss = self.target(hidden, tgt_out, tgt_seg)[0]
return loss, None


def read_dataset(args, path):
Expand All @@ -84,12 +85,14 @@ def read_dataset(args, path):
tgt_in = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(label) + [SEP_TOKEN])
PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
seg = [1] * len(src)
tgt_seg = [1] * len(tgt_in)

if len(src) > args.seq_length:
src = src[: args.seq_length]
seg = seg[: args.seq_length]
if len(tgt_in) > args.tgt_seq_length:
tgt_in = tgt_in[: args.tgt_seq_length]
tgt_seg = tgt_seg[: args.tgt_seq_length]
tgt_out = tgt_in[1:] + [PAD_ID]

while len(src) < args.seq_length:
Expand All @@ -98,38 +101,42 @@ def read_dataset(args, path):
while len(tgt_in) < args.tgt_seq_length:
tgt_in.append(PAD_ID)
tgt_out.append(PAD_ID)
tgt_seg.append(0)

dataset.append((src, tgt_in, tgt_out, seg))
dataset.append((src, tgt_in, tgt_out, seg, tgt_seg))

return dataset


def batch_loader(batch_size, src, tgt_in, tgt_out, seg):
def batch_loader(batch_size, src, tgt_in, tgt_out, seg, tgt_seg):
instances_num = src.size()[0]
for i in range(instances_num // batch_size):
src_batch = src[i * batch_size : (i + 1) * batch_size, :]
tgt_in_batch = tgt_in[i * batch_size : (i + 1) * batch_size, :]
tgt_out_batch = tgt_out[i * batch_size : (i + 1) * batch_size, :]
seg_batch = seg[i * batch_size : (i + 1) * batch_size, :]
yield src_batch, tgt_in_batch, tgt_out_batch, seg_batch, None
tgt_seg_batch = tgt_seg[i * batch_size : (i + 1) * batch_size, :]
yield src_batch, tgt_in_batch, tgt_out_batch, seg_batch, tgt_seg_batch

if instances_num > instances_num // batch_size * batch_size:
src_batch = src[instances_num // batch_size * batch_size :, :]
tgt_in_batch = tgt_in[instances_num // batch_size * batch_size :, :]
tgt_out_batch = tgt_out[instances_num // batch_size * batch_size :, :]
seg_batch = seg[instances_num // batch_size * batch_size :, :]
yield src_batch, tgt_in_batch, tgt_out_batch, seg_batch, None
tgt_seg_batch = tgt_seg[instances_num // batch_size * batch_size :, :]
yield src_batch, tgt_in_batch, tgt_out_batch, seg_batch, tgt_seg_batch


def train_model(args, model, optimizer, scheduler, src_batch, tgt_in_batch, tgt_out_batch, seg_batch):
def train_model(args, model, optimizer, scheduler, src_batch, tgt_in_batch, tgt_out_batch, seg_batch, tgt_seg_batch):
model.zero_grad()

src_batch = src_batch.to(args.device)
tgt_in_batch = tgt_in_batch.to(args.device)
tgt_out_batch = tgt_out_batch.to(args.device)
seg_batch = seg_batch.to(args.device)
tgt_seg_batch = tgt_seg_batch.to(args.device)

loss, _ = model(src_batch, (tgt_in_batch, tgt_out_batch, src_batch), seg_batch)
loss, _ = model(src_batch, (tgt_in_batch, tgt_out_batch, src_batch), seg_batch, tgt_seg_batch)

if torch.cuda.device_count() > 1:
loss = torch.mean(loss)
Expand All @@ -152,31 +159,33 @@ def evaluate(args, dataset):
tgt_in = torch.LongTensor([example[1] for example in dataset])
tgt_out = torch.LongTensor([example[2] for example in dataset])
seg = torch.LongTensor([example[3] for example in dataset])
tgt_seg = torch.LongTensor([example[4] for example in dataset])

generated_sentences = []
args.model.eval()

for i, (src_batch, tgt_in_batch, tgt_out_batch, seg_batch, _) in enumerate(batch_loader(args.batch_size, src, tgt_in, tgt_out, seg)):
for i, (src_batch, tgt_in_batch, tgt_out_batch, seg_batch, tgt_seg_batch) in enumerate(batch_loader(args.batch_size, src, tgt_in, tgt_out, seg, tgt_seg)):

src_batch = src_batch.to(args.device)
tgt_in_batch = torch.zeros(tgt_in_batch.size()[0], 1, dtype=torch.long, device=args.device)
tgt_seg_batch = torch.ones(tgt_in_batch.size()[0], 1, dtype=torch.long, device=args.device)
for j in range(tgt_in_batch.size()[0]):
tgt_in_batch[j][-1] = args.tokenizer.vocab.get(CLS_TOKEN)

seg_batch = seg_batch.to(args.device)

with torch.no_grad():
memory_bank = args.model(src_batch, None, seg_batch, only_use_encoder=True)
memory_bank = args.model(src_batch, None, seg_batch, tgt_seg_batch, only_use_encoder=True)

for _ in range(args.tgt_seq_length):
tgt_out_batch = tgt_in_batch
with torch.no_grad():
outputs = args.model(src_batch, (tgt_in_batch, tgt_out_batch, src_batch), None, memory_bank=memory_bank)
outputs = args.model(src_batch, (tgt_in_batch, tgt_out_batch, src_batch), None, tgt_seg_batch, memory_bank=memory_bank)

next_token_logits = outputs[:, -1]
next_tokens = torch.argmax(next_token_logits, dim=1).unsqueeze(1)
tgt_in_batch = torch.cat([tgt_in_batch, next_tokens], dim=1)

tgt_seg_batch = torch.ones(tgt_in_batch.size()[0], tgt_in_batch.size()[1], dtype=torch.long, device=args.device)
for j in range(len(outputs)):
sentence = " ".join([args.tokenizer.inv_vocab[token_id.item()] for token_id in tgt_in_batch[j][1:]])
generated_sentences.append(sentence)
Expand Down Expand Up @@ -276,10 +285,11 @@ def main():
tgt_in = torch.LongTensor([example[1] for example in trainset])
tgt_out = torch.LongTensor([example[2] for example in trainset])
seg = torch.LongTensor([example[3] for example in trainset])
tgt_seg = torch.LongTensor([example[4] for example in trainset])

model.train()
for i, (src_batch, tgt_in_batch, tgt_out_batch, seg_batch, _) in enumerate(batch_loader(batch_size, src, tgt_in, tgt_out, seg)):
loss = train_model(args, model, optimizer, scheduler, src_batch, tgt_in_batch, tgt_out_batch, seg_batch)
for i, (src_batch, tgt_in_batch, tgt_out_batch, seg_batch, tgt_seg_batch) in enumerate(batch_loader(batch_size, src, tgt_in, tgt_out, seg, tgt_seg)):
loss = train_model(args, model, optimizer, scheduler, src_batch, tgt_in_batch, tgt_out_batch, seg_batch, tgt_seg_batch)
total_loss += loss.item()
if (i + 1) % args.report_steps == 0:
args.logger.info("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i + 1, total_loss / args.report_steps))
Expand Down

0 comments on commit 461373d

Please sign in to comment.