Skip to content

Commit

Permalink
Use mypy for static type checking (HiroIshida#30)
Browse files Browse the repository at this point in the history
* Add mypy.ini

* Fix type annotation misstake by recent update

* Ignore type annotation miss

* Add mypy test in ci
  • Loading branch information
HiroIshida authored Nov 8, 2022
1 parent a897aff commit ddb047a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 3 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/peripheral.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Peripheral test

on:
push:
branches:
- master
pull_request:
branches:
- master

jobs:
peripheral:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v2
- name: pip install formatters and mypy
run: |
pip3 install mypy
- name: check by mypy
run: |
pip3 install -r requirements.txt
pip3 install numpy==1.23 # to enable numpy's type checking
mypy --version
mypy .
41 changes: 41 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[mypy]
python_version = 3.8
exclude = Detic
show_error_codes = True
warn_unused_ignores = False
check_untyped_defs = True

[mypy-cv_bridge]
ignore_missing_imports = True
[mypy-rospy]
ignore_missing_imports = True
[mypy-rostest]
ignore_missing_imports = True
[mypy-rospkg]
ignore_missing_imports = True
[mypy-rosbag]
ignore_missing_imports = True
[mypy-message_filters]
ignore_missing_imports = True
[mypy-detectron2.*]
ignore_missing_imports = True
[mypy-detic.*]
ignore_missing_imports = True
[mypy-centernet.config]
ignore_missing_imports = True
[mypy-detic_ros.*]
ignore_missing_imports = True
[mypy-sensor_msgs.*]
ignore_missing_imports = True
[mypy-std_msgs.*]
ignore_missing_imports = True
[mypy-jsk_recognition_msgs.*]
ignore_missing_imports = True
[mypy-cv2]
ignore_missing_imports = True
[mypy-matplotlib.*]
ignore_missing_imports = True
[mypy-tqdm]
ignore_missing_imports = True
[mypy-moviepy.*]
ignore_missing_imports = True
2 changes: 1 addition & 1 deletion node_script/batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def deep_cast(msg):


def dump_result_as_pickle(results, image_list, output_file_name):
result_dict = {'image': [], 'seginfo': [], 'debug_image': []}
result_dict = {'image': [], 'seginfo': [], 'debug_image': []} # type: ignore
for ((seginfo, debug_image, _), image) in zip(results, image_list):
seginfo, debug_image, _ = detic_wrapper.infer(image)
result_dict['image'].append(image)
Expand Down
11 changes: 9 additions & 2 deletions node_script/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import rospkg
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from std_msgs.msg import Header
from detic_ros.msg import SegmentationInfo


import torch
import numpy as np

Expand All @@ -21,6 +21,10 @@
class DeticWrapper:
predictor: VisualizationDemo
node_config: NodeConfig
bridge: CvBridge
class_names: List[str]
header: Optional[Header]
data: Optional[np.ndarray]

class DummyArgs:
vocabulary: str
Expand Down Expand Up @@ -93,14 +97,17 @@ def get_debug_img(self, visualized_output: VisImage) -> Image:
# Call after inference_step
debug_img = self.bridge.cv2_to_imgmsg(visualized_output.get_image(),
encoding="rgb8")
assert self.header is not None
debug_img.header = self.header
return debug_img

def get_debug_segimg(self) -> Image:
# Call after inference_step
assert self.data is not None
human_friendly_scaling = 255 // self.data.max()
new_data = (self.data * human_friendly_scaling).astype(np.uint8)
debug_seg_img = self.bridge.cv2_to_imgmsg(new_data, encoding="mono8")
assert self.header is not None
debug_seg_img.header = self.header
return debug_seg_img

Expand All @@ -114,7 +121,7 @@ def get_segmentation_info(self, seg_img: Image,
header=self.header)
return seg_info

def get_label_array(self, detected_classes: List[str]) -> LabelArray:
def get_label_array(self, detected_classes: List[int]) -> LabelArray:
# Label 0 is reserved for the background
labels = [Label(id=i+1, name=self.class_names[i]) for i in detected_classes]
lab_arr = LabelArray(header=self.header,
Expand Down

0 comments on commit ddb047a

Please sign in to comment.