5
5
import torch
6
6
from torch import nn
7
7
import torch .nn .functional as F
8
- from einops import rearrange , repeat
9
8
from typing import Optional , List , Callable
10
9
11
10
try :
@@ -34,10 +33,9 @@ def normalize_keypoints(
34
33
35
34
36
35
def rotate_half (x : torch .Tensor ) -> torch .Tensor :
37
- x = rearrange ( x , '... (d r) -> ... d r' , r = 2 )
36
+ x = x . unflatten ( - 1 , ( - 1 , 2 ) )
38
37
x1 , x2 = x .unbind (dim = - 1 )
39
- x = torch .stack ((- x2 , x1 ), dim = - 1 )
40
- return rearrange (x , '... d r -> ... (d r)' )
38
+ return torch .stack ((- x2 , x1 ), dim = - 1 ).flatten (start_dim = - 2 )
41
39
42
40
43
41
def apply_cached_rotary_emb (
@@ -59,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
57
projected = self .Wr (x )
60
58
cosines , sines = torch .cos (projected ), torch .sin (projected )
61
59
emb = torch .stack ([cosines , sines ], 0 ).unsqueeze (- 3 )
62
- return repeat ( emb , '... n -> ... (n r)' , r = 2 )
60
+ return emb . repeat_interleave ( 2 , dim = - 1 )
63
61
64
62
65
63
class TokenConfidence (nn .Module ):
@@ -130,14 +128,14 @@ def __init__(self, embed_dim: int, num_heads: int,
130
128
def _forward (self , x : torch .Tensor ,
131
129
encoding : Optional [torch .Tensor ] = None ):
132
130
qkv = self .Wqkv (x )
133
- qkv = rearrange (qkv , 'b n (h d three) -> b h n d three' ,
134
- three = 3 , h = self .num_heads )
131
+ qkv = qkv .unflatten (- 1 , (self .num_heads , - 1 , 3 )).transpose (1 , 2 )
135
132
q , k , v = qkv [..., 0 ], qkv [..., 1 ], qkv [..., 2 ]
136
133
if encoding is not None :
137
134
q = apply_cached_rotary_emb (encoding , q )
138
135
k = apply_cached_rotary_emb (encoding , k )
139
136
context = self .inner_attn (q , k , v )
140
- message = self .out_proj (rearrange (context , 'b h n d -> b n (h d)' ))
137
+ message = self .out_proj (
138
+ context .transpose (1 , 2 ).flatten (start_dim = - 2 ))
141
139
return x + self .ffn (torch .cat ([x , message ], - 1 ))
142
140
143
141
def forward (self , x0 , x1 , encoding0 = None , encoding1 = None ):
@@ -174,7 +172,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
174
172
qk0 , qk1 = self .map_ (self .to_qk , x0 , x1 )
175
173
v0 , v1 = self .map_ (self .to_v , x0 , x1 )
176
174
qk0 , qk1 , v0 , v1 = map (
177
- lambda t : rearrange ( t , 'b n (h d) -> b h n d' , h = self .heads ),
175
+ lambda t : t . unflatten ( - 1 , ( self .heads , - 1 )). transpose ( 1 , 2 ),
178
176
(qk0 , qk1 , v0 , v1 ))
179
177
if self .flash is not None :
180
178
m0 = self .flash (qk0 , qk1 , v1 )
@@ -186,7 +184,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
186
184
attn10 = F .softmax (sim .transpose (- 2 , - 1 ).contiguous (), dim = - 1 )
187
185
m0 = torch .einsum ('bhij, bhjd -> bhid' , attn01 , v1 )
188
186
m1 = torch .einsum ('bhji, bhjd -> bhid' , attn10 .transpose (- 2 , - 1 ), v0 )
189
- m0 , m1 = self .map_ (lambda t : rearrange ( t , 'b h n d -> b n (h d)' ),
187
+ m0 , m1 = self .map_ (lambda t : t . transpose ( 1 , 2 ). flatten ( start_dim = - 2 ),
190
188
m0 , m1 )
191
189
m0 , m1 = self .map_ (self .to_out , m0 , m1 )
192
190
x0 = x0 + self .ffn (torch .cat ([x0 , m0 ], - 1 ))
0 commit comments