Skip to content

Commit d3d3c1a

Browse files
committed
Refactoring.
1 parent 4c401c7 commit d3d3c1a

File tree

5 files changed

+104
-20
lines changed

5 files changed

+104
-20
lines changed

posenet/base_model.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,34 @@
11
from abc import ABC, abstractmethod
2+
import tensorflow as tf
23

34

45
class BaseModel(ABC):
56

6-
def __init__(self, output_stride):
7+
def __init__(self, sess, input_tensor_name, output_tensor_names, output_stride):
78
self.output_stride = output_stride
9+
self.sess = sess
10+
self.input_tensor_name = input_tensor_name
11+
self.output_tensors = [
12+
tf.sigmoid(sess.graph.get_tensor_by_name(output_tensor_names['heatmap']), 'heatmap'), # sigmoid!!!
13+
sess.graph.get_tensor_by_name(output_tensor_names['offsets']),
14+
sess.graph.get_tensor_by_name(output_tensor_names['displacement_fwd']),
15+
sess.graph.get_tensor_by_name(output_tensor_names['displacement_bwd'])
16+
]
817

9-
@abstractmethod
10-
def preprocess_input(self):
11-
pass
18+
def valid_resolution(self, width, height):
19+
# calculate closest smaller width and height that is divisible by the stride after subtracting 1 (for the bias?)
20+
target_width = (int(width) // self.output_stride) * self.output_stride + 1
21+
target_height = (int(height) // self.output_stride) * self.output_stride + 1
22+
return target_width, target_height
1223

1324
@abstractmethod
14-
def name_output_results(self, graph):
15-
return graph
25+
def preprocess_input(self, image):
26+
pass
1627

17-
def predict(self, nhwc_images):
18-
return nhwc_images
28+
def predict(self, image):
29+
input_image, image_scale = self.preprocess_input(image)
30+
heatmap_result, offsets_result, displacement_fwd_result, displacement_bwd_result = self.sess.run(
31+
self.output_tensors,
32+
feed_dict={self.input_tensor_name: input_image}
33+
)
34+
return heatmap_result, offsets_result, displacement_fwd_result, displacement_bwd_result, image_scale

posenet/mobilenet.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from posenet.base_model import BaseModel
2+
import numpy as np
3+
import cv2
24

35

46
class MobileNet(BaseModel):
57

6-
def __init__(self, output_stride):
7-
super().__init__(output_stride)
8+
def __init__(self, sess, input_tensor_name, output_tensor_names, output_stride):
9+
super().__init__(sess, input_tensor_name, output_tensor_names, output_stride)
810

9-
def preprocess_input(self):
10-
return self
11+
def preprocess_input(self, image):
12+
target_width, target_height = self.valid_resolution(image.shape[1], image.shape[0])
13+
# the scale that can get us back to the original width and height:
14+
scale = np.array([image.shape[0] / target_height, image.shape[1] / target_width])
15+
input_img = cv2.resize(image, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
16+
input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB).astype(np.float32) # to RGB colors
1117

12-
def name_output_results(self, graph):
13-
return graph
18+
input_img = input_img * (2.0 / 255.0) - 1.0 # normalize to [-1,1]
19+
input_img = input_img.reshape(1, target_height, target_width, 3) # NHWC
20+
return input_img, scale

posenet/posenet.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from posenet.base_model import BaseModel
2+
3+
4+
class PoseNet:
5+
6+
def __init__(self, model: BaseModel):
7+
self.model = model
8+
9+
def estimate_multiple_poses(self, image):
10+
heatmap_result, offsets_result, displacement_fwd_result, displacement_bwd_result, image_scale = \
11+
self.model.predict(image)
12+
13+
return self
14+
15+
def estimate_single_pose(self, image):
16+
heatmap_result, offsets_result, displacement_fwd_result, displacement_bwd_result, image_scale = \
17+
self.model.predict(image)
18+
19+
# poses = [{'nose': {'x': 0.0, 'y': 0.0, 'score': 0}}]
20+
21+
return self

posenet/posenet_factory.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tensorflow as tf
2+
import os
3+
import posenet.converter.tfjsdownload as tfjsdownload
4+
import posenet.converter.tfjs2tf as tfjs2tf
5+
from posenet.resnet import ResNet
6+
from posenet.mobilenet import MobileNet
7+
from posenet.posenet import PoseNet
8+
9+
10+
def load_model(model, neuralnet, model_variant):
11+
12+
model_cfg = tfjsdownload.model_config(model, neuralnet, model_variant)
13+
model_path = model_cfg['tf_dir']
14+
if not os.path.exists(model_path):
15+
print('Cannot find tf model path %s, converting from tfjs...' % model_path)
16+
tfjs2tf.convert(model, neuralnet, model_variant)
17+
assert os.path.exists(model_path)
18+
19+
with tf.compat.v1.Session() as sess:
20+
21+
sess.graph.as_default()
22+
tf.compat.v1.saved_model.loader.load(sess, ["serve"], model_path)
23+
24+
output_tensor_names = model_cfg['output_tensors']
25+
input_tensor_name = model_cfg['input_tensors']['image']
26+
27+
if neuralnet == 'resnet50_v1':
28+
net = ResNet(sess, input_tensor_name, output_tensor_names, model_cfg['output_stride'])
29+
else:
30+
net = MobileNet(sess, input_tensor_name, output_tensor_names, model_cfg['output_stride'])
31+
32+
return PoseNet(net)

posenet/resnet.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from posenet.base_model import BaseModel
2+
import numpy as np
3+
import cv2
24

35

46
class ResNet(BaseModel):
57

6-
def __init__(self, output_stride):
7-
super().__init__(output_stride)
8+
def __init__(self, sess, input_tensor_name, output_tensor_names, output_stride):
9+
super().__init__(sess, input_tensor_name, output_tensor_names, output_stride)
10+
self.image_net_mean = [-123.15, -115.90, -103.06]
811

9-
def preprocess_input(self):
10-
return self
12+
def preprocess_input(self, image):
13+
target_width, target_height = self.valid_resolution(image.shape[1], image.shape[0])
14+
# the scale that can get us back to the original width and height:
15+
scale = np.array([image.shape[0] / target_height, image.shape[1] / target_width])
16+
input_img = cv2.resize(image, (target_width, target_height), interpolation=cv2.INTER_LINEAR)
17+
input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB).astype(np.float32) # to RGB colors
1118

12-
def name_output_results(self, graph):
13-
return graph
19+
input_img = input_img + self.image_net_mean
20+
input_img = input_img.reshape(1, target_height, target_width, 3) # NHWC
21+
return input_img, scale

0 commit comments

Comments
 (0)