-
Notifications
You must be signed in to change notification settings - Fork 45
/
example_mnist_conv.py
53 lines (45 loc) · 1.86 KB
/
example_mnist_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
from network import Network
from fc_layer import FCLayer
from conv_layer import ConvLayer
from flatten_layer import FlattenLayer
from activation_layer import ActivationLayer
from activations import tanh, tanh_prime
from losses import mse, mse_prime
from keras.datasets import mnist
from keras.utils import np_utils
# load MNIST from server
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# training data : 60000 samples
# reshape and normalize input data
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_train = x_train.astype('float32')
x_train /= 255
# encode output which is a number in range [0,9] into a vector of size 10
# e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
y_train = np_utils.to_categorical(y_train)
# same for test data : 10000 samples
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_test = x_test.astype('float32')
x_test /= 255
y_test = np_utils.to_categorical(y_test)
# Network
net = Network()
net.add(ConvLayer((28, 28, 1), (3, 3), 1)) # input_shape=(28, 28, 1) ; output_shape=(26, 26, 1)
net.add(ActivationLayer(tanh, tanh_prime))
net.add(FlattenLayer()) # input_shape=(26, 26, 1) ; output_shape=(1, 26*26*1)
net.add(FCLayer(26*26*1, 100)) # input_shape=(1, 26*26*1) ; output_shape=(1, 100)
net.add(ActivationLayer(tanh, tanh_prime))
net.add(FCLayer(100, 10)) # input_shape=(1, 100) ; output_shape=(1, 10)
net.add(ActivationLayer(tanh, tanh_prime))
# train on 1000 samples
# as we didn't implemented mini-batch GD, training will be pretty slow if we update at each iteration on 60000 samples...
net.use(mse, mse_prime)
net.fit(x_train[0:1000], y_train[0:1000], epochs=100, learning_rate=0.1)
# test on 3 samples
out = net.predict(x_test[0:3])
print("\n")
print("predicted values : ")
print(out, end="\n")
print("true values : ")
print(y_test[0:3])