@@ -445,16 +445,16 @@ def __init__(self,
445
445
self .num_windows = int ((input_resolution // window_size ) * (input_resolution // window_size ))
446
446
447
447
def forward (self , x , q_global ):
448
- B , H , W , C = x .shape
449
- shortcut = x
450
- x = self .norm1 (x )
451
- x_windows = window_partition (x , self .window_size )
452
- x_windows = x_windows .view (- 1 , self .window_size * self .window_size , C )
453
- attn_windows = self .attn (x_windows , q_global )
454
- x = window_reverse (attn_windows , self .window_size , H , W )
455
- x = shortcut + self .drop_path (self .gamma1 * x )
456
- x = x + self .drop_path (self .gamma2 * self .mlp (self .norm2 (x )))
457
- return x
448
+ B , H , W , C = x .shape
449
+ shortcut = x
450
+ x = self .norm1 (x )
451
+ x_windows = window_partition (x , self .window_size )
452
+ x_windows = x_windows .view (- 1 , self .window_size * self .window_size , C )
453
+ attn_windows = self .attn (x_windows , q_global )
454
+ x = window_reverse (attn_windows , self .window_size , H , W )
455
+ x = shortcut + self .drop_path (self .gamma1 * x )
456
+ x = x + self .drop_path (self .gamma2 * self .mlp (self .norm2 (x )))
457
+ return x
458
458
459
459
460
460
class GlobalQueryGen (nn .Module ):
@@ -474,6 +474,9 @@ def __init__(self,
474
474
input_resolution: input image resolution.
475
475
window_size: window size.
476
476
num_heads: number of heads.
477
+
478
+ For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at
479
+ down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details.
477
480
"""
478
481
479
482
super ().__init__ ()
0 commit comments