-
Notifications
You must be signed in to change notification settings - Fork 75
/
exercise3.m
140 lines (109 loc) · 3.69 KB
/
exercise3.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
% -------------------------------------------------------------------------
% Part 3: Learning a simple CNN
% -------------------------------------------------------------------------
setup ;
% -------------------------------------------------------------------------
% Part 3.1: Load an example image and generate its labels
% -------------------------------------------------------------------------
% Load an image
im = rgb2gray(im2single(imread('data/dots.jpg'))) ;
% Compute the location of black blobs in the image
[pos,neg] = extractBlackBlobs(im) ;
figure(1) ; clf ;
subplot(1,3,1) ; imagesc(im) ; axis equal ; title('image') ;
subplot(1,3,2) ; imagesc(pos) ; axis equal ; title('positive points (blob centres)') ;
subplot(1,3,3) ; imagesc(neg) ; axis equal ; title('negative points (not a blob)') ;
colormap gray ;
% -------------------------------------------------------------------------
% Part 3.2: Image preprocessing
% -------------------------------------------------------------------------
% Pre-smooth the image
im = imsmooth(im,3) ;
% Subtract median value
im = im - median(im(:)) ;
% -------------------------------------------------------------------------
% Part 3.3: Learning with stochastic gradient descent
% -------------------------------------------------------------------------
% SGD parameters:
% - numIterations: maximum number of iterations
% - rate: learning rate
% - momentum: momentum rate
% - shrinkRate: shrinkage rate (or coefficient of the L2 regulariser)
% - plotPeriod: how often to plot
numIterations = 500 ;
rate = 5 ;
momentum = 0.9 ;
shrinkRate = 0.0001 ;
plotPeriod = 10 ;
% Initial CNN parameters:
w = 10 * randn(3, 3, 1) ;
w = single(w - mean(w(:))) ;
b = single(0) ;
% Create pixel-level labes to compute the loss
y = zeros(size(pos),'single') ;
y(pos) = +1 ;
y(neg) = -1 ;
% Initial momentum
w_momentum = zeros('like', w) ;
b_momentum = zeros('like', b) ;
% SGD with momentum
for t = 1:numIterations
% Forward pass
res = tinycnn(im, w, b) ;
% Loss
z = y .* (res.x3 - 1) ;
E(1,t) = ...
mean(max(0, 1 - res.x3(pos))) + ...
mean(max(0, res.x3(neg))) ;
E(2,t) = 0.5 * shrinkRate * sum(w(:).^2) ;
E(3,t) = E(1,t) + E(2,t) ;
dzdx3 = ...
- single(res.x3 < 1 & pos) / sum(pos(:)) + ...
+ single(res.x3 > 0 & neg) / sum(neg(:)) ;
% Backward pass
res = tinycnn(im, w, b, dzdx3) ;
% Update momentum
w_momentum = momentum * w_momentum + rate * (res.dzdw + shrinkRate * w) ;
b_momentum = momentum * b_momentum + rate * 0.1 * res.dzdb ;
% Gradient step
w = w - w_momentum ;
b = b - b_momentum ;
% Plots
if mod(t-1, plotPeriod) == 0 || t == numIterations
fp = res.x3 > 0 & y < 0 ;
fn = res.x3 < 1 & y > 0 ;
tn = res.x3 <= 0 & y < 0 ;
tp = res.x3 >= 1 & y > 0 ;
err = cat(3, fp|fn , tp|tn, y==0) ;
figure(2) ; clf ;
colormap gray ;
subplot(2,3,1) ;
plot(1:t, E(:,1:t)') ;
grid on ; title('objective') ;
ylim([0 1.5]) ; legend('error', 'regularizer', 'total') ;
subplot(2,3,2) ; hold on ;
[h,x]=hist(res.x3(pos(:)),30) ; plot(x,h/max(h),'g') ;
[h,x]=hist(res.x3(neg(:)),30) ; plot(x,h/max(h),'r') ;
plot([0 0], [0 1], 'b--') ;
plot([1 1], [0 1], 'b--') ;
xlim([-2 3]) ;
title('histograms of scores') ; legend('pos', 'neg') ;
subplot(2,3,3) ;
vl_imarraysc(w) ;
title('learned filter') ; axis equal ;
subplot(2,3,4) ;
imagesc(res.x3) ;
title('network output') ; axis equal ;
subplot(2,3,5) ;
imagesc(res.x2) ;
title('first layer output') ; axis equal ;
subplot(2,3,6) ;
image(err) ;
title('red: pred. error, green: correct, blue: ignore') ;
if verLessThan('matlab', '8.4.0')
drawnow ;
else
drawnow expose ;
end
end
end