Skip to content

Commit 4c89017

Browse files
committed
fix up vivit
1 parent 580258d commit 4c89017

File tree

2 files changed

+105
-45
lines changed

2 files changed

+105
-45
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.17.3"
7+
version = "1.17.4"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vivit.py

Lines changed: 104 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
from collections import namedtuple
22

33
import torch
4+
from torch import nn, cat
45
import torch.nn.functional as F
5-
from torch import nn
6+
from torch.nn import Module, ModuleList
67
from torch.nn.attention import SDPBackend, sdpa_kernel
78

89
from einops import rearrange, repeat, reduce
910
from einops.layers.torch import Rearrange
1011

11-
1212
# helpers
1313

1414
def exists(val):
1515
return val is not None
1616

17+
def divisible_by(num, den):
18+
return (num % den) == 0
19+
1720
def pair(t):
1821
return t if isinstance(t, tuple) else (t, t)
1922

2023
# classes
2124

22-
class FeedForward(nn.Module):
25+
class FeedForward(Module):
2326
def __init__(self, dim, hidden_dim, dropout = 0.):
2427
super().__init__()
2528
self.net = nn.Sequential(
@@ -33,7 +36,7 @@ def __init__(self, dim, hidden_dim, dropout = 0.):
3336
def forward(self, x):
3437
return self.net(x)
3538

36-
class Attention(nn.Module):
39+
class Attention(Module):
3740
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn = True):
3841
super().__init__()
3942
self.use_flash_attn = use_flash_attn
@@ -55,80 +58,101 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_flash_attn =
5558
nn.Dropout(dropout)
5659
) if project_out else nn.Identity()
5760

58-
def flash_attn(self, q, k, v, mask=None):
61+
def flash_attn(self, q, k, v, mask = None):
62+
5963
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
60-
out = F.scaled_dot_product_attention(q, k, v,
61-
attn_mask=mask,
62-
dropout_p=self.dropout_p,
63-
is_causal=False,
64-
scale=self.scale)
64+
65+
out = F.scaled_dot_product_attention(
66+
q, k, v,
67+
attn_mask = mask,
68+
dropout_p = self.dropout_p,
69+
is_causal = False,
70+
scale = self.scale
71+
)
6572

6673
return out
6774

68-
def forward(self, x, mask=None):
69-
B, F, _ = x.size()
75+
def forward(self, x, mask = None):
76+
batch, seq, _ = x.shape
77+
7078
x = self.norm(x)
7179
qkv = self.to_qkv(x).chunk(3, dim = -1)
7280
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
7381

82+
if exists(mask):
83+
mask = rearrange(mask, 'b j -> b 1 1 j')
84+
7485
if self.use_flash_attn:
75-
out = self.flash_attn(q, k, v, mask=mask)
86+
out = self.flash_attn(q, k, v, mask = mask)
7687

7788
else:
7889
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
7990

80-
if mask is not None:
81-
dots = dots.masked_fill(mask.view(B, 1, 1, F) == 0, float('-inf'))
91+
if exists(mask):
92+
mask = rearrange(mask, 'b j -> b 1 1 j')
93+
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
94+
8295
attn = self.attend(dots)
8396
attn = self.dropout(attn)
8497

8598
out = torch.matmul(attn, v)
99+
86100
out = rearrange(out, 'b h n d -> b n (h d)')
87101
return self.to_out(out)
88102

89-
class Transformer(nn.Module):
103+
class Transformer(Module):
90104
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
91105
super().__init__()
92106
self.use_flash_attn = use_flash_attn
107+
93108
self.norm = nn.LayerNorm(dim)
94-
self.layers = nn.ModuleList([])
109+
self.layers = ModuleList([])
95110
for _ in range(depth):
96111
self.layers.append(nn.ModuleList([
97112
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
98113
FeedForward(dim, mlp_dim, dropout = dropout)
99114
]))
100-
def forward(self, x, mask=None):
115+
116+
def forward(self, x, mask = None):
117+
101118
for attn, ff in self.layers:
102-
x = attn(x, mask=mask) + x
119+
x = attn(x, mask = mask) + x
103120
x = ff(x) + x
121+
104122
return self.norm(x)
105123

106-
class FactorizedTransformer(nn.Module):
124+
class FactorizedTransformer(Module):
107125
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_flash_attn = True):
108126
super().__init__()
109127
self.use_flash_attn = use_flash_attn
128+
110129
self.norm = nn.LayerNorm(dim)
111130
self.layers = nn.ModuleList([])
131+
112132
for _ in range(depth):
113133
self.layers.append(nn.ModuleList([
114134
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
115135
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
116136
FeedForward(dim, mlp_dim, dropout = dropout)
117137
]))
118138

119-
def forward(self, x):
120-
b, f, n, _ = x.shape
139+
def forward(self, x, mask = None):
140+
batch, frames, seq, _ = x.shape
141+
142+
if exists(mask):
143+
mask = repeat(mask, 'b ... -> (b space) ...', space = x.shape[2])
144+
121145
for spatial_attn, temporal_attn, ff in self.layers:
122146
x = rearrange(x, 'b f n d -> (b f) n d')
123147
x = spatial_attn(x) + x
124-
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
125-
x = temporal_attn(x) + x
148+
x = rearrange(x, '(b f) n d -> (b n) f d', b = batch, f = frames)
149+
x = temporal_attn(x, mask = mask) + x
126150
x = ff(x) + x
127-
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
151+
x = rearrange(x, '(b n) f d -> b f n d', b = batch, n = seq)
128152

129153
return self.norm(x)
130154

131-
class ViT(nn.Module):
155+
class ViViT(Module):
132156
def __init__(
133157
self,
134158
*,
@@ -154,8 +178,8 @@ def __init__(
154178
image_height, image_width = pair(image_size)
155179
patch_height, patch_width = pair(image_patch_size)
156180

157-
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
158-
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
181+
assert divisible_by(image_height, patch_height) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
182+
assert divisible_by(frames, frame_patch_size), 'Frames must be divisible by frame patch size'
159183
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
160184

161185
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
@@ -165,6 +189,8 @@ def __init__(
165189

166190
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
167191

192+
self.frame_patch_size = frame_patch_size
193+
168194
self.global_average_pool = pool == 'mean'
169195

170196
self.to_patch_embedding = nn.Sequential(
@@ -193,25 +219,36 @@ def __init__(
193219
self.mlp_head = nn.Linear(dim, num_classes)
194220
self.variant = variant
195221

196-
def forward(self, video, mask=None):
222+
def forward(self, video, mask = None):
223+
device = video.device
224+
197225
x = self.to_patch_embedding(video)
198-
b, f, n, _ = x.shape
226+
batch, frames, seq, _ = x.shape
199227

200-
x = x + self.pos_embedding[:, :f, :n]
228+
x = x + self.pos_embedding[:, :frames, :seq]
201229

202230
if exists(self.spatial_cls_token):
203-
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
204-
x = torch.cat((spatial_cls_tokens, x), dim = 2)
231+
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = batch, f = frames)
232+
x = cat((spatial_cls_tokens, x), dim = 2)
205233

206234
x = self.dropout(x)
207235

236+
# maybe temporal mask
237+
238+
temporal_mask = None
239+
240+
if exists(mask):
241+
temporal_mask = reduce(mask, 'b (f patch) -> b f', 'all', patch = self.frame_patch_size)
242+
243+
# the two variants
244+
208245
if self.variant == 'factorized_encoder':
209246
x = rearrange(x, 'b f n d -> (b f) n d')
210247

211248
# attend across space
212249

213250
x = self.spatial_transformer(x)
214-
x = rearrange(x, '(b f) n d -> b f n d', b = b)
251+
x = rearrange(x, '(b f) n d -> b f n d', b = batch)
215252

216253
# excise out the spatial cls tokens or average pool for temporal attention
217254

@@ -220,27 +257,50 @@ def forward(self, video, mask=None):
220257
# append temporal CLS tokens
221258

222259
if exists(self.temporal_cls_token):
223-
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
260+
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = batch)
224261

225-
x = torch.cat((temporal_cls_tokens, x), dim = 1)
226-
227-
if mask is not None:
228-
temporal_mask = torch.ones((b, f+1), device=x.device, dtype=torch.bool)
229-
temporal_mask[:, 1:] = mask
230-
else:
231-
temporal_mask = None
262+
x = cat((temporal_cls_tokens, x), dim = 1)
263+
264+
if exists(temporal_mask):
265+
temporal_mask = F.pad(temporal_mask, (1, 0), value = True)
232266

233267
# attend across time
234268

235-
x = self.temporal_transformer(x, mask=temporal_mask)
269+
x = self.temporal_transformer(x, mask = temporal_mask)
236270

237271
# excise out temporal cls token or average pool
238272

239273
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
240274

241275
elif self.variant == 'factorized_self_attention':
242-
x = self.factorized_transformer(x)
276+
277+
x = self.factorized_transformer(x, mask = temporal_mask)
278+
243279
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
244280

245281
x = self.to_latent(x)
246282
return self.mlp_head(x)
283+
284+
# main
285+
286+
if __name__ == '__main__':
287+
288+
vivit = ViViT(
289+
dim = 512,
290+
spatial_depth = 2,
291+
temporal_depth = 2,
292+
heads = 4,
293+
mlp_dim = 2048,
294+
image_size = 256,
295+
image_patch_size = 16,
296+
frames = 8,
297+
frame_patch_size = 2,
298+
num_classes = 1000,
299+
variant = 'factorized_encoder',
300+
)
301+
302+
video = torch.randn(3, 3, 8, 256, 256)
303+
mask = torch.randint(0, 2, (3, 8)).bool()
304+
305+
logits = vivit(video, mask = None)
306+
assert logits.shape == (3, 1000)

0 commit comments

Comments
 (0)