diff --git a/modules.py b/modules.py index 7efa387..d1e0807 100644 --- a/modules.py +++ b/modules.py @@ -156,6 +156,16 @@ def mask(inputs, queries=None, keys=None, type=None): return outputs +def zero_padding_mask(inputs): + """ + :param inputs: (N, T, d) + :return: + """ + masks = tf.sign(tf.reduce_sum(tf.abs(inputs), axis=-1)) # (N, T) + masks = tf.expand_dims(masks, -1) # (N, T, 1) + masks = tf.tile(masks, [1, 1, tf.shape(inputs)[-1]]) # (N, T, d) + return masks + def multihead_attention(queries, keys, values, num_heads=8, dropout_rate=0, @@ -178,10 +188,13 @@ def multihead_attention(queries, keys, values, d_model = queries.get_shape().as_list()[-1] with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # Linear projections - Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model) - K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model) - V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model) - + Q = tf.layers.dense(queries, d_model, name="query_linproj") # (N, T_q, d_model) + Q = Q * zero_padding_mask(queries) + K = tf.layers.dense(keys, d_model, name="key_linproj") # (N, T_k, d_model) + K = K * zero_padding_mask(keys) + V = tf.layers.dense(values, d_model, name="value_linproj") # (N, T_k, d_model) + V = V * zero_padding_mask(values) + # Split and concat Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h) K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)