3
3
import torch .nn .functional as F
4
4
import numpy as np
5
5
from torch .autograd import Variable
6
- from custom_layers import fadein_layer , ConcatTable , minibatch_std_concat_layer , Flatten , pixelwise_norm_layer
6
+ from custom_layers import fadein_layer , ConcatTable , minibatch_std_concat_layer , Flatten , pixelwise_norm_layer , equalized_conv2d , equalized_deconv2d , equalized_linear
7
7
import copy
8
8
9
9
10
10
# defined for code simplicity.
11
11
def deconv (layers , c_in , c_out , k_size , stride = 1 , pad = 0 , leaky = True , bn = False , wn = False , pixel = False ):
12
- if wn : layers .append (nn . utils . weight_norm ( nn . ConvTranspose2d ( c_in , c_out , k_size , stride , pad ), name = 'weight' ))
12
+ if wn : layers .append (equalized_deconv2d ( c_in , c_out , k_size , stride , pad ))
13
13
else : layers .append (nn .ConvTranspose2d (c_in , c_out , k_size , stride , pad ))
14
14
if leaky : layers .append (nn .LeakyReLU (0.2 ))
15
15
else : layers .append (nn .ReLU ())
@@ -18,18 +18,19 @@ def deconv(layers, c_in, c_out, k_size, stride=1, pad=0, leaky=True, bn=False, w
18
18
return layers
19
19
20
20
def conv (layers , c_in , c_out , k_size , stride = 1 , pad = 0 , leaky = True , bn = False , wn = False , pixel = False ):
21
- if wn : layers .append (nn . utils . weight_norm ( nn . Conv2d ( c_in , c_out , k_size , stride , pad ), name = 'weight ' ))
21
+ if wn : layers .append (equalized_conv2d ( c_in , c_out , k_size , stride , pad , initializer = 'kaiming ' ))
22
22
else : layers .append (nn .Conv2d (c_in , c_out , k_size , stride , pad ))
23
23
if leaky : layers .append (nn .LeakyReLU (0.2 ))
24
24
else : layers .append (nn .ReLU ())
25
25
if bn : layers .append (nn .BatchNorm2d (c_out ))
26
26
if pixel : layers .append (pixelwise_norm_layer ())
27
27
return layers
28
28
29
- def linear (layers , c_in , c_out , sigmoid = True ):
29
+ def linear (layers , c_in , c_out , sig = True , wn = False ):
30
30
layers .append (Flatten ())
31
- layers .append (nn .Linear (c_in , c_out ))
32
- if sigmoid : layers .append (nn .Sigmoid ())
31
+ if wn : layers .append (equalized_linear (c_in , c_out ))
32
+ else : layers .append (nn .Linear (c_in , c_out ))
33
+ if sig : layers .append (nn .Sigmoid ())
33
34
return layers
34
35
35
36
@@ -41,6 +42,13 @@ def deepcopy_module(module, target):
41
42
new_module [- 1 ].load_state_dict (m .state_dict ()) # copy weights
42
43
return new_module
43
44
45
+ def soft_copy_param (target_link , source_link , tau ):
46
+ ''' soft-copy parameters of a link to another link. '''
47
+ target_params = dict (target_link .named_parameters ())
48
+ for param_name , param in source_link .named_parameters ():
49
+ target_params [param_name ].data = target_params [param_name ].data .mul (1.0 - tau )
50
+ target_params [param_name ].data = target_params [param_name ].data .add (param .data .mul (tau ))
51
+
44
52
def get_module_names (model ):
45
53
names = []
46
54
for key , val in model .state_dict ().iteritems ():
@@ -199,9 +207,9 @@ def last_block(self):
199
207
ndim = self .ndf
200
208
layers = []
201
209
layers .append (minibatch_std_concat_layer ())
202
- layers = conv (layers , ndim + 1 , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
203
- layers = conv (layers , ndim , ndim , 4 , 1 , 0 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
204
- layers = linear (layers , ndim , 1 , self .flag_sigmoid )
210
+ layers = conv (layers , ndim + 1 , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
211
+ layers = conv (layers , ndim , ndim , 4 , 1 , 0 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
212
+ layers = linear (layers , ndim , 1 , sig = self .flag_sigmoid , wn = self . flag_wn )
205
213
return nn .Sequential (* layers ), ndim
206
214
207
215
def intermediate_block (self , resl ):
@@ -217,18 +225,18 @@ def intermediate_block(self, resl):
217
225
ndim = ndim / 2
218
226
layers = []
219
227
if halving :
220
- layers = deconv (layers , ndim , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
221
- layers = deconv (layers , ndim , ndim * 2 , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
228
+ layers = deconv (layers , ndim , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
229
+ layers = deconv (layers , ndim , ndim * 2 , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
222
230
else :
223
- layers = deconv (layers , ndim , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
224
- layers = deconv (layers , ndim , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
231
+ layers = deconv (layers , ndim , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
232
+ layers = deconv (layers , ndim , ndim , 3 , 1 , 1 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
225
233
226
234
layers .append (nn .AvgPool2d (kernel_size = 2 )) # scale up by factor of 2.0
227
235
return nn .Sequential (* layers ), ndim , layer_name
228
236
229
237
def from_rgb_block (self , ndim ):
230
238
layers = []
231
- layers = conv (layers , self .nc , ndim , 1 , 1 , 0 , self .flag_leaky , self .flag_bn , self .flag_wn , self . flag_pixelwise )
239
+ layers = conv (layers , self .nc , ndim , 1 , 1 , 0 , self .flag_leaky , self .flag_bn , self .flag_wn , pixel = False )
232
240
return nn .Sequential (* layers )
233
241
234
242
def get_init_dis (self ):
0 commit comments