-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDecoderLayer.py
50 lines (41 loc) · 2.2 KB
/
DecoderLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from tensorflow import keras
from MultiHeadAttention import MultiHeadAttention
from FFN import feed_forward_network
# Decoder Layer
class DecoderLayer(keras.layers.Layer):
"""
x -> self attention -> add & normalize & dropout -> out1
out1, encoding_outputs -> attention -> add & normalize & dropout -> out2
out2 -> feed_forward -> add & normalize & dropout -> out3
"""
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = feed_forward_network(d_model, dff)
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm3 = keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = keras.layers.Dropout(rate)
self.dropout2 = keras.layers.Dropout(rate)
self.dropout3 = keras.layers.Dropout(rate)
def call(self, x, encoding_outputs, training, decoder_mask, encoder_decoder_padding_mask):
# decoder_mask是由look_ahead_mask和decoder_padding_mask做与操作合并而来
# x.shape: (batch_size, target_seq_len, d_model)
# encoding_outputs.shape: (batch_size, input_seq_len, d_model)
# attn1,out1.shape: (batch_size, target_seq_len, d_model)
attn1, attn_weights1 = self.mha1(x, x, x, decoder_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layer_norm1(attn1 + x)
# attn2,out2.shape: (batch_size, target_seq_len, d_model)
attn2, attn_weights2 = self.mha2(
out1, encoding_outputs, encoding_outputs, encoder_decoder_padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layer_norm2(attn2 + out1)
# ffn_output, out3.shape: (batch_size, target_seq_len, d_model)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layer_norm3(ffn_output + out2)
return out3, attn_weights1, attn_weights2