-
Notifications
You must be signed in to change notification settings - Fork 7
/
layer_norm.py
31 lines (26 loc) · 1.01 KB
/
layer_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
class LayerNormalization(tf.keras.layers.Layer):
def __init__(self):
super(LayerNormalization, self).__init__()
self.hidden_size = None
def build(self, input_shape):
self.hidden_size = tensor_shape.dimension_value(input_shape[-1])
self.gamma = self.add_weight(
"layer_norm_scale",
shape=[self.hidden_size],
dtype="float32",
initializer=tf.ones_initializer(),
experimental_autocast=False)
self.beta = self.add_weight(
"layer_norm_bias",
shape=[self.hidden_size],
dtype="float32",
initializer=tf.zeros_initializer(),
experimental_autocast=False)
super(LayerNormalization, self).build(input_shape)
def call(self, x, epsilon=1e-6, input_dtype=tf.float32):
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)
normalized = (x - mean) * tf.math.rsqrt(variance + epsilon)
return tf.cast(normalized * self.gamma + self.beta, input_dtype)