diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 6a168e9905ba..226180707bb0 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -137,6 +137,7 @@ Flax), PyTorch, and/or TensorFlow. | [DPR](model_doc/dpr) | ✅ | ✅ | ❌ | | [DPT](model_doc/dpt) | ✅ | ❌ | ❌ | | [EfficientFormer](model_doc/efficientformer) | ✅ | ✅ | ❌ | +| [EfficientLoFTR](model_doc/efficientloftr) | ✅ | ❌ | ❌ | | [EfficientNet](model_doc/efficientnet) | ✅ | ❌ | ❌ | | [ELECTRA](model_doc/electra) | ✅ | ✅ | ✅ | | [Emu3](model_doc/emu3) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/efficientloftr.md b/docs/source/en/model_doc/efficientloftr.md new file mode 100644 index 000000000000..afbd1f283378 --- /dev/null +++ b/docs/source/en/model_doc/efficientloftr.md @@ -0,0 +1,98 @@ + + +# EfficientLoFTR + +## Overview + +The EfficientLoFTR model was proposed in [Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed](https://arxiv.org/abs/2403.04765) by Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou. + +This model consists of matching two images together by finding pixel correspondences. It can be used to estimate the pose between them. +This model is useful for tasks such as image matching, homography estimation, etc. + +The abstract from the paper is the following: + +*We present a novel method for efficiently producing semidense matches across images. Previous detector-free matcher +LoFTR has shown remarkable matching capability in handling large-viewpoint change and texture-poor scenarios but suffers +from low efficiency. We revisit its design choices and derive multiple improvements for both efficiency and accuracy. +One key observation is that performing the transformer over the entire feature map is redundant due to shared local +information, therefore we propose an aggregated attention mechanism with adaptive token selection for efficiency. +Furthermore, we find spatial variance exists in LoFTR’s fine correlation module, which is adverse to matching accuracy. +A novel two-stage correlation layer is proposed to achieve accurate subpixel correspondences for accuracy improvement. +Our efficiency optimized model is ∼ 2.5× faster than LoFTR which can even surpass state-of-the-art efficient sparse +matching pipeline SuperPoint + LightGlue. Moreover, extensive experiments show that our method can achieve higher +accuracy compared with competitive semi-dense matchers, with considerable efficiency benefits. This opens up exciting +prospects for large-scale or latency-sensitive applications such as image retrieval and 3D reconstruction. +Project page: [https://zju3dv.github.io/efficientloftr/](https://zju3dv.github.io/efficientloftr/).* + +## How to use + +Here is a quick example of using the model. +```python +from transformers import AutoImageProcessor, AutoModel +import torch +from PIL import Image +import requests + +url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg" +image1 = Image.open(requests.get(url_image1, stream=True).raw) +url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg" +image2 = Image.open(requests.get(url_image2, stream=True).raw) + +images = [image1, image2] + +processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor") +model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor") + +inputs = processor(images, return_tensors="pt") +with torch.no_grad(): + outputs = model(**inputs) +``` + +You can use the `post_process_keypoint_matching` method from the `ImageProcessor` to get the keypoints and matches in a more readable format: + +```python +image_sizes = [[(image.height, image.width) for image in images]] +outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2) +for i, output in enumerate(outputs): + print("For the image pair", i) + for keypoint0, keypoint1, matching_score in zip( + output["keypoints0"], output["keypoints1"], output["matching_scores"] + ): + print( + f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}." + ) + +``` + +From the outputs, you can visualize the matches between the two images using the following code: +```python +processor.plot_keypoint_matching(images, outputs) +``` + +![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/01ZYaLB1NL5XdA8u7yCo4.png) + +This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). +The original code can be found [here](https://github.com/magicleap/SuperGluePretrainedNetwork). + +## EfficientLoFTRConfig + +[[autodoc]] EfficientLoFTRConfig + +## EfficientLoFTRForKeypointMatching + +[[autodoc]] EfficientLoFTRForKeypointMatching + +- forward \ No newline at end of file diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index b8896114eccb..076f8679f64d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -49,6 +49,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) +* [EfficientLoFTR](https://huggingface.co/docs/transformers/model_doc/efficientloftr#transformers.EfficientLoFTRForKeypointMatching) * [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) @@ -252,6 +253,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) +* [EfficientLoFTR](https://huggingface.co/docs/transformers/model_doc/efficientloftr#transformers.EfficientLoFTRForKeypointMatching) * [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel) * [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dc427aad5727..762ea1730164 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -425,6 +425,7 @@ "DPRReaderTokenizer", ], "models.dpt": ["DPTConfig"], + "models.efficientloftr": ["EfficientLoFTRConfig"], "models.efficientnet": ["EfficientNetConfig"], "models.electra": [ "ElectraConfig", @@ -2274,6 +2275,12 @@ "DPTPreTrainedModel", ] ) + _import_structure["models.efficientloftr"].extend( + [ + "EfficientLoFTRForKeypointMatching", + "EfficientLoFTRPreTrainedModel", + ] + ) _import_structure["models.efficientnet"].extend( [ "EfficientNetForImageClassification", @@ -5554,6 +5561,7 @@ DPRReaderTokenizer, ) from .models.dpt import DPTConfig + from .models.efficientloftr import EfficientLoFTRConfig from .models.efficientnet import ( EfficientNetConfig, ) @@ -7284,6 +7292,7 @@ DPTModel, DPTPreTrainedModel, ) + from .models.efficientloftr import EfficientLoFTRForKeypointMatching, EfficientLoFTRPreTrainedModel from .models.efficientnet import ( EfficientNetForImageClassification, EfficientNetModel, diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index b2d343e0237f..2088bcf5c1ec 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -345,6 +345,49 @@ def _compute_llama3_parameters( return inv_freq_llama, attention_factor +def _compute_2d_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_2d_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = config.hidden_size // 4 + dim = int(dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + # inv_freq = 1.0 / (base ** (torch.arange(0, dim, 1, dtype=torch.int64).float().to(device) / dim)) + inv_freq = torch.exp(torch.arange(0, dim, 1, dtype=torch.int64, device=device).float() * (-math.log(base) / dim)) + return inv_freq, attention_factor + + # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE # parameterizations, as long as the callable has the same signature. @@ -355,6 +398,7 @@ def _compute_llama3_parameters( "yarn": _compute_yarn_parameters, "longrope": _compute_longrope_parameters, "llama3": _compute_llama3_parameters, + "2d": _compute_2d_parameters, } diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 220f5dfa59c6..d65866485977 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -86,6 +86,7 @@ donut, dpr, dpt, + efficientloftr, efficientnet, electra, emu3, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e81e41006a6..ea2f915891d1 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -103,6 +103,7 @@ ("dpr", "DPRConfig"), ("dpt", "DPTConfig"), ("efficientformer", "EfficientFormerConfig"), + ("efficientloftr", "EfficientLoFTRConfig"), ("efficientnet", "EfficientNetConfig"), ("electra", "ElectraConfig"), ("emu3", "Emu3Config"), @@ -433,6 +434,7 @@ ("dpr", "DPR"), ("dpt", "DPT"), ("efficientformer", "EfficientFormer"), + ("efficientloftr", "EfficientLoFTR"), ("efficientnet", "EfficientNet"), ("electra", "ELECTRA"), ("emu3", "Emu3"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8fbe1b6c0d68..5be2a9902ee9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -101,6 +101,7 @@ ("dpr", "DPRQuestionEncoder"), ("dpt", "DPTModel"), ("efficientformer", "EfficientFormerModel"), + ("efficientloftr", "EfficientLoFTRForKeypointMatching"), ("efficientnet", "EfficientNetModel"), ("electra", "ElectraModel"), ("encodec", "EncodecModel"), diff --git a/src/transformers/models/efficientloftr/__init__.py b/src/transformers/models/efficientloftr/__init__.py new file mode 100644 index 000000000000..b5fe0acd35f1 --- /dev/null +++ b/src/transformers/models/efficientloftr/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_efficientloftr import * + from .modeling_efficientloftr import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/efficientloftr/configuration_efficientloftr.py b/src/transformers/models/efficientloftr/configuration_efficientloftr.py new file mode 100644 index 000000000000..10815a9c9f61 --- /dev/null +++ b/src/transformers/models/efficientloftr/configuration_efficientloftr.py @@ -0,0 +1,177 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/efficientloftr/modular_efficientloftr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_efficientloftr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from ...configuration_utils import PretrainedConfig + + +class EfficientLoFTRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EffientLoFTRFromKeypointMatching`]. + It is used to instantiate a EfficientLoFTR model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + EfficientLoFTR [stevenbucaille/efficient_loftr](https://huggingface.co/stevenbucaille/efficient_loftr) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + stage_block_dims (`List`, *optional*, defaults to [64, 64, 128, 256]): + The hidden size of the features in the blocks of each stage + stage_num_blocks (`List`, *optional*, defaults to [1, 2, 4, 14]): + The number of blocks in each stages + stage_hidden_expansion (`List`, *optional*, defaults to [1, 1, 1, 1]): + The rate of expansion of hidden size in each stage + stage_stride (`List`, *optional*, defaults to [2, 1, 2, 2]): + The stride used in each stage + hidden_size (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + activation_function (`str`, *optional*, defaults to `"relu"`): + The activation function used in the backbone + aggregation_sizes (`List`, *optional*, defaults to [4, 4]): + The size of each aggregation for the fusion network + num_attention_layers (`int`, *optional*, defaults to 4): + Number of attention layers in the LocalFeatureTransformer + num_attention_heads (`int`, *optional*, defaults to 8): + The number of heads in the GNN layers. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during attention. + mlp_activation_function (`str`, *optional*, defaults to `"leaky_relu"`): + Activation function used in the attention mlp layer. + coarse_matching_skip_softmax (`bool`, *optional*, defaults to `False`): + Whether to skip softmax or not at the coarse matching step. + coarse_matching_threshold (`float`, *optional*, defaults to 0.2): + The threshold for the minimum score required for a match. + coarse_matching_temperature (`float`, *optional*, defaults to 0.1): + The temperature to apply to the coarse similarity matrix + coarse_matching_border_removal (`int`, *optional*, defaults to 2): + The size of the border to remove during coarse matching + fine_kernel_size (`int`, *optional*, defaults to 8): + Kernel size used for the fine feature matching + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + rope_type (`str`, *optional*, defaults to `"2d"`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3', '2d'], with 'default' being the original RoPE implementation. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + fine_matching_slicedim (`int`, *optional*, defaults to 8): + The size of the slice used to divide the fine features for the first and second fine matching stages. + fine_matching_regress_temperature (`float`, *optional*, defaults to 10.0): + The temperature to apply to the fine similarity matrix + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Examples: + ```python + >>> from transformers import EfficientLoFTRConfig, EfficientLoFTRForKeypointMatching + + >>> # Initializing a SuperGlue superglue style configuration + >>> configuration = EfficientLoFTRConfig() + + >>> # Initializing a model from the superglue style configuration + >>> model = EfficientLoFTRForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "efficientloftr" + + def __init__( + self, + stage_block_dims: List[int] = None, + stage_num_blocks: List[int] = None, + stage_hidden_expansion: List[float] = None, + stage_stride: List[int] = None, + hidden_size: int = 256, + activation_function: str = "relu", + aggregation_sizes: List[int] = None, + num_attention_layers: int = 4, + num_attention_heads: int = 8, + num_key_value_heads: int = None, + attention_dropout: float = 0.0, + attention_bias: bool = False, + mlp_activation_function: str = "leaky_relu", + coarse_matching_skip_softmax: bool = False, + coarse_matching_threshold: float = 0.2, + coarse_matching_temperature: float = 0.1, + coarse_matching_border_removal: int = 2, + fine_kernel_size: int = 8, + batch_norm_eps: float = 1e-5, + rope_type: str = "2d", + rope_theta: float = 10000.0, + fine_matching_slicedim: int = 8, + fine_matching_regress_temperature: float = 10.0, + initializer_range: float = 0.02, + **kwargs, + ): + self.stage_block_dims = stage_block_dims if stage_block_dims is not None else [64, 64, 128, 256] + self.stage_num_blocks = stage_num_blocks if stage_num_blocks is not None else [1, 2, 4, 14] + self.stage_hidden_expansion = stage_hidden_expansion if stage_hidden_expansion is not None else [1, 1, 1, 1] + self.stage_stride = stage_stride if stage_stride is not None else [2, 1, 2, 2] + self.hidden_size = hidden_size + if self.hidden_size != self.stage_block_dims[-1]: + raise ValueError( + f"hidden_size should be equal to the last value in stage_block_dims. hidden_size = {self.hidden_size}, stage_blck_dims = {self.stage_block_dims}" + ) + + self.activation_function = activation_function + self.aggregation_sizes = aggregation_sizes if aggregation_sizes is not None else [4, 4] + self.num_attention_layers = num_attention_layers + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.mlp_activation_function = mlp_activation_function + self.coarse_matching_skip_softmax = coarse_matching_skip_softmax + self.coarse_matching_threshold = coarse_matching_threshold + self.coarse_matching_temperature = coarse_matching_temperature + self.coarse_matching_border_removal = coarse_matching_border_removal + self.fine_kernel_size = fine_kernel_size + self.batch_norm_eps = batch_norm_eps + self.fine_matching_slicedim = fine_matching_slicedim + self.fine_matching_regress_temperature = fine_matching_regress_temperature + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.rope_type = rope_type + self.rope_theta = rope_theta + + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["EfficientLoFTRConfig"] diff --git a/src/transformers/models/efficientloftr/convert_efficientloftr_to_hf.py b/src/transformers/models/efficientloftr/convert_efficientloftr_to_hf.py new file mode 100644 index 000000000000..9ed0d87b611c --- /dev/null +++ b/src/transformers/models/efficientloftr/convert_efficientloftr_to_hf.py @@ -0,0 +1,258 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import os +import re +from typing import List + +import torch +from datasets import load_dataset +from huggingface_hub import hf_hub_download + +from transformers import SuperGlueImageProcessor +from transformers.models.efficientloftr.modeling_efficientloftr import ( + EfficientLoFTRConfig, + EfficientLoFTRForKeypointMatching, +) + + +DEFAULT_MODEL_REPO = "stevenbucaille/efficient_loftr_pth" +DEFAULT_FILE = "eloftr.pth" + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image0 = dataset[0]["image"] + image2 = dataset[2]["image"] + return [[image2, image0]] + + +def verify_model_outputs(model, device): + images = prepare_imgs() + preprocessor = SuperGlueImageProcessor() + inputs = preprocessor(images=images, return_tensors="pt").to(device) + model.to(device) + model.eval() + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_matches_values = outputs.matches[0, 0, 20:30] + predicted_matching_scores_values = outputs.matching_scores[0, 0, 20:30] + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + + expected_max_number_keypoints = 501 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=torch.int64).to(device) + expected_matching_scores_values = torch.tensor( + [0.4148, 0.4459, 0.4732, 0.4315, 0.3388, 0.5651, 0.4266, 0.4288, 0.6642, 0.5476] + ).to(device) + + expected_number_of_matches = 501 + + assert outputs.matches.shape == expected_matches_shape + assert outputs.matching_scores.shape == expected_matching_scores_shape + + assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-3) + assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-3) + + assert predicted_number_of_matches == expected_number_of_matches + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"matcher.backbone.layer(\d+).rbr_dense.conv": r"backbone.stages.\1.blocks.0.conv1.conv", + r"matcher.backbone.layer(\d+).rbr_dense.bn": r"backbone.stages.\1.blocks.0.conv1.norm", + r"matcher.backbone.layer(\d+).rbr_1x1.conv": r"backbone.stages.\1.blocks.0.conv2.conv", + r"matcher.backbone.layer(\d+).rbr_1x1.bn": r"backbone.stages.\1.blocks.0.conv2.norm", + r"matcher.backbone.layer(\d+).(\d+).rbr_dense.conv": r"backbone.stages.\1.blocks.\2.conv1.conv", + r"matcher.backbone.layer(\d+).(\d+).rbr_dense.bn": r"backbone.stages.\1.blocks.\2.conv1.norm", + r"matcher.backbone.layer(\d+).(\d+).rbr_1x1.conv": r"backbone.stages.\1.blocks.\2.conv2.conv", + r"matcher.backbone.layer(\d+).(\d+).rbr_1x1.bn": r"backbone.stages.\1.blocks.\2.conv2.norm", + r"matcher.backbone.layer(\d+).(\d+).rbr_identity": r"backbone.stages.\1.blocks.\2.identity", + r"matcher.loftr_coarse.layers.(\d*[02468]).aggregate": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.aggregation.q_aggregation", + r"matcher.loftr_coarse.layers.(\d*[02468]).norm1": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.aggregation.norm", + r"matcher.loftr_coarse.layers.(\d*[02468]).q_proj": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.attention.q_proj", + r"matcher.loftr_coarse.layers.(\d*[02468]).k_proj": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.attention.k_proj", + r"matcher.loftr_coarse.layers.(\d*[02468]).v_proj": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.attention.v_proj", + r"matcher.loftr_coarse.layers.(\d*[02468]).merge": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.attention.o_proj", + r"matcher.loftr_coarse.layers.(\d*[02468]).mlp.(\d+)": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.mlp.fc{1 if m.group(2) == '0' else 2}", + r"matcher.loftr_coarse.layers.(\d*[02468]).norm2": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.self_attention.mlp.layer_norm", + r"matcher.loftr_coarse.layers.(\d*[13579]).aggregate": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.aggregation.q_aggregation", + r"matcher.loftr_coarse.layers.(\d*[13579]).norm1": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.aggregation.norm", + r"matcher.loftr_coarse.layers.(\d*[13579]).q_proj": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.attention.q_proj", + r"matcher.loftr_coarse.layers.(\d*[13579]).k_proj": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.attention.k_proj", + r"matcher.loftr_coarse.layers.(\d*[13579]).v_proj": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.attention.v_proj", + r"matcher.loftr_coarse.layers.(\d*[13579]).merge": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.attention.o_proj", + r"matcher.loftr_coarse.layers.(\d*[13579]).mlp.(\d+)": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.mlp.fc{1 if m.group(2) == '0' else 2}", + r"matcher.loftr_coarse.layers.(\d*[13579]).norm2": lambda m: f"local_feature_transformer.layers.{int(m.group(1)) // 2}.cross_attention.mlp.layer_norm", + r"matcher.fine_preprocess.layer3_outconv": "refinement_layer.out_conv", + r"matcher.fine_preprocess.layer(\d+)_outconv.weight": lambda m: f"refinement_layer.out_conv_layers.{0 if int(m.group(1)) == 2 else m.group(1)}.out_conv1.weight", + r"matcher.fine_preprocess.layer(\d+)_outconv2\.0": lambda m: f"refinement_layer.out_conv_layers.{0 if int(m.group(1)) == 2 else m.group(1)}.out_conv2", + r"matcher.fine_preprocess.layer(\d+)_outconv2\.1": lambda m: f"refinement_layer.out_conv_layers.{0 if int(m.group(1)) == 2 else m.group(1)}.batch_norm", + r"matcher.fine_preprocess.layer(\d+)_outconv2\.3": lambda m: f"refinement_layer.out_conv_layers.{0 if int(m.group(1)) == 2 else m.group(1)}.out_conv3", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: List[str]): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +@torch.no_grad() +def write_model( + model_path, + model_repo, + file_name, + organization, + safe_serialization=True, + push_to_hub=False, +): + os.makedirs(model_path, exist_ok=True) + + # ------------------------------------------------------------ + # EfficientLoFTR config + # ------------------------------------------------------------ + + config = EfficientLoFTRConfig() + config.architectures = ["EfficientLoFTRForKeypointMatching"] + config.save_pretrained(model_path) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {model_repo}/{file_name}...") + checkpoint_path = hf_hub_download(repo_id=model_repo, filename=file_name) + original_state_dict = torch.load(checkpoint_path, weights_only=True, map_location="cpu")["state_dict"] + + print("Converting model...") + all_keys = list(original_state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + state_dict[new_key] = original_state_dict.pop(key).contiguous().clone() + + del original_state_dict + gc.collect() + + print("Loading the checkpoint in a EfficientLoFTR model...") + device = "cuda" + with torch.device(device): + model = EfficientLoFTRForKeypointMatching(config) + model.load_state_dict(state_dict) + print("Checkpoint loaded successfully...") + del model.config._name_or_path + + print("Saving the model...") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + model = EfficientLoFTRForKeypointMatching.from_pretrained(model_path) + print("Model reloaded successfully.") + + model_name = "efficientloftr" + if model_repo == DEFAULT_MODEL_REPO: + print("Checking the model outputs...") + verify_model_outputs(model, device) + print("Model outputs verified successfully.") + + if push_to_hub: + print("Pushing model to the hub...") + model.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add model", + ) + config.push_to_hub(repo_id=f"{organization}/{model_name}", commit_message="Add config") + + write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub) + + +def write_image_processor(save_dir, model_name, organization, push_to_hub=False): + image_processor = SuperGlueImageProcessor() + image_processor.save_pretrained(save_dir) + + if push_to_hub: + print("Pushing image processor to the hub...") + image_processor.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add image processor", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--repo_id", + default=DEFAULT_MODEL_REPO, + type=str, + help="Model repo ID of the original EfficientLoFTR checkpoint you'd like to convert.", + ) + parser.add_argument( + "--file_name", + default=DEFAULT_FILE, + type=str, + help="File name of the original EfficientLoFTR checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push model and image preprocessor to the hub", + ) + parser.add_argument( + "--organization", + default="stevenbucaille", + type=str, + help="Hub organization in which you want the model to be uploaded.", + ) + + args = parser.parse_args() + write_model( + args.pytorch_dump_folder_path, + args.repo_id, + args.file_name, + args.organization, + safe_serialization=True, + push_to_hub=args.push_to_hub, + ) diff --git a/src/transformers/models/efficientloftr/modeling_efficientloftr.py b/src/transformers/models/efficientloftr/modeling_efficientloftr.py new file mode 100644 index 000000000000..015801e8da22 --- /dev/null +++ b/src/transformers/models/efficientloftr/modeling_efficientloftr.py @@ -0,0 +1,1420 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/efficientloftr/modular_efficientloftr.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_efficientloftr.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...activations import ACT2CLS, ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_efficientloftr import EfficientLoFTRConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +class KeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number + of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of + images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is + used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching + information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches and matching_scores are keypoint matching information. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)`, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`) + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + mask: Optional[torch.IntTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class EfficientLoFTRRotaryEmbedding(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, device="cpu") -> None: + super().__init__() + self.config = config + self.rope_type = config.rope_type + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + b, _, h, w = x.shape + + i_position_indices = torch.ones(h, w, device=x.device).cumsum(0).float().unsqueeze(-1) + j_position_indices = torch.ones(h, w, device=x.device).cumsum(1).float().unsqueeze(-1) + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, None, None, :].float().expand(1, 1, 1, -1) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + emb = torch.zeros(1, h, w, self.config.hidden_size // 2) + emb[:, :, :, 0::2] = i_position_indices * inv_freq_expanded + emb[:, :, :, 1::2] = j_position_indices * inv_freq_expanded + + sin = emb.sin() + cos = emb.cos() + + sin = sin.repeat_interleave(2, dim=-1) + cos = cos.repeat_interleave(2, dim=-1) + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + sin = sin.to(device=x.device, dtype=x.dtype) + cos = cos.to(device=x.device, dtype=x.dtype) + + return cos, sin + + +class EfficientLoFTRConvNormLayer(nn.Module): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class EfficientLoFTRRepVGGBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: EfficientLoFTRConfig, in_channels: int, out_channels: int, stride: int = 1) -> None: + super().__init__() + activation = config.activation_function + self.conv1 = EfficientLoFTRConvNormLayer( + config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1 + ) + self.conv2 = EfficientLoFTRConvNormLayer( + config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0 + ) + self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None + self.activation = nn.Identity() if activation is None else ACT2FN[activation] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.identity is not None: + identity_out = self.identity(hidden_states) + else: + identity_out = 0 + hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out + hidden_states = self.activation(hidden_states) + return hidden_states + + +class EfficientLoFTRRepVGGStage(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, in_channels, out_channels, num_blocks, stride) -> None: + super().__init__() + + strides = [stride] + [1] * (num_blocks - 1) + current_channel_dim = in_channels + blocks = [] + for stride in strides: + blocks.append( + EfficientLoFTRRepVGGBlock( + config, + current_channel_dim, + out_channels, + stride, + ) + ) + current_channel_dim = out_channels + self.blocks = nn.ModuleList(blocks) + + def forward( + self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + for block in self.blocks: + hidden_states = block(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + return hidden_states, all_hidden_states + + +class EfficientLoFTRepVGG(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + self.stages = nn.ModuleList([]) + num_stages = len(config.stage_block_dims) + current_in_channels = 1 + + for i in range(num_stages): + out_channels = int(config.stage_block_dims[i] * config.stage_hidden_expansion[i]) + stage = EfficientLoFTRRepVGGStage( + config, current_in_channels, out_channels, config.stage_num_blocks[i], config.stage_stride[i] + ) + current_in_channels = out_channels + self.stages.append(stage) + + def forward( + self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, List[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + outputs = [] + all_hidden_states = () if output_hidden_states else None + for stage in self.stages: + stage_outputs = stage(hidden_states, output_hidden_states=output_hidden_states) + hidden_states = stage_outputs[0] + if output_hidden_states: + all_hidden_states = all_hidden_states + stage_outputs[1] + outputs.append(hidden_states) + + # Exclude first stage in outputs + outputs = outputs[1:] + # Last stage outputs are coarse outputs + coarse_features = outputs[-1] + # Rest is residual features used in EfficientLoFTRFineFusionLayer + residual_features = outputs[:-1] + return coarse_features, residual_features, all_hidden_states + + +class EfficientLoFTRAggregationLayer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + hidden_size = config.hidden_size + aggregation_sizes = config.aggregation_sizes + self.q_aggregation = ( + nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=aggregation_sizes[0], + padding=0, + stride=aggregation_sizes[0], + bias=False, + groups=hidden_size, + ) + if aggregation_sizes[0] != 1 + else nn.Identity() + ) + + self.kv_aggregation = ( + torch.nn.MaxPool2d(kernel_size=aggregation_sizes[1], stride=aggregation_sizes[1]) + if aggregation_sizes[1] != 1 + else nn.Identity() + ) + + self.norm = nn.LayerNorm(hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + query_states = hidden_states + is_cross_attention = encoder_hidden_states is not None + kv_states = encoder_hidden_states if is_cross_attention else hidden_states + + query_states = self.q_aggregation(query_states) + kv_states = self.kv_aggregation(kv_states) + query_states = query_states.permute(0, 2, 3, 1) + kv_states = kv_states.permute(0, 2, 3, 1) + hidden_states = self.norm(query_states) + encoder_hidden_states = self.norm(kv_states) + if attention_mask is not None: + current_mask = encoder_attention_mask if is_cross_attention else attention_mask + attention_mask = self.kv_aggregation(attention_mask.float()).bool() + encoder_attention_mask = self.kv_aggregation(current_mask.float()).bool() + return hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class EfficientLoFTRAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_len, dim = hidden_states.shape + input_shape = hidden_states.shape[:-1] + + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim) + value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) + + query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EfficientLoFTRMLP(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + hidden_size = config.hidden_size + self.fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size, bias=False) + self.activation = ACT2FN[config.mlp_activation_function] + self.fc2 = nn.Linear(2 * hidden_size, hidden_size, bias=False) + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +def get_positional_embeddings_slice( + hidden_states: torch.Tensor, positional_embeddings: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, h, w, _ = hidden_states.shape + positional_embeddings = tuple( + tensor[:, :h, :w, :].expand(batch_size, -1, -1, -1) for tensor in positional_embeddings + ) + return positional_embeddings + + +class EfficientLoFTRAggregatedAttention(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int) -> None: + super().__init__() + + self.aggregation_sizes = config.aggregation_sizes + self.aggregation = EfficientLoFTRAggregationLayer(config) + self.attention = EfficientLoFTRAttention(config, layer_idx) + self.mlp = EfficientLoFTRMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> Tuple[torch.Tensor]: + batch_size, channels, h, w = hidden_states.shape + + # Aggregate features + aggregated_hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask = self.aggregation( + hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask + ) + + attention_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, channels) + encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, channels) + + if position_embeddings is not None: + position_embeddings = get_positional_embeddings_slice(aggregated_hidden_states, position_embeddings) + position_embeddings = tuple(tensor.reshape(batch_size, -1, channels) for tensor in position_embeddings) + + # Multi-head attention + attention_outputs = self.attention( + attention_hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + position_embeddings=position_embeddings, + ) + message = attention_outputs[0] + + # Upsample features + _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape + # (batch_size, seq_len, channels) -> (batch_size, channels, h, w) with seq_len = h * w + message = message.permute(0, 2, 1) + message = message.reshape(batch_size, channels, aggregated_h, aggregated_w) + if self.aggregation_sizes[0] != 1: + message = torch.nn.functional.interpolate( + message, scale_factor=self.aggregation_sizes[0], mode="bilinear", align_corners=False + ) + intermediate_states = torch.cat([hidden_states, message], dim=1) + intermediate_states = intermediate_states.permute(0, 2, 3, 1) + output_states = self.mlp(intermediate_states) + output_states = output_states.permute(0, 3, 1, 2) + + hidden_states = hidden_states + output_states + + outputs = (hidden_states,) + attention_outputs[1:] + return outputs + + +class EfficientLoFTRLocalFeatureTransformerLayer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int) -> None: + super().__init__() + + self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) + self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_attentions = () if output_attentions else None + batch_size, _, c, h, w = hidden_states.shape + + hidden_states = hidden_states.reshape(-1, c, h, w) + if attention_mask is not None: + attention_mask = attention_mask.reshape(-1, c, h, w) + + self_attention_outputs = self.self_attention( + hidden_states, attention_mask, position_embeddings=position_embeddings + ) + hidden_states = self_attention_outputs[0] + + encoder_hidden_states = hidden_states.reshape(-1, 2, c, h, w).flip(1).reshape(-1, c, h, w) + encoder_attention_mask = None + if attention_mask is not None: + encoder_attention_mask = attention_mask.reshape(-1, 2, c, h, w).flip(1).reshape(-1, c, h, w) + + cross_attention_outputs = self.cross_attention( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + + hidden_states = cross_attention_outputs[0] + hidden_states = hidden_states.reshape(batch_size, -1, c, h, w) + + if output_attentions: + all_attentions = all_attentions + (self_attention_outputs[1], cross_attention_outputs[1]) + + return hidden_states, all_attentions + + +class EfficientLoFTRLocalFeatureTransformer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [ + EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i) + for i in range(config.num_attention_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_attentions = () if output_attentions else None + + for layer in self.layers: + layer_outputs = layer( + hidden_states, position_embeddings=position_embeddings, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + layer_outputs[1] + return hidden_states, all_attentions + + +class EfficientLoFTROutConvBlock(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + + self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False) + self.out_conv2 = nn.Conv2d( + intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False + ) + self.batch_norm = nn.BatchNorm2d(intermediate_size) + self.activation = ACT2CLS[config.mlp_activation_function]() + self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, hidden_states: torch.Tensor, residual_states: List[torch.Tensor]) -> torch.Tensor: + residual_states = self.out_conv1(residual_states) + residual_states = residual_states + hidden_states + residual_states = self.out_conv2(residual_states) + residual_states = self.batch_norm(residual_states) + residual_states = self.activation(residual_states) + residual_states = self.out_conv3(residual_states) + residual_states = nn.functional.interpolate( + residual_states, scale_factor=2.0, mode="bilinear", align_corners=False + ) + return residual_states + + +class EfficientLoFTRFineFusionLayer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + self.fine_kernel_size = config.fine_kernel_size + + stage_block_dims = config.stage_block_dims + stage_block_dims = list(reversed(stage_block_dims))[:-1] + self.out_conv = nn.Conv2d( + stage_block_dims[0], stage_block_dims[0], kernel_size=1, stride=1, padding=0, bias=False + ) + self.out_conv_layers = nn.ModuleList() + for i in range(1, len(stage_block_dims)): + out_conv = EfficientLoFTROutConvBlock(config, stage_block_dims[i], stage_block_dims[i - 1]) + self.out_conv_layers.append(out_conv) + + def forward_pyramid( + self, + hidden_states: torch.Tensor, + residual_states: List[torch.Tensor], + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + hidden_states = self.out_conv(hidden_states) + hidden_states = nn.functional.interpolate( + hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for i, layer in enumerate(self.out_conv_layers): + hidden_states = self.out_conv_layers[i](hidden_states, residual_states[i]) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states + + def forward( + self, + coarse_features: torch.Tensor, + residual_features: List[torch.Tensor], + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor]]]: + """ + For each image pair, compute the fine features of pixels. + In both images, compute a patch of fine features center cropped around each coarse pixel. + In the first image, the feature patch is kernel_size large and long. + In the second image, it is (kernel_size + 2) large and long. + """ + batch_size, _, channels, coarse_height, coarse_width = coarse_features.shape + + coarse_features = coarse_features.reshape(-1, channels, coarse_height, coarse_width) + residual_features = list(reversed(residual_features)) + + # 1. Fine feature extraction + pyramid_outputs = self.forward_pyramid( + coarse_features, residual_features, output_hidden_states=output_hidden_states + ) + fine_features = pyramid_outputs[0] + _, fine_channels, fine_height, fine_width = fine_features.shape + + fine_features = fine_features.reshape(batch_size, 2, fine_channels, fine_height, fine_width) + fine_features_0 = fine_features[:, 0] + fine_features_1 = fine_features[:, 1] + + # 2. Unfold all local windows in crops + stride = int(fine_height // coarse_height) + fine_features_0 = nn.functional.unfold( + fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0 + ) + _, _, seq_len = fine_features_0.shape + fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len) + fine_features_0 = fine_features_0.permute(0, 3, 2, 1) + + fine_features_1 = nn.functional.unfold( + fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1 + ) + fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len) + fine_features_1 = fine_features_1.permute(0, 3, 2, 1) + + return fine_features_0, fine_features_1, pyramid_outputs[1] + + +class EfficientLoFTRPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientLoFTRConfig + base_model_prefix = "efficientloftr" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: + """ + Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same, + extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is + a workaround for the issue discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width) + + Returns: + pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width) + + """ + return pixel_values[:, 0, :, :][:, None, :, :] + + +def create_meshgrid( + height: int, + width: int, + normalized_coordinates: bool = False, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + """ + Copied from kornia library : kornia/kornia/utils/grid.py:26 + + Generate a coordinate grid for an image. + + When the flag ``normalized_coordinates`` is set to True, the grid is + normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch + function :py:func:`torch.nn.functional.grid_sample`. + + Args: + height (`int`): + The image height (rows). + width (`int`): + The image width (cols). + normalized_coordinates (`bool`): + Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the + PyTorch function :py:func:`torch.nn.functional.grid_sample`. + device (`torch.device`): + The device on which the grid will be generated. + dtype (`torch.dtype`): + The data type of the generated grid. + + Return: + grid (`torch.Tensor` of shape `(1, height, width, 2)`): + The grid tensor. + + Example: + >>> create_meshgrid(2, 2) + tensor([[[[-1., -1.], + [ 1., -1.]], + + [[-1., 1.], + [ 1., 1.]]]]) + + >>> create_meshgrid(2, 2, normalized_coordinates=False) + tensor([[[[0., 0.], + [1., 0.]], + + [[0., 1.], + [1., 1.]]]]) + + """ + xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + if normalized_coordinates: + xs = (xs / (width - 1) - 0.5) * 2 + ys = (ys / (height - 1) - 0.5) * 2 + grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1) + grid = grid.permute(1, 0, 2).unsqueeze(0) + return grid + + +def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor: + r""" + Copied from kornia library : kornia/geometry/subpix/dsnt.py:76 + Compute the expectation of coordinate values using spatial probabilities. + + The input heatmap is assumed to represent a valid spatial probability distribution, + which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`. + + Args: + input (`torch.Tensor` of shape `(batch_size, channels, height, width)`): + The input tensor representing dense spatial probabilities. + normalized_coordinates (`bool`): + Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return + the coordinates in the range of the input shape. + + Returns: + output (`torch.Tensor` of shape `(batch_size, channels, 2)`) + Expected value of the 2D coordinates. Output order of the coordinates is (x, y). + + Examples: + >>> heatmaps = torch.tensor([[[ + ... [0., 0., 0.], + ... [0., 0., 0.], + ... [0., 1., 0.]]]]) + >>> spatial_expectation2d(heatmaps, False) + tensor([[[1., 2.]]]) + + """ + batch_size, channels, height, width = input.shape + + # Create coordinates grid. + grid = create_meshgrid(height, width, normalized_coordinates, input.device) + grid = grid.to(input.dtype) + + pos_x = grid[..., 0].reshape(-1) + pos_y = grid[..., 1].reshape(-1) + + input_flat = input.view(batch_size, channels, -1) + + # Compute the expectation of the coordinates. + expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True) + expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True) + + output = torch.cat([expected_x, expected_y], -1) + + return output.view(batch_size, channels, 2) + + +def mask_border(tensor: torch.Tensor, border_margin: int, value: Union[bool, float, int]) -> torch.Tensor: + """ + Mask a tensor border with a given value + + Args: + tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + The tensor to mask + border_margin (`int`) : + The size of the border + value (`Union[bool, int, float]`): + The value to place in the tensor's borders + + Returns: + tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + The masked tensor + """ + if border_margin <= 0: + return tensor + + tensor[:, :border_margin] = value + tensor[:, :, :border_margin] = value + tensor[:, :, :, :border_margin] = value + tensor[:, :, :, :, :border_margin] = value + tensor[:, -border_margin:] = value + tensor[:, :, -border_margin:] = value + tensor[:, :, :, -border_margin:] = value + tensor[:, :, :, :, -border_margin:] = value + return tensor + + +EFFICIENTLOFTR_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`EfficientLoFTRConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + +EFFICIENTLOFTR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SuperGlueImageProcessor`]. See + [`SuperGlueImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "EfficientLoFTR model taking images as inputs and outputting the matching of them.", + EFFICIENTLOFTR_START_DOCSTRING, +) +class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel): + """EfficientLoFTR dense image matcher + + Given two images, we determine the correspondences by: + 1. Extracting coarse and fine features through a backbone + 2. Transforming coarse features through self and cross attention + 3. Matching coarse features to obtain coarse coordinates of matches + 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features + 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement + + Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou. + Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed + In CVPR, 2024. https://arxiv.org/abs/2403.04765 + """ + + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__(config) + + self.config = config + self.backbone = EfficientLoFTRepVGG(config) + self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config) + self.refinement_layer = EfficientLoFTRFineFusionLayer(config) + + self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config) + + self.post_init() + + def get_matches_from_scores(self, scores: torch.Tensor): + """ + Based on a keypoint score matrix, compute the best keypoint matches between the first and second image. + Since each image pair can have different number of matches, the matches are concatenated together for all pair + in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch. + Args: + scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + Scores of keypoints + + Returns: + matched_indices (`torch.Tensor` of shape `(2, num_matches)`): + Indices representing which pixel in the first image matches which pixel in the second image + matching_scores (`torch.Tensor` of shape `(num_matches,)`): + Scores of each match + batch_indices (`torch.Tensor` of shape `(num_matches,)`): + Batch correspondences of matches + """ + batch_size, height0, width0, height1, width1 = scores.shape + + scores = scores.reshape(batch_size, height0 * width0, height1 * width1) + + # For each keypoint, get the best match + max_0 = scores.max(2, keepdim=True).values + max_1 = scores.max(1, keepdim=True).values + + # 1. Thresholding + mask = scores > self.config.coarse_matching_threshold + + # 2. Border removal + mask = mask.reshape(batch_size, height0, width0, height1, width1) + mask = mask_border(mask, self.config.coarse_matching_border_removal, False) + mask = mask.reshape(batch_size, height0 * width0, height1 * width1) + + # 3. Mutual nearest neighbors + mask = mask * (scores == max_0) * (scores == max_1) + + # 4. Fine coarse matches + mask_values, mask_indices = mask.max(dim=2) + batch_indices, matched_indices_0 = torch.where(mask_values) + matched_indices_1 = mask_indices[batch_indices, matched_indices_0] + matching_scores = scores[batch_indices, matched_indices_0, matched_indices_1] + + matched_indices = torch.stack([matched_indices_0, matched_indices_1], dim=0) + return matched_indices, matching_scores, batch_indices + + def coarse_matching( + self, coarse_features: torch.Tensor, coarse_scale: float + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8) + * (image_width / 8 elements)) from the first image to the second image. Since the number of matches can vary + with different image pairs, the matches are concatenated together in a dimension. A batch_indices tensor is + returned to inform which keypoint is part of which image pair. + + Args: + coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`): + Coarse features + coarse_scale (`float`): Scale between the image size and the coarse size + + Returns: + matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): + Matched keypoint between the first and the second image. All matched keypoints are concatenated in the + second dimension. + matching_scores (`torch.Tensor` of shape `(batch_size, num_matches)`): + The confidence score of each matched keypoint. + batch_indices (`torch.Tensor` of shape `(num_matches,)`): + Indices of batches for each matched keypoint found. + """ + batch_size, _, channels, height, width = coarse_features.shape + + # (batch_size, 2, channels, height, width) -> (batch_size, 2, height * width, channels) + coarse_features = coarse_features.permute(0, 1, 3, 4, 2) + coarse_features = coarse_features.reshape(batch_size, 2, -1, channels) + + coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5 + coarse_features_0 = coarse_features[:, 0] + coarse_features_1 = coarse_features[:, 1] + + similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2) + similarity = similarity / self.config.coarse_matching_temperature + + if self.config.coarse_matching_skip_softmax: + confidence = similarity + else: + confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2) + + confidence = confidence.reshape(batch_size, height, width, height, width) + matched_indices, matching_scores, batch_indices = self.get_matches_from_scores(confidence) + + matched_keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale + + return ( + matched_keypoints, + matching_scores, + batch_indices, + matched_indices, + ) + + def get_first_stage_fine_matching( + self, + fine_confidence: torch.Tensor, + coarse_matched_keypoints: torch.Tensor, + fine_window_size: int, + fine_scale: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + For each coarse pixel, retrieve the highest fine confidence score and index. + The index represents the matching between a pixel position in the fine window in the first image and a pixel + position in the fine window of the second image. + For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38 + (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38 + which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example + the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image. + + Args: + fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): + First stage confidence of matching fine features between the first and the second image + coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coarse matched keypoint between the first and the second image. + fine_window_size (`int`): + Size of the window used to refine matches + fine_scale (`float`): + Scale between the size of fine features and coarse features + + Returns: + indices (`torch.Tensor` of shape `(2, num_matches, 1)`): + Indices of the fine coordinate matched in the fine window + fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the first fine stage + """ + num_matches, _, _ = fine_confidence.shape + fine_kernel_size = int(math.sqrt(fine_window_size)) + + fine_confidence = fine_confidence.reshape(num_matches, -1) + values, indices = torch.max(fine_confidence, dim=-1) + indices = indices[..., None] + indices_0 = indices // fine_window_size + indices_1 = indices % fine_window_size + + grid = create_meshgrid( + fine_kernel_size, + fine_kernel_size, + normalized_coordinates=False, + device=fine_confidence.device, + dtype=fine_confidence.dtype, + ) + grid = grid - (fine_kernel_size // 2) + 0.5 + grid = grid.reshape(1, -1, 2).expand(num_matches, -1, -1) + delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, 2)).squeeze(1) + delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, 2)).squeeze(1) + + fine_matches_0 = coarse_matched_keypoints[0] + delta_0 * fine_scale + fine_matches_0 = fine_matches_0.reshape(num_matches, 2) + fine_matches_1 = coarse_matched_keypoints[1] + delta_1 * fine_scale + fine_matches_1 = fine_matches_1.reshape(num_matches, 2) + + indices = torch.stack([indices_0, indices_1], dim=0) + fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=0) + + return indices, fine_matches + + def get_second_stage_fine_matching( + self, + indices: torch.Tensor, + fine_matches: torch.Tensor, + fine_confidence: torch.Tensor, + fine_window_size: int, + fine_scale: float, + ) -> torch.Tensor: + """ + For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position. + After applying softmax to these confidences, compute the 2D spatial expected coordinates. + Shift the first stage fine matching with these expected coordinates. + + Args: + indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Indices representing the position of each keypoint in the fine window + fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the first fine stage + fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): + Second stage confidence of matching fine features between the first and the second image + fine_window_size (`int`): + Size of the window used to refine matches + fine_scale (`float`): + Scale between the size of fine features and coarse features + Returns: + fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the second fine stage + """ + num_matches, _, _ = fine_confidence.shape + fine_kernel_size = int(math.sqrt(fine_window_size)) + + indices_0 = indices[0] + indices_1 = indices[1] + indices_1_i = indices_1 // fine_kernel_size + indices_1_j = indices_1 % fine_kernel_size + + matches_indices = torch.arange(num_matches, device=indices_0.device) + + # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3) + matches_indices = matches_indices[..., None, None].expand(-1, 3, 3) + indices_0 = indices_0[..., None].expand(-1, 3, 3) + indices_1_i = indices_1_i[..., None].expand(-1, 3, 3) + indices_1_j = indices_1_j[..., None].expand(-1, 3, 3) + + delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long) + delta = delta[None, ...] + + indices_1_i = indices_1_i + delta[..., 1] + indices_1_j = indices_1_j + delta[..., 0] + + fine_confidence = fine_confidence.reshape( + num_matches, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 + ) + # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3) + fine_confidence = fine_confidence[matches_indices, indices_0, indices_1_i, indices_1_j] + fine_confidence = fine_confidence.reshape(num_matches, 9) + fine_confidence = nn.functional.softmax( + fine_confidence / self.config.fine_matching_regress_temperature, dim=-1 + ) + + heatmap = fine_confidence.reshape(1, -1, 3, 3) + fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0] + + fine_matches_0 = fine_matches[0] + fine_matches_1 = fine_matches[1] + (fine_coordinates_normalized * (3 // 2) * fine_scale) + + fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=0) + + return fine_matches + + def fine_matching( + self, + fine_features_0: torch.Tensor, + fine_features_1: torch.Tensor, + coarse_matched_keypoints: torch.Tensor, + fine_scale: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine + features in the first image and the second image. + + Fine features are sliced in two part : + - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8 + = 56 by default) features. + - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features. + + Each part is used to compute a fine confidence tensor of the following shape : + (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size) + They correspond to the score between each fine pixel in the first image and each fine pixel in the second image. + + Args: + fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`): + Fine features from the first image + fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2) + ** 2)`): + Fine features from the second image + coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): + Keypoint coordinates found in coarse matching for the first and second image + fine_scale (`int`): + Scale between the size of fine features and coarse features + + Returns: + fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`): + Matched keypoint between the first and the second image. All matched keypoints are concatenated in the + second dimension. + first_stage_fine_confidence (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size + ** 2)`): + Scores of fine matching in the first stage + second_stage_fine_confidence (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, + (fine_kernel_size + 2) ** 2)`): + Scores of fine matching in the second stage + + """ + num_matches, fine_window_size, _ = fine_features_0.shape + + if num_matches == 0: + fine_confidence = torch.empty(0, fine_window_size, fine_window_size, device=fine_features_0.device) + return coarse_matched_keypoints, fine_confidence, fine_confidence + + fine_kernel_size = int(math.sqrt(fine_window_size)) + + first_stage_fine_features_0 = fine_features_0[..., : -self.config.fine_matching_slicedim] + first_stage_fine_features_1 = fine_features_1[..., : -self.config.fine_matching_slicedim] + first_stage_fine_features_0 = first_stage_fine_features_0 / first_stage_fine_features_0.shape[-1] ** 0.5 + first_stage_fine_features_1 = first_stage_fine_features_1 / first_stage_fine_features_1.shape[-1] ** 0.5 + first_stage_fine_confidence = first_stage_fine_features_0 @ first_stage_fine_features_1.transpose(-1, -2) + first_stage_fine_confidence = nn.functional.softmax(first_stage_fine_confidence, 1) * nn.functional.softmax( + first_stage_fine_confidence, 2 + ) + first_stage_fine_confidence = first_stage_fine_confidence.reshape( + num_matches, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 + ) + first_stage_fine_confidence = first_stage_fine_confidence[..., 1:-1, 1:-1] + first_stage_fine_confidence = first_stage_fine_confidence.reshape( + num_matches, fine_window_size, fine_window_size + ) + + fine_indices, fine_matches = self.get_first_stage_fine_matching( + first_stage_fine_confidence, + coarse_matched_keypoints, + fine_window_size, + fine_scale, + ) + + second_stage_fine_features_0 = fine_features_0[..., -self.config.fine_matching_slicedim :] + second_stage_fine_features_1 = fine_features_1[..., -self.config.fine_matching_slicedim :] + second_stage_fine_features_1 = second_stage_fine_features_1 / self.config.fine_matching_slicedim**0.5 + second_stage_fine_confidence = second_stage_fine_features_0 @ second_stage_fine_features_1.transpose(-1, -2) + + fine_coordinates = self.get_second_stage_fine_matching( + fine_indices, + fine_matches, + second_stage_fine_confidence, + fine_window_size, + fine_scale, + ) + + return fine_coordinates, first_stage_fine_confidence, second_stage_fine_confidence + + @add_start_docstrings_to_model_forward(EFFICIENTLOFTR_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, "KeypointMatchingOutput"]: + """ + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" + >>> image1 = Image.open(requests.get(url, stream=True).raw) + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" + >>> image2 = Image.open(requests.get(url, stream=True).raw) + >>> images = [image1, image2] + + >>> processor = AutoImageProcessor.from_pretrained("stevenbucaille/efficient_loftr") + >>> model = AutoModel.from_pretrained("stevenbucaille/efficient_loftr") + + >>> with torch.no_grad(): + >>> inputs = processor(images, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + loss = None + if labels is not None: + raise ValueError("SuperGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + pixel_values = self.extract_one_channel_pixel_values(pixel_values) + + # 1. Local Feature CNN + backbone_outputs = self.backbone(pixel_values, output_hidden_states=output_hidden_states) + coarse_features, residual_features = backbone_outputs[:2] + coarse_channels, coarse_height, coarse_width = coarse_features.shape[-3:] + + # 2. Coarse-level LoFTR module + position_embeddings = self.rotary_emb(coarse_features) + coarse_features = coarse_features.reshape(batch_size, 2, coarse_channels, coarse_height, coarse_width) + local_feature_transformer_outputs = self.local_feature_transformer( + coarse_features, position_embeddings=position_embeddings, output_attentions=output_attentions + ) + coarse_features = local_feature_transformer_outputs[0] + + # 3. Compute coarse-level matching + coarse_scale = height / coarse_height + ( + coarse_matched_keypoints, + coarse_matching_scores, + batch_indices, + matched_indices, + ) = self.coarse_matching(coarse_features, coarse_scale) + + # 4. Fine-level refinement + refinement_layer_outputs = self.refinement_layer( + coarse_features, residual_features, output_hidden_states=output_hidden_states + ) + fine_features_0, fine_features_1 = refinement_layer_outputs[:2] + fine_features_0 = fine_features_0[batch_indices, matched_indices[0]] + fine_features_1 = fine_features_1[batch_indices, matched_indices[1]] + + # 5. Computer fine-level matching + fine_height = int(coarse_height * coarse_scale) + fine_scale = height / fine_height + matching_keypoints, first_stage_matching_scores, second_stage_matching_scores = self.fine_matching( + fine_features_0, + fine_features_1, + coarse_matched_keypoints, + fine_scale, + ) + + matching_keypoints[:, :, 0] = matching_keypoints[:, :, 0] / width + matching_keypoints[:, :, 1] = matching_keypoints[:, :, 1] / height + + unique_values, counts = torch.unique_consecutive(batch_indices, return_counts=True) + + if len(unique_values) > 0: + matching_keypoints_0 = matching_keypoints[0] + matching_keypoints_1 = matching_keypoints[1] + split_keypoints_0 = torch.split(matching_keypoints_0, counts.tolist()) + split_keypoints_1 = torch.split(matching_keypoints_1, counts.tolist()) + split_scores = torch.split(coarse_matching_scores, counts.tolist()) + + split_mask = [torch.ones(size, device=matching_keypoints.device) for size in counts.tolist()] + split_indices = [torch.arange(size, device=matching_keypoints.device) for size in counts.tolist()] + + keypoints_0 = pad_sequence(split_keypoints_0, batch_first=True) + keypoints_1 = pad_sequence(split_keypoints_1, batch_first=True) + matching_scores = pad_sequence(split_scores, batch_first=True) + mask = pad_sequence(split_mask, batch_first=True) + matches = pad_sequence(split_indices, batch_first=True) + + keypoints = torch.stack([keypoints_0, keypoints_1], dim=1) + matching_scores = torch.stack([matching_scores, matching_scores], dim=1) + mask = torch.stack([mask, mask], dim=1) + matches = torch.stack([matches, matches], dim=1) + + else: + keypoints = matching_keypoints.unsqueeze(0) + matching_scores = torch.stack([coarse_matching_scores, coarse_matching_scores], dim=0).unsqueeze(0) + mask = torch.ones_like(keypoints) + matches = torch.stack([matched_indices, matched_indices], dim=0).unsqueeze(0) + + if output_hidden_states: + all_hidden_states = all_hidden_states + backbone_outputs[2] + refinement_layer_outputs[2] + + if output_attentions: + all_attentions = all_attentions + local_feature_transformer_outputs[1] + + if not return_dict: + return tuple( + v + for v in [loss, matches, matching_scores, keypoints, mask, all_hidden_states, all_attentions] + if v is not None + ) + + return KeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + mask=mask, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +__all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRForKeypointMatching"] diff --git a/src/transformers/models/efficientloftr/modular_efficientloftr.py b/src/transformers/models/efficientloftr/modular_efficientloftr.py new file mode 100644 index 000000000000..2e81853997da --- /dev/null +++ b/src/transformers/models/efficientloftr/modular_efficientloftr.py @@ -0,0 +1,1434 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...activations import ACT2CLS, ACT2FN +from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ..cohere.modeling_cohere import apply_rotary_pos_emb +from ..llama.modeling_llama import LlamaAttention, eager_attention_forward +from ..rt_detr_v2.modeling_rt_detr_v2 import RTDetrV2ConvNormLayer +from ..superpoint.modeling_superpoint import SuperPointPreTrainedModel + + +"""PyTorch EfficientLoFTR model.""" + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC_ = "EfficientLoFTRConfig" +_CHECKPOINT_FOR_DOC_ = "stevenbucaille/efficient_loftr" + + +def create_meshgrid( + height: int, + width: int, + normalized_coordinates: bool = False, + device: torch.device = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + """ + Copied from kornia library : kornia/kornia/utils/grid.py:26 + + Generate a coordinate grid for an image. + + When the flag ``normalized_coordinates`` is set to True, the grid is + normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch + function :py:func:`torch.nn.functional.grid_sample`. + + Args: + height (`int`): + The image height (rows). + width (`int`): + The image width (cols). + normalized_coordinates (`bool`): + Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the + PyTorch function :py:func:`torch.nn.functional.grid_sample`. + device (`torch.device`): + The device on which the grid will be generated. + dtype (`torch.dtype`): + The data type of the generated grid. + + Return: + grid (`torch.Tensor` of shape `(1, height, width, 2)`): + The grid tensor. + + Example: + >>> create_meshgrid(2, 2) + tensor([[[[-1., -1.], + [ 1., -1.]], + + [[-1., 1.], + [ 1., 1.]]]]) + + >>> create_meshgrid(2, 2, normalized_coordinates=False) + tensor([[[[0., 0.], + [1., 0.]], + + [[0., 1.], + [1., 1.]]]]) + + """ + xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + if normalized_coordinates: + xs = (xs / (width - 1) - 0.5) * 2 + ys = (ys / (height - 1) - 0.5) * 2 + grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1) + grid = grid.permute(1, 0, 2).unsqueeze(0) + return grid + + +def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor: + r""" + Copied from kornia library : kornia/geometry/subpix/dsnt.py:76 + Compute the expectation of coordinate values using spatial probabilities. + + The input heatmap is assumed to represent a valid spatial probability distribution, + which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`. + + Args: + input (`torch.Tensor` of shape `(batch_size, channels, height, width)`): + The input tensor representing dense spatial probabilities. + normalized_coordinates (`bool`): + Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return + the coordinates in the range of the input shape. + + Returns: + output (`torch.Tensor` of shape `(batch_size, channels, 2)`) + Expected value of the 2D coordinates. Output order of the coordinates is (x, y). + + Examples: + >>> heatmaps = torch.tensor([[[ + ... [0., 0., 0.], + ... [0., 0., 0.], + ... [0., 1., 0.]]]]) + >>> spatial_expectation2d(heatmaps, False) + tensor([[[1., 2.]]]) + + """ + batch_size, channels, height, width = input.shape + + # Create coordinates grid. + grid = create_meshgrid(height, width, normalized_coordinates, input.device) + grid = grid.to(input.dtype) + + pos_x = grid[..., 0].reshape(-1) + pos_y = grid[..., 1].reshape(-1) + + input_flat = input.view(batch_size, channels, -1) + + # Compute the expectation of the coordinates. + expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True) + expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True) + + output = torch.cat([expected_x, expected_y], -1) + + return output.view(batch_size, channels, 2) + + +def mask_border(tensor: torch.Tensor, border_margin: int, value: Union[bool, float, int]) -> torch.Tensor: + """ + Mask a tensor border with a given value + + Args: + tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + The tensor to mask + border_margin (`int`) : + The size of the border + value (`Union[bool, int, float]`): + The value to place in the tensor's borders + + Returns: + tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + The masked tensor + """ + if border_margin <= 0: + return tensor + + tensor[:, :border_margin] = value + tensor[:, :, :border_margin] = value + tensor[:, :, :, :border_margin] = value + tensor[:, :, :, :, :border_margin] = value + tensor[:, -border_margin:] = value + tensor[:, :, -border_margin:] = value + tensor[:, :, :, -border_margin:] = value + tensor[:, :, :, :, -border_margin:] = value + return tensor + + +class EfficientLoFTRConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EffientLoFTRFromKeypointMatching`]. + It is used to instantiate a EfficientLoFTR model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the + EfficientLoFTR [stevenbucaille/efficient_loftr](https://huggingface.co/stevenbucaille/efficient_loftr) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + stage_block_dims (`List`, *optional*, defaults to [64, 64, 128, 256]): + The hidden size of the features in the blocks of each stage + stage_num_blocks (`List`, *optional*, defaults to [1, 2, 4, 14]): + The number of blocks in each stages + stage_hidden_expansion (`List`, *optional*, defaults to [1, 1, 1, 1]): + The rate of expansion of hidden size in each stage + stage_stride (`List`, *optional*, defaults to [2, 1, 2, 2]): + The stride used in each stage + hidden_size (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + activation_function (`str`, *optional*, defaults to `"relu"`): + The activation function used in the backbone + aggregation_sizes (`List`, *optional*, defaults to [4, 4]): + The size of each aggregation for the fusion network + num_attention_layers (`int`, *optional*, defaults to 4): + Number of attention layers in the LocalFeatureTransformer + num_attention_heads (`int`, *optional*, defaults to 8): + The number of heads in the GNN layers. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during attention. + mlp_activation_function (`str`, *optional*, defaults to `"leaky_relu"`): + Activation function used in the attention mlp layer. + coarse_matching_skip_softmax (`bool`, *optional*, defaults to `False`): + Whether to skip softmax or not at the coarse matching step. + coarse_matching_threshold (`float`, *optional*, defaults to 0.2): + The threshold for the minimum score required for a match. + coarse_matching_temperature (`float`, *optional*, defaults to 0.1): + The temperature to apply to the coarse similarity matrix + coarse_matching_border_removal (`int`, *optional*, defaults to 2): + The size of the border to remove during coarse matching + fine_kernel_size (`int`, *optional*, defaults to 8): + Kernel size used for the fine feature matching + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch normalization layers. + rope_type (`str`, *optional*, defaults to `"2d"`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3', '2d'], with 'default' being the original RoPE implementation. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + fine_matching_slicedim (`int`, *optional*, defaults to 8): + The size of the slice used to divide the fine features for the first and second fine matching stages. + fine_matching_regress_temperature (`float`, *optional*, defaults to 10.0): + The temperature to apply to the fine similarity matrix + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Examples: + ```python + >>> from transformers import EfficientLoFTRConfig, EfficientLoFTRForKeypointMatching + + >>> # Initializing a SuperGlue superglue style configuration + >>> configuration = EfficientLoFTRConfig() + + >>> # Initializing a model from the superglue style configuration + >>> model = EfficientLoFTRForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "efficientloftr" + + def __init__( + self, + stage_block_dims: List[int] = None, + stage_num_blocks: List[int] = None, + stage_hidden_expansion: List[float] = None, + stage_stride: List[int] = None, + hidden_size: int = 256, + activation_function: str = "relu", + aggregation_sizes: List[int] = None, + num_attention_layers: int = 4, + num_attention_heads: int = 8, + num_key_value_heads: int = None, + attention_dropout: float = 0.0, + attention_bias: bool = False, + mlp_activation_function: str = "leaky_relu", + coarse_matching_skip_softmax: bool = False, + coarse_matching_threshold: float = 0.2, + coarse_matching_temperature: float = 0.1, + coarse_matching_border_removal: int = 2, + fine_kernel_size: int = 8, + batch_norm_eps: float = 1e-5, + rope_type: str = "2d", + rope_theta: float = 10000.0, + fine_matching_slicedim: int = 8, + fine_matching_regress_temperature: float = 10.0, + initializer_range: float = 0.02, + **kwargs, + ): + self.stage_block_dims = stage_block_dims if stage_block_dims is not None else [64, 64, 128, 256] + self.stage_num_blocks = stage_num_blocks if stage_num_blocks is not None else [1, 2, 4, 14] + self.stage_hidden_expansion = stage_hidden_expansion if stage_hidden_expansion is not None else [1, 1, 1, 1] + self.stage_stride = stage_stride if stage_stride is not None else [2, 1, 2, 2] + self.hidden_size = hidden_size + if self.hidden_size != self.stage_block_dims[-1]: + raise ValueError( + f"hidden_size should be equal to the last value in stage_block_dims. hidden_size = {self.hidden_size}, stage_blck_dims = {self.stage_block_dims}" + ) + + self.activation_function = activation_function + self.aggregation_sizes = aggregation_sizes if aggregation_sizes is not None else [4, 4] + self.num_attention_layers = num_attention_layers + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.mlp_activation_function = mlp_activation_function + self.coarse_matching_skip_softmax = coarse_matching_skip_softmax + self.coarse_matching_threshold = coarse_matching_threshold + self.coarse_matching_temperature = coarse_matching_temperature + self.coarse_matching_border_removal = coarse_matching_border_removal + self.fine_kernel_size = fine_kernel_size + self.batch_norm_eps = batch_norm_eps + self.fine_matching_slicedim = fine_matching_slicedim + self.fine_matching_regress_temperature = fine_matching_regress_temperature + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.rope_type = rope_type + self.rope_theta = rope_theta + + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +@dataclass +class KeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number + of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of + images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is + used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching + information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches and matching_scores are keypoint matching information. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)`, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`) + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + mask: Optional[torch.IntTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class EfficientLoFTRRotaryEmbedding(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, device="cpu") -> None: + super().__init__() + self.config = config + self.rope_type = config.rope_type + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + b, _, h, w = x.shape + + i_position_indices = torch.ones(h, w, device=x.device).cumsum(0).float().unsqueeze(-1) + j_position_indices = torch.ones(h, w, device=x.device).cumsum(1).float().unsqueeze(-1) + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, None, None, :].float().expand(1, 1, 1, -1) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + emb = torch.zeros(1, h, w, self.config.hidden_size // 2) + emb[:, :, :, 0::2] = i_position_indices * inv_freq_expanded + emb[:, :, :, 1::2] = j_position_indices * inv_freq_expanded + + sin = emb.sin() + cos = emb.cos() + + sin = sin.repeat_interleave(2, dim=-1) + cos = cos.repeat_interleave(2, dim=-1) + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + sin = sin.to(device=x.device, dtype=x.dtype) + cos = cos.to(device=x.device, dtype=x.dtype) + + return cos, sin + + +class EfficientLoFTRConvNormLayer(RTDetrV2ConvNormLayer): + pass + + +class EfficientLoFTRRepVGGBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: EfficientLoFTRConfig, in_channels: int, out_channels: int, stride: int = 1) -> None: + super().__init__() + activation = config.activation_function + self.conv1 = EfficientLoFTRConvNormLayer( + config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1 + ) + self.conv2 = EfficientLoFTRConvNormLayer( + config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0 + ) + self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None + self.activation = nn.Identity() if activation is None else ACT2FN[activation] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.identity is not None: + identity_out = self.identity(hidden_states) + else: + identity_out = 0 + hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out + hidden_states = self.activation(hidden_states) + return hidden_states + + +class EfficientLoFTRRepVGGStage(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, in_channels, out_channels, num_blocks, stride) -> None: + super().__init__() + + strides = [stride] + [1] * (num_blocks - 1) + current_channel_dim = in_channels + blocks = [] + for stride in strides: + blocks.append( + EfficientLoFTRRepVGGBlock( + config, + current_channel_dim, + out_channels, + stride, + ) + ) + current_channel_dim = out_channels + self.blocks = nn.ModuleList(blocks) + + def forward( + self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + for block in self.blocks: + hidden_states = block(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + return hidden_states, all_hidden_states + + +class EfficientLoFTRepVGG(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + self.stages = nn.ModuleList([]) + num_stages = len(config.stage_block_dims) + current_in_channels = 1 + + for i in range(num_stages): + out_channels = int(config.stage_block_dims[i] * config.stage_hidden_expansion[i]) + stage = EfficientLoFTRRepVGGStage( + config, current_in_channels, out_channels, config.stage_num_blocks[i], config.stage_stride[i] + ) + current_in_channels = out_channels + self.stages.append(stage) + + def forward( + self, hidden_states: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, List[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + outputs = [] + all_hidden_states = () if output_hidden_states else None + for stage in self.stages: + stage_outputs = stage(hidden_states, output_hidden_states=output_hidden_states) + hidden_states = stage_outputs[0] + if output_hidden_states: + all_hidden_states = all_hidden_states + stage_outputs[1] + outputs.append(hidden_states) + + # Exclude first stage in outputs + outputs = outputs[1:] + # Last stage outputs are coarse outputs + coarse_features = outputs[-1] + # Rest is residual features used in EfficientLoFTRFineFusionLayer + residual_features = outputs[:-1] + return coarse_features, residual_features, all_hidden_states + + +class EfficientLoFTRAggregationLayer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + hidden_size = config.hidden_size + aggregation_sizes = config.aggregation_sizes + self.q_aggregation = ( + nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=aggregation_sizes[0], + padding=0, + stride=aggregation_sizes[0], + bias=False, + groups=hidden_size, + ) + if aggregation_sizes[0] != 1 + else nn.Identity() + ) + + self.kv_aggregation = ( + torch.nn.MaxPool2d(kernel_size=aggregation_sizes[1], stride=aggregation_sizes[1]) + if aggregation_sizes[1] != 1 + else nn.Identity() + ) + + self.norm = nn.LayerNorm(hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + query_states = hidden_states + is_cross_attention = encoder_hidden_states is not None + kv_states = encoder_hidden_states if is_cross_attention else hidden_states + + query_states = self.q_aggregation(query_states) + kv_states = self.kv_aggregation(kv_states) + query_states = query_states.permute(0, 2, 3, 1) + kv_states = kv_states.permute(0, 2, 3, 1) + hidden_states = self.norm(query_states) + encoder_hidden_states = self.norm(kv_states) + if attention_mask is not None: + current_mask = encoder_attention_mask if is_cross_attention else attention_mask + attention_mask = self.kv_aggregation(attention_mask.float()).bool() + encoder_attention_mask = self.kv_aggregation(current_mask.float()).bool() + return hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask + + +class EfficientLoFTRAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_len, dim = hidden_states.shape + input_shape = hidden_states.shape[:-1] + + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim) + value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) + + query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EfficientLoFTRMLP(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + hidden_size = config.hidden_size + self.fc1 = nn.Linear(2 * hidden_size, 2 * hidden_size, bias=False) + self.activation = ACT2FN[config.mlp_activation_function] + self.fc2 = nn.Linear(2 * hidden_size, hidden_size, bias=False) + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +def get_positional_embeddings_slice( + hidden_states: torch.Tensor, positional_embeddings: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, h, w, _ = hidden_states.shape + positional_embeddings = tuple( + tensor[:, :h, :w, :].expand(batch_size, -1, -1, -1) for tensor in positional_embeddings + ) + return positional_embeddings + + +class EfficientLoFTRAggregatedAttention(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int) -> None: + super().__init__() + + self.aggregation_sizes = config.aggregation_sizes + self.aggregation = EfficientLoFTRAggregationLayer(config) + self.attention = EfficientLoFTRAttention(config, layer_idx) + self.mlp = EfficientLoFTRMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> Tuple[torch.Tensor]: + batch_size, channels, h, w = hidden_states.shape + + # Aggregate features + aggregated_hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask = self.aggregation( + hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask + ) + + attention_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, channels) + encoder_hidden_states = encoder_hidden_states.reshape(batch_size, -1, channels) + + if position_embeddings is not None: + position_embeddings = get_positional_embeddings_slice(aggregated_hidden_states, position_embeddings) + position_embeddings = tuple(tensor.reshape(batch_size, -1, channels) for tensor in position_embeddings) + + # Multi-head attention + attention_outputs = self.attention( + attention_hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + position_embeddings=position_embeddings, + ) + message = attention_outputs[0] + + # Upsample features + _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape + # (batch_size, seq_len, channels) -> (batch_size, channels, h, w) with seq_len = h * w + message = message.permute(0, 2, 1) + message = message.reshape(batch_size, channels, aggregated_h, aggregated_w) + if self.aggregation_sizes[0] != 1: + message = torch.nn.functional.interpolate( + message, scale_factor=self.aggregation_sizes[0], mode="bilinear", align_corners=False + ) + intermediate_states = torch.cat([hidden_states, message], dim=1) + intermediate_states = intermediate_states.permute(0, 2, 3, 1) + output_states = self.mlp(intermediate_states) + output_states = output_states.permute(0, 3, 1, 2) + + hidden_states = hidden_states + output_states + + outputs = (hidden_states,) + attention_outputs[1:] + return outputs + + +class EfficientLoFTRLocalFeatureTransformerLayer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int) -> None: + super().__init__() + + self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) + self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_attentions = () if output_attentions else None + batch_size, _, c, h, w = hidden_states.shape + + hidden_states = hidden_states.reshape(-1, c, h, w) + if attention_mask is not None: + attention_mask = attention_mask.reshape(-1, c, h, w) + + self_attention_outputs = self.self_attention( + hidden_states, attention_mask, position_embeddings=position_embeddings + ) + hidden_states = self_attention_outputs[0] + + encoder_hidden_states = hidden_states.reshape(-1, 2, c, h, w).flip(1).reshape(-1, c, h, w) + encoder_attention_mask = None + if attention_mask is not None: + encoder_attention_mask = attention_mask.reshape(-1, 2, c, h, w).flip(1).reshape(-1, c, h, w) + + cross_attention_outputs = self.cross_attention( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + + hidden_states = cross_attention_outputs[0] + hidden_states = hidden_states.reshape(batch_size, -1, c, h, w) + + if output_attentions: + all_attentions = all_attentions + (self_attention_outputs[1], cross_attention_outputs[1]) + + return hidden_states, all_attentions + + +class EfficientLoFTRLocalFeatureTransformer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [ + EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i) + for i in range(config.num_attention_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_attentions = () if output_attentions else None + + for layer in self.layers: + layer_outputs = layer( + hidden_states, position_embeddings=position_embeddings, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + layer_outputs[1] + return hidden_states, all_attentions + + +class EfficientLoFTROutConvBlock(nn.Module): + def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int) -> None: + super().__init__() + + self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False) + self.out_conv2 = nn.Conv2d( + intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False + ) + self.batch_norm = nn.BatchNorm2d(intermediate_size) + self.activation = ACT2CLS[config.mlp_activation_function]() + self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, hidden_states: torch.Tensor, residual_states: List[torch.Tensor]) -> torch.Tensor: + residual_states = self.out_conv1(residual_states) + residual_states = residual_states + hidden_states + residual_states = self.out_conv2(residual_states) + residual_states = self.batch_norm(residual_states) + residual_states = self.activation(residual_states) + residual_states = self.out_conv3(residual_states) + residual_states = nn.functional.interpolate( + residual_states, scale_factor=2.0, mode="bilinear", align_corners=False + ) + return residual_states + + +class EfficientLoFTRFineFusionLayer(nn.Module): + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__() + + self.fine_kernel_size = config.fine_kernel_size + + stage_block_dims = config.stage_block_dims + stage_block_dims = list(reversed(stage_block_dims))[:-1] + self.out_conv = nn.Conv2d( + stage_block_dims[0], stage_block_dims[0], kernel_size=1, stride=1, padding=0, bias=False + ) + self.out_conv_layers = nn.ModuleList() + for i in range(1, len(stage_block_dims)): + out_conv = EfficientLoFTROutConvBlock(config, stage_block_dims[i], stage_block_dims[i - 1]) + self.out_conv_layers.append(out_conv) + + def forward_pyramid( + self, + hidden_states: torch.Tensor, + residual_states: List[torch.Tensor], + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + hidden_states = self.out_conv(hidden_states) + hidden_states = nn.functional.interpolate( + hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for i, layer in enumerate(self.out_conv_layers): + hidden_states = self.out_conv_layers[i](hidden_states, residual_states[i]) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states + + def forward( + self, + coarse_features: torch.Tensor, + residual_features: List[torch.Tensor], + output_hidden_states: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor]]]: + """ + For each image pair, compute the fine features of pixels. + In both images, compute a patch of fine features center cropped around each coarse pixel. + In the first image, the feature patch is kernel_size large and long. + In the second image, it is (kernel_size + 2) large and long. + """ + batch_size, _, channels, coarse_height, coarse_width = coarse_features.shape + + coarse_features = coarse_features.reshape(-1, channels, coarse_height, coarse_width) + residual_features = list(reversed(residual_features)) + + # 1. Fine feature extraction + pyramid_outputs = self.forward_pyramid( + coarse_features, residual_features, output_hidden_states=output_hidden_states + ) + fine_features = pyramid_outputs[0] + _, fine_channels, fine_height, fine_width = fine_features.shape + + fine_features = fine_features.reshape(batch_size, 2, fine_channels, fine_height, fine_width) + fine_features_0 = fine_features[:, 0] + fine_features_1 = fine_features[:, 1] + + # 2. Unfold all local windows in crops + stride = int(fine_height // coarse_height) + fine_features_0 = nn.functional.unfold( + fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0 + ) + _, _, seq_len = fine_features_0.shape + fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len) + fine_features_0 = fine_features_0.permute(0, 3, 2, 1) + + fine_features_1 = nn.functional.unfold( + fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1 + ) + fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len) + fine_features_1 = fine_features_1.permute(0, 3, 2, 1) + + return fine_features_0, fine_features_1, pyramid_outputs[1] + + +class EfficientLoFTRPreTrainedModel(SuperPointPreTrainedModel): + config_class = EfficientLoFTRConfig + base_model_prefix = "efficientloftr" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +EFFICIENTLOFTR_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`EfficientLoFTRConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + +EFFICIENTLOFTR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SuperGlueImageProcessor`]. See + [`SuperGlueImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors. See `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "EfficientLoFTR model taking images as inputs and outputting the matching of them.", + EFFICIENTLOFTR_START_DOCSTRING, +) +class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel): + """EfficientLoFTR dense image matcher + + Given two images, we determine the correspondences by: + 1. Extracting coarse and fine features through a backbone + 2. Transforming coarse features through self and cross attention + 3. Matching coarse features to obtain coarse coordinates of matches + 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features + 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement + + Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou. + Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed + In CVPR, 2024. https://arxiv.org/abs/2403.04765 + """ + + def __init__(self, config: EfficientLoFTRConfig) -> None: + super().__init__(config) + + self.config = config + self.backbone = EfficientLoFTRepVGG(config) + self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config) + self.refinement_layer = EfficientLoFTRFineFusionLayer(config) + + self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config) + + self.post_init() + + def get_matches_from_scores(self, scores: torch.Tensor): + """ + Based on a keypoint score matrix, compute the best keypoint matches between the first and second image. + Since each image pair can have different number of matches, the matches are concatenated together for all pair + in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch. + Args: + scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + Scores of keypoints + + Returns: + matched_indices (`torch.Tensor` of shape `(2, num_matches)`): + Indices representing which pixel in the first image matches which pixel in the second image + matching_scores (`torch.Tensor` of shape `(num_matches,)`): + Scores of each match + batch_indices (`torch.Tensor` of shape `(num_matches,)`): + Batch correspondences of matches + """ + batch_size, height0, width0, height1, width1 = scores.shape + + scores = scores.reshape(batch_size, height0 * width0, height1 * width1) + + # For each keypoint, get the best match + max_0 = scores.max(2, keepdim=True).values + max_1 = scores.max(1, keepdim=True).values + + # 1. Thresholding + mask = scores > self.config.coarse_matching_threshold + + # 2. Border removal + mask = mask.reshape(batch_size, height0, width0, height1, width1) + mask = mask_border(mask, self.config.coarse_matching_border_removal, False) + mask = mask.reshape(batch_size, height0 * width0, height1 * width1) + + # 3. Mutual nearest neighbors + mask = mask * (scores == max_0) * (scores == max_1) + + # 4. Fine coarse matches + mask_values, mask_indices = mask.max(dim=2) + batch_indices, matched_indices_0 = torch.where(mask_values) + matched_indices_1 = mask_indices[batch_indices, matched_indices_0] + matching_scores = scores[batch_indices, matched_indices_0, matched_indices_1] + + matched_indices = torch.stack([matched_indices_0, matched_indices_1], dim=0) + return matched_indices, matching_scores, batch_indices + + def coarse_matching( + self, coarse_features: torch.Tensor, coarse_scale: float + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8) + * (image_width / 8 elements)) from the first image to the second image. Since the number of matches can vary + with different image pairs, the matches are concatenated together in a dimension. A batch_indices tensor is + returned to inform which keypoint is part of which image pair. + + Args: + coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`): + Coarse features + coarse_scale (`float`): Scale between the image size and the coarse size + + Returns: + matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): + Matched keypoint between the first and the second image. All matched keypoints are concatenated in the + second dimension. + matching_scores (`torch.Tensor` of shape `(batch_size, num_matches)`): + The confidence score of each matched keypoint. + batch_indices (`torch.Tensor` of shape `(num_matches,)`): + Indices of batches for each matched keypoint found. + """ + batch_size, _, channels, height, width = coarse_features.shape + + # (batch_size, 2, channels, height, width) -> (batch_size, 2, height * width, channels) + coarse_features = coarse_features.permute(0, 1, 3, 4, 2) + coarse_features = coarse_features.reshape(batch_size, 2, -1, channels) + + coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5 + coarse_features_0 = coarse_features[:, 0] + coarse_features_1 = coarse_features[:, 1] + + similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2) + similarity = similarity / self.config.coarse_matching_temperature + + if self.config.coarse_matching_skip_softmax: + confidence = similarity + else: + confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2) + + confidence = confidence.reshape(batch_size, height, width, height, width) + matched_indices, matching_scores, batch_indices = self.get_matches_from_scores(confidence) + + matched_keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale + + return ( + matched_keypoints, + matching_scores, + batch_indices, + matched_indices, + ) + + def get_first_stage_fine_matching( + self, + fine_confidence: torch.Tensor, + coarse_matched_keypoints: torch.Tensor, + fine_window_size: int, + fine_scale: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + For each coarse pixel, retrieve the highest fine confidence score and index. + The index represents the matching between a pixel position in the fine window in the first image and a pixel + position in the fine window of the second image. + For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38 + (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38 + which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example + the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image. + + Args: + fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): + First stage confidence of matching fine features between the first and the second image + coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coarse matched keypoint between the first and the second image. + fine_window_size (`int`): + Size of the window used to refine matches + fine_scale (`float`): + Scale between the size of fine features and coarse features + + Returns: + indices (`torch.Tensor` of shape `(2, num_matches, 1)`): + Indices of the fine coordinate matched in the fine window + fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the first fine stage + """ + num_matches, _, _ = fine_confidence.shape + fine_kernel_size = int(math.sqrt(fine_window_size)) + + fine_confidence = fine_confidence.reshape(num_matches, -1) + values, indices = torch.max(fine_confidence, dim=-1) + indices = indices[..., None] + indices_0 = indices // fine_window_size + indices_1 = indices % fine_window_size + + grid = create_meshgrid( + fine_kernel_size, + fine_kernel_size, + normalized_coordinates=False, + device=fine_confidence.device, + dtype=fine_confidence.dtype, + ) + grid = grid - (fine_kernel_size // 2) + 0.5 + grid = grid.reshape(1, -1, 2).expand(num_matches, -1, -1) + delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, 2)).squeeze(1) + delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, 2)).squeeze(1) + + fine_matches_0 = coarse_matched_keypoints[0] + delta_0 * fine_scale + fine_matches_0 = fine_matches_0.reshape(num_matches, 2) + fine_matches_1 = coarse_matched_keypoints[1] + delta_1 * fine_scale + fine_matches_1 = fine_matches_1.reshape(num_matches, 2) + + indices = torch.stack([indices_0, indices_1], dim=0) + fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=0) + + return indices, fine_matches + + def get_second_stage_fine_matching( + self, + indices: torch.Tensor, + fine_matches: torch.Tensor, + fine_confidence: torch.Tensor, + fine_window_size: int, + fine_scale: float, + ) -> torch.Tensor: + """ + For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position. + After applying softmax to these confidences, compute the 2D spatial expected coordinates. + Shift the first stage fine matching with these expected coordinates. + + Args: + indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Indices representing the position of each keypoint in the fine window + fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the first fine stage + fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): + Second stage confidence of matching fine features between the first and the second image + fine_window_size (`int`): + Size of the window used to refine matches + fine_scale (`float`): + Scale between the size of fine features and coarse features + Returns: + fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the second fine stage + """ + num_matches, _, _ = fine_confidence.shape + fine_kernel_size = int(math.sqrt(fine_window_size)) + + indices_0 = indices[0] + indices_1 = indices[1] + indices_1_i = indices_1 // fine_kernel_size + indices_1_j = indices_1 % fine_kernel_size + + matches_indices = torch.arange(num_matches, device=indices_0.device) + + # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3) + matches_indices = matches_indices[..., None, None].expand(-1, 3, 3) + indices_0 = indices_0[..., None].expand(-1, 3, 3) + indices_1_i = indices_1_i[..., None].expand(-1, 3, 3) + indices_1_j = indices_1_j[..., None].expand(-1, 3, 3) + + delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long) + delta = delta[None, ...] + + indices_1_i = indices_1_i + delta[..., 1] + indices_1_j = indices_1_j + delta[..., 0] + + fine_confidence = fine_confidence.reshape( + num_matches, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 + ) + # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3) + fine_confidence = fine_confidence[matches_indices, indices_0, indices_1_i, indices_1_j] + fine_confidence = fine_confidence.reshape(num_matches, 9) + fine_confidence = nn.functional.softmax( + fine_confidence / self.config.fine_matching_regress_temperature, dim=-1 + ) + + heatmap = fine_confidence.reshape(1, -1, 3, 3) + fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0] + + fine_matches_0 = fine_matches[0] + fine_matches_1 = fine_matches[1] + (fine_coordinates_normalized * (3 // 2) * fine_scale) + + fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=0) + + return fine_matches + + def fine_matching( + self, + fine_features_0: torch.Tensor, + fine_features_1: torch.Tensor, + coarse_matched_keypoints: torch.Tensor, + fine_scale: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine + features in the first image and the second image. + + Fine features are sliced in two part : + - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8 + = 56 by default) features. + - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features. + + Each part is used to compute a fine confidence tensor of the following shape : + (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size) + They correspond to the score between each fine pixel in the first image and each fine pixel in the second image. + + Args: + fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`): + Fine features from the first image + fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2) + ** 2)`): + Fine features from the second image + coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`): + Keypoint coordinates found in coarse matching for the first and second image + fine_scale (`int`): + Scale between the size of fine features and coarse features + + Returns: + fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`): + Matched keypoint between the first and the second image. All matched keypoints are concatenated in the + second dimension. + first_stage_fine_confidence (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size + ** 2)`): + Scores of fine matching in the first stage + second_stage_fine_confidence (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, + (fine_kernel_size + 2) ** 2)`): + Scores of fine matching in the second stage + + """ + num_matches, fine_window_size, _ = fine_features_0.shape + + if num_matches == 0: + fine_confidence = torch.empty(0, fine_window_size, fine_window_size, device=fine_features_0.device) + return coarse_matched_keypoints, fine_confidence, fine_confidence + + fine_kernel_size = int(math.sqrt(fine_window_size)) + + first_stage_fine_features_0 = fine_features_0[..., : -self.config.fine_matching_slicedim] + first_stage_fine_features_1 = fine_features_1[..., : -self.config.fine_matching_slicedim] + first_stage_fine_features_0 = first_stage_fine_features_0 / first_stage_fine_features_0.shape[-1] ** 0.5 + first_stage_fine_features_1 = first_stage_fine_features_1 / first_stage_fine_features_1.shape[-1] ** 0.5 + first_stage_fine_confidence = first_stage_fine_features_0 @ first_stage_fine_features_1.transpose(-1, -2) + first_stage_fine_confidence = nn.functional.softmax(first_stage_fine_confidence, 1) * nn.functional.softmax( + first_stage_fine_confidence, 2 + ) + first_stage_fine_confidence = first_stage_fine_confidence.reshape( + num_matches, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 + ) + first_stage_fine_confidence = first_stage_fine_confidence[..., 1:-1, 1:-1] + first_stage_fine_confidence = first_stage_fine_confidence.reshape( + num_matches, fine_window_size, fine_window_size + ) + + fine_indices, fine_matches = self.get_first_stage_fine_matching( + first_stage_fine_confidence, + coarse_matched_keypoints, + fine_window_size, + fine_scale, + ) + + second_stage_fine_features_0 = fine_features_0[..., -self.config.fine_matching_slicedim :] + second_stage_fine_features_1 = fine_features_1[..., -self.config.fine_matching_slicedim :] + second_stage_fine_features_1 = second_stage_fine_features_1 / self.config.fine_matching_slicedim**0.5 + second_stage_fine_confidence = second_stage_fine_features_0 @ second_stage_fine_features_1.transpose(-1, -2) + + fine_coordinates = self.get_second_stage_fine_matching( + fine_indices, + fine_matches, + second_stage_fine_confidence, + fine_window_size, + fine_scale, + ) + + return fine_coordinates, first_stage_fine_confidence, second_stage_fine_confidence + + @add_start_docstrings_to_model_forward(EFFICIENTLOFTR_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, "KeypointMatchingOutput"]: + """ + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoModel + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" + >>> image1 = Image.open(requests.get(url, stream=True).raw) + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" + >>> image2 = Image.open(requests.get(url, stream=True).raw) + >>> images = [image1, image2] + + >>> processor = AutoImageProcessor.from_pretrained("stevenbucaille/efficient_loftr") + >>> model = AutoModel.from_pretrained("stevenbucaille/efficient_loftr") + + >>> with torch.no_grad(): + >>> inputs = processor(images, return_tensors="pt") + >>> outputs = model(**inputs) + ```""" + loss = None + if labels is not None: + raise ValueError("SuperGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + pixel_values = self.extract_one_channel_pixel_values(pixel_values) + + # 1. Local Feature CNN + backbone_outputs = self.backbone(pixel_values, output_hidden_states=output_hidden_states) + coarse_features, residual_features = backbone_outputs[:2] + coarse_channels, coarse_height, coarse_width = coarse_features.shape[-3:] + + # 2. Coarse-level LoFTR module + position_embeddings = self.rotary_emb(coarse_features) + coarse_features = coarse_features.reshape(batch_size, 2, coarse_channels, coarse_height, coarse_width) + local_feature_transformer_outputs = self.local_feature_transformer( + coarse_features, position_embeddings=position_embeddings, output_attentions=output_attentions + ) + coarse_features = local_feature_transformer_outputs[0] + + # 3. Compute coarse-level matching + coarse_scale = height / coarse_height + ( + coarse_matched_keypoints, + coarse_matching_scores, + batch_indices, + matched_indices, + ) = self.coarse_matching(coarse_features, coarse_scale) + + # 4. Fine-level refinement + refinement_layer_outputs = self.refinement_layer( + coarse_features, residual_features, output_hidden_states=output_hidden_states + ) + fine_features_0, fine_features_1 = refinement_layer_outputs[:2] + fine_features_0 = fine_features_0[batch_indices, matched_indices[0]] + fine_features_1 = fine_features_1[batch_indices, matched_indices[1]] + + # 5. Computer fine-level matching + fine_height = int(coarse_height * coarse_scale) + fine_scale = height / fine_height + matching_keypoints, first_stage_matching_scores, second_stage_matching_scores = self.fine_matching( + fine_features_0, + fine_features_1, + coarse_matched_keypoints, + fine_scale, + ) + + matching_keypoints[:, :, 0] = matching_keypoints[:, :, 0] / width + matching_keypoints[:, :, 1] = matching_keypoints[:, :, 1] / height + + unique_values, counts = torch.unique_consecutive(batch_indices, return_counts=True) + + if len(unique_values) > 0: + matching_keypoints_0 = matching_keypoints[0] + matching_keypoints_1 = matching_keypoints[1] + split_keypoints_0 = torch.split(matching_keypoints_0, counts.tolist()) + split_keypoints_1 = torch.split(matching_keypoints_1, counts.tolist()) + split_scores = torch.split(coarse_matching_scores, counts.tolist()) + + split_mask = [torch.ones(size, device=matching_keypoints.device) for size in counts.tolist()] + split_indices = [torch.arange(size, device=matching_keypoints.device) for size in counts.tolist()] + + keypoints_0 = pad_sequence(split_keypoints_0, batch_first=True) + keypoints_1 = pad_sequence(split_keypoints_1, batch_first=True) + matching_scores = pad_sequence(split_scores, batch_first=True) + mask = pad_sequence(split_mask, batch_first=True) + matches = pad_sequence(split_indices, batch_first=True) + + keypoints = torch.stack([keypoints_0, keypoints_1], dim=1) + matching_scores = torch.stack([matching_scores, matching_scores], dim=1) + mask = torch.stack([mask, mask], dim=1) + matches = torch.stack([matches, matches], dim=1) + + else: + keypoints = matching_keypoints.unsqueeze(0) + matching_scores = torch.stack([coarse_matching_scores, coarse_matching_scores], dim=0).unsqueeze(0) + mask = torch.ones_like(keypoints) + matches = torch.stack([matched_indices, matched_indices], dim=0).unsqueeze(0) + + if output_hidden_states: + all_hidden_states = all_hidden_states + backbone_outputs[2] + refinement_layer_outputs[2] + + if output_attentions: + all_attentions = all_attentions + local_feature_transformer_outputs[1] + + if not return_dict: + return tuple( + v + for v in [loss, matches, matching_scores, keypoints, mask, all_hidden_states, all_attentions] + if v is not None + ) + + return KeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + mask=mask, + hidden_states=all_hidden_states, + attentions=all_attentions, + ) + + +__all__ = ["EfficientLoFTRConfig", "EfficientLoFTRPreTrainedModel", "EfficientLoFTRForKeypointMatching"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e04a785f2c94..984fee7ca876 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3866,6 +3866,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class EfficientLoFTRForKeypointMatching(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EfficientLoFTRPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EfficientNetForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/efficientloftr/__init__.py b/tests/models/efficientloftr/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/efficientloftr/test_modeling_efficientloftr.py b/tests/models/efficientloftr/test_modeling_efficientloftr.py new file mode 100644 index 000000000000..4fa595c0c238 --- /dev/null +++ b/tests/models/efficientloftr/test_modeling_efficientloftr.py @@ -0,0 +1,323 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import unittest +from functools import reduce +from typing import List + +from datasets import load_dataset + +from transformers.models.efficientloftr import EfficientLoFTRConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_torch_available(): + import torch + + from transformers import EfficientLoFTRForKeypointMatching + +if is_vision_available(): + from transformers import AutoImageProcessor + + +class EfficientLoFTRModelTester: + def __init__( + self, + parent, + batch_size=2, + image_width=80, + image_height=60, + stage_block_dims: List[int] = [32, 32, 64], + stage_num_blocks: List[int] = [1, 1, 1], + stage_hidden_expansion: List[int] = [1, 1, 1], + stage_stride: List[int] = [2, 1, 2], + aggregation_sizes: List[int] = [1, 1], + num_attention_layers: int = 2, + num_attention_heads: int = 8, + hidden_size: int = 64, + coarse_matching_threshold: float = 0.0, + fine_kernel_size: int = 2, + coarse_matching_border_removal: int = 0, + ): + self.parent = parent + self.batch_size = batch_size + self.image_width = image_width + self.image_height = image_height + + self.stage_block_dims = stage_block_dims + self.stage_num_blocks = stage_num_blocks + self.stage_hidden_expansion = stage_hidden_expansion + self.stage_stride = stage_stride + self.aggregation_sizes = aggregation_sizes + self.num_attention_layers = num_attention_layers + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.coarse_matching_threshold = coarse_matching_threshold + self.coarse_matching_border_removal = coarse_matching_border_removal + self.fine_kernel_size = fine_kernel_size + + def prepare_config_and_inputs(self): + # SuperGlue expects a grayscale image as input + pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + return EfficientLoFTRConfig( + stage_block_dims=self.stage_block_dims, + stage_num_blocks=self.stage_num_blocks, + stage_hidden_expansion=self.stage_hidden_expansion, + stage_stride=self.stage_stride, + aggregation_sizes=self.aggregation_sizes, + num_attention_layers=self.num_attention_layers, + num_attention_heads=self.num_attention_heads, + hidden_size=self.hidden_size, + coarse_matching_threshold=self.coarse_matching_threshold, + coarse_matching_border_removal=self.coarse_matching_border_removal, + fine_kernel_size=self.fine_kernel_size, + ) + + def create_and_check_model(self, config, pixel_values): + model = EfficientLoFTRForKeypointMatching(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + maximum_num_matches = result.mask.shape[-1] + self.parent.assertEqual( + result.keypoints.shape, + (self.batch_size, 2, maximum_num_matches, 2), + ) + self.parent.assertEqual( + result.matches.shape, + (self.batch_size, 2, maximum_num_matches), + ) + self.parent.assertEqual( + result.matching_scores.shape, + (self.batch_size, 2, maximum_num_matches), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class EfficientLoFTRModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (EfficientLoFTRForKeypointMatching,) if is_torch_available() else () + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = True + + def setUp(self): + self.model_tester = EfficientLoFTRModelTester(self) + self.config_tester = ConfigTester(self, config_class=EfficientLoFTRConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + @unittest.skip(reason="SuperGlueForKeypointMatching does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching does not use feedforward chunking") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SuperGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SuperGlue does not output any loss term in the forward pass") + def test_retain_grad_hidden_states_attentions(self): + pass + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_hidden_states = 2 * len(self.model_tester.stage_num_blocks) - 1 + self.assertEqual(len(hidden_states), expected_num_hidden_states) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.image_height // 2, self.model_tester.image_width // 2], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_attention_outputs(self): + def check_attention_output(inputs_dict, config, model_class): + config._attn_implementation = "eager" + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + total_stride = reduce(lambda a, b: a * b, config.stage_stride) + hidden_size = ( + self.model_tester.image_height // total_stride * self.model_tester.image_width // total_stride + ) + + expected_attention_shape = [ + self.model_tester.num_attention_heads, + hidden_size, + hidden_size, + ] + + for i, attention in enumerate(attentions): + self.assertListEqual( + list(attention.shape[-3:]), + expected_attention_shape, + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + check_attention_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + + check_attention_output(inputs_dict, config, model_class) + + @slow + def test_model_from_pretrained(self): + from_pretrained_ids = ["stevenbucaille/efficient_loftr"] + for model_name in from_pretrained_ids: + model = EfficientLoFTRForKeypointMatching.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_forward_labels_should_be_none(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + model_inputs = self._prepare_for_class(inputs_dict, model_class) + # Provide an arbitrary sized Tensor as labels to model inputs + model_inputs["labels"] = torch.rand((128, 128)) + + with self.assertRaises(ValueError) as cm: + model(**model_inputs) + self.assertEqual(ValueError, cm.exception.__class__) + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image1 = dataset[0]["image"] + image2 = dataset[1]["image"] + image3 = dataset[2]["image"] + return [[image1, image2], [image3, image2]] + + +@require_torch +@require_vision +class EfficientLoFTRModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained("stevenbucaille/efficient_loftr") if is_vision_available() else None + + @slow + def test_inference(self): + model = EfficientLoFTRForKeypointMatching.from_pretrained( + "stevenbucaille/efficient_loftr", attn_implementation="eager" + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values = outputs.matches[0, 0, 10:20] + predicted_matching_scores_values = outputs.matching_scores[0, 0, 10:20] + + expected_number_of_matches = 780 + expected_matches_values = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + device=predicted_matches_values.device) # fmt:skip + expected_matching_scores_values = torch.tensor([0.9957,0.2224,0.8803, 0.9283, 0.2241, 0.6321, 0.5206, 0.8053, 0.7174, 0.9872], + device=predicted_matches_values.device) # fmt:skip + + self.assertTrue(predicted_number_of_matches == expected_number_of_matches) + self.assertTrue(torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-3)) + self.assertTrue(torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-3))