Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

cmd-line support for loading mmap datasets #685

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pytorch_translate/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ def create_from_file(path, is_npz=True, num_examples_limit: Optional[int] = None
return result
else:
# idx, bin format
return InMemoryIndexedDataset(path)
impl = data.indexed_dataset.infer_dataset_impl(path)
if impl == "mmap":
return data.indexed_dataset.MMapIndexedDataset(path)
else:
return InMemoryIndexedDataset(path)

def subsample(self, indices):
"""
Expand Down
6 changes: 6 additions & 0 deletions pytorch_translate/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ def add_preprocessing_args(parser):
default="",
help="Path for the binary file containing target side monolingual data",
)
group.add_argument(
"--fairseq-binary-data-format",
default=False,
action="store_true",
help="Binary data paths are prefixes of .bin and .idx files",
)

# TODO(T43045193): Move this to multilingual_task.py eventually
group.add_argument(
Expand Down
14 changes: 8 additions & 6 deletions pytorch_translate/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytorch_translate.data.dictionary import Dictionary


def maybe_generate_temp_file_path(output_path=None):
def maybe_generate_temp_file_path(output_path=None, is_npz=True):
"""
This function generates a temp file path if output_path is empty or None.
This is useful to do before calling any preprocessing function that has a
Expand All @@ -28,7 +28,7 @@ def maybe_generate_temp_file_path(output_path=None):
os.close(fd)
# numpy silently appends this suffix if it is not present, so this ensures
# that the correct path is returned
if not output_path.endswith(".npz"):
if is_npz and not output_path.endswith(".npz"):
output_path += ".npz"
return output_path

Expand Down Expand Up @@ -148,16 +148,18 @@ def preprocess_corpora(args, dictionary_cls=Dictionary):
utils.maybe_parse_collection_argument(args.train_target_binary_path), str
):
args.train_source_binary_path = maybe_generate_temp_file_path(
args.train_source_binary_path
args.train_source_binary_path,
is_npz=not args.fairseq_binary_data_format,
)
args.train_target_binary_path = maybe_generate_temp_file_path(
args.train_target_binary_path
args.train_target_binary_path,
is_npz=not args.fairseq_binary_data_format,
)
args.eval_source_binary_path = maybe_generate_temp_file_path(
args.eval_source_binary_path
args.eval_source_binary_path, is_npz=not args.fairseq_binary_data_format
)
args.eval_target_binary_path = maybe_generate_temp_file_path(
args.eval_target_binary_path
args.eval_target_binary_path, is_npz=not args.fairseq_binary_data_format
)

# Additional text preprocessing options could be added here before
Expand Down
2 changes: 2 additions & 0 deletions pytorch_translate/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def setup_training_model(args):
src_bin_path=args.train_source_binary_path,
tgt_bin_path=args.train_target_binary_path,
weights_file=getattr(args, "train_weights_path", None),
is_npz=not args.fairseq_binary_data_format,
)

if args.task == "dual_learning_task":
Expand All @@ -311,6 +312,7 @@ def setup_training_model(args):
split=args.valid_subset,
src_bin_path=args.eval_source_binary_path,
tgt_bin_path=args.eval_target_binary_path,
is_npz=not args.fairseq_binary_data_format,
)

return task, model, criterion
Expand Down