|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn.functional as F |
| 5 | +from torch import pi, nn, arange, cat, stack, Tensor |
| 6 | +from torch.nn import Module, ModuleList |
| 7 | +from torch.amp import autocast |
| 8 | + |
| 9 | +from einops import rearrange, repeat, reduce, pack, unpack |
| 10 | +from einops.layers.torch import Rearrange |
| 11 | + |
| 12 | +# helpers |
| 13 | + |
| 14 | +def exists(val): |
| 15 | + return val is not None |
| 16 | + |
| 17 | +def l2norm(t): |
| 18 | + return F.normalize(t, dim = -1, p = 2) |
| 19 | + |
| 20 | +def join(arr, delimiter = ' '): |
| 21 | + return delimiter.join(arr) |
| 22 | + |
| 23 | +def ensure_tuple(t, length): |
| 24 | + if isinstance(t, (tuple, list)): |
| 25 | + assert len(t) == length, f'Expected tuple of length {length}, got {len(t)}' |
| 26 | + return tuple(t) |
| 27 | + |
| 28 | + return (t,) * length |
| 29 | + |
| 30 | +# golden gate rotary - Jerry Xiong, PhD student at UIUC |
| 31 | +# https://jerryxio.ng/posts/nd-rope/ |
| 32 | + |
| 33 | +# but using polar version instead |
| 34 | +# Gopalakrishnan et al. https://arxiv.org/abs/2509.10534 |
| 35 | + |
| 36 | +def _phi(m: int) -> float: |
| 37 | + x = 2.0 |
| 38 | + for _ in range(10): |
| 39 | + x = (1 + x) ** (1.0 / (m + 1.0)) |
| 40 | + return x |
| 41 | + |
| 42 | +def make_directions(n: int, d: int) -> Tensor: |
| 43 | + g = _phi(d) |
| 44 | + alpha = (1.0 / g) ** arange(1, d + 1, dtype = torch.float64) |
| 45 | + i = arange(1, n + 1, dtype = torch.float64).unsqueeze(1) |
| 46 | + z = torch.fmod(i * alpha, 1.0) |
| 47 | + directions = torch.erfinv(2.0 * z - 1.0) |
| 48 | + directions = l2norm(directions) |
| 49 | + return directions.float() |
| 50 | + |
| 51 | +class GoldenGatePoPENd(Module): |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + dim_pos: int, |
| 55 | + heads: int, |
| 56 | + dim_head: int, |
| 57 | + min_freq: float = 1.0, |
| 58 | + max_freq: float = 10000.0, |
| 59 | + p_zero_freqs: float = 0.0, # proportion of frequencies set to 0 |
| 60 | + init_learned_bias_uniform = False |
| 61 | + ): |
| 62 | + super().__init__() |
| 63 | + n_freqs = dim_head |
| 64 | + n_zero_freqs = round(p_zero_freqs * n_freqs) |
| 65 | + |
| 66 | + omega = cat(( |
| 67 | + torch.zeros(n_zero_freqs), |
| 68 | + min_freq * (max_freq / min_freq) ** torch.linspace(0, 1, n_freqs - n_zero_freqs), |
| 69 | + )) |
| 70 | + |
| 71 | + directions = rearrange( |
| 72 | + make_directions(heads * n_freqs, dim_pos), |
| 73 | + '(h f) p -> h f p', |
| 74 | + h = heads |
| 75 | + ) |
| 76 | + |
| 77 | + omega_expanded = rearrange(omega, 'f -> f 1') |
| 78 | + self.register_buffer('freqs', directions * omega_expanded) # shape: (h, f, p) |
| 79 | + |
| 80 | + self.learned_bias = nn.Parameter(torch.zeros(heads, dim_head)) |
| 81 | + |
| 82 | + if init_learned_bias_uniform: |
| 83 | + self.learned_bias.uniform_(-2. * pi, 0.) |
| 84 | + |
| 85 | + @autocast('cuda', enabled = False) |
| 86 | + def forward(self, pos): |
| 87 | + |
| 88 | + freqs = rearrange(self.freqs, 'h f p -> 1 h 1 f p') |
| 89 | + positions = rearrange(pos.float(), 'b n p -> b 1 n 1 p') |
| 90 | + |
| 91 | + # compute theta for each (batch, head, seq, freq) |
| 92 | + |
| 93 | + theta = reduce(freqs * positions, 'b h n f p -> b h n f', 'sum') |
| 94 | + |
| 95 | + bias = self.learned_bias.clamp(-2. * pi, 0.) |
| 96 | + bias = rearrange(bias, 'h d -> h 1 d') |
| 97 | + |
| 98 | + return theta, bias |
| 99 | + |
| 100 | +@autocast('cuda', enabled = False) |
| 101 | +def apply_polar_pos_emb(t, freqs): |
| 102 | + orig_dtype = t.dtype |
| 103 | + |
| 104 | + t = t.float() |
| 105 | + t = F.softplus(t) |
| 106 | + |
| 107 | + out = cat((t * freqs.cos(), t * freqs.sin()), dim = -1) |
| 108 | + |
| 109 | + return out.type(orig_dtype) |
| 110 | + |
| 111 | +# classes |
| 112 | + |
| 113 | +class FeedForward(Module): |
| 114 | + def __init__(self, dim, hidden_dim, dropout = 0.): |
| 115 | + super().__init__() |
| 116 | + self.net = nn.Sequential( |
| 117 | + nn.LayerNorm(dim), |
| 118 | + nn.Linear(dim, hidden_dim), |
| 119 | + nn.GELU(), |
| 120 | + nn.Dropout(dropout), |
| 121 | + nn.Linear(hidden_dim, dim), |
| 122 | + nn.Dropout(dropout) |
| 123 | + ) |
| 124 | + |
| 125 | + def forward(self, x): |
| 126 | + return self.net(x) |
| 127 | + |
| 128 | +class Attention(Module): |
| 129 | + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): |
| 130 | + super().__init__() |
| 131 | + inner_dim = dim_head * heads |
| 132 | + project_out = not (heads == 1 and dim_head == dim) |
| 133 | + |
| 134 | + self.heads = heads |
| 135 | + self.scale = dim_head ** -0.5 |
| 136 | + |
| 137 | + self.norm = nn.LayerNorm(dim) |
| 138 | + self.attend = nn.Softmax(dim = -1) |
| 139 | + self.dropout = nn.Dropout(dropout) |
| 140 | + |
| 141 | + self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False) |
| 142 | + self.to_v = nn.Linear(dim, inner_dim, bias = False) |
| 143 | + |
| 144 | + self.to_out = nn.Sequential( |
| 145 | + nn.Linear(inner_dim, dim), |
| 146 | + nn.Dropout(dropout) |
| 147 | + ) if project_out else nn.Identity() |
| 148 | + |
| 149 | + def forward(self, x, polar_pos_emb = None): |
| 150 | + x = self.norm(x) |
| 151 | + qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x)) |
| 152 | + |
| 153 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) |
| 154 | + |
| 155 | + if exists(polar_pos_emb): |
| 156 | + freqs, bias = polar_pos_emb |
| 157 | + q = apply_polar_pos_emb(q, freqs) |
| 158 | + k = apply_polar_pos_emb(k, freqs + bias) |
| 159 | + |
| 160 | + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
| 161 | + |
| 162 | + attn = self.attend(dots) |
| 163 | + attn = self.dropout(attn) |
| 164 | + |
| 165 | + out = torch.matmul(attn, v) |
| 166 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 167 | + return self.to_out(out) |
| 168 | + |
| 169 | +class Transformer(Module): |
| 170 | + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., polar_emb = None): |
| 171 | + super().__init__() |
| 172 | + self.norm = nn.LayerNorm(dim) |
| 173 | + |
| 174 | + self.polar_emb = polar_emb |
| 175 | + |
| 176 | + self.layers = ModuleList([]) |
| 177 | + |
| 178 | + for _ in range(depth): |
| 179 | + self.layers.append(ModuleList([ |
| 180 | + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), |
| 181 | + FeedForward(dim, mlp_dim, dropout = dropout) |
| 182 | + ])) |
| 183 | + |
| 184 | + def forward(self, x, pos = None): |
| 185 | + |
| 186 | + # pope embedding |
| 187 | + |
| 188 | + polar_pos_emb = None |
| 189 | + if exists(pos) and exists(self.polar_emb): |
| 190 | + polar_pos_emb = self.polar_emb(pos) |
| 191 | + |
| 192 | + # transformer layers |
| 193 | + |
| 194 | + for attn, ff in self.layers: |
| 195 | + x = attn(x, polar_pos_emb) + x |
| 196 | + x = ff(x) + x |
| 197 | + |
| 198 | + return self.norm(x) |
| 199 | + |
| 200 | +class ViTND(Module): |
| 201 | + def __init__( |
| 202 | + self, |
| 203 | + *, |
| 204 | + ndim: int, |
| 205 | + input_shape: int | tuple[int, ...], |
| 206 | + patch_size: int | tuple[int, ...], |
| 207 | + num_classes: int, |
| 208 | + dim: int, |
| 209 | + depth: int, |
| 210 | + heads: int, |
| 211 | + mlp_dim: int, |
| 212 | + channels: int = 3, |
| 213 | + dim_head: int = 64, |
| 214 | + dropout: float = 0., |
| 215 | + emb_dropout: float = 0., |
| 216 | + pope_min_freq: float = 1.0, |
| 217 | + pope_max_freq: float = 10000.0, |
| 218 | + pope_p_zero_freqs: float = 0.0, |
| 219 | + pope_init_learned_bias_uniform = False |
| 220 | + ): |
| 221 | + super().__init__() |
| 222 | + |
| 223 | + assert 1 <= ndim <= 7, 'ndim must be between 1 and 7' |
| 224 | + |
| 225 | + self.ndim = ndim |
| 226 | + |
| 227 | + input_shape = ensure_tuple(input_shape, ndim) |
| 228 | + patch_size = ensure_tuple(patch_size, ndim) |
| 229 | + |
| 230 | + for i, (inp_dim, patch_dim) in enumerate(zip(input_shape, patch_size)): |
| 231 | + assert inp_dim % patch_dim == 0, f'Input dimension {i} ({inp_dim}) must be divisible by patch size ({patch_dim})' |
| 232 | + |
| 233 | + num_patches_per_dim = [inp_dim // patch_dim for inp_dim, patch_dim in zip(input_shape, patch_size)] |
| 234 | + num_patches = 1 |
| 235 | + for n in num_patches_per_dim: |
| 236 | + num_patches *= n |
| 237 | + |
| 238 | + patch_dim = channels |
| 239 | + for p in patch_size: |
| 240 | + patch_dim *= p |
| 241 | + |
| 242 | + dim_names = 'fghijkl'[:ndim] |
| 243 | + |
| 244 | + input_dims = [f'({d} p{i})' for i, d in enumerate(dim_names)] |
| 245 | + patch_dims = [f'p{i}' for i in range(ndim)] |
| 246 | + |
| 247 | + input_pattern = f'b c {join(input_dims)}' |
| 248 | + output_pattern = f'b {join(dim_names)} ({join(patch_dims)} c)' |
| 249 | + rearrange_str = f'{input_pattern} -> {output_pattern}' |
| 250 | + |
| 251 | + rearrange_kwargs = {f'p{i}': p for i, p in enumerate(patch_size)} |
| 252 | + |
| 253 | + self.to_patch_embedding = nn.Sequential( |
| 254 | + Rearrange(rearrange_str, **rearrange_kwargs), |
| 255 | + nn.Linear(patch_dim, dim), |
| 256 | + nn.LayerNorm(dim), |
| 257 | + ) |
| 258 | + |
| 259 | + self.dropout = nn.Dropout(emb_dropout) |
| 260 | + |
| 261 | + # golden gate pope |
| 262 | + |
| 263 | + self.polar_emb = GoldenGatePoPENd( |
| 264 | + dim_pos = ndim, |
| 265 | + heads = heads, |
| 266 | + dim_head = dim_head, |
| 267 | + min_freq = pope_min_freq, |
| 268 | + max_freq = pope_max_freq, |
| 269 | + p_zero_freqs = pope_p_zero_freqs, |
| 270 | + init_learned_bias_uniform = pope_init_learned_bias_uniform |
| 271 | + ) |
| 272 | + |
| 273 | + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, polar_emb = self.polar_emb) |
| 274 | + |
| 275 | + self.to_latent = nn.Identity() |
| 276 | + self.mlp_head = nn.Linear(dim, num_classes) |
| 277 | + |
| 278 | + def muon_parameters(self): |
| 279 | + params = [] |
| 280 | + |
| 281 | + for m in self.modules(): |
| 282 | + if isinstance(m, Attention): |
| 283 | + params.extend([ |
| 284 | + m.to_v.weight, |
| 285 | + m.to_out[0].weight |
| 286 | + ]) |
| 287 | + elif isinstance(m, FeedForward): |
| 288 | + params.extend([ |
| 289 | + m.net[1].weight, |
| 290 | + m.net[-2].weight |
| 291 | + ]) |
| 292 | + |
| 293 | + return params |
| 294 | + |
| 295 | + def forward( |
| 296 | + self, |
| 297 | + x, |
| 298 | + return_embed = False |
| 299 | + ): |
| 300 | + x = self.to_patch_embedding(x) # (b, *spatial_dims, patch_dim) |
| 301 | + |
| 302 | + batch, *spatial_dims, _, device = *x.shape, x.device |
| 303 | + |
| 304 | + # Generate position coordinates |
| 305 | + |
| 306 | + grids = [arange(d, device = device, dtype = torch.float32) for d in spatial_dims] |
| 307 | + grid = torch.meshgrid(*grids, indexing = 'ij') |
| 308 | + pos = stack(grid, dim = -1) # (*spatial_dims, ndim) |
| 309 | + |
| 310 | + # flatten spatial dimensions for attention with nd rotary |
| 311 | + |
| 312 | + pos = repeat(pos, '... p -> b (...) p', b = batch) |
| 313 | + x, packed_shape = pack([x], 'b * d') |
| 314 | + |
| 315 | + x = self.dropout(x) |
| 316 | + |
| 317 | + embed = self.transformer(x, pos) |
| 318 | + |
| 319 | + # return the embed with reconstituted patch shape |
| 320 | + |
| 321 | + if return_embed: |
| 322 | + embed, = unpack(embed, packed_shape, 'b * d') |
| 323 | + return embed |
| 324 | + |
| 325 | + # pooling to logits |
| 326 | + |
| 327 | + pooled = reduce(embed, 'b n d -> b d', 'mean') |
| 328 | + |
| 329 | + pooled = self.to_latent(pooled) |
| 330 | + return self.mlp_head(pooled) |
| 331 | + |
| 332 | +if __name__ == '__main__': |
| 333 | + |
| 334 | + model = ViTND( |
| 335 | + ndim = 5, |
| 336 | + input_shape = (4, 8, 16, 32, 64), |
| 337 | + patch_size = (2, 2, 4, 4, 8), |
| 338 | + num_classes = 1000, |
| 339 | + dim = 512, |
| 340 | + depth = 6, |
| 341 | + heads = 8, |
| 342 | + mlp_dim = 2048, |
| 343 | + channels = 3, |
| 344 | + dropout = 0.1, |
| 345 | + emb_dropout = 0.1 |
| 346 | + ) |
| 347 | + |
| 348 | + data = torch.randn(3, 3, 4, 8, 16, 32, 64) |
| 349 | + |
| 350 | + logits = model(data) |
| 351 | + |
| 352 | + embed = model(data, return_embed = True) # (2, 2, 4, 4, 8, 8, 512) |
0 commit comments