From 9a3acd978c279e437bae488db1379cc7910ca0e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Li=20Yudong=20=28=E6=9D=8E=E7=85=9C=E4=B8=9C=29?= Date: Thu, 27 Jul 2023 16:22:57 +0800 Subject: [PATCH] Update generate_lm.py --- scripts/generate_lm.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/scripts/generate_lm.py b/scripts/generate_lm.py index 07511f6b..daeab336 100644 --- a/scripts/generate_lm.py +++ b/scripts/generate_lm.py @@ -83,6 +83,9 @@ def top_k_top_p_filtering(logits, top_k, top_p): model = GenerateLm(args) model = load_model(model, args.load_model_path) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + model.eval() with open(args.test_path, mode="r", encoding="utf-8") as f: @@ -93,7 +96,7 @@ def top_k_top_p_filtering(logits, top_k, top_p): if len(src) > args.seq_length: src = src[:args.seq_length] seg = seg[:args.seq_length] - src_tensor, seg_tensor = torch.LongTensor([src]), torch.LongTensor([seg]) + src_tensor, seg_tensor = torch.LongTensor([src]).to(device), torch.LongTensor([seg]).to(device) with open(args.prediction_path, mode="w", encoding="utf-8") as f: for i in range(args.seq_length - beginning_length): @@ -103,10 +106,13 @@ def top_k_top_p_filtering(logits, top_k, top_p): next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) src_tensor = torch.cat([src_tensor, next_token.view(1, 1)], dim=1) - seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]])], dim=1) + seg_tensor = torch.cat([seg_tensor, torch.tensor([[1]]).to(device)], dim=1) f.write(line + "\n") - generated_sentence = "".join( - args.tokenizer.convert_ids_to_tokens([token_id.item() for token_id in src_tensor[0]]) - ) + tokens = [token_id.item() for token_id in src_tensor[0]] + if args.tokenizer.sp_model is not None: + generated_sentence = args.tokenizer.sp_model.decode(tokens) + else: + generated_sentence = "".join(args.tokenizer.convert_ids_to_tokens(tokens)) + f.write(generated_sentence)