Skip to content

Commit

Permalink
Stricter types
Browse files Browse the repository at this point in the history
Makes things like this give a type error:

```
class NullTest(tlv.TLVStructure):
    n = tlv.BoolMember(None, nullable=True)
    notn = tlv.BoolMember(None, nullable=False)

s = NullTest()
s.n = None  # checks ok
s.notn = None  # doesn't type check
```
  • Loading branch information
CharString committed Jul 20, 2024
1 parent 6078dbe commit d6917a9
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 57 deletions.
172 changes: 120 additions & 52 deletions circuitmatter/tlv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import enum
import math
import struct
from typing import Any, Optional, Type, Union
from typing import Literal
from abc import ABC, abstractmethod
from typing import AnyStr, Generic, Iterable, Literal, Optional, Type, TypeVar, overload

# As a byte string to save space.
TAG_LENGTH = b"\x00\x01\x02\x04\x02\x04\x06\x08"
Expand All @@ -26,8 +28,8 @@ class ElementType(enum.IntEnum):
class TLVStructure:
_max_length = None

def __init__(self, buffer=None):
self.buffer: memoryview = buffer
def __init__(self, buffer: Optional[bytes] = None):
self.buffer = buffer
# These three dicts are keyed by tag.
self.tag_value_offset = {}
self.null_tags = set()
Expand All @@ -39,19 +41,13 @@ def __init__(self, buffer=None):
def max_length(cls):
if cls._max_length is None:
cls._max_length = 0
for field in vars(cls):
descriptor_class = vars(cls)[field]
if field.startswith("_") or not isinstance(descriptor_class, Member):
continue
for _field, descriptor_class in cls._members(cls):
cls._max_length += descriptor_class.max_length
return cls._max_length

def __str__(self):
members = []
for field in vars(type(self)):
descriptor_class = vars(type(self))[field]
if field.startswith("_") or not isinstance(descriptor_class, Member):
continue
for field, descriptor_class in self._members(type(self)):
value = descriptor_class.print(self)
if isinstance(descriptor_class, StructMember):
value = value.replace("\n", "\n ")
Expand All @@ -63,14 +59,17 @@ def encode(self) -> memoryview:
end = self.encode_into(buffer)
return memoryview(buffer)[:end]

def encode_into(self, buffer, offset=0):
for field in vars(type(self)):
descriptor_class = vars(type(self))[field]
if field.startswith("_") or not isinstance(descriptor_class, Member):
continue
def encode_into(self, buffer, offset: int = 0):
for _, descriptor_class in self._members(type(self)):
offset = descriptor_class.encode_into(self, buffer, offset)
return offset

@staticmethod
def _members(class_: Type) -> Iterable[tuple[str, Member]]:
for field_name, descriptor in vars(class_).items():
if not field_name.startswith("_") and isinstance(descriptor, Member):
yield field_name, descriptor

def scan_until(self, tag):
if self.buffer is None:
return
Expand Down Expand Up @@ -169,8 +168,14 @@ def scan_until(self, tag):
break


class Member:
def __init__(self, tag, optional=False, nullable=False):
_T = TypeVar("_T")
_NULLABLE = TypeVar("_NULLABLE", Literal[True], Literal[False])


class Member(ABC, Generic[_T, _NULLABLE]):
max_value_length: int = 0

def __init__(self, tag, *, optional: bool = False, nullable: _NULLABLE = False):
self.tag = tag
self.optional = optional
self.nullable = nullable
Expand All @@ -185,11 +190,21 @@ def __init__(self, tag, optional=False, nullable=False):
def max_length(self):
return 1 + self.tag_length + self.max_value_length

@overload
def __get__(
self,
obj: Optional[TLVStructure],
self: Member[_T, Literal[True]],
obj: TLVStructure,
objtype: Optional[Type[TLVStructure]] = None,
) -> Any:
) -> Optional[_T]: ...

@overload
def __get__(
self: Member[_T, Literal[False]],
obj: TLVStructure,
objtype: Optional[Type[TLVStructure]] = None,
) -> _T: ...

def __get__(self, obj, objtype=None):
if self.tag in obj.cached_values:
return obj.cached_values[self.tag]
if self.tag not in obj.tag_value_offset:
Expand All @@ -205,13 +220,21 @@ def __get__(
obj.cached_values[self.tag] = value
return value

def __set__(self, obj: TLVStructure, value: Any) -> None:
@overload
def __set__(
self: Member[_T, Literal[True]], obj: TLVStructure, value: Optional[_T]
) -> None: ...
@overload
def __set__(
self: Member[_T, Literal[False]], obj: TLVStructure, value: _T
) -> None: ...
def __set__(self, obj, value):
if value is None and not self.nullable:
raise ValueError("Not nullable")
obj.cached_values[self.tag] = value

def encode_into(self, obj: TLVStructure, buffer: bytearray, offset: int) -> int:
value = self.__get__(obj)
value = self.__get__(obj) # type: ignore # self inference issues
element_type = ElementType.NULL
if value is not None:
element_type = self.encode_element_type(value)
Expand All @@ -227,19 +250,57 @@ def encode_into(self, obj: TLVStructure, buffer: bytearray, offset: int) -> int:
buffer[offset] = self.tag
offset += 1
if value is not None:
new_offset = self.encode_value_into(value, buffer, offset)
new_offset = self.encode_value_into( # type: ignore # self inference issues
value,
buffer,
offset,
)
return new_offset
return offset

def print(self, obj):
value = self.__get__(obj)
value = self.__get__(obj) # type: ignore # self inference issues
if value is None:
return "null"
return self._print(value)


class NumberMember(Member):
def __init__(self, tag, _format, optional=False):
@abstractmethod
def decode(self, buffer, length: int, offset: int = 0) -> _T:
"Return the decoded value at `offset` in `buffer`"
...

@abstractmethod
def encode_element_type(self, value: _T) -> int:
"Return Element Type Field as defined in Appendix A in the spec"
...

@overload
@abstractmethod
def encode_value_into(
self: Member[_T, Literal[True]], value: Optional[_T], buffer, offset: int
) -> int: ...
@overload
@abstractmethod
def encode_value_into(
self: Member[_T, Literal[False]], value: _T, buffer, offset: int
) -> int: ...
@abstractmethod
def encode_value_into(self, value, buffer, offset: int) -> int:
"Encode `value` into `buffer` and return the new offset"
...

@abstractmethod
def _print(self, value: _T) -> str:
"Return string representation of `value`"
...


# number type
_NT = TypeVar("_NT", float, int)


class NumberMember(Member[_NT, _NULLABLE], Generic[_NT, _NULLABLE]):
def __init__(self, tag, _format: str, nullable: _NULLABLE = False, **kwargs):
self.format = _format
self.integer = _format[-1].upper() in INT_SIZE
self.signed = self.format.islower()
Expand All @@ -253,22 +314,22 @@ def __init__(self, tag, _format, optional=False):
self._element_type = ElementType.FLOAT
if self.max_value_length == 8:
self._element_type |= 1
super().__init__(tag, optional)
super().__init__(tag, nullable=nullable, **kwargs)

def __set__(self, obj, value):
if self.integer:
if value is not None and self.integer:
octets = 2 ** INT_SIZE.index(self.format.upper()[-1])
bits = 8 * octets
max_size = (2 ** (bits - 1) if self.signed else 2**bits) - 1
min_size = -max_size - 1 if self.signed else 0
max_size: int = (2 ** (bits - 1) if self.signed else 2**bits) - 1
min_size: int = -max_size - 1 if self.signed else 0
if not min_size <= value <= max_size:
raise ValueError(
f"Out of bounds for {octets} octet {'' if self.signed else 'un'}signed int"
)

super().__set__(obj, value)
super().__set__(obj, value) # type: ignore # self inference issues

def decode(self, buffer, length, offset=0):
def decode(self, buffer, length, offset=0) -> _NT:
if self.integer:
encoded_format = INT_SIZE[int(math.log(length, 2))]
if self.format.islower():
Expand All @@ -294,23 +355,26 @@ def encode_value_into(self, value, buffer, offset) -> int:
return offset + self.max_value_length


IntOctetCount = Union[Literal[1], Literal[2], Literal[4], Literal[8]]


class IntMember(NumberMember):
class IntMember(NumberMember[int, _NULLABLE]):
def __init__(
self, tag, /, signed: bool = True, octets: IntOctetCount = 1, optional=False
self,
tag,
*,
signed: bool = True,
octets: Literal[1, 2, 4, 8] = 1,
nullable: _NULLABLE = False,
**kwargs,
):
uformat = INT_SIZE[int(math.log2(octets))]
# little-endian
self.format = f"<{uformat.lower() if signed else uformat}"
super().__init__(tag, _format=self.format, optional=optional)
super().__init__(tag, _format=self.format, nullable=nullable, **kwargs)


class BoolMember(Member):
class BoolMember(Member[bool, _NULLABLE]):
max_value_length = 0

def decode(self, buffer, length, offset=0) -> bool:
def decode(self, buffer, length, offset=0):
octet = buffer[offset]
return octet & 1 == 1

Expand All @@ -326,18 +390,18 @@ def encode_value_into(self, value, buffer, offset) -> int:
return offset


class OctetStringMember(Member):
_base_element_type = ElementType.OCTET_STRING
class StringMember(Member[AnyStr, _NULLABLE], Generic[AnyStr, _NULLABLE]):
_base_element_type: ElementType

def __init__(self, tag, max_length, optional=False):
def __init__(
self, tag, max_length, *, optional=False, nullable: _NULLABLE = False, **kwargs
):
self.max_value_length = max_length
length_encoding = 0
while max_length > (256 ** (length_encoding + 1)):
length_encoding += 1
length_encoding = int(math.log(max_length, 256))
self._element_type = self._base_element_type | length_encoding
self.length_format = INT_SIZE[length_encoding]
self.length_length = struct.calcsize(self.length_format)
super().__init__(tag, optional)
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)

def decode(self, buffer, length, offset=0):
return buffer[offset : offset + length]
Expand All @@ -355,7 +419,11 @@ def encode_value_into(self, value, buffer, offset) -> int:
return offset + len(value)


class UTF8StringMember(OctetStringMember):
class OctetStringMember(StringMember[bytes, _NULLABLE]):
_base_element_type: ElementType = ElementType.OCTET_STRING


class UTF8StringMember(StringMember[str, _NULLABLE]):
_base_element_type = ElementType.UTF8_STRING

def decode(self, buffer, length, offset=0):
Expand All @@ -372,7 +440,7 @@ class StructMember(Member):
def __init__(self, tag, substruct_class, optional=False):
self.substruct_class = substruct_class
self.max_value_length = substruct_class.max_length() + 1
super().__init__(tag, optional)
super().__init__(tag, optional=optional)

def decode(self, buffer, length, offset=0) -> TLVStructure:
return self.substruct_class(buffer[offset : offset + length])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ test = [
"hypothesis",
"pytest",
"pytest-cov",
# "typing_extensions",
"typing_extensions",
]

[tool.coverage.run]
Expand Down
39 changes: 35 additions & 4 deletions tests/test_tlv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from circuitmatter import tlv
from hypothesis import given, strategies as st
import math
from typing import Optional

import pytest
from hypothesis import given
from hypothesis import strategies as st
from typing_extensions import assert_type

import math
from circuitmatter import tlv

# Test TLV encoding using examples from spec

Expand Down Expand Up @@ -194,6 +198,21 @@ def test_roundtrip(self, v: int):
assert s2.i == s.i
assert str(s2) == str(s)

def test_nullability(self):
class Struct(tlv.TLVStructure):
i = tlv.IntMember(None)
ni = tlv.IntMember(None, nullable=True)

s = Struct()
assert_type(s.i, int)
assert_type(s.ni, Optional[int])

s.ni = None
assert s.ni is None

with pytest.raises(ValueError):
s.i = None


# UTF-8 String, 1-octet length, "Hello!"
# 0c 06 48 65 6c 6c 6f 21
Expand Down Expand Up @@ -270,7 +289,12 @@ def test_roundtrip(self, v: bytes):


class Null(tlv.TLVStructure):
n = tlv.BoolMember(None, nullable=True)
n = tlv.BoolMember(None, nullable=True, optional=False)


class NotNull(tlv.TLVStructure):
n = tlv.BoolMember(None, nullable=True, optional=False)
b = tlv.BoolMember(None)


class TestNull:
Expand All @@ -284,6 +308,13 @@ def test_null_encode(self):
s.n = None
assert s.encode().tobytes() == b"\x14"

def test_nullable(self):
s = NotNull()

assert_type(s.b, bool)
with pytest.raises(ValueError):
s.b = None # type: ignore # testing runtime behaviour


# Single precision floating point 0.0
# 0a 00 00 00 00
Expand Down

0 comments on commit d6917a9

Please sign in to comment.