Skip to content

Commit 0f3567c

Browse files
committed
lots of changes are added
- add tiny yolo3 support, close qqwweee#62 - weights.h5 can be used in yolo.py, close qqwweee#54 - use binary cross entropy loss, close qqwweee#58 - use fit_generator and real-time data augmentation, close qqwweee#67
1 parent 3f93a89 commit 0f3567c

File tree

8 files changed

+490
-148
lines changed

8 files changed

+490
-148
lines changed

README.md

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5
2020
python yolo.py OR python yolo_video.py
2121
```
2222

23+
For Tiny YOLOv3, just do in the similar way. And modify model path and anchor path in `yolo.py`.
24+
2325
---
2426

2527
## Training
2628

2729
1. Generate your own annotation file and class names file.
2830
One row for one image;
29-
Row format: image_file_path box1 box2 ... boxN;
30-
Box format: x_min,y_min,x_max,y_max,class_id (no space).
31+
Row format: `image_file_path box1 box2 ... boxN`;
32+
Box format: `x_min,y_min,x_max,y_max,class_id` (no space).
3133
For VOC dataset, try `python voc_annotation.py`
3234

3335
2. Make sure you have run `python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5`
@@ -36,4 +38,18 @@ python yolo.py OR python yolo_video.py
3638

3739
3. Modify train.py and start training.
3840
`python train.py`
39-
You will get the trained model model_data/my_yolo.h5.
41+
Use your trained weights or checkpoint weights in `yolo.py`.
42+
Remember to modify class path or anchor path.
43+
44+
---
45+
46+
## Some issues to know
47+
48+
1. The test environment is
49+
- Python 3.5.2
50+
- Keras 2.1.5
51+
- tensorflow 1.6.0
52+
53+
2. Default anchors are used. If you use your own anchors, probably some changes are needed.
54+
55+
3. The training strategy is for reference only. Adjust it according to your dataset and your goal. And add further strategy if needed.

convert.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
from keras import backend as K
1515
from keras.layers import (Conv2D, Input, ZeroPadding2D, Add,
16-
UpSampling2D, Concatenate)
16+
UpSampling2D, MaxPooling2D, Concatenate)
1717
from keras.layers.advanced_activations import LeakyReLU
1818
from keras.layers.normalization import BatchNormalization
1919
from keras.models import Model
@@ -194,13 +194,23 @@ def _main(args):
194194
all_layers.append(skip_layer)
195195
prev_layer = skip_layer
196196

197+
elif section.startswith('maxpool'):
198+
size = int(cfg_parser[section]['size'])
199+
stride = int(cfg_parser[section]['stride'])
200+
all_layers.append(
201+
MaxPooling2D(
202+
pool_size=(size, size),
203+
strides=(stride, stride),
204+
padding='same')(prev_layer))
205+
prev_layer = all_layers[-1]
206+
197207
elif section.startswith('shortcut'):
198208
index = int(cfg_parser[section]['from'])
199209
activation = cfg_parser[section]['activation']
200210
assert activation == 'linear', 'Only linear activation supported.'
201211
all_layers.append(Add()([all_layers[index], prev_layer]))
202212
prev_layer = all_layers[-1]
203-
213+
204214
elif section.startswith('upsample'):
205215
stride = int(cfg_parser[section]['stride'])
206216
assert stride == 2, 'Only stride=2 supported.'

model_data/tiny_yolo_anchors.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
10,14, 23,27, 37,58, 81,82, 135,169, 344,319

train.py

Lines changed: 96 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,60 @@
44
import os
55

66
import numpy as np
7-
from PIL import Image
7+
import keras.backend as K
88
from keras.layers import Input, Lambda
99
from keras.models import load_model, Model
1010
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
1111

12-
from yolo3.model import preprocess_true_boxes, yolo_body, yolo_loss
13-
from yolo3.utils import letterbox_image
12+
from yolo3.model import preprocess_true_boxes, yolo_body, tiny_yolo_body, yolo_loss
13+
from yolo3.utils import get_random_data
1414

15-
# Default anchor boxes
16-
YOLO_ANCHORS = np.array(((10,13), (16,30), (33,23), (30,61),
17-
(62,45), (59,119), (116,90), (156,198), (373,326)))
1815

1916
def _main():
2017
annotation_path = 'train.txt'
21-
data_path = 'train.npz'
22-
output_path = 'model_data/my_yolo.h5'
2318
log_dir = 'logs/000/'
2419
classes_path = 'model_data/voc_classes.txt'
2520
anchors_path = 'model_data/yolo_anchors.txt'
2621
class_names = get_classes(classes_path)
2722
anchors = get_anchors(anchors_path)
2823

29-
input_shape = (416,416) # multiple of 32
30-
image_data, box_data = get_training_data(annotation_path, data_path,
31-
input_shape, max_boxes=100, load_previous=True)
32-
y_true = preprocess_true_boxes(box_data, input_shape, anchors, len(class_names))
24+
input_shape = (416,416) # multiple of 32, hw
3325

34-
infer_model, model = create_model(input_shape, anchors, len(class_names),
26+
is_tiny_version = len(anchors)==6 # default setting
27+
create_func = create_tiny_model if is_tiny_version else create_model
28+
model = create_func(input_shape, anchors, len(class_names),
3529
load_pretrained=True, freeze_body=True)
3630

37-
train(model, image_data/255., y_true, log_dir=log_dir)
31+
train(model, annotation_path, input_shape, anchors, len(class_names), log_dir=log_dir)
3832

39-
infer_model.save(output_path)
33+
def train(model, annotation_path, input_shape, anchors, num_classes, log_dir='logs/'):
34+
'''retrain/fine-tune the model'''
35+
model.compile(optimizer='adam', loss={
36+
# use custom yolo_loss Lambda layer.
37+
'yolo_loss': lambda y_true, y_pred: y_pred})
4038

39+
logging = TensorBoard(log_dir=log_dir)
40+
checkpoint = ModelCheckpoint(log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5",
41+
monitor='val_loss', save_weights_only=True, save_best_only=True)
42+
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='auto')
43+
44+
batch_size = 32
45+
val_split = 0.1
46+
with open(annotation_path) as f:
47+
lines = f.readlines()
48+
np.random.shuffle(lines)
49+
num_val = int(len(lines)*val_split)
50+
num_train = len(lines) - num_val
51+
52+
model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
53+
steps_per_epoch=max(1, num_train//batch_size),
54+
validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
55+
validation_steps=max(1, num_val//batch_size),
56+
epochs=30,
57+
initial_epoch=0,
58+
callbacks=[logging, checkpoint, early_stopping])
59+
model.save_weights(log_dir + 'trained_weights.h5')
60+
# Further training.
4161

4262

4363
def get_classes(classes_path):
@@ -49,63 +69,22 @@ def get_classes(classes_path):
4969

5070
def get_anchors(anchors_path):
5171
'''loads the anchors from a file'''
52-
if os.path.isfile(anchors_path):
53-
with open(anchors_path) as f:
54-
anchors = f.readline()
55-
anchors = [float(x) for x in anchors.split(',')]
56-
return np.array(anchors).reshape(-1, 2)
57-
else:
58-
Warning("Could not open anchors file, using default.")
59-
return YOLO_ANCHORS
60-
61-
def get_training_data(annotation_path, data_path, input_shape, max_boxes=100, load_previous=True):
62-
'''processes the data into standard shape
63-
annotation row format: image_file_path box1 box2 ... boxN
64-
box format: x_min,y_min,x_max,y_max,class_index (no space)
65-
'''
66-
if load_previous==True and os.path.isfile(data_path):
67-
data = np.load(data_path)
68-
print('Loading training data from ' + data_path)
69-
return data['image_data'], data['box_data']
70-
image_data = []
71-
box_data = []
72-
with open(annotation_path) as f:
73-
for line in f.readlines():
74-
line = line.split(' ')
75-
filename = line[0]
76-
image = Image.open(filename)
77-
boxed_image = letterbox_image(image, tuple(reversed(input_shape)))
78-
image_data.append(np.array(boxed_image,dtype='uint8'))
79-
80-
boxes = np.zeros((max_boxes,5), dtype='int32')
81-
for i, box in enumerate(line[1:]):
82-
if i < max_boxes:
83-
boxes[i] = np.array(list(map(int,box.split(','))))
84-
else:
85-
break
86-
image_size = np.array(image.size)
87-
input_size = np.array(input_shape[::-1])
88-
new_size = (image_size * np.min(input_size/image_size)).astype('int32')
89-
boxes[:i+1, 0:2] = (boxes[:i+1, 0:2]*new_size/image_size + (input_size-new_size)/2).astype('int32')
90-
boxes[:i+1, 2:4] = (boxes[:i+1, 2:4]*new_size/image_size + (input_size-new_size)/2).astype('int32')
91-
box_data.append(boxes)
92-
image_data = np.array(image_data)
93-
box_data = np.array(box_data)
94-
np.savez(data_path, image_data=image_data, box_data=box_data)
95-
print('Saving training data into ' + data_path)
96-
return image_data, box_data
72+
with open(anchors_path) as f:
73+
anchors = f.readline()
74+
anchors = [float(x) for x in anchors.split(',')]
75+
return np.array(anchors).reshape(-1, 2)
9776

9877

9978
def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=True):
10079
'''create the training model'''
10180
image_input = Input(shape=(None, None, 3))
10281
h, w = input_shape
103-
num_anchors = len(anchors)//3
104-
y_true = [Input(shape=(h//32, w//32, num_anchors, num_classes+5)),
105-
Input(shape=(h//16, w//16, num_anchors, num_classes+5)),
106-
Input(shape=(h//8, w//8, num_anchors, num_classes+5))]
82+
num_anchors = len(anchors)
10783

108-
model_body = yolo_body(image_input, num_anchors, num_classes)
84+
y_true = [Input(shape=(h//{0:32, 1:16, 2:8}[l], w//{0:32, 1:16, 2:8}[l], \
85+
num_anchors//3, num_classes+5)) for l in range(3)]
86+
87+
model_body = yolo_body(image_input, num_anchors//3, num_classes)
10988

11089
if load_pretrained:
11190
weights_path = os.path.join('model_data', 'yolo_weights.h5')
@@ -121,33 +100,66 @@ def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze
121100
model_body.layers[i].trainable = False
122101

123102
model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',
124-
arguments={'anchors': anchors, 'num_classes': num_classes})(
103+
arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5})(
125104
[*model_body.output, *y_true])
126105
model = Model([model_body.input, *y_true], model_loss)
127106

128-
return model_body, model
107+
return model
129108

130-
def train(model, image_data, y_true, log_dir='logs/'):
131-
'''retrain/fine-tune the model'''
132-
model.compile(optimizer='adam', loss={
133-
# use custom yolo_loss Lambda layer.
134-
'yolo_loss': lambda y_true, y_pred: y_pred})
109+
def create_tiny_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=True):
110+
'''create the training model, for Tiny YOLOv3'''
111+
image_input = Input(shape=(None, None, 3))
112+
h, w = input_shape
113+
num_anchors = len(anchors)
135114

136-
logging = TensorBoard(log_dir=log_dir)
137-
checkpoint = ModelCheckpoint(log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5",
138-
monitor='val_loss', save_weights_only=True, save_best_only=True)
139-
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='auto')
115+
y_true = [Input(shape=(h//{0:32, 1:16}[l], w//{0:32, 1:16}[l], \
116+
num_anchors//2, num_classes+5)) for l in range(2)]
140117

141-
model.fit([image_data, *y_true],
142-
np.zeros(len(image_data)),
143-
validation_split=.1,
144-
batch_size=32,
145-
epochs=30,
146-
callbacks=[logging, checkpoint, early_stopping])
147-
model.save_weights(log_dir + 'trained_weights.h5')
148-
# Further training.
118+
model_body = tiny_yolo_body(image_input, num_anchors//2, num_classes)
149119

120+
if load_pretrained:
121+
weights_path = os.path.join('model_data/', 'tiny_yolo_weights.h5')
122+
if not os.path.exists(weights_path):
123+
print("CREATING WEIGHTS FILE" + weights_path)
124+
yolo_path = os.path.join('model_data', 'tiny_yolo.h5')
125+
orig_model = load_model(yolo_path, compile=False)
126+
orig_model.save_weights(weights_path)
127+
model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
128+
if freeze_body:
129+
# Do not freeze 2 output layers.
130+
for i in range(len(model_body.layers)-2):
131+
model_body.layers[i].trainable = False
132+
133+
model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',
134+
arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.7})(
135+
[*model_body.output, *y_true])
136+
model = Model([model_body.input, *y_true], model_loss)
150137

138+
return model
139+
140+
def data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes):
141+
'''data generator for fit_generator'''
142+
n = len(annotation_lines)
143+
np.random.shuffle(annotation_lines)
144+
i = 0
145+
while True:
146+
image_data = []
147+
box_data = []
148+
for b in range(batch_size):
149+
i %= n
150+
image, box = get_random_data(annotation_lines[i], input_shape)
151+
image_data.append(image)
152+
box_data.append(box)
153+
i += 1
154+
image_data = np.array(image_data)
155+
box_data = np.array(box_data)
156+
y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)
157+
yield [image_data, *y_true], np.zeros(batch_size)
158+
159+
def data_generator_wrap(annotation_lines, batch_size, input_shape, anchors, num_classes):
160+
n = len(annotation_lines)
161+
if n==0 or batch_size<=0: return None
162+
return data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes)
151163

152164
if __name__ == '__main__':
153165
_main()

yolo.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,28 @@
77
import colorsys
88
import os
99
import random
10-
from timeit import time
11-
from timeit import default_timer as timer ### to calculate FPS
10+
from timeit import default_timer as timer
1211

1312
import numpy as np
1413
from keras import backend as K
1514
from keras.models import load_model
15+
from keras.layers import Input
1616
from PIL import Image, ImageFont, ImageDraw
1717

18-
from yolo3.model import yolo_eval
18+
from yolo3.model import yolo_eval, yolo_body, tiny_yolo_body
1919
from yolo3.utils import letterbox_image
2020

2121
class YOLO(object):
2222
def __init__(self):
23-
self.model_path = 'model_data/yolo.h5'
23+
self.model_path = 'model_data/yolo.h5' # model path or trained weights path
2424
self.anchors_path = 'model_data/yolo_anchors.txt'
2525
self.classes_path = 'model_data/coco_classes.txt'
2626
self.score = 0.3
2727
self.iou = 0.5
2828
self.class_names = self._get_class()
2929
self.anchors = self._get_anchors()
3030
self.sess = K.get_session()
31-
self.model_image_size = (416, 416) # fixed size or (None, None)
32-
self.is_fixed_size = self.model_image_size != (None, None)
31+
self.model_image_size = (416, 416) # fixed size or (None, None), hw
3332
self.boxes, self.scores, self.classes = self.generate()
3433

3534
def _get_class(self):
@@ -43,15 +42,28 @@ def _get_anchors(self):
4342
anchors_path = os.path.expanduser(self.anchors_path)
4443
with open(anchors_path) as f:
4544
anchors = f.readline()
46-
anchors = [float(x) for x in anchors.split(',')]
47-
anchors = np.array(anchors).reshape(-1, 2)
48-
return anchors
45+
anchors = [float(x) for x in anchors.split(',')]
46+
return np.array(anchors).reshape(-1, 2)
4947

5048
def generate(self):
5149
model_path = os.path.expanduser(self.model_path)
52-
assert model_path.endswith('.h5'), 'Keras model must be a .h5 file.'
50+
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
51+
52+
# Load model, or construct model and load weights.
53+
num_anchors = len(self.anchors)
54+
num_classes = len(self.class_names)
55+
is_tiny_version = num_anchors==6 # default setting
56+
try:
57+
self.yolo_model = load_model(model_path, compile=False)
58+
except:
59+
self.yolo_model = tiny_yolo_body(Input(shape=(None,None,3)), num_anchors//2, num_classes) \
60+
if is_tiny_version else yolo_body(Input(shape=(None,None,3)), num_anchors//3, num_classes)
61+
self.yolo_model.load_weights(self.model_path) # make sure model, anchors and classes match
62+
else:
63+
assert self.yolo_model.layers[-1].output_shape[-1] == \
64+
num_anchors/len(self.yolo_model.output) * (num_classes + 5), \
65+
'Mismatch between model and given anchor and class sizes'
5366

54-
self.yolo_model = load_model(model_path, compile=False)
5567
print('{} model, anchors, and classes loaded.'.format(model_path))
5668

5769
# Generate colors for drawing bounding boxes.
@@ -73,9 +85,9 @@ def generate(self):
7385
return boxes, scores, classes
7486

7587
def detect_image(self, image):
76-
start = time.time()
88+
start = timer()
7789

78-
if self.is_fixed_size:
90+
if self.model_image_size != (None, None):
7991
assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required'
8092
assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required'
8193
boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size)))
@@ -135,7 +147,7 @@ def detect_image(self, image):
135147
draw.text(text_origin, label, fill=(0, 0, 0), font=font)
136148
del draw
137149

138-
end = time.time()
150+
end = timer()
139151
print(end - start)
140152
return image
141153

0 commit comments

Comments
 (0)