-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
65 lines (51 loc) · 2.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import argparse
from omegaconf import OmegaConf
from utils.general_utils import random_seed
from exp.stage import first_stage_train, second_stage_train
def main(args):
if args.exp == 'd2c-vae':
config = OmegaConf.load(args.configs)
args.data_config = config.data
args.ddconfig = config.model.params.ddconfig
args.mlpconfig = config.model.params.mlpconfig
args.loss_config = config.model.params.lossconfig
args.embed_dim = config.model.embed_dim
args.lr = config.model.lr
args.resolution = config.model.params.ddconfig.resolution
args.resume = config.model.resume
args.use_fp16 = config.model.use_fp16
args.amp = config.model.amp
args.domain = config.data.domain
args.mode = config.data.mode
first_stage_train(args)
elif args.exp == 'ldm':
config = OmegaConf.load(args.configs)
args.data_config = config.data
args.ddconfig = config.model.params.ddconfig
args.mlpconfig = config.model.params.mlpconfig
args.unetconfig = config.model.params.unetconfig
args.loss_config = config.model.params.lossconfig
args.ddpmconfig = config.model.params.ddpmconfig
args.embed_dim = config.model.embed_dim
args.lr = config.model.lr
args.resolution = config.model.params.ddconfig.resolution
args.resume = config.model.resume
args.pretrained = config.model.pretrained
args.amp = config.model.amp
args.use_fp16 = config.model.use_fp16
args.domain = config.data.domain
args.mode = config.data.mode
args.DiT = config.model.DiT
second_stage_train(args)
else:
raise ValueError('Undefined Type!')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, required=True, choices=['d2c-vae', 'ldm'])
parser.add_argument('--configs', type=str)
parser.add_argument('--seed', type=int, default=777)
args = parser.parse_args()
# seed
random_seed(args.seed)
main(args)