forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention_layer.py
148 lines (117 loc) · 5.4 KB
/
attention_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class Attention(tf.layers.Layer):
"""Multi-headed attention layer."""
def __init__(self, hidden_size, num_heads, attention_dropout, train):
if hidden_size % num_heads != 0:
raise ValueError("Hidden size must be evenly divisible by the number of "
"heads.")
super(Attention, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.train = train
# Layers for linearly projecting the queries, keys, and values.
self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q")
self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k")
self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v")
self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False,
name="output_transform")
def split_heads(self, x):
"""Split x into different heads, and transpose the resulting value.
The tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
Args:
x: A tensor with shape [batch_size, length, hidden_size]
Returns:
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
"""
with tf.name_scope("split_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
# Calculate depth of last dimension after it has been split.
depth = (self.hidden_size // self.num_heads)
# Split the last dimension
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
# Transpose the result
return tf.transpose(x, [0, 2, 1, 3])
def combine_heads(self, x):
"""Combine tensor that has been split.
Args:
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
Returns:
A tensor with shape [batch_size, length, hidden_size]
"""
with tf.name_scope("combine_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[2]
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, cache=None):
"""Apply attention mechanism to x and y.
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
y: a tensor with shape [batch_size, length_y, hidden_size]
bias: attention bias that will be added to the result of the dot product.
cache: (Used during prediction) dictionary with tensors containing results
of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
# Linearly project the query (q), key (k) and value (v) using different
# learned projections. This is in preparation of splitting them into
# multiple heads. Multi-head attention uses multiple queries, keys, and
# values rather than regular attention (which uses a single q, k, v).
q = self.q_dense_layer(x)
k = self.k_dense_layer(y)
v = self.v_dense_layer(y)
if cache is not None:
# Combine cached keys and values with new keys and values.
k = tf.concat([cache["k"], k], axis=1)
v = tf.concat([cache["v"], v], axis=1)
# Update cache
cache["k"] = k
cache["v"] = v
# Split q, k, v into heads.
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# Scale q to prevent the dot product between q and k from growing too large.
depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if self.train:
weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout)
attention_output = tf.matmul(weights, v)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)
# Run the combined outputs through another linear projection layer.
attention_output = self.output_dense_layer(attention_output)
return attention_output
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, x, bias, cache=None):
return super(SelfAttention, self).call(x, x, bias, cache)