Skip to content

Commit c3dce22

Browse files
committed
allow for no final output head on the vit
1 parent fb5014f commit c3dce22

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.17.1"
7+
version = "1.17.2"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
113113
self.pool = pool
114114
self.to_latent = nn.Identity()
115115

116-
self.mlp_head = nn.Linear(dim, num_classes)
116+
self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None
117117

118118
def forward(self, img):
119119
batch = img.shape[0]
@@ -129,6 +129,9 @@ def forward(self, img):
129129

130130
x = self.transformer(x)
131131

132+
if not exists(self.mlp_head):
133+
return x
134+
132135
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
133136

134137
x = self.to_latent(x)

0 commit comments

Comments
 (0)