Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Jan 7, 2021
1 parent 16530e7 commit 7657165
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
6 changes: 4 additions & 2 deletions midaGAN/nn/discriminators/ms_patchgan3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class MultiScalePatchGAN3D(nn.Module):
def __init__(self, in_channels, ndf, n_layers, kernel_size, scales, norm_type):
super().__init__()
# Multiscale PatchGAN consists of multiple PatchGANs.
self.model = nn.ModuleDict({str(scale): patchgan3d.PatchGAN3D(in_channels, ndf, n_layers, kernel_size, norm_type) \
for scale in range(1, scales + 1)})
self.model = nn.ModuleDict()
for scale in range(1, scales + 1):
self.model[str(scale)] = patchgan3d.PatchGAN3D(in_channels, ndf, n_layers, kernel_size,
norm_type)

def forward(self, input):
model_outputs = {}
Expand Down
4 changes: 1 addition & 3 deletions midaGAN/nn/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ def forward(self, prediction: Union[Dict[str, torch.Tensor], torch.Tensor], \
"""
# If prediction is a dict, compute loss and reduce over all keys of the dict
if isinstance(prediction, dict):
loss_list = [self.calculate_loss(pred, target_is_real) \
for k, pred in prediction.items()]
loss_list = [self.calculate_loss(pred, target_is_real) for pred in prediction.values()]
loss = torch.stack(loss_list).mean()

else:
loss = self.calculate_loss(prediction, target_is_real)

Expand Down
3 changes: 1 addition & 2 deletions midaGAN/utils/trackers/wandb_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ def __init__(self, conf):

config_dict = OmegaConf.to_container(conf, resolve=True)

wandb.init(project=project, entity=entity,
config=config_dict) # TODO: project and organization from conf
wandb.init(project=project, entity=entity, config=config_dict)

if conf.logging.wandb.run:
wandb.run.name = conf.logging.wandb.run
Expand Down

0 comments on commit 7657165

Please sign in to comment.