Skip to content

Commit

Permalink
refactor: wrapper should not have internal state (HiroIshida#37)
Browse files Browse the repository at this point in the history
* refactor: wrapper should not have internal state

* Fix batch_processor

* Remove version specification for local install

* Some more refactor
  • Loading branch information
HiroIshida authored Nov 12, 2022
1 parent 10d4904 commit 6738058
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 95 deletions.
31 changes: 22 additions & 9 deletions node_script/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import pickle
from typing import List, Optional

import numpy as np
import rosbag
import tqdm
from cv_bridge import CvBridge
from moviepy.editor import ImageSequenceClip
from node_config import NodeConfig
from sensor_msgs.msg import Image
from wrapper import DeticWrapper
from wrapper import DeticWrapper, InferenceRawResult


def bag_to_images(file_path: str, topic_name_extract: Optional[str] = None):
Expand Down Expand Up @@ -48,10 +49,16 @@ def deep_cast(msg):
return image_list


def dump_result_as_pickle(results, image_list, output_file_name):
def dump_result_as_pickle(
results: List[InferenceRawResult],
images: List[Image],
output_file_name: str):

result_dict = {'image': [], 'seginfo': [], 'debug_image': []} # type: ignore
for ((seginfo, debug_image, _), image) in zip(results, image_list):
seginfo, debug_image, _ = detic_wrapper.infer(image)
for result, image in zip(results, images):
seginfo = result.get_segmentation_info()
debug_image = result.get_ros_debug_image()

result_dict['image'].append(image)
result_dict['seginfo'].append(seginfo)
result_dict['debug_image'].append(debug_image)
Expand All @@ -60,15 +67,21 @@ def dump_result_as_pickle(results, image_list, output_file_name):
pickle.dump(result_dict, f)


def dump_result_as_rosbag(input_bagfile_name, results, output_file_name):
def dump_result_as_rosbag(
input_bagfile_name: str,
results: List[InferenceRawResult],
output_file_name: str):

bag_out = rosbag.Bag(output_file_name, 'w')

bag_inp = rosbag.Bag(input_bagfile_name)
for topic_name, msg, t in bag_inp.read_messages():
bag_out.write(topic_name, msg, t)
bag_inp.close()

for seginfo, debug_image, _ in results:
for result in results:
seginfo = result.get_segmentation_info()
debug_image = result.get_ros_debug_image()
bag_out.write('/detic_segmentor/segmentation_info', seginfo, seginfo.header.stamp)
bag_out.write('/detic_segmentor/debug_image', debug_image, debug_image.header.stamp)
bag_out.close()
Expand Down Expand Up @@ -126,10 +139,10 @@ def dump_result_as_rosbag(input_bagfile_name, results, output_file_name):
# dump debug gif image
bridge = CvBridge()

def convert(msg):
bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
def convert(msg) -> np.ndarray:
return bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')

debug_images = [result[1] for result in results]
debug_images = [result.get_ros_debug_image() for result in results]
images = list(map(convert, debug_images))
clip = ImageSequenceClip(images, fps=20)
clip.write_gif(debug_file_name, fps=20)
32 changes: 20 additions & 12 deletions node_script/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
class DeticRosNode:
detic_wrapper: DeticWrapper
sub: Subscriber
pub_debug_image: Publisher
pub_debug_segmentation_image: Publisher
pub_debug_image: Optional[Publisher]
pub_debug_segmentation_image: Optional[Publisher]
pub_info: Publisher
pub_segimg: Publisher
pub_labels: Publisher
Expand Down Expand Up @@ -55,23 +55,28 @@ def __init__(self, node_config: Optional[NodeConfig] = None):

def callback_image(self, msg: Image):
# Inference
seg_img, labels, scores, vis_img = self.detic_wrapper.inference_step(msg)
raw_result = self.detic_wrapper.infer(msg)

# Publish main topics
if self.detic_wrapper.node_config.use_jsk_msgs:
seg_img = raw_result.get_ros_segmentaion_image()
labels = raw_result.get_label_array()
scores = raw_result.get_score_array()
self.pub_segimg.publish(seg_img)
self.pub_labels.publish(self.detic_wrapper.get_label_array(labels))
self.pub_score.publish(self.detic_wrapper.get_score_array(scores))
self.pub_labels.publish(labels)
self.pub_score.publish(scores)
else:
seg_info = self.detic_wrapper.get_segmentation_info(seg_img, labels, scores)
seg_info = raw_result.get_segmentation_info()
self.pub_info.publish(seg_info)

# Publish optional topics
if self.pub_debug_image is not None and vis_img is not None:
debug_img = self.detic_wrapper.get_debug_img(vis_img)

if self.pub_debug_image is not None:
debug_img = raw_result.get_ros_debug_image()
self.pub_debug_image.publish(debug_img)

if self.pub_debug_segmentation_image is not None:
debug_seg_img = self.detic_wrapper.get_debug_segimg()
debug_seg_img = raw_result.get_ros_debug_segmentation_img()
self.pub_debug_segmentation_image.publish(debug_seg_img)

# Print debug info
Expand All @@ -81,12 +86,15 @@ def callback_image(self, msg: Image):

def callback_srv(self, req: DeticSegRequest) -> DeticSegResponse:
msg = req.image
seginfo, debug_img, _ = self.detic_wrapper.infer(msg)
raw_result = self.detic_wrapper.infer(msg)
seginfo = raw_result.get_segmentation_info()

resp = DeticSegResponse()
resp.seg_info = seginfo
if debug_img is not None:
resp.debug_image = debug_img

if raw_result.visualization is not None:
debug_image = raw_result.get_ros_debug_segmentation_img()
resp.debug_image = debug_image
return resp


Expand Down
137 changes: 65 additions & 72 deletions node_script/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional

import detic
import numpy as np
Expand All @@ -17,13 +18,62 @@
from detic_ros.msg import SegmentationInfo


_cv_bridge = CvBridge()


@dataclass(frozen=True)
class InferenceRawResult:
segmentation_raw_image: np.ndarray
class_indices: List[int]
scores: List[float]
visualization: Optional[VisImage]
header: Header
class_names: List[str]

def get_ros_segmentaion_image(self) -> Image:
seg_img = _cv_bridge.cv2_to_imgmsg(self.segmentation_raw_image, encoding="32SC1")
seg_img.header = self.header
return seg_img

def get_ros_debug_image(self) -> Image:
message = "you didn't configure the wrapper so that it computes the debug images"
assert self.visualization is not None, message
debug_img = _cv_bridge.cv2_to_imgmsg(
self.visualization.get_image(), encoding="rgb8")
debug_img.header = self.header
return debug_img

def get_ros_debug_segmentation_img(self) -> Image:
human_friendly_scaling = 255 // self.segmentation_raw_image.max()
new_data = (self.segmentation_raw_image * human_friendly_scaling).astype(np.uint8)
debug_seg_img = _cv_bridge.cv2_to_imgmsg(new_data, encoding="mono8")
assert self.header is not None
debug_seg_img.header = self.header
return debug_seg_img

def get_label_array(self) -> LabelArray:
labels = [Label(id=i + 1, name=self.class_names[i]) for i in self.class_indices]
lab_arr = LabelArray(header=self.header, labels=labels)
return lab_arr

def get_score_array(self) -> VectorArray:
vec_arr = VectorArray(header=self.header, vector_dim=len(self.scores), data=self.scores)
return vec_arr

def get_segmentation_info(self) -> SegmentationInfo:
seg_img = self.get_ros_segmentaion_image()
detected_classes_names = [self.class_names[i] for i in self.class_indices]
seg_info = SegmentationInfo(detected_classes=detected_classes_names,
scores=self.scores,
segmentation=seg_img,
header=self.header)
return seg_info


class DeticWrapper:
predictor: VisualizationDemo
node_config: NodeConfig
bridge: CvBridge
class_names: List[str]
header: Optional[Header]
data: Optional[np.ndarray]

class DummyArgs:
vocabulary: str
Expand All @@ -40,24 +90,20 @@ def __init__(self, node_config: NodeConfig):

self.predictor = VisualizationDemo(detectron_cfg, dummy_args)
self.node_config = node_config
self.bridge = CvBridge()
self.class_names = self.predictor.metadata.get("thing_classes", None)
self.header = None
self.data = None

@staticmethod
def _adhoc_hack_metadata_path():
# because original BUILDIN_CLASSIFIER is somehow posi-dep
# because original BUILDIN_CLASSIFIER is somehow position dependent
rospack = rospkg.RosPack()
pack_path = rospack.get_path('detic_ros')
path_dict = detic.predictor.BUILDIN_CLASSIFIER
for key in path_dict.keys():
path_dict[key] = os.path.join(pack_path, path_dict[key])

def inference_step(self, msg: Image) -> Tuple[Image, List[int], List[float], Optional[VisImage]]:
def infer(self, msg: Image) -> InferenceRawResult:
# Segmentation image, detected classes, detection scores, visualization image
img = self.bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
self.header = msg.header
img = _cv_bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')

if self.node_config.verbose:
time_start = rospy.Time.now()
Expand All @@ -83,68 +129,15 @@ def inference_step(self, msg: Image) -> Tuple[Image, List[int], List[float], Opt
mask = instances.pred_masks[i]
# label 0 is reserved for background label, so starting from 1
data[mask] = (i + 1)
self.data = data
seg_img = self.bridge.cv2_to_imgmsg(data, encoding="32SC1")
seg_img.header = self.header

# Get class and score arrays
class_indexes = instances.pred_classes.tolist()
scores = instances.scores.tolist()
return seg_img, class_indexes, scores, visualized_output

def get_debug_img(self, visualized_output: VisImage) -> Image:
# Call after inference_step
debug_img = self.bridge.cv2_to_imgmsg(visualized_output.get_image(),
encoding="rgb8")
assert self.header is not None
debug_img.header = self.header
return debug_img

def get_debug_segimg(self) -> Image:
# Call after inference_step
assert self.data is not None
human_friendly_scaling = 255 // self.data.max()
new_data = (self.data * human_friendly_scaling).astype(np.uint8)
debug_seg_img = self.bridge.cv2_to_imgmsg(new_data, encoding="mono8")
assert self.header is not None
debug_seg_img.header = self.header
return debug_seg_img

def get_segmentation_info(self, seg_img: Image,
detected_classes: List[int],
scores: List[float]) -> SegmentationInfo:
detected_classes_names = [self.class_names[i] for i in detected_classes]
seg_info = SegmentationInfo(detected_classes=detected_classes_names,
scores=scores,
segmentation=seg_img,
header=self.header)
return seg_info

def get_label_array(self, detected_classes: List[int]) -> LabelArray:
# Label 0 is reserved for the background
labels = [Label(id=i + 1, name=self.class_names[i]) for i in detected_classes]
lab_arr = LabelArray(header=self.header,
labels=labels)
return lab_arr

def get_score_array(self, scores: List[float]) -> VectorArray:
vec_arr = VectorArray(header=self.header,
vector_dim=len(scores),
data=scores)
return vec_arr

def infer(self, msg: Image) -> Tuple[SegmentationInfo, Optional[Image], Optional[Image]]:
seg_img, labels, scores, vis_img = self.inference_step(msg)
seg_info = self.get_segmentation_info(seg_img, labels, scores)

if self.node_config.out_debug_img:
debug_img = self.get_debug_img(vis_img)
else:
debug_img = None

if self.node_config.out_debug_img:
debug_seg_img = self.get_debug_segimg()
else:
debug_seg_img = None

return seg_info, debug_img, debug_seg_img
result = InferenceRawResult(
data,
class_indexes,
scores,
visualized_output,
msg.header,
self.class_names)
return result
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ detectron2
# Copied from Detic
opencv-python==4.5.5.62
timm==0.5.4
dataclasses==0.6
dataclasses # remove version specification for local install
ftfy==6.0.3
regex==2022.1.18
fasttext==0.9.2
scikit-learn==1.0.2
numpy==1.19
numpy # remove version specification for local install
lvis==0.5.3
nltk==3.6.7
git+https://github.com/openai/CLIP.git

0 comments on commit 6738058

Please sign in to comment.