diff --git a/circuitmatter/__init__.py b/circuitmatter/__init__.py index 7d2b8bf..20db8a1 100644 --- a/circuitmatter/__init__.py +++ b/circuitmatter/__init__.py @@ -43,7 +43,7 @@ def __init__( # Define the UDP IP address and port UDP_IP = "::" # Listen on all available network interfaces - self.UDP_PORT = 5540 + self.UDP_PORT = 5541 # Create the UDP socket self.socket = self.socketpool.socket( @@ -63,6 +63,7 @@ def __init__( basic_info = data_model.BasicInformationCluster() basic_info.vendor_id = vendor_id basic_info.product_id = product_id + basic_info.product_name = "CircuitMatter" self.add_cluster(0, basic_info) group_keys = core.GroupKeyManagementCluster() self.add_cluster(0, group_keys) @@ -150,8 +151,11 @@ def process_packets(self): self.process_packet(addr, self.packet_buffer[:nbytes]) def get_report(self, cluster, path): - report = interaction_model.AttributeReportIB() - report.AttributeData = cluster.get_attribute_data(path) + reports = [] + for data in cluster.get_attribute_data(path): + report = interaction_model.AttributeReportIB() + report.AttributeData = data + reports.append(report) # Only add status if an error occurs # astatus = interaction_model.AttributeStatusIB() # astatus.Path = path @@ -160,7 +164,7 @@ def get_report(self, cluster, path): # status.ClusterStatus = 0 # astatus.Status = status # report.AttributeStatus = astatus - return report + return reports def invoke(self, session, cluster, path, fields, command_ref): print("invoke", path) @@ -412,13 +416,13 @@ def process_packet(self, address, data): path.Endpoint = endpoint print(path.Endpoint) print(path) - attribute_reports.append(self.get_report(cluster, path)) + attribute_reports.extend(self.get_report(cluster, path)) else: print(f"Cluster 0x{path.Cluster:02x} not found") else: if path.Cluster in self._endpoints[path.Endpoint]: cluster = self._endpoints[path.Endpoint][path.Cluster] - attribute_reports.append(self.get_report(cluster, path)) + attribute_reports.extend(self.get_report(cluster, path)) else: print(f"Cluster 0x{path.Cluster:02x} not found") response = interaction_model.ReportDataMessage() @@ -478,6 +482,21 @@ def process_packet(self, address, data): ) elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE: print("Received Invoke Response") + elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST: + print("Received Subscribe Request") + subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode( + message.application_payload[0], message.application_payload[1:] + ) + print(subscribe_request) + error_status = session.StatusReport() + error_status.general_code = session.GeneralCode.UNSUPPORTED + error_status.protocol_id = ProtocolId.SECURE_CHANNEL + exchange.send( + ProtocolId.SECURE_CHANNEL, + SecureProtocolOpcode.STATUS_REPORT, + error_status, + ) + else: print(message) print("application payload", message.application_payload.hex(" ")) diff --git a/circuitmatter/__main__.py b/circuitmatter/__main__.py index f31add0..d852064 100644 --- a/circuitmatter/__main__.py +++ b/circuitmatter/__main__.py @@ -134,8 +134,8 @@ def advertise_service( if self.publish_address is None: command = [ "avahi-publish-address", - f"{instance_name}.local", - "fe80::642:1aff:fe0c:9f2a", + "dalinar.local", + "fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", # "fe80::642:1aff:fe0c:9f2a", ] print("run", command) self.publish_address = subprocess.Popen(command) diff --git a/circuitmatter/data_model.py b/circuitmatter/data_model.py index 63e81b3..b50fa8f 100644 --- a/circuitmatter/data_model.py +++ b/circuitmatter/data_model.py @@ -1,6 +1,7 @@ import enum import random import struct +import typing from typing import Iterable, Union from . import interaction_model @@ -123,9 +124,9 @@ def encode(self, value) -> bytes: class StructAttribute(Attribute): - def __init__(self, _id, struct_type): + def __init__(self, _id, struct_type, default=None): self.struct_type = struct_type - super().__init__(_id) + super().__init__(_id, default=default) def encode(self, value) -> memoryview: buffer = memoryview(bytearray(value.max_length() + 2)) @@ -145,8 +146,13 @@ class UTF8StringAttribute(Attribute): def __init__(self, _id, min_length=0, max_length=1200, default=None): self.min_length = min_length self.max_length = max_length + self.member = tlv.UTF8StringMember(None, max_length=max_length) super().__init__(_id, default=default) + def encode(self, value): + print(repr(value)) + return self.member.encode(value) + class BitmapAttribute(Attribute): pass @@ -175,24 +181,31 @@ def _attributes(cls) -> Iterable[tuple[str, Attribute]]: if not field_name.startswith("_") and isinstance(descriptor, Attribute): yield field_name, descriptor - def get_attribute_data(self, path) -> interaction_model.AttributeDataIB: - data = interaction_model.AttributeDataIB() - data.DataVersion = 0 - data.Path = path - found = False + def get_attribute_data( + self, path + ) -> typing.List[interaction_model.AttributeDataIB]: + replies = [] for field_name, descriptor in self._attributes(): - if descriptor.id != path.Attribute: + if path.Attribute is not None and descriptor.id != path.Attribute: continue print("reading", field_name) value = getattr(self, field_name) print("encoding anything", value) + data = interaction_model.AttributeDataIB() + data.DataVersion = 0 + attribute_path = interaction_model.AttributePathIB() + attribute_path.Endpoint = path.Endpoint + attribute_path.Cluster = path.Cluster + attribute_path.Attribute = descriptor.id + data.Path = attribute_path data.Data = descriptor.encode(value) print("get", field_name, data.Data.hex(" ")) - found = True - break - if not found: + replies.append(data) + if path.Attribute is not None: + break + if not replies: print("not found", path.Attribute) - return data + return replies @classmethod def _commands(cls) -> Iterable[tuple[str, Command]]: @@ -297,30 +310,42 @@ class CapabilityMinima(tlv.Structure): ) class ProductAppearance(tlv.Structure): - Finish = tlv.EnumMember(0, ProductFinish) - PrimaryColor = tlv.EnumMember(1, Color) + Finish = tlv.EnumMember(0, ProductFinish, default=ProductFinish.OTHER) + PrimaryColor = tlv.EnumMember(1, Color, default=Color.BLACK) - data_model_revision = NumberAttribute(0x00, signed=False, bits=16) - vendor_name = UTF8StringAttribute(0x01, max_length=32) + data_model_revision = NumberAttribute(0x00, signed=False, bits=16, default=16) + vendor_name = UTF8StringAttribute(0x01, max_length=32, default="CircuitMatter") vendor_id = NumberAttribute(0x02, signed=False, bits=16) - product_name = UTF8StringAttribute(0x03, max_length=32) + product_name = UTF8StringAttribute(0x03, max_length=32, default="Test Device") product_id = NumberAttribute(0x04, signed=False, bits=16) node_label = UTF8StringAttribute(0x05, max_length=32, default="") location = UTF8StringAttribute(0x06, max_length=2, default="XX") - hardware_version = NumberAttribute(0x07, signed=False, bits=16) - hardware_version_string = UTF8StringAttribute(0x08, min_length=1, max_length=64) - software_version = NumberAttribute(0x09, signed=False, bits=32) - software_version_string = UTF8StringAttribute(0x0A, min_length=1, max_length=64) - manufacturing_date = UTF8StringAttribute(0x0B, min_length=8, max_length=16) - part_number = UTF8StringAttribute(0x0C, max_length=32) - product_url = UTF8StringAttribute(0x0D, max_length=256) - product_label = UTF8StringAttribute(0x0E, max_length=64) - serial_number = UTF8StringAttribute(0x0F, max_length=32) + hardware_version = NumberAttribute(0x07, signed=False, bits=16, default=0) + hardware_version_string = UTF8StringAttribute( + 0x08, min_length=1, max_length=64, default="Unknown" + ) + software_version = NumberAttribute(0x09, signed=False, bits=32, default=0) + software_version_string = UTF8StringAttribute( + 0x0A, min_length=1, max_length=64, default="Unknown" + ) + manufacturing_date = UTF8StringAttribute( + 0x0B, min_length=8, max_length=16, default="Unknown" + ) + part_number = UTF8StringAttribute(0x0C, max_length=32, default="") + product_url = UTF8StringAttribute( + 0x0D, max_length=256, default="https://github.com/adafruit/circuitmatter" + ) + product_label = UTF8StringAttribute(0x0E, max_length=64, default="") + serial_number = UTF8StringAttribute(0x0F, max_length=32, default="") local_config_disabled = BoolAttribute(0x10, default=False) reachable = BoolAttribute(0x11, default=True) - unique_id = UTF8StringAttribute(0x12, max_length=32) - capability_minima = StructAttribute(0x13, CapabilityMinima) - product_appearance = StructAttribute(0x14, ProductAppearance) + unique_id = UTF8StringAttribute(0x12, max_length=32, default="") + capability_minima = StructAttribute( + 0x13, CapabilityMinima, default=CapabilityMinima() + ) + product_appearance = StructAttribute( + 0x14, ProductAppearance, default=ProductAppearance() + ) specification_version = NumberAttribute(0x15, signed=False, bits=32, default=0) max_paths_per_invoke = NumberAttribute(0x16, signed=False, bits=16, default=1) diff --git a/circuitmatter/exchange.py b/circuitmatter/exchange.py index 875a08e..29feed3 100644 --- a/circuitmatter/exchange.py +++ b/circuitmatter/exchange.py @@ -81,7 +81,8 @@ def receive(self, message) -> bool: # Drop messages that are missing an acknowledgement counter. return True if ( - message.acknowledged_message_counter + self.pending_retransmission is not None + and message.acknowledged_message_counter != self.pending_retransmission.message_counter ): # Drop messages that have the wrong acknowledgement counter. diff --git a/circuitmatter/interaction_model.py b/circuitmatter/interaction_model.py index fab7f7f..d158995 100644 --- a/circuitmatter/interaction_model.py +++ b/circuitmatter/interaction_model.py @@ -76,7 +76,7 @@ class AttributePathIB(tlv.List): WildcardPathFlags = tlv.IntMember(6, signed=False, octets=4, optional=True) -class EventPathIB(tlv.Structure): +class EventPathIB(tlv.List): """Section 10.6.8""" Node = tlv.IntMember(0, signed=False, octets=8) @@ -202,3 +202,14 @@ class InvokeResponseMessage(InteractionModelMessage): SuppressResponse = tlv.BoolMember(0) InvokeResponses = tlv.ArrayMember(1, InvokeResponseIB) MoreChunkedMessages = tlv.BoolMember(2, optional=True) + + +class SubscribeRequestMessage(InteractionModelMessage): + KeepSubscriptions = tlv.BoolMember(0) + MinIntervalFloor = tlv.IntMember(1, signed=False, octets=2) + MaxIntervalCeiling = tlv.IntMember(2, signed=False, octets=2) + AttributeRequests = tlv.ArrayMember(3, AttributePathIB, optional=True) + EventRequests = tlv.ArrayMember(4, EventPathIB, optional=True) + EventFilters = tlv.ArrayMember(5, EventFilterIB, optional=True) + FabricFiltered = tlv.BoolMember(7) + DataVersionFilters = tlv.ArrayMember(8, DataVersionFilterIB, optional=True) diff --git a/circuitmatter/message.py b/circuitmatter/message.py index 7cd30fe..5f690b4 100644 --- a/circuitmatter/message.py +++ b/circuitmatter/message.py @@ -26,6 +26,7 @@ class SecurityFlags(enum.IntFlag): class Message: def __init__(self): self.clear() + self.buffer = None def clear(self): self.flags: int = 0 @@ -94,6 +95,15 @@ def decode(self, buffer): else: self.source_node_id = 0 + dsiz = self.flags & 0b11 + if dsiz == 1: + self.destination_node_id = struct.unpack_from("> 4) != 0: raise RuntimeError("Incorrect version") self.secure_session = not ( @@ -106,6 +116,9 @@ def decode(self, buffer): self.duplicate = None def encode_into(self, buffer, cipher=None): + if self.buffer is not None: + buffer[: len(self.buffer)] = self.buffer + return len(self.buffer) offset = 0 struct.pack_into( "