Skip to content

Commit f13829c

Browse files
committed
Update_dataset
1 parent 6867187 commit f13829c

File tree

3 files changed

+205
-7
lines changed

3 files changed

+205
-7
lines changed

config.py

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import collections
2+
import functools
3+
import os
4+
import re
5+
6+
import yaml
7+
from util.distributed import master_only_print as print
8+
9+
10+
class AttrDict(dict):
11+
"""Dict as attribute trick."""
12+
13+
def __init__(self, *args, **kwargs):
14+
super(AttrDict, self).__init__(*args, **kwargs)
15+
self.__dict__ = self
16+
for key, value in self.__dict__.items():
17+
if isinstance(value, dict):
18+
self.__dict__[key] = AttrDict(value)
19+
elif isinstance(value, (list, tuple)):
20+
if isinstance(value[0], dict):
21+
self.__dict__[key] = [AttrDict(item) for item in value]
22+
else:
23+
self.__dict__[key] = value
24+
25+
def yaml(self):
26+
"""Convert object to yaml dict and return."""
27+
yaml_dict = {}
28+
for key, value in self.__dict__.items():
29+
if isinstance(value, AttrDict):
30+
yaml_dict[key] = value.yaml()
31+
elif isinstance(value, list):
32+
if isinstance(value[0], AttrDict):
33+
new_l = []
34+
for item in value:
35+
new_l.append(item.yaml())
36+
yaml_dict[key] = new_l
37+
else:
38+
yaml_dict[key] = value
39+
else:
40+
yaml_dict[key] = value
41+
return yaml_dict
42+
43+
def __repr__(self):
44+
"""Print all variables."""
45+
ret_str = []
46+
for key, value in self.__dict__.items():
47+
if isinstance(value, AttrDict):
48+
ret_str.append('{}:'.format(key))
49+
child_ret_str = value.__repr__().split('\n')
50+
for item in child_ret_str:
51+
ret_str.append(' ' + item)
52+
elif isinstance(value, list):
53+
if isinstance(value[0], AttrDict):
54+
ret_str.append('{}:'.format(key))
55+
for item in value:
56+
# Treat as AttrDict above.
57+
child_ret_str = item.__repr__().split('\n')
58+
for item in child_ret_str:
59+
ret_str.append(' ' + item)
60+
else:
61+
ret_str.append('{}: {}'.format(key, value))
62+
else:
63+
ret_str.append('{}: {}'.format(key, value))
64+
return '\n'.join(ret_str)
65+
66+
67+
class Config(AttrDict):
68+
r"""Configuration class. This should include every human specifiable
69+
hyperparameter values for your training."""
70+
71+
def __init__(self, filename=None, args=None, verbose=False, is_train=True):
72+
super(Config, self).__init__()
73+
# Set default parameters.
74+
# Logging.
75+
76+
large_number = 1000000000
77+
self.snapshot_save_iter = large_number
78+
self.snapshot_save_epoch = large_number
79+
self.snapshot_save_start_iter = 0
80+
self.snapshot_save_start_epoch = 0
81+
self.image_save_iter = large_number
82+
self.eval_epoch = large_number
83+
self.start_eval_epoch = large_number
84+
self.eval_epoch = large_number
85+
self.max_epoch = large_number
86+
self.max_iter = large_number
87+
self.logging_iter = 100
88+
self.image_to_tensorboard=False
89+
self.which_iter = args.which_iter
90+
self.resume = not args.no_resume
91+
92+
93+
self.checkpoints_dir = args.checkpoints_dir
94+
self.name = args.name
95+
self.phase = 'train' if is_train else 'test'
96+
97+
# Networks.
98+
self.gen = AttrDict(type='generators.dummy')
99+
self.dis = AttrDict(type='discriminators.dummy')
100+
101+
# Optimizers.
102+
103+
# Data.
104+
self.data = AttrDict(name='dummy',
105+
type='datasets.images',
106+
num_workers=0)
107+
self.test_data = AttrDict(name='dummy',
108+
type='datasets.images',
109+
num_workers=0,
110+
test=AttrDict(is_lmdb=False,
111+
roots='',
112+
batch_size=1))
113+
self.trainer = AttrDict(
114+
model_average=False,
115+
model_average_beta=0.9999,
116+
model_average_start_iteration=1000,
117+
model_average_batch_norm_estimation_iteration=30,
118+
model_average_remove_sn=True,
119+
image_to_tensorboard=False,
120+
hparam_to_tensorboard=False,
121+
distributed_data_parallel='pytorch',
122+
delay_allreduce=True,
123+
gan_relativistic=False,
124+
gen_step=1,
125+
dis_step=1)
126+
127+
# # Cudnn.
128+
self.cudnn = AttrDict(deterministic=False,
129+
benchmark=True)
130+
131+
# Others.
132+
self.pretrained_weight = ''
133+
self.inference_args = AttrDict()
134+
135+
136+
# Update with given configurations.
137+
assert os.path.exists(filename), 'File {} not exist.'.format(filename)
138+
loader = yaml.SafeLoader
139+
loader.add_implicit_resolver(
140+
u'tag:yaml.org,2002:float',
141+
re.compile(u'''^(?:
142+
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
143+
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
144+
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
145+
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
146+
|[-+]?\\.(?:inf|Inf|INF)
147+
|\\.(?:nan|NaN|NAN))$''', re.X),
148+
list(u'-+0123456789.'))
149+
try:
150+
with open(filename, 'r') as f:
151+
cfg_dict = yaml.load(f, Loader=loader)
152+
except EnvironmentError:
153+
print('Please check the file with name of "%s"', filename)
154+
recursive_update(self, cfg_dict)
155+
156+
# Put common opts in both gen and dis.
157+
if 'common' in cfg_dict:
158+
self.common = AttrDict(**cfg_dict['common'])
159+
self.gen.common = self.common
160+
self.dis.common = self.common
161+
162+
163+
if verbose:
164+
print(' config '.center(80, '-'))
165+
print(self.__repr__())
166+
print(''.center(80, '-'))
167+
168+
169+
def rsetattr(obj, attr, val):
170+
"""Recursively find object and set value"""
171+
pre, _, post = attr.rpartition('.')
172+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
173+
174+
175+
def rgetattr(obj, attr, *args):
176+
"""Recursively find object and return value"""
177+
178+
def _getattr(obj, attr):
179+
r"""Get attribute."""
180+
return getattr(obj, attr, *args)
181+
182+
return functools.reduce(_getattr, [obj] + attr.split('.'))
183+
184+
185+
def recursive_update(d, u):
186+
"""Recursively update AttrDict d with AttrDict u"""
187+
for key, value in u.items():
188+
if isinstance(value, collections.abc.Mapping):
189+
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
190+
elif isinstance(value, (list, tuple)):
191+
if isinstance(value[0], dict):
192+
d.__dict__[key] = [AttrDict(item) for item in value]
193+
else:
194+
d.__dict__[key] = value
195+
else:
196+
d.__dict__[key] = value
197+
return d

inference_refine_1D_cam.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,16 @@ def write2video(results_dir, *video_list):
9090
opt.data.cross_id = args.cross_id
9191
opt.data.cross_id_target = args.cross_id_target
9292
if args.multi_view:
93-
from data.multiface_video_dataset_inv_fix_target import MultifaceVideoDataset
94-
dataset = MultifaceVideoDataset(opt.data, is_inference=True)
95-
# opt.trainer.inversion.iterations = 300
93+
# TODO: add multi-view dataset
94+
raise NotImplementedError
95+
# from data.multiface_video_dataset_inv_fix_target import MultifaceVideoDataset
96+
# dataset = MultifaceVideoDataset(opt.data, is_inference=True)
9697
else:
9798
if args.cross_id_target is not None:
9899
assert args.cross_id
99-
from data.hdtf_video_dataset_inv_fix_target import HDTFVideoDataset
100+
from data.hdtf_cross_id import HDTFVideoDataset
100101
else:
101-
from data.hdtf_video_dataset_inv import HDTFVideoDataset
102+
from data.dataset import HDTFVideoDataset
102103
dataset = HDTFVideoDataset(opt.data, is_inference=True)
103104

104105
# create a model

scripts/inference.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export CUDA_VISIBLE_DEVICES=1
22
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12345 inference_refine_1D_cam.py \
3-
--config ./config/config/otavatar.yaml \
4-
--name config/otavatar.yaml \
3+
--config ./config/otavatar.yaml \
4+
--name result/otavatar/animation \
55
--no_resume \
66
--which_iter 2000 \
77
--image_size 512 \

0 commit comments

Comments
 (0)