Skip to content

Commit e619c9c

Browse files
authored
Use pytorch way to do escape NaN
The old tensorflow-style grammar still cause NaN output when wi is very small.
1 parent ee1fd75 commit e619c9c

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

model/networks.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,7 @@ def forward(self, f, b, mask=None):
285285
if self.use_cuda:
286286
escape_NaN = escape_NaN.cuda()
287287
wi = wi[0] # [L, C, k, k]
288-
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
289-
axis=[1, 2, 3],
290-
keepdim=True)),
291-
escape_NaN)
288+
max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
292289
wi_normed = wi / max_wi
293290
# xi shape: [1, C, H, W], yi shape: [1, L, H, W]
294291
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W

0 commit comments

Comments
 (0)