Skip to content

Commit

Permalink
Check format by flake and isort (HiroIshida#31)
Browse files Browse the repository at this point in the history
* Fix format according to flake and isort

* Check format in ci
  • Loading branch information
HiroIshida authored Nov 8, 2022
1 parent ddb047a commit ffc3f08
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 88 deletions.
8 changes: 8 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[flake8]
ignore =
I, H # follow isort
E501 # max length should be checked by black
E203 # https://github.com/psf/black/issues/315
W503 # https://github.com/psf/black/issues/52
E402 # because we can't avoid sys.path.append stuff
A, B, C, CNL, D, Q # external plugins if installed
7 changes: 6 additions & 1 deletion .github/workflows/peripheral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ jobs:
uses: actions/checkout@v2
- name: pip install formatters and mypy
run: |
pip3 install mypy
pip3 install mypy flake8 isort
- name: check by mypy
run: |
pip3 install -r requirements.txt
pip3 install numpy==1.23 # to enable numpy's type checking
mypy --version
mypy .
- name: check by isrot and flake8
run: |
python3 -m isort example/ test/ node_script/
python3 -m flake8 example/ test/ node_script/
2 changes: 2 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[settings]
profile = black
6 changes: 3 additions & 3 deletions example/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/usr/bin/env python3
import argparse
import cv_bridge

import cv2
import matplotlib.pyplot as plt
import rospy
from cv_bridge import CvBridge
from sensor_msgs.msg import Image

from detic_ros.srv import DeticSeg
import matplotlib.pyplot as plt

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down
12 changes: 6 additions & 6 deletions example/masked_image_publisher.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/usr/bin/env python3
import copy

import message_filters
import numpy as np
import rospy
import message_filters
from cv_bridge import CvBridge
from sensor_msgs.msg import Image

from detic_ros.msg import SegmentationInfo
from cv_bridge import CvBridge


class SampleNode:
Expand Down Expand Up @@ -46,7 +48,7 @@ def callback(self, msg_image, msg_info: SegmentationInfo):
img = bridge.imgmsg_to_cv2(msg_image, desired_encoding='passthrough')

# Add 1 to label_index to account for the background
mask_indexes = np.where(img==label_index+1)
mask_indexes = np.where(img == label_index + 1)

masked_img = copy.deepcopy(img)
masked_img[mask_indexes] = 0 # filled by black
Expand All @@ -55,7 +57,7 @@ def callback(self, msg_image, msg_info: SegmentationInfo):
self.pub.publish(msg_out)


if __name__=='__main__':
if __name__ == '__main__':
import sys
argv = [x for x in sys.argv if not x.startswith('_')] # remove roslaunch args
if len(argv) == 1:
Expand All @@ -70,5 +72,3 @@ def callback(self, msg_image, msg_info: SegmentationInfo):
SampleNode(mask_class_name=class_name)
rospy.spin()
pass


15 changes: 7 additions & 8 deletions node_script/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@
import argparse
import os
import pickle
from numpy import inf
from typing import List, Optional

import rosbag
import rospy
import tqdm
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from detic_ros.msg import SegmentationInfo
from moviepy.editor import ImageSequenceClip
import tqdm

from node_config import NodeConfig
from sensor_msgs.msg import Image
from wrapper import DeticWrapper


Expand Down Expand Up @@ -78,7 +74,7 @@ def dump_result_as_rosbag(input_bagfile_name, results, output_file_name):
bag_out.close()


if __name__=='__main__':
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str, help='input file path')
parser.add_argument('-model', type=str, default='swin', help='model type')
Expand Down Expand Up @@ -129,7 +125,10 @@ def dump_result_as_rosbag(input_bagfile_name, results, output_file_name):

# dump debug gif image
bridge = CvBridge()
convert = lambda msg: bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')

def convert(msg):
bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')

debug_images = [result[1] for result in results]
images = list(map(convert, debug_images))
clip = ImageSequenceClip(images, fps=20)
Expand Down
14 changes: 7 additions & 7 deletions node_script/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Optional

import rospy
from rospy import Subscriber, Publisher
from sensor_msgs.msg import Image
from detic_ros.msg import SegmentationInfo
from detic_ros.srv import DeticSeg, DeticSegRequest, DeticSegResponse
from jsk_recognition_msgs.msg import LabelArray, VectorArray

from node_config import NodeConfig
from rospy import Publisher, Subscriber
from sensor_msgs.msg import Image
from wrapper import DeticWrapper

from detic_ros.msg import SegmentationInfo
from detic_ros.srv import DeticSeg, DeticSegRequest, DeticSegResponse


class DeticRosNode:
detic_wrapper: DeticWrapper
Expand All @@ -22,7 +22,7 @@ class DeticRosNode:
pub_labels: Publisher
pub_score: Publisher

def __init__(self, node_config: Optional[NodeConfig]=None):
def __init__(self, node_config: Optional[NodeConfig] = None):
if node_config is None:
node_config = NodeConfig.from_rosparam()

Expand Down Expand Up @@ -90,7 +90,7 @@ def callback_srv(self, req: DeticSegRequest) -> DeticSegResponse:
return resp


if __name__=='__main__':
if __name__ == '__main__':
rospy.init_node('detic_node', anonymous=True)
node = DeticRosNode()
rospy.spin()
59 changes: 29 additions & 30 deletions node_script/node_config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
import sys
from dataclasses import dataclass
import rospy
import rospkg

import rospkg
import rospy
import torch

# Dirty but no way, because CenterNet2 is not package oriented
sys.path.insert(0, os.path.join(sys.path[0], 'third_party/CenterNet2/'))

from detectron2.config import get_cfg
from centernet.config import add_centernet_config
from detectron2.config import get_cfg
from detic.config import add_detic_config


Expand All @@ -36,7 +36,8 @@ class NodeConfig:
}

@classmethod
def from_args(cls,
def from_args(
cls,
model_type: str = 'swin',
enable_pubsub: bool = True,
out_debug_img: bool = True,
Expand All @@ -46,8 +47,7 @@ def from_args(cls,
confidence_threshold: float = 0.5,
device_name: str = 'auto',
vocabulary: str = 'lvis',
custom_vocabulary: str = '',
):
custom_vocabulary: str = ''):

if device_name == 'auto':
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -67,33 +67,32 @@ def from_args(cls,
model_name + '.pth')

return cls(
enable_pubsub,
out_debug_img,
out_debug_segimg,
verbose,
use_jsk_msgs,
vocabulary,
custom_vocabulary,
default_detic_config_path,
default_model_weights_path,
confidence_threshold,
device_name)
enable_pubsub,
out_debug_img,
out_debug_segimg,
verbose,
use_jsk_msgs,
vocabulary,
custom_vocabulary,
default_detic_config_path,
default_model_weights_path,
confidence_threshold,
device_name)

@classmethod
def from_rosparam(cls):

return cls.from_args(
rospy.get_param('~model_type', 'swin'),
rospy.get_param('~enable_pubsub', True),
rospy.get_param('~out_debug_img', True),
rospy.get_param('~out_debug_segimg', False),
rospy.get_param('~verbose', True),
rospy.get_param('~use_jsk_msgs', False),
rospy.get_param('~confidence_threshold', 0.5),
rospy.get_param('~device', 'auto'),
rospy.get_param('~vocabulary', 'lvis'),
rospy.get_param('~custom_vocabulary', ''),
)
rospy.get_param('~model_type', 'swin'),
rospy.get_param('~enable_pubsub', True),
rospy.get_param('~out_debug_img', True),
rospy.get_param('~out_debug_segimg', False),
rospy.get_param('~verbose', True),
rospy.get_param('~use_jsk_msgs', False),
rospy.get_param('~confidence_threshold', 0.5),
rospy.get_param('~device', 'auto'),
rospy.get_param('~vocabulary', 'lvis'),
rospy.get_param('~custom_vocabulary', ''))

def to_detectron_config(self):
cfg = get_cfg()
Expand All @@ -107,13 +106,13 @@ def to_detectron_config(self):
cfg.merge_from_list(['MODEL.WEIGHTS', self.model_weights_path])

# Similar to https://github.com/facebookresearch/Detic/demo.py
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' # load later
cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand' # load later
cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = True

# Maybe should edit detic_configs/Base-C2_L_R5021k_640b64_4x.yaml
pack_path = rospkg.RosPack().get_path('detic_ros')
cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = os.path.join(
pack_path, 'datasets/metadata/lvis_v1_train_cat_info.json')
pack_path, 'datasets/metadata/lvis_v1_train_cat_info.json')

cfg.freeze()
return cfg
29 changes: 14 additions & 15 deletions node_script/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import os
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple

import rospy
import detic
import numpy as np
import rospkg
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from std_msgs.msg import Header
from detic_ros.msg import SegmentationInfo

import rospy
import torch
import numpy as np

import detic
from cv_bridge import CvBridge
from detectron2.utils.visualizer import VisImage
from detic.predictor import VisualizationDemo
from jsk_recognition_msgs.msg import Label, LabelArray, VectorArray
from detectron2.utils.visualizer import VisImage
from node_config import NodeConfig
from sensor_msgs.msg import Image
from std_msgs.msg import Header

from detic_ros.msg import SegmentationInfo


class DeticWrapper:
Expand All @@ -26,8 +25,9 @@ class DeticWrapper:
header: Optional[Header]
data: Optional[np.ndarray]

class DummyArgs:
class DummyArgs:
vocabulary: str

def __init__(self, vocabulary, custom_vocabulary):
assert vocabulary in ['lvis', 'openimages', 'objects365', 'coco', 'custom']
self.vocabulary = vocabulary
Expand All @@ -45,14 +45,13 @@ def __init__(self, node_config: NodeConfig):
self.header = None
self.data = None


@staticmethod
def _adhoc_hack_metadata_path():
# because original BUILDIN_CLASSIFIER is somehow posi-dep
rospack = rospkg.RosPack()
pack_path = rospack.get_path('detic_ros')
path_dict = detic.predictor.BUILDIN_CLASSIFIER
for key in path_dict.keys():
for key in path_dict.keys():
path_dict[key] = os.path.join(pack_path, path_dict[key])

def inference_step(self, msg: Image) -> Tuple[Image, List[int], List[float], Optional[VisImage]]:
Expand Down Expand Up @@ -123,7 +122,7 @@ def get_segmentation_info(self, seg_img: Image,

def get_label_array(self, detected_classes: List[int]) -> LabelArray:
# Label 0 is reserved for the background
labels = [Label(id=i+1, name=self.class_names[i]) for i in detected_classes]
labels = [Label(id=i + 1, name=self.class_names[i]) for i in detected_classes]
lab_arr = LabelArray(header=self.header,
labels=labels)
return lab_arr
Expand Down
9 changes: 5 additions & 4 deletions test/test_batch_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/usr/bin/env python3
import os
import pickle
import unittest

import pickle
import rostest, rospy
import rospkg
import rospy
import rostest


class TestNode(unittest.TestCase):
Expand All @@ -19,15 +20,15 @@ def test_pkl_dump(self):

pkl_path = os.path.join(pkg_path, 'test', 'data', 'desk_segmented.pkl')
with open(pkl_path, 'rb') as f:
obj = pickle.load(f)
pickle.load(f)

def test_rosbag_dump(self):
pkg_path = rospkg.RosPack().get_path('detic_ros')
bag_path = os.path.join(pkg_path, 'test', 'data', 'desk.bag')
ret = os.system('rosrun detic_ros batch_processor.py {} -n 1 -format bag'.format(bag_path))
assert ret == 0

outbag_path = os.path.join(pkg_path, 'test', 'data', 'desk_segmented.bag')
os.path.join(pkg_path, 'test', 'data', 'desk_segmented.bag')


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit ffc3f08

Please sign in to comment.