Skip to content

Commit

Permalink
Switch back to adaptivepool from resizeright in MakeCutouts
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Sepiol committed Oct 11, 2021
1 parent 7f8619b commit 5bf3174
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
1 change: 1 addition & 0 deletions cgd/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


def range_loss(input):
"""(Katherine Crowson) - Spherical distance loss"""
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])


Expand Down
8 changes: 1 addition & 7 deletions cgd/modules.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import torch as th
import torch.nn.functional as tf
import torchvision.transforms as tvt
import torchvision.transforms.functional as tvtf
from torchvision.transforms.transforms import RandomAdjustSharpness, RandomVerticalFlip

from cgd.ResizeRight import interp_methods, resize_right


class MakeCutouts(th.nn.Module):
def __init__(self, cut_size: int, num_cutouts: int, cutout_size_power: float = 1.0, use_augs: bool = False):
Expand Down Expand Up @@ -37,9 +32,8 @@ def forward(self, input: th.Tensor):
offsetx = th.randint(0, side_x - size + 1, ())
offsety = th.randint(0, side_y - size + 1, ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
cutout = resize_right.resize(cutout, out_shape=[self.cut_size, self.cut_size], interp_method=interp_methods.lanczos3, by_convs=True)
cutout = self.augs(cutout)
# cutout = tf.adaptive_avg_pool2d(cutout, self.cut_size)
cutout = tf.adaptive_avg_pool2d(cutout, self.cut_size)
cutouts.append(cutout)

return th.cat(cutouts)

0 comments on commit 5bf3174

Please sign in to comment.