-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_chitchat_gpt.py
65 lines (49 loc) · 2.44 KB
/
run_chitchat_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os.path
import torch
import transformers
class Chitchat(object):
def __init__(self, device, models_dir):
model_name = os.path.join(models_dir, 'rugpt_chitchat')
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
self.model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
self.model.to(device)
self.model.eval()
def reply(self, history, num_return_sequences):
prompt = '<s>' + '\n'.join(history) + '\nчатбот:'
encoded_prompt = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(device)
output_sequences = self.model.generate(input_ids=encoded_prompt,
max_length=len(prompt) + 120,
temperature=0.90,
typical_p=None,
top_k=0,
top_p=0.8,
do_sample=True,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.pad_token_id)
replies = []
for o in output_sequences:
reply = self.tokenizer.decode(o.tolist(), clean_up_tokenization_spaces=True)
reply = reply[len(prompt):] # отсекаем затравку
reply = reply[: reply.find('</s>')]
if '\nчеловек:' in reply:
reply = reply[:reply.index('\nчеловек:')]
reply = reply.strip()
if reply not in replies: # только уникальные реплики, сохраняем порядок выдачи
replies.append(reply)
return replies
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
models_dir = os.path.expanduser('~/polygon/chatbot/tmp')
chitchat = Chitchat(device, models_dir)
while True:
dialog = []
while True:
msg = input('H:> ').strip()
if msg:
dialog.append('человек: ' + msg)
reply = chitchat.reply(dialog, num_return_sequences=1)[0]
print(f'B:> {reply}')
dialog.append('чатбот: ' + reply)
else:
dialog = []
print('-'*100)