11from collections import namedtuple
22
33import torch
4+ from torch import nn , cat
45import torch .nn .functional as F
5- from torch import nn
6+ from torch . nn import Module , ModuleList
67from torch .nn .attention import SDPBackend , sdpa_kernel
78
89from einops import rearrange , repeat , reduce
910from einops .layers .torch import Rearrange
1011
11-
1212# helpers
1313
1414def exists (val ):
1515 return val is not None
1616
17+ def divisible_by (num , den ):
18+ return (num % den ) == 0
19+
1720def pair (t ):
1821 return t if isinstance (t , tuple ) else (t , t )
1922
2023# classes
2124
22- class FeedForward (nn . Module ):
25+ class FeedForward (Module ):
2326 def __init__ (self , dim , hidden_dim , dropout = 0. ):
2427 super ().__init__ ()
2528 self .net = nn .Sequential (
@@ -33,7 +36,7 @@ def __init__(self, dim, hidden_dim, dropout = 0.):
3336 def forward (self , x ):
3437 return self .net (x )
3538
36- class Attention (nn . Module ):
39+ class Attention (Module ):
3740 def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. , use_flash_attn = True ):
3841 super ().__init__ ()
3942 self .use_flash_attn = use_flash_attn
@@ -55,80 +58,101 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn =
5558 nn .Dropout (dropout )
5659 ) if project_out else nn .Identity ()
5760
58- def flash_attn (self , q , k , v , mask = None ):
61+ def flash_attn (self , q , k , v , mask = None ):
62+
5963 with sdpa_kernel ([SDPBackend .MATH , SDPBackend .EFFICIENT_ATTENTION , SDPBackend .FLASH_ATTENTION , SDPBackend .CUDNN_ATTENTION ]):
60- out = F .scaled_dot_product_attention (q , k , v ,
61- attn_mask = mask ,
62- dropout_p = self .dropout_p ,
63- is_causal = False ,
64- scale = self .scale )
64+
65+ out = F .scaled_dot_product_attention (
66+ q , k , v ,
67+ attn_mask = mask ,
68+ dropout_p = self .dropout_p ,
69+ is_causal = False ,
70+ scale = self .scale
71+ )
6572
6673 return out
6774
68- def forward (self , x , mask = None ):
69- B , F , _ = x .size ()
75+ def forward (self , x , mask = None ):
76+ batch , seq , _ = x .shape
77+
7078 x = self .norm (x )
7179 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
7280 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
7381
82+ if exists (mask ):
83+ mask = rearrange (mask , 'b j -> b 1 1 j' )
84+
7485 if self .use_flash_attn :
75- out = self .flash_attn (q , k , v , mask = mask )
86+ out = self .flash_attn (q , k , v , mask = mask )
7687
7788 else :
7889 dots = torch .matmul (q , k .transpose (- 1 , - 2 )) * self .scale
7990
80- if mask is not None :
81- dots = dots .masked_fill (mask .view (B , 1 , 1 , F ) == 0 , float ('-inf' ))
91+ if exists (mask ):
92+ mask = rearrange (mask , 'b j -> b 1 1 j' )
93+ dots = dots .masked_fill (~ mask , - torch .finfo (dots .dtype ).max )
94+
8295 attn = self .attend (dots )
8396 attn = self .dropout (attn )
8497
8598 out = torch .matmul (attn , v )
99+
86100 out = rearrange (out , 'b h n d -> b n (h d)' )
87101 return self .to_out (out )
88102
89- class Transformer (nn . Module ):
103+ class Transformer (Module ):
90104 def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. , use_flash_attn = True ):
91105 super ().__init__ ()
92106 self .use_flash_attn = use_flash_attn
107+
93108 self .norm = nn .LayerNorm (dim )
94- self .layers = nn . ModuleList ([])
109+ self .layers = ModuleList ([])
95110 for _ in range (depth ):
96111 self .layers .append (nn .ModuleList ([
97112 Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
98113 FeedForward (dim , mlp_dim , dropout = dropout )
99114 ]))
100- def forward (self , x , mask = None ):
115+
116+ def forward (self , x , mask = None ):
117+
101118 for attn , ff in self .layers :
102- x = attn (x , mask = mask ) + x
119+ x = attn (x , mask = mask ) + x
103120 x = ff (x ) + x
121+
104122 return self .norm (x )
105123
106- class FactorizedTransformer (nn . Module ):
124+ class FactorizedTransformer (Module ):
107125 def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. , use_flash_attn = True ):
108126 super ().__init__ ()
109127 self .use_flash_attn = use_flash_attn
128+
110129 self .norm = nn .LayerNorm (dim )
111130 self .layers = nn .ModuleList ([])
131+
112132 for _ in range (depth ):
113133 self .layers .append (nn .ModuleList ([
114134 Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout , use_flash_attn = use_flash_attn ),
115135 Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout , use_flash_attn = use_flash_attn ),
116136 FeedForward (dim , mlp_dim , dropout = dropout )
117137 ]))
118138
119- def forward (self , x ):
120- b , f , n , _ = x .shape
139+ def forward (self , x , mask = None ):
140+ batch , frames , seq , _ = x .shape
141+
142+ if exists (mask ):
143+ mask = repeat (mask , 'b ... -> (b space) ...' , space = x .shape [2 ])
144+
121145 for spatial_attn , temporal_attn , ff in self .layers :
122146 x = rearrange (x , 'b f n d -> (b f) n d' )
123147 x = spatial_attn (x ) + x
124- x = rearrange (x , '(b f) n d -> (b n) f d' , b = b , f = f )
125- x = temporal_attn (x ) + x
148+ x = rearrange (x , '(b f) n d -> (b n) f d' , b = batch , f = frames )
149+ x = temporal_attn (x , mask = mask ) + x
126150 x = ff (x ) + x
127- x = rearrange (x , '(b n) f d -> b f n d' , b = b , n = n )
151+ x = rearrange (x , '(b n) f d -> b f n d' , b = batch , n = seq )
128152
129153 return self .norm (x )
130154
131- class ViT ( nn . Module ):
155+ class ViViT ( Module ):
132156 def __init__ (
133157 self ,
134158 * ,
@@ -154,8 +178,8 @@ def __init__(
154178 image_height , image_width = pair (image_size )
155179 patch_height , patch_width = pair (image_patch_size )
156180
157- assert image_height % patch_height == 0 and image_width % patch_width == 0 , 'Image dimensions must be divisible by the patch size.'
158- assert frames % frame_patch_size == 0 , 'Frames must be divisible by frame patch size'
181+ assert divisible_by ( image_height , patch_height ) and image_width % patch_width == 0 , 'Image dimensions must be divisible by the patch size.'
182+ assert divisible_by ( frames , frame_patch_size ) , 'Frames must be divisible by frame patch size'
159183 assert variant in ('factorized_encoder' , 'factorized_self_attention' ), f'variant = { variant } is not implemented'
160184
161185 num_image_patches = (image_height // patch_height ) * (image_width // patch_width )
@@ -165,6 +189,8 @@ def __init__(
165189
166190 assert pool in {'cls' , 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)'
167191
192+ self .frame_patch_size = frame_patch_size
193+
168194 self .global_average_pool = pool == 'mean'
169195
170196 self .to_patch_embedding = nn .Sequential (
@@ -193,25 +219,36 @@ def __init__(
193219 self .mlp_head = nn .Linear (dim , num_classes )
194220 self .variant = variant
195221
196- def forward (self , video , mask = None ):
222+ def forward (self , video , mask = None ):
223+ device = video .device
224+
197225 x = self .to_patch_embedding (video )
198- b , f , n , _ = x .shape
226+ batch , frames , seq , _ = x .shape
199227
200- x = x + self .pos_embedding [:, :f , :n ]
228+ x = x + self .pos_embedding [:, :frames , :seq ]
201229
202230 if exists (self .spatial_cls_token ):
203- spatial_cls_tokens = repeat (self .spatial_cls_token , '1 1 d -> b f 1 d' , b = b , f = f )
204- x = torch . cat ((spatial_cls_tokens , x ), dim = 2 )
231+ spatial_cls_tokens = repeat (self .spatial_cls_token , '1 1 d -> b f 1 d' , b = batch , f = frames )
232+ x = cat ((spatial_cls_tokens , x ), dim = 2 )
205233
206234 x = self .dropout (x )
207235
236+ # maybe temporal mask
237+
238+ temporal_mask = None
239+
240+ if exists (mask ):
241+ temporal_mask = reduce (mask , 'b (f patch) -> b f' , 'all' , patch = self .frame_patch_size )
242+
243+ # the two variants
244+
208245 if self .variant == 'factorized_encoder' :
209246 x = rearrange (x , 'b f n d -> (b f) n d' )
210247
211248 # attend across space
212249
213250 x = self .spatial_transformer (x )
214- x = rearrange (x , '(b f) n d -> b f n d' , b = b )
251+ x = rearrange (x , '(b f) n d -> b f n d' , b = batch )
215252
216253 # excise out the spatial cls tokens or average pool for temporal attention
217254
@@ -220,27 +257,50 @@ def forward(self, video, mask=None):
220257 # append temporal CLS tokens
221258
222259 if exists (self .temporal_cls_token ):
223- temporal_cls_tokens = repeat (self .temporal_cls_token , '1 1 d-> b 1 d' , b = b )
260+ temporal_cls_tokens = repeat (self .temporal_cls_token , '1 1 d-> b 1 d' , b = batch )
224261
225- x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
226-
227- if mask is not None :
228- temporal_mask = torch .ones ((b , f + 1 ), device = x .device , dtype = torch .bool )
229- temporal_mask [:, 1 :] = mask
230- else :
231- temporal_mask = None
262+ x = cat ((temporal_cls_tokens , x ), dim = 1 )
263+
264+ if exists (temporal_mask ):
265+ temporal_mask = F .pad (temporal_mask , (1 , 0 ), value = True )
232266
233267 # attend across time
234268
235- x = self .temporal_transformer (x , mask = temporal_mask )
269+ x = self .temporal_transformer (x , mask = temporal_mask )
236270
237271 # excise out temporal cls token or average pool
238272
239273 x = x [:, 0 ] if not self .global_average_pool else reduce (x , 'b f d -> b d' , 'mean' )
240274
241275 elif self .variant == 'factorized_self_attention' :
242- x = self .factorized_transformer (x )
276+
277+ x = self .factorized_transformer (x , mask = temporal_mask )
278+
243279 x = x [:, 0 , 0 ] if not self .global_average_pool else reduce (x , 'b f n d -> b d' , 'mean' )
244280
245281 x = self .to_latent (x )
246282 return self .mlp_head (x )
283+
284+ # main
285+
286+ if __name__ == '__main__' :
287+
288+ vivit = ViViT (
289+ dim = 512 ,
290+ spatial_depth = 2 ,
291+ temporal_depth = 2 ,
292+ heads = 4 ,
293+ mlp_dim = 2048 ,
294+ image_size = 256 ,
295+ image_patch_size = 16 ,
296+ frames = 8 ,
297+ frame_patch_size = 2 ,
298+ num_classes = 1000 ,
299+ variant = 'factorized_encoder' ,
300+ )
301+
302+ video = torch .randn (3 , 3 , 8 , 256 , 256 )
303+ mask = torch .randint (0 , 2 , (3 , 8 )).bool ()
304+
305+ logits = vivit (video , mask = None )
306+ assert logits .shape == (3 , 1000 )
0 commit comments