Skip to content

Commit 6bfedce

Browse files
committed
bin/train.py: Update for updated named_subparser API.
1 parent 7506b0e commit 6bfedce

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

keras_retinanet/bin/train.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ def check_args(parsed_args, dataset_plugins):
175175

176176
# let the selected dataset check args too
177177
dataset_plugins[parsed_args.dataset.name].check_args(parsed_args.dataset.args)
178-
if parsed_args.evaluate:
179-
dataset_plugins[parsed_args.evaluate.name].check_args(parsed_args.evaluate.args)
178+
if parsed_args.validate:
179+
dataset_plugins[parsed_args.validate.name].check_args(parsed_args.validate.args)
180180

181181
return parsed_args
182182

@@ -210,10 +210,8 @@ def parse_args(args, dataset_plugins):
210210
parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output', default='./logs')
211211
parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false')
212212

213-
train_dataset = NamedSubparser('--dataset', required=True)
214-
validation_dataset = NamedSubparser('--evaluate', required=False)
215-
parser.add_named_subparser(train_dataset)
216-
parser.add_named_subparser(validation_dataset)
213+
train_dataset = parser.add_named_subparser(['--dataset'], required=True)
214+
validation_dataset = parser.add_named_subparser(['--validate'])
217215

218216
# let all plugins register their arguments.
219217
for name, plugin in dataset_plugins.items():
@@ -259,8 +257,8 @@ def main(args=None):
259257

260258
# make the validation data generator
261259
validation_generator = None
262-
if args.evaluate:
263-
validation_generator = create_generator(arg.evaluate, dataset_plugins, batch_size=args.batch_size)
260+
if args.validate:
261+
validation_generator = create_generator(arg.validate, dataset_plugins, batch_size=args.batch_size)
264262

265263
if 'resnet' in args.backbone:
266264
from ..models.resnet import resnet_retinanet as retinanet, custom_objects, download_imagenet

0 commit comments

Comments
 (0)