Skip to content

Commit

Permalink
chore: shuffle data on the generated HF splits
Browse files Browse the repository at this point in the history
Generated splits used the data in the input order, which caused skewed datasets, e.g. containing only one class because the last 15% of the train dataset only had samples from one class. Shuffling the data when generating the splits is enough to prevent this dependency from the input order in Hugging Face datasets.
  • Loading branch information
eloy-encord committed May 8, 2024
1 parent 219f89f commit 56fb13e
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions tti_eval/dataset/types/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datasets import ClassLabel, DatasetDict, Sequence, Value, load_dataset
from datasets import Dataset as _RemoteHFDataset

from tti_eval.dataset import Dataset, Split

Expand Down Expand Up @@ -29,47 +30,58 @@ def set_transform(self, transform):
super().set_transform(transform)
self._dataset.set_transform(transform)

def _get_available_splits(self, **kwargs) -> list[Split]:
datasets: DatasetDict = load_dataset(self.title_in_source, cache_dir=self._cache_dir.as_posix(), **kwargs)
return list(Split(s) for s in datasets.keys() if s in [_ for _ in Split]) + [Split.ALL]
@staticmethod
def _get_available_splits(dataset_dict: DatasetDict) -> list[Split]:
return [Split(s) for s in dataset_dict.keys() if s in [_ for _ in Split]] + [Split.ALL]

def _setup(self, **kwargs):
def _get_hf_dataset_split(self, **kwargs) -> _RemoteHFDataset:
try:
available_splits = self._get_available_splits(**kwargs)
if self.split == Split.ALL: # Retrieve all the dataset data if the split is ALL
return load_dataset(self.title_in_source, split="all", cache_dir=self._cache_dir.as_posix(), **kwargs)
dataset_dict: DatasetDict = load_dataset(
self.title_in_source,
cache_dir=self._cache_dir.as_posix(),
**kwargs,
)
except Exception as e:
raise ValueError(f"Failed to load dataset from Hugging Face: `{self.title_in_source}`") from e

available_splits = HFDataset._get_available_splits(dataset_dict)
missing_splits = [s for s in Split if s not in available_splits]
if self.split == Split.TRAIN and self.split in missing_splits:
# Train split must always exist
raise AttributeError(f"Missing train split in Hugging Face dataset `{self.title_in_source}`")

# Select appropriate HF dataset split
hf_split: str
if self.split == Split.ALL:
hf_split = str(self.split)
elif self.split == Split.TEST:
hf_split = f"train[{85}%:]" if self.split in missing_splits else str(self.split)
elif self.split == Split.VALIDATION:
# The range of the validation split in the training data may vary depending on whether
# the test split is also missing
hf_split = (
f"train[{100 - 15 * len(missing_splits)}%:{100 - 15 * (len(missing_splits) - 1)}%]"
if self.split in missing_splits
else str(self.split)
)
elif self.split == Split.TRAIN:
# Take into account the capacity taken by missing splits
hf_split = f"train[:{100 - 15 * len(missing_splits)}%]"

# Return target dataset if it already exists and won't be modified
if self.split in [Split.VALIDATION, Split.TEST] and self.split in available_splits:
return dataset_dict[self.split]
if self.split == Split.TRAIN:
if self.split in missing_splits:
# Train split must always exist
raise AttributeError(f"Missing train split in Hugging Face dataset: `{self.title_in_source}`")
if not missing_splits:
# No need to split the train dataset, can be returned as a whole
return dataset_dict[self.split]

# Get a 15% of the train dataset for each missing split (VALIDATION, TEST or both)
# This operation includes data shuffling to prevent splits with skewed class counts because of the input order
split_percent = 0.15 * len(missing_splits)
split_seed = 42
# Split the original train dataset into two, the final train dataset and the missing splits dataset
split_to_dataset = dataset_dict["train"].train_test_split(test_size=split_percent, seed=split_seed)
if self.split == Split.TRAIN:
return split_to_dataset[self.split]

if len(missing_splits) == 1:
# One missing split (either VALIDATION or TEST), so we return the 15% stored in "test"
return split_to_dataset["test"]
else:
raise ValueError(f"Unhandled split type `{self.split}`")

self._dataset = load_dataset(
self.title_in_source,
split=hf_split,
cache_dir=self._cache_dir.as_posix(),
**kwargs,
)
# Both VALIDATION and TEST splits are missing
# Each one will take a half of the 30% stored in "test"
if self.split == Split.VALIDATION:
return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["train"]
else:
return split_to_dataset["test"].train_test_split(test_size=0.5, seed=split_seed)["test"]

def _setup(self, **kwargs):
self._dataset = self._get_hf_dataset_split(**kwargs)

if self._target_feature not in self._dataset.features:
raise ValueError(f"The dataset `{self.title}` does not have the target feature `{self._target_feature}`")
Expand Down

0 comments on commit 56fb13e

Please sign in to comment.