Skip to content

Commit bddd7f8

Browse files
committed
Towards TF2.0
1 parent 2292300 commit bddd7f8

13 files changed

+114
-36
lines changed

docker_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ else
2525
gpu_opts=""
2626
fi
2727

28-
docker run $gpu_opts -it --rm -v $WORK:/work "$image" python "$@"
28+
docker run $gpu_opts -it --rm -v $WORK:/work "$image" "$@"

get_test_images_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env bash
22

3-
./docker_run.sh get_test_images.py
3+
./docker_run.sh python get_test_images.py

image_demo_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env bash
22

3-
./docker_run.sh image_demo.py --model 101 --image_dir ./images --output_dir ./output
3+
./docker_run.sh python image_demo.py --model 101 --image_dir ./images --output_dir ./output

image_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main():
3939
img = cv2.imread(f)
4040
pose_scores, keypoint_scores, keypoint_coords = posenet.estimate_multiple_poses(img)
4141
img_poses = posenet.draw_poses(img, pose_scores, keypoint_scores, keypoint_coords)
42-
posenet.print_scores(img, pose_scores, keypoint_scores, keypoint_coords)
42+
posenet.print_scores(f, pose_scores, keypoint_scores, keypoint_coords)
4343
cv2.imwrite(os.path.join(args.output_dir, os.path.relpath(f, args.image_dir)), img_poses)
4444

4545
print('Average FPS:', len(filenames) / (time.time() - start))

image_test_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env bash
22

3-
./docker_run.sh image_test.py --model 101 --image_dir ./images --output_dir ./output
3+
./docker_run.sh python image_test.py --model 101 --image_dir ./images --output_dir ./output

inspect_saved_model.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env bash
2+
3+
FOLDER=$1
4+
5+
# e.g.: $> ./inspect_saved_model.sh _tf_models/posenet/mobilenet_v1_100/stride16
6+
./docker_run.sh saved_model_cli show --dir "$FOLDER" --all

posenet/base_model.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44

55
class BaseModel(ABC):
66

7-
def __init__(self, sess, input_tensor_name, output_tensor_names, output_stride):
7+
# keys for the output_tensor_names map
8+
HEATMAP_KEY = "heatmap"
9+
OFFSETS_KEY = "offsets"
10+
DISPLACEMENT_FWD_KEY = "displacement_fwd"
11+
DISPLACEMENT_BWD_KEY = "displacement_bwd"
12+
13+
def __init__(self, model_function, output_tensor_names, output_stride):
814
self.output_stride = output_stride
9-
self.sess = sess
10-
self.input_tensor_name = input_tensor_name
11-
self.output_tensors = [
12-
tf.sigmoid(sess.graph.get_tensor_by_name(output_tensor_names['heatmap']), 'heatmap'), # sigmoid!!!
13-
sess.graph.get_tensor_by_name(output_tensor_names['offsets']),
14-
sess.graph.get_tensor_by_name(output_tensor_names['displacement_fwd']),
15-
sess.graph.get_tensor_by_name(output_tensor_names['displacement_bwd'])
16-
]
15+
self.output_tensor_names = output_tensor_names
16+
self.model_function = model_function
17+
# self.sess = sess
18+
# self.input_tensor_name = input_tensor_name
19+
# self.output_tensors = output_tensors
1720

1821
def valid_resolution(self, width, height):
1922
# calculate closest smaller width and height that is divisible by the stride after subtracting 1 (for the bias?)
@@ -27,11 +30,24 @@ def preprocess_input(self, image):
2730

2831
def predict(self, image):
2932
input_image, image_scale = self.preprocess_input(image)
30-
heatmap_result, offsets_result, displacement_fwd_result, displacement_bwd_result = self.sess.run(
31-
self.output_tensors,
32-
feed_dict={self.input_tensor_name: input_image}
33-
)
34-
return heatmap_result, offsets_result, displacement_fwd_result, displacement_bwd_result, image_scale
33+
34+
input_image = tf.convert_to_tensor(input_image, dtype=tf.float32)
35+
36+
result = self.model_function(input_image)
37+
38+
heatmap_result = result[self.output_tensor_names[self.HEATMAP_KEY]]
39+
offsets_result = result[self.output_tensor_names[self.OFFSETS_KEY]]
40+
displacement_fwd_result = result[self.output_tensor_names[self.DISPLACEMENT_FWD_KEY]]
41+
displacement_bwd_result = result[self.output_tensor_names[self.DISPLACEMENT_BWD_KEY]]
42+
43+
44+
# self.sess.run(
45+
# self.output_tensors,
46+
# feed_dict={self.input_tensor_name: input_image}
47+
# )
48+
49+
return tf.sigmoid(heatmap_result), offsets_result, displacement_fwd_result, displacement_bwd_result, image_scale
3550

3651
def close(self):
37-
self.sess.close()
52+
# self.sess.close()
53+
return

posenet/converter/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66

77
def load_config(config_name='config.yaml'):
88
cfg_f = open(os.path.join(BASE_DIR, config_name), "r+")
9-
cfg = yaml.load(cfg_f)
9+
cfg = yaml.load(cfg_f, Loader=yaml.FullLoader)
1010
return cfg

posenet/converter/tfjs2tf.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,46 @@
44
import posenet.converter.tfjsdownload as tfjsdownload
55

66

7+
def __tensor_info_def(sess, tensor_names):
8+
signatures = {}
9+
for tensor_name in tensor_names:
10+
tensor = sess.graph.get_tensor_by_name(tensor_name)
11+
tensor_info = tf.compat.v1.saved_model.build_tensor_info(tensor)
12+
signatures[tensor_name] = tensor_info
13+
return signatures
14+
15+
716
def convert(model, neuralnet, model_variant):
817
model_cfg = tfjsdownload.model_config(model, neuralnet, model_variant)
918
model_file_path = os.path.join(model_cfg['tfjs_dir'], model_cfg['filename'])
1019
if not os.path.exists(model_file_path):
1120
print('Cannot find tfjs model path %s, downloading tfjs model...' % model_file_path)
1221
tfjsdownload.download_tfjs_model(model, neuralnet, model_variant)
13-
tfjs.api.graph_model_to_saved_model(model_cfg['tfjs_dir'], model_cfg['tf_dir'], ['serve'])
22+
23+
# 'graph_model_to_saved_model' doesn't store the signature for the model!
24+
# tfjs.api.graph_model_to_saved_model(model_cfg['tfjs_dir'], model_cfg['tf_dir'], ['serve'])
25+
# so we do it manually here:
26+
27+
# see: https://www.programcreek.com/python/example/104885/tensorflow.python.saved_model.signature_def_utils.build_signature_def
28+
graph = tfjs.api.load_graph_model(model_cfg['tfjs_dir'])
29+
builder = tf.compat.v1.saved_model.Builder(model_cfg['tf_dir'])
30+
31+
with tf.compat.v1.Session(graph=graph) as sess:
32+
input_tensor_names = tfjs.util.get_input_tensors(graph)
33+
output_tensor_names = tfjs.util.get_output_tensors(graph)
34+
35+
signature_inputs = __tensor_info_def(sess, input_tensor_names)
36+
signature_outputs = __tensor_info_def(sess, output_tensor_names)
37+
38+
method_name = tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME
39+
signature_def = tf.compat.v1.saved_model.build_signature_def(inputs=signature_inputs,
40+
outputs=signature_outputs,
41+
method_name=method_name)
42+
signature_map = {tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def}
43+
builder.add_meta_graph_and_variables(sess=sess,
44+
tags=['serve'],
45+
signature_def_map=signature_map)
46+
return builder.save()
1447

1548

1649
def list_tensors(model, neuralnet, model_variant):

posenet/mobilenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
class MobileNet(BaseModel):
77

8-
def __init__(self, sess, input_tensor_name, output_tensor_names, output_stride):
9-
super().__init__(sess, input_tensor_name, output_tensor_names, output_stride)
8+
def __init__(self, model_function, output_tensor_names, output_stride):
9+
super().__init__(model_function, output_tensor_names, output_stride)
1010

1111
def preprocess_input(self, image):
1212
target_width, target_height = self.valid_resolution(image.shape[1], image.shape[0])

0 commit comments

Comments
 (0)