From 088fda70de153a388836b2fdca90fe726ec5b222 Mon Sep 17 00:00:00 2001 From: Ning Dong Date: Wed, 29 May 2019 13:10:07 -0700 Subject: [PATCH] Make target-text-file optional in generate.py Summary: Now --target-text-file is mandatory when --source-text-file is specified. However for getting translation without evaluation it's not necessary. Remove the validation. LMK if there are potential concerns. Differential Revision: D15540664 fbshipit-source-id: 71889cdbad746abbcbc07968994f9c9f733b204e --- pytorch_translate/generate.py | 5 +- .../tasks/pytorch_translate_task.py | 180 +++++++++++------- 2 files changed, 107 insertions(+), 78 deletions(-) diff --git a/pytorch_translate/generate.py b/pytorch_translate/generate.py index f11fa71f..d95a3b94 100644 --- a/pytorch_translate/generate.py +++ b/pytorch_translate/generate.py @@ -530,7 +530,7 @@ def get_parser_with_args(): ) generation_group.add_argument( "--target-text-file", - default="", + default=None, metavar="FILE", help="Path to raw text file containing examples in target dialect. " "This overrides what would be loaded from the data dir.", @@ -619,9 +619,6 @@ def validate_args(args): (src_file and os.path.isfile(src_file)) for src_file in args.source_text_file ), "Please specify a valid file for --source-text-file" - assert args.target_text_file and os.path.isfile( - args.target_text_file - ), "Please specify a valid file for --target-text-file" def generate(args): diff --git a/pytorch_translate/tasks/pytorch_translate_task.py b/pytorch_translate/tasks/pytorch_translate_task.py index d7506b08..10483ae5 100644 --- a/pytorch_translate/tasks/pytorch_translate_task.py +++ b/pytorch_translate/tasks/pytorch_translate_task.py @@ -385,26 +385,29 @@ def load_dataset_from_text( self, split: str, source_text_file: str, - target_text_file: str, + target_text_file: Optional[str] = None, append_eos: Optional[bool] = False, reverse_source: Optional[bool] = True, ): - dst_dataset = data.IndexedRawTextDataset( - path=target_text_file, - dictionary=self.target_dictionary, - # We always append EOS to the target sentence since we still want - # the model to output an indication the sentence has finished, even - # if we don't append the EOS symbol to the source sentence - # (to prevent the model from misaligning UNKs or other words - # to the frequently occurring EOS). - append_eos=True, - # We don't reverse the order of the target sentence, since - # even if the source sentence is fed to the model backwards, - # we still want the model to start outputting from the first word. - reverse_order=False, + dst_dataset = ( + data.IndexedRawTextDataset( + path=target_text_file, + dictionary=self.target_dictionary, + # We always append EOS to the target sentence since we still want + # the model to output an indication the sentence has finished, even + # if we don't append the EOS symbol to the source sentence + # (to prevent the model from misaligning UNKs or other words + # to the frequently occurring EOS). + append_eos=True, + # We don't reverse the order of the target sentence, since + # even if the source sentence is fed to the model backwards, + # we still want the model to start outputting from the first word. + reverse_order=False, + ) + if target_text_file + else None ) - - if self.char_source_dict is not None: + if self.char_source_dict: src_dataset = char_data.InMemoryNumpyWordCharDataset() src_dataset.parse( path=source_text_file, @@ -413,14 +416,21 @@ def load_dataset_from_text( reverse_order=reverse_source, append_eos=append_eos, ) - self.datasets[split] = char_data.LanguagePairSourceCharDataset( - src=src_dataset, - src_sizes=src_dataset.sizes, - src_dict=self.source_dictionary, - tgt=dst_dataset, - tgt_sizes=dst_dataset.sizes, - tgt_dict=self.target_dictionary, - ) + if dst_dataset: + self.datasets[split] = char_data.LanguagePairSourceCharDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + tgt=dst_dataset, + tgt_sizes=dst_dataset.sizes, + tgt_dict=self.target_dictionary, + ) + else: + self.datasets[split] = char_data.LanguagePairSourceCharDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + ) else: src_dataset = data.IndexedRawTextDataset( path=source_text_file, @@ -428,15 +438,23 @@ def load_dataset_from_text( append_eos=append_eos, reverse_order=reverse_source, ) - self.datasets[split] = data.LanguagePairDataset( - src=src_dataset, - src_sizes=src_dataset.sizes, - src_dict=self.source_dictionary, - tgt=dst_dataset, - tgt_sizes=dst_dataset.sizes, - tgt_dict=self.target_dictionary, - left_pad_source=False, - ) + if dst_dataset: + self.datasets[split] = data.LanguagePairDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + tgt=dst_dataset, + tgt_sizes=dst_dataset.sizes, + tgt_dict=self.target_dictionary, + left_pad_source=False, + ) + else: + self.datasets[split] = data.LanguagePairDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + left_pad_source=False, + ) print(f"| {split} {len(self.datasets[split])} examples") @@ -444,7 +462,7 @@ def load_multisource_dataset_from_text( self, split: str, source_text_files: List[str], - target_text_file: str, + target_text_file: Optional[str] = None, append_eos: Optional[bool] = False, reverse_source: Optional[bool] = True, ): @@ -454,28 +472,35 @@ def load_multisource_dataset_from_text( append_eos=append_eos, reverse_order=reverse_source, ) - tgt_dataset = data.IndexedRawTextDataset( - path=target_text_file, - dictionary=self.target_dictionary, - # We always append EOS to the target sentence since we still want - # the model to output an indication the sentence has finished, even - # if we don't append the EOS symbol to the source sentence - # (to prevent the model from misaligning UNKs or other words - # to the frequently occurring EOS). - append_eos=True, - # We don't reverse the order of the target sentence, since - # even if the source sentence is fed to the model backwards, - # we still want the model to start outputting from the first word. - reverse_order=False, - ) - self.datasets[split] = multisource_data.MultisourceLanguagePairDataset( - src=src_dataset, - src_sizes=src_dataset.sizes, - src_dict=self.source_dictionary, - tgt=tgt_dataset, - tgt_sizes=tgt_dataset.sizes, - tgt_dict=self.target_dictionary, - ) + if target_text_file: + tgt_dataset = data.IndexedRawTextDataset( + path=target_text_file, + dictionary=self.target_dictionary, + # We always append EOS to the target sentence since we still want + # the model to output an indication the sentence has finished, even + # if we don't append the EOS symbol to the source sentence + # (to prevent the model from misaligning UNKs or other words + # to the frequently occurring EOS). + append_eos=True, + # We don't reverse the order of the target sentence, since + # even if the source sentence is fed to the model backwards, + # we still want the model to start outputting from the first word. + reverse_order=False, + ) + self.datasets[split] = multisource_data.MultisourceLanguagePairDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + tgt=tgt_dataset, + tgt_sizes=tgt_dataset.sizes, + tgt_dict=self.target_dictionary, + ) + else: + self.datasets[split] = multisource_data.MultisourceLanguagePairDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + ) @property def source_dictionary(self): @@ -562,9 +587,9 @@ def load_dataset_from_text_multilingual( self, split: str, source_text_file: str, - target_text_file: str, + target_text_file: Optional[str], source_lang_id: int, - target_lang_id: int, + target_lang_id: Optional[int], append_eos: Optional[bool] = False, reverse_source: Optional[bool] = True, ): @@ -576,22 +601,29 @@ def load_dataset_from_text_multilingual( reverse_order=reverse_source, prepend_language_id=False, ) - tgt_dataset = pytorch_translate_data.IndexedRawTextDatasetWithLangId( - path=target_text_file, - dictionary=self.target_dictionary, - lang_id=target_lang_id, - append_eos=True, - reverse_order=False, - prepend_language_id=True, - ) - self.datasets[split] = data.LanguagePairDataset( - src=src_dataset, - src_sizes=src_dataset.sizes, - src_dict=self.source_dictionary, - tgt=tgt_dataset, - tgt_sizes=tgt_dataset.sizes, - tgt_dict=self.target_dictionary, - ) + if target_text_file: + tgt_dataset = pytorch_translate_data.IndexedRawTextDatasetWithLangId( + path=target_text_file, + dictionary=self.target_dictionary, + lang_id=target_lang_id, + append_eos=True, + reverse_order=False, + prepend_language_id=True, + ) + self.datasets[split] = data.LanguagePairDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + tgt=tgt_dataset, + tgt_sizes=tgt_dataset.sizes, + tgt_dict=self.target_dictionary, + ) + else: + self.datasets[split] = data.LanguagePairDataset( + src=src_dataset, + src_sizes=src_dataset.sizes, + src_dict=self.source_dictionary, + ) print(f"| {split} {len(self.datasets[split])} examples") def set_encoder_langs(self, encoder_langs):