This repository has been archived by the owner on Sep 15, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
352 lines (262 loc) · 9.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
from corpus import corpus, dictionary
from datetime import datetime
from model import rnn
import os
from time import time
import torch
from torch import nn, optim
from torch.tensor import Tensor
from typing import List, Tuple
from utils import duration_since, plot
def init_corpus() -> None:
'''
Initialize a corpus. Read datasets from JSON files.
'''
global cp
cp = corpus()
cp.get_all_text_data(all_in_one=False)
for year in range(2019, 2022):
cp.read_data(str(year))
print(f'Dictionary size: {cp.dictionary.len()}')
def init_model() -> None:
'''
Initialize the training model.
'''
dict_size = cp.dictionary.len()
global m, criterion, optimizer
m = rnn(dict_size, hidden_size, dict_size, num_layers, dropout).to(device)
load_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(m.parameters(), learning_rate)
def words_to_tensor(words: List[str]) -> Tensor:
'''
Convert a sentence to a tensor.
Return a tensor.
:param words: a preprocessed word list of the sentence
'''
tensor: Tensor = torch.zeros(len(words), device=device).long()
for i in range(len(words)):
ran_i: int = torch.randint(
cp.dictionary.start_pos, cp.dictionary.len(), (1,),
)[0]
tensor[i] = cp.dictionary.word2idx.get(words[i], ran_i)
return tensor
def get_random_words(count: int = 1, dataset: str = 'dev') -> List[str]:
'''
Return a sequence of random words from the dataset.
:param count: how many words are required
:param dataset: which dataset, can be `'train'`, `'dev'` or `'test'`
'''
if dataset == 'dev':
src = cp.dev_set
elif dataset == 'test':
src = cp.test_set
else:
src = cp.train_set
max_i: int = len(src) - count
i: int = torch.randint(0, max_i, (1,))[0]
words: List[str] = src[i:i+count]
return words
def get_random_pair(dataset: str = 'train') -> Tuple[Tensor, Tensor]:
'''
Return a random pair of input and target from the dataset.
:param dataset: which dataset, can be `'train'`, `'dev'` or `'test'`
'''
if dataset == 'dev':
src = cp.dev_set
elif dataset == 'test':
src = cp.test_set
else:
src = cp.train_set
max_i: int = len(src) - chunk_size
i: int = torch.randint(0, max_i, (1,))[0]
inp_words: List[str] = src[i:i+chunk_size]
inp: Tensor = words_to_tensor(inp_words)
tar_words: List[str] = src[i+1:i+1+chunk_size]
tar: Tensor = words_to_tensor(tar_words)
return inp, tar
def train(inp: Tensor, tar: Tensor) -> float:
'''
Train the model using a pair of input and target.
Return the loss.
:param inp: input tensor
:param tar: target tensor
'''
m.train()
m.zero_grad()
hid: Tensor = m.init_hidden()
loss: Tensor = 0
for i in range(inp.size(0)):
out, hid = m(inp[i], hid)
loss += criterion(out, tar[i].view(-1))
loss.backward()
nn.utils.clip_grad_norm_(m.parameters(), clip)
optimizer.step()
return loss.item() / chunk_size
def validate(inp: Tensor, tar: Tensor) -> float:
'''
Validate the model using a pair of input and target.
Return the loss.
:param inp: input tensor
:param tar: target tensor
'''
m.eval()
hid: Tensor = m.init_hidden()
loss: Tensor = 0
with torch.no_grad():
for i in range(inp.size(0)):
out, hid = m(inp[i], hid)
loss += criterion(out, tar[i].view(-1))
return loss.item() / chunk_size
def train_model() -> Tuple[List[float], List[float]]:
'''
The main training function.
Return all training losses and all validation losses.
'''
all_train_losses: List[float] = []
all_valid_losses: List[float] = []
total_train_loss: float = 0.0
total_valid_loss: float = 0.0
min_valid_loss: float = 4.0
for epoch in range(1, num_epochs + 1):
train_loss: float = train(*get_random_pair('train'))
valid_loss: float = validate(*get_random_pair('dev'))
total_train_loss += train_loss
total_valid_loss += valid_loss
if valid_loss < min_valid_loss:
save_model(valid_loss)
min_valid_loss = valid_loss
if epoch % print_every == 0:
progress: float = epoch / num_epochs * 100
print(
'{}: ({} {:.1f}%) train_loss: {:.3f}, valid_loss: {:.3f}'
.format(
duration_since(start_time), epoch, progress,
train_loss, valid_loss,
)
)
evaluate_model()
if epoch % plot_every == 0:
all_train_losses.append(total_train_loss / plot_every)
all_valid_losses.append(total_valid_loss / plot_every)
total_train_loss = 0.0
total_valid_loss = 0.0
return all_train_losses, all_valid_losses
def evaluate(prime_words: List[str] = None, predict_len: int = 30,
temperature: float = 0.8) -> List[str]:
'''
Evaluate the network by generating a sentence using a priming word.
To evaluate the network we feed one word at a time, use the outputs of the
network as a probability distribution for the next word, and repeat.
To start generation we pass some priming words to start setting up the
hidden state, from which we then generate one word at a time.
Return the predicted words.
:param prime_words: priming words to start
:param predict_len: expected length of words to predict
:param temperature: randomness of predictions; higher value results in more diversity
'''
hid: Tensor = m.init_hidden()
if not prime_words:
prime_words = [cp.dictionary.sos]
with torch.no_grad():
prime_inp: Tensor = words_to_tensor(prime_words)
predicted_words: List[str] = prime_words
for p in range(len(prime_words) - 1):
_, hid = m(prime_inp[p], hid)
inp: Tensor = prime_inp[-1]
for p in range(predict_len):
out, hid = m(inp, hid)
# Sample from the network as a multinomial distribution
out_dist: Tensor = out.view(-1).div(temperature).exp()
top_i: int = torch.multinomial(out_dist, 1)[0]
# Add predicted word to words and use as next input
predicted_word: str = cp.dictionary.idx2word[top_i]
predicted_words.append(predicted_word)
# if (predicted_word == cp.dictionary.eos):
# break
inp.fill_(top_i)
return predicted_words
def evaluate_model(save: bool = False) -> None:
'''
The main evaluating function.
:param save: save the output to local file
'''
m.eval()
prime_words: List[str] = get_random_words(prime_len, 'dev')
predicted_words: List[str] = evaluate(
prime_words, predict_len, temperature,
)
output: List[str] = ' '.join(predicted_words)
if save:
current_time: str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(output_path, 'a') as f:
f.write(f'{current_time}:\n{output}\n\n')
else:
print(output)
def generate() -> None:
'''
Generate new sentences using the best model, and save to local file.
'''
load_model()
for i in range(1, batch_size + 1):
progress: float = i / batch_size * 100
print(f'({i} {progress:.1f}%)', end='\r', flush=True)
evaluate_model(save=True)
def save_model(loss: float) -> None:
'''
Save the current model.
:param loss: current loss
'''
with open(model_path, 'wb') as f:
torch.save(m.state_dict(), f)
print(duration_since(start_time) + f': Model saved, {loss:.3f}')
def load_model() -> None:
'''
Load the best model from file.
'''
try:
with open(model_path, 'rb') as f:
m.load_state_dict(torch.load(f))
except FileNotFoundError:
pass
def main() -> None:
'''
The main function of Trump-bot.
'''
init_corpus()
print(duration_since(start_time) + ': Reading dataset done.')
init_model()
print(duration_since(start_time) + ': Training model initialized.')
all_train_losses, all_valid_losses = train_model()
print(duration_since(start_time) + ': Training model done.')
plot(num_epochs, plot_every, all_train_losses, all_valid_losses)
print(duration_since(start_time) + ': Plotting done.')
generate()
print(duration_since(start_time) + ': New sentences generated.')
if __name__ == '__main__':
# Parameters
hidden_size = 1500
num_layers = 3
dropout = 0.4
learning_rate = 0.0003
num_epochs = 10000
batch_size = 50
chunk_size = 20
prime_len = 5
predict_len = 100
temperature = 0.7
clip = 0.25
random_seed = 1234
print_every = 250
plot_every = 250
model_path = os.path.realpath('model/model.pt')
output_path = os.path.realpath('output/output.txt')
# Set the random seed manually for reproducibility.
torch.manual_seed(random_seed)
# Enable CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
start_time = time()
try:
main()
except KeyboardInterrupt:
print('\nAborted.')