11import torch
22from torch import nn
3+ from torch .nn import Module , ModuleList
34
45from einops import rearrange , repeat
56from einops .layers .torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
1112
1213# classes
1314
14- class FeedForward (nn . Module ):
15+ class FeedForward (Module ):
1516 def __init__ (self , dim , hidden_dim , dropout = 0. ):
1617 super ().__init__ ()
1718 self .net = nn .Sequential (
@@ -26,7 +27,7 @@ def __init__(self, dim, hidden_dim, dropout = 0.):
2627 def forward (self , x ):
2728 return self .net (x )
2829
29- class Attention (nn . Module ):
30+ class Attention (Module ):
3031 def __init__ (self , dim , heads = 8 , dim_head = 64 , dropout = 0. ):
3132 super ().__init__ ()
3233 inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ def forward(self, x):
6263 out = rearrange (out , 'b h n d -> b n (h d)' )
6364 return self .to_out (out )
6465
65- class Transformer (nn . Module ):
66+ class Transformer (Module ):
6667 def __init__ (self , dim , depth , heads , dim_head , mlp_dim , dropout = 0. ):
6768 super ().__init__ ()
6869 self .norm = nn .LayerNorm (dim )
69- self .layers = nn .ModuleList ([])
70+ self .layers = ModuleList ([])
71+
7072 for _ in range (depth ):
71- self .layers .append (nn . ModuleList ([
73+ self .layers .append (ModuleList ([
7274 Attention (dim , heads = heads , dim_head = dim_head , dropout = dropout ),
7375 FeedForward (dim , mlp_dim , dropout = dropout )
7476 ]))
@@ -80,7 +82,7 @@ def forward(self, x):
8082
8183 return self .norm (x )
8284
83- class ViT (nn . Module ):
85+ class ViT (Module ):
8486 def __init__ (self , * , image_size , patch_size , num_classes , dim , depth , heads , mlp_dim , pool = 'cls' , channels = 3 , dim_head = 64 , dropout = 0. , emb_dropout = 0. ):
8587 super ().__init__ ()
8688 image_height , image_width = pair (image_size )
@@ -101,8 +103,9 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
101103 nn .LayerNorm (dim ),
102104 )
103105
104- self .cls_token = nn .Parameter (torch .randn (1 , num_cls_tokens , dim ))
105- self .pos_embedding = nn .Parameter (torch .randn (1 , num_patches + num_cls_tokens , dim ))
106+ self .num_cls_tokens = num_cls_tokens
107+ self .cls_token = nn .Parameter (torch .randn (num_cls_tokens , dim ))
108+ self .pos_embedding = nn .Parameter (torch .randn (num_patches + num_cls_tokens , dim ))
106109
107110 self .dropout = nn .Dropout (emb_dropout )
108111
@@ -114,12 +117,15 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
114117 self .mlp_head = nn .Linear (dim , num_classes )
115118
116119 def forward (self , img ):
120+ batch = img .shape [0 ]
117121 x = self .to_patch_embedding (img )
118- b , n , _ = x .shape
119122
120- cls_tokens = repeat (self .cls_token , '1 ... d -> b ... d' , b = b )
121- x = torch .cat ((cls_tokens , x ), dim = 1 )
122- x += self .pos_embedding [:, :(n + 1 )]
123+ cls_tokens = repeat (self .cls_token , '... d -> b ... d' , b = batch )
124+ x = torch .cat ((cls_tokens , x ), dim = 1 )
125+
126+ seq = x .shape [1 ]
127+
128+ x = x + self .pos_embedding [:seq ]
123129 x = self .dropout (x )
124130
125131 x = self .transformer (x )
0 commit comments