Skip to content

Commit

Permalink
Update generate_lm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ydli-ai authored Jul 27, 2023
1 parent 5680879 commit 9a3acd9
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions scripts/generate_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def top_k_top_p_filtering(logits, top_k, top_p):

model = GenerateLm(args)
model = load_model(model, args.load_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

model.eval()

with open(args.test_path, mode="r", encoding="utf-8") as f:
Expand All @@ -93,7 +96,7 @@ def top_k_top_p_filtering(logits, top_k, top_p):
if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src_tensor, seg_tensor = torch.LongTensor([src]), torch.LongTensor([seg])
src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device)

with open(args.prediction_path, mode="w", encoding="utf-8") as f:
for i in range(args.seq_length - beginning_length):
Expand All @@ -103,10 +106,13 @@ def top_k_top_p_filtering(logits, top_k, top_p):
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]])], dim=1)
seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1)

f.write(line + "\n")
generated_sentence = "".join(
args.tokenizer.convert_ids_to_tokens([token_id.item() for token_id in src_tensor[0]])
)
tokens = [token_id.item() for token_id in src_tensor[0]]
if args.tokenizer.sp_model is not None:
generated_sentence = args.tokenizer.sp_model.decode(tokens)
else:
generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens))

f.write(generated_sentence)

0 comments on commit 9a3acd9

Please sign in to comment.