-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
580 additions
and
267 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
# -*- coding: utf-8 -*- | ||
# File: convert_model.py | ||
# Author: Qian Ge <[email protected]> | ||
# Modified from | ||
# modified from: | ||
# https://github.com/qqwweee/keras-yolo3/blob/master/convert.py | ||
# reference: | ||
# https://github.com/pjreddie/darknet/blob/b13f67bfdd87434e141af532cdb5dc1b8369aa3b/src/parser.c#L958 | ||
|
@@ -17,20 +17,18 @@ | |
|
||
|
||
def unique_config_sections(config_path): | ||
"""Convert all config sections to have unique names. | ||
Adds unique suffixes to config sections for compability with configparser. | ||
""" Convert all config sections to have unique names. | ||
Adds unique suffixes to config sections for compability with configparser. | ||
""" | ||
section_counters = defaultdict(int) | ||
output_stream = io.StringIO() | ||
yolo_id = 0 | ||
# prev_dim = 3 | ||
# prev_dim_dict = {} | ||
|
||
with open(config_path) as fin: | ||
for line in fin: | ||
if line.startswith('['): | ||
section = line.strip().strip('[]') | ||
n_section = section | ||
# out_dim = prev_dim | ||
if section == 'yolo': | ||
n_section = section | ||
yolo_id += 1 | ||
|
@@ -50,15 +48,25 @@ def unique_config_sections(config_path): | |
_section = n_section + '_' + str(section_counters[n_section]-1) | ||
else: | ||
_section = n_section + '_' + str(section_counters[n_section]) | ||
# prev_dim_dict[_section] = prev_dim | ||
# prev_dim = out_dim | ||
|
||
print(_section) | ||
line = line.replace(section, _section) | ||
output_stream.write(line) | ||
output_stream.seek(0) | ||
return output_stream | ||
|
||
def parse_conv(weights_file, cfg_parser, section, layer_dict): | ||
""" parse conv layer | ||
Args: | ||
weights_file (file object): file object of .weights file | ||
cfg_parser (ConfigParser object): ConfigParser object of .cfg file for net | ||
section (str): name of conv layer | ||
layer_dict (dictionary): dict storing layer info | ||
Returns: | ||
dict storing layer info and weights values | ||
""" | ||
prev_layer_channel = layer_dict['prev_layer_channel'] | ||
count = layer_dict['count'] | ||
|
||
|
@@ -69,11 +77,6 @@ def parse_conv(weights_file, cfg_parser, section, layer_dict): | |
activation = cfg_parser[section]['activation'] | ||
batch_normalize = 'batch_normalize' in cfg_parser[section] | ||
|
||
# Setting weights. | ||
# Darknet serializes convolutional weights as: | ||
# [bias/beta, [gamma, mean, variance], conv_weights] | ||
# prev_layer_shape = K.int_shape(prev_layer) | ||
|
||
weights_shape = (size, size, prev_layer_channel, filters) | ||
darknet_w_shape = (filters, weights_shape[2], size, size) | ||
weights_size = np.product(weights_shape) | ||
|
@@ -125,6 +128,16 @@ def parse_conv(weights_file, cfg_parser, section, layer_dict): | |
return layer_dict | ||
|
||
def convert(weights_path, config_path, save_path): | ||
""" convert .weight file to .npy file | ||
The converted .npy file will be saved in save_path | ||
Args: | ||
weights_path (str): path of .weight file | ||
config_path (str): path of configuration .cfg file | ||
save_path (str): path for saving .npy file | ||
""" | ||
# load weights file | ||
weights_file = open(weights_path, 'rb') | ||
major, minor, revision = np.ndarray( | ||
shape=(3, ), dtype='int32', buffer=weights_file.read(12)) | ||
|
@@ -134,7 +147,7 @@ def convert(weights_path, config_path, save_path): | |
else: | ||
seen = np.ndarray(shape=(1,), dtype='int32', buffer=weights_file.read(4)) | ||
print('Weights Header: ', major, minor, revision, seen) | ||
|
||
# parse net configuration | ||
net_config = unique_config_sections(config_path) | ||
cfg_parser = configparser.ConfigParser() | ||
cfg_parser.read_file(net_config) | ||
|
@@ -144,29 +157,23 @@ def convert(weights_path, config_path, save_path): | |
dim_list = [] | ||
layer_dict['prev_layer_channel'] = 3 | ||
layer_dict['count'] = 0 | ||
# layer_id = 0 | ||
for section in cfg_parser.sections(): | ||
print('Parsing section {}'.format(section)) | ||
if section.startswith('conv'): | ||
save_weight_dict[section] = {} | ||
|
||
layer_dict = parse_conv(weights_file, cfg_parser, section, layer_dict) | ||
save_weight_dict[section]['weights'] = layer_dict['conv_weights'] | ||
|
||
if len(layer_dict['bn_weight_list']) > 0: | ||
save_weight_dict[section]['bn'] = layer_dict['bn_weight_list'] | ||
if len(layer_dict['conv_bias']) > 0: | ||
save_weight_dict[section]['biases'] = layer_dict['conv_bias'] | ||
elif section.startswith('route'): | ||
route_layers = list(map(int, (cfg_parser[section]['layers']).split(','))) | ||
layer_dict['prev_layer_channel'] = sum([dim_list[layer_] for layer_ in route_layers]) | ||
# print(route_layers) | ||
dim_list.append(layer_dict['prev_layer_channel']) | ||
# layer_id += 1 | ||
remaining_weights = len(weights_file.read()) / 4 | ||
print('Load {} of {} from weights.'.format(layer_dict['count'], remaining_weights + layer_dict['count'])) | ||
weights_file.close() | ||
# print(dim_list) | ||
np.save(save_path, save_weight_dict) | ||
|
||
def get_args(): | ||
|
@@ -192,20 +199,22 @@ def get_args(): | |
weights_dir = FLAGS.weights_dir | ||
save_dir = FLAGS.save_dir | ||
|
||
|
||
FLAGS = get_args() | ||
if FLAGS.model == 'darknet': | ||
# convert Darknet53 for classification | ||
config_path = 'darknet53.cfg' | ||
weights_path = os.path.join(weights_dir, 'darknet53_448.weights') | ||
save_path = os.path.join(save_dir, 'darknet53_448.npy') | ||
elif FLAGS.model == 'yolov3_feat': | ||
# convert Darknet53 first 52 conv layers for feature extration in yolov3 | ||
config_path = 'yolov3_feat.cfg' | ||
weights_path = os.path.join(weights_dir, 'yolov3.weights') | ||
save_path = os.path.join(save_dir, 'yolov3_feat.npy') | ||
elif FLAGS.model == 'yolo': | ||
# convert yolov3 trained on COCO dataset | ||
config_path = 'yolov3.cfg' | ||
weights_path = os.path.join(weights_dir, 'yolov3.weights') | ||
save_path = os.path.join(save_dir, 'yolov3.npy') | ||
|
||
convert(weights_path, config_path, save_path) | ||
# unique_config_sections(config_path) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.