-
Notifications
You must be signed in to change notification settings - Fork 15
/
nse.py
366 lines (330 loc) · 17.9 KB
/
nse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import sys
from overrides import overrides
from keras import backend as K
from keras.engine import InputSpec, Layer
from keras.layers import LSTM, Dense
class NSE(Layer):
'''
Simple Neural Semantic Encoder.
'''
def __init__(self, output_dim, input_length=None, composer_activation='linear',
return_mode='last_output', weights=None, **kwargs):
'''
Arguments:
output_dim (int)
input_length (int)
composer_activation (str): activation used in the MLP
return_mode (str): One of last_output, all_outputs, output_and_memory
This is analogous to the return_sequences flag in Keras' Recurrent.
last_output returns only the last h_t
all_outputs returns the whole sequence of h_ts
output_and_memory returns the last output and the last memory concatenated
(needed if this layer is followed by a MMA-NSE)
weights (list): Initial weights
'''
self.output_dim = output_dim
self.input_dim = output_dim # Equation 2 in the paper makes this assumption.
self.initial_weights = weights
self.input_spec = [InputSpec(ndim=3)]
self.input_length = input_length
self.composer_activation = composer_activation
super(NSE, self).__init__(**kwargs)
self.reader = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="gpu",
name="{}_reader".format(self.name))
# TODO: Let the writer use parameter dropout and any consume_less mode.
# Setting dropout to 0 here to eliminate the need for constants.
# Setting consume_less to gpu to eliminate need for preprocessing
self.writer = LSTM(self.output_dim, dropout_W=0.0, dropout_U=0.0, consume_less="gpu",
name="{}_writer".format(self.name))
self.composer = Dense(self.output_dim * 2, activation=self.composer_activation,
name="{}_composer".format(self.name))
if return_mode not in ["last_output", "all_outputs", "output_and_memory"]:
raise Exception("Unrecognized return mode: %s" % (return_mode))
self.return_mode = return_mode
def get_output_shape_for(self, input_shape):
input_length = input_shape[1]
if self.return_mode == "last_output":
return (input_shape[0], self.output_dim)
elif self.return_mode == "all_outputs":
return (input_shape[0], input_length, self.output_dim)
else:
# return_mode is output_and_memory. Output will be concatenated to memory.
return (input_shape[0], input_length + 1, self.output_dim)
def compute_mask(self, input, mask):
if mask is None or self.return_mode == "last_output":
return None
elif self.return_mode == "all_outputs":
return mask # (batch_size, input_length)
else:
# Return mode is output_and_memory
# Mask memory corresponding to all the inputs that are masked, and do not mask the output
# (batch_size, input_length + 1)
return K.cast(K.concatenate([K.zeros_like(mask[:, :1]), mask]), 'uint8')
def get_composer_input_shape(self, input_shape):
# Takes concatenation of output and memory summary
return (input_shape[0], self.output_dim * 2)
def get_reader_input_shape(self, input_shape):
return input_shape
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
input_dim = input_shape[-1]
reader_input_shape = self.get_reader_input_shape(input_shape)
print >>sys.stderr, "NSE reader input shape:", reader_input_shape
writer_input_shape = (input_shape[0], 1, self.output_dim * 2) # Will process one timestep at a time
print >>sys.stderr, "NSE writer input shape:", writer_input_shape
composer_input_shape = self.get_composer_input_shape(input_shape)
print >>sys.stderr, "NSE composer input shape:", composer_input_shape
self.reader.build(reader_input_shape)
self.writer.build(writer_input_shape)
self.composer.build(composer_input_shape)
# Aggregate weights of individual components for this layer.
reader_weights = self.reader.trainable_weights
writer_weights = self.writer.trainable_weights
composer_weights = self.composer.trainable_weights
self.trainable_weights = reader_weights + writer_weights + composer_weights
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
def get_initial_states(self, nse_input, input_mask=None):
'''
This method produces the 'read' mask for all timesteps
and initializes the memory slot mem_0.
Input: nse_input (batch_size, input_length, input_dim)
Output: list[Tensors]:
h_0 (batch_size, output_dim)
c_0 (batch_size, output_dim)
flattened_mem_0 (batch_size, input_length * output_dim)
While this method simply copies input to mem_0, variants that inherit from this class can do
something fancier.
'''
input_to_read = nse_input
mem_0 = input_to_read
flattened_mem_0 = K.batch_flatten(mem_0)
initial_states = self.reader.get_initial_states(nse_input)
initial_states += [flattened_mem_0]
return initial_states
@staticmethod
def summarize_memory(o_t, mem_tm1):
'''
This method selects the relevant parts of the memory given the read output and summarizes the
memory. Implements Equations 2-3 or 8-11 in the paper.
'''
# Selecting relevant memory slots, Equation 2
z_t = K.softmax(K.sum(K.expand_dims(o_t, dim=1) * mem_tm1, axis=2)) # (batch_size, input_length)
# Summarizing memory, Equation 3
m_rt = K.sum(K.expand_dims(z_t, dim=2) * mem_tm1, axis=1) # (batch_size, output_dim)
return z_t, m_rt
def compose_memory_and_output(self, output_memory_list):
'''
This method takes a list of tensors and applies the composition function on their concatrnation.
Implements equation 4 or 12 in the paper.
'''
# Composition, Equation 4
c_t = self.composer.call(K.concatenate(output_memory_list)) # (batch_size, output_dim)
return c_t
def update_memory(self, z_t, h_t, mem_tm1):
'''
This method takes the attention vector (z_t), writer output (h_t) and previous timestep's memory (mem_tm1)
and updates the memory. Implements equations 6, 14 or 15.
'''
tiled_z_t = K.tile(K.expand_dims(z_t), (self.output_dim)) # (batch_size, input_length, output_dim)
input_length = K.shape(mem_tm1)[1]
# (batch_size, input_length, output_dim)
tiled_h_t = K.permute_dimensions(K.tile(K.expand_dims(h_t), (input_length)), (0, 2, 1))
# Updating memory. First term in summation corresponds to selective forgetting and the second term to
# selective addition. Equation 6.
mem_t = mem_tm1 * (1 - tiled_z_t) + tiled_h_t * tiled_z_t # (batch_size, input_length, output_dim)
return mem_t
@staticmethod
def split_states(states):
# This method is a helper for the step function to split the states into reader states, memory and
# awrite states.
return states[:2], states[2], states[3:]
def step(self, input_t, states):
'''
This method is a step function that updates the memory at each time step and produces
a new output vector (Equations 1 to 6 in the paper).
The memory_state is flattened because K.rnn requires all states to be of the same shape as the output,
because it uses the same mask for the output and the states.
Inputs:
input_t (batch_size, input_dim)
states (list[Tensor])
flattened_mem_tm1 (batch_size, input_length * output_dim)
writer_h_tm1 (batch_size, output_dim)
writer_c_tm1 (batch_size, output_dim)
Outputs:
h_t (batch_size, output_dim)
flattened_mem_t (batch_size, input_length * output_dim)
'''
reader_states, flattened_mem_tm1, writer_states = self.split_states(states)
input_mem_shape = K.shape(flattened_mem_tm1)
mem_tm1_shape = (input_mem_shape[0], input_mem_shape[1]/self.output_dim, self.output_dim)
mem_tm1 = K.reshape(flattened_mem_tm1, mem_tm1_shape) # (batch_size, input_length, output_dim)
reader_constants = self.reader.get_constants(input_t) # Does not depend on input_t, see init.
reader_states = reader_states[:2] + reader_constants + reader_states[2:]
o_t, [_, reader_c_t] = self.reader.step(input_t, reader_states) # o_t, reader_c_t: (batch_size, output_dim)
z_t, m_rt = self.summarize_memory(o_t, mem_tm1)
c_t = self.compose_memory_and_output([o_t, m_rt])
# Collecting the necessary variables to directly call writer's step function.
writer_constants = self.writer.get_constants(c_t) # returns dropouts for W and U (all 1s, see init)
writer_states += writer_constants
# Making a call to writer's step function, Equation 5
h_t, [_, writer_c_t] = self.writer.step(c_t, writer_states) # h_t, writer_c_t: (batch_size, output_dim)
mem_t = self.update_memory(z_t, h_t, mem_tm1)
flattened_mem_t = K.batch_flatten(mem_t)
return h_t, [o_t, reader_c_t, flattened_mem_t, h_t, writer_c_t]
def loop(self, x, initial_states, mask):
# This is a separate method because Ontoaware variants will have to override this to make a call
# to changingdim rnn.
last_output, all_outputs, last_states = K.rnn(self.step, x, initial_states, mask=mask)
return last_output, all_outputs, last_states
def call(self, x, mask=None):
# input_shape = (batch_size, input_length, input_dim). This needs to be defined in build.
initial_read_states = self.get_initial_states(x, mask)
fake_writer_input = K.expand_dims(initial_read_states[0], dim=1) # (batch_size, 1, output_dim)
initial_write_states = self.writer.get_initial_states(fake_writer_input) # h_0 and c_0 of the writer LSTM
initial_states = initial_read_states + initial_write_states
# last_output: (batch_size, output_dim)
# all_outputs: (batch_size, input_length, output_dim)
# last_states:
# last_memory_state: (batch_size, input_length, output_dim)
# last_output
# last_writer_ct
last_output, all_outputs, last_states = self.loop(x, initial_states, mask)
last_memory = last_states[0]
if self.return_mode == "last_output":
return last_output
elif self.return_mode == "all_outputs":
return all_outputs
else:
# return mode is output_and_memory
expanded_last_output = K.expand_dims(last_output, dim=1) # (batch_size, 1, output_dim)
# (batch_size, 1+input_length, output_dim)
return K.concatenate([expanded_last_output, last_memory], axis=1)
def get_config(self):
config = {'output_dim': self.output_dim,
'input_length': self.input_length,
'composer_activation': self.composer_activation,
'return_mode': self.return_mode}
base_config = super(NSE, self).get_config()
config.update(base_config)
return config
class MultipleMemoryAccessNSE(NSE):
'''
MultipleMemoryAccessNSE is very similar to the simple NSE. The difference is that along with the sentence
memory, it has access to one (or multiple) additional memory. The operations on the additional memory are
exactly the same as the original memory. The additional memory is initialized from the final timestep of
a different NSE, and the composer will take as input the concatenation of the reader output and summaries
of both the memories.
'''
#TODO: This is currently assuming we need access to one additional memory. Change it to an arbitrary number.
@overrides
def get_output_shape_for(self, input_shape):
# This class has twice the input length as an NSE due to the concatenated input. Pass the right size
# to NSE's method to get the right putput shape.
nse_input_shape = (input_shape[0], input_shape[1]/2, input_shape[2])
return super(MultipleMemoryAccessNSE, self).get_output_shape_for(nse_input_shape)
def get_reader_input_shape(self, input_shape):
return (input_shape[0], input_shape[1]/2, self.output_dim)
def get_composer_input_shape(self, input_shape):
return (input_shape[0], self.output_dim * 3)
@overrides
def get_initial_states(self, nse_input, input_mask=None):
'''
Read input in MMA-NSE will be of shape (batch_size, read_input_length*2, input_dim), a concatenation of
the actual input to this NSE and the output from a different NSE. The latter will be used to initialize
the shared memory. The former will be passed to the read LSTM and also used to initialize the current
memory.
'''
input_length = K.shape(nse_input)[1]
read_input_length = input_length/2
input_to_read = nse_input[:, :read_input_length, :]
initial_shared_memory = K.batch_flatten(nse_input[:, read_input_length:, :])
mem_0 = K.batch_flatten(input_to_read)
o_mask = self.reader.compute_mask(input_to_read, input_mask)
reader_states = self.reader.get_initial_states(nse_input)
initial_states = reader_states + [mem_0, initial_shared_memory]
return initial_states, o_mask
@overrides
def step(self, input_t, states):
reader_states = states[:2]
flattened_mem_tm1, flattened_shared_mem_tm1 = states[2:4]
writer_h_tm1, writer_c_tm1 = states[4:]
input_mem_shape = K.shape(flattened_mem_tm1)
mem_shape = (input_mem_shape[0], input_mem_shape[1]/self.output_dim, self.output_dim)
mem_tm1 = K.reshape(flattened_mem_tm1, mem_shape)
shared_mem_tm1 = K.reshape(flattened_shared_mem_tm1, mem_shape)
reader_constants = self.reader.get_constants(input_t)
reader_states += reader_constants
o_t, [_, reader_c_t] = self.reader.step(input_t, reader_states)
z_t, m_rt = self.summarize_memory(o_t, mem_tm1)
shared_z_t, shared_m_rt = self.summarize_memory(o_t, shared_mem_tm1)
c_t = self.compose_memory_and_output([o_t, m_rt, shared_m_rt])
# Collecting the necessary variables to directly call writer's step function.
writer_constants = self.writer.get_constants(c_t) # returns dropouts for W and U (all 1s, see init)
writer_states = [writer_h_tm1, writer_c_tm1] + writer_constants
# Making a call to writer's step function, Equation 5
h_t, [_, writer_c_t] = self.writer.step(c_t, writer_states) # h_t, writer_c_t: (batch_size, output_dim)
mem_t = self.update_memory(z_t, h_t, mem_tm1)
shared_mem_t = self.update_memory(shared_z_t, h_t, shared_mem_tm1)
return h_t, [o_t, reader_c_t, K.batch_flatten(mem_t), K.batch_flatten(shared_mem_t), h_t, writer_c_t]
class InputMemoryMerger(Layer):
'''
This layer taks as input, the memory part of the output of a NSE layer, and the embedded input to a MMANSE
layer, and prepares a single input tensor for MMANSE that is a concatenation of the first sentence's memory
and the second sentence's embedding.
This is a concrete layer instead of a lambda function because we want to support masking.
'''
def __init__(self, **kwargs):
self.supports_masking = True
super(InputMemoryMerger, self).__init__(**kwargs)
def get_output_shape_for(self, input_shapes):
return (input_shapes[1][0], input_shapes[1][1]*2, input_shapes[1][2])
def compute_mask(self, inputs, mask=None):
# pylint: disable=unused-argument
if mask is None:
return None
elif mask == [None, None]:
return None
else:
memory_mask, mmanse_embed_mask = mask
return K.concatenate([mmanse_embed_mask, memory_mask], axis=1) # (batch_size, nse_input_length * 2)
def call(self, inputs, mask=None):
shared_memory = inputs[0]
mmanse_embed_input = inputs[1] # (batch_size, nse_input_length, output_dim)
return K.concatenate([mmanse_embed_input, shared_memory], axis=1)
class OutputSplitter(Layer):
'''
This layer takes the concatenation of output and memory from NSE and returns either the output or the
memory.
'''
def __init__(self, return_mode, **kwargs):
self.supperots_masking = True
if return_mode not in ["output", "memory"]:
raise Exception("Invalid return mode: %s" % return_mode)
self.return_mode = return_mode
super(OutputSplitter, self).__init__(**kwargs)
def get_output_shape_for(self, input_shape):
if self.return_mode == "output":
return (input_shape[0], input_shape[2])
else:
# Return mode is memory.
# input contains output and memory concatenated along the second dimension.
return (input_shape[0], input_shape[1] - 1, input_shape[2])
def compute_mask(self, inputs, mask=None):
# pylint: disable=unused-argument
if self.return_mode == "output" or mask is None:
return None
else:
# Return mode is memory and mask is not None
return mask[:, 1:] # (batch_size, nse_input_length)
def call(self, inputs, mask=None):
if self.return_mode == "output":
return inputs[:, 0, :] # (batch_size, output_dim)
else:
return inputs[:, 1:, :] # (batch_size, nse_input_length, output_dim)
def get_config(self):
config = {"return_mode": self.return_mode}
base_config = super(OutputSplitter, self).get_config()
config.update(base_config)
return config