Skip to content

Commit 9fd10e1

Browse files
authored
Merge pull request #42 from chhoumann/transformer
transformer (last minute, not refactored) impl. for web-app
2 parents c9138d6 + 6c8da93 commit 9fd10e1

15 files changed

+521541
-1
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
1.05 MB
Binary file not shown.
Binary file not shown.

code/web-app/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from classes.lstm import LSTMModel, Optimization
33
from classes.gru import GRUModel
44
from classes.rnn import RNNModel
5+
from classes.transformer import *
56
from rmse import RMSELoss
67
from statsmodels.tsa.arima.model import ARIMA
78
from predict_page import show_predict_page

code/web-app/classes/transformer.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import numpy as np
2+
import pandas as pd
3+
import os, datetime
4+
import tensorflow as tf
5+
from tensorflow.keras.models import *
6+
from tensorflow.keras.layers import *
7+
from keras import backend as K
8+
from tensorflow import keras
9+
from tensorflow.keras import layers
10+
11+
batch_size = 10
12+
seq_len = 20
13+
14+
d_k = 40
15+
d_v = 40
16+
n_heads = 6
17+
ff_dim = 40
18+
19+
LAG=10
20+
21+
class Time2Vector(Layer):
22+
def __init__(self, seq_len, **kwargs):
23+
super(Time2Vector, self).__init__()
24+
self.seq_len = seq_len
25+
26+
def build(self, input_shape):
27+
'''Initialize weights and biases with shape (batch, seq_len)'''
28+
self.weights_linear = self.add_weight(name='weight_linear',
29+
shape=(int(self.seq_len),),
30+
initializer='uniform',
31+
trainable=True)
32+
33+
self.bias_linear = self.add_weight(name='bias_linear',
34+
shape=(int(self.seq_len),),
35+
initializer='uniform',
36+
trainable=True)
37+
38+
self.weights_periodic = self.add_weight(name='weight_periodic',
39+
shape=(int(self.seq_len),),
40+
initializer='uniform',
41+
trainable=True)
42+
43+
self.bias_periodic = self.add_weight(name='bias_periodic',
44+
shape=(int(self.seq_len),),
45+
initializer='uniform',
46+
trainable=True)
47+
48+
def call(self, x):
49+
'''Calculate linear and periodic time features'''
50+
x = tf.math.reduce_mean(x[:, :, :LAG], axis=-1)
51+
time_linear = self.weights_linear * x + self.bias_linear # Linear time feature
52+
time_linear = tf.expand_dims(time_linear, axis=-1) # Add dimension (batch, seq_len, 1)
53+
54+
time_periodic = tf.math.sin(tf.multiply(x, self.weights_periodic) + self.bias_periodic)
55+
time_periodic = tf.expand_dims(time_periodic, axis=-1) # Add dimension (batch, seq_len, 1)
56+
return tf.concat([time_linear, time_periodic], axis=-1) # shape = (batch, seq_len, 2)
57+
58+
def get_config(self): # Needed for saving and loading model with custom layer
59+
config = super().get_config().copy()
60+
config.update({'seq_len': self.seq_len})
61+
return config
62+
63+
64+
class SingleAttention(Layer):
65+
def __init__(self, d_k, d_v):
66+
super(SingleAttention, self).__init__()
67+
self.d_k = d_k
68+
self.d_v = d_v
69+
70+
def build(self, input_shape):
71+
self.query = Dense(self.d_k,
72+
input_shape=input_shape,
73+
kernel_initializer='glorot_uniform',
74+
bias_initializer='glorot_uniform')
75+
76+
self.key = Dense(self.d_k,
77+
input_shape=input_shape,
78+
kernel_initializer='glorot_uniform',
79+
bias_initializer='glorot_uniform')
80+
81+
self.value = Dense(self.d_v,
82+
input_shape=input_shape,
83+
kernel_initializer='glorot_uniform',
84+
bias_initializer='glorot_uniform')
85+
86+
def call(self, inputs): # inputs = (in_seq, in_seq, in_seq)
87+
q = self.query(inputs[0])
88+
k = self.key(inputs[1])
89+
90+
attn_weights = tf.matmul(q, k, transpose_b=True)
91+
attn_weights = tf.map_fn(lambda x: x / np.sqrt(self.d_k), attn_weights)
92+
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
93+
94+
v = self.value(inputs[2])
95+
attn_out = tf.matmul(attn_weights, v)
96+
return attn_out
97+
98+
class MultiAttention(Layer):
99+
def __init__(self, d_k, d_v, n_heads):
100+
super(MultiAttention, self).__init__()
101+
self.d_k = d_k
102+
self.d_v = d_v
103+
self.n_heads = n_heads
104+
self.attn_heads = list()
105+
106+
def build(self, input_shape):
107+
for n in range(self.n_heads):
108+
self.attn_heads.append(SingleAttention(self.d_k, self.d_v))
109+
self.linear = Dense(LAG + 3, input_shape=input_shape, kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform')
110+
111+
def call(self, inputs):
112+
attn = [self.attn_heads[i](inputs) for i in range(self.n_heads)]
113+
concat_attn = tf.concat(attn, axis=-1)
114+
multi_linear = self.linear(concat_attn)
115+
return multi_linear
116+
117+
118+
class TransformerEncoder(Layer):
119+
def __init__(self, d_k, d_v, n_heads, ff_dim, dropout=0.3, **kwargs):
120+
super(TransformerEncoder, self).__init__()
121+
self.d_k = d_k
122+
self.d_v = d_v
123+
self.n_heads = n_heads
124+
self.ff_dim = ff_dim
125+
self.attn_heads = list()
126+
self.dropout_rate = dropout
127+
128+
def build(self, input_shape):
129+
self.attn_multi = MultiAttention(self.d_k, self.d_v, self.n_heads)
130+
self.attn_dropout = Dropout(self.dropout_rate)
131+
self.attn_normalize = LayerNormalization(input_shape=input_shape, epsilon=1e-7)
132+
self.ff_conv1D_1 = Conv1D(filters=self.ff_dim, kernel_size=1, activation='relu',
133+
kernel_initializer='glorot_uniform', bias_initializer='glorot_uniform')
134+
self.ff_conv1D_2 = Conv1D(filters=LAG + 3, kernel_size=1, kernel_initializer='glorot_uniform',
135+
bias_initializer='glorot_uniform') # input_shape[0]=(batch, seq_len, 8), input_shape[0][-1]=8
136+
self.ff_dropout = Dropout(self.dropout_rate)
137+
self.ff_normalize = LayerNormalization(input_shape=input_shape, epsilon=1e-7)
138+
139+
def call(self, inputs): # inputs = (in_seq, in_seq, in_seq)
140+
attn_layer = self.attn_multi(inputs)
141+
attn_layer = self.attn_dropout(attn_layer)
142+
attn_layer = self.attn_normalize(inputs[0] + attn_layer)
143+
144+
ff_layer = self.ff_conv1D_1(attn_layer)
145+
ff_layer = self.ff_conv1D_2(ff_layer)
146+
ff_layer = self.ff_dropout(ff_layer)
147+
ff_layer = self.ff_normalize(inputs[0] + ff_layer)
148+
return ff_layer
149+
150+
def get_config(self): # Needed for saving and loading model with custom layer
151+
config = super().get_config().copy()
152+
config.update({'d_k': self.d_k,
153+
'd_v': self.d_v,
154+
'n_heads': self.n_heads,
155+
'ff_dim': self.ff_dim,
156+
'attn_heads': self.attn_heads,
157+
'dropout_rate': self.dropout_rate})
158+
return config
159+
160+
161+
def create_model():
162+
time_embedding = Time2Vector(seq_len)
163+
attn_layer1 = TransformerEncoder(d_k, d_v, n_heads, ff_dim)
164+
attn_layer2 = TransformerEncoder(d_k, d_v, n_heads, ff_dim)
165+
attn_layer3 = TransformerEncoder(d_k, d_v, n_heads, ff_dim)
166+
167+
in_seq = Input(shape=(seq_len, LAG + 1))
168+
x = time_embedding(in_seq)
169+
x = Concatenate(axis=-1)([in_seq, x])
170+
x = attn_layer1((x, x, x))
171+
x = attn_layer2((x, x, x))
172+
x = attn_layer3((x, x, x))
173+
x = GlobalAveragePooling1D(data_format='channels_first')(x)
174+
x = Dropout(0.3)(x)
175+
x = Dense(128, activation='relu')(x)
176+
x = Dropout(0.3)(x)
177+
out = Dense(1, activation='linear')(x)
178+
179+
model = Model(inputs=in_seq, outputs=out)
180+
optimizer = keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07)
181+
# optimizer = keras.optimizers.SGD(learning_rate=0.02, momentum=0.9, nesterov=True, clipnorm=1.0, clipvalue=0.5)
182+
model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
183+
return model
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
date,meantemp,humidity,wind_speed,meanpressure
2+
2017-01-01,15.91304347826087,85.8695652173913,2.743478260869565,59.0
3+
2017-01-02,18.5,77.22222222222223,2.8944444444444444,1018.2777777777778
4+
2017-01-03,17.11111111111111,81.88888888888889,4.016666666666667,1018.3333333333334
5+
2017-01-04,18.7,70.05,4.545,1015.7
6+
2017-01-05,18.38888888888889,74.94444444444444,3.3000000000000003,1014.3333333333334
7+
2017-01-06,19.318181818181817,79.31818181818181,8.681818181818182,1011.7727272727273
8+
2017-01-07,14.708333333333334,95.83333333333333,10.041666666666664,1011.375
9+
2017-01-08,15.68421052631579,83.52631578947368,1.95,1015.55
10+
2017-01-09,14.571428571428571,80.80952380952381,6.542857142857142,1015.952380952381
11+
2017-01-10,12.11111111111111,71.94444444444444,9.361111111111109,1016.8888888888889
12+
2017-01-11,11.0,72.11111111111111,9.77222222222222,1016.7777777777778
13+
2017-01-12,11.789473684210526,74.57894736842105,6.626315789473684,1016.3684210526316
14+
2017-01-13,13.235294117647058,67.05882352941177,6.435294117647059,1017.5294117647059
15+
2017-01-14,13.2,74.28,5.276,1018.84
16+
2017-01-15,16.434782608695652,72.56521739130434,3.630434782608696,1018.1304347826087
17+
2017-01-16,14.65,78.45,10.38,1017.15
18+
2017-01-17,11.722222222222221,84.44444444444444,8.038888888888888,1018.3888888888889
19+
2017-01-18,13.041666666666666,78.33333333333333,6.029166666666664,1021.9583333333334
20+
2017-01-19,14.619047619047619,75.14285714285714,10.338095238095239,1022.8095238095239
21+
2017-01-20,15.263157894736842,66.47368421052632,11.226315789473684,1021.7894736842105
22+
2017-01-21,15.391304347826088,70.8695652173913,13.695652173913043,1020.4782608695652
23+
2017-01-22,18.44,76.24,5.8679999999999986,1021.04
24+
2017-01-23,18.11764705882353,76.0,6.752941176470588,1019.8235294117648
25+
2017-01-24,18.347826086956523,68.1304347826087,3.3913043478260865,1018.8695652173913
26+
2017-01-25,21.0,69.96,8.755999999999998,1018.4
27+
2017-01-26,16.178571428571427,91.64285714285714,8.467857142857143,1017.7857142857143
28+
2017-01-27,16.5,77.04166666666667,14.358333333333333,1018.125
29+
2017-01-28,14.863636363636363,82.77272727272727,9.690909090909093,1019.6363636363636
30+
2017-01-29,15.666666666666666,81.77777777777777,10.294444444444444,1017.3888888888889
31+
2017-01-30,16.444444444444443,77.55555555555556,4.322222222222222,1015.8333333333334
32+
2017-01-31,16.125,76.0,4.625,1015.5
33+
2017-02-01,15.25,78.625,5.1000000000000005,1017.5
34+
2017-02-02,17.09090909090909,66.54545454545455,3.027272727272727,1018.9090909090909
35+
2017-02-03,15.636363636363637,78.18181818181819,1.8545454545454545,1017.7272727272727
36+
2017-02-04,18.7,77.6,9.819999999999999,1014.4
37+
2017-02-05,18.63157894736842,77.63157894736842,8.099999999999998,1014.2105263157895
38+
2017-02-06,16.88888888888889,69.66666666666667,9.044444444444444,1016.0
39+
2017-02-07,15.125,63.75,7.637500000000001,1016.125
40+
2017-02-08,15.7,68.4,4.08,1015.6
41+
2017-02-09,15.375,68.375,7.875000000000002,1016.375
42+
2017-02-10,14.666666666666666,71.77777777777777,9.066666666666666,1015.6666666666666
43+
2017-02-11,15.625,64.0,3.95,1016.625
44+
2017-02-12,16.25,70.375,1.625,1019.625
45+
2017-02-13,16.333333333333332,67.0,6.377777777777778,1021.5555555555555
46+
2017-02-14,16.875,65.5,6.9625,1021.375
47+
2017-02-15,17.571428571428573,67.71428571428571,5.557142857142857,1020.5714285714286
48+
2017-02-16,20.25,56.75,10.4375,1017.625
49+
2017-02-17,21.3,64.4,9.279999999999998,1016.5
50+
2017-02-18,21.125,70.75,6.25,1016.25
51+
2017-02-19,22.363636363636363,66.0909090909091,6.054545454545456,1013.0
52+
2017-02-20,23.375,60.125,6.937499999999999,1005.375
53+
2017-02-21,21.833333333333332,69.41666666666667,12.341666666666667,1007.4166666666666
54+
2017-02-22,19.125,57.125,7.4125000000000005,1012.25
55+
2017-02-23,18.625,42.875,14.35,1015.25
56+
2017-02-24,19.125,40.375,16.6625,1016.125
57+
2017-02-25,19.0,50.42857142857143,11.928571428571427,1014.2857142857143
58+
2017-02-26,18.75,59.0,11.1125,1012.375
59+
2017-02-27,19.875,58.375,5.1000000000000005,1014.25
60+
2017-02-28,23.333333333333332,51.666666666666664,3.9111111111111114,1013.1111111111111
61+
2017-03-01,24.46153846153846,47.92307692307692,6.415384615384617,1012.9230769230769
62+
2017-03-02,23.75,54.25,5.930000000000001,1012.15
63+
2017-03-03,20.5,42.5,7.4125000000000005,1010.625
64+
2017-03-04,19.125,43.125,8.350000000000001,1010.0
65+
2017-03-05,19.75,41.25,9.962499999999999,1010.5
66+
2017-03-06,20.0,42.44444444444444,9.666666666666664,1010.3333333333334
67+
2017-03-07,22.625,41.5,6.025,1007.375
68+
2017-03-08,21.545454545454547,52.72727272727273,10.263636363636364,1008.9090909090909
69+
2017-03-09,20.785714285714285,69.07142857142857,8.342857142857143,1007.3571428571429
70+
2017-03-10,19.9375,67.75,11.4625,1006.875
71+
2017-03-11,18.533333333333335,60.4,5.566666666666666,1009.8
72+
2017-03-12,17.375,56.625,7.637499999999999,1014.75
73+
2017-03-13,17.444444444444443,49.333333333333336,9.055555555555554,1014.8888888888889
74+
2017-03-14,18.0,56.333333333333336,4.522222222222222,1016.5555555555555
75+
2017-03-15,19.875,54.75,7.175000000000001,1014.125
76+
2017-03-16,24.0,49.2,5.5600000000000005,1011.1
77+
2017-03-17,20.9,59.7,11.489999999999998,1010.7
78+
2017-03-18,24.692307692307693,46.30769230769231,7.1230769230769235,1009.8461538461538
79+
2017-03-19,24.666666666666668,52.27777777777778,9.161111111111111,1011.8888888888889
80+
2017-03-20,23.333333333333332,54.666666666666664,10.077777777777778,1012.5555555555555
81+
2017-03-21,25.0,49.0,9.2625,1011.75
82+
2017-03-22,27.25,45.0,10.187500000000002,1009.75
83+
2017-03-23,28.0,49.75,3.4875000000000003,1008.875
84+
2017-03-24,28.916666666666668,37.666666666666664,10.033333333333335,1010.5833333333334
85+
2017-03-25,26.5,39.375,10.425,1009.875
86+
2017-03-26,29.1,37.1,17.59,1010.2
87+
2017-03-27,29.5,38.625,13.65,1009.5
88+
2017-03-28,29.88888888888889,40.666666666666664,8.844444444444445,1009.0
89+
2017-03-29,31.0,34.5,13.2,1007.125
90+
2017-03-30,29.285714285714285,36.857142857142854,10.585714285714285,1007.1428571428571
91+
2017-03-31,30.625,37.625,6.949999999999999,1007.5
92+
2017-04-01,31.375,35.125,9.0375,1005.0
93+
2017-04-02,29.75,33.75,9.2625,1004.25
94+
2017-04-03,30.5,29.75,6.9375,1004.25
95+
2017-04-04,30.933333333333334,31.866666666666667,14.319999999999999,1007.2
96+
2017-04-05,29.23076923076923,46.0,14.384615384615387,1005.0
97+
2017-04-06,31.22222222222222,26.0,13.577777777777776,1002.8888888888889
98+
2017-04-07,27.0,29.875,4.65,1007.375
99+
2017-04-08,25.625,29.375,8.337499999999999,1010.375
100+
2017-04-09,27.125,21.125,14.125,1010.625
101+
2017-04-10,27.857142857142858,19.428571428571427,19.314285714285713,1008.5714285714286
102+
2017-04-11,29.25,17.75,15.512500000000001,1006.25
103+
2017-04-12,29.25,26.0,9.4875,1005.875
104+
2017-04-13,29.666666666666668,29.11111111111111,4.944444444444445,1006.7777777777778
105+
2017-04-14,30.5,37.625,1.3875000000000002,1004.625
106+
2017-04-15,31.22222222222222,30.444444444444443,5.966666666666667,1002.4444444444445
107+
2017-04-16,31.0,34.25,2.0999999999999996,1003.25
108+
2017-04-17,32.55555555555556,38.44444444444444,5.366666666666666,1004.4444444444445
109+
2017-04-18,34.0,27.333333333333332,7.811111111111111,1003.1111111111111
110+
2017-04-19,33.5,24.125,9.025,1000.875
111+
2017-04-20,34.5,27.5,5.5625,998.625
112+
2017-04-21,34.25,39.375,6.9625,999.875
113+
2017-04-22,32.9,40.9,8.89,1001.6
114+
2017-04-23,32.875,27.5,9.962499999999999,1002.125
115+
2017-04-24,32.0,27.142857142857142,12.157142857142858,1004.1428571428571

0 commit comments

Comments
 (0)