Skip to content

Commit ad80b6c

Browse files
committed
fix positional embed for mean pool case and cleanup
1 parent 0ebd4ed commit ad80b6c

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.16.0"
7+
version = "1.16.1"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vit.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from torch import nn
3+
from torch.nn import Module, ModuleList
34

45
from einops import rearrange, repeat
56
from einops.layers.torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
1112

1213
# classes
1314

14-
class FeedForward(nn.Module):
15+
class FeedForward(Module):
1516
def __init__(self, dim, hidden_dim, dropout = 0.):
1617
super().__init__()
1718
self.net = nn.Sequential(
@@ -26,7 +27,7 @@ def __init__(self, dim, hidden_dim, dropout = 0.):
2627
def forward(self, x):
2728
return self.net(x)
2829

29-
class Attention(nn.Module):
30+
class Attention(Module):
3031
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
3132
super().__init__()
3233
inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ def forward(self, x):
6263
out = rearrange(out, 'b h n d -> b n (h d)')
6364
return self.to_out(out)
6465

65-
class Transformer(nn.Module):
66+
class Transformer(Module):
6667
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
6768
super().__init__()
6869
self.norm = nn.LayerNorm(dim)
69-
self.layers = nn.ModuleList([])
70+
self.layers = ModuleList([])
71+
7072
for _ in range(depth):
71-
self.layers.append(nn.ModuleList([
73+
self.layers.append(ModuleList([
7274
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
7375
FeedForward(dim, mlp_dim, dropout = dropout)
7476
]))
@@ -80,7 +82,7 @@ def forward(self, x):
8082

8183
return self.norm(x)
8284

83-
class ViT(nn.Module):
85+
class ViT(Module):
8486
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
8587
super().__init__()
8688
image_height, image_width = pair(image_size)
@@ -101,8 +103,9 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
101103
nn.LayerNorm(dim),
102104
)
103105

104-
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
105-
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
106+
self.num_cls_tokens = num_cls_tokens
107+
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
108+
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
106109

107110
self.dropout = nn.Dropout(emb_dropout)
108111

@@ -114,12 +117,15 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
114117
self.mlp_head = nn.Linear(dim, num_classes)
115118

116119
def forward(self, img):
120+
batch = img.shape[0]
117121
x = self.to_patch_embedding(img)
118-
b, n, _ = x.shape
119122

120-
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
121-
x = torch.cat((cls_tokens, x), dim=1)
122-
x += self.pos_embedding[:, :(n + 1)]
123+
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
124+
x = torch.cat((cls_tokens, x), dim = 1)
125+
126+
seq = x.shape[1]
127+
128+
x = x + self.pos_embedding[:seq]
123129
x = self.dropout(x)
124130

125131
x = self.transformer(x)

0 commit comments

Comments
 (0)