Skip to content

Commit 316694f

Browse files
committed
fix: add tfrecords buffer size in bytes for tfrecords dataset
1 parent 0635cfe commit 316694f

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tensorflow_asr/datasets.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626
# An ASR dataset is some `.tsv` files in format: `PATH\tDURATION\tTRANSCRIPT`. You must create those files by your own with your own data and methods.
2727

28-
# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT` because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob:
28+
# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT`
29+
# because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob:
2930

30-
# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file. Default is [English](../featurizers/english.txt)
31+
# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file.
32+
# Default is [English](../featurizers/english.txt)
3133

3234
# **Inputs**
3335

@@ -141,8 +143,9 @@ def get_global_shape(
141143

142144

143145
BUFFER_SIZE = 100
146+
TFRECORD_BUFFER_SIZE = 32 * 1024 * 1024
144147
TFRECORD_SHARDS = 16
145-
AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.experimental.AUTOTUNE)
148+
AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.AUTOTUNE)
146149

147150

148151
class BaseDataset:
@@ -416,6 +419,7 @@ def __init__(
416419
indefinite: bool = True,
417420
drop_remainder: bool = True,
418421
buffer_size: int = BUFFER_SIZE,
422+
tfrecords_buffer_size: int = TFRECORD_BUFFER_SIZE,
419423
compression_type: str = "GZIP",
420424
sample_rate: int = 16000,
421425
name: str = "",
@@ -442,6 +446,7 @@ def __init__(
442446
if tfrecords_shards <= 0:
443447
raise ValueError("tfrecords_shards must be positive")
444448
self.tfrecords_shards = tfrecords_shards
449+
self.tfrecords_buffer_size = tfrecords_buffer_size
445450
self.compression_type = compression_type
446451

447452
def write_tfrecord_file(self, splitted_entries: tuple):
@@ -506,7 +511,9 @@ def create(self, batch_size: int, padded_shapes=None):
506511
ignore_order = tf.data.Options()
507512
ignore_order.deterministic = False
508513
files_ds = files_ds.with_options(ignore_order)
509-
dataset = tf.data.TFRecordDataset(files_ds, compression_type=self.compression_type, num_parallel_reads=AUTOTUNE)
514+
dataset = tf.data.TFRecordDataset(
515+
files_ds, compression_type=self.compression_type, buffer_size=self.tfrecords_buffer_size, num_parallel_reads=AUTOTUNE
516+
)
510517

511518
return self.process(dataset, batch_size, padded_shapes=padded_shapes)
512519

0 commit comments

Comments
 (0)