From 5a50e8ec923f3cb2fadc77923ea4be894948dcf2 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Mon, 10 Feb 2025 17:46:25 -0800 Subject: [PATCH] Training Stability: Patch HF Hub and Datasets methods and update datasets.py (#280) * bring over ZQ's apply_all_patches * update datasets * update types and tests --- ultravox/data/__init__.py | 1 + ultravox/data/configs/librispeech.py | 10 +- ultravox/data/configs/peoplespeech.py | 3 +- ultravox/data/datasets.py | 153 ++++++++++++++++++++------ ultravox/data/datasets_test.py | 109 ++++++++++-------- ultravox/data/registry.py | 10 +- ultravox/data/types.py | 81 ++++++++++++-- ultravox/training/train.py | 3 + ultravox/utils/monkey_patches.py | 131 ++++++++++++++++++++++ 9 files changed, 403 insertions(+), 98 deletions(-) create mode 100644 ultravox/utils/monkey_patches.py diff --git a/ultravox/data/__init__.py b/ultravox/data/__init__.py index 8505c985..7db57896 100644 --- a/ultravox/data/__init__.py +++ b/ultravox/data/__init__.py @@ -12,6 +12,7 @@ "VoiceDataset", "VoiceDatasetArgs", "VoiceSample", + "DatasetOptions", "create_dataset", "register_datasets", ] diff --git a/ultravox/data/configs/librispeech.py b/ultravox/data/configs/librispeech.py index 0fd9c083..cca01291 100644 --- a/ultravox/data/configs/librispeech.py +++ b/ultravox/data/configs/librispeech.py @@ -13,11 +13,12 @@ subset="clean", splits=[ types.DatasetSplitConfig( - name="train.100", num_samples=28_539, split_type=types.DatasetSplit.TRAIN + name="train.100", num_samples=28_539, split=types.DatasetSplit.TRAIN ), types.DatasetSplitConfig( - name="train.360", num_samples=104_014, split_type=types.DatasetSplit.TRAIN + name="train.360", num_samples=104_014, split=types.DatasetSplit.TRAIN ), + types.DatasetSplitConfig(name="test", num_samples=2620), ], ) @@ -27,8 +28,9 @@ subset="other", splits=[ types.DatasetSplitConfig( - name="train.500", num_samples=148_688, split_type=types.DatasetSplit.TRAIN + name="train.500", num_samples=148_688, split=types.DatasetSplit.TRAIN ), + types.DatasetSplitConfig(name="test", num_samples=2939), ], ) @@ -36,12 +38,14 @@ name="librispeech-clean-transcription", base="librispeech-clean", user_template=types.TRANSCRIPTION_USER_TEMPLATE, + eval_config=types.EvalConfig(metric="wer", args={"lang_id": "en"}), ) LS_OTHER_TRANS_CONFIG = types.DatasetConfig( name="librispeech-other-transcription", base="librispeech-other", user_template=types.TRANSCRIPTION_USER_TEMPLATE, + eval_config=types.EvalConfig(metric="wer", args={"lang_id": "en"}), ) LS_CLEAN_CONT_CONFIG = types.DatasetConfig( diff --git a/ultravox/data/configs/peoplespeech.py b/ultravox/data/configs/peoplespeech.py index 36f8e8f8..1165ee55 100644 --- a/ultravox/data/configs/peoplespeech.py +++ b/ultravox/data/configs/peoplespeech.py @@ -7,7 +7,7 @@ splits=[ types.DatasetSplitConfig(name="train", num_samples=1_501_271), types.DatasetSplitConfig( - name="test", num_samples=34_898, split_type=types.DatasetSplit.VALIDATION + name="test", num_samples=34_898, split=types.DatasetSplit.VALIDATION ), ], assistant_template="{{text_proc.format_asr_text(text)}}", @@ -18,6 +18,7 @@ name="peoplespeech-clean-transcription", base="peoplespeech", user_template=types.TRANSCRIPTION_USER_TEMPLATE, + eval_config=types.EvalConfig(metric="wer", args={"lang_id": "en"}), ) PS_CONT_CONFIG = types.DatasetConfig( diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 5badf0c5..9010868d 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -97,7 +97,16 @@ class SizedIterableDataset(abc.ABC, data.IterableDataset): """ @abc.abstractmethod - def __len__(self): + def __len__(self) -> int: + pass + + @abc.abstractmethod + def __str__(self) -> str: + pass + + @property + @abc.abstractmethod + def name(self) -> str: pass @@ -111,14 +120,27 @@ def __init__(self, args: types.VoiceDatasetArgs) -> None: super().__init__() self._args = args self._rng = np.random.default_rng(self._args.shuffle_seed) + self._name = "[unset]" + self._length = -1 - def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None: + # num_samples is the total number of samples in the dataset + def _init_dataset( + self, + dataset: data.Dataset, + name: str, + num_samples: int, + ) -> None: self._dataset = dataset + self._name = name self._length = num_samples def __len__(self): return self._length + @property + def name(self): + return self._name + def _load_hf_dataset( self, path: str, @@ -172,40 +194,39 @@ def _load_mds_dataset( def __iter__(self): actual_length = 0 - for _, row in enumerate(self._dataset): + skipped_samples = 0 + bad_samples = 0 + dataset_iter = iter(self._dataset) + for row in dataset_iter: + actual_length += 1 sample = self._get_sample(row) if sample is None: - raise ValueError( - f"Sample is None in dataset {self._config.alias} for row {row}" - ) + print(f"Sample is None in dataset {self._config.alias} for row {row}") + bad_samples += 1 + continue # Skip this sample and proceed to the next if self._args.include_audio: - # If audio_field is set, make sure the sample has audio data. if sample.audio is None: - raise ValueError(f"Audio is None for sample {sample}") + print(f"Audio is None for sample {sample}") + bad_samples += 1 + continue # Skip this sample if sample.audio.shape[-1] == 0: - raise ValueError(f"Audio length is 0 for sample {sample}") + print(f"Audio length is 0 for sample {sample}") + bad_samples += 1 + continue # Skip this sample if ( self._args.max_audio_duration_secs is not None and sample.audio.shape[-1] / data_sample.SAMPLE_RATE > self._args.max_audio_duration_secs ): - duration = sample.audio.shape[-1] / data_sample.SAMPLE_RATE - warnings.warn( - f"Audio length ({duration}s) exceeds max audio duration ({self._args.max_audio_duration_secs}s), skipping sample." - ) - continue + skipped_samples += 1 + continue # Skip this sample yield sample - actual_length += 1 - if actual_length == len(self) + 1: - warnings.warn( - f"The presumed length {self._length} has been exceeded for {self._config.name}:{self._args.split.value}. Make sure to update." - ) - if actual_length != len(self): - warnings.warn( - f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for {self._config.name}:{self._args.split.value}. Make sure to update." - ) + + logging.info( + f"Extracted {actual_length} samples from {self.name} (total: {len(self)}), removed {bad_samples} bad samples, and skipped {skipped_samples} samples for exceeding max audio duration ({self._args.max_audio_duration_secs}s)." + ) @abc.abstractmethod def _get_sample( @@ -262,7 +283,7 @@ def __init__( dsets = [] total_samples = 0 for split in config.splits: - if split.split_type == self._args.split: + if split.split == self._args.split: if not config.use_mds: ds = self._load_hf_dataset( config.path, @@ -283,7 +304,13 @@ def __init__( len(dsets) > 0 ), f"The {config.name} dataset has no {self._args.split} splits." dataset = ds if len(dsets) == 1 else hf_datasets.concatenate_datasets(dsets) - super()._init_dataset(dataset, total_samples) + + dataset_name = f"{config.name}.{self._args.split.value}" + + super()._init_dataset(dataset, dataset_name, total_samples) + + def __str__(self): + return f"GenericDataset({self._config})" def _get_sample(self, row) -> Optional[data_sample.VoiceSample]: assert self._config.user_template is not None @@ -296,15 +323,14 @@ def _get_sample(self, row) -> Optional[data_sample.VoiceSample]: ).render( **row, text_proc=text_proc, - dataset=self, **self._config.user_template_args, ) assistant_content = jinja2.Template( self._config.assistant_template, undefined=jinja2.StrictUndefined - ).render(**row, text_proc=text_proc, dataset=self) + ).render(**row, text_proc=text_proc) transcript = jinja2.Template( self._config.transcript_template, undefined=jinja2.StrictUndefined - ).render(**row, text_proc=text_proc, dataset=self) + ).render(**row, text_proc=text_proc) except jinja2.TemplateError as e: print(f"Error rendering template: {e}") print(f"user_template: {self._config.user_template}") @@ -322,10 +348,13 @@ def _get_sample(self, row) -> Optional[data_sample.VoiceSample]: audio = self._get_audio(row, self._config.audio_field) return self._make_sample(messages, audio, audio_transcript=transcript) + def get_config(self): + return self._config + -class LibriSpeechDummyDataset(VoiceDataset): +class LibriSpeechDummyDataset(GenericDataset): def __init__(self, args: types.VoiceDatasetArgs) -> None: - super().__init__(args) + VoiceDataset.__init__(self, args) # This dataset doesn't support streaming. dataset = self._load_hf_dataset( "hf-internal-testing/librispeech_asr_dummy", @@ -333,7 +362,20 @@ def __init__(self, args: types.VoiceDatasetArgs) -> None: split="validation", streaming=False, ) - self._init_dataset(dataset, 73) + self._init_dataset(dataset, "dummy", 73) + + def __str__(self): + return "LibriSpeechDummyDataset" + + @property + def name(self): + return "dummy" + + def get_config(self): + return types.DatasetConfig( + name="dummy", + path="hf-internal-testing/librispeech_asr_dummy", + ) def _get_sample( self, row: transformers.BatchFeature @@ -360,6 +402,13 @@ def __iter__(self): def __len__(self): return self._length + def __str__(self): + return f"EmptyDataset(length={self._length})" + + @property + def name(self): + return "empty" + class InterleaveDataset(SizedIterableDataset): """Interleaves multiple SizedIterableDataset objects based on normalized weights.""" @@ -380,6 +429,7 @@ def __init__( assert len(weights) == len(datasets) else: weights = [1.0] * len(datasets) + self._weights = weights self._weighted_samples = [int(w * len(d)) for w, d in zip(weights, datasets)] self._total_samples = sum(self._weighted_samples) @@ -398,12 +448,25 @@ def __iter__(self): yield next(ds_iters[iter_index]) except StopIteration: ds_iters[iter_index] = iter(self._datasets[iter_index]) - yield next(ds_iters[iter_index]) + try: + yield next(ds_iters[iter_index]) + except StopIteration: + warnings.warn( + f"Dataset {iter_index} is empty. num_workers is likely too high. Stopping iteration." + ) + break ds_pos[iter_index] += 1 def __len__(self): return self._total_samples + def __str__(self): + return "+".join([f"{d}:{w:.2f}" for w, d in zip(self._weights, self._datasets)]) + + @property + def name(self): + return "+".join([ds.name for ds in self._datasets]) + class Dataproc(SizedIterableDataset): """Base class to preprocess a dataset of VoiceSamples.""" @@ -421,6 +484,13 @@ def __iter__(self): def __len__(self): return len(self._dataset) + def __str__(self): + return f"Dataproc({self._dataset})" + + @property + def name(self): + return self._dataset.name + class Range(SizedIterableDataset): """Limits the number of samples from another dataset.""" @@ -431,7 +501,11 @@ def __init__( self._dataset = dataset self._length = num_samples or len(dataset) if self._length > len(dataset): - raise ValueError("num_samples exceeds dataset length.") + warnings.warn( + f"num_samples ({self._length}) exceeds dataset length ({len(dataset)}). Truncating to {len(dataset)}." + ) + self._length = len(dataset) + self._name = f"{dataset.name}.{self._length}" def __iter__(self): for i, sample in enumerate(self._dataset): @@ -439,5 +513,18 @@ def __iter__(self): break yield sample + def __str__(self): + return f"Range({self._dataset}%{len(self)})" + def __len__(self): return self._length + + @property + def name(self): + return self._name + + def get_config(self): + if isinstance(self._dataset, GenericDataset): + return self._dataset.get_config() + else: + raise ValueError("Cannot get config for non-GenericDataset") diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 3fd9689d..06db8b35 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -27,6 +27,13 @@ def __iter__(self): def __len__(self): return self._length + def __str__(self): + return "FakeSizedIterableDataset" + + @property + def name(self): + return "fake" + class FakeHuggingFaceIterableDataset(hf_datasets.IterableDataset): """Fake version of an ASR Hugging Face IterableDataset.""" @@ -50,12 +57,19 @@ class FakeTranscribeDataset(datasets.VoiceDataset): def __init__(self, n: int, args: Optional[types.VoiceDatasetArgs] = None): super().__init__(args or types.VoiceDatasetArgs()) - self._init_dataset(FakeHuggingFaceIterableDataset(n), n) + self._init_dataset(FakeHuggingFaceIterableDataset(n), "fake", n) def _get_sample(self, row: BatchFeature) -> Optional[data_sample.VoiceSample]: messages = self._make_messages("<|audio|>", row["text"]) return self._make_sample(messages, np.zeros(256), row["text"]) + def __str__(self): + return "FakeTranscribeDataset" + + @property + def name(self): + return "fake_transcribe" + class FakeGenericDataset(datasets.GenericDataset): """Fake version of GenericDataset, hooked to return a FakeHuggingFaceIterableDataset.""" @@ -141,7 +155,7 @@ def test_range(): s = datasets.Range(ds, 5) assert len(s) == 5 assert list(s) == [0, 1, 2, 3, 4] - with pytest.raises(ValueError, match="exceeds dataset length"): + with pytest.warns(UserWarning, match="exceeds dataset length"): s = datasets.Range(ds, 100) s = datasets.Range(ds, 10) assert list(s) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -170,16 +184,16 @@ def test_dataset_config(): path="mock_path", splits=[ types.DatasetSplitConfig( - name="clean", num_samples=5000, split_type=types.DatasetSplit.TRAIN + name="clean", num_samples=5000, split=types.DatasetSplit.TRAIN ), types.DatasetSplitConfig( - name="other", num_samples=10000, split_type=types.DatasetSplit.TRAIN + name="other", num_samples=10000, split=types.DatasetSplit.TRAIN ), types.DatasetSplitConfig(name="validation", num_samples=1000), types.DatasetSplitConfig( name="another_validation", num_samples=1000, - split_type=types.DatasetSplit.VALIDATION, + split=types.DatasetSplit.VALIDATION, ), ], ) @@ -188,16 +202,16 @@ def test_dataset_config(): assert len(config.splits) == 4 assert config.splits[0].name == "clean" assert config.splits[0].num_samples == 5000 - assert config.splits[0].split_type == types.DatasetSplit.TRAIN + assert config.splits[0].split == types.DatasetSplit.TRAIN assert config.splits[1].name == "other" assert config.splits[1].num_samples == 10000 - assert config.splits[1].split_type == types.DatasetSplit.TRAIN + assert config.splits[1].split == types.DatasetSplit.TRAIN assert config.splits[2].name == "validation" assert config.splits[2].num_samples == 1000 - assert config.splits[2].split_type == types.DatasetSplit.VALIDATION + assert config.splits[2].split == types.DatasetSplit.VALIDATION assert config.splits[3].name == "another_validation" assert config.splits[3].num_samples == 1000 - assert config.splits[3].split_type == types.DatasetSplit.VALIDATION + assert config.splits[3].split == types.DatasetSplit.VALIDATION def test_dataset_config_serialization(): @@ -206,10 +220,10 @@ def test_dataset_config_serialization(): path="fake_path", splits=[ types.DatasetSplitConfig( - name="clean", num_samples=5000, split_type=types.DatasetSplit.TRAIN + name="clean", num_samples=5000, split=types.DatasetSplit.TRAIN ), types.DatasetSplitConfig( - name="other", num_samples=10000, split_type=types.DatasetSplit.TRAIN + name="other", num_samples=10000, split=types.DatasetSplit.TRAIN ), ], ) @@ -231,7 +245,7 @@ def test_generic_dataset(): path="fake_path", splits=[ types.DatasetSplitConfig( - name="fake", num_samples=5, split_type=types.DatasetSplit.TRAIN + name="fake", num_samples=5, split=types.DatasetSplit.TRAIN ) ], ) @@ -254,7 +268,7 @@ def test_generic_dataset_custom_templates(): path="fake_path", splits=[ types.DatasetSplitConfig( - name="fake", num_samples=5, split_type=types.DatasetSplit.TRAIN + name="fake", num_samples=5, split=types.DatasetSplit.TRAIN ) ], user_template="Listen to the following and respond with 'xyzzy':\n<|audio|>", @@ -283,7 +297,7 @@ def test_generic_dataset_text_only(): path="fake_path", splits=[ types.DatasetSplitConfig( - name="fake", num_samples=5, split_type=types.DatasetSplit.TRAIN + name="fake", num_samples=5, split=types.DatasetSplit.TRAIN ) ], user_template="Transcribe\n<|audio|>", @@ -305,7 +319,7 @@ def test_generic_dataset_merge_configs(): path="fake_path", splits=[ types.DatasetSplitConfig( - name="fake", num_samples=5, split_type=types.DatasetSplit.TRAIN + name="fake", num_samples=5, split=types.DatasetSplit.TRAIN ) ], ) @@ -327,7 +341,7 @@ def test_generic_dataset_merge_configs(): assert config.path == "fake_path" assert config.splits[0].name == "fake" assert config.splits[0].num_samples == 5 - assert config.splits[0].split_type == types.DatasetSplit.TRAIN + assert config.splits[0].split == types.DatasetSplit.TRAIN assert config.user_template == "fake_user_template" assert config.user_template_args == {"a": 1} assert config.assistant_template == "{{text}}" # the default @@ -335,37 +349,38 @@ def test_generic_dataset_merge_configs(): assert config.audio_field == "fake_audio_field" -def test_generic_dataset_length_mismatch(): - config = types.DatasetConfig( - name="fake_dataset", - path="fake_path", - splits=[ - types.DatasetSplitConfig( - name="fake", num_samples=5, split_type=types.DatasetSplit.TRAIN - ) - ], - ) - ds = FakeGenericDataset(10, config) - assert len(ds) == 5 - - pattern = r"(has been exceeded|Mismatch between presumed length)" - with pytest.warns(UserWarning, match=pattern): - list(ds) - - config = types.DatasetConfig( - name="fake_dataset", - path="fake_path", - splits=[ - types.DatasetSplitConfig( - name="fake", num_samples=10, split_type=types.DatasetSplit.TRAIN - ) - ], - ) - ds = FakeGenericDataset(5, config) - assert len(ds) == 10 - - with pytest.warns(UserWarning, match="Mismatch between presumed length"): - list(ds) +# This test is disabled as we don't have a good way to measure the actual length of the dataset when num_workers > 1 +# def test_generic_dataset_length_mismatch(): +# config = types.DatasetConfig( +# name="fake_dataset", +# path="fake_path", +# splits=[ +# types.DatasetSplitConfig( +# name="fake", num_samples=5, split=types.DatasetSplit.TRAIN +# ) +# ], +# ) +# ds = FakeGenericDataset(10, config) +# assert len(ds) == 5 + +# pattern = r"(has been exceeded|Mismatch between presumed length)" +# with pytest.warns(UserWarning, match=pattern): +# list(ds) + +# config = types.DatasetConfig( +# name="fake_dataset", +# path="fake_path", +# splits=[ +# types.DatasetSplitConfig( +# name="fake", num_samples=10, split=types.DatasetSplit.TRAIN +# ) +# ], +# ) +# ds = FakeGenericDataset(5, config) +# assert len(ds) == 10 + +# with pytest.warns(UserWarning, match="Mismatch between presumed length"): +# list(ds) def test_generic_dataset_multiple_splits(): diff --git a/ultravox/data/registry.py b/ultravox/data/registry.py index 4788e132..925e6c0f 100644 --- a/ultravox/data/registry.py +++ b/ultravox/data/registry.py @@ -1,4 +1,5 @@ import dataclasses +import logging from typing import Dict, List, Optional from ultravox.data import datasets @@ -41,8 +42,8 @@ def _merge_configs(configs: List[types.DatasetConfig]) -> types.DatasetConfig: def create_dataset( - name: str, args: types.VoiceDatasetArgs -) -> datasets.SizedIterableDataset: + name: str, args: types.VoiceDatasetArgs, verbose: bool = False +) -> datasets.GenericDataset: if name == "dummy": return datasets.LibriSpeechDummyDataset(args) assert name in DATASET_MAP, f"Unknown dataset: {name}" @@ -60,7 +61,10 @@ def create_dataset( raise ValueError(f"Dataset {name} has no path") if not merged_config.splits: raise ValueError(f"Dataset {name} has no splits") - return datasets.GenericDataset(args, merged_config) + if verbose: + logging.info(f"Creating dataset {name} with config:\n{merged_config}") + dataset = datasets.GenericDataset(args, merged_config) + return dataset register_datasets(boolq.configs) diff --git a/ultravox/data/types.py b/ultravox/data/types.py index 2d24559e..1ebc909f 100644 --- a/ultravox/data/types.py +++ b/ultravox/data/types.py @@ -1,6 +1,7 @@ import dataclasses import enum -from typing import Dict, List, Optional +import json +from typing import Any, Dict, List, Optional from simple_parsing import helpers @@ -11,7 +12,10 @@ f"Continue the following text using less than 50 words:\n\n{AUDIO_PLACEHOLDER}" ) CONTINUATION_ASSISTANT_TEMPLATE = "{{continuation}}" -TRANSCRIPTION_USER_TEMPLATE = f"Transcribe\n{AUDIO_PLACEHOLDER}" +QA_USER_TEMPLATE = f"Answer the following question:\n\n{AUDIO_PLACEHOLDER}" +TRANSCRIPTION_USER_TEMPLATE = ( + f"Repeat the following text, without any explanation: {AUDIO_PLACEHOLDER}" +) class DatasetSplit(str, enum.Enum): @@ -20,26 +24,68 @@ class DatasetSplit(str, enum.Enum): TEST = "test" +@dataclasses.dataclass +class DatasetOptions: + name: str + weight: float = 1.0 + + @dataclasses.dataclass class VoiceDatasetArgs: - """Global arguments for voice datasets.""" + """Global arguments for train/val/test dataset creation.""" - batch_size: int = 4 - """Batch size for train, eval, or validation.""" + split: DatasetSplit = DatasetSplit.TRAIN + """Which split of the dataset to use.""" include_audio: bool = True """Whether to include audio in the samples.""" shuffle: bool = False """Whether to shuffle the dataset.""" shuffle_seed: int = 42 """Seed for shuffling the dataset.""" - max_audio_duration_secs: Optional[float] = None + shuffle_buffer_size: int = 1000 + """Buffer size for shuffling the dataset. Only used for streaming datasets.""" + max_audio_duration_secs: Optional[float] = 16 """Whether to skip samples with audio longer than this duration.""" - split: DatasetSplit = DatasetSplit.TRAIN - """Which split of the dataset to use.""" + max_samples: Optional[int] = None + """max number of samples to use per dataset""" def __post_init__(self): if isinstance(self.split, str): self.split = DatasetSplit(self.split.lower()) + if self.max_audio_duration_secs and self.max_audio_duration_secs < 0: + self.max_audio_duration_secs = None + + +@dataclasses.dataclass +class TrainDatasetArgs(VoiceDatasetArgs): + split: DatasetSplit = DatasetSplit.TRAIN + shuffle: bool = True + + def __post_init__(self): + super().__post_init__() + assert self.split == DatasetSplit.TRAIN + + +@dataclasses.dataclass +class ValDatasetArgs(VoiceDatasetArgs): + split: DatasetSplit = DatasetSplit.VALIDATION + max_samples: Optional[int] = 64 + + def __post_init__(self): + super().__post_init__() + assert self.split == DatasetSplit.VALIDATION + assert self.shuffle is False + + +@dataclasses.dataclass +class EvalDatasetArgs(VoiceDatasetArgs): + split: DatasetSplit = DatasetSplit.TEST + max_audio_duration_secs: Optional[float] = 30 + + def __post_init__(self): + super().__post_init__() + assert self.split == DatasetSplit.TEST + assert self.shuffle is False @dataclasses.dataclass @@ -48,20 +94,27 @@ class DatasetSplitConfig(helpers.Serializable): """Name of the split.""" num_samples: int """Number of samples in the split""" - split_type: Optional[DatasetSplit] = None + split: Optional[DatasetSplit] = None """Type of split, i.e., train, test, or validation.""" def __post_init__(self): """Automatically set split type based on split name""" - if self.split_type is None: + if self.split is None: try: - self.split_type = DatasetSplit(self.name.lower()) + self.split = DatasetSplit(self.name.lower()) except ValueError: raise ValueError( f"Could not automatically determine split type from split name '{self.name}'. Please explicitly specify split_type for splits that are not named 'train', 'validation', or 'test'." ) +# Eval config for a single metric, added to the dataset config +@dataclasses.dataclass +class EvalConfig(helpers.Serializable): + metric: str + args: Dict[str, Any] = dataclasses.field(default_factory=dict) + + @dataclasses.dataclass class DatasetConfig(helpers.Serializable): # Note that subclasses can override any of these fields, but they currently can't @@ -91,6 +144,8 @@ class DatasetConfig(helpers.Serializable): """Set to True to load the dataset from GCP (using MDS) instead of Hugging Face.""" mds_batch_size: Optional[int] = None """Batch size for the dataset when using MDS.""" + eval_config: Optional[EvalConfig] = None + """Eval config for the dataset.""" def __post_init__(self): """Set defaults only if this is a root config, so that said defaults in a subclass don't act as overrides.""" @@ -103,8 +158,12 @@ def __post_init__(self): "audio_field": "audio", "use_mds": False, "mds_batch_size": 32, + "eval_config": {}, } if self.base is None: for attr, default_value in DEFAULTS.items(): if getattr(self, attr) is None: setattr(self, attr, default_value) + + def __str__(self) -> str: + return json.dumps(self.to_dict(), indent=2) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 9bc33390..6b98682c 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -28,6 +28,7 @@ from ultravox.training import config_base from ultravox.training import ddp_utils from ultravox.training.helpers import prefetch_weights +from ultravox.utils import monkey_patches def prepare_dataset( @@ -61,6 +62,8 @@ def prepare_dataset( def main() -> None: + monkey_patches.apply_all_patches() + # Disable parallelism to avoid deadlocks in DataLoader, apparently # multiple processes are forked when using multiple datasets. os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/ultravox/utils/monkey_patches.py b/ultravox/utils/monkey_patches.py new file mode 100644 index 00000000..0cbe86d3 --- /dev/null +++ b/ultravox/utils/monkey_patches.py @@ -0,0 +1,131 @@ +import logging +from functools import wraps +from typing import Any, Type + +import datasets as hf_datasets +import huggingface_hub as hf_hub +import requests +import tenacity +from huggingface_hub.utils._typing import HTTP_METHOD_T + +logger = logging.getLogger(__name__) + +IS_PATCHED = False + + +def patch_with_retry(cls: Type[Any], method_name: str, max_attempts: int = 10) -> None: + """ + Generic function to patch any method with retry capability. + + Args: + cls: The class containing the method to patch + method_name: The name of the method to patch + max_attempts: Maximum number of retry attempts (default: 10) + """ + original_method = getattr(cls, method_name) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(max_attempts), + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), + retry=tenacity.retry_if_exception_type(Exception), + before_sleep=tenacity.before_sleep_log(logger, logging.INFO), + ) + @wraps(original_method) + def method_with_retry(self, *args, **kwargs): + return original_method(self, *args, **kwargs) + + # Apply the patch + setattr(cls, method_name, method_with_retry) + logger.info( + f"Applied retry patch to {cls.__name__}.{method_name} with max_attempts={max_attempts}" + ) + + +def patch_hf_hub_http_backoff(): + """ + Monkey patch the huggingface_hub http_backoff implementation to include the ChunkedEncodingError exception. + """ + original_http_backoff = hf_hub.hf_file_system.http_backoff + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), + retry=tenacity.retry_if_exception_type(Exception), + before_sleep=tenacity.before_sleep_log(logger, logging.INFO), + ) + def http_backoff( + method: HTTP_METHOD_T, + url: str, + *, + max_retries: int = 10, + retry_on_exceptions: tuple = ( + requests.Timeout, + requests.ConnectionError, + requests.exceptions.ChunkedEncodingError, + ), + **kwargs, + ) -> requests.Response: + return original_http_backoff( + method=method, + url=url, + max_retries=max_retries, + retry_on_exceptions=retry_on_exceptions, + **kwargs, + ) + + hf_hub.hf_file_system.http_backoff = http_backoff + logger.info( + "Applied retry patch to huggingface_hub http_backoff with ChunkedEncodingError support" + ) + + +def patch_audio_decoder(): + """ + Monkey-patch the datasets.Audio.decode_example method to handle errors gracefully. + When decoding fails, returns a dict with None for array and original path. + """ + # Store the original decode_example method + original_decode_example = hf_datasets.Audio.decode_example + + def safe_decode_example(self, value, token_per_repo_id=None): + try: + # Try to decode using the original method + return original_decode_example(self, value, token_per_repo_id) + except Exception as e: + logger.warning(f"Error decoding audio at path {value.get('path')}: {e}") + return { + "array": None, + "path": value.get("path", None), + "sampling_rate": self.sampling_rate or 16000, + } + + # Replace the original decode_example with our safe version + hf_datasets.Audio.decode_example = safe_decode_example + logger.info( + "Applied patch to datasets.Audio.decode_example for graceful error handling" + ) + + +def apply_all_patches(): + """ + Apply all patches at once. + """ + global IS_PATCHED + if IS_PATCHED: + return + + logger.info("Starting to apply patches...") + + # Patch HF Hub methods + patch_with_retry(hf_hub.HfApi, "dataset_info") + patch_with_retry(hf_hub.HfApi, "model_info") + patch_with_retry(hf_hub.HfApi, "repo_info") + + # Patch http_backoff + patch_hf_hub_http_backoff() + + # Patch audio decoder + patch_audio_decoder() + + IS_PATCHED = True + logger.info("Successfully applied all patches")