generated from cavalleria/pytorch-template
-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathbase_trainer.py
200 lines (169 loc) · 7.5 KB
/
base_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from time import time
import os, math, json, logging, datetime
import torch
from utils.visualization import WriterTensorboardX
class BaseTrainer:
def __init__(self, model, loss, metrics, optimizer, resume, config, train_logger=None):
self.config = config
# Setup directory for checkpoint saving
start_time = datetime.datetime.now().strftime('%m%d_%H%M%S')
self.checkpoint_dir = os.path.join(config['trainer']['save_dir'], config['name'], start_time)
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Setup logger
'''
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(message)s",
handlers=[
logging.FileHandler(os.path.join(self.checkpoint_dir, "train.log")),
])
self.logger = logging.getLogger(self.__class__.__name__)
'''
fh = logging.FileHandler(os.path.join(self.checkpoint_dir, "train.log"))
fh.setLevel(logging.INFO)
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(logging.INFO)
self.logger.addHandler(fh)
# Setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(config['n_gpu'])
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.epochs = config['trainer']['epochs']
self.save_freq = config['trainer']['save_freq']
self.verbosity = config['trainer']['verbosity']
self.train_logger = train_logger
# configuration to monitor model performance and save best
self.monitor = config['trainer']['monitor']
self.monitor_mode = config['trainer']['monitor_mode']
assert self.monitor_mode in ['min', 'max', 'off']
self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf
self.start_epoch = 1
# setup visualization writer instance
writer_train_dir = os.path.join(config['visualization']['log_dir'], config['name'], start_time, "train")
writer_valid_dir = os.path.join(config['visualization']['log_dir'], config['name'], start_time, "valid")
self.writer_train = WriterTensorboardX(writer_train_dir, self.logger, config['visualization']['tensorboardX'])
self.writer_valid = WriterTensorboardX(writer_valid_dir, self.logger, config['visualization']['tensorboardX'])
# Save configuration file into checkpoint directory
config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
with open(config_save_path, 'w') as handle:
json.dump(config, handle, indent=4, sort_keys=False)
# Resume
if resume:
self._resume_checkpoint(resume)
def _prepare_device(self, n_gpu_use):
"""
setup GPU device if available, move model into configured device
"""
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
self.logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
msg = "Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu)
self.logger.warning(msg)
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids
def train(self):
for epoch in range(self.start_epoch, self.epochs + 1):
self.logger.info("\n----------------------------------------------------------------")
self.logger.info("[EPOCH %d]" % (epoch))
start_time = time()
result = self._train_epoch(epoch)
finish_time = time()
self.logger.info("Finish at {}, Runtime: {:.3f} [s]".format(datetime.datetime.now(), finish_time-start_time))
# save logged informations into log dict
log = {}
for key, value in result.items():
if key == 'train_metrics':
log.update({'train_' + mtr.__name__ : value[i] for i, mtr in enumerate(self.metrics)})
elif key == 'valid_metrics':
log.update({'valid_' + mtr.__name__ : value[i] for i, mtr in enumerate(self.metrics)})
else:
log[key] = value
# print logged informations to the screen
if self.train_logger is not None:
self.train_logger.add_entry(log)
if self.verbosity >= 1:
for key, value in sorted(list(log.items())):
self.logger.info('{:25s}: {}'.format(str(key), value))
# evaluate model performance according to configured metric, save best checkpoint as model_best
best = False
if self.monitor_mode != 'off':
try:
if (self.monitor_mode == 'min' and log[self.monitor] < self.monitor_best) or\
(self.monitor_mode == 'max' and log[self.monitor] > self.monitor_best):
self.logger.info("Monitor improved from %f to %f" % (self.monitor_best, log[self.monitor]))
self.monitor_best = log[self.monitor]
best = True
except KeyError:
if epoch == 1:
msg = "Warning: Can\'t recognize metric named '{}' ".format(self.monitor)\
+ "for performance monitoring. model_best checkpoint won\'t be updated."
self.logger.warning(msg)
# Save checkpoint
self._save_checkpoint(epoch, save_best=best)
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError
def _save_checkpoint(self, epoch, save_best=False):
"""
Saving checkpoints
:param epoch: current epoch number
:param log: logging information of the epoch
:param save_best: if True, rename the saved checkpoint to 'model_best.pth'
"""
# Construct savedict
arch = type(self.model).__name__
state = {
'arch': arch,
'epoch': epoch,
'logger': self.train_logger,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.monitor_best,
'config': self.config
}
# Save checkpoint for each epoch
if self.save_freq is not None: # Use None mode to avoid over disk space with large models
if epoch % self.save_freq == 0:
filename = os.path.join(self.checkpoint_dir, 'epoch{}.pth'.format(epoch))
torch.save(state, filename)
self.logger.info("Saving checkpoint at {}".format(filename))
# Save the best checkpoint
if save_best:
best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
torch.save(state, best_path)
self.logger.info("Saving current best at {}".format(best_path))
else:
self.logger.info("Monitor is not improved from %f" % (self.monitor_best))
def _resume_checkpoint(self, resume_path):
"""
Resume from saved checkpoints
:param resume_path: Checkpoint path to be resumed
"""
self.logger.info("Loading checkpoint: {}".format(resume_path))
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.monitor_best = checkpoint['monitor_best']
# load architecture params from checkpoint.
if checkpoint['config']['arch'] != self.config['arch']:
self.logger.warning('Warning: Architecture configuration given in config file is different from that of checkpoint. ' + \
'This may yield an exception while state_dict is being loaded.')
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
# # load optimizer state from checkpoint only when optimizer type is not changed.
# if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
# self.logger.warning('Warning: Optimizer type given in config file is different from that of checkpoint. ' + \
# 'Optimizer parameters not being resumed.')
# else:
# self.optimizer.load_state_dict(checkpoint['optimizer'])
self.train_logger = checkpoint['logger']
self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch-1))