-
Notifications
You must be signed in to change notification settings - Fork 65
/
train_moglow.py
66 lines (53 loc) · 2.01 KB
/
train_moglow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Train script.
Usage:
train_moglow.py <hparams> <dataset>
"""
import os
import motion
import numpy as np
import datetime
from docopt import docopt
from torch.utils.data import DataLoader, Dataset
from glow.builder import build
from glow.trainer import Trainer
from glow.generator import Generator
from glow.config import JsonConfig
from torch.utils.data import DataLoader
if __name__ == "__main__":
args = docopt(__doc__)
hparams = args["<hparams>"]
dataset = args["<dataset>"]
assert dataset in motion.Datasets, (
"`{}` is not supported, use `{}`".format(dataset, motion.Datasets.keys()))
assert os.path.exists(hparams), (
"Failed to find hparams josn `{}`".format(hparams))
hparams = JsonConfig(hparams)
dataset = motion.Datasets[dataset]
date = str(datetime.datetime.now())
date = date[:date.rfind(":")].replace("-", "")\
.replace(":", "")\
.replace(" ", "_")
log_dir = os.path.join(hparams.Dir.log_root, "log_" + date)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
print("log_dir:" + str(log_dir))
is_training = hparams.Infer.pre_trained == ""
data = dataset(hparams, is_training)
x_channels, cond_channels = data.n_channels()
# build graph
built = build(x_channels, cond_channels, hparams, is_training)
if is_training:
# build trainer
trainer = Trainer(**built, data=data, log_dir=log_dir, hparams=hparams)
# train model
trainer.train()
else:
# Synthesize a lot of data.
generator = Generator(data, built['data_device'], log_dir, hparams)
if "temperature" in hparams.Infer:
temp = hparams.Infer.temperature
else:
temp = 1
# We generate x times to get some different variations for each input
for i in range(5):
generator.generate_sample(built['graph'],eps_std=temp, counter=i)