Skip to content

Commit

Permalink
Tests pass (more to come)
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Jul 15, 2024
1 parent c65dfaa commit c2eefe2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
args: [ "--fix", "--output-format=github" ]
# Run the formatter.
- id: ruff-format
38 changes: 22 additions & 16 deletions circuitmatter/tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class ElementType(enum.IntEnum):
NULL = 0b10100
STRUCTURE = 0b10101
ARRAY = 0b10110
LIST = 0b10111
Expand All @@ -20,6 +21,7 @@ def __init__(self, buffer=None):
self.buffer: memoryview = buffer
# These three dicts are keyed by tag.
self.tag_value_offset = {}
self.null_tags = set()
self.tag_value_length = {}
self.cached_values = {}
self._offset = 0 # Stopped at the next control octet
Expand All @@ -46,7 +48,7 @@ def scan_until(self, tag):
tag_control = control_octet >> 5
element_type = control_octet & 0x1F
print(
f"Control 0x{control_octet:x} tag_control {tag_control} element_type {element_type}"
f"Control 0x{control_octet:x} tag_control {tag_control} element_type {element_type:x}"
)

this_tag = None
Expand Down Expand Up @@ -82,6 +84,7 @@ def scan_until(self, tag):

length_offset = self._offset + 1 + TAG_LENGTH[tag_control]
element_category = element_type >> 2
print(f"element_category {element_category}")
if element_category == 0 or element_category == 1: # ints
value_offset = length_offset
value_length = 1 << (element_type & 0x3)
Expand All @@ -97,15 +100,19 @@ def scan_until(self, tag):
elif (
element_category == 3 or element_category == 4
): # UTF-8 String or Octet String
print(f"element_type {element_type:x}", bin(element_type))
power_of_two = element_type & 0x3
print(f"power_of_two {power_of_two}")
length_length = 1 << power_of_two
print(f"length_length {length_length}")
value_offset = length_offset + length_length
value_length = struct.unpack_from(
INT_SIZE[power_of_two], self.buffer, length_offset
)[0]
elif element_type == 0b10100: # Null
value_offset = self._offset
value_length = 1
self.null_tags.add(this_tag)
else: # Container
value_offset = length_offset
value_length = 1
Expand Down Expand Up @@ -161,7 +168,7 @@ def __get__(
return obj.cached_values[self.tag]
if self.tag not in obj.tag_value_offset:
obj.scan_until(self.tag)
if self.tag not in obj.tag_value_offset:
if self.tag not in obj.tag_value_offset or self.tag in obj.null_tags:
return None

value = self.decode(
Expand All @@ -175,6 +182,12 @@ def __get__(
def __set__(self, obj: TLVStructure, value: Any) -> None:
obj.cached_values[self.tag] = value

def print(self, obj):
value = self.__get__(obj)
if value is None:
return "null"
return self._print(value)


class IntegerMember(Member):
def __init__(self, tag, _format, optional=False):
Expand All @@ -191,8 +204,7 @@ def decode(self, buffer, length, offset=0):
encoded_format = self.format
return struct.unpack_from(encoded_format, buffer, offset=offset)[0]

def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
unsigned = "U" if self.format.isupper() else ""
return f"{value}{unsigned}"

Expand All @@ -205,8 +217,7 @@ def decode(self, buffer, length, offset=0):
encoded_format = "<d"
return struct.unpack_from(encoded_format, buffer, offset=offset)[0]

def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
return f"{value}"


Expand All @@ -215,8 +226,8 @@ def decode(self, buffer, length, offset=0) -> bool:
octet = buffer[offset]
return octet & 1 == 1

def print(self, obj):
if self.__get__(obj):
def _print(self, value):
if value:
return "true"
return "false"

Expand All @@ -229,8 +240,7 @@ def __init__(self, tag, max_length, optional=False):
def decode(self, buffer, length, offset=0):
return buffer[offset : offset + length]

def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
return " ".join((f"{byte:02x}" for byte in value))


Expand All @@ -242,8 +252,7 @@ def __init__(self, tag, max_length, optional=False):
def decode(self, buffer, length, offset=0):
return buffer[offset : offset + length].decode("utf-8")

def print(self, obj):
value = self.__get__(obj)
def _print(self, value):
return f'"{value}"'


Expand All @@ -255,8 +264,5 @@ def __init__(self, tag, substruct_class, optional=False):
def decode(self, buffer, length, offset=0) -> TLVStructure:
return self.substruct_class(buffer[offset : offset + length])

def print(self, obj):
value = self.__get__(obj)
if value is None:
return "null"
def _print(self, value):
return str(value)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ dynamic = ["version", "description"]

[project.urls]
Home = "https://github.com/adafruit/circuitmatter"

[tool.pytest.ini_options]
pythonpath = [
"."
]
Empty file added tests/__init__.py
Empty file.
16 changes: 5 additions & 11 deletions tests/test_tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,17 @@ def test_utf8_string_tschs_decode(self):
# assert bytes(s) == b"\x0c\x06Hello!"


# Octet String, 1-octet length, octets 00 01 02 03 04 10 05 00 01 02 03 04
# Octet String, 1-octet length, octets 00 01 02 03 04
# encoded: 10 05 00 01 02 03 04
class OctetStringOneOctet(tlv.TLVStructure):
s = tlv.OctetStringMember(None, 16)


class TestOctetString:
def test_octet_string_decode(self):
s = OctetStringOneOctet(
b"\x0d\x0c\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
)
assert str(s) == "{\n s = 00 01 02 03 04 10 05 00 01 02 03 04\n}"
assert s.s == b"\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"

# def test_octet_string_encode(self):
# s = OctetString()
# s.s = b"\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
# assert bytes(s) == b"\x0d\x0c\x00\x01\x02\x03\x04\x10\x05\x00\x01\x02\x03\x04"
s = OctetStringOneOctet(b"\x10\x05\x00\x01\x02\x03\x04")
assert str(s) == "{\n s = 00 01 02 03 04\n}"
assert s.s == b"\x00\x01\x02\x03\x04"


# Null
Expand Down

0 comments on commit c2eefe2

Please sign in to comment.