-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
70 lines (60 loc) · 2.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn.functional as F
import pickle
import os
import options
import logger
from rl.qlearning_trainer_gc import (
QLearningGraphCenteredTrainer,
QLearningPrioritizedBufferGraphCenteredTrainer,
)
from model.graph_centered import GraphCenteredNet, GraphCenteredNetV2
from data.embedding import DirectionalPositionalEmbedding, DirectionalEmbedding
opt = options.parse_options()
opt.embedding = DirectionalPositionalEmbedding()
# Setup logger
logger.setup_logger(opt.logs, training_id=opt.training_id)
history_train = dict()
history_eval = dict()
if opt.use_prioritised_replay:
trainer = QLearningPrioritizedBufferGraphCenteredTrainer(opt)
else:
trainer = QLearningGraphCenteredTrainer(opt)
# Load pretrained models if any
if opt.pretrained:
trainer.policy_net.load_state_dict(torch.load(opt.weights_path))
trainer.target_net.load_state_dict(torch.load(opt.weights_path))
# Training loop
for epoch in range(opt.epochs):
# Train
train_epoch_info = trainer.train_one_epoch()
# monitor the information about training
for info in train_epoch_info:
if info not in history_train:
history_train[info] = [train_epoch_info[info]]
else:
history_train[info].append(train_epoch_info[info])
# Evaluate
eval_epoch_info = trainer.eval_one_epoch()
# monitor the information about training
for info in eval_epoch_info:
if info not in history_eval:
history_eval[info] = [eval_epoch_info[info]]
else:
history_eval[info].append(eval_epoch_info[info])
# Save weights and history
if epoch % opt.save_every == 0:
trainer.save_model()
with open(
os.path.join(opt.logs, opt.training_id, "train_history.pkl"), "wb"
) as f:
pickle.dump(history_train, f, pickle.HIGHEST_PROTOCOL)
with open(
os.path.join(opt.logs, opt.training_id, "eval_history.pkl"), "wb"
) as f:
pickle.dump(history_eval, f, pickle.HIGHEST_PROTOCOL)
trainer.save_model()
with open(os.path.join(opt.logs, opt.training_id, "train_history.pkl"), "wb") as f:
pickle.dump(history_train, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(opt.logs, opt.training_id, "eval_history.pkl"), "wb") as f:
pickle.dump(history_eval, f, pickle.HIGHEST_PROTOCOL)