Skip to content

Commit a41c6ff

Browse files
committed
Wiring in the cli parameters.
1 parent b43e7e4 commit a41c6ff

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ version and converted on the fly.
7575
Image demo runs inference on an input folder of images and outputs those images with the keypoints and skeleton
7676
overlayed.
7777

78-
`python image_demo.py --model 101 --image_dir ./images --output_dir ./output`
78+
`python image_demo.py --model resnet50 --stride 32 --image_dir ./images --output_dir ./output`
7979

8080
A folder of suitable test images can be downloaded by first running the `get_test_images.py` script.
8181

benchmark.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88

99
parser = argparse.ArgumentParser()
10-
parser.add_argument('--model', type=int, default=101)
10+
parser.add_argument('--model', type=str, default='resnet50') # mobilenet resnet50
11+
parser.add_argument('--stride', type=int, default=16) # 8, 16, 32 (max 16 for mobilenet)
12+
parser.add_argument('--quant_bytes', type=int, default=4) # 4 = float
13+
parser.add_argument('--multiplier', type=float, default=1.0) # only for mobilenet
1114
parser.add_argument('--image_dir', type=str, default='./images')
1215
parser.add_argument('--num_images', type=int, default=1000)
1316
args = parser.parse_args()
@@ -18,10 +21,10 @@ def main():
1821
print('Tensorflow version: %s' % tf.__version__)
1922
assert tf.__version__.startswith('2.'), "Tensorflow version 2.x must be used!"
2023

21-
model = 'resnet50' # mobilenet resnet50
22-
stride = 32 # 8, 16, 32
23-
quant_bytes = 4 # float
24-
multiplier = 1.0 # only for mobilenet
24+
model = args.model # mobilenet resnet50
25+
stride = args.stride # 8, 16, 32 (max 16 for mobilenet)
26+
quant_bytes = args.quant_bytes # float
27+
multiplier = args.multiplier # only for mobilenet
2528

2629
posenet = load_model(model, stride, quant_bytes, multiplier)
2730

image_demo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from posenet.posenet_factory import load_model
77

88
parser = argparse.ArgumentParser()
9-
parser.add_argument('--model', type=int, default=101)
10-
parser.add_argument('--scale_factor', type=float, default=1.0)
9+
parser.add_argument('--model', type=str, default='resnet50') # mobilenet resnet50
10+
parser.add_argument('--stride', type=int, default=16) # 8, 16, 32 (max 16 for mobilenet)
11+
parser.add_argument('--quant_bytes', type=int, default=4) # 4 = float
12+
parser.add_argument('--multiplier', type=float, default=1.0) # only for mobilenet
1113
parser.add_argument('--notxt', action='store_true')
1214
parser.add_argument('--image_dir', type=str, default='./images')
1315
parser.add_argument('--output_dir', type=str, default='./output')
@@ -23,10 +25,10 @@ def main():
2325
if not os.path.exists(args.output_dir):
2426
os.makedirs(args.output_dir)
2527

26-
model = 'resnet50' # mobilenet resnet50
27-
stride = 32 # 8, 16, 32 (max 16 for mobilenet)
28-
quant_bytes = 4 # float
29-
multiplier = 1.0 # only for mobilenet
28+
model = args.model # mobilenet resnet50
29+
stride = args.stride # 8, 16, 32 (max 16 for mobilenet)
30+
quant_bytes = args.quant_bytes # float
31+
multiplier = args.multiplier # only for mobilenet
3032

3133
posenet = load_model(model, stride, quant_bytes, multiplier)
3234

0 commit comments

Comments
 (0)