Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EfficientLoFTR model #36355

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DPR](model_doc/dpr) | ✅ | ✅ | ❌ |
| [DPT](model_doc/dpt) | ✅ | ❌ | ❌ |
| [EfficientFormer](model_doc/efficientformer) | ✅ | ✅ | ❌ |
| [EfficientLoFTR](model_doc/efficientloftr) | ✅ | ❌ | ❌ |
| [EfficientNet](model_doc/efficientnet) | ✅ | ❌ | ❌ |
| [ELECTRA](model_doc/electra) | ✅ | ✅ | ✅ |
| [Emu3](model_doc/emu3) | ✅ | ❌ | ❌ |
Expand Down
98 changes: 98 additions & 0 deletions docs/source/en/model_doc/efficientloftr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the MIT License; you may not use this file except in compliance with
the License.
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# EfficientLoFTR

## Overview

The EfficientLoFTR model was proposed in [Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed](https://arxiv.org/abs/2403.04765) by Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou.

This model consists of matching two images together by finding pixel correspondences. It can be used to estimate the pose between them.
This model is useful for tasks such as image matching, homography estimation, etc.

The abstract from the paper is the following:

*We present a novel method for efficiently producing semidense matches across images. Previous detector-free matcher
LoFTR has shown remarkable matching capability in handling large-viewpoint change and texture-poor scenarios but suffers
from low efficiency. We revisit its design choices and derive multiple improvements for both efficiency and accuracy.
One key observation is that performing the transformer over the entire feature map is redundant due to shared local
information, therefore we propose an aggregated attention mechanism with adaptive token selection for efficiency.
Furthermore, we find spatial variance exists in LoFTR’s fine correlation module, which is adverse to matching accuracy.
A novel two-stage correlation layer is proposed to achieve accurate subpixel correspondences for accuracy improvement.
Our efficiency optimized model is ∼ 2.5× faster than LoFTR which can even surpass state-of-the-art efficient sparse
matching pipeline SuperPoint + LightGlue. Moreover, extensive experiments show that our method can achieve higher
accuracy compared with competitive semi-dense matchers, with considerable efficiency benefits. This opens up exciting
prospects for large-scale or latency-sensitive applications such as image retrieval and 3D reconstruction.
Project page: [https://zju3dv.github.io/efficientloftr/](https://zju3dv.github.io/efficientloftr/).*

## How to use

Here is a quick example of using the model.
```python
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import requests

url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
image1 = Image.open(requests.get(url_image1, stream=True).raw)
url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
image2 = Image.open(requests.get(url_image2, stream=True).raw)

images = [image1, image2]

processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")

inputs = processor(images, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
```

You can use the `post_process_keypoint_matching` method from the `ImageProcessor` to get the keypoints and matches in a more readable format:

```python
image_sizes = [[(image.height, image.width) for image in images]]
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
for i, output in enumerate(outputs):
print("For the image pair", i)
for keypoint0, keypoint1, matching_score in zip(
output["keypoints0"], output["keypoints1"], output["matching_scores"]
):
print(
f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}."
)

```

From the outputs, you can visualize the matches between the two images using the following code:
```python
processor.plot_keypoint_matching(images, outputs)
```

![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/01ZYaLB1NL5XdA8u7yCo4.png)

This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork).

## EfficientLoFTRConfig

[[autodoc]] EfficientLoFTRConfig

## EfficientLoFTRForKeypointMatching

[[autodoc]] EfficientLoFTRForKeypointMatching

- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [EfficientLoFTR](https://huggingface.co/docs/transformers/model_doc/efficientloftr#transformers.EfficientLoFTRForKeypointMatching)
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
Expand Down Expand Up @@ -252,6 +253,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [EfficientLoFTR](https://huggingface.co/docs/transformers/model_doc/efficientloftr#transformers.EfficientLoFTRForKeypointMatching)
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@
"DPRReaderTokenizer",
],
"models.dpt": ["DPTConfig"],
"models.efficientloftr": ["EfficientLoFTRConfig"],
"models.efficientnet": ["EfficientNetConfig"],
"models.electra": [
"ElectraConfig",
Expand Down Expand Up @@ -2274,6 +2275,12 @@
"DPTPreTrainedModel",
]
)
_import_structure["models.efficientloftr"].extend(
[
"EfficientLoFTRForKeypointMatching",
"EfficientLoFTRPreTrainedModel",
]
)
_import_structure["models.efficientnet"].extend(
[
"EfficientNetForImageClassification",
Expand Down Expand Up @@ -5554,6 +5561,7 @@
DPRReaderTokenizer,
)
from .models.dpt import DPTConfig
from .models.efficientloftr import EfficientLoFTRConfig
from .models.efficientnet import (
EfficientNetConfig,
)
Expand Down Expand Up @@ -7284,6 +7292,7 @@
DPTModel,
DPTPreTrainedModel,
)
from .models.efficientloftr import EfficientLoFTRForKeypointMatching, EfficientLoFTRPreTrainedModel
from .models.efficientnet import (
EfficientNetForImageClassification,
EfficientNetModel,
Expand Down
44 changes: 44 additions & 0 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,49 @@ def _compute_llama3_parameters(
return inv_freq_llama, attention_factor


def _compute_2d_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if config is not None and len(rope_kwargs) > 0:
raise ValueError(
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
f"`_compute_2d_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
)
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = config.hidden_size // 4
dim = int(dim * partial_rotary_factor)

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
# inv_freq = 1.0 / (base ** (torch.arange(0, dim, 1, dtype=torch.int64).float().to(device) / dim))
inv_freq = torch.exp(torch.arange(0, dim, 1, dtype=torch.int64, device=device).float() * (-math.log(base) / dim))
return inv_freq, attention_factor


# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
Expand All @@ -355,6 +398,7 @@ def _compute_llama3_parameters(
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
"llama3": _compute_llama3_parameters,
"2d": _compute_2d_parameters,
}


Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
donut,
dpr,
dpt,
efficientloftr,
efficientnet,
electra,
emu3,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
("dpr", "DPRConfig"),
("dpt", "DPTConfig"),
("efficientformer", "EfficientFormerConfig"),
("efficientloftr", "EfficientLoFTRConfig"),
("efficientnet", "EfficientNetConfig"),
("electra", "ElectraConfig"),
("emu3", "Emu3Config"),
Expand Down Expand Up @@ -433,6 +434,7 @@
("dpr", "DPR"),
("dpt", "DPT"),
("efficientformer", "EfficientFormer"),
("efficientloftr", "EfficientLoFTR"),
("efficientnet", "EfficientNet"),
("electra", "ELECTRA"),
("emu3", "Emu3"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
("dpr", "DPRQuestionEncoder"),
("dpt", "DPTModel"),
("efficientformer", "EfficientFormerModel"),
("efficientloftr", "EfficientLoFTRForKeypointMatching"),
("efficientnet", "EfficientNetModel"),
("electra", "ElectraModel"),
("encodec", "EncodecModel"),
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/efficientloftr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_efficientloftr import *
from .modeling_efficientloftr import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading