Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Make target-text-file optional in generate.py #559

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytorch_translate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def get_parser_with_args():
)
generation_group.add_argument(
"--target-text-file",
default="",
default=None,
metavar="FILE",
help="Path to raw text file containing examples in target dialect. "
"This overrides what would be loaded from the data dir.",
Expand Down Expand Up @@ -619,9 +619,6 @@ def validate_args(args):
(src_file and os.path.isfile(src_file))
for src_file in args.source_text_file
), "Please specify a valid file for --source-text-file"
assert args.target_text_file and os.path.isfile(
args.target_text_file
), "Please specify a valid file for --target-text-file"


def generate(args):
Expand Down
180 changes: 106 additions & 74 deletions pytorch_translate/tasks/pytorch_translate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,26 +385,29 @@ def load_dataset_from_text(
self,
split: str,
source_text_file: str,
target_text_file: str,
target_text_file: Optional[str] = None,
append_eos: Optional[bool] = False,
reverse_source: Optional[bool] = True,
):
dst_dataset = data.IndexedRawTextDataset(
path=target_text_file,
dictionary=self.target_dictionary,
# We always append EOS to the target sentence since we still want
# the model to output an indication the sentence has finished, even
# if we don't append the EOS symbol to the source sentence
# (to prevent the model from misaligning UNKs or other words
# to the frequently occurring EOS).
append_eos=True,
# We don't reverse the order of the target sentence, since
# even if the source sentence is fed to the model backwards,
# we still want the model to start outputting from the first word.
reverse_order=False,
dst_dataset = (
data.IndexedRawTextDataset(
path=target_text_file,
dictionary=self.target_dictionary,
# We always append EOS to the target sentence since we still want
# the model to output an indication the sentence has finished, even
# if we don't append the EOS symbol to the source sentence
# (to prevent the model from misaligning UNKs or other words
# to the frequently occurring EOS).
append_eos=True,
# We don't reverse the order of the target sentence, since
# even if the source sentence is fed to the model backwards,
# we still want the model to start outputting from the first word.
reverse_order=False,
)
if target_text_file
else None
)

if self.char_source_dict is not None:
if self.char_source_dict:
src_dataset = char_data.InMemoryNumpyWordCharDataset()
src_dataset.parse(
path=source_text_file,
Expand All @@ -413,38 +416,53 @@ def load_dataset_from_text(
reverse_order=reverse_source,
append_eos=append_eos,
)
self.datasets[split] = char_data.LanguagePairSourceCharDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=dst_dataset,
tgt_sizes=dst_dataset.sizes,
tgt_dict=self.target_dictionary,
)
if dst_dataset:
self.datasets[split] = char_data.LanguagePairSourceCharDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=dst_dataset,
tgt_sizes=dst_dataset.sizes,
tgt_dict=self.target_dictionary,
)
else:
self.datasets[split] = char_data.LanguagePairSourceCharDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
)
else:
src_dataset = data.IndexedRawTextDataset(
path=source_text_file,
dictionary=self.source_dictionary,
append_eos=append_eos,
reverse_order=reverse_source,
)
self.datasets[split] = data.LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=dst_dataset,
tgt_sizes=dst_dataset.sizes,
tgt_dict=self.target_dictionary,
left_pad_source=False,
)
if dst_dataset:
self.datasets[split] = data.LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=dst_dataset,
tgt_sizes=dst_dataset.sizes,
tgt_dict=self.target_dictionary,
left_pad_source=False,
)
else:
self.datasets[split] = data.LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
left_pad_source=False,
)

print(f"| {split} {len(self.datasets[split])} examples")

def load_multisource_dataset_from_text(
self,
split: str,
source_text_files: List[str],
target_text_file: str,
target_text_file: Optional[str] = None,
append_eos: Optional[bool] = False,
reverse_source: Optional[bool] = True,
):
Expand All @@ -454,28 +472,35 @@ def load_multisource_dataset_from_text(
append_eos=append_eos,
reverse_order=reverse_source,
)
tgt_dataset = data.IndexedRawTextDataset(
path=target_text_file,
dictionary=self.target_dictionary,
# We always append EOS to the target sentence since we still want
# the model to output an indication the sentence has finished, even
# if we don't append the EOS symbol to the source sentence
# (to prevent the model from misaligning UNKs or other words
# to the frequently occurring EOS).
append_eos=True,
# We don't reverse the order of the target sentence, since
# even if the source sentence is fed to the model backwards,
# we still want the model to start outputting from the first word.
reverse_order=False,
)
self.datasets[split] = multisource_data.MultisourceLanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=tgt_dataset,
tgt_sizes=tgt_dataset.sizes,
tgt_dict=self.target_dictionary,
)
if target_text_file:
tgt_dataset = data.IndexedRawTextDataset(
path=target_text_file,
dictionary=self.target_dictionary,
# We always append EOS to the target sentence since we still want
# the model to output an indication the sentence has finished, even
# if we don't append the EOS symbol to the source sentence
# (to prevent the model from misaligning UNKs or other words
# to the frequently occurring EOS).
append_eos=True,
# We don't reverse the order of the target sentence, since
# even if the source sentence is fed to the model backwards,
# we still want the model to start outputting from the first word.
reverse_order=False,
)
self.datasets[split] = multisource_data.MultisourceLanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=tgt_dataset,
tgt_sizes=tgt_dataset.sizes,
tgt_dict=self.target_dictionary,
)
else:
self.datasets[split] = multisource_data.MultisourceLanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
)

@property
def source_dictionary(self):
Expand Down Expand Up @@ -562,9 +587,9 @@ def load_dataset_from_text_multilingual(
self,
split: str,
source_text_file: str,
target_text_file: str,
target_text_file: Optional[str],
source_lang_id: int,
target_lang_id: int,
target_lang_id: Optional[int],
append_eos: Optional[bool] = False,
reverse_source: Optional[bool] = True,
):
Expand All @@ -576,22 +601,29 @@ def load_dataset_from_text_multilingual(
reverse_order=reverse_source,
prepend_language_id=False,
)
tgt_dataset = pytorch_translate_data.IndexedRawTextDatasetWithLangId(
path=target_text_file,
dictionary=self.target_dictionary,
lang_id=target_lang_id,
append_eos=True,
reverse_order=False,
prepend_language_id=True,
)
self.datasets[split] = data.LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=tgt_dataset,
tgt_sizes=tgt_dataset.sizes,
tgt_dict=self.target_dictionary,
)
if target_text_file:
tgt_dataset = pytorch_translate_data.IndexedRawTextDatasetWithLangId(
path=target_text_file,
dictionary=self.target_dictionary,
lang_id=target_lang_id,
append_eos=True,
reverse_order=False,
prepend_language_id=True,
)
self.datasets[split] = data.LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
tgt=tgt_dataset,
tgt_sizes=tgt_dataset.sizes,
tgt_dict=self.target_dictionary,
)
else:
self.datasets[split] = data.LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.source_dictionary,
)
print(f"| {split} {len(self.datasets[split])} examples")

def set_encoder_langs(self, encoder_langs):
Expand Down