Skip to content

Commit

Permalink
refactor: use faiss instead of autofaiss (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord committed Jul 2, 2024
1 parent 7f7466b commit cf09968
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 23 deletions.
23 changes: 12 additions & 11 deletions tti_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any

import numpy as np
from autofaiss import build_index
from faiss import IndexFlatL2

from tti_eval.common import Embeddings
from tti_eval.utils import disable_tqdm, enable_tqdm
from tti_eval.utils import disable_tqdm

from .base import EvaluationModel

Expand Down Expand Up @@ -46,28 +46,29 @@ def __init__(
self._class_counts[class_ids] = counts

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index

logger.info("knn classifier index_infos", extra=self.index_infos)
d = self._val_embeddings.images.shape[-1]
self._index = IndexFlatL2(d)
self._index.add(self._val_embeddings.images)

def evaluate(self) -> float:
_, nearest_indices = self._index.search(self._train_embeddings.images, self.k) # type: ignore
nearest_classes = self._val_embeddings.labels[nearest_indices]

# To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved,
# where Q represents the size of the respective class in the validation embeddings
top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k)
top_nearest_per_class = np.where(
self._class_counts < self.k, self._class_counts, self.k
)
top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels]

# Add a placeholder value for indices outside the retrieval scope
nearest_classes[np.arange(self.k) >= top_nearest_per_sample[:, np.newaxis]] = -1

# Count the number of neighbours that match the class of the sample and compute the mean accuracy
matches_per_sample = np.sum(nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], axis=1)
matches_per_sample = np.sum(
nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis],
axis=1,
)
accuracies = np.divide(
matches_per_sample,
top_nearest_per_sample,
Expand Down
20 changes: 8 additions & 12 deletions tti_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any

import numpy as np
from autofaiss import build_index
from faiss import IndexFlatL2

from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
from tti_eval.utils import disable_tqdm, enable_tqdm
from tti_eval.utils import disable_tqdm

from .base import ClassificationModel
from .utils import softmax
Expand Down Expand Up @@ -44,15 +44,9 @@ def __init__(
super().__init__(train_embeddings, validation_embeddings, num_classes)
self.k = k
disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(
train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR
)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index

logger.info("knn classifier index_infos", extra=self.index_infos)
d = train_embeddings.images.shape[-1]
self._index = IndexFlatL2(d)
self._index.add(train_embeddings.images)

@staticmethod
def get_default_params() -> dict[str, Any]:
Expand All @@ -65,7 +59,9 @@ def predict(self) -> tuple[ProbabilityArray, ClassArray]:
# Calculate class votes from the distances (avoiding division by zero)
# Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors
max_value = np.finfo(np.float32).max
scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0)
scores = np.divide(
1, dists, out=np.full_like(dists, max_value), where=dists != 0
)
# NOTE: if self.k and self.num_classes are both large, this might become a big one.
# We can shape of a factor self.k if we count differently here.
n = len(self._val_embeddings.images)
Expand Down

0 comments on commit cf09968

Please sign in to comment.