-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
train.py
273 lines (226 loc) · 9.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# System libs
import os
import time
# import math
import random
import argparse
from distutils.version import LooseVersion
# Numerical libs
import torch
import torch.nn as nn
# Our libs
from mit_semseg.config import cfg
from mit_semseg.dataset import TrainDataset
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import AverageMeter, parse_devices, setup_logger
from mit_semseg.lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback
# train one epoch
def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
batch_time = AverageMeter()
data_time = AverageMeter()
ave_total_loss = AverageMeter()
ave_acc = AverageMeter()
segmentation_module.train(not cfg.TRAIN.fix_bn)
# main loop
tic = time.time()
for i in range(cfg.TRAIN.epoch_iters):
# load a batch of data
batch_data = next(iterator)
data_time.update(time.time() - tic)
segmentation_module.zero_grad()
# adjust learning rate
cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters
adjust_learning_rate(optimizers, cur_iter, cfg)
# forward pass
loss, acc = segmentation_module(batch_data)
loss = loss.mean()
acc = acc.mean()
# Backward
loss.backward()
for optimizer in optimizers:
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - tic)
tic = time.time()
# update average loss and acc
ave_total_loss.update(loss.data.item())
ave_acc.update(acc.data.item()*100)
# calculate accuracy, and display
if i % cfg.TRAIN.disp_iter == 0:
print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
'Accuracy: {:4.2f}, Loss: {:.6f}'
.format(epoch, i, cfg.TRAIN.epoch_iters,
batch_time.average(), data_time.average(),
cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder,
ave_acc.average(), ave_total_loss.average()))
fractional_epoch = epoch - 1 + 1. * i / cfg.TRAIN.epoch_iters
history['train']['epoch'].append(fractional_epoch)
history['train']['loss'].append(loss.data.item())
history['train']['acc'].append(acc.data.item())
def checkpoint(nets, history, cfg, epoch):
print('Saving checkpoints...')
(net_encoder, net_decoder, crit) = nets
dict_encoder = net_encoder.state_dict()
dict_decoder = net_decoder.state_dict()
torch.save(
history,
'{}/history_epoch_{}.pth'.format(cfg.DIR, epoch))
torch.save(
dict_encoder,
'{}/encoder_epoch_{}.pth'.format(cfg.DIR, epoch))
torch.save(
dict_decoder,
'{}/decoder_epoch_{}.pth'.format(cfg.DIR, epoch))
def group_weight(module):
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.conv._ConvNd):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.batchnorm._BatchNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
return groups
def create_optimizers(nets, cfg):
(net_encoder, net_decoder, crit) = nets
optimizer_encoder = torch.optim.SGD(
group_weight(net_encoder),
lr=cfg.TRAIN.lr_encoder,
momentum=cfg.TRAIN.beta1,
weight_decay=cfg.TRAIN.weight_decay)
optimizer_decoder = torch.optim.SGD(
group_weight(net_decoder),
lr=cfg.TRAIN.lr_decoder,
momentum=cfg.TRAIN.beta1,
weight_decay=cfg.TRAIN.weight_decay)
return (optimizer_encoder, optimizer_decoder)
def adjust_learning_rate(optimizers, cur_iter, cfg):
scale_running_lr = ((1. - float(cur_iter) / cfg.TRAIN.max_iters) ** cfg.TRAIN.lr_pow)
cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder * scale_running_lr
cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder * scale_running_lr
(optimizer_encoder, optimizer_decoder) = optimizers
for param_group in optimizer_encoder.param_groups:
param_group['lr'] = cfg.TRAIN.running_lr_encoder
for param_group in optimizer_decoder.param_groups:
param_group['lr'] = cfg.TRAIN.running_lr_decoder
def main(cfg, gpus):
# Network Builders
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
weights=cfg.MODEL.weights_decoder)
crit = nn.NLLLoss(ignore_index=-1)
if cfg.MODEL.arch_decoder.endswith('deepsup'):
segmentation_module = SegmentationModule(
net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale)
else:
segmentation_module = SegmentationModule(
net_encoder, net_decoder, crit)
# Dataset and Loader
dataset_train = TrainDataset(
cfg.DATASET.root_dataset,
cfg.DATASET.list_train,
cfg.DATASET,
batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)
loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=len(gpus), # we have modified data_parallel
shuffle=False, # we do not use this param
collate_fn=user_scattered_collate,
num_workers=cfg.TRAIN.workers,
drop_last=True,
pin_memory=True)
print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))
# create loader iterator
iterator_train = iter(loader_train)
# load nets into gpu
if len(gpus) > 1:
segmentation_module = UserScatteredDataParallel(
segmentation_module,
device_ids=gpus)
# For sync bn
patch_replication_callback(segmentation_module)
segmentation_module.cuda()
# Set up optimizers
nets = (net_encoder, net_decoder, crit)
optimizers = create_optimizers(nets, cfg)
# Main loop
history = {'train': {'epoch': [], 'loss': [], 'acc': []}}
for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
train(segmentation_module, iterator_train, optimizers, history, epoch+1, cfg)
# checkpointing
checkpoint(nets, history, cfg, epoch+1)
print('Training Done!')
if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(
description="PyTorch Semantic Segmentation Training"
)
parser.add_argument(
"--cfg",
default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument(
"--gpus",
default="0-3",
help="gpus to use, e.g. 0-3 or 0,1,2,3"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)
# cfg.freeze()
logger = setup_logger(distributed_rank=0) # TODO
logger.info("Loaded configuration file {}".format(args.cfg))
logger.info("Running with config:\n{}".format(cfg))
# Output directory
if not os.path.isdir(cfg.DIR):
os.makedirs(cfg.DIR)
logger.info("Outputing checkpoints to: {}".format(cfg.DIR))
with open(os.path.join(cfg.DIR, 'config.yaml'), 'w') as f:
f.write("{}".format(cfg))
# Start from checkpoint
if cfg.TRAIN.start_epoch > 0:
cfg.MODEL.weights_encoder = os.path.join(
cfg.DIR, 'encoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
cfg.MODEL.weights_decoder = os.path.join(
cfg.DIR, 'decoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
assert os.path.exists(cfg.MODEL.weights_encoder) and \
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
# Parse gpu ids
gpus = parse_devices(args.gpus)
gpus = [x.replace('gpu', '') for x in gpus]
gpus = [int(x) for x in gpus]
num_gpus = len(gpus)
cfg.TRAIN.batch_size = num_gpus * cfg.TRAIN.batch_size_per_gpu
cfg.TRAIN.max_iters = cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch
cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder
cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder
random.seed(cfg.TRAIN.seed)
torch.manual_seed(cfg.TRAIN.seed)
main(cfg, gpus)