@@ -79,23 +79,6 @@ def get_input(self, batch, k):
79
79
def get_last_layer (self ):
80
80
return self .decoder .conv_out .weight
81
81
82
- @torch .no_grad ()
83
- def log_images (self , batch , only_inputs = False , ** kwargs ):
84
- log = dict ()
85
- x = self .get_input (batch , self .image_key )
86
- x = x .to (self .device )
87
- if not only_inputs :
88
- xrec , posterior = self (x )
89
- if x .shape [1 ] > 3 :
90
- # colorize with random projection
91
- assert xrec .shape [1 ] > 3
92
- x = self .to_rgb (x )
93
- xrec = self .to_rgb (xrec )
94
- log ["samples" ] = self .decode (torch .randn_like (posterior .sample ()))
95
- log ["reconstructions" ] = xrec
96
- log ["inputs" ] = x
97
- return log
98
-
99
82
def to_rgb (self , x ):
100
83
assert self .image_key == "segmentation"
101
84
if not hasattr (self , "colorize" ):
@@ -240,7 +223,6 @@ def __init__(self,
240
223
kl_weight = 1e-8 ,
241
224
remap = None ,
242
225
):
243
-
244
226
z_channels = ddconfig ["z_channels" ]
245
227
super ().__init__ (ddconfig ,
246
228
# lossconfig,
@@ -256,75 +238,22 @@ def __init__(self,
256
238
# self.loss.n_classes = n_embed
257
239
self .vocab_size = n_embed
258
240
259
- self .quantize = GumbelQuantize (z_channels , embed_dim ,
260
- n_embed = n_embed ,
261
- kl_weight = kl_weight , temp_init = 1.0 ,
262
- remap = remap )
241
+ self .quantize = GumbelQuantize (
242
+ z_channels , embed_dim ,
243
+ n_embed = n_embed ,
244
+ kl_weight = kl_weight , temp_init = 1.0 ,
245
+ remap = remap
246
+ )
263
247
264
248
# self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
265
249
266
250
if ckpt_path is not None :
267
251
self .init_from_ckpt (ckpt_path , ignore_keys = ignore_keys )
268
252
269
- def temperature_scheduling (self ):
270
- self .quantize .temperature = self .temperature_scheduler (self .global_step )
271
-
272
253
def encode_to_prequant (self , x ):
273
254
h = self .encoder (x )
274
255
h = self .quant_conv (h )
275
256
return h
276
257
277
258
def decode_code (self , code_b ):
278
259
raise NotImplementedError
279
-
280
- def training_step (self , batch , batch_idx , optimizer_idx ):
281
- self .temperature_scheduling ()
282
- x = self .get_input (batch , self .image_key )
283
- xrec , qloss = self (x )
284
-
285
- if optimizer_idx == 0 :
286
- # autoencode
287
- aeloss , log_dict_ae = self .loss (qloss , x , xrec , optimizer_idx , self .global_step ,
288
- last_layer = self .get_last_layer (), split = "train" )
289
-
290
- self .log_dict (log_dict_ae , prog_bar = False , logger = True , on_step = True , on_epoch = True )
291
- self .log ("temperature" , self .quantize .temperature , prog_bar = False , logger = True , on_step = True , on_epoch = True )
292
- return aeloss
293
-
294
- if optimizer_idx == 1 :
295
- # discriminator
296
- discloss , log_dict_disc = self .loss (qloss , x , xrec , optimizer_idx , self .global_step ,
297
- last_layer = self .get_last_layer (), split = "train" )
298
- self .log_dict (log_dict_disc , prog_bar = False , logger = True , on_step = True , on_epoch = True )
299
- return discloss
300
-
301
- def validation_step (self , batch , batch_idx ):
302
- x = self .get_input (batch , self .image_key )
303
- xrec , qloss = self (x , return_pred_indices = True )
304
- aeloss , log_dict_ae = self .loss (qloss , x , xrec , 0 , self .global_step ,
305
- last_layer = self .get_last_layer (), split = "val" )
306
-
307
- discloss , log_dict_disc = self .loss (qloss , x , xrec , 1 , self .global_step ,
308
- last_layer = self .get_last_layer (), split = "val" )
309
- rec_loss = log_dict_ae ["val/rec_loss" ]
310
- self .log ("val/rec_loss" , rec_loss ,
311
- prog_bar = True , logger = True , on_step = False , on_epoch = True , sync_dist = True )
312
- self .log ("val/aeloss" , aeloss ,
313
- prog_bar = True , logger = True , on_step = False , on_epoch = True , sync_dist = True )
314
- self .log_dict (log_dict_ae )
315
- self .log_dict (log_dict_disc )
316
- return self .log_dict
317
-
318
- def log_images (self , batch , ** kwargs ):
319
- log = dict ()
320
- x = self .get_input (batch , self .image_key )
321
- x = x .to (self .device )
322
- # encode
323
- h = self .encoder (x )
324
- h = self .quant_conv (h )
325
- quant , _ , _ = self .quantize (h )
326
- # decode
327
- x_rec = self .decode (quant )
328
- log ["inputs" ] = x
329
- log ["reconstructions" ] = x_rec
330
- return log
0 commit comments