|
2 | 2 | from typing import Any |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | | -from autofaiss import build_index |
| 5 | +from faiss import IndexFlatL2 |
6 | 6 |
|
7 | 7 | from tti_eval.common import Embeddings |
8 | | -from tti_eval.utils import disable_tqdm, enable_tqdm |
| 8 | +from tti_eval.utils import disable_tqdm |
9 | 9 |
|
10 | 10 | from .base import EvaluationModel |
11 | 11 |
|
@@ -46,28 +46,29 @@ def __init__( |
46 | 46 | self._class_counts[class_ids] = counts |
47 | 47 |
|
48 | 48 | disable_tqdm() # Disable tqdm progress bar when building the index |
49 | | - index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR) |
50 | | - enable_tqdm() |
51 | | - if index is None: |
52 | | - raise ValueError("Failed to build an index for knn search") |
53 | | - self._index = index |
54 | | - |
55 | | - logger.info("knn classifier index_infos", extra=self.index_infos) |
| 49 | + d = self._val_embeddings.images.shape[-1] |
| 50 | + self._index = IndexFlatL2(d) |
| 51 | + self._index.add(self._val_embeddings.images) |
56 | 52 |
|
57 | 53 | def evaluate(self) -> float: |
58 | 54 | _, nearest_indices = self._index.search(self._train_embeddings.images, self.k) # type: ignore |
59 | 55 | nearest_classes = self._val_embeddings.labels[nearest_indices] |
60 | 56 |
|
61 | 57 | # To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved, |
62 | 58 | # where Q represents the size of the respective class in the validation embeddings |
63 | | - top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k) |
| 59 | + top_nearest_per_class = np.where( |
| 60 | + self._class_counts < self.k, self._class_counts, self.k |
| 61 | + ) |
64 | 62 | top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels] |
65 | 63 |
|
66 | 64 | # Add a placeholder value for indices outside the retrieval scope |
67 | 65 | nearest_classes[np.arange(self.k) >= top_nearest_per_sample[:, np.newaxis]] = -1 |
68 | 66 |
|
69 | 67 | # Count the number of neighbours that match the class of the sample and compute the mean accuracy |
70 | | - matches_per_sample = np.sum(nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], axis=1) |
| 68 | + matches_per_sample = np.sum( |
| 69 | + nearest_classes == np.array(self._train_embeddings.labels)[:, np.newaxis], |
| 70 | + axis=1, |
| 71 | + ) |
71 | 72 | accuracies = np.divide( |
72 | 73 | matches_per_sample, |
73 | 74 | top_nearest_per_sample, |
|
0 commit comments