Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.

Add source/target dialect token mapping #577

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_translate/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
DENOISING_AUTOENCODER_TASK = "pytorch_translate_denoising_autoencoder"
MULTILINGUAL_TRANSLATION_TASK = "pytorch_translate_multilingual_task"
LATENT_VARIABLE_TASK = "translation_vae"

MAX_LANGUAGES = 300
22 changes: 22 additions & 0 deletions pytorch_translate/tasks/pytorch_translate_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fairseq.data import FairseqDataset, data_utils
from fairseq.models import FairseqMultiModel
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from pytorch_translate import vocab_constants
from pytorch_translate.data import iterators as ptt_iterators


Expand Down Expand Up @@ -67,3 +68,24 @@ def get_batch_iterator(
def max_positions(self):
"""Return None to allow model to dictate max sentence length allowed"""
return None

def get_encoder_langtok(self, src_lang, tgt_lang):
if self.args.encoder_langtok is not None:
if self.args.encoder_langtok == "src":
if src_lang in vocab_constants.DIALECT_CODES:
return vocab_constants.DIALECT_CODES[src_lang]
else:
if tgt_lang in vocab_constants.DIALECT_CODES:
return vocab_constants.DIALECT_CODES[tgt_lang]
# if encoder_langtok is not None or src_lang and tgt_lang are not in
# vocab_constants.DIALECT_CODES
return self.dicts[src_lang].eos()

def get_decoder_langtok(self, tgt_lang):
if (
not self.args.decoder_langtok
or tgt_lang not in vocab_constants.DIALECT_CODES
):
return self.dicts[tgt_lang].eos()
else:
return vocab_constants.DIALECT_CODES[tgt_lang]
73 changes: 73 additions & 0 deletions pytorch_translate/vocab_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,76 @@
EOS_ID = 2
UNK_ID = 3
MASK_ID = 4

DIALECT_CODES = {
"af_ZA": 8,
"am_ET": 9,
"ar_AR": 10,
"be_BY": 11,
"bg_BG": 12,
"bn_IN": 13,
"br_FR": 14,
"bs_BA": 15,
"ca_ES": 16,
"cs_CZ": 17,
"cy_GB": 18,
"da_DK": 19,
"de_DE": 20,
"el_GR": 21,
"en_XX": 22,
"es_XX": 23,
"et_EE": 24,
"eu_ES": 25,
"fa_IR": 26,
"fi_FI": 27,
"fr_XX": 28,
"gu_IN": 29,
"ha_NG": 30,
"he_IL": 31,
"hi_IN": 32,
"hr_HR": 33,
"hu_HU": 34,
"id_ID": 35,
"it_IT": 36,
"ja_XX": 37,
"km_KH": 38,
"kn_IN": 39,
"ko_KR": 40,
"lt_LT": 41,
"lv_LV": 42,
"mk_MK": 43,
"ml_IN": 44,
"mn_MN": 45,
"mr_IN": 46,
"ms_MY": 47,
"my_MM": 48,
"ne_NP": 49,
"nl_XX": 50,
"no_XX": 51,
"pa_IN": 52,
"pl_PL": 53,
"ps_AF": 54,
"pt_XX": 55,
"ro_RO": 56,
"ru_RU": 57,
"si_LK": 57,
"sk_SK": 58,
"sl_SI": 59,
"so_SO": 60,
"sq_AL": 61,
"sr_RS": 62,
"sv_SE": 63,
"sw_KE": 64,
"ta_IN": 65,
"te_IN": 66,
"th_TH": 67,
"tl_XX": 68,
"tr_TR": 69,
"uk_UA": 70,
"ur_PK": 71,
"vi_VN": 72,
"xh_ZA": 73,
"zh_CN": 74,
"zh_TW": 75,
"zu_ZA": 76,
}
87 changes: 58 additions & 29 deletions pytorch_translate/vocab_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
import torch.nn as nn

from pytorch_translate import constants


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,6 +73,7 @@ def set_arg_defaults(args):

def select_top_candidate_per_word(
source_index,
target_dialect,
target_indices_with_prob,
counter_per_word,
max_translation_candidates_per_word,
Expand All @@ -80,22 +83,28 @@ def select_top_candidate_per_word(
translation_candidates_saved = 0
target_indices_with_prob.sort(key=lambda x: x[1], reverse=True)
for target_index_with_prob in target_indices_with_prob:
if counter_per_word[source_index] >= max_translation_candidates_per_word:
if (
counter_per_word[target_dialect][source_index]
>= max_translation_candidates_per_word
):
# don't save more than max_translation_candidates_per_word
# translation candidates for any one source token
break

# update translation candidates matrix at [source index, running counter
# per source token] to = target index
translation_candidates[
source_index, counter_per_word[source_index]
target_dialect, source_index, counter_per_word[target_dialect][source_index]
] = target_index_with_prob[0]
translation_candidates_set.update((source_index, target_index_with_prob[0]))
counter_per_word[source_index] += 1
translation_candidates_set.update(
(target_dialect, source_index, target_index_with_prob[0])
)
counter_per_word[target_dialect][source_index] += 1
translation_candidates_saved += 1
return translation_candidates_saved


# Dummy Change
def get_translation_candidates(
src_dict,
dst_dict,
Expand All @@ -104,33 +113,37 @@ def get_translation_candidates(
max_translation_candidates_per_word,
):
"""
Reads a lexical dictionary file, where each line is (source token, possible
translation of source token, probability). The file is generally grouped
by source tokens, but within the group, the probabilities are not
necessarily sorted.

A a 0.3
A c 0.1
A e 0.05
A f 0.01
B b 0.6
B b 0.2
A z 0.001
A y 0.002
Reads a lexical dictionary file, where each line is (target dialect ID, source
token, possible translation of source token, probability). The file is
generally grouped by target dialect and source tokens, but within the group,
the probabilities are not necessarily sorted.

0 A a 0.3
0 A c 0.1
0 A e 0.05
25 A f 0.01
25 B b 0.6
25 B b 0.2
25 A z 0.001
25 A y 0.002
...

Returns: translation_candidates
Matrix of shape (src_dict, max_translation_candidates_per_word) where
each row corresponds to a source word in the vocab and contains token
indices of translation candidates for that source word
Matrix of shape (constants.MAX_LANGUAGES, src_dict,
max_translation_candidates_per_word) where each row corresponds to a
source word in the vocab and contains token indices of translation
candidates for that source word
"""

translation_candidates = np.zeros(
[len(src_dict), max_translation_candidates_per_word], dtype=np.int32
[constants.MAX_LANGUAGES, len(src_dict), max_translation_candidates_per_word],
dtype=np.int32,
)

# running count of translation candidates per source word
counter_per_word = np.zeros(len(src_dict), dtype=np.int32)
# running count of translation candidates per source word per target dialect
counter_per_word = np.zeros(
[constants.MAX_LANGUAGES, len(src_dict)], dtype=np.int32
)

# tracks if we've already seen some (source token, target token) pair so we
# ignore duplicate lines
Expand All @@ -142,13 +155,14 @@ def get_translation_candidates(

with codecs.open(lexical_dictionary, "r", "utf-8") as lexical_dictionary_file:
current_source_index = None
current_target_dialect = None
current_target_indices = []
for line in lexical_dictionary_file.readlines():
alignment_data = line.split()
if len(alignment_data) != 3:
if len(alignment_data) != 4:
logger.warning(f"Malformed line in lexical dictionary: {line}")
continue
source_word, target_word, prob = alignment_data
target_dialect, source_word, target_word, prob = alignment_data
prob = float(prob)
source_index = src_dict.index(source_word)
target_index = dst_dict.index(target_word)
Expand All @@ -159,31 +173,37 @@ def get_translation_candidates(
continue

if source_index is not None and target_index is not None:
if source_index != current_source_index:
if (
source_index != current_source_index
and target_dialect != current_target_dialect
):
# We've finished processing the possible translation
# candidates for this source token group, so save the
# extracted translation candidates
translation_candidates_saved += select_top_candidate_per_word(
current_source_index,
current_target_dialect,
current_target_indices,
counter_per_word,
max_translation_candidates_per_word,
translation_candidates,
translation_candidates_set,
)
current_target_dialect = target_dialect
current_source_index = source_index
current_target_indices = []

if (
target_index >= num_top_words
and (source_index, target_index)
and (target_dialect, source_index, target_index)
not in translation_candidates_set
):
current_target_indices.append((target_index, prob))
# Save the extracted translation candidates for the last source token
# group
translation_candidates_saved += select_top_candidate_per_word(
current_source_index,
current_target_dialect,
current_target_indices,
counter_per_word,
max_translation_candidates_per_word,
Expand Down Expand Up @@ -230,7 +250,13 @@ def __init__(
)

# encoder_output is default None for backwards compatibility
def forward(self, src_tokens, encoder_output=None, decoder_input_tokens=None):
def forward(
self,
src_tokens,
encoder_output=None,
decoder_input_tokens=None,
target_dialect=0,
):
assert self.dst_dict.pad() == 0, (
f"VocabReduction only works correctly when the padding ID is 0 "
"(to ensure its position in possible_translation_tokens is also 0), "
Expand All @@ -243,7 +269,10 @@ def forward(self, src_tokens, encoder_output=None, decoder_input_tokens=None):
vocab_list.append(flat_decoder_input_tokens)

if self.translation_candidates is not None:
reduced_vocab = self.translation_candidates.index_select(
candidates_in_target_dialect = self.translation_candidates.index_select(
dim=0, index=target_dialect
)
reduced_vocab = candidates_in_target_dialect.index_select(
dim=0, index=src_tokens.view(-1)
).view(-1)
vocab_list.append(reduced_vocab)
Expand Down