Skip to content

Commit

Permalink
Training Stability: Patch HF Hub and Datasets methods and update data…
Browse files Browse the repository at this point in the history
…sets.py (#280)

* bring over ZQ's apply_all_patches

* update datasets

* update types and tests
  • Loading branch information
farzadab authored Feb 11, 2025
1 parent 227c775 commit 5a50e8e
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 98 deletions.
1 change: 1 addition & 0 deletions ultravox/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"VoiceDataset",
"VoiceDatasetArgs",
"VoiceSample",
"DatasetOptions",
"create_dataset",
"register_datasets",
]
10 changes: 7 additions & 3 deletions ultravox/data/configs/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
subset="clean",
splits=[
types.DatasetSplitConfig(
name="train.100", num_samples=28_539, split_type=types.DatasetSplit.TRAIN
name="train.100", num_samples=28_539, split=types.DatasetSplit.TRAIN
),
types.DatasetSplitConfig(
name="train.360", num_samples=104_014, split_type=types.DatasetSplit.TRAIN
name="train.360", num_samples=104_014, split=types.DatasetSplit.TRAIN
),
types.DatasetSplitConfig(name="test", num_samples=2620),
],
)

Expand All @@ -27,21 +28,24 @@
subset="other",
splits=[
types.DatasetSplitConfig(
name="train.500", num_samples=148_688, split_type=types.DatasetSplit.TRAIN
name="train.500", num_samples=148_688, split=types.DatasetSplit.TRAIN
),
types.DatasetSplitConfig(name="test", num_samples=2939),
],
)

LS_CLEAN_TRANS_CONFIG = types.DatasetConfig(
name="librispeech-clean-transcription",
base="librispeech-clean",
user_template=types.TRANSCRIPTION_USER_TEMPLATE,
eval_config=types.EvalConfig(metric="wer", args={"lang_id": "en"}),
)

LS_OTHER_TRANS_CONFIG = types.DatasetConfig(
name="librispeech-other-transcription",
base="librispeech-other",
user_template=types.TRANSCRIPTION_USER_TEMPLATE,
eval_config=types.EvalConfig(metric="wer", args={"lang_id": "en"}),
)

LS_CLEAN_CONT_CONFIG = types.DatasetConfig(
Expand Down
3 changes: 2 additions & 1 deletion ultravox/data/configs/peoplespeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
splits=[
types.DatasetSplitConfig(name="train", num_samples=1_501_271),
types.DatasetSplitConfig(
name="test", num_samples=34_898, split_type=types.DatasetSplit.VALIDATION
name="test", num_samples=34_898, split=types.DatasetSplit.VALIDATION
),
],
assistant_template="{{text_proc.format_asr_text(text)}}",
Expand All @@ -18,6 +18,7 @@
name="peoplespeech-clean-transcription",
base="peoplespeech",
user_template=types.TRANSCRIPTION_USER_TEMPLATE,
eval_config=types.EvalConfig(metric="wer", args={"lang_id": "en"}),
)

PS_CONT_CONFIG = types.DatasetConfig(
Expand Down
153 changes: 120 additions & 33 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,16 @@ class SizedIterableDataset(abc.ABC, data.IterableDataset):
"""

@abc.abstractmethod
def __len__(self):
def __len__(self) -> int:
pass

@abc.abstractmethod
def __str__(self) -> str:
pass

@property
@abc.abstractmethod
def name(self) -> str:
pass


Expand All @@ -111,14 +120,27 @@ def __init__(self, args: types.VoiceDatasetArgs) -> None:
super().__init__()
self._args = args
self._rng = np.random.default_rng(self._args.shuffle_seed)
self._name = "[unset]"
self._length = -1

def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None:
# num_samples is the total number of samples in the dataset
def _init_dataset(
self,
dataset: data.Dataset,
name: str,
num_samples: int,
) -> None:
self._dataset = dataset
self._name = name
self._length = num_samples

def __len__(self):
return self._length

@property
def name(self):
return self._name

def _load_hf_dataset(
self,
path: str,
Expand Down Expand Up @@ -172,40 +194,39 @@ def _load_mds_dataset(

def __iter__(self):
actual_length = 0
for _, row in enumerate(self._dataset):
skipped_samples = 0
bad_samples = 0
dataset_iter = iter(self._dataset)
for row in dataset_iter:
actual_length += 1
sample = self._get_sample(row)
if sample is None:
raise ValueError(
f"Sample is None in dataset {self._config.alias} for row {row}"
)
print(f"Sample is None in dataset {self._config.alias} for row {row}")
bad_samples += 1
continue # Skip this sample and proceed to the next

if self._args.include_audio:
# If audio_field is set, make sure the sample has audio data.
if sample.audio is None:
raise ValueError(f"Audio is None for sample {sample}")
print(f"Audio is None for sample {sample}")
bad_samples += 1
continue # Skip this sample
if sample.audio.shape[-1] == 0:
raise ValueError(f"Audio length is 0 for sample {sample}")
print(f"Audio length is 0 for sample {sample}")
bad_samples += 1
continue # Skip this sample
if (
self._args.max_audio_duration_secs is not None
and sample.audio.shape[-1] / data_sample.SAMPLE_RATE
> self._args.max_audio_duration_secs
):
duration = sample.audio.shape[-1] / data_sample.SAMPLE_RATE
warnings.warn(
f"Audio length ({duration}s) exceeds max audio duration ({self._args.max_audio_duration_secs}s), skipping sample."
)
continue
skipped_samples += 1
continue # Skip this sample

yield sample
actual_length += 1
if actual_length == len(self) + 1:
warnings.warn(
f"The presumed length {self._length} has been exceeded for {self._config.name}:{self._args.split.value}. Make sure to update."
)
if actual_length != len(self):
warnings.warn(
f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for {self._config.name}:{self._args.split.value}. Make sure to update."
)

logging.info(
f"Extracted {actual_length} samples from {self.name} (total: {len(self)}), removed {bad_samples} bad samples, and skipped {skipped_samples} samples for exceeding max audio duration ({self._args.max_audio_duration_secs}s)."
)

@abc.abstractmethod
def _get_sample(
Expand Down Expand Up @@ -262,7 +283,7 @@ def __init__(
dsets = []
total_samples = 0
for split in config.splits:
if split.split_type == self._args.split:
if split.split == self._args.split:
if not config.use_mds:
ds = self._load_hf_dataset(
config.path,
Expand All @@ -283,7 +304,13 @@ def __init__(
len(dsets) > 0
), f"The {config.name} dataset has no {self._args.split} splits."
dataset = ds if len(dsets) == 1 else hf_datasets.concatenate_datasets(dsets)
super()._init_dataset(dataset, total_samples)

dataset_name = f"{config.name}.{self._args.split.value}"

super()._init_dataset(dataset, dataset_name, total_samples)

def __str__(self):
return f"GenericDataset({self._config})"

def _get_sample(self, row) -> Optional[data_sample.VoiceSample]:
assert self._config.user_template is not None
Expand All @@ -296,15 +323,14 @@ def _get_sample(self, row) -> Optional[data_sample.VoiceSample]:
).render(
**row,
text_proc=text_proc,
dataset=self,
**self._config.user_template_args,
)
assistant_content = jinja2.Template(
self._config.assistant_template, undefined=jinja2.StrictUndefined
).render(**row, text_proc=text_proc, dataset=self)
).render(**row, text_proc=text_proc)
transcript = jinja2.Template(
self._config.transcript_template, undefined=jinja2.StrictUndefined
).render(**row, text_proc=text_proc, dataset=self)
).render(**row, text_proc=text_proc)
except jinja2.TemplateError as e:
print(f"Error rendering template: {e}")
print(f"user_template: {self._config.user_template}")
Expand All @@ -322,18 +348,34 @@ def _get_sample(self, row) -> Optional[data_sample.VoiceSample]:
audio = self._get_audio(row, self._config.audio_field)
return self._make_sample(messages, audio, audio_transcript=transcript)

def get_config(self):
return self._config


class LibriSpeechDummyDataset(VoiceDataset):
class LibriSpeechDummyDataset(GenericDataset):
def __init__(self, args: types.VoiceDatasetArgs) -> None:
super().__init__(args)
VoiceDataset.__init__(self, args)
# This dataset doesn't support streaming.
dataset = self._load_hf_dataset(
"hf-internal-testing/librispeech_asr_dummy",
"clean",
split="validation",
streaming=False,
)
self._init_dataset(dataset, 73)
self._init_dataset(dataset, "dummy", 73)

def __str__(self):
return "LibriSpeechDummyDataset"

@property
def name(self):
return "dummy"

def get_config(self):
return types.DatasetConfig(
name="dummy",
path="hf-internal-testing/librispeech_asr_dummy",
)

def _get_sample(
self, row: transformers.BatchFeature
Expand All @@ -360,6 +402,13 @@ def __iter__(self):
def __len__(self):
return self._length

def __str__(self):
return f"EmptyDataset(length={self._length})"

@property
def name(self):
return "empty"


class InterleaveDataset(SizedIterableDataset):
"""Interleaves multiple SizedIterableDataset objects based on normalized weights."""
Expand All @@ -380,6 +429,7 @@ def __init__(
assert len(weights) == len(datasets)
else:
weights = [1.0] * len(datasets)
self._weights = weights
self._weighted_samples = [int(w * len(d)) for w, d in zip(weights, datasets)]
self._total_samples = sum(self._weighted_samples)

Expand All @@ -398,12 +448,25 @@ def __iter__(self):
yield next(ds_iters[iter_index])
except StopIteration:
ds_iters[iter_index] = iter(self._datasets[iter_index])
yield next(ds_iters[iter_index])
try:
yield next(ds_iters[iter_index])
except StopIteration:
warnings.warn(
f"Dataset {iter_index} is empty. num_workers is likely too high. Stopping iteration."
)
break
ds_pos[iter_index] += 1

def __len__(self):
return self._total_samples

def __str__(self):
return "+".join([f"{d}:{w:.2f}" for w, d in zip(self._weights, self._datasets)])

@property
def name(self):
return "+".join([ds.name for ds in self._datasets])


class Dataproc(SizedIterableDataset):
"""Base class to preprocess a dataset of VoiceSamples."""
Expand All @@ -421,6 +484,13 @@ def __iter__(self):
def __len__(self):
return len(self._dataset)

def __str__(self):
return f"Dataproc({self._dataset})"

@property
def name(self):
return self._dataset.name


class Range(SizedIterableDataset):
"""Limits the number of samples from another dataset."""
Expand All @@ -431,13 +501,30 @@ def __init__(
self._dataset = dataset
self._length = num_samples or len(dataset)
if self._length > len(dataset):
raise ValueError("num_samples exceeds dataset length.")
warnings.warn(
f"num_samples ({self._length}) exceeds dataset length ({len(dataset)}). Truncating to {len(dataset)}."
)
self._length = len(dataset)
self._name = f"{dataset.name}.{self._length}"

def __iter__(self):
for i, sample in enumerate(self._dataset):
if i >= self._length:
break
yield sample

def __str__(self):
return f"Range({self._dataset}%{len(self)})"

def __len__(self):
return self._length

@property
def name(self):
return self._name

def get_config(self):
if isinstance(self._dataset, GenericDataset):
return self._dataset.get_config()
else:
raise ValueError("Cannot get config for non-GenericDataset")
Loading

0 comments on commit 5a50e8e

Please sign in to comment.