Skip to content

Commit 3cff5e5

Browse files
committed
address #352
1 parent fdaf7f9 commit 3cff5e5

File tree

6 files changed

+6
-6
lines changed

6 files changed

+6
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.16.1"
7+
version = "1.16.2"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/na_vit_nested_tensor_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(
176176

177177
self.channels = channels
178178
self.patch_size = patch_size
179-
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
179+
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
180180

181181
self.to_patch_embedding = nn.Sequential(
182182
nn.LayerNorm(patch_dim),

vit_pytorch/simple_flash_attn_vit_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
146146
patch_dim = channels * patch_height * patch_width * frame_patch_size
147147

148148
self.to_patch_embedding = nn.Sequential(
149-
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
149+
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
150150
nn.LayerNorm(patch_dim),
151151
nn.Linear(patch_dim, dim),
152152
nn.LayerNorm(dim),

vit_pytorch/simple_vit_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
103103
patch_dim = channels * patch_height * patch_width * frame_patch_size
104104

105105
self.to_patch_embedding = nn.Sequential(
106-
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
106+
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
107107
nn.LayerNorm(patch_dim),
108108
nn.Linear(patch_dim, dim),
109109
nn.LayerNorm(dim),

vit_pytorch/vit_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
8989
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
9090

9191
self.to_patch_embedding = nn.Sequential(
92-
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
92+
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
9393
nn.LayerNorm(patch_dim),
9494
nn.Linear(patch_dim, dim),
9595
nn.LayerNorm(dim),

vit_pytorch/vivit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
self.global_average_pool = pool == 'mean'
142142

143143
self.to_patch_embedding = nn.Sequential(
144-
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
144+
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
145145
nn.LayerNorm(patch_dim),
146146
nn.Linear(patch_dim, dim),
147147
nn.LayerNorm(dim)

0 commit comments

Comments
 (0)