Skip to content

Commit d518e89

Browse files
AmitMYclaude
andauthored
cache position grids in NaViT forward pass (#354)
Use lru_cache to cache unique (ph, pw, device) position grids, avoiding redundant computation when multiple images share the same patch dimensions. Cache persists across forward passes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent dd6462d commit d518e89

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

vit_pytorch/na_vit.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from functools import partial
3+
from functools import partial, lru_cache
44
from typing import List
55

66
import torch
@@ -27,6 +27,12 @@ def pair(t):
2727
def divisible_by(numer, denom):
2828
return (numer % denom) == 0
2929

30+
@lru_cache(maxsize=128)
31+
def posemb_grid(ph, pw, device):
32+
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
33+
w_idx = torch.arange(pw, device=device).repeat(ph)
34+
return torch.stack([h_idx, w_idx], dim=-1)
35+
3036
# auto grouping images
3137

3238
def group_images_by_max_seq_len(
@@ -293,12 +299,8 @@ def forward(
293299
# extract patches for all images
294300
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
295301

296-
# compute positions using repeat_interleave (faster than meshgrid per image)
297-
positions = []
298-
for ph, pw in patch_dims:
299-
h_idx = arange(ph).repeat_interleave(pw)
300-
w_idx = arange(pw).repeat(ph)
301-
positions.append(torch.stack([h_idx, w_idx], dim=-1))
302+
# compute positions - uses lru_cache to avoid redundant computation across forward passes
303+
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
302304

303305
# handle token dropout
304306
if has_token_dropout:

0 commit comments

Comments
 (0)