-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
78 lines (57 loc) · 2.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
import argparse
import yaml
import torch
from utils.logger import print_log
from utils.random_seed import setup_seed, SEED
from utils.config_utils import overwrite_values
from utils import register as R
########### Import your packages below ##########
import models
from trainer import create_trainer
from data import create_dataset, create_dataloader
from utils.nn_utils import count_parameters
def parse():
parser = argparse.ArgumentParser(description='training')
# device
parser.add_argument('--gpus', type=int, nargs='+', required=True, help='gpu to use, -1 for cpu')
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank. Necessary for using the torch.distributed.launch utility.")
# config
parser.add_argument('--config', type=str, required=True, help='Path to the yaml configure')
parser.add_argument('--seed', type=int, default=SEED, help='Random seed')
return parser.parse_known_args()
def load_ckpt(model, ckpt):
trained_model = torch.load(ckpt, map_location='cpu')
model.load_state_dict(trained_model.state_dict())
return model
def main(args, opt_args):
# load config
config = yaml.safe_load(open(args.config, 'r'))
config = overwrite_values(config, opt_args)
########## define your model #########
model = R.construct(config['model'])
if 'load_ckpt' in config:
model = load_ckpt(model, config['load_ckpt'])
########### load your train / valid set ###########
train_set, valid_set, _ = create_dataset(config['dataset'])
########## define your trainer/trainconfig #########
if len(args.gpus) > 1:
args.local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', world_size=len(args.gpus))
else:
args.local_rank = -1
if args.local_rank <= 0:
print_log(f'Number of parameters: {count_parameters(model) / 1e6} M')
train_loader = create_dataloader(train_set, config['dataloader'], len(args.gpus))
valid_loader = create_dataloader(valid_set, config['dataloader'], validation=True)
trainer = create_trainer(config, model, train_loader, valid_loader)
trainer.train(args.gpus, args.local_rank)
if __name__ == '__main__':
args, opt_args = parse()
print_log(f'Overwritting args: {opt_args}')
setup_seed(args.seed)
main(args, opt_args)