Skip to content

Commit

Permalink
Implementation of the paper "Jointly Learning to Align & Translate wi…
Browse files Browse the repository at this point in the history
…th Transformer" (#1615)
  • Loading branch information
Zenglinxiao authored and francoishernandez committed Nov 22, 2019
1 parent 17feb20 commit b98fb3d
Show file tree
Hide file tree
Showing 26 changed files with 746 additions and 131 deletions.
64 changes: 50 additions & 14 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@ the script is a slightly modified version of ylhsieh's one2.

Usage:

```
```shell
embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
[-emb_file_enc EMB_FILE_ENC]
[-emb_file_dec EMB_FILE_DEC] -output_file
OUTPUT_FILE -dict_file DICT_FILE [-verbose]
[-skip_lines SKIP_LINES]
[-type {GloVe,word2vec}]
```

Run embeddings_to_torch.py -h for more usagecomplete info.

Example


1) get GloVe files:

```
```shell
mkdir "glove_dir"
wget http://nlp.stanford.edu/data/glove.6B.zip
unzip glove.6B.zip -d "glove_dir"
```

2) prepare data:

```
```shell
onmt_preprocess \
-train_src data/train.src.txt \
-train_tgt data/train.tgt.txt \
Expand All @@ -42,15 +42,15 @@ onmt_preprocess \

3) prepare embeddings:

```
```shell
./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \
-dict_file "data/data.vocab.pt" \
-output_file "data/embeddings"
```

4) train using pre-trained embeddings:

```
```shell
onmt_train -save_model data/model \
-batch_size 64 \
-layers 2 \
Expand All @@ -61,14 +61,13 @@ onmt_train -save_model data/model \
-data data/data
```


## How do I use the Transformer model?

The transformer model is very sensitive to hyperparameters. To run it
effectively you need to set a bunch of different options that mimic the Google
setup. We have confirmed the following command can replicate their WMT results.

```
```shell
python train.py -data /tmp/de2/data -save_model /tmp/extra \
-layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \
-encoder_type transformer -decoder_type transformer -position_encoding \
Expand All @@ -77,17 +76,16 @@ python train.py -data /tmp/de2/data -save_model /tmp/extra \
-optim adam -adam_beta2 0.998 -decay_method noam -warmup_steps 8000 -learning_rate 2 \
-max_grad_norm 0 -param_init 0 -param_init_glorot \
-label_smoothing 0.1 -valid_steps 10000 -save_checkpoint_steps 10000 \
-world_size 4 -gpu_ranks 0 1 2 3
-world_size 4 -gpu_ranks 0 1 2 3
```

Here are what each of the parameters mean:

* `param_init_glorot` `-param_init 0`: correct initialization of parameters
* `position_encoding`: add sinusoidal position encoding to each embedding
* `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate.
* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches.
- `label_smoothing 0.1`: use label smoothing loss.

* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches.
* `label_smoothing 0.1`: use label smoothing loss.

## Do you support multi-gpu?

Expand All @@ -98,6 +96,7 @@ If you want to use GPU id 1 and 3 of your OS, you will need to `export CUDA_VISI
Both `-world_size` and `-gpu_ranks` need to be set. E.g. `-world_size 4 -gpu_ranks 0 1 2 3` will use 4 GPU on this node only.

If you want to use 2 nodes with 2 GPU each, you need to set `-master_ip` and `-master_port`, and

* `-world_size 4 -gpu_ranks 0 1`: on the first node
* `-world_size 4 -gpu_ranks 2 3`: on the second node
* `-accum_count 2`: This will accumulate over 2 batches before updating parameters.
Expand All @@ -122,27 +121,64 @@ Bear in mind that your models must share the same target vocabulary.
We introduced `-train_ids` which is a list of IDs that will be given to the preprocessed shards.

E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslation.en` `from_backtranslation.de`, we can pass the following in the `preprocess.py` command:
```

```shell
...
-train_src parallel.en from_backtranslation.en \
-train_tgt parallel.de from_backtranslation.de \
-train_ids A B \
-save_data my_data \
...
```

and it will dump `my_data.train_A.X.pt` based on `parallel.en`//`parallel.de` and `my_data.train_B.X.pt` based on `from_backtranslation.en`//`from_backtranslation.de`.

### Training

We introduced `-data_ids` based on the same principle as above, as well as `-data_weights`, which is the list of the weight each corpus should have.
E.g.
```

```shell
...
-data my_data \
-data_ids A B \
-data_weights 1 7 \
...
```

will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on.

**Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing.

## Can I get word alignment while translating?

### Raw alignments from averaging Transformer attention heads

Currently, we support producing word alignment while translating for Transformer based models. Using `-report_align` when calling `translate.py` will output the inferred alignments in Pharaoh format. Those alignments are computed from an argmax on the average of the attention heads of the *second to last* decoder layer. The resulting alignment src-tgt (Pharaoh) will be pasted to the translation sentence, separated by ` ||| `.
Note: The *second to last* default behaviour was empirically determined. It is not the same as the paper (they take the *penultimate* layer), probably because of light differences in the architecture.

* alignments use the standard "Pharaoh format", where a pair `i-j` indicates the i<sub>th</sub> word of source language is aligned to j<sub>th</sub> word of target language.
* Example: {'src': 'das stimmt nicht !'; 'output': 'that is not true ! ||| 0-0 0-1 1-2 2-3 1-4 1-5 3-6'}
* Using the`-tgt` option when calling `translate.py`, we output alignments between the source and the gold target rather than the inferred target, assuming we're doing evaluation.
* To convert subword alignments to word alignments, or symetrize bidirectional alignments, please refer to the [lilt scripts](https://github.com/lilt/alignment-scripts).

### Supervised learning on a specific head

The quality of output alignments can be further improved by providing reference alignments while training. This will invoke multi-task learning on translation and alignment. This is an implementation based on the paper [Jointly Learning to Align and Translate with Transformer Models](https://arxiv.org/abs/1909.02074).

The data need to be preprocessed with the reference alignments in order to learn the supervised task.

When calling `preprocess.py`, add:

* `--train_align <path>`: path(s) to the training alignments in Pharaoh format
* `--valid_align <path>`: path to the validation set alignments in Pharaoh format (optional).
The reference alignment file(s) could be generated by [GIZA++](https://github.com/moses-smt/mgiza/) or [fast_align](https://github.com/clab/fast_align).

Note: There should be no blank lines in the alignment files provided.

Options to learn such alignments are:

* `-lambda_align`: set the value > 0.0 to enable joint align training, the paper suggests 0.05;
* `-alignment_layer`: indicate the index of the decoder layer;
* `-alignment_heads`: number of alignment heads for the alignment task - should be set to 1 for the supervised task, and preferably kept to default (or same as `num_heads`) for the average task;
* `-full_context_alignment`: do full context decoder pass (no future mask) when computing alignments. This will slow down the training (~12% in terms of tok/s) but will be beneficial to generate better alignment.
10 changes: 10 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,13 @@ @article{DBLP:journals/corr/MartinsA16
biburl = {https://dblp.org/rec/bib/journals/corr/MartinsA16},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
address = {Hong Kong},
month = {November},
url = {https://arxiv.org/abs/1909.02074},
year = {2019},
}
48 changes: 29 additions & 19 deletions onmt/bin/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,22 @@ def check_existing_pt_files(opt, corpus_type, ids, existing_fields):


def process_one_shard(corpus_params, params):
corpus_type, fields, src_reader, tgt_reader, opt, existing_fields,\
src_vocab, tgt_vocab = corpus_params
i, (src_shard, tgt_shard, maybe_id, filter_pred) = params
corpus_type, fields, src_reader, tgt_reader, align_reader, opt,\
existing_fields, src_vocab, tgt_vocab = corpus_params
i, (src_shard, tgt_shard, align_shard, maybe_id, filter_pred) = params
# create one counter per shard
sub_sub_counter = defaultdict(Counter)
assert len(src_shard) == len(tgt_shard)
logger.info("Building shard %d." % i)

src_data = {"reader": src_reader, "data": src_shard, "dir": opt.src_dir}
tgt_data = {"reader": tgt_reader, "data": tgt_shard, "dir": None}
align_data = {"reader": align_reader, "data": align_shard, "dir": None}
_readers, _data, _dir = inputters.Dataset.config(
[('src', src_data), ('tgt', tgt_data), ('align', align_data)])

dataset = inputters.Dataset(
fields,
readers=([src_reader, tgt_reader]
if tgt_reader else [src_reader]),
data=([("src", src_shard), ("tgt", tgt_shard)]
if tgt_reader else [("src", src_shard)]),
dirs=([opt.src_dir, None]
if tgt_reader else [opt.src_dir]),
fields, readers=_readers, data=_data, dirs=_dir,
sort_key=inputters.str2sortkey[opt.data_type],
filter_pred=filter_pred
)
Expand Down Expand Up @@ -125,19 +126,22 @@ def maybe_load_vocab(corpus_type, counters, opt):
return src_vocab, tgt_vocab, existing_fields


def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader,
align_reader, opt):
assert corpus_type in ['train', 'valid']

if corpus_type == 'train':
counters = defaultdict(Counter)
srcs = opt.train_src
tgts = opt.train_tgt
ids = opt.train_ids
aligns = opt.train_align
elif corpus_type == 'valid':
counters = None
srcs = [opt.valid_src]
tgts = [opt.valid_tgt]
ids = [None]
aligns = [opt.valid_align]

src_vocab, tgt_vocab, existing_fields = maybe_load_vocab(
corpus_type, counters, opt)
Expand All @@ -149,12 +153,12 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
if existing_shards == ids and not opt.overwrite:
return

def shard_iterator(srcs, tgts, ids, existing_shards,
def shard_iterator(srcs, tgts, ids, aligns, existing_shards,
existing_fields, corpus_type, opt):
"""
Builds a single iterator yielding every shard of every corpus.
"""
for src, tgt, maybe_id in zip(srcs, tgts, ids):
for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns):
if maybe_id in existing_shards:
if opt.overwrite:
logger.warning("Overwrite shards for corpus {}"
Expand All @@ -180,15 +184,18 @@ def shard_iterator(srcs, tgts, ids, existing_shards,
filter_pred = None
src_shards = split_corpus(src, opt.shard_size)
tgt_shards = split_corpus(tgt, opt.shard_size)
for i, (ss, ts) in enumerate(zip(src_shards, tgt_shards)):
yield (i, (ss, ts, maybe_id, filter_pred))
align_shards = split_corpus(maybe_align, opt.shard_size)
for i, (ss, ts, a_s) in enumerate(
zip(src_shards, tgt_shards, align_shards)):
yield (i, (ss, ts, a_s, maybe_id, filter_pred))

shard_iter = shard_iterator(srcs, tgts, ids, existing_shards,
shard_iter = shard_iterator(srcs, tgts, ids, aligns, existing_shards,
existing_fields, corpus_type, opt)

with Pool(opt.num_threads) as p:
dataset_params = (corpus_type, fields, src_reader, tgt_reader,
opt, existing_fields, src_vocab, tgt_vocab)
align_reader, opt, existing_fields,
src_vocab, tgt_vocab)
func = partial(process_one_shard, dataset_params)
for sub_counter in p.imap(func, shard_iter):
if sub_counter is not None:
Expand Down Expand Up @@ -253,19 +260,22 @@ def preprocess(opt):
src_nfeats,
tgt_nfeats,
dynamic_dict=opt.dynamic_dict,
with_align=opt.train_align[0] is not None,
src_truncate=opt.src_seq_length_trunc,
tgt_truncate=opt.tgt_seq_length_trunc)

src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
tgt_reader = inputters.str2reader["text"].from_opt(opt)
align_reader = inputters.str2reader["text"].from_opt(opt)

logger.info("Building & saving training data...")
build_save_dataset(
'train', fields, src_reader, tgt_reader, opt)
'train', fields, src_reader, tgt_reader, align_reader, opt)

if opt.valid_src and opt.valid_tgt:
logger.info("Building & saving validation data...")
build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
build_save_dataset(
'valid', fields, src_reader, tgt_reader, align_reader, opt)


def _get_parser():
Expand Down
5 changes: 4 additions & 1 deletion onmt/bin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,17 @@ def translate():
inputs = request.get_json(force=True)
out = {}
try:
trans, scores, n_best, times = translation_server.run(inputs)
trans, scores, n_best, _, aligns = translation_server.run(inputs)
assert len(trans) == len(inputs) * n_best
assert len(scores) == len(inputs) * n_best
assert len(aligns) == len(inputs) * n_best

out = [[] for _ in range(n_best)]
for i in range(len(trans)):
response = {"src": inputs[i // n_best]['src'], "tgt": trans[i],
"n_best": n_best, "pred_score": scores[i]}
if aligns[i] is not None:
response["align"] = aligns[i]
out[i % n_best].append(response)
except ServerModelError as e:
out['error'] = str(e)
Expand Down
2 changes: 2 additions & 0 deletions onmt/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def next_batch(device_id):
if hasattr(b, 'alignment') else None
b.src_map = b.src_map.to(torch.device(device_id)) \
if hasattr(b, 'src_map') else None
b.align = b.align.to(torch.device(device_id)) \
if hasattr(b, 'align') else None

# hack to dodge unpicklable `dict_keys`
b.fields = list(b.fields)
Expand Down
7 changes: 3 additions & 4 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-

from __future__ import unicode_literals
from itertools import repeat

from onmt.utils.logging import init_logger
from onmt.utils.misc import split_corpus
Expand All @@ -18,8 +17,7 @@ def translate(opt):

translator = build_translator(opt, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
if opt.tgt is not None else repeat(None)
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
shard_pairs = zip(src_shards, tgt_shards)

for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
Expand All @@ -30,7 +28,8 @@ def translate(opt):
src_dir=opt.src_dir,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
attn_debug=opt.attn_debug
attn_debug=opt.attn_debug,
align_debug=opt.align_debug
)


Expand Down
3 changes: 2 additions & 1 deletion onmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def detach_state(self):
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
self.state["input_feed"] = self.state["input_feed"].detach()

def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
**kwargs):
"""
Args:
tgt (LongTensor): sequences of padded tokens
Expand Down
3 changes: 2 additions & 1 deletion onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(self, model_decoders):
super(EnsembleDecoder, self).__init__(attentional)
self.model_decoders = model_decoders

def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
**kwargs):
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
# Memory_lengths is a single tensor shared between all models.
# This assumption will not hold if Translator is modified
Expand Down
Loading

0 comments on commit b98fb3d

Please sign in to comment.