-
Notifications
You must be signed in to change notification settings - Fork 96
/
mnist_cnn_train.py
166 lines (122 loc) · 5.7 KB
/
mnist_cnn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Some code was borrowed from https://github.com/petewarden/tensorflow_makefile/blob/master/tensorflow/models/image/mnist/convolutional.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy
import tensorflow as tf
import tensorflow.contrib.slim as slim
import mnist_data
import cnn_model
MODEL_DIRECTORY = "model/model.ckpt"
LOGS_DIRECTORY = "logs/train"
# Params for Train
training_epochs = 10# 10 for augmented training data, 20 for training data
TRAIN_BATCH_SIZE = 50
display_step = 100
validation_step = 500
# Params for test
TEST_BATCH_SIZE = 5000
def train():
# Some parameters
batch_size = TRAIN_BATCH_SIZE
num_labels = mnist_data.NUM_LABELS
# Prepare mnist data
train_total_data, train_size, validation_data, validation_labels, test_data, test_labels = mnist_data.prepare_MNIST_data(True)
# Boolean for MODE of train or test
is_training = tf.placeholder(tf.bool, name='MODE')
# tf Graph input
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10]) #answer
# Predict
y = cnn_model.CNN(x)
# Get loss of model
with tf.name_scope("LOSS"):
loss = slim.losses.softmax_cross_entropy(y,y_)
# Create a summary to monitor loss tensor
tf.scalar_summary('loss', loss)
# Define optimizer
with tf.name_scope("ADAM"):
# Optimizer: set up a variable that's incremented once per batch and
# controls the learning rate decay.
batch = tf.Variable(0)
learning_rate = tf.train.exponential_decay(
1e-4, # Base learning rate.
batch * batch_size, # Current index into the dataset.
train_size, # Decay step.
0.95, # Decay rate.
staircase=True)
# Use simple momentum for the optimization.
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss,global_step=batch)
# Create a summary to monitor learning_rate tensor
tf.scalar_summary('learning_rate', learning_rate)
# Get accuracy of model
with tf.name_scope("ACC"):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Create a summary to monitor accuracy tensor
tf.scalar_summary('acc', accuracy)
# Merge all summaries into a single op
merged_summary_op = tf.merge_all_summaries()
# Add ops to save and restore all the variables
saver = tf.train.Saver()
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer(), feed_dict={is_training: True})
# Training cycle
total_batch = int(train_size / batch_size)
# op to write logs to Tensorboard
summary_writer = tf.train.SummaryWriter(LOGS_DIRECTORY, graph=tf.get_default_graph())
# Save the maximum accuracy value for validation data
max_acc = 0.
# Loop for epoch
for epoch in range(training_epochs):
# Random shuffling
numpy.random.shuffle(train_total_data)
train_data_ = train_total_data[:, :-num_labels]
train_labels_ = train_total_data[:, -num_labels:]
# Loop over all batches
for i in range(total_batch):
# Compute the offset of the current minibatch in the data.
offset = (i * batch_size) % (train_size)
batch_xs = train_data_[offset:(offset + batch_size), :]
batch_ys = train_labels_[offset:(offset + batch_size), :]
# Run optimization op (backprop), loss op (to get loss value)
# and summary nodes
_, train_accuracy, summary = sess.run([train_step, accuracy, merged_summary_op] , feed_dict={x: batch_xs, y_: batch_ys, is_training: True})
# Write logs at every iteration
summary_writer.add_summary(summary, epoch * total_batch + i)
# Display logs
if i % display_step == 0:
print("Epoch:", '%04d,' % (epoch + 1),
"batch_index %4d/%4d, training accuracy %.5f" % (i, total_batch, train_accuracy))
# Get accuracy for validation data
if i % validation_step == 0:
# Calculate accuracy
validation_accuracy = sess.run(accuracy,
feed_dict={x: validation_data, y_: validation_labels, is_training: False})
print("Epoch:", '%04d,' % (epoch + 1),
"batch_index %4d/%4d, validation accuracy %.5f" % (i, total_batch, validation_accuracy))
# Save the current model if the maximum accuracy is updated
if validation_accuracy > max_acc:
max_acc = validation_accuracy
save_path = saver.save(sess, MODEL_DIRECTORY)
print("Model updated and saved in file: %s" % save_path)
print("Optimization Finished!")
# Restore variables from disk
saver.restore(sess, MODEL_DIRECTORY)
# Calculate accuracy for all mnist test images
test_size = test_labels.shape[0]
batch_size = TEST_BATCH_SIZE
total_batch = int(test_size / batch_size)
acc_buffer = []
# Loop over all batches
for i in range(total_batch):
# Compute the offset of the current minibatch in the data.
offset = (i * batch_size) % (test_size)
batch_xs = test_data[offset:(offset + batch_size), :]
batch_ys = test_labels[offset:(offset + batch_size), :]
y_final = sess.run(y, feed_dict={x: batch_xs, y_: batch_ys, is_training: False})
correct_prediction = numpy.equal(numpy.argmax(y_final, 1), numpy.argmax(batch_ys, 1))
acc_buffer.append(numpy.sum(correct_prediction) / batch_size)
print("test accuracy for the stored model: %g" % numpy.mean(acc_buffer))
if __name__ == '__main__':
train()