4
4
import os
5
5
6
6
import numpy as np
7
- from PIL import Image
7
+ import keras . backend as K
8
8
from keras .layers import Input , Lambda
9
9
from keras .models import load_model , Model
10
10
from keras .callbacks import TensorBoard , ModelCheckpoint , EarlyStopping
11
11
12
- from yolo3 .model import preprocess_true_boxes , yolo_body , yolo_loss
13
- from yolo3 .utils import letterbox_image
12
+ from yolo3 .model import preprocess_true_boxes , yolo_body , tiny_yolo_body , yolo_loss
13
+ from yolo3 .utils import get_random_data
14
14
15
- # Default anchor boxes
16
- YOLO_ANCHORS = np .array (((10 ,13 ), (16 ,30 ), (33 ,23 ), (30 ,61 ),
17
- (62 ,45 ), (59 ,119 ), (116 ,90 ), (156 ,198 ), (373 ,326 )))
18
15
19
16
def _main ():
20
17
annotation_path = 'train.txt'
21
- data_path = 'train.npz'
22
- output_path = 'model_data/my_yolo.h5'
23
18
log_dir = 'logs/000/'
24
19
classes_path = 'model_data/voc_classes.txt'
25
20
anchors_path = 'model_data/yolo_anchors.txt'
26
21
class_names = get_classes (classes_path )
27
22
anchors = get_anchors (anchors_path )
28
23
29
- input_shape = (416 ,416 ) # multiple of 32
30
- image_data , box_data = get_training_data (annotation_path , data_path ,
31
- input_shape , max_boxes = 100 , load_previous = True )
32
- y_true = preprocess_true_boxes (box_data , input_shape , anchors , len (class_names ))
24
+ input_shape = (416 ,416 ) # multiple of 32, hw
33
25
34
- infer_model , model = create_model (input_shape , anchors , len (class_names ),
26
+ is_tiny_version = len (anchors )== 6 # default setting
27
+ create_func = create_tiny_model if is_tiny_version else create_model
28
+ model = create_func (input_shape , anchors , len (class_names ),
35
29
load_pretrained = True , freeze_body = True )
36
30
37
- train (model , image_data / 255. , y_true , log_dir = log_dir )
31
+ train (model , annotation_path , input_shape , anchors , len ( class_names ) , log_dir = log_dir )
38
32
39
- infer_model .save (output_path )
33
+ def train (model , annotation_path , input_shape , anchors , num_classes , log_dir = 'logs/' ):
34
+ '''retrain/fine-tune the model'''
35
+ model .compile (optimizer = 'adam' , loss = {
36
+ # use custom yolo_loss Lambda layer.
37
+ 'yolo_loss' : lambda y_true , y_pred : y_pred })
40
38
39
+ logging = TensorBoard (log_dir = log_dir )
40
+ checkpoint = ModelCheckpoint (log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5" ,
41
+ monitor = 'val_loss' , save_weights_only = True , save_best_only = True )
42
+ early_stopping = EarlyStopping (monitor = 'val_loss' , min_delta = 0 , patience = 5 , verbose = 1 , mode = 'auto' )
43
+
44
+ batch_size = 32
45
+ val_split = 0.1
46
+ with open (annotation_path ) as f :
47
+ lines = f .readlines ()
48
+ np .random .shuffle (lines )
49
+ num_val = int (len (lines )* val_split )
50
+ num_train = len (lines ) - num_val
51
+
52
+ model .fit_generator (data_generator_wrap (lines [:num_train ], batch_size , input_shape , anchors , num_classes ),
53
+ steps_per_epoch = max (1 , num_train // batch_size ),
54
+ validation_data = data_generator_wrap (lines [num_train :], batch_size , input_shape , anchors , num_classes ),
55
+ validation_steps = max (1 , num_val // batch_size ),
56
+ epochs = 30 ,
57
+ initial_epoch = 0 ,
58
+ callbacks = [logging , checkpoint , early_stopping ])
59
+ model .save_weights (log_dir + 'trained_weights.h5' )
60
+ # Further training.
41
61
42
62
43
63
def get_classes (classes_path ):
@@ -49,63 +69,22 @@ def get_classes(classes_path):
49
69
50
70
def get_anchors (anchors_path ):
51
71
'''loads the anchors from a file'''
52
- if os .path .isfile (anchors_path ):
53
- with open (anchors_path ) as f :
54
- anchors = f .readline ()
55
- anchors = [float (x ) for x in anchors .split (',' )]
56
- return np .array (anchors ).reshape (- 1 , 2 )
57
- else :
58
- Warning ("Could not open anchors file, using default." )
59
- return YOLO_ANCHORS
60
-
61
- def get_training_data (annotation_path , data_path , input_shape , max_boxes = 100 , load_previous = True ):
62
- '''processes the data into standard shape
63
- annotation row format: image_file_path box1 box2 ... boxN
64
- box format: x_min,y_min,x_max,y_max,class_index (no space)
65
- '''
66
- if load_previous == True and os .path .isfile (data_path ):
67
- data = np .load (data_path )
68
- print ('Loading training data from ' + data_path )
69
- return data ['image_data' ], data ['box_data' ]
70
- image_data = []
71
- box_data = []
72
- with open (annotation_path ) as f :
73
- for line in f .readlines ():
74
- line = line .split (' ' )
75
- filename = line [0 ]
76
- image = Image .open (filename )
77
- boxed_image = letterbox_image (image , tuple (reversed (input_shape )))
78
- image_data .append (np .array (boxed_image ,dtype = 'uint8' ))
79
-
80
- boxes = np .zeros ((max_boxes ,5 ), dtype = 'int32' )
81
- for i , box in enumerate (line [1 :]):
82
- if i < max_boxes :
83
- boxes [i ] = np .array (list (map (int ,box .split (',' ))))
84
- else :
85
- break
86
- image_size = np .array (image .size )
87
- input_size = np .array (input_shape [::- 1 ])
88
- new_size = (image_size * np .min (input_size / image_size )).astype ('int32' )
89
- boxes [:i + 1 , 0 :2 ] = (boxes [:i + 1 , 0 :2 ]* new_size / image_size + (input_size - new_size )/ 2 ).astype ('int32' )
90
- boxes [:i + 1 , 2 :4 ] = (boxes [:i + 1 , 2 :4 ]* new_size / image_size + (input_size - new_size )/ 2 ).astype ('int32' )
91
- box_data .append (boxes )
92
- image_data = np .array (image_data )
93
- box_data = np .array (box_data )
94
- np .savez (data_path , image_data = image_data , box_data = box_data )
95
- print ('Saving training data into ' + data_path )
96
- return image_data , box_data
72
+ with open (anchors_path ) as f :
73
+ anchors = f .readline ()
74
+ anchors = [float (x ) for x in anchors .split (',' )]
75
+ return np .array (anchors ).reshape (- 1 , 2 )
97
76
98
77
99
78
def create_model (input_shape , anchors , num_classes , load_pretrained = True , freeze_body = True ):
100
79
'''create the training model'''
101
80
image_input = Input (shape = (None , None , 3 ))
102
81
h , w = input_shape
103
- num_anchors = len (anchors )// 3
104
- y_true = [Input (shape = (h // 32 , w // 32 , num_anchors , num_classes + 5 )),
105
- Input (shape = (h // 16 , w // 16 , num_anchors , num_classes + 5 )),
106
- Input (shape = (h // 8 , w // 8 , num_anchors , num_classes + 5 ))]
82
+ num_anchors = len (anchors )
107
83
108
- model_body = yolo_body (image_input , num_anchors , num_classes )
84
+ y_true = [Input (shape = (h // {0 :32 , 1 :16 , 2 :8 }[l ], w // {0 :32 , 1 :16 , 2 :8 }[l ], \
85
+ num_anchors // 3 , num_classes + 5 )) for l in range (3 )]
86
+
87
+ model_body = yolo_body (image_input , num_anchors // 3 , num_classes )
109
88
110
89
if load_pretrained :
111
90
weights_path = os .path .join ('model_data' , 'yolo_weights.h5' )
@@ -121,33 +100,66 @@ def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze
121
100
model_body .layers [i ].trainable = False
122
101
123
102
model_loss = Lambda (yolo_loss , output_shape = (1 ,), name = 'yolo_loss' ,
124
- arguments = {'anchors' : anchors , 'num_classes' : num_classes })(
103
+ arguments = {'anchors' : anchors , 'num_classes' : num_classes , 'ignore_thresh' : 0.5 })(
125
104
[* model_body .output , * y_true ])
126
105
model = Model ([model_body .input , * y_true ], model_loss )
127
106
128
- return model_body , model
107
+ return model
129
108
130
- def train ( model , image_data , y_true , log_dir = 'logs/' ):
131
- '''retrain/fine-tune the model'''
132
- model . compile ( optimizer = 'adam' , loss = {
133
- # use custom yolo_loss Lambda layer.
134
- 'yolo_loss' : lambda y_true , y_pred : y_pred } )
109
+ def create_tiny_model ( input_shape , anchors , num_classes , load_pretrained = True , freeze_body = True ):
110
+ '''create the training model, for Tiny YOLOv3 '''
111
+ image_input = Input ( shape = ( None , None , 3 ))
112
+ h , w = input_shape
113
+ num_anchors = len ( anchors )
135
114
136
- logging = TensorBoard (log_dir = log_dir )
137
- checkpoint = ModelCheckpoint (log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5" ,
138
- monitor = 'val_loss' , save_weights_only = True , save_best_only = True )
139
- early_stopping = EarlyStopping (monitor = 'val_loss' , min_delta = 0 , patience = 5 , verbose = 1 , mode = 'auto' )
115
+ y_true = [Input (shape = (h // {0 :32 , 1 :16 }[l ], w // {0 :32 , 1 :16 }[l ], \
116
+ num_anchors // 2 , num_classes + 5 )) for l in range (2 )]
140
117
141
- model .fit ([image_data , * y_true ],
142
- np .zeros (len (image_data )),
143
- validation_split = .1 ,
144
- batch_size = 32 ,
145
- epochs = 30 ,
146
- callbacks = [logging , checkpoint , early_stopping ])
147
- model .save_weights (log_dir + 'trained_weights.h5' )
148
- # Further training.
118
+ model_body = tiny_yolo_body (image_input , num_anchors // 2 , num_classes )
149
119
120
+ if load_pretrained :
121
+ weights_path = os .path .join ('model_data/' , 'tiny_yolo_weights.h5' )
122
+ if not os .path .exists (weights_path ):
123
+ print ("CREATING WEIGHTS FILE" + weights_path )
124
+ yolo_path = os .path .join ('model_data' , 'tiny_yolo.h5' )
125
+ orig_model = load_model (yolo_path , compile = False )
126
+ orig_model .save_weights (weights_path )
127
+ model_body .load_weights (weights_path , by_name = True , skip_mismatch = True )
128
+ if freeze_body :
129
+ # Do not freeze 2 output layers.
130
+ for i in range (len (model_body .layers )- 2 ):
131
+ model_body .layers [i ].trainable = False
132
+
133
+ model_loss = Lambda (yolo_loss , output_shape = (1 ,), name = 'yolo_loss' ,
134
+ arguments = {'anchors' : anchors , 'num_classes' : num_classes , 'ignore_thresh' : 0.7 })(
135
+ [* model_body .output , * y_true ])
136
+ model = Model ([model_body .input , * y_true ], model_loss )
150
137
138
+ return model
139
+
140
+ def data_generator (annotation_lines , batch_size , input_shape , anchors , num_classes ):
141
+ '''data generator for fit_generator'''
142
+ n = len (annotation_lines )
143
+ np .random .shuffle (annotation_lines )
144
+ i = 0
145
+ while True :
146
+ image_data = []
147
+ box_data = []
148
+ for b in range (batch_size ):
149
+ i %= n
150
+ image , box = get_random_data (annotation_lines [i ], input_shape )
151
+ image_data .append (image )
152
+ box_data .append (box )
153
+ i += 1
154
+ image_data = np .array (image_data )
155
+ box_data = np .array (box_data )
156
+ y_true = preprocess_true_boxes (box_data , input_shape , anchors , num_classes )
157
+ yield [image_data , * y_true ], np .zeros (batch_size )
158
+
159
+ def data_generator_wrap (annotation_lines , batch_size , input_shape , anchors , num_classes ):
160
+ n = len (annotation_lines )
161
+ if n == 0 or batch_size <= 0 : return None
162
+ return data_generator (annotation_lines , batch_size , input_shape , anchors , num_classes )
151
163
152
164
if __name__ == '__main__' :
153
165
_main ()
0 commit comments