@@ -48,13 +48,15 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4848 nn .Dropout (dropout )
4949 ) if project_out else nn .Identity ()
5050
51- def forward (self , x ):
51+ def forward (self , x , mask = None ):
5252 x = self .norm (x )
5353 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
5454 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
5555
5656 dots = torch .matmul (q , k .transpose (- 1 , - 2 )) * self .scale
5757
58+ if mask is not None :
59+ dots = dots .masked_fill (mask .view (4 , 1 , 1 , 5 ) == 0 , float ('-inf' ))
5860 attn = self .attend (dots )
5961 attn = self .dropout (attn )
6062
@@ -72,9 +74,9 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
7274 Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
7375 FeedForward (dim , mlp_dim , dropout = dropout )
7476 ]))
75- def forward (self , x ):
77+ def forward (self , x , mask = None ):
7678 for attn , ff in self .layers :
77- x = attn (x ) + x
79+ x = attn (x , mask = mask ) + x
7880 x = ff (x ) + x
7981 return self .norm (x )
8082
@@ -166,7 +168,7 @@ def __init__(
166168 self .mlp_head = nn .Linear (dim , num_classes )
167169 self .variant = variant
168170
169- def forward (self , video ):
171+ def forward (self , video , mask = None ):
170172 x = self .to_patch_embedding (video )
171173 b , f , n , _ = x .shape
172174
@@ -197,10 +199,15 @@ def forward(self, video):
197199
198200 x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
199201
202+ if mask is not None :
203+ temporal_mask = torch .ones ((b , f + 1 ), device = x .device , dtype = torch .bool )
204+ temporal_mask [:, 1 :] = mask
205+ else :
206+ temporal_mask = None
200207
201208 # attend across time
202209
203- x = self .temporal_transformer (x )
210+ x = self .temporal_transformer (x , mask = temporal_mask )
204211
205212 # excise out temporal cls token or average pool
206213
0 commit comments