-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
415 lines (331 loc) · 16.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
import os
import time
import numpy as np
import argparse
import torch
import wandb
import matplotlib.pyplot as plt
from collections import OrderedDict
from typing import Callable, Any
# opt
import torch.cuda.amp as amp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from apex import optimizers
# logging, yparams
import logging
from utils import logging_utils
logging_utils.config_logger()
from utils.YParams import YParams
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap as ruamelDict
# metrics, utils, data
from utils import get_data_loader_distributed
from utils.weighted_acc_rmse import weighted_rmse_torch
from utils.img_utils import vis
from utils.preprocess_utils import PreProcessor
from utils.losses import LossHandler
from networks.helpers import get_model
def ckpt_identity(layer: Callable, *args: Any, **kwargs: Any) -> Any:
"""Identity function for when activation checkpointing is not needed"""
return layer(*args)
def set_seed(params, world_size):
seed = params.seed
if seed is None:
seed = np.random.randint(10000)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if world_size > 0:
torch.cuda.manual_seed_all(seed)
class Trainer():
def count_parameters(self):
return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
def __init__(self, params, args):
self.sweep_id = args.sweep_id
self.root_dir = params['exp_dir']
self.config = args.config
params['enable_amp'] = args.enable_amp
self.world_size = 1
if 'WORLD_SIZE' in os.environ:
self.world_size = int(os.environ['WORLD_SIZE'])
self.local_rank = 0
self.world_rank = 0
if self.world_size > 1:
dist.init_process_group(backend='nccl',
init_method='env://')
self.world_rank = dist.get_rank()
self.local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(self.local_rank)
torch.backends.cudnn.benchmark = True
self.log_to_screen = params.log_to_screen and self.world_rank==0
self.log_to_wandb = params.log_to_wandb and self.world_rank==0
self.device = torch.cuda.current_device()
self.params = params
self.params['name'] = args.config + '_' + str(args.run_num)
self.params['group'] = args.config
# for dali data loader, set the actual number of data shards and id
self.params['data_num_shards'] = self.world_size
self.params['data_shard_id'] = self.world_rank
self.config = args.config
self.run_num = args.run_num
self.ckpt_fn = torch.utils.checkpoint.checkpoint if hasattr(params, 'activation_ckpt') and params.activation_ckpt else ckpt_identity
def build_and_launch(self):
self.params['in_channels'] = np.array(self.params['in_channels'])
self.params['out_channels'] = np.array(self.params['out_channels'])
self.params['n_in_channels'] = len(self.params['in_channels'])
self.params['n_out_channels'] = len(self.params['out_channels'])
if self.params.add_zenith:
self.params.n_in_channels += 1
if self.params.add_landmask:
self.params.n_in_channels += 2
if self.params.add_orography:
self.params.n_in_channels += 1
# init wandb
if self.sweep_id:
jid = os.environ['SLURM_JOBID'] # so different sweeps dont resume
exp_dir = os.path.join(*[self.root_dir, 'sweeps', self.sweep_id, self.config, jid])
else:
exp_dir = os.path.join(*[self.root_dir, self.config, self.run_num])
if self.world_rank==0:
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir)
os.makedirs(os.path.join(exp_dir, 'training_checkpoints/'))
os.makedirs(os.path.join(exp_dir, 'wandb/'))
self.params['experiment_dir'] = os.path.abspath(exp_dir)
self.params['checkpoint_path'] = os.path.join(exp_dir, 'training_checkpoints/ckpt.tar')
self.params['best_checkpoint_path'] = os.path.join(exp_dir, 'training_checkpoints/best_ckpt.tar')
self.params['resuming'] = True if os.path.isfile(self.params.checkpoint_path) else False
if self.log_to_wandb:
if self.sweep_id:
wandb.init(dir=os.path.join(exp_dir, "wandb"))
hpo_config = wandb.config
self.params.update_params(hpo_config)
logging.info('HPO sweep %s, trial params:'%self.sweep_id)
logging.info(self.params.log())
else:
wandb.init(dir=os.path.join(exp_dir, "wandb"),
config=self.params.params, name=self.params.name, group=self.params.group, project=self.params.project,
entity=self.params.entity, resume=self.params.resuming)
logging.info(self.params.log())
if self.sweep_id and dist.is_initialized():
# broadcast the params to all ranks since the sweep agent has changed it
if self.world_rank == 0: # where the wandb agent has changed params
objects = [self.params]
else:
self.params = None
objects = [None]
dist.broadcast_object_list(objects, src=0)
self.params = objects[0]
# set_seed(self.params, self.world_size)
if self.world_rank==0:
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(exp_dir, 'out.log'))
logging_utils.log_versions()
self.params['global_batch_size'] = self.params.batch_size
self.params['local_batch_size'] = int(self.params.batch_size//self.world_size)
self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader_distributed(self.params, self.params.train_data_path, dist.is_initialized(), train=True)
self.valid_data_loader, self.valid_dataset = get_data_loader_distributed(self.params, self.params.valid_data_path, dist.is_initialized(), train=False)
self.params['img_shape_x'] = self.train_dataset.img_shape_x
self.params['img_shape_y'] = self.train_dataset.img_shape_y
# dump the yaml used
if self.world_rank == 0:
hparams = ruamelDict()
yaml = YAML()
for key, value in self.params.params.items():
hparams[str(key)] = value.tolist() if isinstance(value, np.ndarray) else value
with open(os.path.join(self.params['experiment_dir'], 'hyperparams.yaml'), 'w') as hpfile:
yaml.dump(hparams, hpfile)
self.loss_obj = LossHandler(self.params).to(self.device)
self.model = get_model(self.params).to(self.device)
# data preprocessing
self.preprocessor = PreProcessor(self.params, self.device).to(self.device)
if self.log_to_wandb:
wandb.watch(self.model)
if self.params.optimizer_type == 'adam':
self.optimizer = torch.optim.Adam(self.model.parameters(), lr =self.params.lr, betas=(0.9, 0.95), fused=True)
elif self.params.optimizer_type == 'FusedLAMB':
self.optimizer = optimizers.FusedLAMB(self.model.parameters(), lr = self.params.lr, max_grad_norm=5.)
else:
raise Exception(f"optimizer type {self.params.optimizer_type} not implemented")
if self.params.enable_amp == True:
self.gscaler = amp.GradScaler()
if dist.is_initialized():
self.model = DistributedDataParallel(self.model,
device_ids=[self.local_rank],
output_device=[self.local_rank],
static_graph=(params.checkpointing>0))
self.iters = 0
self.startEpoch = 0
if self.params.finetune and not self.params.resuming:
assert (
params.pretrained_checkpoint_path is not None
), "error, please specify a valid pretrained checkpoint path"
if self.log_to_screen:
logging.info("Loading checkpoint %s"%self.params.pretrained_checkpoint_path)
self.restore_checkpoint(params.pretrained_checkpoint_path)
if self.params.resuming:
if self.log_to_screen:
logging.info("Loading checkpoint %s"%self.params.checkpoint_path)
self.restore_checkpoint(self.params.checkpoint_path)
self.epoch = self.startEpoch
if self.params.scheduler == 'ReduceLROnPlateau':
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.2, patience=5, mode='min')
elif self.params.scheduler == 'CosineAnnealingLR':
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.params.max_epochs, last_epoch=self.startEpoch-1)
else:
self.scheduler = None
num_p = self.count_parameters()
if self.log_to_screen:
logging.info(self.model)
logging.info("Number of parameters = {}".format(num_p))
# launch training
self.train()
def train(self):
if self.log_to_screen:
logging.info("Starting Training Loop...")
best_valid_loss = 1.e6
for epoch in range(self.startEpoch, self.params.max_epochs):
if dist.is_initialized() and (self.train_sampler is not None):
self.train_sampler.set_epoch(epoch)
start = time.time()
tr_time, data_time, train_logs = self.train_one_epoch()
valid_time, valid_logs = self.validate_one_epoch()
if self.params.scheduler == 'ReduceLROnPlateau':
self.scheduler.step(valid_logs['valid_loss'])
elif self.params.scheduler == 'CosineAnnealingLR':
self.scheduler.step()
if self.log_to_wandb:
for pg in self.optimizer.param_groups:
lr = pg['lr']
wandb.log({'lr': lr})
if self.world_rank == 0:
if self.params.save_checkpoint:
#checkpoint at the end of every epoch
self.save_checkpoint(self.params.checkpoint_path)
if valid_logs['valid_loss'] <= best_valid_loss:
#logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss']))
self.save_checkpoint(self.params.best_checkpoint_path)
best_valid_loss = valid_logs['valid_loss']
if self.log_to_screen:
logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
logging.info('Training time = {}, Valid time = {}'.format(tr_time, valid_time))
logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['loss'], valid_logs['valid_loss']))
def train_one_epoch(self):
self.epoch += 1
tr_time = 0
data_time = 0
tr_loss = []
self.model.train()
st = time.time()
for i, data in enumerate(self.train_data_loader, 0):
tr_start = time.time()
inp, tar, coszen = self.preprocessor(data)
self.model.zero_grad()
with amp.autocast(self.params.enable_amp):
gen = self.model(inp, coszen=coszen).to(self.device, dtype = torch.float)
loss = self.loss_obj(gen, tar, inp)
if self.params.enable_amp:
self.gscaler.scale(loss).backward()
self.gscaler.step(self.optimizer)
else:
loss.backward()
self.optimizer.step()
if self.params.enable_amp:
self.gscaler.update()
# logs
if dist.is_initialized():
dist.all_reduce(loss)
tr_loss.append(loss.item()/dist.get_world_size())
tr_time += time.time() - tr_start
logs = {'loss': np.mean(tr_loss)}
if self.log_to_wandb:
wandb.log(logs, step=self.epoch)
return tr_time, data_time, logs
def validate_one_epoch(self):
self.model.eval()
mult = torch.as_tensor(np.load(self.params.global_stds_path)[0,self.params.out_channels,0,0]).to(self.device)
valid_buff = torch.zeros((3), dtype=torch.float32, device=self.device)
valid_loss = valid_buff[0].view(-1)
valid_steps = valid_buff[2].view(-1)
valid_weighted_rmse = torch.zeros((self.params.n_out_channels), dtype=torch.float32, device=self.device)
valid_weighted_acc = torch.zeros((self.params.n_out_channels), dtype=torch.float32, device=self.device)
valid_start = time.time()
sample_idx = np.random.randint(len(self.valid_data_loader))
with torch.no_grad():
for i, data in enumerate(self.valid_data_loader, 0):
inp, tar, coszen = self.preprocessor(data)
gen = self.model(inp, coszen=coszen).to(self.device, dtype = torch.float)
valid_loss += self.loss_obj(gen, tar, inp)
valid_steps += 1.
# compute metrics on final step of rollout when n_future > 1
# TODO fix this for dali dataloader
tar = tar[:,-self.params.n_out_channels:]
gen = gen[:,-self.params.n_out_channels:]
valid_weighted_rmse += weighted_rmse_torch(gen, tar)
if (i == sample_idx) and self.log_to_wandb:
fields = [gen[0,0].detach().cpu().numpy(), tar[0,0].detach().cpu().numpy()]
if dist.is_initialized():
dist.all_reduce(valid_buff)
dist.all_reduce(valid_weighted_rmse)
valid_time = time.time() - valid_start
# divide by number of steps
valid_buff[0:2] = valid_buff[0:2] / valid_buff[2]
valid_weighted_rmse = valid_weighted_rmse / valid_buff[2]
valid_weighted_rmse *= mult
# download buffers
valid_buff_cpu = valid_buff.detach().cpu().numpy()
valid_weighted_rmse_cpu = valid_weighted_rmse.detach().cpu().numpy()
valid_time = time.time() - valid_start
valid_weighted_rmse = mult*torch.mean(valid_weighted_rmse, axis = 0)
logs = {'valid_loss': valid_buff_cpu[0]}
# track specific variables
if hasattr(self.params, 'track_channels'):
idxes = [self.params.channel_names.index(varname) for varname in self.params.track_channels]
track_channels = self.params.track_channels
else:
track_channels = ['u10m', 'v10m']
idxes = [0, 1]
for idx,var in zip(idxes,track_channels):
logs.update({f'valid_rmse_{var}': valid_weighted_rmse_cpu[idx]})
if self.log_to_wandb:
fig = vis(fields)
logs['vis'] = wandb.Image(fig)
plt.close(fig)
wandb.log(logs, step=self.epoch)
return valid_time, logs
def save_checkpoint(self, checkpoint_path, model=None):
if not model:
model = self.model
torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path)
def restore_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
try:
self.model.load_state_dict(checkpoint['model_state'])
except:
new_state_dict = OrderedDict()
for key, val in checkpoint['model_state'].items():
name = key[7:]
new_state_dict[name] = val
self.model.load_state_dict(new_state_dict)
if self.params.resuming:
self.iters = checkpoint['iters']
self.startEpoch = checkpoint['epoch']
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--run_num", default='00', type=str)
parser.add_argument("--yaml_config", default='./config/afno.yaml', type=str)
parser.add_argument("--config", default='default', type=str)
parser.add_argument("--enable_amp", action='store_true')
parser.add_argument("--sweep_id", default=None, type=str, help='sweep config from ./configs/sweeps.yaml')
args = parser.parse_args()
params = YParams(os.path.abspath(args.yaml_config), args.config)
trainer = Trainer(params, args)
if args.sweep_id and trainer.world_rank==0:
wandb.agent(args.sweep_id, function=trainer.build_and_launch, count=1, entity=trainer.params.entity, project=trainer.params.project)
else:
trainer.build_and_launch()
if dist.is_initialized():
dist.barrier()
logging.info('DONE')