-
Notifications
You must be signed in to change notification settings - Fork 15
/
embedding.py
98 lines (88 loc) · 5.35 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from keras import backend as K
from keras.layers import Embedding
class OntoAwareEmbedding(Embedding):
'''
This class modifies two aspects of the Embedding class in Keras:
1. Higher order inputs: Embedding already works with inputs of any shape, except that the output shape
returned by it assumes that the input it 2D. Changing it.
2. Sense priors: The expected input shape is (num_samples, num_words, num_senses, num_hyps+1). The +1 at the
end is the word index appended at the end of each sense indices vector. This is to define an additional real
value for each word, which will act as the sense prior in OntoLSTM. OntoLSTM is reponsible for handling this
correctly. The output shape is (num_samples, num_words, num_senses, num_hyps, embedding_dim + 1).
'''
input_ndim = 4
def __init__(self, word_index_size, synset_index_size, embedding_dim, set_sense_priors=True,
tune_embedding=True, **kwargs):
self.embedding_dim = embedding_dim
self.word_index_size = word_index_size
self.synset_index_size = synset_index_size
self.set_sense_priors = set_sense_priors
# We have a separate "tune_embedding" field instead of using trainable because we have two sets of
# parameters here: the embedding weights, and sense prior weights. We may want to fix only one of
# them at a time.
self.tune_embedding = tune_embedding
# Convincing Embedding to return an embedding of the right shape. The output_dim of this layer is embedding_dim+1
kwargs['output_dim'] = self.embedding_dim
kwargs['input_dim'] = self.synset_index_size
self.onto_aware_embedding_weights = None
super(OntoAwareEmbedding, self).__init__(**kwargs)
@staticmethod
def _get_initial_sense_priors(shape, rate_range=None, name=None):
# This returns a Keras variable with the initial values all being 0.5.
if rate_range is None:
low, high = 0.01, 0.99
else:
low, high = rate_range
return K.random_uniform_variable(shape, low, high, name=name)
def build(self, input_shape):
# input shape is (batch_size, num_words, num_senses, num_hyps)
self.num_senses = input_shape[-2]
self.num_hyps = input_shape[-1] - 1 # -1 because the last value is a word index
# embedding of size 1.
if self.set_sense_priors:
self.sense_priors = self._get_initial_sense_priors((self.word_index_size, 1), name='{}_sense_priors'.format(self.name))
else:
# OntoLSTM makes sense proabilities uniform if the passed sense parameters are zero.
self.sense_priors = K.zeros((self.word_index_size, 1)) # uniform sense probs
# Keeping aside the initial weights to not let Embedding set them. It wouldn't know what sense priors are.
if self.initial_weights is not None:
self.onto_aware_embedding_weights = self.initial_weights
self.initial_weights = None
# The following method will set self.trainable_weights
super(OntoAwareEmbedding, self).build(input_shape) # input_shape will not be used by Embedding's build.
if not self.tune_embedding:
# Move embedding to non_trainable_weights
self._non_trainable_weights.append(self._trainable_weights.pop())
if self.set_sense_priors:
self._trainable_weights.append(self.sense_priors)
if self.onto_aware_embedding_weights is not None:
self.set_weights(self.onto_aware_embedding_weights)
def call(self, x, mask=None):
# Remove the word indices at the end before making a call to Embedding.
x_synsets = x[:, :, :, :-1] # (num_samples, num_words, num_senses, num_hyps)
# Taking the last index from the first sense. The last index in all the senses will be the same.
x_word_index = x[:, :, 0, -1] # (num_samples, num_words)
# (num_samples, num_words, num_senses, num_hyps, embedding_dim)
synset_embedding = super(OntoAwareEmbedding, self).call(x_synsets, mask=None)
# (num_samples, num_words, 1, 1)
sense_prior_embedding = K.expand_dims(K.gather(self.sense_priors, x_word_index))
# Now tile sense_prior_embedding and concatenate it with synset_embedding.
# (num_samples, num_words, num_senses, num_hyps, 1)
tiled_sense_prior_embedding = K.expand_dims(K.tile(sense_prior_embedding, (1, 1, self.num_senses, self.num_hyps)))
synset_embedding_with_priors = K.concatenate([synset_embedding, tiled_sense_prior_embedding])
return synset_embedding_with_priors
def compute_mask(self, x, mask=None):
# Since the output dim is different, we need to change the mask size
embedding_mask = super(OntoAwareEmbedding, self).compute_mask(x, mask)
return embedding_mask[:, :, :, :-1] if embedding_mask is not None else None
def get_output_shape_for(self, input_shape):
return input_shape[:3] + (self.num_hyps, self.embedding_dim+1)
def get_config(self):
config = {"word_index_size": self.word_index_size,
"synset_index_size": self.synset_index_size,
"embedding_dim": self.embedding_dim,
"set_sense_priors": self.set_sense_priors
}
base_config = super(OntoAwareEmbedding, self).get_config()
config.update(base_config)
return config