-
Notifications
You must be signed in to change notification settings - Fork 7
/
layers.py
97 lines (77 loc) · 3.06 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import torch.nn.functional as nnf
# WARP FUNCTION
class SpatialTransformer(nn.Module):
"""
N-D Spatial Transformer
"""
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids)
grid = torch.unsqueeze(grid, 0)
grid = grid.type(torch.FloatTensor)
# registering the grid as a buffer cleanly moves it to the GPU, but it also
# adds it to the state dict. this is annoying since everything in the state dict
# is included when saving weights to disk, so the model files are way bigger
# than they need to be. so far, there does not appear to be an elegant solution.
# see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
self.register_buffer('grid', grid)
def forward(self, src, flow):
# new locations
new_locs = self.grid + flow
shape = flow.shape[2:]
# need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
# move channels dim to last position
# also not sure why, but the channels need to be reversed
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
new_locs = new_locs[..., [1, 0]]
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
return nnf.grid_sample(src, new_locs, mode=self.mode)
class VecInt(nn.Module):
"""
Integrates a vector field via scaling and squaring.
"""
def __init__(self, inshape, nsteps):
super().__init__()
assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
self.nsteps = nsteps
self.scale = 1.0 / (2 ** self.nsteps)
self.transformer = SpatialTransformer(inshape)
def forward(self, vec):
vec = vec * self.scale
for _ in range(self.nsteps):
vec = vec + self.transformer(vec, vec)
return vec
class ResizeTransform(nn.Module):
"""
Resize a transform, which involves resizing the vector field *and* rescaling it.
"""
def __init__(self, vel_resize, ndims):
super().__init__()
self.factor = 1.0 / vel_resize
self.mode = 'linear'
if ndims == 2:
self.mode = 'bi' + self.mode
elif ndims == 3:
self.mode = 'tri' + self.mode
def forward(self, x):
if self.factor < 1:
# resize first to save memory
x = nnf.interpolate(x, scale_factor=self.factor, mode=self.mode)
x = self.factor * x
elif self.factor > 1:
# multiply first to save memory
x = self.factor * x
x = nnf.interpolate(x, scale_factor=self.factor, mode=self.mode)
# don't do anything if resize is 1
return x