Skip to content

Commit 580258d

Browse files
authored
Allow to pass mask parameter for temporal transformer in ViVit (#356)
* Mask for temporal transformer in ViVit This allows to pad videos to certain length which allow the transformer to ignore padded frames using batch sizes > 1 * Added flash attention to vivit * Added flash attention to vivit * Added flash attention to vivit
1 parent 6f1caef commit 580258d

File tree

1 file changed

+49
-17
lines changed

1 file changed

+49
-17
lines changed

vit_pytorch/vivit.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from collections import namedtuple
2+
13
import torch
4+
import torch.nn.functional as F
25
from torch import nn
6+
from torch.nn.attention import SDPBackend, sdpa_kernel
37

48
from einops import rearrange, repeat, reduce
59
from einops.layers.torch import Rearrange
610

11+
712
# helpers
813

914
def exists(val):
@@ -29,8 +34,10 @@ def forward(self, x):
2934
return self.net(x)
3035

3136
class Attention(nn.Module):
32-
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
37+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
3338
super().__init__()
39+
self.use_flash_attn = use_flash_attn
40+
self.dropout_p = dropout
3441
inner_dim = dim_head * heads
3542
project_out = not (heads == 1 and dim_head == dim)
3643

@@ -48,45 +55,64 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4855
nn.Dropout(dropout)
4956
) if project_out else nn.Identity()
5057

51-
def forward(self, x):
58+
def flash_attn(self, q, k, v, mask=None):
59+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
60+
out = F.scaled_dot_product_attention(q, k, v,
61+
attn_mask=mask,
62+
dropout_p=self.dropout_p,
63+
is_causal=False,
64+
scale=self.scale)
65+
66+
return out
67+
68+
def forward(self, x, mask=None):
69+
B, F, _ = x.size()
5270
x = self.norm(x)
5371
qkv = self.to_qkv(x).chunk(3, dim = -1)
5472
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
5573

56-
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
74+
if self.use_flash_attn:
75+
out = self.flash_attn(q, k, v, mask=mask)
5776

58-
attn = self.attend(dots)
59-
attn = self.dropout(attn)
77+
else:
78+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
6079

61-
out = torch.matmul(attn, v)
80+
if mask is not None:
81+
dots = dots.masked_fill(mask.view(B, 1, 1, F) == 0, float('-inf'))
82+
attn = self.attend(dots)
83+
attn = self.dropout(attn)
84+
85+
out = torch.matmul(attn, v)
6286
out = rearrange(out, 'b h n d -> b n (h d)')
6387
return self.to_out(out)
6488

6589
class Transformer(nn.Module):
66-
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
90+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
6791
super().__init__()
92+
self.use_flash_attn = use_flash_attn
6893
self.norm = nn.LayerNorm(dim)
6994
self.layers = nn.ModuleList([])
7095
for _ in range(depth):
7196
self.layers.append(nn.ModuleList([
7297
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
7398
FeedForward(dim, mlp_dim, dropout = dropout)
7499
]))
75-
def forward(self, x):
100+
def forward(self, x, mask=None):
76101
for attn, ff in self.layers:
77-
x = attn(x) + x
102+
x = attn(x, mask=mask) + x
78103
x = ff(x) + x
79104
return self.norm(x)
80105

81106
class FactorizedTransformer(nn.Module):
82-
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
107+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
83108
super().__init__()
109+
self.use_flash_attn = use_flash_attn
84110
self.norm = nn.LayerNorm(dim)
85111
self.layers = nn.ModuleList([])
86112
for _ in range(depth):
87113
self.layers.append(nn.ModuleList([
88-
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
89-
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
114+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
115+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
90116
FeedForward(dim, mlp_dim, dropout = dropout)
91117
]))
92118

@@ -122,6 +148,7 @@ def __init__(
122148
dropout = 0.,
123149
emb_dropout = 0.,
124150
variant = 'factorized_encoder',
151+
use_flash_attn: bool = True,
125152
):
126153
super().__init__()
127154
image_height, image_width = pair(image_size)
@@ -154,19 +181,19 @@ def __init__(
154181

155182
if variant == 'factorized_encoder':
156183
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
157-
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
158-
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
184+
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
185+
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
159186
elif variant == 'factorized_self_attention':
160187
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
161-
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
188+
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout, use_flash_attn)
162189

163190
self.pool = pool
164191
self.to_latent = nn.Identity()
165192

166193
self.mlp_head = nn.Linear(dim, num_classes)
167194
self.variant = variant
168195

169-
def forward(self, video):
196+
def forward(self, video, mask=None):
170197
x = self.to_patch_embedding(video)
171198
b, f, n, _ = x.shape
172199

@@ -197,10 +224,15 @@ def forward(self, video):
197224

198225
x = torch.cat((temporal_cls_tokens, x), dim = 1)
199226

227+
if mask is not None:
228+
temporal_mask = torch.ones((b, f+1), device=x.device, dtype=torch.bool)
229+
temporal_mask[:, 1:] = mask
230+
else:
231+
temporal_mask = None
200232

201233
# attend across time
202234

203-
x = self.temporal_transformer(x)
235+
x = self.temporal_transformer(x, mask=temporal_mask)
204236

205237
# excise out temporal cls token or average pool
206238

0 commit comments

Comments
 (0)