-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
53 lines (47 loc) · 2.13 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# -*- coding: utf-8 -*-
import argparse
import os
from tagger.cmds import Evaluate, Predict, Train
from tagger.config import Config
import torch
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Create the tagger model.'
)
subparsers = parser.add_subparsers(title='Commands', dest='mode')
subcommands = {
'evaluate': Evaluate(),
'predict': Predict(),
'train': Train()
}
for name, subcommand in subcommands.items():
subparser = subcommand.add_subparser(name, subparsers)
subparser.add_argument('--conf', '-c', default='config.ini',
help='path to config file')
subparser.add_argument('--model', '-m', default='exp/HMM-EM-debug/model.char',
help='path to model file')
subparser.add_argument('--vocab', '-v', default='exp/HMM-EM-debug/vocab.char',
help='path to vocab file')
subparser.add_argument('--device', '-d', default='-1',
help='ID of GPU to use')
subparser.add_argument('--preprocess', '-p', action='store_true',
help='whether to preprocess the data first')
subparser.add_argument('--seed', '-s', default=1, type=int,
help='seed for generating random numbers')
subparser.add_argument('--threads', '-t', default=4, type=int,
help='max num of threads')
args = parser.parse_args()
print(f"Set the max num of threads to {args.threads}")
print(f"Set the seed for generating random numbers to {args.seed}")
print(f"Set the device with ID {args.device} visible")
torch.set_num_threads(args.threads)
torch.manual_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Override the default configs with parsed arguments")
config = Config(args.conf)
config.update(vars(args))
print(config)
print(f"Run the subcommand in mode {args.mode}")
cmd = subcommands[args.mode]
cmd(config)