Skip to content

Commit

Permalink
Mypy fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Carlstrom <[email protected]>
  • Loading branch information
InvincibleRMC committed Mar 29, 2024
1 parent 3e30aa5 commit 1b049b8
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions mqtt_ros_bridge/encodings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Protocol, Type, TypeVar, TypeAlias, cast, Iterable
import json
from array import array
from typing import (Any, Iterable, MutableSequence, Protocol, Type, TypeAlias,
TypeVar, cast)

from numpy import ndarray, array
from numpy import ndarray
from numpy.typing import NDArray

import json
from rclpy.type_support import check_is_valid_msg_type


NestedDictionary: TypeAlias = dict[str, 'NestedDictionary'] | dict[str, object]


Expand All @@ -19,18 +19,22 @@ def get_fields_and_field_types(cls) -> dict[str, str]:


MsgLikeT = TypeVar("MsgLikeT", bound=MsgLike)
ArrayElementT = TypeVar('ArrayElementT', int, float, str)


RESERVED_FIELD_TYPE = '_msgs/'
RESERVED_FIELD_TYPE = '/'
ENCODING = 'latin-1'


def numpy_encoding(array: NDArray) -> list:
return [int(x) for x in array]
def numpy_encoding(array_arg: NDArray[Any]) -> list[int]:
return [int(x) for x in array_arg]


def numpy_decoding(ls: list) -> NDArray:
return array(ls)
def array_encoding(array_arg: MutableSequence[ArrayElementT]) -> list[ArrayElementT]:
if len(array_arg) == 0:
return []
element_type = type(array_arg[0])
return [element_type(x) for x in array_arg]


def human_readable_encoding(msg: MsgLike) -> bytes:
Expand All @@ -55,8 +59,12 @@ def human_readable_encoding_recursive(msg: MsgLike) -> NestedDictionary:
value = [byte.decode(ENCODING) for byte in value]
elif RESERVED_FIELD_TYPE in field_types:
value = [human_readable_encoding_recursive(msg_in_list) for msg_in_list in value]
elif isinstance(value, list) and len(value) == 0:
value = []
elif isinstance(value, ndarray):
value = numpy_encoding(value)
elif isinstance(value, array):
value = array_encoding(value)
elif RESERVED_FIELD_TYPE in field_types:
value = human_readable_encoding_recursive(value)
msg_dict[field] = value
Expand All @@ -77,19 +85,29 @@ def human_readable_decoding_recursive(msg_dict: NestedDictionary,
msg = msg_type()
set_value: object
for field, value in msg_dict.items():
if isinstance(getattr(msg, field), bytes):
field_default = getattr(msg, field)
if isinstance(field_default, bytes):
if isinstance(value, str):
set_value = value.encode(ENCODING)
elif isinstance(value, dict):
set_value = human_readable_decoding_recursive(value, type(getattr(msg, field)))
elif isinstance(field_default, list):
if len(field_default) == 0:
set_value = []
else:
field_default_element = field_default[0]
if isinstance(field_default_element, bytes):
value = cast(list[str], value)
set_value = [byte.encode(ENCODING) for byte in value]
elif RESERVED_FIELD_TYPE in msg_type.get_fields_and_field_types()[field]:
value = cast(Iterable[NestedDictionary], value)
set_value = [human_readable_decoding_recursive(msg_in_list,
type(getattr(msg, field)[0]))
for msg_in_list in value]
else:
set_value = value
else:
set_value = value

setattr(msg, field, set_value)
return msg

from test_msgs.msg import Arrays

# print(Arrays())

print(human_readable_decoding(human_readable_encoding(Arrays()), Arrays))

0 comments on commit 1b049b8

Please sign in to comment.