-
Notifications
You must be signed in to change notification settings - Fork 324
/
train.py
293 lines (241 loc) · 9.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import tensorflow as tf
import numpy as np
import glob
import sys
from matplotlib import pyplot as plt
from batchnorm import ConvolutionalBatchNormalizer
filenames = sorted(glob.glob("../colornet/*/*.jpg"))
batch_size = 1
num_epochs = 1e+9
global_step = tf.Variable(0, name='global_step', trainable=False)
phase_train = tf.placeholder(tf.bool, name='phase_train')
uv = tf.placeholder(tf.uint8, name='uv')
def read_my_file_format(filename_queue, randomize=False):
reader = tf.WholeFileReader()
key, file = reader.read(filename_queue)
uint8image = tf.image.decode_jpeg(file, channels=3)
uint8image = tf.random_crop(uint8image, (224, 224, 3))
if randomize:
uint8image = tf.image.random_flip_left_right(uint8image)
uint8image = tf.image.random_flip_up_down(uint8image, seed=None)
float_image = tf.div(tf.cast(uint8image, tf.float32), 255)
return float_image
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=num_epochs, shuffle=False)
example = read_my_file_format(filename_queue, randomize=False)
min_after_dequeue = 100
capacity = min_after_dequeue + 3 * batch_size
example_batch = tf.train.shuffle_batch(
[example], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch
def batch_norm(x, depth, phase_train):
with tf.variable_scope('batchnorm'):
ewma = tf.train.ExponentialMovingAverage(decay=0.9999)
bn = ConvolutionalBatchNormalizer(depth, 0.001, ewma, True)
update_assignments = bn.get_assigner()
x = bn.normalize(x, train=phase_train)
return x
def conv2d(_X, w, sigmoid=False, bn=False):
with tf.variable_scope('conv2d'):
_X = tf.nn.conv2d(_X, w, [1, 1, 1, 1], 'SAME')
if bn:
_X = batch_norm(_X, w.get_shape()[3], phase_train)
if sigmoid:
return tf.sigmoid(_X)
else:
_X = tf.nn.relu(_X)
return tf.maximum(0.01 * _X, _X)
def colornet(_tensors):
"""
Network architecture http://tinyclouds.org/colorize/residual_encoder.png
"""
with tf.variable_scope('colornet'):
# Bx28x28x512 -> batch norm -> 1x1 conv = Bx28x28x256
conv1 = tf.nn.relu(tf.nn.conv2d(batch_norm(_tensors[
"conv4_3"], 512, phase_train),
_tensors["weights"]["wc1"], [1, 1, 1, 1], 'SAME'))
# upscale to 56x56x256
conv1 = tf.image.resize_bilinear(conv1, (56, 56))
conv1 = tf.add(conv1, batch_norm(
_tensors["conv3_3"], 256, phase_train))
# Bx56x56x256-> 3x3 conv = Bx56x56x128
conv2 = conv2d(conv1, _tensors["weights"][
'wc2'], sigmoid=False, bn=True)
# upscale to 112x112x128
conv2 = tf.image.resize_bilinear(conv2, (112, 112))
conv2 = tf.add(conv2, batch_norm(
_tensors["conv2_2"], 128, phase_train))
# Bx112x112x128 -> 3x3 conv = Bx112x112x64
conv3 = conv2d(conv2, _tensors["weights"][
'wc3'], sigmoid=False, bn=True)
# upscale to Bx224x224x64
conv3 = tf.image.resize_bilinear(conv3, (224, 224))
conv3 = tf.add(conv3, batch_norm(_tensors["conv1_2"], 64, phase_train))
# Bx224x224x64 -> 3x3 conv = Bx224x224x3
conv4 = conv2d(conv3, _tensors["weights"][
'wc4'], sigmoid=False, bn=True)
conv4 = tf.add(conv4, batch_norm(
_tensors["grayscale"], 3, phase_train))
# Bx224x224x3 -> 3x3 conv = Bx224x224x3
conv5 = conv2d(conv4, _tensors["weights"][
'wc5'], sigmoid=False, bn=True)
# Bx224x224x3 -> 3x3 conv = Bx224x224x2
conv6 = conv2d(conv5, _tensors["weights"][
'wc6'], sigmoid=True, bn=True)
return conv6
def concat_images(imga, imgb):
"""
Combines two color image ndarrays side-by-side.
"""
ha, wa = imga.shape[:2]
hb, wb = imgb.shape[:2]
max_height = np.max([ha, hb])
total_width = wa + wb
new_img = np.zeros(shape=(max_height, total_width, 3), dtype=np.float32)
new_img[:ha, :wa] = imga
new_img[:hb, wa:wa + wb] = imgb
return new_img
def rgb2yuv(rgb):
"""
Convert RGB image into YUV https://en.wikipedia.org/wiki/YUV
"""
rgb2yuv_filter = tf.constant(
[[[[0.299, -0.169, 0.499],
[0.587, -0.331, -0.418],
[0.114, 0.499, -0.0813]]]])
rgb2yuv_bias = tf.constant([0., 0.5, 0.5])
temp = tf.nn.conv2d(rgb, rgb2yuv_filter, [1, 1, 1, 1], 'SAME')
temp = tf.nn.bias_add(temp, rgb2yuv_bias)
return temp
def yuv2rgb(yuv):
"""
Convert YUV image into RGB https://en.wikipedia.org/wiki/YUV
"""
yuv = tf.mul(yuv, 255)
yuv2rgb_filter = tf.constant(
[[[[1., 1., 1.],
[0., -0.34413999, 1.77199996],
[1.40199995, -0.71414, 0.]]]])
yuv2rgb_bias = tf.constant([-179.45599365, 135.45983887, -226.81599426])
temp = tf.nn.conv2d(yuv, yuv2rgb_filter, [1, 1, 1, 1], 'SAME')
temp = tf.nn.bias_add(temp, yuv2rgb_bias)
temp = tf.maximum(temp, tf.zeros(temp.get_shape(), dtype=tf.float32))
temp = tf.minimum(temp, tf.mul(
tf.ones(temp.get_shape(), dtype=tf.float32), 255))
temp = tf.div(temp, 255)
return temp
with open("vgg/tensorflow-vgg16/vgg16-20160129.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
with tf.variable_scope('colornet'):
# Store layers weight
weights = {
# 1x1 conv, 512 inputs, 256 outputs
'wc1': tf.Variable(tf.truncated_normal([1, 1, 512, 256], stddev=0.01)),
# 3x3 conv, 512 inputs, 128 outputs
'wc2': tf.Variable(tf.truncated_normal([3, 3, 256, 128], stddev=0.01)),
# 3x3 conv, 256 inputs, 64 outputs
'wc3': tf.Variable(tf.truncated_normal([3, 3, 128, 64], stddev=0.01)),
# 3x3 conv, 128 inputs, 3 outputs
'wc4': tf.Variable(tf.truncated_normal([3, 3, 64, 3], stddev=0.01)),
# 3x3 conv, 6 inputs, 3 outputs
'wc5': tf.Variable(tf.truncated_normal([3, 3, 3, 3], stddev=0.01)),
# 3x3 conv, 3 inputs, 2 outputs
'wc6': tf.Variable(tf.truncated_normal([3, 3, 3, 2], stddev=0.01)),
}
colorimage = input_pipeline(filenames, batch_size, num_epochs=num_epochs)
colorimage_yuv = rgb2yuv(colorimage)
grayscale = tf.image.rgb_to_grayscale(colorimage)
grayscale_rgb = tf.image.grayscale_to_rgb(grayscale)
grayscale_yuv = rgb2yuv(grayscale_rgb)
grayscale = tf.concat(3, [grayscale, grayscale, grayscale])
tf.import_graph_def(graph_def, input_map={"images": grayscale})
graph = tf.get_default_graph()
with tf.variable_scope('vgg'):
conv1_2 = graph.get_tensor_by_name("import/conv1_2/Relu:0")
conv2_2 = graph.get_tensor_by_name("import/conv2_2/Relu:0")
conv3_3 = graph.get_tensor_by_name("import/conv3_3/Relu:0")
conv4_3 = graph.get_tensor_by_name("import/conv4_3/Relu:0")
tensors = {
"conv1_2": conv1_2,
"conv2_2": conv2_2,
"conv3_3": conv3_3,
"conv4_3": conv4_3,
"grayscale": grayscale,
"weights": weights
}
# Construct model
pred = colornet(tensors)
pred_yuv = tf.concat(3, [tf.split(3, 3, grayscale_yuv)[0], pred])
pred_rgb = yuv2rgb(pred_yuv)
loss = tf.square(tf.sub(pred, tf.concat(
3, [tf.split(3, 3, colorimage_yuv)[1], tf.split(3, 3, colorimage_yuv)[2]])))
if uv == 1:
loss = tf.split(3, 2, loss)[0]
elif uv == 2:
loss = tf.split(3, 2, loss)[1]
else:
loss = (tf.split(3, 2, loss)[0] + tf.split(3, 2, loss)[1]) / 2
if phase_train:
optimizer = tf.train.GradientDescentOptimizer(0.0001)
opt = optimizer.minimize(
loss, global_step=global_step, gate_gradients=optimizer.GATE_NONE)
# Summaries
tf.histogram_summary("weights1", weights["wc1"])
tf.histogram_summary("weights2", weights["wc2"])
tf.histogram_summary("weights3", weights["wc3"])
tf.histogram_summary("weights4", weights["wc4"])
tf.histogram_summary("weights5", weights["wc5"])
tf.histogram_summary("weights6", weights["wc6"])
tf.histogram_summary("instant_loss", tf.reduce_mean(loss))
tf.image_summary("colorimage", colorimage, max_images=1)
tf.image_summary("pred_rgb", pred_rgb, max_images=1)
tf.image_summary("grayscale", grayscale_rgb, max_images=1)
# Saver.
saver = tf.train.Saver()
# Create the graph, etc.
init_op = tf.initialize_all_variables()
# Create a session for running operations in the Graph.
sess = tf.Session()
# Initialize the variables.
sess.run(init_op)
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("tb_log", sess.graph_def)
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
# Run training steps
training_opt = sess.run(opt, feed_dict={phase_train: True, uv: 1})
training_opt = sess.run(opt, feed_dict={phase_train: True, uv: 2})
step = sess.run(global_step)
if step % 1 == 0:
pred_, pred_rgb_, colorimage_, grayscale_rgb_, cost, merged_ = sess.run(
[pred, pred_rgb, colorimage, grayscale_rgb, loss, merged], feed_dict={phase_train: False, uv: 3})
print {
"step": step,
"cost": np.mean(cost)
}
if step % 1000 == 0:
summary_image = concat_images(grayscale_rgb_[0], pred_rgb_[0])
summary_image = concat_images(summary_image, colorimage_[0])
plt.imsave("summary/" + str(step) + "_0", summary_image)
sys.stdout.flush()
writer.add_summary(merged_, step)
writer.flush()
if step % 100000 == 99998:
save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)
sys.stdout.flush()
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
# Wait for threads to finish.
coord.join(threads)
sess.close()