Skip to content

Commit

Permalink
Test struct decode
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Jul 15, 2024
1 parent c2eefe2 commit 5278ea1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 64 deletions.
22 changes: 11 additions & 11 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ class SecureProtocolOpcode(enum.Enum):
# : UNSIGNED INTEGER [ range 16-bits ],
# }
class SessionParameterStruct(tlv.TLVStructure):
session_idle_interval = tlv.IntegerMember(1, "<I", optional=True)
session_active_interval = tlv.IntegerMember(2, "<I", optional=True)
session_active_threshold = tlv.IntegerMember(3, "<H", optional=True)
data_model_revision = tlv.IntegerMember(4, "<H")
interaction_model_revision = tlv.IntegerMember(5, "<H")
specification_version = tlv.IntegerMember(6, "<I")
max_paths_per_invoke = tlv.IntegerMember(7, "<H")
session_idle_interval = tlv.NumberMember(1, "<I", optional=True)
session_active_interval = tlv.NumberMember(2, "<I", optional=True)
session_active_threshold = tlv.NumberMember(3, "<H", optional=True)
data_model_revision = tlv.NumberMember(4, "<H")
interaction_model_revision = tlv.NumberMember(5, "<H")
specification_version = tlv.NumberMember(6, "<I")
max_paths_per_invoke = tlv.NumberMember(7, "<H")


# pbkdfparamreq-struct => STRUCTURE [ tag-order ]
Expand All @@ -139,8 +139,8 @@ class SessionParameterStruct(tlv.TLVStructure):
# }
class PBKDFParamRequest(tlv.TLVStructure):
initiatorRandom = tlv.OctetStringMember(1, 32)
initiatorSessionId = tlv.IntegerMember(2, "<H")
passcodeId = tlv.IntegerMember(3, "<H")
initiatorSessionId = tlv.NumberMember(2, "<H")
passcodeId = tlv.NumberMember(3, "<H")
hasPBKDFParameters = tlv.BoolMember(4)
initiatorSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)

Expand All @@ -151,7 +151,7 @@ class PBKDFParamRequest(tlv.TLVStructure):
# salt [2] : OCTET STRING [ length 16..32 ],
# }
class Crypto_PBKDFParameterSet(tlv.TLVStructure):
iterations = tlv.IntegerMember(1, "<I")
iterations = tlv.NumberMember(1, "<I")
salt = tlv.OctetStringMember(2, 32)


Expand All @@ -170,6 +170,6 @@ class Crypto_PBKDFParameterSet(tlv.TLVStructure):
class PBKDFParamResponse(tlv.TLVStructure):
initiatorRandom = tlv.OctetStringMember(1, 32)
responderRandom = tlv.OctetStringMember(2, 32)
responderSessionId = tlv.IntegerMember(3, "<H")
responderSessionId = tlv.NumberMember(3, "<H")
pbkdf_parameters = tlv.StructMember(4, Crypto_PBKDFParameterSet)
responderSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)
94 changes: 69 additions & 25 deletions circuitmatter/tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


class ElementType(enum.IntEnum):
SIGNED_INT = 0b00000
UNSIGNED_INT = 0b00100
BOOL = 0b01000
NULL = 0b10100
STRUCTURE = 0b10101
ARRAY = 0b10110
Expand All @@ -17,6 +20,8 @@ class ElementType(enum.IntEnum):


class TLVStructure:
_max_length = None

def __init__(self, buffer=None):
self.buffer: memoryview = buffer
# These three dicts are keyed by tag.
Expand All @@ -26,6 +31,17 @@ def __init__(self, buffer=None):
self.cached_values = {}
self._offset = 0 # Stopped at the next control octet

@classmethod
def max_length(cls):
if cls._max_length is None:
cls._max_length = 0
for field in vars(cls):
descriptor_class = vars(cls)[field]
if field.startswith("_") or not isinstance(descriptor_class, Member):
continue
cls._max_length += descriptor_class.max_length
return cls._max_length

def __str__(self):
members = []
for field in vars(type(self)):
Expand All @@ -38,6 +54,15 @@ def __str__(self):
members.append(f"{field} = {value}")
return "{\n " + ",\n ".join(members) + "\n}"

def __bytes__(self):
buffer = bytearray(self.max_length())
offset = 0
for field in vars(type(self)):
descriptor_class = vars(type(self))[field]
if field.startswith("_") or not isinstance(descriptor_class, Member):
continue
offset += descriptor_class.encode_into(self, buffer, offset)

def scan_until(self, tag):
if self.buffer is None:
return
Expand Down Expand Up @@ -115,9 +140,8 @@ def scan_until(self, tag):
self.null_tags.add(this_tag)
else: # Container
value_offset = length_offset
value_length = 1
value_length = 0
nesting = 0
print("in container")
while (
self.buffer[value_offset + value_length]
!= ElementType.END_OF_CONTAINER
Expand All @@ -126,19 +150,13 @@ def scan_until(self, tag):
octet = self.buffer[value_offset + value_length]
if octet == ElementType.END_OF_CONTAINER:
nesting -= 1
print(nesting)
elif (octet & 0x1F) in (
ElementType.STRUCTURE,
ElementType.ARRAY,
ElementType.LIST,
):
nesting += 1
print(nesting)
value_length += 1
print(
f"new length {value_length} {self.buffer[value_offset + value_length]:02x}"
)
print(f"container length {value_length}")

self.tag_value_offset[this_tag] = value_offset
self.tag_value_length[this_tag] = value_length
Expand All @@ -158,6 +176,16 @@ class Member:
def __init__(self, tag, optional=False):
self.tag = tag
self.optional = optional
self.tag_length = 0
if isinstance(tag, int):
self.tag_length = 1
elif isinstance(tag, tuple):
self.tag_length = 8
self._max_length = None

@property
def max_length(self):
return self.tag_length + self.max_value_length

def __get__(
self,
Expand All @@ -182,17 +210,32 @@ def __get__(
def __set__(self, obj: TLVStructure, value: Any) -> None:
obj.cached_values[self.tag] = value

def encode_into(self, obj: TLVStructure, buffer: bytearray, offset: int) -> int:
value = self.__get__(obj)
element_type = ElementType.NULL
if value is not None:
element_type = self.encode_element_type(value)
buffer[offset] = 0x00 | element_type
offset += 1
if self.tag:
buffer[offset] = self.tag
offset += 1
if value is not None:
return self.encode_value_into(value, buffer, offset)
return offset

def print(self, obj):
value = self.__get__(obj)
if value is None:
return "null"
return self._print(value)


class IntegerMember(Member):
class NumberMember(Member):
def __init__(self, tag, _format, optional=False):
self.format = _format
self.integer = _format[-1] in INT_SIZE
self.integer = _format[-1].upper() in INT_SIZE
self.max_value_length = struct.calcsize(self.format)
super().__init__(tag, optional)

def decode(self, buffer, length, offset=0):
Expand All @@ -201,27 +244,20 @@ def decode(self, buffer, length, offset=0):
if self.format.islower():
encoded_format = encoded_format.lower()
else:
encoded_format = self.format
if length == 4:
encoded_format = "<f"
else:
encoded_format = "<d"
return struct.unpack_from(encoded_format, buffer, offset=offset)[0]

def _print(self, value):
unsigned = "U" if self.format.isupper() else ""
return f"{value}{unsigned}"


class FloatMember(Member):
def decode(self, buffer, length, offset=0):
if length == 4:
encoded_format = "<f"
else:
encoded_format = "<d"
return struct.unpack_from(encoded_format, buffer, offset=offset)[0]

def _print(self, value):
return f"{value}"


class BoolMember(Member):
max_value_length = 0

def decode(self, buffer, length, offset=0) -> bool:
octet = buffer[offset]
return octet & 1 == 1
Expand All @@ -231,10 +267,17 @@ def _print(self, value):
return "true"
return "false"

@property
def element_type(self, value):
return ElementType.BOOL | (1 if value else 0)

def encode_value_into(self, value, buffer, offset) -> int:
return offset


class OctetStringMember(Member):
def __init__(self, tag, max_length, optional=False):
self.max_length = max_length
self.max_value_length = max_length
super().__init__(tag, optional)

def decode(self, buffer, length, offset=0):
Expand All @@ -246,7 +289,7 @@ def _print(self, value):

class UTF8StringMember(Member):
def __init__(self, tag, max_length, optional=False):
self.max_length = max_length
self.max_value_length = max_length
super().__init__(tag, optional)

def decode(self, buffer, length, offset=0):
Expand All @@ -259,6 +302,7 @@ def _print(self, value):
class StructMember(Member):
def __init__(self, tag, substruct_class, optional=False):
self.substruct_class = substruct_class
self.max_value_length = substruct_class.max_length()
super().__init__(tag, optional)

def decode(self, buffer, length, offset=0) -> TLVStructure:
Expand Down
Loading

0 comments on commit 5278ea1

Please sign in to comment.