diff --git a/pytorch_translate/constants.py b/pytorch_translate/constants.py index 22472871..698f9093 100644 --- a/pytorch_translate/constants.py +++ b/pytorch_translate/constants.py @@ -10,3 +10,5 @@ DENOISING_AUTOENCODER_TASK = "pytorch_translate_denoising_autoencoder" MULTILINGUAL_TRANSLATION_TASK = "pytorch_translate_multilingual_task" LATENT_VARIABLE_TASK = "translation_vae" + +MAX_LANGUAGES = 300 diff --git a/pytorch_translate/tasks/pytorch_translate_multi_task.py b/pytorch_translate/tasks/pytorch_translate_multi_task.py index 7913ab93..8a99265b 100644 --- a/pytorch_translate/tasks/pytorch_translate_multi_task.py +++ b/pytorch_translate/tasks/pytorch_translate_multi_task.py @@ -4,6 +4,7 @@ from fairseq.data import FairseqDataset, data_utils from fairseq.models import FairseqMultiModel from fairseq.tasks.multilingual_translation import MultilingualTranslationTask +from pytorch_translate import vocab_constants from pytorch_translate.data import iterators as ptt_iterators @@ -67,3 +68,24 @@ def get_batch_iterator( def max_positions(self): """Return None to allow model to dictate max sentence length allowed""" return None + + def get_encoder_langtok(self, src_lang, tgt_lang): + if self.args.encoder_langtok is not None: + if self.args.encoder_langtok == "src": + if src_lang in vocab_constants.DIALECT_CODES: + return vocab_constants.DIALECT_CODES[src_lang] + else: + if tgt_lang in vocab_constants.DIALECT_CODES: + return vocab_constants.DIALECT_CODES[tgt_lang] + # if encoder_langtok is not None or src_lang and tgt_lang are not in + # vocab_constants.DIALECT_CODES + return self.dicts[src_lang].eos() + + def get_decoder_langtok(self, tgt_lang): + if ( + not self.args.decoder_langtok + or tgt_lang not in vocab_constants.DIALECT_CODES + ): + return self.dicts[tgt_lang].eos() + else: + return vocab_constants.DIALECT_CODES[tgt_lang] diff --git a/pytorch_translate/vocab_constants.py b/pytorch_translate/vocab_constants.py index 8c279dda..78550d6b 100644 --- a/pytorch_translate/vocab_constants.py +++ b/pytorch_translate/vocab_constants.py @@ -14,3 +14,76 @@ EOS_ID = 2 UNK_ID = 3 MASK_ID = 4 + +DIALECT_CODES = { + "af_ZA": 8, + "am_ET": 9, + "ar_AR": 10, + "be_BY": 11, + "bg_BG": 12, + "bn_IN": 13, + "br_FR": 14, + "bs_BA": 15, + "ca_ES": 16, + "cs_CZ": 17, + "cy_GB": 18, + "da_DK": 19, + "de_DE": 20, + "el_GR": 21, + "en_XX": 22, + "es_XX": 23, + "et_EE": 24, + "eu_ES": 25, + "fa_IR": 26, + "fi_FI": 27, + "fr_XX": 28, + "gu_IN": 29, + "ha_NG": 30, + "he_IL": 31, + "hi_IN": 32, + "hr_HR": 33, + "hu_HU": 34, + "id_ID": 35, + "it_IT": 36, + "ja_XX": 37, + "km_KH": 38, + "kn_IN": 39, + "ko_KR": 40, + "lt_LT": 41, + "lv_LV": 42, + "mk_MK": 43, + "ml_IN": 44, + "mn_MN": 45, + "mr_IN": 46, + "ms_MY": 47, + "my_MM": 48, + "ne_NP": 49, + "nl_XX": 50, + "no_XX": 51, + "pa_IN": 52, + "pl_PL": 53, + "ps_AF": 54, + "pt_XX": 55, + "ro_RO": 56, + "ru_RU": 57, + "si_LK": 57, + "sk_SK": 58, + "sl_SI": 59, + "so_SO": 60, + "sq_AL": 61, + "sr_RS": 62, + "sv_SE": 63, + "sw_KE": 64, + "ta_IN": 65, + "te_IN": 66, + "th_TH": 67, + "tl_XX": 68, + "tr_TR": 69, + "uk_UA": 70, + "ur_PK": 71, + "vi_VN": 72, + "xh_ZA": 73, + "zh_CN": 74, + "zh_TW": 75, + "zu_ZA": 76, +} diff --git a/pytorch_translate/vocab_reduction.py b/pytorch_translate/vocab_reduction.py index 3f279783..7ede9afa 100644 --- a/pytorch_translate/vocab_reduction.py +++ b/pytorch_translate/vocab_reduction.py @@ -7,6 +7,8 @@ import torch import torch.nn as nn +from pytorch_translate import constants + logger = logging.getLogger(__name__) @@ -71,6 +73,7 @@ def set_arg_defaults(args): def select_top_candidate_per_word( source_index, + target_dialect, target_indices_with_prob, counter_per_word, max_translation_candidates_per_word, @@ -80,7 +83,10 @@ def select_top_candidate_per_word( translation_candidates_saved = 0 target_indices_with_prob.sort(key=lambda x: x[1], reverse=True) for target_index_with_prob in target_indices_with_prob: - if counter_per_word[source_index] >= max_translation_candidates_per_word: + if ( + counter_per_word[target_dialect][source_index] + >= max_translation_candidates_per_word + ): # don't save more than max_translation_candidates_per_word # translation candidates for any one source token break @@ -88,14 +94,17 @@ def select_top_candidate_per_word( # update translation candidates matrix at [source index, running counter # per source token] to = target index translation_candidates[ - source_index, counter_per_word[source_index] + target_dialect, source_index, counter_per_word[target_dialect][source_index] ] = target_index_with_prob[0] - translation_candidates_set.update((source_index, target_index_with_prob[0])) - counter_per_word[source_index] += 1 + translation_candidates_set.update( + (target_dialect, source_index, target_index_with_prob[0]) + ) + counter_per_word[target_dialect][source_index] += 1 translation_candidates_saved += 1 return translation_candidates_saved +# Dummy Change def get_translation_candidates( src_dict, dst_dict, @@ -104,33 +113,37 @@ def get_translation_candidates( max_translation_candidates_per_word, ): """ - Reads a lexical dictionary file, where each line is (source token, possible - translation of source token, probability). The file is generally grouped - by source tokens, but within the group, the probabilities are not - necessarily sorted. - - A a 0.3 - A c 0.1 - A e 0.05 - A f 0.01 - B b 0.6 - B b 0.2 - A z 0.001 - A y 0.002 + Reads a lexical dictionary file, where each line is (target dialect ID, source + token, possible translation of source token, probability). The file is + generally grouped by target dialect and source tokens, but within the group, + the probabilities are not necessarily sorted. + + 0 A a 0.3 + 0 A c 0.1 + 0 A e 0.05 + 25 A f 0.01 + 25 B b 0.6 + 25 B b 0.2 + 25 A z 0.001 + 25 A y 0.002 ... Returns: translation_candidates - Matrix of shape (src_dict, max_translation_candidates_per_word) where - each row corresponds to a source word in the vocab and contains token - indices of translation candidates for that source word + Matrix of shape (constants.MAX_LANGUAGES, src_dict, + max_translation_candidates_per_word) where each row corresponds to a + source word in the vocab and contains token indices of translation + candidates for that source word """ translation_candidates = np.zeros( - [len(src_dict), max_translation_candidates_per_word], dtype=np.int32 + [constants.MAX_LANGUAGES, len(src_dict), max_translation_candidates_per_word], + dtype=np.int32, ) - # running count of translation candidates per source word - counter_per_word = np.zeros(len(src_dict), dtype=np.int32) + # running count of translation candidates per source word per target dialect + counter_per_word = np.zeros( + [constants.MAX_LANGUAGES, len(src_dict)], dtype=np.int32 + ) # tracks if we've already seen some (source token, target token) pair so we # ignore duplicate lines @@ -142,13 +155,14 @@ def get_translation_candidates( with codecs.open(lexical_dictionary, "r", "utf-8") as lexical_dictionary_file: current_source_index = None + current_target_dialect = None current_target_indices = [] for line in lexical_dictionary_file.readlines(): alignment_data = line.split() - if len(alignment_data) != 3: + if len(alignment_data) != 4: logger.warning(f"Malformed line in lexical dictionary: {line}") continue - source_word, target_word, prob = alignment_data + target_dialect, source_word, target_word, prob = alignment_data prob = float(prob) source_index = src_dict.index(source_word) target_index = dst_dict.index(target_word) @@ -159,24 +173,29 @@ def get_translation_candidates( continue if source_index is not None and target_index is not None: - if source_index != current_source_index: + if ( + source_index != current_source_index + and target_dialect != current_target_dialect + ): # We've finished processing the possible translation # candidates for this source token group, so save the # extracted translation candidates translation_candidates_saved += select_top_candidate_per_word( current_source_index, + current_target_dialect, current_target_indices, counter_per_word, max_translation_candidates_per_word, translation_candidates, translation_candidates_set, ) + current_target_dialect = target_dialect current_source_index = source_index current_target_indices = [] if ( target_index >= num_top_words - and (source_index, target_index) + and (target_dialect, source_index, target_index) not in translation_candidates_set ): current_target_indices.append((target_index, prob)) @@ -184,6 +203,7 @@ def get_translation_candidates( # group translation_candidates_saved += select_top_candidate_per_word( current_source_index, + current_target_dialect, current_target_indices, counter_per_word, max_translation_candidates_per_word, @@ -230,7 +250,13 @@ def __init__( ) # encoder_output is default None for backwards compatibility - def forward(self, src_tokens, encoder_output=None, decoder_input_tokens=None): + def forward( + self, + src_tokens, + encoder_output=None, + decoder_input_tokens=None, + target_dialect=0, + ): assert self.dst_dict.pad() == 0, ( f"VocabReduction only works correctly when the padding ID is 0 " "(to ensure its position in possible_translation_tokens is also 0), " @@ -243,7 +269,10 @@ def forward(self, src_tokens, encoder_output=None, decoder_input_tokens=None): vocab_list.append(flat_decoder_input_tokens) if self.translation_candidates is not None: - reduced_vocab = self.translation_candidates.index_select( + candidates_in_target_dialect = self.translation_candidates.index_select( + dim=0, index=target_dialect + ) + reduced_vocab = candidates_in_target_dialect.index_select( dim=0, index=src_tokens.view(-1) ).view(-1) vocab_list.append(reduced_vocab)