Skip to content

Commit 252c8a5

Browse files
ajoulinfacebook-github-bot
authored andcommitted
Add unsupervised multilingual alignement
Summary: Add a script for unsupervised multilingual alignment. Reviewed By: EdouardGrave Differential Revision: D17180273 fbshipit-source-id: edbb139ff9474ef325a43bb16e9c0cf1a76e0900
1 parent cc325e5 commit 252c8a5

File tree

2 files changed

+215
-3
lines changed

2 files changed

+215
-3
lines changed

alignment/README.md

+17-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ The details of this approach can be found in [1].
1616
### Unsupervised alignment
1717

1818
The script `unsup_align.py` aligns word embeddings from two languages without requiring any supervision.
19-
The details of this approach can be found in [2].
19+
Additionally, the script `unsup_multialign.py` aligns multiple languages to a common space with no supervision.
20+
The details of these approaches can be found in [2] and [3] respectively.
2021

21-
In addition to NumPy, the unsupervised method requires the [Python Optimal Transport](https://pot.readthedocs.io/en/stable/) toolbox.
22+
In addition to NumPy, the unsupervised methods require the [Python Optimal Transport](https://pot.readthedocs.io/en/stable/) toolbox.
2223

2324
### Download
2425

@@ -39,7 +40,7 @@ If you use the supervised alignment method, please cite:
3940
}
4041
```
4142

42-
If you use the unsupervised alignment method, please cite:
43+
If you use the unsupervised bilingual alignment method, please cite:
4344

4445
[2] E. Grave, A. Joulin, Q. Berthet, [*Unsupervised Alignment of Embeddings with Wasserstein Procrustes*](https://arxiv.org/abs/1805.11222)
4546

@@ -51,3 +52,16 @@ If you use the unsupervised alignment method, please cite:
5152
year={2018}
5253
}
5354
```
55+
56+
If you use the unsupervised alignment script `unsup_multialign.py`, please cite:
57+
58+
[3] J. Alaux, E. Grave, M. Cuturi, A. Joulin, [*Unsupervised Hyperalignment for Multilingual Word Embeddings*](https://arxiv.org/abs/1811.01124)
59+
60+
```
61+
@article{alaux2018unsupervised,
62+
title={Unsupervised hyperalignment for multilingual word embeddings},
63+
author={Alaux, Jean and Grave, Edouard and Cuturi, Marco and Joulin, Armand},
64+
journal={arXiv preprint arXiv:1811.01124},
65+
year={2018}
66+
}
67+
```

alignment/unsup_multialign.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2019-present, Facebook, Inc.
5+
# All rights reserved.
6+
#
7+
# This source code is licensed under the license found in the
8+
# LICENSE file in the root directory of this source tree.
9+
10+
import io, os, ot, argparse, random
11+
import numpy as np
12+
from utils import *
13+
14+
parser = argparse.ArgumentParser(description=' ')
15+
16+
parser.add_argument('--embdir', default='data/', type=str)
17+
parser.add_argument('--outdir', default='output/', type=str)
18+
parser.add_argument('--lglist', default='en-fr-es-it-pt-de-pl-ru-da-nl-cs', type=str,
19+
help='list of languages. The first element is the pivot. Example: en-fr-es to align English, French and Spanish with English as the pivot.')
20+
21+
parser.add_argument('--maxload', default=20000, type=int, help='Max number of loaded vectors')
22+
parser.add_argument('--uniform', action='store_true', help='switch to uniform probability of picking language pairs')
23+
24+
# optimization parameters for the square loss
25+
parser.add_argument('--epoch', default=2, type=int, help='nb of epochs for square loss')
26+
parser.add_argument('--niter', default=500, type=int, help='max number of iteration per epoch for square loss')
27+
parser.add_argument('--lr', default=0.1, type=float, help='learning rate for square loss')
28+
parser.add_argument('--bsz', default=500, type=int, help='batch size for square loss')
29+
30+
# optimization parameters for the RCSLS loss
31+
parser.add_argument('--altepoch', default=100, type=int, help='nb of epochs for RCSLS loss')
32+
parser.add_argument('--altlr', default=25, type=float, help='learning rate for RCSLS loss')
33+
parser.add_argument("--altbsz", type=int, default=1000, help="batch size for RCSLS")
34+
35+
args = parser.parse_args()
36+
37+
###### SPECIFIC FUNCTIONS ######
38+
39+
def getknn(sc, x, y, k=10):
40+
sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
41+
ytopk = y[sidx.flatten(), :]
42+
ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
43+
f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
44+
df = np.dot(ytopk.sum(1).T, x)
45+
return f / k, df / k
46+
47+
48+
def rcsls(Xi, Xj, Zi, Zj, R, knn=10):
49+
X_trans = np.dot(Xi, R.T)
50+
f = 2 * np.sum(X_trans * Xj)
51+
df = 2 * np.dot(Xj.T, Xi)
52+
fk0, dfk0 = getknn(np.dot(X_trans, Zj.T), Xi, Zj, knn)
53+
fk1, dfk1 = getknn(np.dot(np.dot(Zi, R.T), Xj.T).T, Xj, Zi, knn)
54+
f = f - fk0 -fk1
55+
df = df - dfk0 - dfk1.T
56+
return -f / Xi.shape[0], -df.T / Xi.shape[0]
57+
58+
59+
def GWmatrix(emb0):
60+
N = np.shape(emb0)[0]
61+
N2 = .5* np.linalg.norm(emb0, axis=1).reshape(1, N)
62+
C2 = np.tile(N2.transpose(), (1, N)) + np.tile(N2, (N, 1))
63+
C2 -= np.dot(emb0,emb0.T)
64+
return C2
65+
66+
67+
def gromov_wasserstein(x_src, x_tgt, C2):
68+
N = x_src.shape[0]
69+
C1 = GWmatrix(x_src)
70+
M = ot.gromov_wasserstein(C1,C2,np.ones(N),np.ones(N),'square_loss',epsilon=0.55,max_iter=100,tol=1e-4)
71+
return procrustes(np.dot(M,x_tgt), x_src)
72+
73+
74+
def align(EMB, TRANS, lglist, args):
75+
nmax, l = args.maxload, len(lglist)
76+
# create a list of language pairs to sample from
77+
# (default == higher probability to pick a language pair contianing the pivot)
78+
# if --uniform: uniform probability of picking a language pair
79+
samples = []
80+
for i in range(l):
81+
for j in range(l):
82+
if j == i :
83+
continue
84+
if j > 0 and args.uniform == False:
85+
samples.append((0,j))
86+
if i > 0 and args.uniform == False:
87+
samples.append((i,0))
88+
samples.append((i,j))
89+
90+
# optimization of the l2 loss
91+
print('start optimizing L2 loss')
92+
lr0, bsz, nepoch, niter = args.lr, args.bsz, args.epoch, args.niter
93+
for epoch in range(nepoch):
94+
print("start epoch %d / %d"%(epoch+1, nepoch))
95+
ones = np.ones(bsz)
96+
f, fold, nb, lr = 0.0, 0.0, 0.0, lr0
97+
for it in range(niter):
98+
if it > 1 and f > fold + 1e-3:
99+
lr /= 2
100+
if lr < .05:
101+
break
102+
fold = f
103+
f, nb = 0.0, 0.0
104+
for k in range(100 * (l-1)):
105+
(i,j) = random.choice(samples)
106+
embi = EMB[i][np.random.permutation(nmax)[:bsz], :]
107+
embj = EMB[j][np.random.permutation(nmax)[:bsz], :]
108+
perm = ot.sinkhorn(ones, ones, np.linalg.multi_dot([embi, -TRANS[i], TRANS[j].T,embj.T]), reg = 0.025, stopThr = 1e-3)
109+
grad = np.linalg.multi_dot([embi.T, perm, embj])
110+
f -= np.trace(np.linalg.multi_dot([TRANS[i].T, grad, TRANS[j]])) / embi.shape[0]
111+
nb += 1
112+
if i > 0:
113+
TRANS[i] = proj_ortho(TRANS[i] + lr * np.dot(grad, TRANS[j]))
114+
if j > 0:
115+
TRANS[j] = proj_ortho(TRANS[j] + lr * np.dot(grad.transpose(), TRANS[i]))
116+
print("iter %d / %d - epoch %d - loss: %.5f lr: %.4f" % (it, niter, epoch+1, f / nb , lr))
117+
print("end of epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
118+
niter, bsz = max(int(niter/2),2), min(1000, bsz * 2)
119+
#end for epoch in range(nepoch):
120+
121+
# optimization of the RCSLS loss
122+
print('start optimizing RCSLS loss')
123+
f, fold, nb, lr = 0.0, 0.0, 0.0, args.altlr
124+
for epoch in range(args.altepoch):
125+
if epoch > 1 and f-fold > -1e-4 * abs(fold):
126+
lr/= 2
127+
if lr < 1e-1:
128+
break
129+
fold = f
130+
f, nb = 0.0, 0.0
131+
for k in range(round(nmax / args.altbsz) * 10 * (l-1)):
132+
(i,j) = random.choice(samples)
133+
sgdidx = np.random.choice(nmax, size=args.altbsz, replace=False)
134+
embi = EMB[i][sgdidx, :]
135+
embj = EMB[j][:nmax, :]
136+
# crude alignment approximation:
137+
T = np.dot(TRANS[i], TRANS[j].T)
138+
scores = np.linalg.multi_dot([embi, T, embj.T])
139+
perm = np.zeros_like(scores)
140+
perm[np.arange(len(scores)), scores.argmax(1)] = 1
141+
embj = np.dot(perm, embj)
142+
# normalization over a subset of embeddings for speed up
143+
fi, grad = rcsls(embi, embj, embi, embj, T.T)
144+
f += fi
145+
nb += 1
146+
if i > 0:
147+
TRANS[i] = proj_ortho(TRANS[i] - lr * np.dot(grad, TRANS[j]))
148+
if j > 0:
149+
TRANS[j] = proj_ortho(TRANS[j] - lr * np.dot(grad.transpose(), TRANS[i]))
150+
print("epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
151+
#end for epoch in range(args.altepoch):
152+
return TRANS
153+
154+
def convex_init(X, Y, niter=100, reg=0.05, apply_sqrt=False):
155+
n, d = X.shape
156+
K_X, K_Y = np.dot(X, X.T), np.dot(Y, Y.T)
157+
K_Y *= np.linalg.norm(K_X) / np.linalg.norm(K_Y)
158+
K2_X, K2_Y = np.dot(K_X, K_X), np.dot(K_Y, K_Y)
159+
P = np.ones([n, n]) / float(n)
160+
for it in range(1, niter + 1):
161+
G = np.dot(P, K2_X) + np.dot(K2_Y, P) - 2 * np.dot(K_Y, np.dot(P, K_X))
162+
q = ot.sinkhorn(np.ones(n), np.ones(n), G, reg, stopThr=1e-3)
163+
alpha = 2.0 / float(2.0 + it)
164+
P = alpha * q + (1.0 - alpha) * P
165+
return procrustes(np.dot(P, X), Y).T
166+
167+
168+
###### MAIN ######
169+
170+
lglist = args.lglist.split('-')
171+
l = len(lglist)
172+
173+
# embs:
174+
EMB = {}
175+
for i in range(l):
176+
fn = args.embdir + '/wiki.' + lglist[i] + '.vec'
177+
_, vecs = load_vectors(fn, maxload=args.maxload)
178+
EMB[i] = vecs
179+
180+
#init
181+
print("Computing initial bilingual apping with Gromov-Wasserstein...")
182+
TRANS={}
183+
maxinit = 2000
184+
emb0 = EMB[0][:maxinit,:]
185+
C0 = GWmatrix(emb0)
186+
TRANS[0] = np.eye(300)
187+
for i in range(1, l):
188+
print("init "+lglist[i])
189+
embi = EMB[i][:maxinit,:]
190+
TRANS[i] = gromov_wasserstein(embi, emb0, C0)
191+
192+
# align
193+
align(EMB, TRANS, lglist, args)
194+
195+
print('saving matrices in ' + args.outdir)
196+
languages=''.join(lglist)
197+
for i in range(l):
198+
save_matrix(args.outdir + '/W-' + languages + '-' + lglist[i], TRANS[i])

0 commit comments

Comments
 (0)