1+ from collections import namedtuple
2+
13import torch
4+ import torch .nn .functional as F
25from torch import nn
6+ from torch .nn .attention import SDPBackend , sdpa_kernel
37
48from einops import rearrange , repeat , reduce
59from einops .layers .torch import Rearrange
610
11+
712# helpers
813
914def exists (val ):
@@ -29,8 +34,10 @@ def forward(self, x):
2934 return self .net (x )
3035
3136class Attention (nn .Module ):
32- def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. ):
37+ def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. , use_flash_attn = True ):
3338 super ().__init__ ()
39+ self .use_flash_attn = use_flash_attn
40+ self .dropout_p = dropout
3441 inner_dim = dim_head * heads
3542 project_out = not (heads == 1 and dim_head == dim )
3643
@@ -48,45 +55,64 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4855 nn .Dropout (dropout )
4956 ) if project_out else nn .Identity ()
5057
51- def forward (self , x ):
58+ def flash_attn (self , q , k , v , mask = None ):
59+ with sdpa_kernel ([SDPBackend .MATH , SDPBackend .EFFICIENT_ATTENTION , SDPBackend .FLASH_ATTENTION , SDPBackend .CUDNN_ATTENTION ]):
60+ out = F .scaled_dot_product_attention (q , k , v ,
61+ attn_mask = mask ,
62+ dropout_p = self .dropout_p ,
63+ is_causal = False ,
64+ scale = self .scale )
65+
66+ return out
67+
68+ def forward (self , x , mask = None ):
69+ B , F , _ = x .size ()
5270 x = self .norm (x )
5371 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
5472 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = self .heads ), qkv )
5573
56- dots = torch .matmul (q , k .transpose (- 1 , - 2 )) * self .scale
74+ if self .use_flash_attn :
75+ out = self .flash_attn (q , k , v , mask = mask )
5776
58- attn = self . attend ( dots )
59- attn = self . dropout ( attn )
77+ else :
78+ dots = torch . matmul ( q , k . transpose ( - 1 , - 2 )) * self . scale
6079
61- out = torch .matmul (attn , v )
80+ if mask is not None :
81+ dots = dots .masked_fill (mask .view (B , 1 , 1 , F ) == 0 , float ('-inf' ))
82+ attn = self .attend (dots )
83+ attn = self .dropout (attn )
84+
85+ out = torch .matmul (attn , v )
6286 out = rearrange (out , 'b h n d -> b n (h d)' )
6387 return self .to_out (out )
6488
6589class Transformer (nn .Module ):
66- def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. ):
90+ def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. , use_flash_attn = True ):
6791 super ().__init__ ()
92+ self .use_flash_attn = use_flash_attn
6893 self .norm = nn .LayerNorm (dim )
6994 self .layers = nn .ModuleList ([])
7095 for _ in range (depth ):
7196 self .layers .append (nn .ModuleList ([
7297 Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
7398 FeedForward (dim , mlp_dim , dropout = dropout )
7499 ]))
75- def forward (self , x ):
100+ def forward (self , x , mask = None ):
76101 for attn , ff in self .layers :
77- x = attn (x ) + x
102+ x = attn (x , mask = mask ) + x
78103 x = ff (x ) + x
79104 return self .norm (x )
80105
81106class FactorizedTransformer (nn .Module ):
82- def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. ):
107+ def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. , use_flash_attn = True ):
83108 super ().__init__ ()
109+ self .use_flash_attn = use_flash_attn
84110 self .norm = nn .LayerNorm (dim )
85111 self .layers = nn .ModuleList ([])
86112 for _ in range (depth ):
87113 self .layers .append (nn .ModuleList ([
88- Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
89- Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
114+ Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout , use_flash_attn = use_flash_attn ),
115+ Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout , use_flash_attn = use_flash_attn ),
90116 FeedForward (dim , mlp_dim , dropout = dropout )
91117 ]))
92118
@@ -122,6 +148,7 @@ def __init__(
122148 dropout = 0. ,
123149 emb_dropout = 0. ,
124150 variant = 'factorized_encoder' ,
151+ use_flash_attn : bool = True ,
125152 ):
126153 super ().__init__ ()
127154 image_height , image_width = pair (image_size )
@@ -154,19 +181,19 @@ def __init__(
154181
155182 if variant == 'factorized_encoder' :
156183 self .temporal_cls_token = nn .Parameter (torch .randn (1 , 1 , dim )) if not self .global_average_pool else None
157- self .spatial_transformer = Transformer (dim , spatial_depth , heads , dim_head , mlp_dim , dropout )
158- self .temporal_transformer = Transformer (dim , temporal_depth , heads , dim_head , mlp_dim , dropout )
184+ self .spatial_transformer = Transformer (dim , spatial_depth , heads , dim_head , mlp_dim , dropout , use_flash_attn )
185+ self .temporal_transformer = Transformer (dim , temporal_depth , heads , dim_head , mlp_dim , dropout , use_flash_attn )
159186 elif variant == 'factorized_self_attention' :
160187 assert spatial_depth == temporal_depth , 'Spatial and temporal depth must be the same for factorized self-attention'
161- self .factorized_transformer = FactorizedTransformer (dim , spatial_depth , heads , dim_head , mlp_dim , dropout )
188+ self .factorized_transformer = FactorizedTransformer (dim , spatial_depth , heads , dim_head , mlp_dim , dropout , use_flash_attn )
162189
163190 self .pool = pool
164191 self .to_latent = nn .Identity ()
165192
166193 self .mlp_head = nn .Linear (dim , num_classes )
167194 self .variant = variant
168195
169- def forward (self , video ):
196+ def forward (self , video , mask = None ):
170197 x = self .to_patch_embedding (video )
171198 b , f , n , _ = x .shape
172199
@@ -197,10 +224,15 @@ def forward(self, video):
197224
198225 x = torch .cat ((temporal_cls_tokens , x ), dim = 1 )
199226
227+ if mask is not None :
228+ temporal_mask = torch .ones ((b , f + 1 ), device = x .device , dtype = torch .bool )
229+ temporal_mask [:, 1 :] = mask
230+ else :
231+ temporal_mask = None
200232
201233 # attend across time
202234
203- x = self .temporal_transformer (x )
235+ x = self .temporal_transformer (x , mask = temporal_mask )
204236
205237 # excise out temporal cls token or average pool
206238
0 commit comments