Skip to content

Commit 6aa0374

Browse files
committed
register tokens for the AST in VAAT
1 parent b35a97d commit 6aa0374

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.15.5"
7+
version = "1.15.6"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vaat.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ def __init__(
215215
spec_hop_length = None,
216216
spec_pad = 0,
217217
spec_center = True,
218-
spec_pad_mode = 'reflect'
218+
spec_pad_mode = 'reflect',
219+
num_register_tokens = 4
219220
):
220221
super().__init__()
221222
self.dim = dim
@@ -256,8 +257,11 @@ def __init__(
256257
)
257258

258259
self.final_norm = nn.LayerNorm(dim)
260+
259261
self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
260262

263+
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
264+
261265
def forward(
262266
self,
263267
raw_audio_or_spec, # (b t) | (b f t)
@@ -296,6 +300,12 @@ def forward(
296300

297301
tokens = rearrange(tokens, 'b ... c -> b (...) c')
298302

303+
# register tokens
304+
305+
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)
306+
307+
tokens, packed_shape = pack((register_tokens, tokens), 'b * d')
308+
299309
# attention
300310

301311
attended, hiddens = self.transformer(tokens, return_hiddens = True)
@@ -307,6 +317,8 @@ def forward(
307317
if return_hiddens:
308318
return normed, stack(hiddens)
309319

320+
register_tokens, normed = unpack(normed, packed_shape, 'b * d')
321+
310322
pooled = reduce(normed, 'b n d -> b d', 'mean')
311323

312324
maybe_logits = self.mlp_head(pooled)
@@ -384,7 +396,7 @@ def forward(self, img, return_hiddens = False):
384396
if return_hiddens:
385397
return x, stack(hiddens)
386398

387-
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
399+
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
388400

389401
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
390402

vit_pytorch/vat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def forward(self, img, return_hiddens = False):
237237
if return_hiddens:
238238
return x, stack(hiddens)
239239

240-
cls_tokens, x, register_tokens = unpack(x, packed_shape, 'b * d')
240+
register_tokens, cls_tokens, x = unpack(x, packed_shape, 'b * d')
241241

242242
x = x.mean(dim = 1) if self.pool == 'mean' else cls_tokens
243243

0 commit comments

Comments
 (0)