From 76c6cb2d0a50a68d0063dea44d6ae22cb2370b8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= <99720527+eloy-encord@users.noreply.github.com> Date: Wed, 8 May 2024 13:03:10 +0100 Subject: [PATCH] fix: empty dataset splits (#73) If a dataset is small enough or the values for the train/val/test splits is very small, it may be possible that no piece of data is assigned to that split. This change enforces that no split is empty, while it may have size 1 and thus not really useful, but it won't break the execution and will leave the responsibility to the user to take action in order to fix the disparity caused by his/her actions or the very small dataset size. --- tti_eval/cli/main.py | 2 +- tti_eval/dataset/types/encord_ds.py | 4 ++-- tti_eval/dataset/utils.py | 15 ++++++++++++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tti_eval/cli/main.py b/tti_eval/cli/main.py index 5b8862e..93be060 100644 --- a/tti_eval/cli/main.py +++ b/tti_eval/cli/main.py @@ -58,7 +58,7 @@ def build_command( embd_defn.save_embeddings(embeddings=embeddings, split=split, overwrite=True) print(f"Embeddings saved successfully to file at `{embd_defn.embedding_path(split)}`") except Exception as e: - print(f"Failed to build embeddings for this bastard: {embd_defn}") + print(f"Failed to build embeddings for {embd_defn} on the specified split {split}") print(e) import traceback diff --git a/tti_eval/dataset/types/encord_ds.py b/tti_eval/dataset/types/encord_ds.py index bfaec85..4c70b21 100644 --- a/tti_eval/dataset/types/encord_ds.py +++ b/tti_eval/dataset/types/encord_ds.py @@ -108,8 +108,8 @@ def _setup( ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") if ssh_key_path is None: raise ValueError( - "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing." - "Please set one of them to proceed" + "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing. " + "Please set one of them to proceed." ) client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) self._project = client.get_project(project_hash) diff --git a/tti_eval/dataset/utils.py b/tti_eval/dataset/utils.py index 21ae9a0..ee6d736 100644 --- a/tti_eval/dataset/utils.py +++ b/tti_eval/dataset/utils.py @@ -95,16 +95,25 @@ def simple_random_split( :raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1, or if `train_split` or `validation_split` are less than 0. """ + if dataset_size < 3: + raise ValueError(f"Expected a dataset with size at least 3, got {dataset_size}") + if train_split < 0 or validation_split < 0: raise ValueError(f"Expected positive splits, got ({train_split=}, {validation_split=})") - if train_split + validation_split > 1: + if train_split + validation_split >= 1: raise ValueError( f"Expected `train_split` and `validation_split` sum between 0 and 1, got {train_split + validation_split}" ) rng = np.random.default_rng(seed) selection = rng.permutation(dataset_size) - train_size = int(dataset_size * train_split) - validation_size = int(dataset_size * validation_split) + train_size = max(1, int(dataset_size * train_split)) + validation_size = max(1, int(dataset_size * validation_split)) + # Ensure that the TEST split has at least an element + if train_size + validation_size == dataset_size: + if train_size > 1: + train_size -= 1 + if validation_size > 1: + validation_size -= 1 return { Split.TRAIN: selection[:train_size], Split.VALIDATION: selection[train_size : train_size + validation_size],