forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv.jl
141 lines (118 loc) · 4.54 KB
/
conv.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Classifies MNIST digits with a convolutional network.
# Writes out saved model to the file "mnist_conv.bson".
# Demonstrates basic model construction, training, saving,
# conditional early-exit, and learning rate scheduling.
#
# This model, while simple, should hit around 99% test
# accuracy after training for approximately 20 epochs.
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy
using Base.Iterators: repeated, partition
using Printf, BSON
using CUDAapi
if has_cuda()
@info "CUDA is on"
import CuArrays
CuArrays.allowscalar(false)
end
# Load labels and images from Flux.Data.MNIST
@info("Loading data set")
train_labels = MNIST.labels()
train_imgs = MNIST.images()
# Bundle images together with labels and group into minibatchess
function make_minibatch(X, Y, idxs)
X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
for i in 1:length(idxs)
X_batch[:, :, :, i] = Float32.(X[idxs[i]])
end
Y_batch = onehotbatch(Y[idxs], 0:9)
return (X_batch, Y_batch)
end
batch_size = 128
mb_idxs = partition(1:length(train_imgs), batch_size)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]
# Prepare test set as one giant minibatch:
test_imgs = MNIST.images(:test)
test_labels = MNIST.labels(:test)
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_imgs))
# Define our model. We will use a simple convolutional architecture with
# three iterations of Conv -> ReLU -> MaxPool, followed by a final Dense
# layer that feeds into a softmax probability output.
@info("Constructing model...")
model = Chain(
# First convolution, operating upon a 28x28 image
Conv((3, 3), 1=>16, pad=(1,1), relu),
MaxPool((2,2)),
# Second convolution, operating upon a 14x14 image
Conv((3, 3), 16=>32, pad=(1,1), relu),
MaxPool((2,2)),
# Third convolution, operating upon a 7x7 image
Conv((3, 3), 32=>32, pad=(1,1), relu),
MaxPool((2,2)),
# Reshape 3d tensor into a 2d one, at this point it should be (3, 3, 32, N)
# which is where we get the 288 in the `Dense` layer below:
x -> reshape(x, :, size(x, 4)),
Dense(288, 10),
# Finally, softmax to get nice probabilities
softmax,
)
# Load model and datasets onto GPU, if enabled
train_set = gpu.(train_set)
test_set = gpu.(test_set)
model = gpu(model)
# Make sure our model is nicely precompiled before starting our training loop
model(train_set[1][1])
# We augment `x` a little bit here, adding in random noise.
augment(x) = x .+ gpu(0.1f0*randn(eltype(x), size(x)))
paramvec(m) = vcat(map(p->reshape(p, :), params(m))...)
anynan(x) = any(isnan.(x))
# `loss()` calculates the crossentropy loss between our prediction `y_hat`
# (calculated from `model(x)`) and the ground truth `y`. We augment the data
# a bit, adding gaussian random noise to our image to make it more robust.
function loss(x, y)
x̂ = augment(x)
ŷ = model(x̂)
return crossentropy(ŷ, y)
end
accuracy(x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
# Train our model with the given training set using the ADAM optimizer and
# printing out performance against the test set as we go.
opt = ADAM(0.001)
@info("Beginning training loop...")
best_acc = 0.0
last_improvement = 0
for epoch_idx in 1:100
global best_acc, last_improvement
# Train for a single epoch
Flux.train!(loss, params(model), train_set, opt)
if anynan(paramvec(model))
@error "NaN params"
break
end
# Calculate accuracy:
acc = accuracy(test_set...)
@info(@sprintf("[%d]: Test accuracy: %.4f", epoch_idx, acc))
# If our accuracy is good enough, quit out.
if acc >= 0.999
@info(" -> Early-exiting: We reached our target accuracy of 99.9%")
break
end
# If this is the best accuracy we've seen so far, save the model out
if acc >= best_acc
@info(" -> New best accuracy! Saving model out to mnist_conv.bson")
BSON.@save joinpath(dirname(@__FILE__), "mnist_conv.bson") params=cpu.(params(model)) epoch_idx acc
best_acc = acc
last_improvement = epoch_idx
end
# If we haven't seen improvement in 5 epochs, drop our learning rate:
if epoch_idx - last_improvement >= 5 && opt.eta > 1e-6
opt.eta /= 10.0
@warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!")
# After dropping learning rate, give it a few epochs to improve
last_improvement = epoch_idx
end
if epoch_idx - last_improvement >= 10
@warn(" -> We're calling this converged.")
break
end
end