Skip to content

Commit

Permalink
Merge pull request #3 from CharString/adding_stricter_type_annotations
Browse files Browse the repository at this point in the history
Adding stricter type annotations
  • Loading branch information
tannewt authored Jul 24, 2024
2 parents 05388f1 + b74a17d commit aa1b3fb
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
24 changes: 12 additions & 12 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import binascii
import enum
import pathlib
import json
import os
import pathlib
import struct
import time

Expand Down Expand Up @@ -133,13 +133,13 @@ class SecureProtocolOpcode(enum.IntEnum):
# : UNSIGNED INTEGER [ range 16-bits ],
# }
class SessionParameterStruct(tlv.TLVStructure):
session_idle_interval = tlv.NumberMember(1, "<I", optional=True)
session_active_interval = tlv.NumberMember(2, "<I", optional=True)
session_active_threshold = tlv.NumberMember(3, "<H", optional=True)
data_model_revision = tlv.NumberMember(4, "<H")
interaction_model_revision = tlv.NumberMember(5, "<H")
specification_version = tlv.NumberMember(6, "<I")
max_paths_per_invoke = tlv.NumberMember(7, "<H")
session_idle_interval = tlv.IntMember(1, signed=False, octets=4, optional=True)
session_active_interval = tlv.IntMember(2, signed=False, octets=4, optional=True)
session_active_threshold = tlv.IntMember(3, signed=False, octets=2, optional=True)
data_model_revision = tlv.IntMember(4, signed=False, octets=2)
interaction_model_revision = tlv.IntMember(5, signed=False, octets=2)
specification_version = tlv.IntMember(6, signed=False, octets=4)
max_paths_per_invoke = tlv.IntMember(7, signed=False, octets=2)


# pbkdfparamreq-struct => STRUCTURE [ tag-order ]
Expand All @@ -156,8 +156,8 @@ class SessionParameterStruct(tlv.TLVStructure):
# }
class PBKDFParamRequest(tlv.TLVStructure):
initiatorRandom = tlv.OctetStringMember(1, 32)
initiatorSessionId = tlv.NumberMember(2, "<H")
passcodeId = tlv.NumberMember(3, "<H")
initiatorSessionId = tlv.IntMember(2, signed=False, octets=2)
passcodeId = tlv.IntMember(3, signed=False, octets=2)
hasPBKDFParameters = tlv.BoolMember(4)
initiatorSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)

Expand All @@ -168,7 +168,7 @@ class PBKDFParamRequest(tlv.TLVStructure):
# salt [2] : OCTET STRING [ length 16..32 ],
# }
class Crypto_PBKDFParameterSet(tlv.TLVStructure):
iterations = tlv.NumberMember(1, "<I")
iterations = tlv.IntMember(1, signed=False, octets=4)
salt = tlv.OctetStringMember(2, 32)


Expand All @@ -187,7 +187,7 @@ class Crypto_PBKDFParameterSet(tlv.TLVStructure):
class PBKDFParamResponse(tlv.TLVStructure):
initiatorRandom = tlv.OctetStringMember(1, 32)
responderRandom = tlv.OctetStringMember(2, 32)
responderSessionId = tlv.NumberMember(3, "<H")
responderSessionId = tlv.IntMember(3, signed=False, octets=2)
pbkdf_parameters = tlv.StructMember(4, Crypto_PBKDFParameterSet)
responderSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)

Expand Down
34 changes: 33 additions & 1 deletion circuitmatter/tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,46 @@ def __init__(
nullable: _NULLABLE = False,
**kwargs,
):
"""
:param octets: Number of octests to use for encoding.
1, 2, 4, 8 are 8, 16, 32, and 64 bits respectively
:param optional: Indicates whether the value MAY be omitted from the encoding.
Can be used for deprecation.
:param nullable: Indicates whether a TLV Null MAY be encoded in place of a value.
"""
# TODO 7.18.1 mentions other bit lengths (that are not a power of 2) than the TLV Appendix
uformat = INT_SIZE[int(math.log2(octets))]
# little-endian
# < = little-endian
self.format = f"<{uformat.lower() if signed else uformat}"
super().__init__(
tag, _format=self.format, optional=optional, nullable=nullable, **kwargs
)


class FloatMember(NumberMember[float, _OPT, _NULLABLE]):
def __init__(
self,
tag,
*,
octets: Literal[4, 8] = 4,
optional: _OPT = False,
nullable: _NULLABLE = False,
**kwargs,
):
"""
:param octets: Number of octests to use for encoding.
4, 8 are single and double precision floats respectively.
:param optional: Indicates whether the value MAY be omitted from the encoding.
Can be used for deprecation.
:param nullable: Indicates whether a TLV Null MAY be encoded in place of a value.
"""
# < = little-endian
self.format = f"<{'f' if octets == 4 else 'd'}"
super().__init__(
tag, _format=self.format, optional=optional, nullable=nullable, **kwargs
)


class BoolMember(Member[bool, _OPT, _NULLABLE]):
max_value_length = 0

Expand Down
29 changes: 24 additions & 5 deletions tests/test_tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,11 @@ def test_nullable(self):
# Double precision floating point negative infinity 0b 00 00 00 00 00 00 f0 ff
# (-∞)
class FloatSingle(tlv.TLVStructure):
f = tlv.NumberMember(None, "f")
f = tlv.FloatMember(None)


class FloatDouble(tlv.TLVStructure):
f = tlv.NumberMember(None, "d")
f = tlv.FloatMember(None, octets=8)


class TestFloatSingle:
Expand Down Expand Up @@ -398,12 +398,12 @@ def test_precision_float_negative_infinity_encode(self):
assert s.encode().tobytes() == b"\x0a\x00\x00\x80\xff"

@given(v=...)
def test_roundtrip(self, v: float):
s = FloatSingle()
def test_roundtrip_double(self, v: float):
s = FloatDouble()
s.f = v
buffer = s.encode().tobytes()

s2 = FloatSingle(buffer)
s2 = FloatDouble(buffer)

assert (
(math.isnan(s.f) and math.isnan(s2.f))
Expand All @@ -412,6 +412,25 @@ def test_roundtrip(self, v: float):
or math.isclose(s2.f, s.f, rel_tol=1e-7, abs_tol=1e-9)
)

@given(
v=st.floats(
# encoding to LE float32 raises OverflowError outside these ranges
# TODO: should we raise ValueError with a bounds check or encode -inf/inf?
min_value=(2**-126),
max_value=(2 - 2**-23) * 2**127,
),
)
def test_roundtrip_single(self, v: float):
s = FloatSingle()
s.f = v
buffer = s.encode().tobytes()

s2 = FloatSingle(buffer)

assert (math.isnan(s.f) and math.isnan(s2.f)) or math.isclose(
s2.f, s.f, rel_tol=1e-7, abs_tol=1e-9
)


class TestFloatDouble:
def test_precision_float_0_0_decode(self):
Expand Down

0 comments on commit aa1b3fb

Please sign in to comment.