forked from fawazsammani/chatbot-transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat.py
60 lines (48 loc) · 2.01 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
import torch
from torch.utils.data import Dataset
import torch.utils.data
from models import *
from utils import *
load_checkpoint = True
ckpt_path = 'checkpoint.pth.tar'
def evaluate(transformer, question, question_mask, max_len, word_map):
"""
Performs Greedy Decoding with a batch size of 1
"""
rev_word_map = {v: k for k, v in word_map.items()}
transformer.eval()
start_token = word_map['<start>']
encoded = transformer.encode(question, question_mask)
words = torch.LongTensor([[start_token]]).to(device)
for step in range(max_len - 1):
size = words.shape[1]
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
decoded = transformer.decode(words, target_mask, encoded, question_mask)
predictions = transformer.logit(decoded[:, -1])
_, next_word = torch.max(predictions, dim = 1)
next_word = next_word.item()
if next_word == word_map['<end>']:
break
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2)
# Construct Sentence
if words.dim() == 2:
words = words.squeeze(0)
words = words.tolist()
sen_idx = [w for w in words if w not in {word_map['<start>']}]
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
return sentence
if load_checkpoint:
checkpoint = torch.load(ckpt_path)
transformer = checkpoint['transformer']
while(1):
question = input("Question: ")
if question == 'quit':
break
max_len = input("Maximum Reply Length: ")
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
print(sentence)