Skip to content

Commit cd03085

Browse files
authored
Remove einops dependency (cvg#25)
1 parent 1902630 commit cd03085

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

lightglue/lightglue.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torch import nn
77
import torch.nn.functional as F
8-
from einops import rearrange, repeat
98
from typing import Optional, List, Callable
109

1110
try:
@@ -34,10 +33,9 @@ def normalize_keypoints(
3433

3534

3635
def rotate_half(x: torch.Tensor) -> torch.Tensor:
37-
x = rearrange(x, '... (d r) -> ... d r', r=2)
36+
x = x.unflatten(-1, (-1, 2))
3837
x1, x2 = x.unbind(dim=-1)
39-
x = torch.stack((-x2, x1), dim=-1)
40-
return rearrange(x, '... d r -> ... (d r)')
38+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
4139

4240

4341
def apply_cached_rotary_emb(
@@ -59,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5957
projected = self.Wr(x)
6058
cosines, sines = torch.cos(projected), torch.sin(projected)
6159
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
62-
return repeat(emb, '... n -> ... (n r)', r=2)
60+
return emb.repeat_interleave(2, dim=-1)
6361

6462

6563
class TokenConfidence(nn.Module):
@@ -130,14 +128,14 @@ def __init__(self, embed_dim: int, num_heads: int,
130128
def _forward(self, x: torch.Tensor,
131129
encoding: Optional[torch.Tensor] = None):
132130
qkv = self.Wqkv(x)
133-
qkv = rearrange(qkv, 'b n (h d three) -> b h n d three',
134-
three=3, h=self.num_heads)
131+
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
135132
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
136133
if encoding is not None:
137134
q = apply_cached_rotary_emb(encoding, q)
138135
k = apply_cached_rotary_emb(encoding, k)
139136
context = self.inner_attn(q, k, v)
140-
message = self.out_proj(rearrange(context, 'b h n d -> b n (h d)'))
137+
message = self.out_proj(
138+
context.transpose(1, 2).flatten(start_dim=-2))
141139
return x + self.ffn(torch.cat([x, message], -1))
142140

143141
def forward(self, x0, x1, encoding0=None, encoding1=None):
@@ -174,7 +172,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
174172
qk0, qk1 = self.map_(self.to_qk, x0, x1)
175173
v0, v1 = self.map_(self.to_v, x0, x1)
176174
qk0, qk1, v0, v1 = map(
177-
lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads),
175+
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
178176
(qk0, qk1, v0, v1))
179177
if self.flash is not None:
180178
m0 = self.flash(qk0, qk1, v1)
@@ -186,7 +184,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> List[torch.Tensor]:
186184
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
187185
m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
188186
m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1), v0)
189-
m0, m1 = self.map_(lambda t: rearrange(t, 'b h n d -> b n (h d)'),
187+
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
190188
m0, m1)
191189
m0, m1 = self.map_(self.to_out, m0, m1)
192190
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ torchvision>=0.3
33
numpy
44
opencv-python
55
matplotlib
6-
kornia>=0.6.11
7-
einops
6+
kornia>=0.6.11

0 commit comments

Comments
 (0)