Skip to content

Commit dcfd590

Browse files
committed
Initial commit
1 parent 0a2039c commit dcfd590

File tree

14 files changed

+1898
-2
lines changed

14 files changed

+1898
-2
lines changed

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
data/
2+
output/
3+
4+
*~
5+
.DS_Store
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
.idea/

LICENSE.txt

Lines changed: 674 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,71 @@
1-
# UNdreaMT (Unsupervised Neural Machine Translation)
2-
Coming soon ;)
1+
UNdreaMT: Unsupervised Neural Machine Translation
2+
==============
3+
4+
This is an open source implementation of our unsupervised neural machine translation system, described in the following paper:
5+
6+
Mikel Artetxe, Gorka Labaka, Eneko Agirre, and Kyunghyun Cho. 2018. **[Unsupervised Neural Machine Translation](https://arxiv.org/pdf/1710.11041.pdf)**. In *Proceedings of the Sixth International Conference on Learning Representations (ICLR 2018)*.
7+
8+
If you use this software for academic research, please cite the paper in question:
9+
```
10+
@inproceedings{artetxe2018iclr,
11+
author = {Artetxe, Mikel and Labaka, Gorka and Agirre, Eneko and Cho, Kyunghyun},
12+
title = {Unsupervised neural machine translation},
13+
booktitle = {Proceedings of the Sixth International Conference on Learning Representations},
14+
month = {April},
15+
year = {2018}
16+
}
17+
```
18+
19+
20+
Requirements
21+
--------
22+
- Python 3
23+
- PyTorch (tested with v0.3)
24+
25+
26+
Usage
27+
--------
28+
29+
The following command trains an unsupervised NMT system from monolingual corpora using the exact same settings described in the paper:
30+
31+
```
32+
python3 train.py --src SRC.MONO.TXT --trg TRG.MONO.TXT --src_embeddings SRC.EMB.TXT --trg_embeddings TRG.EMB.TXT --save MODEL_PREFIX --cuda
33+
```
34+
35+
The data in the above command should be provided as follows:
36+
- `SRC.MONO.TXT` and `TRG.MONO.TXT` are the source and target language monolingual corpora. They should both be pre-processed so atomic symbols (either tokens or BPE units) are separated by whitespaces. For that purpose, we recommend using [Moses](http://www.statmt.org/moses/) to tokenize and truecase the corpora and, optionally, [Subword-NMT](https://github.com/rsennrich/subword-nmt) if you want to use BPE.
37+
- `SRC.EMB.TXT` and `TRG.EMB.TXT` are the source and target language cross-lingual embeddings. In order to obtain them, we recommend training monolingual embeddings in the corpora above using either [word2vec](https://github.com/tmikolov/word2vec) or [fasttext](https://github.com/facebookresearch/fastText), and then map them to a shared space using [VecMap](https://github.com/artetxem/vecmap). Please make sure to cutoff the vocabulary as desired before mapping the embeddings.
38+
- `MODEL_PREFIX` is the prefix of the output model.
39+
40+
Using the above settings, training takes about 3 days in a single Titan Xp. Once training is done, you can use the resulting model for translation as follows:
41+
42+
```
43+
python3 translate.py MODEL_PREFIX.final.src2trg.pth < INPUT.TXT > OUTPUT.TXT
44+
```
45+
46+
For more details and additional options, run the above scripts with the `--help` flag.
47+
48+
49+
FAQ
50+
--------
51+
52+
###### You claim that your unsupervised NMT system is trained on monolingual corpora alone, but it also requires bilingual embeddings... Isn't that cheating?
53+
54+
Not really, because we also learn the bilingual embeddings from monolingual corpora alone. We use our companion tool [VecMap](https://github.com/artetxem/vecmap) for that.
55+
56+
57+
###### Can I use this software to train a regular NMT system on parallel corpora?
58+
59+
Yes! You can use the following arguments to make UNdreaMT behave like a regular NMT system:
60+
61+
```
62+
python3 train.py --src2trg SRC.PARALLEL.TXT TRG.PARALLEL.TXT --src_vocabulary SRC.VOCAB.TXT --trg_vocabulary TRG.VOCAB.TXT --embedding_size 300 --learn_encoder_embeddings --disable_denoising --save MODEL_PREFIX --cuda
63+
```
64+
65+
66+
License
67+
-------
68+
69+
Copyright (C) 2018, Mikel Artetxe
70+
71+
Licensed under the terms of the GNU General Public License, either version 3 or (at your option) any later version. A full copy of the license can be found in LICENSE.txt.

train.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (C) 2018 Mikel Artetxe <[email protected]>
2+
#
3+
# This program is free software: you can redistribute it and/or modify
4+
# it under the terms of the GNU General Public License as published by
5+
# the Free Software Foundation, either version 3 of the License, or
6+
# (at your option) any later version.
7+
#
8+
# This program is distributed in the hope that it will be useful,
9+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
# GNU General Public License for more details.
12+
#
13+
# You should have received a copy of the GNU General Public License
14+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
15+
16+
import undreamt.train
17+
18+
19+
if __name__ == '__main__':
20+
undreamt.train.main_train()

translate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (C) 2018 Mikel Artetxe <[email protected]>
2+
#
3+
# This program is free software: you can redistribute it and/or modify
4+
# it under the terms of the GNU General Public License as published by
5+
# the Free Software Foundation, either version 3 of the License, or
6+
# (at your option) any later version.
7+
#
8+
# This program is distributed in the hope that it will be useful,
9+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
# GNU General Public License for more details.
12+
#
13+
# You should have received a copy of the GNU General Public License
14+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
15+
16+
import argparse
17+
import sys
18+
import torch
19+
20+
21+
def main():
22+
# Parse command line arguments
23+
parser = argparse.ArgumentParser(description='Translate using a pre-trained model')
24+
parser.add_argument('model', help='a model previously trained with train.py')
25+
parser.add_argument('--batch_size', type=int, default=50, help='the batch size (defaults to 50)')
26+
parser.add_argument('--beam_size', type=int, default=12, help='the beam size (defaults to 12, 0 for greedy search)')
27+
parser.add_argument('--encoding', default='utf-8', help='the character encoding for input/output (defaults to utf-8)')
28+
parser.add_argument('-i', '--input', default=sys.stdin.fileno(), help='the input file (defaults to stdin)')
29+
parser.add_argument('-o', '--output', default=sys.stdout.fileno(), help='the output file (defaults to stdout)')
30+
args = parser.parse_args()
31+
32+
# Load model
33+
translator = torch.load(args.model)
34+
35+
# Translate sentences
36+
end = False
37+
fin = open(args.input, encoding=args.encoding, errors='surrogateescape')
38+
fout = open(args.output, mode='w', encoding=args.encoding, errors='surrogateescape')
39+
while not end:
40+
batch = []
41+
while len(batch) < args.batch_size and not end:
42+
line = fin.readline()
43+
if not line:
44+
end = True
45+
else:
46+
batch.append(line)
47+
if args.beam_size <= 0 and len(batch) > 0:
48+
for translation in translator.greedy(batch, train=False):
49+
print(translation, file=fout)
50+
elif len(batch) > 0:
51+
for translation in translator.beam_search(batch, train=False, beam_size=args.beam_size):
52+
print(translation, file=fout)
53+
fout.flush()
54+
fin.close()
55+
fout.close()
56+
57+
58+
if __name__ == '__main__':
59+
main()

undreamt/__init__.py

Whitespace-only changes.

undreamt/attention.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (C) 2018 Mikel Artetxe <[email protected]>
2+
#
3+
# This program is free software: you can redistribute it and/or modify
4+
# it under the terms of the GNU General Public License as published by
5+
# the Free Software Foundation, either version 3 of the License, or
6+
# (at your option) any later version.
7+
#
8+
# This program is distributed in the hope that it will be useful,
9+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
# GNU General Public License for more details.
12+
#
13+
# You should have received a copy of the GNU General Public License
14+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
15+
16+
import torch.nn as nn
17+
18+
19+
class GlobalAttention(nn.Module):
20+
def __init__(self, dim, alignment_function='general'):
21+
super(GlobalAttention, self).__init__()
22+
self.alignment_function = alignment_function
23+
if self.alignment_function == 'general':
24+
self.linear_align = nn.Linear(dim, dim, bias=False)
25+
elif self.alignment_function != 'dot':
26+
raise ValueError('Invalid alignment function: {0}'.format(alignment_function))
27+
self.softmax = nn.Softmax(dim=1)
28+
self.linear_context = nn.Linear(dim, dim, bias=False)
29+
self.linear_query = nn.Linear(dim, dim, bias=False)
30+
self.tanh = nn.Tanh()
31+
32+
def forward(self, query, context, mask):
33+
# query: batch*dim
34+
# context: length*batch*dim
35+
# ans: batch*dim
36+
37+
context_t = context.transpose(0, 1) # batch*length*dim
38+
39+
# Compute alignment scores
40+
q = query if self.alignment_function == 'dot' else self.linear_align(query)
41+
align = context_t.bmm(q.unsqueeze(2)).squeeze(2) # batch*length
42+
43+
# Mask alignment scores
44+
if mask is not None:
45+
align.data.masked_fill_(mask, -float('inf'))
46+
47+
# Compute attention from alignment scores
48+
attention = self.softmax(align) # batch*length
49+
50+
# Computed weighted context
51+
weighted_context = attention.unsqueeze(1).bmm(context_t).squeeze(1) # batch*dim
52+
53+
# Combine context and query
54+
return self.tanh(self.linear_context(weighted_context) + self.linear_query(query))

0 commit comments

Comments
 (0)