@@ -215,7 +215,8 @@ def __init__(
215215 spec_hop_length = None ,
216216 spec_pad = 0 ,
217217 spec_center = True ,
218- spec_pad_mode = 'reflect'
218+ spec_pad_mode = 'reflect' ,
219+ num_register_tokens = 4
219220 ):
220221 super ().__init__ ()
221222 self .dim = dim
@@ -256,8 +257,11 @@ def __init__(
256257 )
257258
258259 self .final_norm = nn .LayerNorm (dim )
260+
259261 self .mlp_head = nn .Linear (dim , num_classes ) if exists (num_classes ) else nn .Identity ()
260262
263+ self .register_tokens = nn .Parameter (torch .randn (num_register_tokens , dim ) * 1e-2 )
264+
261265 def forward (
262266 self ,
263267 raw_audio_or_spec , # (b t) | (b f t)
@@ -296,6 +300,12 @@ def forward(
296300
297301 tokens = rearrange (tokens , 'b ... c -> b (...) c' )
298302
303+ # register tokens
304+
305+ register_tokens = repeat (self .register_tokens , 'n d -> b n d' , b = batch )
306+
307+ tokens , packed_shape = pack ((register_tokens , tokens ), 'b * d' )
308+
299309 # attention
300310
301311 attended , hiddens = self .transformer (tokens , return_hiddens = True )
@@ -307,6 +317,8 @@ def forward(
307317 if return_hiddens :
308318 return normed , stack (hiddens )
309319
320+ register_tokens , normed = unpack (normed , packed_shape , 'b * d' )
321+
310322 pooled = reduce (normed , 'b n d -> b d' , 'mean' )
311323
312324 maybe_logits = self .mlp_head (pooled )
@@ -384,7 +396,7 @@ def forward(self, img, return_hiddens = False):
384396 if return_hiddens :
385397 return x , stack (hiddens )
386398
387- cls_tokens , x , register_tokens = unpack (x , packed_shape , 'b * d' )
399+ register_tokens , cls_tokens , x = unpack (x , packed_shape , 'b * d' )
388400
389401 x = x .mean (dim = 1 ) if self .pool == 'mean' else cls_tokens
390402
0 commit comments