Skip to content

Commit cf09968

Browse files
refactor: use faiss instead of autofaiss (#80)
1 parent 7f7466b commit cf09968

File tree

2 files changed

+20
-23
lines changed

2 files changed

+20
-23
lines changed

tti_eval/evaluation/image_retrieval.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import Any
33

44
import numpy as np
5-
from autofaiss import build_index
5+
from faiss import IndexFlatL2
66

77
from tti_eval.common import Embeddings
8-
from tti_eval.utils import disable_tqdm, enable_tqdm
8+
from tti_eval.utils import disable_tqdm
99

1010
from .base import EvaluationModel
1111

@@ -46,28 +46,29 @@ def __init__(
4646
self._class_counts[class_ids] = counts
4747

4848
disable_tqdm() # Disable tqdm progress bar when building the index
49-
index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR)
50-
enable_tqdm()
51-
if index is None:
52-
raise ValueError("Failed to build an index for knn search")
53-
self._index = index
54-
55-
logger.info("knn classifier index_infos", extra=self.index_infos)
49+
d = self._val_embeddings.images.shape[-1]
50+
self._index = IndexFlatL2(d)
51+
self._index.add(self._val_embeddings.images)
5652

5753
def evaluate(self) -> float:
5854
_, nearest_indices = self._index.search(self._train_embeddings.images, self.k) # type: ignore
5955
nearest_classes = self._val_embeddings.labels[nearest_indices]
6056

6157
# To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved,
6258
# where Q represents the size of the respective class in the validation embeddings
63-
top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k)
59+
top_nearest_per_class = np.where(
60+
self._class_counts < self.k, self._class_counts, self.k
61+
)
6462
top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels]
6563

6664
# Add a placeholder value for indices outside the retrieval scope
6765
nearest_classes[np.arange(self.k) >= top_nearest_per_sample[:, np.newaxis]] = -1
6866

6967
# Count the number of neighbours that match the class of the sample and compute the mean accuracy
70-
matches_per_sample = np.sum(nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], axis=1)
68+
matches_per_sample = np.sum(
69+
nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis],
70+
axis=1,
71+
)
7172
accuracies = np.divide(
7273
matches_per_sample,
7374
top_nearest_per_sample,

tti_eval/evaluation/knn.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import Any
33

44
import numpy as np
5-
from autofaiss import build_index
5+
from faiss import IndexFlatL2
66

77
from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
8-
from tti_eval.utils import disable_tqdm, enable_tqdm
8+
from tti_eval.utils import disable_tqdm
99

1010
from .base import ClassificationModel
1111
from .utils import softmax
@@ -44,15 +44,9 @@ def __init__(
4444
super().__init__(train_embeddings, validation_embeddings, num_classes)
4545
self.k = k
4646
disable_tqdm() # Disable tqdm progress bar when building the index
47-
index, self.index_infos = build_index(
48-
train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR
49-
)
50-
enable_tqdm()
51-
if index is None:
52-
raise ValueError("Failed to build an index for knn search")
53-
self._index = index
54-
55-
logger.info("knn classifier index_infos", extra=self.index_infos)
47+
d = train_embeddings.images.shape[-1]
48+
self._index = IndexFlatL2(d)
49+
self._index.add(train_embeddings.images)
5650

5751
@staticmethod
5852
def get_default_params() -> dict[str, Any]:
@@ -65,7 +59,9 @@ def predict(self) -> tuple[ProbabilityArray, ClassArray]:
6559
# Calculate class votes from the distances (avoiding division by zero)
6660
# Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors
6761
max_value = np.finfo(np.float32).max
68-
scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0)
62+
scores = np.divide(
63+
1, dists, out=np.full_like(dists, max_value), where=dists != 0
64+
)
6965
# NOTE: if self.k and self.num_classes are both large, this might become a big one.
7066
# We can shape of a factor self.k if we count differently here.
7167
n = len(self._val_embeddings.images)

0 commit comments

Comments
 (0)