Skip to content

Commit 5ddf309

Browse files
author
min-nashory
committed
add equlaized-lr feature, add smoothed genetor
1 parent b14d024 commit 5ddf309

File tree

6 files changed

+96
-29
lines changed

6 files changed

+96
-29
lines changed

config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
parser = argparse.ArgumentParser('PGGAN')
77

88
## general settings.
9-
parser.add_argument('--train_data_root', type=str, default='/home1/irteam/nashory/data/CelebA/Img')
9+
parser.add_argument('--train_data_root', type=str, default='/home1/work/nashory/data/CelebA/Img')
1010
parser.add_argument('--random_seed', type=int, default=int(time.time()))
1111
parser.add_argument('--n_gpu', type=int, default=1) # for Multi-GPU training.
1212

@@ -16,7 +16,8 @@
1616

1717

1818
## training parameters.
19-
parser.add_argument('--lr', type=float, default=0.0002)
19+
parser.add_argument('--lr', type=float, default=0.001)
20+
parser.add_argument('--smoothing', type=float, default=0.997)
2021
parser.add_argument('--nc', type=int, default=3)
2122
parser.add_argument('--nz', type=int, default=512)
2223
parser.add_argument('--ngf', type=int, default=512)
@@ -32,7 +33,7 @@
3233
parser.add_argument('--flag_bn', type=bool, default=False)
3334
parser.add_argument('--flag_pixelwise', type=bool, default=True)
3435
parser.add_argument('--flag_leaky', type=bool, default=True)
35-
parser.add_argument('--flag_tanh', type=bool, default=False)
36+
parser.add_argument('--flag_tanh', type=bool, default=True)
3637
parser.add_argument('--flag_sigmoid', type=bool, default=True)
3738

3839

custom_layers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,38 @@ def forward(self, x):
6666
return torch.addcdiv(t, 1, x, norm)
6767

6868

69+
# for equaliaeed-learning rate.
70+
class equalized_conv2d(nn.Module):
71+
def __init__(self, c_in, c_out, k_size, stride, pad, initializer='kaiming'):
72+
super(equalized_conv2d, self).__init__()
73+
self.conv = nn.Conv2d(c_in, c_out, k_size, stride, pad)
74+
if initializer == 'kaiming': torch.nn.init.kaiming_normal(self.conv.weight)
75+
elif initializer == 'xavier': torch.nn.init.xavier_normal(self.conv.weight)
76+
self.inv_c = np.sqrt(2.0/(c_in*k_size**2))
77+
78+
def forward(self, x):
79+
return self.conv(x.mul(self.inv_c))
80+
81+
82+
class equalized_deconv2d(nn.Module):
83+
def __init__(self, c_in, c_out, k_size, stride, pad, initializer='kaiming'):
84+
super(equalized_deconv2d, self).__init__()
85+
self.deconv = nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad)
86+
if initializer == 'kaiming': torch.nn.init.kaiming_normal(self.deconv.weight)
87+
elif initializer == 'xavier': torch.nn.init.xavier_normal(self.deconv.weight)
88+
self.inv_c = np.sqrt(2.0/(c_in*k_size**2))
89+
90+
def forward(self, x):
91+
return self.deconv(x.mul(self.inv_c))
92+
93+
94+
class equalized_linear(nn.Module):
95+
def __init__(self, c_in, c_out, initializer='kaiming'):
96+
super(equalized_linear, self).__init__()
97+
self.linear = nn.Linear(c_in, c_out)
98+
if initializer == 'kaiming': torch.nn.init.kaiming_normal(self.linear.weight)
99+
elif initializer == 'xavier': torch.nn.init.xavier_normal(self.linear.weight)
100+
self.inv_c = np.sqrt(2.0/(c_in))
101+
102+
def forward(self, x):
103+
return self.linear(x.mul(self.inv_c))

network.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import torch.nn.functional as F
44
import numpy as np
55
from torch.autograd import Variable
6-
from custom_layers import fadein_layer, ConcatTable, minibatch_std_concat_layer, Flatten, pixelwise_norm_layer
6+
from custom_layers import fadein_layer, ConcatTable, minibatch_std_concat_layer, Flatten, pixelwise_norm_layer, equalized_conv2d, equalized_deconv2d, equalized_linear
77
import copy
88

99

1010
# defined for code simplicity.
1111
def deconv(layers, c_in, c_out, k_size, stride=1, pad=0, leaky=True, bn=False, wn=False, pixel=False):
12-
if wn: layers.append(nn.utils.weight_norm(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad), name='weight'))
12+
if wn: layers.append(equalized_deconv2d(c_in, c_out, k_size, stride, pad))
1313
else: layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad))
1414
if leaky: layers.append(nn.LeakyReLU(0.2))
1515
else: layers.append(nn.ReLU())
@@ -18,18 +18,19 @@ def deconv(layers, c_in, c_out, k_size, stride=1, pad=0, leaky=True, bn=False, w
1818
return layers
1919

2020
def conv(layers, c_in, c_out, k_size, stride=1, pad=0, leaky=True, bn=False, wn=False, pixel=False):
21-
if wn: layers.append(nn.utils.weight_norm(nn.Conv2d(c_in, c_out, k_size, stride, pad), name='weight'))
21+
if wn: layers.append(equalized_conv2d(c_in, c_out, k_size, stride, pad, initializer='kaiming'))
2222
else: layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad))
2323
if leaky: layers.append(nn.LeakyReLU(0.2))
2424
else: layers.append(nn.ReLU())
2525
if bn: layers.append(nn.BatchNorm2d(c_out))
2626
if pixel: layers.append(pixelwise_norm_layer())
2727
return layers
2828

29-
def linear(layers, c_in, c_out, sigmoid=True):
29+
def linear(layers, c_in, c_out, sig=True, wn=False):
3030
layers.append(Flatten())
31-
layers.append(nn.Linear(c_in, c_out))
32-
if sigmoid: layers.append(nn.Sigmoid())
31+
if wn: layers.append(equalized_linear(c_in, c_out))
32+
else: layers.append(nn.Linear(c_in, c_out))
33+
if sig: layers.append(nn.Sigmoid())
3334
return layers
3435

3536

@@ -41,6 +42,13 @@ def deepcopy_module(module, target):
4142
new_module[-1].load_state_dict(m.state_dict()) # copy weights
4243
return new_module
4344

45+
def soft_copy_param(target_link, source_link, tau):
46+
''' soft-copy parameters of a link to another link. '''
47+
target_params = dict(target_link.named_parameters())
48+
for param_name, param in source_link.named_parameters():
49+
target_params[param_name].data = target_params[param_name].data.mul(1.0-tau)
50+
target_params[param_name].data = target_params[param_name].data.add(param.data.mul(tau))
51+
4452
def get_module_names(model):
4553
names = []
4654
for key, val in model.state_dict().iteritems():
@@ -199,9 +207,9 @@ def last_block(self):
199207
ndim = self.ndf
200208
layers = []
201209
layers.append(minibatch_std_concat_layer())
202-
layers = conv(layers, ndim+1, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
203-
layers = conv(layers, ndim, ndim, 4, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
204-
layers = linear(layers, ndim, 1, self.flag_sigmoid)
210+
layers = conv(layers, ndim+1, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
211+
layers = conv(layers, ndim, ndim, 4, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
212+
layers = linear(layers, ndim, 1, sig=self.flag_sigmoid, wn=self.flag_wn)
205213
return nn.Sequential(*layers), ndim
206214

207215
def intermediate_block(self, resl):
@@ -217,18 +225,18 @@ def intermediate_block(self, resl):
217225
ndim = ndim/2
218226
layers = []
219227
if halving:
220-
layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
221-
layers = deconv(layers, ndim, ndim*2, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
228+
layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
229+
layers = deconv(layers, ndim, ndim*2, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
222230
else:
223-
layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
224-
layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
231+
layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
232+
layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
225233

226234
layers.append(nn.AvgPool2d(kernel_size=2)) # scale up by factor of 2.0
227235
return nn.Sequential(*layers), ndim, layer_name
228236

229237
def from_rgb_block(self, ndim):
230238
layers = []
231-
layers = conv(layers, self.nc, ndim, 1, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise)
239+
layers = conv(layers, self.nc, ndim, 1, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False)
232240
return nn.Sequential(*layers)
233241

234242
def get_init_dis(self):

requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,5 @@ subprocess32==3.2.7
2222
tensorboardX==0.8
2323
tensorflow==1.3.0
2424
tensorflow-tensorboard==0.1.8
25-
torch==0.2.0.post3
26-
torchvision==0.1.9
2725
tqdm==4.19.4
2826
Werkzeug==0.12.2

trainer.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, config):
2626
self.optimizer = config.optimizer
2727

2828
self.resl = 2 # we start from 2^2 = 4
29+
self.smoothing = config.smoothing
2930
self.max_resl = config.max_resl
3031
self.trns_tick = config.trns_tick
3132
self.stab_tick = config.stab_tick
@@ -45,6 +46,20 @@ def __init__(self, config):
4546

4647
# network and cirterion
4748
self.G = net.Generator(config)
49+
self.Gs = net.Generator(config)
50+
51+
# shallow copy test.
52+
#net.soft_copy_param(self.Gs, self.G, self.smoothing)
53+
#for param in self.G.parameters():
54+
# print(param.data.mean())
55+
#print('------------------')
56+
#for param in self.Gs.parameters():
57+
# print(param.data.mean())
58+
#print('------------------')
59+
60+
61+
62+
4863
self.D = net.Discriminator(config)
4964
print ('Generator structure: ')
5065
print(self.G.model)
@@ -123,8 +138,8 @@ def resl_scheduler(self):
123138
self.fadein['gen'].update_alpha(d_alpha)
124139
self.complete['gen'] = self.fadein['gen'].alpha*100
125140
self.flag_flush_gen = False
126-
self.G.module.flush_network() # flush and,
127-
#self.G.module.freeze_layers() # freeze.
141+
self.G.module.flush_network() # flush G
142+
self.Gs.flush_network() # flush Gs
128143
self.fadein['gen'] = None
129144
self.complete['gen'] = 0.0
130145
self.phase = 'dtrns'
@@ -134,18 +149,14 @@ def resl_scheduler(self):
134149
self.complete['dis'] = self.fadein['dis'].alpha*100
135150
self.flag_flush_dis = False
136151
self.D.module.flush_network() # flush and,
137-
#self.D.module.freeze_layers() # freeze.
138152
self.fadein['dis'] = None
139153
self.complete['dis'] = 0.0
140154
self.phase = 'gtrns'
141155

142156
# grow network.
143157
if floor(self.resl) != prev_resl:
144-
#if prev_resl==2:
145-
# self.G.module.freeze_layers() # freeze.
146-
# self.D.module.freeze_layers() # freeze.
147-
148158
self.G.module.grow_network(floor(self.resl))
159+
self.Gs.grow_network(floor(self.resl))
149160
self.D.module.grow_network(floor(self.resl))
150161
self.renew_everything()
151162
self.fadein['gen'] = self.G.module.model.fadein_block
@@ -248,6 +259,10 @@ def train(self):
248259
loss_g.backward()
249260
self.opt_g.step()
250261

262+
# generator smoothing
263+
net.soft_copy_param(self.Gs, self.G.module, self.smoothing)
264+
265+
251266
# logging.
252267
log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}] errD: {4:.4f} | errG: {5:.4f} | [cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(self.epoch, self.globalTick, self.stack, len(self.loader.dataset), loss_d.data[0], loss_g.data[0], self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'])
253268
tqdm.write(log_msg)
@@ -267,12 +282,15 @@ def train(self):
267282
# tensorboard visualization.
268283
if self.use_tb:
269284
x_test = self.G(self.z_test)
285+
x_test_s = self.Gs(self.z_test)
270286
self.tb.add_scalar('data/loss_g', loss_g.data[0], self.globalIter)
271287
self.tb.add_scalar('data/loss_d', loss_d.data[0], self.globalIter)
272288
self.tb.add_scalar('tick/globalTick', int(self.globalTick), self.globalIter)
273-
self.tb.add_image_grid('grid/x_test', 4, x_test.data.float(), self.globalIter)
274-
self.tb.add_image_grid('grid/x_tilde', 4, self.x_tilde.data.float(), self.globalIter)
275-
self.tb.add_image_grid('grid/x_intp', 1, self.x.data.float(), self.globalIter)
289+
self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.globalIter)
290+
self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1,1], [0,1]), self.globalIter)
291+
self.tb.add_image_grid('grid/x_test_s', 4, utils.adjust_dyn_range(x_test_s.data.float(), [-1,1], [0,1]), self.globalIter)
292+
self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1,1], [0,1]), self.globalIter)
293+
self.tb.add_image_grid('grid/x_intp', 1, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter)
276294

277295

278296
def snapshot(self, path):

utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
import time
1111

1212

13+
def adjust_dyn_range(x, drange_in, drange_out):
14+
if not drange_in == drange_out:
15+
scale = float(drange_out[1]-drange_out[0])/float(drange_in[1]-drange_in[0])
16+
bias = drange_out[0]-drange_in[0]*scale
17+
x = x.mul(scale).add(bias)
18+
return x
19+
1320

1421
def resize(x, size):
1522
transform = transforms.Compose([

0 commit comments

Comments
 (0)