|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Copyright (c) 2019-present, Facebook, Inc. |
| 5 | +# All rights reserved. |
| 6 | +# |
| 7 | +# This source code is licensed under the license found in the |
| 8 | +# LICENSE file in the root directory of this source tree. |
| 9 | + |
| 10 | +import io, os, ot, argparse, random |
| 11 | +import numpy as np |
| 12 | +from utils import * |
| 13 | + |
| 14 | +parser = argparse.ArgumentParser(description=' ') |
| 15 | + |
| 16 | +parser.add_argument('--embdir', default='data/', type=str) |
| 17 | +parser.add_argument('--outdir', default='output/', type=str) |
| 18 | +parser.add_argument('--lglist', default='en-fr-es-it-pt-de-pl-ru-da-nl-cs', type=str, |
| 19 | + help='list of languages. The first element is the pivot. Example: en-fr-es to align English, French and Spanish with English as the pivot.') |
| 20 | + |
| 21 | +parser.add_argument('--maxload', default=20000, type=int, help='Max number of loaded vectors') |
| 22 | +parser.add_argument('--uniform', action='store_true', help='switch to uniform probability of picking language pairs') |
| 23 | + |
| 24 | +# optimization parameters for the square loss |
| 25 | +parser.add_argument('--epoch', default=2, type=int, help='nb of epochs for square loss') |
| 26 | +parser.add_argument('--niter', default=500, type=int, help='max number of iteration per epoch for square loss') |
| 27 | +parser.add_argument('--lr', default=0.1, type=float, help='learning rate for square loss') |
| 28 | +parser.add_argument('--bsz', default=500, type=int, help='batch size for square loss') |
| 29 | + |
| 30 | +# optimization parameters for the RCSLS loss |
| 31 | +parser.add_argument('--altepoch', default=100, type=int, help='nb of epochs for RCSLS loss') |
| 32 | +parser.add_argument('--altlr', default=25, type=float, help='learning rate for RCSLS loss') |
| 33 | +parser.add_argument("--altbsz", type=int, default=1000, help="batch size for RCSLS") |
| 34 | + |
| 35 | +args = parser.parse_args() |
| 36 | + |
| 37 | +###### SPECIFIC FUNCTIONS ###### |
| 38 | + |
| 39 | +def getknn(sc, x, y, k=10): |
| 40 | + sidx = np.argpartition(sc, -k, axis=1)[:, -k:] |
| 41 | + ytopk = y[sidx.flatten(), :] |
| 42 | + ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1]) |
| 43 | + f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx]) |
| 44 | + df = np.dot(ytopk.sum(1).T, x) |
| 45 | + return f / k, df / k |
| 46 | + |
| 47 | + |
| 48 | +def rcsls(Xi, Xj, Zi, Zj, R, knn=10): |
| 49 | + X_trans = np.dot(Xi, R.T) |
| 50 | + f = 2 * np.sum(X_trans * Xj) |
| 51 | + df = 2 * np.dot(Xj.T, Xi) |
| 52 | + fk0, dfk0 = getknn(np.dot(X_trans, Zj.T), Xi, Zj, knn) |
| 53 | + fk1, dfk1 = getknn(np.dot(np.dot(Zi, R.T), Xj.T).T, Xj, Zi, knn) |
| 54 | + f = f - fk0 -fk1 |
| 55 | + df = df - dfk0 - dfk1.T |
| 56 | + return -f / Xi.shape[0], -df.T / Xi.shape[0] |
| 57 | + |
| 58 | + |
| 59 | +def GWmatrix(emb0): |
| 60 | + N = np.shape(emb0)[0] |
| 61 | + N2 = .5* np.linalg.norm(emb0, axis=1).reshape(1, N) |
| 62 | + C2 = np.tile(N2.transpose(), (1, N)) + np.tile(N2, (N, 1)) |
| 63 | + C2 -= np.dot(emb0,emb0.T) |
| 64 | + return C2 |
| 65 | + |
| 66 | + |
| 67 | +def gromov_wasserstein(x_src, x_tgt, C2): |
| 68 | + N = x_src.shape[0] |
| 69 | + C1 = GWmatrix(x_src) |
| 70 | + M = ot.gromov_wasserstein(C1,C2,np.ones(N),np.ones(N),'square_loss',epsilon=0.55,max_iter=100,tol=1e-4) |
| 71 | + return procrustes(np.dot(M,x_tgt), x_src) |
| 72 | + |
| 73 | + |
| 74 | +def align(EMB, TRANS, lglist, args): |
| 75 | + nmax, l = args.maxload, len(lglist) |
| 76 | + # create a list of language pairs to sample from |
| 77 | + # (default == higher probability to pick a language pair contianing the pivot) |
| 78 | + # if --uniform: uniform probability of picking a language pair |
| 79 | + samples = [] |
| 80 | + for i in range(l): |
| 81 | + for j in range(l): |
| 82 | + if j == i : |
| 83 | + continue |
| 84 | + if j > 0 and args.uniform == False: |
| 85 | + samples.append((0,j)) |
| 86 | + if i > 0 and args.uniform == False: |
| 87 | + samples.append((i,0)) |
| 88 | + samples.append((i,j)) |
| 89 | + |
| 90 | + # optimization of the l2 loss |
| 91 | + print('start optimizing L2 loss') |
| 92 | + lr0, bsz, nepoch, niter = args.lr, args.bsz, args.epoch, args.niter |
| 93 | + for epoch in range(nepoch): |
| 94 | + print("start epoch %d / %d"%(epoch+1, nepoch)) |
| 95 | + ones = np.ones(bsz) |
| 96 | + f, fold, nb, lr = 0.0, 0.0, 0.0, lr0 |
| 97 | + for it in range(niter): |
| 98 | + if it > 1 and f > fold + 1e-3: |
| 99 | + lr /= 2 |
| 100 | + if lr < .05: |
| 101 | + break |
| 102 | + fold = f |
| 103 | + f, nb = 0.0, 0.0 |
| 104 | + for k in range(100 * (l-1)): |
| 105 | + (i,j) = random.choice(samples) |
| 106 | + embi = EMB[i][np.random.permutation(nmax)[:bsz], :] |
| 107 | + embj = EMB[j][np.random.permutation(nmax)[:bsz], :] |
| 108 | + perm = ot.sinkhorn(ones, ones, np.linalg.multi_dot([embi, -TRANS[i], TRANS[j].T,embj.T]), reg = 0.025, stopThr = 1e-3) |
| 109 | + grad = np.linalg.multi_dot([embi.T, perm, embj]) |
| 110 | + f -= np.trace(np.linalg.multi_dot([TRANS[i].T, grad, TRANS[j]])) / embi.shape[0] |
| 111 | + nb += 1 |
| 112 | + if i > 0: |
| 113 | + TRANS[i] = proj_ortho(TRANS[i] + lr * np.dot(grad, TRANS[j])) |
| 114 | + if j > 0: |
| 115 | + TRANS[j] = proj_ortho(TRANS[j] + lr * np.dot(grad.transpose(), TRANS[i])) |
| 116 | + print("iter %d / %d - epoch %d - loss: %.5f lr: %.4f" % (it, niter, epoch+1, f / nb , lr)) |
| 117 | + print("end of epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr)) |
| 118 | + niter, bsz = max(int(niter/2),2), min(1000, bsz * 2) |
| 119 | + #end for epoch in range(nepoch): |
| 120 | + |
| 121 | + # optimization of the RCSLS loss |
| 122 | + print('start optimizing RCSLS loss') |
| 123 | + f, fold, nb, lr = 0.0, 0.0, 0.0, args.altlr |
| 124 | + for epoch in range(args.altepoch): |
| 125 | + if epoch > 1 and f-fold > -1e-4 * abs(fold): |
| 126 | + lr/= 2 |
| 127 | + if lr < 1e-1: |
| 128 | + break |
| 129 | + fold = f |
| 130 | + f, nb = 0.0, 0.0 |
| 131 | + for k in range(round(nmax / args.altbsz) * 10 * (l-1)): |
| 132 | + (i,j) = random.choice(samples) |
| 133 | + sgdidx = np.random.choice(nmax, size=args.altbsz, replace=False) |
| 134 | + embi = EMB[i][sgdidx, :] |
| 135 | + embj = EMB[j][:nmax, :] |
| 136 | + # crude alignment approximation: |
| 137 | + T = np.dot(TRANS[i], TRANS[j].T) |
| 138 | + scores = np.linalg.multi_dot([embi, T, embj.T]) |
| 139 | + perm = np.zeros_like(scores) |
| 140 | + perm[np.arange(len(scores)), scores.argmax(1)] = 1 |
| 141 | + embj = np.dot(perm, embj) |
| 142 | + # normalization over a subset of embeddings for speed up |
| 143 | + fi, grad = rcsls(embi, embj, embi, embj, T.T) |
| 144 | + f += fi |
| 145 | + nb += 1 |
| 146 | + if i > 0: |
| 147 | + TRANS[i] = proj_ortho(TRANS[i] - lr * np.dot(grad, TRANS[j])) |
| 148 | + if j > 0: |
| 149 | + TRANS[j] = proj_ortho(TRANS[j] - lr * np.dot(grad.transpose(), TRANS[i])) |
| 150 | + print("epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr)) |
| 151 | + #end for epoch in range(args.altepoch): |
| 152 | + return TRANS |
| 153 | + |
| 154 | +def convex_init(X, Y, niter=100, reg=0.05, apply_sqrt=False): |
| 155 | + n, d = X.shape |
| 156 | + K_X, K_Y = np.dot(X, X.T), np.dot(Y, Y.T) |
| 157 | + K_Y *= np.linalg.norm(K_X) / np.linalg.norm(K_Y) |
| 158 | + K2_X, K2_Y = np.dot(K_X, K_X), np.dot(K_Y, K_Y) |
| 159 | + P = np.ones([n, n]) / float(n) |
| 160 | + for it in range(1, niter + 1): |
| 161 | + G = np.dot(P, K2_X) + np.dot(K2_Y, P) - 2 * np.dot(K_Y, np.dot(P, K_X)) |
| 162 | + q = ot.sinkhorn(np.ones(n), np.ones(n), G, reg, stopThr=1e-3) |
| 163 | + alpha = 2.0 / float(2.0 + it) |
| 164 | + P = alpha * q + (1.0 - alpha) * P |
| 165 | + return procrustes(np.dot(P, X), Y).T |
| 166 | + |
| 167 | + |
| 168 | +###### MAIN ###### |
| 169 | + |
| 170 | +lglist = args.lglist.split('-') |
| 171 | +l = len(lglist) |
| 172 | + |
| 173 | +# embs: |
| 174 | +EMB = {} |
| 175 | +for i in range(l): |
| 176 | + fn = args.embdir + '/wiki.' + lglist[i] + '.vec' |
| 177 | + _, vecs = load_vectors(fn, maxload=args.maxload) |
| 178 | + EMB[i] = vecs |
| 179 | + |
| 180 | +#init |
| 181 | +print("Computing initial bilingual apping with Gromov-Wasserstein...") |
| 182 | +TRANS={} |
| 183 | +maxinit = 2000 |
| 184 | +emb0 = EMB[0][:maxinit,:] |
| 185 | +C0 = GWmatrix(emb0) |
| 186 | +TRANS[0] = np.eye(300) |
| 187 | +for i in range(1, l): |
| 188 | + print("init "+lglist[i]) |
| 189 | + embi = EMB[i][:maxinit,:] |
| 190 | + TRANS[i] = gromov_wasserstein(embi, emb0, C0) |
| 191 | + |
| 192 | +# align |
| 193 | +align(EMB, TRANS, lglist, args) |
| 194 | + |
| 195 | +print('saving matrices in ' + args.outdir) |
| 196 | +languages=''.join(lglist) |
| 197 | +for i in range(l): |
| 198 | + save_matrix(args.outdir + '/W-' + languages + '-' + lglist[i], TRANS[i]) |
0 commit comments