-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathimbd_qrnn.py
54 lines (46 loc) · 1.84 KB
/
imbd_qrnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''Trains a QRNN on the IMDB sentiment classification task.
Modified from keras' examples/imbd_lstm.py.
'''
from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
from keras.preprocessing import sequence
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Embedding
from keras.layers import LSTM, SimpleRNN, GRU
from keras.regularizers import l2
from keras.constraints import maxnorm
from keras.datasets import imdb
from qrnn import QRNN
max_features = 20000
maxlen = 80 # cut texts after this number of words (among top max_features most common words)
batch_size = 32
print('Loading data...')
(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features)
print(len(X_train), 'train sequences')
print(len(X_test), 'test sequences')
print('Pad sequences (samples x time)')
X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128, dropout=0.2))
model.add(QRNN(128, window_size=3, dropout=0.2,
W_regularizer=l2(1e-4), b_regularizer=l2(1e-4),
W_constraint=maxnorm(10), b_constraint=maxnorm(10)))
model.add(Dense(1))
model.add(Activation('sigmoid'))
# try using different optimizers and different optimizer configs
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print('Train...')
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=15,
validation_data=(X_test, y_test))
score, acc = model.evaluate(X_test, y_test,
batch_size=batch_size)
print('Test score:', score)
print('Test accuracy:', acc)