From 72febb73a326dfbe72c719ebaffa58e6fa289747 Mon Sep 17 00:00:00 2001 From: James Cross Date: Tue, 11 Feb 2020 15:16:45 -0800 Subject: [PATCH] cmd-line support for loading mmap datasets Summary: Allows loading data from the Fairseq .idx/.bin format (including most current "mmap" implementation) by specifying the `--fairseq-binary-data-format` flag. (Note that D16867809 added iniernal support for loading legacy .idx / .bin files, but did not expose an option for using that format to the command-line trainer.) Differential Revision: D19844619 fbshipit-source-id: a5a6c34524dea94165aacb09f1996f6575bdee21 --- pytorch_translate/data/data.py | 6 +++++- pytorch_translate/options.py | 6 ++++++ pytorch_translate/preprocess.py | 14 ++++++++------ pytorch_translate/train.py | 2 ++ 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pytorch_translate/data/data.py b/pytorch_translate/data/data.py index 03eba1fd..ef0fb189 100644 --- a/pytorch_translate/data/data.py +++ b/pytorch_translate/data/data.py @@ -261,7 +261,11 @@ def create_from_file(path, is_npz=True, num_examples_limit: Optional[int] = None return result else: # idx, bin format - return InMemoryIndexedDataset(path) + impl = data.indexed_dataset.infer_dataset_impl(path) + if impl == "mmap": + return data.indexed_dataset.MMapIndexedDataset(path) + else: + return InMemoryIndexedDataset(path) def subsample(self, indices): """ diff --git a/pytorch_translate/options.py b/pytorch_translate/options.py index c20cfe44..c3769b98 100644 --- a/pytorch_translate/options.py +++ b/pytorch_translate/options.py @@ -293,6 +293,12 @@ def add_preprocessing_args(parser): default="", help="Path for the binary file containing target side monolingual data", ) + group.add_argument( + "--fairseq-binary-data-format", + default=False, + action="store_true", + help="Binary data paths are prefixes of .bin and .idx files", + ) # TODO(T43045193): Move this to multilingual_task.py eventually group.add_argument( diff --git a/pytorch_translate/preprocess.py b/pytorch_translate/preprocess.py index 6667b27e..031ae6d3 100644 --- a/pytorch_translate/preprocess.py +++ b/pytorch_translate/preprocess.py @@ -15,7 +15,7 @@ from pytorch_translate.data.dictionary import Dictionary -def maybe_generate_temp_file_path(output_path=None): +def maybe_generate_temp_file_path(output_path=None, is_npz=True): """ This function generates a temp file path if output_path is empty or None. This is useful to do before calling any preprocessing function that has a @@ -28,7 +28,7 @@ def maybe_generate_temp_file_path(output_path=None): os.close(fd) # numpy silently appends this suffix if it is not present, so this ensures # that the correct path is returned - if not output_path.endswith(".npz"): + if is_npz and not output_path.endswith(".npz"): output_path += ".npz" return output_path @@ -148,16 +148,18 @@ def preprocess_corpora(args, dictionary_cls=Dictionary): utils.maybe_parse_collection_argument(args.train_target_binary_path), str ): args.train_source_binary_path = maybe_generate_temp_file_path( - args.train_source_binary_path + args.train_source_binary_path, + is_npz=not args.fairseq_binary_data_format, ) args.train_target_binary_path = maybe_generate_temp_file_path( - args.train_target_binary_path + args.train_target_binary_path, + is_npz=not args.fairseq_binary_data_format, ) args.eval_source_binary_path = maybe_generate_temp_file_path( - args.eval_source_binary_path + args.eval_source_binary_path, is_npz=not args.fairseq_binary_data_format ) args.eval_target_binary_path = maybe_generate_temp_file_path( - args.eval_target_binary_path + args.eval_target_binary_path, is_npz=not args.fairseq_binary_data_format ) # Additional text preprocessing options could be added here before diff --git a/pytorch_translate/train.py b/pytorch_translate/train.py index d8a644c6..9c335407 100644 --- a/pytorch_translate/train.py +++ b/pytorch_translate/train.py @@ -300,6 +300,7 @@ def setup_training_model(args): src_bin_path=args.train_source_binary_path, tgt_bin_path=args.train_target_binary_path, weights_file=getattr(args, "train_weights_path", None), + is_npz=not args.fairseq_binary_data_format, ) if args.task == "dual_learning_task": @@ -311,6 +312,7 @@ def setup_training_model(args): split=args.valid_subset, src_bin_path=args.eval_source_binary_path, tgt_bin_path=args.eval_target_binary_path, + is_npz=not args.fairseq_binary_data_format, ) return task, model, criterion