Skip to content

Commit

Permalink
Fixes for matter.js interop
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Oct 4, 2024
1 parent ede6818 commit 97d7402
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 129 deletions.
31 changes: 25 additions & 6 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(

# Define the UDP IP address and port
UDP_IP = "::" # Listen on all available network interfaces
self.UDP_PORT = 5540
self.UDP_PORT = 5541

# Create the UDP socket
self.socket = self.socketpool.socket(
Expand All @@ -63,6 +63,7 @@ def __init__(
basic_info = data_model.BasicInformationCluster()
basic_info.vendor_id = vendor_id
basic_info.product_id = product_id
basic_info.product_name = "CircuitMatter"
self.add_cluster(0, basic_info)
group_keys = core.GroupKeyManagementCluster()
self.add_cluster(0, group_keys)
Expand Down Expand Up @@ -150,8 +151,11 @@ def process_packets(self):
self.process_packet(addr, self.packet_buffer[:nbytes])

def get_report(self, cluster, path):
report = interaction_model.AttributeReportIB()
report.AttributeData = cluster.get_attribute_data(path)
reports = []
for data in cluster.get_attribute_data(path):
report = interaction_model.AttributeReportIB()
report.AttributeData = data
reports.append(report)
# Only add status if an error occurs
# astatus = interaction_model.AttributeStatusIB()
# astatus.Path = path
Expand All @@ -160,7 +164,7 @@ def get_report(self, cluster, path):
# status.ClusterStatus = 0
# astatus.Status = status
# report.AttributeStatus = astatus
return report
return reports

def invoke(self, session, cluster, path, fields, command_ref):
print("invoke", path)
Expand Down Expand Up @@ -412,13 +416,13 @@ def process_packet(self, address, data):
path.Endpoint = endpoint
print(path.Endpoint)
print(path)
attribute_reports.append(self.get_report(cluster, path))
attribute_reports.extend(self.get_report(cluster, path))
else:
print(f"Cluster 0x{path.Cluster:02x} not found")
else:
if path.Cluster in self._endpoints[path.Endpoint]:
cluster = self._endpoints[path.Endpoint][path.Cluster]
attribute_reports.append(self.get_report(cluster, path))
attribute_reports.extend(self.get_report(cluster, path))
else:
print(f"Cluster 0x{path.Cluster:02x} not found")
response = interaction_model.ReportDataMessage()
Expand Down Expand Up @@ -478,6 +482,21 @@ def process_packet(self, address, data):
)
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
print("Received Invoke Response")
elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST:
print("Received Subscribe Request")
subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
)
print(subscribe_request)
error_status = session.StatusReport()
error_status.general_code = session.GeneralCode.UNSUPPORTED
error_status.protocol_id = ProtocolId.SECURE_CHANNEL
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.STATUS_REPORT,
error_status,
)

else:
print(message)
print("application payload", message.application_payload.hex(" "))
Expand Down
4 changes: 2 additions & 2 deletions circuitmatter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def advertise_service(
if self.publish_address is None:
command = [
"avahi-publish-address",
f"{instance_name}.local",
"fe80::642:1aff:fe0c:9f2a",
"dalinar.local",
"fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", # "fe80::642:1aff:fe0c:9f2a",
]
print("run", command)
self.publish_address = subprocess.Popen(command)
Expand Down
83 changes: 54 additions & 29 deletions circuitmatter/data_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import random
import struct
import typing
from typing import Iterable, Union

from . import interaction_model
Expand Down Expand Up @@ -123,9 +124,9 @@ def encode(self, value) -> bytes:


class StructAttribute(Attribute):
def __init__(self, _id, struct_type):
def __init__(self, _id, struct_type, default=None):
self.struct_type = struct_type
super().__init__(_id)
super().__init__(_id, default=default)

def encode(self, value) -> memoryview:
buffer = memoryview(bytearray(value.max_length() + 2))
Expand All @@ -145,8 +146,13 @@ class UTF8StringAttribute(Attribute):
def __init__(self, _id, min_length=0, max_length=1200, default=None):
self.min_length = min_length
self.max_length = max_length
self.member = tlv.UTF8StringMember(None, max_length=max_length)
super().__init__(_id, default=default)

def encode(self, value):
print(repr(value))
return self.member.encode(value)


class BitmapAttribute(Attribute):
pass
Expand Down Expand Up @@ -175,24 +181,31 @@ def _attributes(cls) -> Iterable[tuple[str, Attribute]]:
if not field_name.startswith("_") and isinstance(descriptor, Attribute):
yield field_name, descriptor

def get_attribute_data(self, path) -> interaction_model.AttributeDataIB:
data = interaction_model.AttributeDataIB()
data.DataVersion = 0
data.Path = path
found = False
def get_attribute_data(
self, path
) -> typing.List[interaction_model.AttributeDataIB]:
replies = []
for field_name, descriptor in self._attributes():
if descriptor.id != path.Attribute:
if path.Attribute is not None and descriptor.id != path.Attribute:
continue
print("reading", field_name)
value = getattr(self, field_name)
print("encoding anything", value)
data = interaction_model.AttributeDataIB()
data.DataVersion = 0
attribute_path = interaction_model.AttributePathIB()
attribute_path.Endpoint = path.Endpoint
attribute_path.Cluster = path.Cluster
attribute_path.Attribute = descriptor.id
data.Path = attribute_path
data.Data = descriptor.encode(value)
print("get", field_name, data.Data.hex(" "))
found = True
break
if not found:
replies.append(data)
if path.Attribute is not None:
break
if not replies:
print("not found", path.Attribute)
return data
return replies

@classmethod
def _commands(cls) -> Iterable[tuple[str, Command]]:
Expand Down Expand Up @@ -297,30 +310,42 @@ class CapabilityMinima(tlv.Structure):
)

class ProductAppearance(tlv.Structure):
Finish = tlv.EnumMember(0, ProductFinish)
PrimaryColor = tlv.EnumMember(1, Color)
Finish = tlv.EnumMember(0, ProductFinish, default=ProductFinish.OTHER)
PrimaryColor = tlv.EnumMember(1, Color, default=Color.BLACK)

data_model_revision = NumberAttribute(0x00, signed=False, bits=16)
vendor_name = UTF8StringAttribute(0x01, max_length=32)
data_model_revision = NumberAttribute(0x00, signed=False, bits=16, default=16)
vendor_name = UTF8StringAttribute(0x01, max_length=32, default="CircuitMatter")
vendor_id = NumberAttribute(0x02, signed=False, bits=16)
product_name = UTF8StringAttribute(0x03, max_length=32)
product_name = UTF8StringAttribute(0x03, max_length=32, default="Test Device")
product_id = NumberAttribute(0x04, signed=False, bits=16)
node_label = UTF8StringAttribute(0x05, max_length=32, default="")
location = UTF8StringAttribute(0x06, max_length=2, default="XX")
hardware_version = NumberAttribute(0x07, signed=False, bits=16)
hardware_version_string = UTF8StringAttribute(0x08, min_length=1, max_length=64)
software_version = NumberAttribute(0x09, signed=False, bits=32)
software_version_string = UTF8StringAttribute(0x0A, min_length=1, max_length=64)
manufacturing_date = UTF8StringAttribute(0x0B, min_length=8, max_length=16)
part_number = UTF8StringAttribute(0x0C, max_length=32)
product_url = UTF8StringAttribute(0x0D, max_length=256)
product_label = UTF8StringAttribute(0x0E, max_length=64)
serial_number = UTF8StringAttribute(0x0F, max_length=32)
hardware_version = NumberAttribute(0x07, signed=False, bits=16, default=0)
hardware_version_string = UTF8StringAttribute(
0x08, min_length=1, max_length=64, default="Unknown"
)
software_version = NumberAttribute(0x09, signed=False, bits=32, default=0)
software_version_string = UTF8StringAttribute(
0x0A, min_length=1, max_length=64, default="Unknown"
)
manufacturing_date = UTF8StringAttribute(
0x0B, min_length=8, max_length=16, default="Unknown"
)
part_number = UTF8StringAttribute(0x0C, max_length=32, default="")
product_url = UTF8StringAttribute(
0x0D, max_length=256, default="https://github.com/adafruit/circuitmatter"
)
product_label = UTF8StringAttribute(0x0E, max_length=64, default="")
serial_number = UTF8StringAttribute(0x0F, max_length=32, default="")
local_config_disabled = BoolAttribute(0x10, default=False)
reachable = BoolAttribute(0x11, default=True)
unique_id = UTF8StringAttribute(0x12, max_length=32)
capability_minima = StructAttribute(0x13, CapabilityMinima)
product_appearance = StructAttribute(0x14, ProductAppearance)
unique_id = UTF8StringAttribute(0x12, max_length=32, default="")
capability_minima = StructAttribute(
0x13, CapabilityMinima, default=CapabilityMinima()
)
product_appearance = StructAttribute(
0x14, ProductAppearance, default=ProductAppearance()
)
specification_version = NumberAttribute(0x15, signed=False, bits=32, default=0)
max_paths_per_invoke = NumberAttribute(0x16, signed=False, bits=16, default=1)

Expand Down
3 changes: 2 additions & 1 deletion circuitmatter/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def receive(self, message) -> bool:
# Drop messages that are missing an acknowledgement counter.
return True
if (
message.acknowledged_message_counter
self.pending_retransmission is not None
and message.acknowledged_message_counter
!= self.pending_retransmission.message_counter
):
# Drop messages that have the wrong acknowledgement counter.
Expand Down
13 changes: 12 additions & 1 deletion circuitmatter/interaction_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class AttributePathIB(tlv.List):
WildcardPathFlags = tlv.IntMember(6, signed=False, octets=4, optional=True)


class EventPathIB(tlv.Structure):
class EventPathIB(tlv.List):
"""Section 10.6.8"""

Node = tlv.IntMember(0, signed=False, octets=8)
Expand Down Expand Up @@ -202,3 +202,14 @@ class InvokeResponseMessage(InteractionModelMessage):
SuppressResponse = tlv.BoolMember(0)
InvokeResponses = tlv.ArrayMember(1, InvokeResponseIB)
MoreChunkedMessages = tlv.BoolMember(2, optional=True)


class SubscribeRequestMessage(InteractionModelMessage):
KeepSubscriptions = tlv.BoolMember(0)
MinIntervalFloor = tlv.IntMember(1, signed=False, octets=2)
MaxIntervalCeiling = tlv.IntMember(2, signed=False, octets=2)
AttributeRequests = tlv.ArrayMember(3, AttributePathIB, optional=True)
EventRequests = tlv.ArrayMember(4, EventPathIB, optional=True)
EventFilters = tlv.ArrayMember(5, EventFilterIB, optional=True)
FabricFiltered = tlv.BoolMember(7)
DataVersionFilters = tlv.ArrayMember(8, DataVersionFilterIB, optional=True)
13 changes: 13 additions & 0 deletions circuitmatter/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SecurityFlags(enum.IntFlag):
class Message:
def __init__(self):
self.clear()
self.buffer = None

def clear(self):
self.flags: int = 0
Expand Down Expand Up @@ -94,6 +95,15 @@ def decode(self, buffer):
else:
self.source_node_id = 0

dsiz = self.flags & 0b11
if dsiz == 1:
self.destination_node_id = struct.unpack_from("<Q", buffer, offset)[0]
offset += 8
elif dsiz == 2:
self.destination_node_id = struct.unpack_from("<H", buffer, offset)[0]
offset += 2
self.destination_node_id |= 0xFFFF_FFFF_FFFF_0000

if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
self.secure_session = not (
Expand All @@ -106,6 +116,9 @@ def decode(self, buffer):
self.duplicate = None

def encode_into(self, buffer, cipher=None):
if self.buffer is not None:
buffer[: len(self.buffer)] = self.buffer
return len(self.buffer)
offset = 0
struct.pack_into(
"<BHBI",
Expand Down
3 changes: 0 additions & 3 deletions circuitmatter/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,8 @@ def reply_to_sigma1(self, exchange, sigma1):
)
print(candidate_destination_id.hex(), sigma1.destinationId.hex())
if sigma1.destinationId == candidate_destination_id:
print("matched!")
matching_noc = i
break
else:
print("didn't match")

if matching_noc is None:
error_status = StatusReport()
Expand Down
2 changes: 1 addition & 1 deletion circuitmatter/tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def construct_containers(self):
tags.remove(tag)
self.values[tag] = member_class.from_value(self.values[tag])
if tags:
raise RuntimeError(f"Unknown tags {tags}")
raise RuntimeError(f"Unknown tags {tags} in {type(self)}")

@classmethod
def from_value(cls, value):
Expand Down
2 changes: 1 addition & 1 deletion test_data/device_state.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"discriminator": 1983,
"discriminator": 3840,
"passcode": 67202583,
"iteration-count": 10000,
"salt": "5uCP0ITHYzI9qBEe6hfU4HfY3y7VopSk0qNvhvznhiQ=",
Expand Down
Loading

0 comments on commit 97d7402

Please sign in to comment.