|
2 | 2 | from typing import Any
|
3 | 3 |
|
4 | 4 | import numpy as np
|
5 |
| -from autofaiss import build_index |
| 5 | +from faiss import IndexFlatL2 |
6 | 6 |
|
7 | 7 | from tti_eval.common import Embeddings
|
8 |
| -from tti_eval.utils import disable_tqdm, enable_tqdm |
| 8 | +from tti_eval.utils import disable_tqdm |
9 | 9 |
|
10 | 10 | from .base import EvaluationModel
|
11 | 11 |
|
@@ -46,28 +46,29 @@ def __init__(
|
46 | 46 | self._class_counts[class_ids] = counts
|
47 | 47 |
|
48 | 48 | disable_tqdm() # Disable tqdm progress bar when building the index
|
49 |
| - index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR) |
50 |
| - enable_tqdm() |
51 |
| - if index is None: |
52 |
| - raise ValueError("Failed to build an index for knn search") |
53 |
| - self._index = index |
54 |
| - |
55 |
| - logger.info("knn classifier index_infos", extra=self.index_infos) |
| 49 | + d = self._val_embeddings.images.shape[-1] |
| 50 | + self._index = IndexFlatL2(d) |
| 51 | + self._index.add(self._val_embeddings.images) |
56 | 52 |
|
57 | 53 | def evaluate(self) -> float:
|
58 | 54 | _, nearest_indices = self._index.search(self._train_embeddings.images, self.k) # type: ignore
|
59 | 55 | nearest_classes = self._val_embeddings.labels[nearest_indices]
|
60 | 56 |
|
61 | 57 | # To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved,
|
62 | 58 | # where Q represents the size of the respective class in the validation embeddings
|
63 |
| - top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k) |
| 59 | + top_nearest_per_class = np.where( |
| 60 | + self._class_counts < self.k, self._class_counts, self.k |
| 61 | + ) |
64 | 62 | top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels]
|
65 | 63 |
|
66 | 64 | # Add a placeholder value for indices outside the retrieval scope
|
67 | 65 | nearest_classes[np.arange(self.k) >= top_nearest_per_sample[:, np.newaxis]] = -1
|
68 | 66 |
|
69 | 67 | # Count the number of neighbours that match the class of the sample and compute the mean accuracy
|
70 |
| - matches_per_sample = np.sum(nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], axis=1) |
| 68 | + matches_per_sample = np.sum( |
| 69 | + nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], |
| 70 | + axis=1, |
| 71 | + ) |
71 | 72 | accuracies = np.divide(
|
72 | 73 | matches_per_sample,
|
73 | 74 | top_nearest_per_sample,
|
|
0 commit comments