Skip to content

Commit e5a3640

Browse files
author
Asif Ahmed
committed
replace silu, reduce code
1 parent e450b8e commit e5a3640

File tree

4 files changed

+24
-96
lines changed

4 files changed

+24
-96
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name='vqcompress',
55
author='Asif Ahmed',
66
description='Image compression with vqgan, autoencoder etc.',
7-
version='0.1.7',
7+
version='0.1.8',
88
url='https://github.com/quickgrid/vq-compress',
99
packages=find_packages(),
1010
classifiers=[

vqcompress/core/ldm/autoencoder.py

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,6 @@ def get_input(self, batch, k):
7979
def get_last_layer(self):
8080
return self.decoder.conv_out.weight
8181

82-
@torch.no_grad()
83-
def log_images(self, batch, only_inputs=False, **kwargs):
84-
log = dict()
85-
x = self.get_input(batch, self.image_key)
86-
x = x.to(self.device)
87-
if not only_inputs:
88-
xrec, posterior = self(x)
89-
if x.shape[1] > 3:
90-
# colorize with random projection
91-
assert xrec.shape[1] > 3
92-
x = self.to_rgb(x)
93-
xrec = self.to_rgb(xrec)
94-
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
95-
log["reconstructions"] = xrec
96-
log["inputs"] = x
97-
return log
98-
9982
def to_rgb(self, x):
10083
assert self.image_key == "segmentation"
10184
if not hasattr(self, "colorize"):
@@ -240,7 +223,6 @@ def __init__(self,
240223
kl_weight=1e-8,
241224
remap=None,
242225
):
243-
244226
z_channels = ddconfig["z_channels"]
245227
super().__init__(ddconfig,
246228
# lossconfig,
@@ -256,75 +238,22 @@ def __init__(self,
256238
# self.loss.n_classes = n_embed
257239
self.vocab_size = n_embed
258240

259-
self.quantize = GumbelQuantize(z_channels, embed_dim,
260-
n_embed=n_embed,
261-
kl_weight=kl_weight, temp_init=1.0,
262-
remap=remap)
241+
self.quantize = GumbelQuantize(
242+
z_channels, embed_dim,
243+
n_embed=n_embed,
244+
kl_weight=kl_weight, temp_init=1.0,
245+
remap=remap
246+
)
263247

264248
# self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
265249

266250
if ckpt_path is not None:
267251
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
268252

269-
def temperature_scheduling(self):
270-
self.quantize.temperature = self.temperature_scheduler(self.global_step)
271-
272253
def encode_to_prequant(self, x):
273254
h = self.encoder(x)
274255
h = self.quant_conv(h)
275256
return h
276257

277258
def decode_code(self, code_b):
278259
raise NotImplementedError
279-
280-
def training_step(self, batch, batch_idx, optimizer_idx):
281-
self.temperature_scheduling()
282-
x = self.get_input(batch, self.image_key)
283-
xrec, qloss = self(x)
284-
285-
if optimizer_idx == 0:
286-
# autoencode
287-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
288-
last_layer=self.get_last_layer(), split="train")
289-
290-
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
291-
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
292-
return aeloss
293-
294-
if optimizer_idx == 1:
295-
# discriminator
296-
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
297-
last_layer=self.get_last_layer(), split="train")
298-
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
299-
return discloss
300-
301-
def validation_step(self, batch, batch_idx):
302-
x = self.get_input(batch, self.image_key)
303-
xrec, qloss = self(x, return_pred_indices=True)
304-
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
305-
last_layer=self.get_last_layer(), split="val")
306-
307-
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
308-
last_layer=self.get_last_layer(), split="val")
309-
rec_loss = log_dict_ae["val/rec_loss"]
310-
self.log("val/rec_loss", rec_loss,
311-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
312-
self.log("val/aeloss", aeloss,
313-
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
314-
self.log_dict(log_dict_ae)
315-
self.log_dict(log_dict_disc)
316-
return self.log_dict
317-
318-
def log_images(self, batch, **kwargs):
319-
log = dict()
320-
x = self.get_input(batch, self.image_key)
321-
x = x.to(self.device)
322-
# encode
323-
h = self.encoder(x)
324-
h = self.quant_conv(h)
325-
quant, _, _ = self.quantize(h)
326-
# decode
327-
x_rec = self.decode(quant)
328-
log["inputs"] = x
329-
log["reconstructions"] = x_rec
330-
return log

vqcompress/core/ldm/distributions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import torch
21
import numpy as np
2+
import torch
33

44

55
class DiagonalGaussianDistribution(object):
@@ -22,22 +22,26 @@ def kl(self, other=None):
2222
return torch.Tensor([0.])
2323
else:
2424
if other is None:
25-
return 0.5 * torch.sum(torch.pow(self.mean, 2)
26-
+ self.var - 1.0 - self.logvar,
27-
dim=[1, 2, 3])
25+
return 0.5 * torch.sum(
26+
torch.pow(self.mean, 2)
27+
+ self.var - 1.0 - self.logvar,
28+
dim=[1, 2, 3]
29+
)
2830
else:
2931
return 0.5 * torch.sum(
3032
torch.pow(self.mean - other.mean, 2) / other.var
3133
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
32-
dim=[1, 2, 3])
34+
dim=[1, 2, 3]
35+
)
3336

34-
def nll(self, sample, dims=[1,2,3]):
37+
def nll(self, sample, dims=[1, 2, 3]):
3538
if self.deterministic:
3639
return torch.Tensor([0.])
3740
logtwopi = np.log(2.0 * np.pi)
3841
return 0.5 * torch.sum(
3942
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
40-
dim=dims)
43+
dim=dims
44+
)
4145

4246
def mode(self):
4347
return self.mean

vqcompress/core/ldm/model.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,13 @@
44
import torch
55
import torch.nn as nn
66
from einops import rearrange
7-
7+
import torch.nn.functional as F
88
try:
99
import xformers.ops
1010
except ModuleNotFoundError as err:
1111
print(err)
1212

1313

14-
def nonlinearity(x):
15-
# swish
16-
return x * torch.sigmoid(x)
17-
18-
1914
def Normalize(in_channels, num_groups=32):
2015
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
2116

@@ -115,14 +110,14 @@ def __init__(
115110
def forward(self, x, temb):
116111
h = x
117112
h = self.norm1(h)
118-
h = nonlinearity(h)
113+
h = F.silu(h)
119114
h = self.conv1(h)
120115

121116
if temb is not None:
122-
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
117+
h = h + self.temb_proj(F.silu(temb))[:, :, None, None]
123118

124119
h = self.norm2(h)
125-
h = nonlinearity(h)
120+
h = F.silu(h)
126121
h = self.dropout(h)
127122
h = self.conv2(h)
128123

@@ -349,7 +344,7 @@ def forward(self, x):
349344

350345
# end
351346
h = self.norm_out(h)
352-
h = nonlinearity(h)
347+
h = F.silu(h)
353348
h = self.conv_out(h)
354349
return h
355350

@@ -470,7 +465,7 @@ def forward(self, z):
470465
return h
471466

472467
h = self.norm_out(h)
473-
h = nonlinearity(h)
468+
h = F.silu(h)
474469
h = self.conv_out(h)
475470
if self.tanh_out:
476471
h = torch.tanh(h)

0 commit comments

Comments
 (0)