-
Notifications
You must be signed in to change notification settings - Fork 138
/
modules.py
276 lines (218 loc) · 8.04 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
from functions import vq, vq_st
def to_scalar(arr):
if type(arr) == list:
return [x.item() for x in arr]
else:
return arr.item()
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
try:
nn.init.xavier_uniform_(m.weight.data)
m.bias.data.fill_(0)
except AttributeError:
print("Skipping initialization of ", classname)
class VAE(nn.Module):
def __init__(self, input_dim, dim, z_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.Conv2d(dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.Conv2d(dim, dim, 5, 1, 0),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.Conv2d(dim, z_dim * 2, 3, 1, 0),
nn.BatchNorm2d(z_dim * 2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(z_dim, dim, 3, 1, 0),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, dim, 5, 1, 0),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
nn.Tanh()
)
self.apply(weights_init)
def forward(self, x):
mu, logvar = self.encoder(x).chunk(2, dim=1)
q_z_x = Normal(mu, logvar.mul(.5).exp())
p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar))
kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()
x_tilde = self.decoder(q_z_x.rsample())
return x_tilde, kl_div
class VQEmbedding(nn.Module):
def __init__(self, K, D):
super().__init__()
self.embedding = nn.Embedding(K, D)
self.embedding.weight.data.uniform_(-1./K, 1./K)
def forward(self, z_e_x):
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
latents = vq(z_e_x_, self.embedding.weight)
return latents
def straight_through(self, z_e_x):
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
dim=0, index=indices)
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
return z_q_x, z_q_x_bar
class ResBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.block = nn.Sequential(
nn.ReLU(True),
nn.Conv2d(dim, dim, 3, 1, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.Conv2d(dim, dim, 1),
nn.BatchNorm2d(dim)
)
def forward(self, x):
return x + self.block(x)
class VectorQuantizedVAE(nn.Module):
def __init__(self, input_dim, dim, K=512):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.Conv2d(dim, dim, 4, 2, 1),
ResBlock(dim),
ResBlock(dim),
)
self.codebook = VQEmbedding(K, dim)
self.decoder = nn.Sequential(
ResBlock(dim),
ResBlock(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
nn.BatchNorm2d(dim),
nn.ReLU(True),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
nn.Tanh()
)
self.apply(weights_init)
def encode(self, x):
z_e_x = self.encoder(x)
latents = self.codebook(z_e_x)
return latents
def decode(self, latents):
z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2) # (B, D, H, W)
x_tilde = self.decoder(z_q_x)
return x_tilde
def forward(self, x):
z_e_x = self.encoder(x)
z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x)
x_tilde = self.decoder(z_q_x_st)
return x_tilde, z_e_x, z_q_x
class GatedActivation(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x, y = x.chunk(2, dim=1)
return F.tanh(x) * F.sigmoid(y)
class GatedMaskedConv2d(nn.Module):
def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10):
super().__init__()
assert kernel % 2 == 1, print("Kernel size must be odd")
self.mask_type = mask_type
self.residual = residual
self.class_cond_embedding = nn.Embedding(
n_classes, 2 * dim
)
kernel_shp = (kernel // 2 + 1, kernel) # (ceil(n/2), n)
padding_shp = (kernel // 2, kernel // 2)
self.vert_stack = nn.Conv2d(
dim, dim * 2,
kernel_shp, 1, padding_shp
)
self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
kernel_shp = (1, kernel // 2 + 1)
padding_shp = (0, kernel // 2)
self.horiz_stack = nn.Conv2d(
dim, dim * 2,
kernel_shp, 1, padding_shp
)
self.horiz_resid = nn.Conv2d(dim, dim, 1)
self.gate = GatedActivation()
def make_causal(self):
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
def forward(self, x_v, x_h, h):
if self.mask_type == 'A':
self.make_causal()
h = self.class_cond_embedding(h)
h_vert = self.vert_stack(x_v)
h_vert = h_vert[:, :, :x_v.size(-1), :]
out_v = self.gate(h_vert + h[:, :, None, None])
h_horiz = self.horiz_stack(x_h)
h_horiz = h_horiz[:, :, :, :x_h.size(-2)]
v2h = self.vert_to_horiz(h_vert)
out = self.gate(v2h + h_horiz + h[:, :, None, None])
if self.residual:
out_h = self.horiz_resid(out) + x_h
else:
out_h = self.horiz_resid(out)
return out_v, out_h
class GatedPixelCNN(nn.Module):
def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10):
super().__init__()
self.dim = dim
# Create embedding layer to embed input
self.embedding = nn.Embedding(input_dim, dim)
# Building the PixelCNN layer by layer
self.layers = nn.ModuleList()
# Initial block with Mask-A convolution
# Rest with Mask-B convolutions
for i in range(n_layers):
mask_type = 'A' if i == 0 else 'B'
kernel = 7 if i == 0 else 3
residual = False if i == 0 else True
self.layers.append(
GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes)
)
# Add the output layer
self.output_conv = nn.Sequential(
nn.Conv2d(dim, 512, 1),
nn.ReLU(True),
nn.Conv2d(512, input_dim, 1)
)
self.apply(weights_init)
def forward(self, x, label):
shp = x.size() + (-1, )
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
x_v, x_h = (x, x)
for i, layer in enumerate(self.layers):
x_v, x_h = layer(x_v, x_h, label)
return self.output_conv(x_h)
def generate(self, label, shape=(8, 8), batch_size=64):
param = next(self.parameters())
x = torch.zeros(
(batch_size, *shape),
dtype=torch.int64, device=param.device
)
for i in range(shape[0]):
for j in range(shape[1]):
logits = self.forward(x, label)
probs = F.softmax(logits[:, :, i, j], -1)
x.data[:, i, j].copy_(
probs.multinomial(1).squeeze().data
)
return x