Skip to content

Commit 3900ff4

Browse files
authored
Merge pull request #19 from NVlabs/dev
update model & checkpoints
2 parents f679914 + c77aae0 commit 3900ff4

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

models/gc_vit.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -445,16 +445,16 @@ def __init__(self,
445445
self.num_windows = int((input_resolution // window_size) * (input_resolution // window_size))
446446

447447
def forward(self, x, q_global):
448-
B, H, W, C = x.shape
449-
shortcut = x
450-
x = self.norm1(x)
451-
x_windows = window_partition(x, self.window_size)
452-
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
453-
attn_windows = self.attn(x_windows, q_global)
454-
x = window_reverse(attn_windows, self.window_size, H, W)
455-
x = shortcut + self.drop_path(self.gamma1 * x)
456-
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
457-
return x
448+
B, H, W, C = x.shape
449+
shortcut = x
450+
x = self.norm1(x)
451+
x_windows = window_partition(x, self.window_size)
452+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
453+
attn_windows = self.attn(x_windows, q_global)
454+
x = window_reverse(attn_windows, self.window_size, H, W)
455+
x = shortcut + self.drop_path(self.gamma1 * x)
456+
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
457+
return x
458458

459459

460460
class GlobalQueryGen(nn.Module):
@@ -474,6 +474,9 @@ def __init__(self,
474474
input_resolution: input image resolution.
475475
window_size: window size.
476476
num_heads: number of heads.
477+
478+
For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at
479+
down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details.
477480
"""
478481

479482
super().__init__()

0 commit comments

Comments
 (0)