Skip to content

Commit

Permalink
Merge branch 'generator_plugin'
Browse files Browse the repository at this point in the history
  • Loading branch information
awilliamson committed Feb 14, 2018
2 parents 5b2027c + 6b01cba commit 66940b0
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 77 deletions.
150 changes: 73 additions & 77 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import argparse
from yapsy.PluginManager import PluginManagerSingleton
import os
import sys

Expand All @@ -36,9 +37,6 @@
from .. import layers
from ..callbacks import RedirectModel
from ..callbacks.eval import Evaluate
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..preprocessing.csv_generator import CSVGenerator
from ..preprocessing.open_images import OpenImagesGenerator
from ..models.resnet import resnet_retinanet, custom_objects, download_imagenet
from ..utils.transform import random_transform_generator
from ..utils.keras_version import check_keras_version
Expand Down Expand Up @@ -151,74 +149,47 @@ def create_generators(args):
# create random transform generator for augmenting training data
transform_generator = random_transform_generator(flip_x_chance=0.5)

if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
from ..preprocessing.coco import CocoGenerator
generators = None

train_generator = CocoGenerator(
args.coco_path,
'train2017',
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = CocoGenerator(
args.coco_path,
'val2017',
batch_size=args.batch_size
)
elif args.dataset_type == 'pascal':
train_generator = PascalVocGenerator(
args.pascal_path,
'trainval',
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = PascalVocGenerator(
args.pascal_path,
'test',
batch_size=args.batch_size
)
elif args.dataset_type == 'csv':
train_generator = CSVGenerator(
args.annotations,
args.classes,
transform_generator=transform_generator,
batch_size=args.batch_size
)
for plugin in PluginManagerSingleton.get().getAllPlugins():
try:
getattr(plugin.plugin_object, 'dataset_type')
except AttributeError:
print("Plugin ({}) does not contain a definition for dataset_type and cannot be utilised.".format(plugin.name))
else:
if plugin.plugin_object.dataset_type.lower() == args.dataset_type.lower():
if callable(getattr(plugin.plugin_object, 'get_generator', None)):
generators = plugin.plugin_object.get_generator(args, transform_generator=transform_generator)
break
else:
print("Plugin ({}) does not contain a definition for get_generator and cannot be utilised".format(plugin.name))

if args.dataset_type == 'oid':
generators = {
'train_generator' : OpenImagesGenerator(
args.main_dir,
subset='train',
version=args.version,
labels_filter=args.labels_filter,
annotation_cache_dir=args.annotation_cache_dir,
transform_generator=transform_generator,
batch_size=args.batch_size
),

if args.val_annotations:
validation_generator = CSVGenerator(
args.val_annotations,
args.classes,
'validation_generator' : OpenImagesGenerator(
args.main_dir,
subset='validation',
version=args.version,
labels_filter=args.labels_filter,
annotation_cache_dir=args.annotation_cache_dir,
batch_size=args.batch_size
)
else:
validation_generator = None
elif args.dataset_type == 'oid':
train_generator = OpenImagesGenerator(
args.main_dir,
subset='train',
version=args.version,
labels_filter=args.labels_filter,
annotation_cache_dir=args.annotation_cache_dir,
transform_generator=transform_generator,
batch_size=args.batch_size
)
}

validation_generator = OpenImagesGenerator(
args.main_dir,
subset='validation',
version=args.version,
labels_filter=args.labels_filter,
annotation_cache_dir=args.annotation_cache_dir,
batch_size=args.batch_size
)
else:
if generators is None:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

return train_generator, validation_generator
return generators["train_generator"], generators["validation_generator"]


def check_args(parsed_args):
Expand All @@ -241,6 +212,9 @@ def check_args(parsed_args):
"Multi GPU training ({}) and resuming from snapshots ({}) is not supported.".format(parsed_args.multi_gpu,
parsed_args.snapshot))

for plugin in PluginManagerSingleton.get().getAllPlugins():
plugin.plugin_object.check_args(parsed_args)

return parsed_args


Expand All @@ -249,26 +223,19 @@ def parse_args(args):
subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
subparsers.required = True

coco_parser = subparsers.add_parser('coco')
coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')

pascal_parser = subparsers.add_parser('pascal')
pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')

def csv_list(string):
return string.split(',')

# Invoke all loaded plugins for their subparsers.
for plugin in PluginManagerSingleton.get().getAllPlugins():
plugin.plugin_object.parser_args(subparsers)

oid_parser = subparsers.add_parser('oid')
oid_parser.add_argument('main_dir', help='Path to dataset directory.')
oid_parser.add_argument('--version', help='The current dataset version is V3.', default='2017_11')
oid_parser.add_argument('--labels_filter', help='A list of labels to filter.', type=csv_list, default=None)
oid_parser.add_argument('--version', help='The current dataset version is V3.', default='2017_11')
oid_parser.add_argument('--labels_filter', help='A list of labels to filter.', type=csv_list, default=None)
oid_parser.add_argument('--annotation_cache_dir', help='Path to store annotation cache.', default='.')

csv_parser = subparsers.add_parser('csv')
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.')
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).')

group = parser.add_mutually_exclusive_group()
group.add_argument('--snapshot', help='Resume training from a snapshot.')
group.add_argument('--imagenet-weights', help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True)
Expand All @@ -287,8 +254,37 @@ def csv_list(string):

return check_args(parser.parse_args(args))

def load_plugins(plugin_path):
"""
Responsible for initialising the plugin manager, setting the plugin directory to search, and loading all available
plugins - then activating them.
:param plugin_path: String/[Str] for plugin paths to check
:return: None
"""

pm = PluginManagerSingleton.get()
pl = pm.getPluginLocator()
pl.setPluginInfoExtension("dataset")
pm.setPluginLocator(pl)

plugin_path = [plugin_path] if type(plugin_path) is str else plugin_path
pm.setPluginPlaces(plugin_path)

# Load all plugins
pm.collectPlugins()

for k, n in enumerate(pm.getAllPlugins()):
pm.activatePluginByName(n.name)
print("Loaded: {}".format(n.name))

def main(args=None):
#Load plugins first, as their procedures are needed for parsing args.
print("Loading plugins...")
# Load Plugins
load_plugins(['plugins'])
print("Loaded plugins.")

# parse arguments
if args is None:
args = sys.argv[1:]
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions keras_retinanet/plugins/coco/coco.dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
;Copyright 2017-2018 Ashley Williamson (https://inp.io)
;
;Licensed under the Apache License, Version 2.0 (the "License");
;you may not use this file except in compliance with the License.
;You may obtain a copy of the License at
;
; http://www.apache.org/licenses/LICENSE-2.0
;
;Unless required by applicable law or agreed to in writing, software
;distributed under the License is distributed on an "AS IS" BASIS,
;WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;See the License for the specific language governing permissions and
;limitations under the License.

[Core]
Name = Coco Dataset
Module = coco

[Documentation]
Description = RetinaNet Definition for COCO Dataset
Author = Ashley Williamson
Version = 1.0
Website = https://inp.io
54 changes: 54 additions & 0 deletions keras_retinanet/plugins/coco/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Copyright 2017-2018 Ashley Williamson (https://inp.io)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import keras_retinanet.utils.plugin as plugins
# See https://yapsy.readthedocs.io/en/latest/Advices.html#plugin-class-detection-caveat
# Caveat surrounding import. Must us 'as' rather than directly importing DatasetPlugin

from keras_retinanet.preprocessing.coco import CocoGenerator


class CocoPlugin(plugins.DatasetPlugin):
def __init__(self):
super(CocoPlugin, self).__init__()

self.dataset_type = "coco"

def parser_args(self, subparsers):
coco_parser = subparsers.add_parser(self.dataset_type)
coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')

return coco_parser

def get_generator(self, args, transform_generator=None):
train_generator = CocoGenerator(
args.coco_path,
'train2017',
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = CocoGenerator(
args.coco_path,
'val2017',
batch_size=args.batch_size
)

return {
"train_generator": train_generator,
"validation_generator": validation_generator
}

Empty file.
23 changes: 23 additions & 0 deletions keras_retinanet/plugins/csv/csv.dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
;Copyright 2017-2018 Ashley Williamson (https://inp.io)
;
;Licensed under the Apache License, Version 2.0 (the "License");
;you may not use this file except in compliance with the License.
;You may obtain a copy of the License at
;
; http://www.apache.org/licenses/LICENSE-2.0
;
;Unless required by applicable law or agreed to in writing, software
;distributed under the License is distributed on an "AS IS" BASIS,
;WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;See the License for the specific language governing permissions and
;limitations under the License.

[Core]
Name = CSV Dataset
Module = csv

[Documentation]
Description = RetinaNet Definition for CSV Dataset
Author = Ashley Williamson
Version = 1.0
Website = https://inp.io
59 changes: 59 additions & 0 deletions keras_retinanet/plugins/csv/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright 2017-2018 Ashley Williamson (https://inp.io)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import keras_retinanet.utils.plugin as plugins
# See https://yapsy.readthedocs.io/en/latest/Advices.html#plugin-class-detection-caveat
# Caveat surrounding import. Must us 'as' rather than directly importing DatasetPlugin

from keras_retinanet.preprocessing.csv_generator import CSVGenerator


class CSVPlugin(plugins.DatasetPlugin):
def __init__(self):
super(CSVPlugin, self).__init__()

self.dataset_type = "csv"

def parser_args(self, subparsers):
csv_parser = subparsers.add_parser(self.dataset_type)
csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.')
csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')
csv_parser.add_argument('--val-annotations',
help='Path to CSV file containing annotations for validation (optional).')

return csv_parser

def get_generator(self, args, transform_generator=None):
train_generator = CSVGenerator(
args.annotations,
args.classes,
transform_generator=transform_generator,
batch_size=args.batch_size
)

if args.val_annotations:
validation_generator = CSVGenerator(
args.val_annotations,
args.classes,
batch_size=args.batch_size
)
else:
validation_generator = None

return {
"train_generator": train_generator,
"validation_generator": validation_generator
}
Empty file.
23 changes: 23 additions & 0 deletions keras_retinanet/plugins/voc/voc.dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
;Copyright 2017-2018 Ashley Williamson (https://inp.io)
;
;Licensed under the Apache License, Version 2.0 (the "License");
;you may not use this file except in compliance with the License.
;You may obtain a copy of the License at
;
; http://www.apache.org/licenses/LICENSE-2.0
;
;Unless required by applicable law or agreed to in writing, software
;distributed under the License is distributed on an "AS IS" BASIS,
;WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;See the License for the specific language governing permissions and
;limitations under the License.

[Core]
Name = VOC Dataset
Module = voc

[Documentation]
Description = RetinaNet Definition for VOC Dataset
Author = Ashley Williamson
Version = 1.0
Website = https://inp.io
Loading

0 comments on commit 66940b0

Please sign in to comment.