Skip to content

Commit

Permalink
misc: silence unnecessary progress bars
Browse files Browse the repository at this point in the history
  • Loading branch information
eloy-encord committed May 9, 2024
1 parent 885a5ae commit 2f25b44
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
3 changes: 3 additions & 0 deletions tti_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autofaiss import build_index

from tti_eval.common import Embeddings
from tti_eval.utils import disable_tqdm, enable_tqdm

from .base import EvaluationModel

Expand Down Expand Up @@ -44,7 +45,9 @@ def __init__(
self._class_counts = np.zeros(self.num_classes, dtype=np.int32)
self._class_counts[class_ids] = counts

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index
Expand Down
9 changes: 4 additions & 5 deletions tti_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autofaiss import build_index

from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
from tti_eval.utils import disable_tqdm, enable_tqdm

from .base import ClassificationModel
from .utils import softmax
Expand Down Expand Up @@ -42,13 +43,11 @@ def __init__(
"""
super().__init__(train_embeddings, validation_embeddings, num_classes)
self.k = k

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(
train_embeddings.images,
metric_type="l2",
save_on_disk=False,
verbose=logging.ERROR,
train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR
)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index
Expand Down
11 changes: 11 additions & 0 deletions tti_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from functools import partialmethod
from itertools import chain
from typing import Literal, overload

from tqdm import tqdm

from tti_eval.common import EmbeddingDefinition
from tti_eval.constants import PROJECT_PATHS


def disable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)


def enable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)


@overload
def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]:
...
Expand Down

0 comments on commit 2f25b44

Please sign in to comment.