Skip to content

Commit 7e703f2

Browse files
committed
get a version of n-dimensional vit with golden gate polar coordinate embeddings into the repo for future use
1 parent 0b7518e commit 7e703f2

File tree

3 files changed

+365
-1
lines changed

3 files changed

+365
-1
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,4 +2213,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
22132213
}
22142214
```
22152215

2216+
```bibtex
2217+
@misc{gopalakrishnan2025decouplingwhatwherepolar,
2218+
title = {Decoupling the "What" and "Where" With Polar Coordinate Positional Embeddings},
2219+
author = {Anand Gopalakrishnan and Robert Csordás and Jürgen Schmidhuber and Michael C. Mozer},
2220+
year = {2025},
2221+
eprint = {2509.10534},
2222+
archivePrefix = {arXiv},
2223+
primaryClass = {cs.LG},
2224+
url = {https://arxiv.org/abs/2509.10534},
2225+
}
2226+
```
2227+
22162228
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.16.5"
7+
version = "1.17.1"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vit_nd_pope.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import pi, nn, arange, cat, stack, Tensor
6+
from torch.nn import Module, ModuleList
7+
from torch.amp import autocast
8+
9+
from einops import rearrange, repeat, reduce, pack, unpack
10+
from einops.layers.torch import Rearrange
11+
12+
# helpers
13+
14+
def exists(val):
15+
return val is not None
16+
17+
def l2norm(t):
18+
return F.normalize(t, dim = -1, p = 2)
19+
20+
def join(arr, delimiter = ' '):
21+
return delimiter.join(arr)
22+
23+
def ensure_tuple(t, length):
24+
if isinstance(t, (tuple, list)):
25+
assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}'
26+
return tuple(t)
27+
28+
return (t,) * length
29+
30+
# golden gate rotary - Jerry Xiong, PhD student at UIUC
31+
# https://jerryxio.ng/posts/nd-rope/
32+
33+
# but using polar version instead
34+
# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534
35+
36+
def _phi(m: int) -> float:
37+
x = 2.0
38+
for _ in range(10):
39+
x = (1 + x) ** (1.0 / (m + 1.0))
40+
return x
41+
42+
def make_directions(n: int, d: int) -> Tensor:
43+
g = _phi(d)
44+
alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64)
45+
i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1)
46+
z = torch.fmod(i * alpha, 1.0)
47+
directions = torch.erfinv(2.0 * z - 1.0)
48+
directions = l2norm(directions)
49+
return directions.float()
50+
51+
class GoldenGatePoPENd(Module):
52+
def __init__(
53+
self,
54+
dim_pos: int,
55+
heads: int,
56+
dim_head: int,
57+
min_freq: float = 1.0,
58+
max_freq: float = 10000.0,
59+
p_zero_freqs: float = 0.0, # proportion of frequencies set to 0
60+
init_learned_bias_uniform = False
61+
):
62+
super().__init__()
63+
n_freqs = dim_head
64+
n_zero_freqs = round(p_zero_freqs * n_freqs)
65+
66+
omega = cat((
67+
torch.zeros(n_zero_freqs),
68+
min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs),
69+
))
70+
71+
directions = rearrange(
72+
make_directions(heads * n_freqs, dim_pos),
73+
'(h f) p -> h f p',
74+
h = heads
75+
)
76+
77+
omega_expanded = rearrange(omega, 'f -> f 1')
78+
self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p)
79+
80+
self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head))
81+
82+
if init_learned_bias_uniform:
83+
self.learned_bias.uniform_(-2. * pi, 0.)
84+
85+
@autocast('cuda', enabled = False)
86+
def forward(self, pos):
87+
88+
freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p')
89+
positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p')
90+
91+
# compute theta for each (batch, head, seq, freq)
92+
93+
theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum')
94+
95+
bias = self.learned_bias.clamp(-2. * pi, 0.)
96+
bias = rearrange(bias, 'h d -> h 1 d')
97+
98+
return theta, bias
99+
100+
@autocast('cuda', enabled = False)
101+
def apply_polar_pos_emb(t, freqs):
102+
orig_dtype = t.dtype
103+
104+
t = t.float()
105+
t = F.softplus(t)
106+
107+
out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1)
108+
109+
return out.type(orig_dtype)
110+
111+
# classes
112+
113+
class FeedForward(Module):
114+
def __init__(self, dim, hidden_dim, dropout = 0.):
115+
super().__init__()
116+
self.net = nn.Sequential(
117+
nn.LayerNorm(dim),
118+
nn.Linear(dim, hidden_dim),
119+
nn.GELU(),
120+
nn.Dropout(dropout),
121+
nn.Linear(hidden_dim, dim),
122+
nn.Dropout(dropout)
123+
)
124+
125+
def forward(self, x):
126+
return self.net(x)
127+
128+
class Attention(Module):
129+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
130+
super().__init__()
131+
inner_dim = dim_head * heads
132+
project_out = not (heads == 1 and dim_head == dim)
133+
134+
self.heads = heads
135+
self.scale = dim_head ** -0.5
136+
137+
self.norm = nn.LayerNorm(dim)
138+
self.attend = nn.Softmax(dim = -1)
139+
self.dropout = nn.Dropout(dropout)
140+
141+
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
142+
self.to_v = nn.Linear(dim, inner_dim, bias = False)
143+
144+
self.to_out = nn.Sequential(
145+
nn.Linear(inner_dim, dim),
146+
nn.Dropout(dropout)
147+
) if project_out else nn.Identity()
148+
149+
def forward(self, x, polar_pos_emb = None):
150+
x = self.norm(x)
151+
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
152+
153+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
154+
155+
if exists(polar_pos_emb):
156+
freqs, bias = polar_pos_emb
157+
q = apply_polar_pos_emb(q, freqs)
158+
k = apply_polar_pos_emb(k, freqs + bias)
159+
160+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
161+
162+
attn = self.attend(dots)
163+
attn = self.dropout(attn)
164+
165+
out = torch.matmul(attn, v)
166+
out = rearrange(out, 'b h n d -> b n (h d)')
167+
return self.to_out(out)
168+
169+
class Transformer(Module):
170+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None):
171+
super().__init__()
172+
self.norm = nn.LayerNorm(dim)
173+
174+
self.polar_emb = polar_emb
175+
176+
self.layers = ModuleList([])
177+
178+
for _ in range(depth):
179+
self.layers.append(ModuleList([
180+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
181+
FeedForward(dim, mlp_dim, dropout = dropout)
182+
]))
183+
184+
def forward(self, x, pos = None):
185+
186+
# pope embedding
187+
188+
polar_pos_emb = None
189+
if exists(pos) and exists(self.polar_emb):
190+
polar_pos_emb = self.polar_emb(pos)
191+
192+
# transformer layers
193+
194+
for attn, ff in self.layers:
195+
x = attn(x, polar_pos_emb) + x
196+
x = ff(x) + x
197+
198+
return self.norm(x)
199+
200+
class ViTND(Module):
201+
def __init__(
202+
self,
203+
*,
204+
ndim: int,
205+
input_shape: int | tuple[int, ...],
206+
patch_size: int | tuple[int, ...],
207+
num_classes: int,
208+
dim: int,
209+
depth: int,
210+
heads: int,
211+
mlp_dim: int,
212+
channels: int = 3,
213+
dim_head: int = 64,
214+
dropout: float = 0.,
215+
emb_dropout: float = 0.,
216+
pope_min_freq: float = 1.0,
217+
pope_max_freq: float = 10000.0,
218+
pope_p_zero_freqs: float = 0.0,
219+
pope_init_learned_bias_uniform = False
220+
):
221+
super().__init__()
222+
223+
assert 1 <= ndim <= 7, 'ndim must be between 1 and 7'
224+
225+
self.ndim = ndim
226+
227+
input_shape = ensure_tuple(input_shape, ndim)
228+
patch_size = ensure_tuple(patch_size, ndim)
229+
230+
for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)):
231+
assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})'
232+
233+
num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)]
234+
num_patches = 1
235+
for n in num_patches_per_dim:
236+
num_patches *= n
237+
238+
patch_dim = channels
239+
for p in patch_size:
240+
patch_dim *= p
241+
242+
dim_names = 'fghijkl'[:ndim]
243+
244+
input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)]
245+
patch_dims = [f'p{i}' for i in range(ndim)]
246+
247+
input_pattern = f'b c {join(input_dims)}'
248+
output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)'
249+
rearrange_str = f'{input_pattern} -> {output_pattern}'
250+
251+
rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)}
252+
253+
self.to_patch_embedding = nn.Sequential(
254+
Rearrange(rearrange_str, **rearrange_kwargs),
255+
nn.Linear(patch_dim, dim),
256+
nn.LayerNorm(dim),
257+
)
258+
259+
self.dropout = nn.Dropout(emb_dropout)
260+
261+
# golden gate pope
262+
263+
self.polar_emb = GoldenGatePoPENd(
264+
dim_pos = ndim,
265+
heads = heads,
266+
dim_head = dim_head,
267+
min_freq = pope_min_freq,
268+
max_freq = pope_max_freq,
269+
p_zero_freqs = pope_p_zero_freqs,
270+
init_learned_bias_uniform = pope_init_learned_bias_uniform
271+
)
272+
273+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb)
274+
275+
self.to_latent = nn.Identity()
276+
self.mlp_head = nn.Linear(dim, num_classes)
277+
278+
def muon_parameters(self):
279+
params = []
280+
281+
for m in self.modules():
282+
if isinstance(m, Attention):
283+
params.extend([
284+
m.to_v.weight,
285+
m.to_out[0].weight
286+
])
287+
elif isinstance(m, FeedForward):
288+
params.extend([
289+
m.net[1].weight,
290+
m.net[-2].weight
291+
])
292+
293+
return params
294+
295+
def forward(
296+
self,
297+
x,
298+
return_embed = False
299+
):
300+
x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim)
301+
302+
batch, *spatial_dims, _, device = *x.shape, x.device
303+
304+
# Generate position coordinates
305+
306+
grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims]
307+
grid = torch.meshgrid(*grids, indexing = 'ij')
308+
pos = stack(grid, dim = -1) # (*spatial_dims, ndim)
309+
310+
# flatten spatial dimensions for attention with nd rotary
311+
312+
pos = repeat(pos, '... p -> b (...) p', b = batch)
313+
x, packed_shape = pack([x], 'b * d')
314+
315+
x = self.dropout(x)
316+
317+
embed = self.transformer(x, pos)
318+
319+
# return the embed with reconstituted patch shape
320+
321+
if return_embed:
322+
embed, = unpack(embed, packed_shape, 'b * d')
323+
return embed
324+
325+
# pooling to logits
326+
327+
pooled = reduce(embed, 'b n d -> b d', 'mean')
328+
329+
pooled = self.to_latent(pooled)
330+
return self.mlp_head(pooled)
331+
332+
if __name__ == '__main__':
333+
334+
model = ViTND(
335+
ndim = 5,
336+
input_shape = (4, 8, 16, 32, 64),
337+
patch_size = (2, 2, 4, 4, 8),
338+
num_classes = 1000,
339+
dim = 512,
340+
depth = 6,
341+
heads = 8,
342+
mlp_dim = 2048,
343+
channels = 3,
344+
dropout = 0.1,
345+
emb_dropout = 0.1
346+
)
347+
348+
data = torch.randn(3, 3, 4, 8, 16, 32, 64)
349+
350+
logits = model(data)
351+
352+
embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512)

0 commit comments

Comments
 (0)