Skip to content

Commit 4b8e488

Browse files
committed
Mask for temporal transformer in ViVit
This allows to pad videos to certain length which allow the transformer to ignore padded frames using batch sizes > 1
1 parent fb5014f commit 4b8e488

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

vit_pytorch/vivit.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
4848
nn.Dropout(dropout)
4949
) if project_out else nn.Identity()
5050

51-
def forward(self, x):
51+
def forward(self, x, mask=None):
5252
x = self.norm(x)
5353
qkv = self.to_qkv(x).chunk(3, dim = -1)
5454
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
5555

5656
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
5757

58+
if mask is not None:
59+
dots = dots.masked_fill(mask.view(4, 1, 1, 5) == 0, float('-inf'))
5860
attn = self.attend(dots)
5961
attn = self.dropout(attn)
6062

@@ -72,9 +74,9 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
7274
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
7375
FeedForward(dim, mlp_dim, dropout = dropout)
7476
]))
75-
def forward(self, x):
77+
def forward(self, x, mask=None):
7678
for attn, ff in self.layers:
77-
x = attn(x) + x
79+
x = attn(x, mask=mask) + x
7880
x = ff(x) + x
7981
return self.norm(x)
8082

@@ -166,7 +168,7 @@ def __init__(
166168
self.mlp_head = nn.Linear(dim, num_classes)
167169
self.variant = variant
168170

169-
def forward(self, video):
171+
def forward(self, video, mask=None):
170172
x = self.to_patch_embedding(video)
171173
b, f, n, _ = x.shape
172174

@@ -197,10 +199,15 @@ def forward(self, video):
197199

198200
x = torch.cat((temporal_cls_tokens, x), dim = 1)
199201

202+
if mask is not None:
203+
temporal_mask = torch.ones((b, f+1), device=x.device, dtype=torch.bool)
204+
temporal_mask[:, 1:] = mask
205+
else:
206+
temporal_mask = None
200207

201208
# attend across time
202209

203-
x = self.temporal_transformer(x)
210+
x = self.temporal_transformer(x, mask=temporal_mask)
204211

205212
# excise out temporal cls token or average pool
206213

0 commit comments

Comments
 (0)