diff --git a/tti_eval/cli/main.py b/tti_eval/cli/main.py index 5b8862e..93be060 100644 --- a/tti_eval/cli/main.py +++ b/tti_eval/cli/main.py @@ -58,7 +58,7 @@ def build_command( embd_defn.save_embeddings(embeddings=embeddings, split=split, overwrite=True) print(f"Embeddings saved successfully to file at `{embd_defn.embedding_path(split)}`") except Exception as e: - print(f"Failed to build embeddings for this bastard: {embd_defn}") + print(f"Failed to build embeddings for {embd_defn} on the specified split {split}") print(e) import traceback diff --git a/tti_eval/dataset/types/encord_ds.py b/tti_eval/dataset/types/encord_ds.py index bfaec85..4c70b21 100644 --- a/tti_eval/dataset/types/encord_ds.py +++ b/tti_eval/dataset/types/encord_ds.py @@ -108,8 +108,8 @@ def _setup( ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") if ssh_key_path is None: raise ValueError( - "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing." - "Please set one of them to proceed" + "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing. " + "Please set one of them to proceed." ) client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) self._project = client.get_project(project_hash) diff --git a/tti_eval/dataset/utils.py b/tti_eval/dataset/utils.py index 21ae9a0..ee6d736 100644 --- a/tti_eval/dataset/utils.py +++ b/tti_eval/dataset/utils.py @@ -95,16 +95,25 @@ def simple_random_split( :raises ValueError: If the sum of `train_split` and `validation_split` is greater than 1, or if `train_split` or `validation_split` are less than 0. """ + if dataset_size < 3: + raise ValueError(f"Expected a dataset with size at least 3, got {dataset_size}") + if train_split < 0 or validation_split < 0: raise ValueError(f"Expected positive splits, got ({train_split=}, {validation_split=})") - if train_split + validation_split > 1: + if train_split + validation_split >= 1: raise ValueError( f"Expected `train_split` and `validation_split` sum between 0 and 1, got {train_split + validation_split}" ) rng = np.random.default_rng(seed) selection = rng.permutation(dataset_size) - train_size = int(dataset_size * train_split) - validation_size = int(dataset_size * validation_split) + train_size = max(1, int(dataset_size * train_split)) + validation_size = max(1, int(dataset_size * validation_split)) + # Ensure that the TEST split has at least an element + if train_size + validation_size == dataset_size: + if train_size > 1: + train_size -= 1 + if validation_size > 1: + validation_size -= 1 return { Split.TRAIN: selection[:train_size], Split.VALIDATION: selection[train_size : train_size + validation_size],