Skip to content

Commit

Permalink
Replace train create_generators with Plugin impl
Browse files Browse the repository at this point in the history
Replaced entire create_generators routine with PluginManager based
system. Includes error handling for plugins which may not conform to API
entirely.
Dictionary based capture to support extensibility.
Additionally, changed plugin path to more correctly represent project
root.
  • Loading branch information
awilliamson committed Feb 22, 2018
1 parent 3f79156 commit 27a2613
Showing 1 changed file with 16 additions and 49 deletions.
65 changes: 16 additions & 49 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from .. import layers
from ..callbacks import RedirectModel
from ..callbacks.eval import Evaluate
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..preprocessing.csv_generator import CSVGenerator
from ..models.resnet import resnet50_retinanet, custom_objects
from ..utils.transform import random_transform_generator
from ..utils.keras_version import check_keras_version
Expand Down Expand Up @@ -121,55 +119,25 @@ def create_generators(args):
# create random transform generator for augmenting training data
transform_generator = random_transform_generator(flip_x_chance=0.5)

if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
from ..preprocessing.coco import CocoGenerator
generators = None

train_generator = CocoGenerator(
args.coco_path,
'train2017',
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = CocoGenerator(
args.coco_path,
'val2017',
batch_size=args.batch_size
)
elif args.dataset_type == 'pascal':
train_generator = PascalVocGenerator(
args.pascal_path,
'trainval',
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = PascalVocGenerator(
args.pascal_path,
'test',
batch_size=args.batch_size
)
elif args.dataset_type == 'csv':
train_generator = CSVGenerator(
args.annotations,
args.classes,
transform_generator=transform_generator,
batch_size=args.batch_size
)

if args.val_annotations:
validation_generator = CSVGenerator(
args.val_annotations,
args.classes,
batch_size=args.batch_size
)
for plugin in PluginManagerSingleton.get().getAllPlugins():
try:
getattr(plugin.plugin_object, 'dataset_type')
except AttributeError:
print("Plugin ({}) does not contain a definition for dataset_type and cannot be utilised.".format(plugin.name))
else:
validation_generator = None
else:
if plugin.plugin_object.dataset_type.lower() == args.dataset_type.lower():
if callable(getattr(plugin.plugin_object, 'create_generators', None)):
generators = plugin.plugin_object.create_generators(args, transform_generator=transform_generator)
break
else:
print("Plugin ({}) does not contain a definition for get_generator and cannot be utilised".format(plugin.name))

if generators is None:
raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

return train_generator, validation_generator
return generators["train_generator"], generators["validation_generator"]


def check_args(parsed_args):
Expand Down Expand Up @@ -247,11 +215,10 @@ def load_plugins(plugin_path):
print("Loaded: {}".format(n.name))

def main(args=None):

#Load plugins first, as their procedures are needed for parsing args.
print("Loading plugins...")
# Load Plugins
load_plugins(['keras_retinanet/plugins'])
load_plugins(['plugins'])
print("Loaded plugins.")

# parse arguments
Expand Down

0 comments on commit 27a2613

Please sign in to comment.