-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_task_dota_auto.py
161 lines (145 loc) · 6.73 KB
/
train_task_dota_auto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import os
import sys
from datetime import datetime
import signal
from mmengine.config import Config
import logging
import shutil
import pynvml
from tools.GauS_tools import (print_color_str, get_latest_dir)
from tools.model_converters.publish_model import process_checkpoint
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config list file path')
parser.add_argument('gpus', nargs='*', help='gpus id')
parser.add_argument('--single', action='store_true', default=False, help='train config list file path')
parser.add_argument('--multipy', action='store_true', default=False, help='train config list file path')
parser.add_argument('--outfile', default=None)
args = parser.parse_args()
if args.multipy == args.single:
raise ValueError
if args.single and (len(args.gpus) > 1 or len(args.gpus) == 0):
raise ValueError
return args
def logEnd():
logging.info(f'End time {datetime.now().strftime("%Y%m%d_%H%M%S")}')
class HandlerStopExp(Exception):
def __init__(self, *args, **kwargs):
self.message = f'Handle stop process.'
# value = property(lambda self: object(), lambda self, v: None, lambda self: None) # default
def handler_signal(signal, frame):
print_color_str('Handle stop process.\n', 'r')
raise HandlerStopExp
def main():
signal.signal(signal.SIGINT, handler_signal)
args = parse_args()
path = args.config
if args.single:
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpus[0]}"
else:
pynvml.nvmlInit()
n_gpus = pynvml.nvmlDeviceGetCount()
outfile = args.outfile
time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
date_str = time_str.split('_')[0]
if outfile is None:
task_file = os.path.splitext(os.path.basename(path))[0]
outdir = os.path.join(os.getcwd(), 'log', date_str)
os.makedirs(outdir, exist_ok=True)
outfile = os.path.join(outdir, f'{task_file}_{time_str}')
logging.basicConfig(filename=f'{outfile}.log', level=logging.INFO,
format='%(asctime)s %(message)s', datefmt='%Y%m%d-%H%M%S')
logging.info(f'Begin time {time_str}')
with open(path, 'r') as f:
file_list = [file.strip() for file in f.readlines() if len(file.strip()) > 0]
cmds = []
save_dirs = []
valid_files = []
for file in file_list:
if file.startswith('#'):
print_color_str(f'Skip: {file[1:].strip()}', 'b')
logging.info(f'Skip: {file[1:].strip()}')
continue
try:
split_file = [line.strip() for line in file.split(' ') if len(line.strip()) > 0]
if len(split_file) == 1:
file = split_file[0]
options = ''
elif len(split_file) >= 2:
file = split_file[0]
options = ' '.join(split_file[1:])
else:
raise ValueError
valid_files.append(file)
cfg = Config.fromfile(file)
if args.single:
cmds.append(f'python3 ./tools/train.py {file} {options}')
else:
cmds.append(f'./tools/dist_train.sh {file} {n_gpus} {options}')
save_dirs.append(cfg.work_dir)
except HandlerStopExp as hse:
logging.critical(f'{hse}')
logEnd()
sys.exit(0)
except Exception as e:
logging.error(f'There are something wrong when parse {file}.\n\t{e}')
print_color_str(f'There are something wrong when parse {file}.\n\t{e}', 'm')
logging.info(f'Valid config file: {len(cmds)}')
for cmd in cmds:
logging.info(f'\t{cmd}')
for cmd, valid_file, save_dir in zip(cmds, valid_files, save_dirs):
try:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print_color_str(cmd, 'c')
logging.info(f'Training: {valid_file}')
os.system(cmd)
logging.info(f'Done: {valid_file}')
logging.info(f'{cmd}\tSucceed')
result_dir = get_latest_dir(save_dir)
file_name = os.path.splitext(os.path.basename(valid_file))[0]
with open(os.path.join(save_dir, "last_checkpoint"), 'r') as f:
pth_path = [line.strip() for line in f.readlines() if len(line.strip()) > 0][-1]
shutil.copy(pth_path, result_dir)
print_color_str(f'Copy {pth_path}\tTo\t{result_dir}', 'g')
logging.info(f'Copy {pth_path}\tTo\t{result_dir}')
test_save_dir = os.path.join(result_dir, file_name)
pubulish_checkpoint = process_checkpoint(pth_path, f'{test_save_dir}.pth')
print_color_str(f'Publish {pth_path}\tTo\t{pubulish_checkpoint}', 'g')
logging.info(f'Publish {pth_path}\tTo\t{pubulish_checkpoint}')
shutil.copy(os.path.join(save_dir, f'{file_name}.py'), result_dir)
test_outfile_prefix = os.path.join(result_dir, file_name)
if args.single:
test_cmd = f'python3 ./tools/test.py {valid_file} {pubulish_checkpoint} ' \
f'--cfg-options test_evaluator.outfile_prefix={test_outfile_prefix}'
else:
test_cmd = f'./tools/dist_test.sh {valid_file} {pubulish_checkpoint} {n_gpus} ' \
f'--cfg-options test_evaluator.outfile_prefix={test_outfile_prefix}'
print_color_str(test_cmd, 'g')
logging.info(test_cmd)
os.system(test_cmd)
test_dir = get_latest_dir(save_dir)
if test_dir is not None:
shutil.rmtree(test_dir)
print_color_str(f'Remove cache: {test_dir}', 'm')
logging.info(f'Remove cache: {test_dir}')
shutil.move(os.path.join(result_dir, file_name, f'{file_name}.zip'), result_dir)
print_color_str(f"Move {os.path.join(result_dir, file_name, f'{file_name}.zip')} "
f"To {result_dir}", 'g')
logging.info(f"Move {os.path.join(result_dir, file_name, f'{file_name}.zip')} To {result_dir}")
shutil.rmtree(os.path.join(result_dir, file_name))
print_color_str(f'Remove cache: {os.path.join(result_dir, file_name)}', 'm')
logging.info(f'Remove cache: {os.path.join(result_dir, file_name)}')
print_color_str(f'Test succeed! {file_name}')
logging.info(f'Test succeed! {file_name}')
except HandlerStopExp as hse:
logging.critical(f'{hse}')
logEnd()
sys.exit(0)
except Exception as e:
logging.error(f'{cmd}\tFailed\n\tDetail: {e}')
logEnd()
# logging.info(f'End time {datetime.now().strftime("%Y%m%d_%H%M%S")}')
if __name__ == '__main__':
main()