Skip to content

Commit 66940b0

Browse files
committed
Merge branch 'generator_plugin'
2 parents 5b2027c + 6b01cba commit 66940b0

File tree

11 files changed

+341
-77
lines changed

11 files changed

+341
-77
lines changed

keras_retinanet/bin/train.py

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
import argparse
20+
from yapsy.PluginManager import PluginManagerSingleton
2021
import os
2122
import sys
2223

@@ -36,9 +37,6 @@
3637
from .. import layers
3738
from ..callbacks import RedirectModel
3839
from ..callbacks.eval import Evaluate
39-
from ..preprocessing.pascal_voc import PascalVocGenerator
40-
from ..preprocessing.csv_generator import CSVGenerator
41-
from ..preprocessing.open_images import OpenImagesGenerator
4240
from ..models.resnet import resnet_retinanet, custom_objects, download_imagenet
4341
from ..utils.transform import random_transform_generator
4442
from ..utils.keras_version import check_keras_version
@@ -151,74 +149,47 @@ def create_generators(args):
151149
# create random transform generator for augmenting training data
152150
transform_generator = random_transform_generator(flip_x_chance=0.5)
153151

154-
if args.dataset_type == 'coco':
155-
# import here to prevent unnecessary dependency on cocoapi
156-
from ..preprocessing.coco import CocoGenerator
152+
generators = None
157153

158-
train_generator = CocoGenerator(
159-
args.coco_path,
160-
'train2017',
161-
transform_generator=transform_generator,
162-
batch_size=args.batch_size
163-
)
164-
165-
validation_generator = CocoGenerator(
166-
args.coco_path,
167-
'val2017',
168-
batch_size=args.batch_size
169-
)
170-
elif args.dataset_type == 'pascal':
171-
train_generator = PascalVocGenerator(
172-
args.pascal_path,
173-
'trainval',
174-
transform_generator=transform_generator,
175-
batch_size=args.batch_size
176-
)
177-
178-
validation_generator = PascalVocGenerator(
179-
args.pascal_path,
180-
'test',
181-
batch_size=args.batch_size
182-
)
183-
elif args.dataset_type == 'csv':
184-
train_generator = CSVGenerator(
185-
args.annotations,
186-
args.classes,
187-
transform_generator=transform_generator,
188-
batch_size=args.batch_size
189-
)
154+
for plugin in PluginManagerSingleton.get().getAllPlugins():
155+
try:
156+
getattr(plugin.plugin_object, 'dataset_type')
157+
except AttributeError:
158+
print("Plugin ({}) does not contain a definition for dataset_type and cannot be utilised.".format(plugin.name))
159+
else:
160+
if plugin.plugin_object.dataset_type.lower() == args.dataset_type.lower():
161+
if callable(getattr(plugin.plugin_object, 'get_generator', None)):
162+
generators = plugin.plugin_object.get_generator(args, transform_generator=transform_generator)
163+
break
164+
else:
165+
print("Plugin ({}) does not contain a definition for get_generator and cannot be utilised".format(plugin.name))
166+
167+
if args.dataset_type == 'oid':
168+
generators = {
169+
'train_generator' : OpenImagesGenerator(
170+
args.main_dir,
171+
subset='train',
172+
version=args.version,
173+
labels_filter=args.labels_filter,
174+
annotation_cache_dir=args.annotation_cache_dir,
175+
transform_generator=transform_generator,
176+
batch_size=args.batch_size
177+
),
190178

191-
if args.val_annotations:
192-
validation_generator = CSVGenerator(
193-
args.val_annotations,
194-
args.classes,
179+
'validation_generator' : OpenImagesGenerator(
180+
args.main_dir,
181+
subset='validation',
182+
version=args.version,
183+
labels_filter=args.labels_filter,
184+
annotation_cache_dir=args.annotation_cache_dir,
195185
batch_size=args.batch_size
196186
)
197-
else:
198-
validation_generator = None
199-
elif args.dataset_type == 'oid':
200-
train_generator = OpenImagesGenerator(
201-
args.main_dir,
202-
subset='train',
203-
version=args.version,
204-
labels_filter=args.labels_filter,
205-
annotation_cache_dir=args.annotation_cache_dir,
206-
transform_generator=transform_generator,
207-
batch_size=args.batch_size
208-
)
187+
}
209188

210-
validation_generator = OpenImagesGenerator(
211-
args.main_dir,
212-
subset='validation',
213-
version=args.version,
214-
labels_filter=args.labels_filter,
215-
annotation_cache_dir=args.annotation_cache_dir,
216-
batch_size=args.batch_size
217-
)
218-
else:
189+
if generators is None:
219190
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))
220191

221-
return train_generator, validation_generator
192+
return generators["train_generator"], generators["validation_generator"]
222193

223194

224195
def check_args(parsed_args):
@@ -241,6 +212,9 @@ def check_args(parsed_args):
241212
"Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(parsed_args.multi_gpu,
242213
parsed_args.snapshot))
243214

215+
for plugin in PluginManagerSingleton.get().getAllPlugins():
216+
plugin.plugin_object.check_args(parsed_args)
217+
244218
return parsed_args
245219

246220

@@ -249,26 +223,19 @@ def parse_args(args):
249223
subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
250224
subparsers.required = True
251225

252-
coco_parser = subparsers.add_parser('coco')
253-
coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')
254-
255-
pascal_parser = subparsers.add_parser('pascal')
256-
pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')
257-
258226
def csv_list(string):
259227
return string.split(',')
260228

229+
# Invoke all loaded plugins for their subparsers.
230+
for plugin in PluginManagerSingleton.get().getAllPlugins():
231+
plugin.plugin_object.parser_args(subparsers)
232+
261233
oid_parser = subparsers.add_parser('oid')
262234
oid_parser.add_argument('main_dir', help='Path to dataset directory.')
263-
oid_parser.add_argument('--version', help='The current dataset version is V3.', default='2017_11')
264-
oid_parser.add_argument('--labels_filter', help='A list of labels to filter.', type=csv_list, default=None)
235+
oid_parser.add_argument('--version', help='The current dataset version is V3.', default='2017_11')
236+
oid_parser.add_argument('--labels_filter', help='A list of labels to filter.', type=csv_list, default=None)
265237
oid_parser.add_argument('--annotation_cache_dir', help='Path to store annotation cache.', default='.')
266238

267-
csv_parser = subparsers.add_parser('csv')
268-
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.')
269-
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
270-
csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).')
271-
272239
group = parser.add_mutually_exclusive_group()
273240
group.add_argument('--snapshot', help='Resume training from a snapshot.')
274241
group.add_argument('--imagenet-weights', help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True)
@@ -287,8 +254,37 @@ def csv_list(string):
287254

288255
return check_args(parser.parse_args(args))
289256

257+
def load_plugins(plugin_path):
258+
"""
259+
Responsible for initialising the plugin manager, setting the plugin directory to search, and loading all available
260+
plugins - then activating them.
261+
262+
:param plugin_path: String/[Str] for plugin paths to check
263+
:return: None
264+
"""
265+
266+
pm = PluginManagerSingleton.get()
267+
pl = pm.getPluginLocator()
268+
pl.setPluginInfoExtension("dataset")
269+
pm.setPluginLocator(pl)
270+
271+
plugin_path = [plugin_path] if type(plugin_path) is str else plugin_path
272+
pm.setPluginPlaces(plugin_path)
273+
274+
# Load all plugins
275+
pm.collectPlugins()
276+
277+
for k, n in enumerate(pm.getAllPlugins()):
278+
pm.activatePluginByName(n.name)
279+
print("Loaded: {}".format(n.name))
290280

291281
def main(args=None):
282+
#Load plugins first, as their procedures are needed for parsing args.
283+
print("Loading plugins...")
284+
# Load Plugins
285+
load_plugins(['plugins'])
286+
print("Loaded plugins.")
287+
292288
# parse arguments
293289
if args is None:
294290
args = sys.argv[1:]

keras_retinanet/plugins/coco/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
;Copyright 2017-2018 Ashley Williamson (https://inp.io)
2+
;
3+
;Licensed under the Apache License, Version 2.0 (the "License");
4+
;you may not use this file except in compliance with the License.
5+
;You may obtain a copy of the License at
6+
;
7+
; http://www.apache.org/licenses/LICENSE-2.0
8+
;
9+
;Unless required by applicable law or agreed to in writing, software
10+
;distributed under the License is distributed on an "AS IS" BASIS,
11+
;WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
;See the License for the specific language governing permissions and
13+
;limitations under the License.
14+
15+
[Core]
16+
Name = Coco Dataset
17+
Module = coco
18+
19+
[Documentation]
20+
Description = RetinaNet Definition for COCO Dataset
21+
Author = Ashley Williamson
22+
Version = 1.0
23+
Website = https://inp.io

keras_retinanet/plugins/coco/coco.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
Copyright 2017-2018 Ashley Williamson (https://inp.io)
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import keras_retinanet.utils.plugin as plugins
18+
# See https://yapsy.readthedocs.io/en/latest/Advices.html#plugin-class-detection-caveat
19+
# Caveat surrounding import. Must us 'as' rather than directly importing DatasetPlugin
20+
21+
from keras_retinanet.preprocessing.coco import CocoGenerator
22+
23+
24+
class CocoPlugin(plugins.DatasetPlugin):
25+
def __init__(self):
26+
super(CocoPlugin, self).__init__()
27+
28+
self.dataset_type = "coco"
29+
30+
def parser_args(self, subparsers):
31+
coco_parser = subparsers.add_parser(self.dataset_type)
32+
coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')
33+
34+
return coco_parser
35+
36+
def get_generator(self, args, transform_generator=None):
37+
train_generator = CocoGenerator(
38+
args.coco_path,
39+
'train2017',
40+
transform_generator=transform_generator,
41+
batch_size=args.batch_size
42+
)
43+
44+
validation_generator = CocoGenerator(
45+
args.coco_path,
46+
'val2017',
47+
batch_size=args.batch_size
48+
)
49+
50+
return {
51+
"train_generator": train_generator,
52+
"validation_generator": validation_generator
53+
}
54+

keras_retinanet/plugins/csv/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
;Copyright 2017-2018 Ashley Williamson (https://inp.io)
2+
;
3+
;Licensed under the Apache License, Version 2.0 (the "License");
4+
;you may not use this file except in compliance with the License.
5+
;You may obtain a copy of the License at
6+
;
7+
; http://www.apache.org/licenses/LICENSE-2.0
8+
;
9+
;Unless required by applicable law or agreed to in writing, software
10+
;distributed under the License is distributed on an "AS IS" BASIS,
11+
;WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
;See the License for the specific language governing permissions and
13+
;limitations under the License.
14+
15+
[Core]
16+
Name = CSV Dataset
17+
Module = csv
18+
19+
[Documentation]
20+
Description = RetinaNet Definition for CSV Dataset
21+
Author = Ashley Williamson
22+
Version = 1.0
23+
Website = https://inp.io

keras_retinanet/plugins/csv/csv.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
Copyright 2017-2018 Ashley Williamson (https://inp.io)
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import keras_retinanet.utils.plugin as plugins
18+
# See https://yapsy.readthedocs.io/en/latest/Advices.html#plugin-class-detection-caveat
19+
# Caveat surrounding import. Must us 'as' rather than directly importing DatasetPlugin
20+
21+
from keras_retinanet.preprocessing.csv_generator import CSVGenerator
22+
23+
24+
class CSVPlugin(plugins.DatasetPlugin):
25+
def __init__(self):
26+
super(CSVPlugin, self).__init__()
27+
28+
self.dataset_type = "csv"
29+
30+
def parser_args(self, subparsers):
31+
csv_parser = subparsers.add_parser(self.dataset_type)
32+
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.')
33+
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
34+
csv_parser.add_argument('--val-annotations',
35+
help='Path to CSV file containing annotations for validation (optional).')
36+
37+
return csv_parser
38+
39+
def get_generator(self, args, transform_generator=None):
40+
train_generator = CSVGenerator(
41+
args.annotations,
42+
args.classes,
43+
transform_generator=transform_generator,
44+
batch_size=args.batch_size
45+
)
46+
47+
if args.val_annotations:
48+
validation_generator = CSVGenerator(
49+
args.val_annotations,
50+
args.classes,
51+
batch_size=args.batch_size
52+
)
53+
else:
54+
validation_generator = None
55+
56+
return {
57+
"train_generator": train_generator,
58+
"validation_generator": validation_generator
59+
}

keras_retinanet/plugins/voc/__init__.py

Whitespace-only changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
;Copyright 2017-2018 Ashley Williamson (https://inp.io)
2+
;
3+
;Licensed under the Apache License, Version 2.0 (the "License");
4+
;you may not use this file except in compliance with the License.
5+
;You may obtain a copy of the License at
6+
;
7+
; http://www.apache.org/licenses/LICENSE-2.0
8+
;
9+
;Unless required by applicable law or agreed to in writing, software
10+
;distributed under the License is distributed on an "AS IS" BASIS,
11+
;WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
;See the License for the specific language governing permissions and
13+
;limitations under the License.
14+
15+
[Core]
16+
Name = VOC Dataset
17+
Module = voc
18+
19+
[Documentation]
20+
Description = RetinaNet Definition for VOC Dataset
21+
Author = Ashley Williamson
22+
Version = 1.0
23+
Website = https://inp.io

0 commit comments

Comments
 (0)