Skip to content

Commit

Permalink
Start subscribe support and report chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Oct 10, 2024
1 parent edb2996 commit aa514ac
Show file tree
Hide file tree
Showing 17 changed files with 653 additions and 261 deletions.
198 changes: 89 additions & 109 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import time

from . import case
from .clusters import core
from . import data_model
from . import interaction_model
from .message import Message
from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode
from . import session
from .device_types.utility.root_node import RootNode

__version__ = "0.0.0"

Expand Down Expand Up @@ -56,38 +56,17 @@ def __init__(

self._endpoints = {}
self._next_endpoint = 0
self._descriptor = data_model.DescriptorCluster()
self._descriptor.PartsList = []
self._descriptor.ServerList = []
self.add_cluster(0, self._descriptor)
basic_info = data_model.BasicInformationCluster()
basic_info.vendor_id = vendor_id
basic_info.product_id = product_id
basic_info.product_name = "CircuitMatter"
self.add_cluster(0, basic_info)
access_control = data_model.AccessControlCluster()
self.add_cluster(0, access_control)
group_keys = core.GroupKeyManagementCluster()
self.add_cluster(0, group_keys)
network_info = data_model.NetworkCommissioningCluster()

ethernet = data_model.NetworkCommissioningCluster.NetworkInfoStruct()
ethernet.NetworkID = "enp13s0".encode("utf-8")
ethernet.Connected = True
network_info.networks = [ethernet]
network_info.connect_max_time_seconds = 10
self.add_cluster(0, network_info)
general_commissioning = core.GeneralCommissioningCluster()
self.add_cluster(0, general_commissioning)
noc = core.NodeOperationalCredentialsCluster(
group_keys, random_source, self.mdns_server, self.UDP_PORT
self.root_node = RootNode(
random_source, self.mdns_server, self.UDP_PORT, vendor_id, product_id
)
self.add_cluster(0, noc)
self.add_device(self.root_node)

self.vendor_id = vendor_id
self.product_id = product_id

self.manager = session.SessionManager(self.random, self.socket, noc)
self.manager = session.SessionManager(
self.random, self.socket, self.root_node.noc
)

print(f"Listening on UDP port {self.UDP_PORT}")

Expand Down Expand Up @@ -127,16 +106,28 @@ def add_cluster(self, endpoint, cluster):
if endpoint not in self._endpoints:
self._endpoints[endpoint] = {}
if endpoint > 0:
self._descriptor.PartsList.append(endpoint)
self.root_node.descriptor.PartsList.append(endpoint)
self._next_endpoint = max(self._next_endpoint, endpoint + 1)
if endpoint == 0:
self._descriptor.ServerList.append(cluster.CLUSTER_ID)
self._endpoints[endpoint][cluster.CLUSTER_ID] = cluster

def add_device(self, device):
self._endpoints[self._next_endpoint] = {}
if self._next_endpoint > 0:
self._descriptor.PartsList.append(self._next_endpoint)
self.root_node.descriptor.PartsList.append(self._next_endpoint)

device.descriptor = data_model.DescriptorCluster()
device_type = data_model.DescriptorCluster.DeviceTypeStruct()
device_type.DeviceType = device.DEVICE_TYPE_ID
device_type.Revision = device.REVISION
device.descriptor.DeviceTypeList = [device_type]
device.descriptor.PartsList = [self._next_endpoint]
device.descriptor.ServerList = []
device.descriptor.ClientList = []

for server in device.servers:
device.descriptor.ServerList.append(server.CLUSTER_ID)
self.add_cluster(self._next_endpoint, server)
self.add_cluster(self._next_endpoint, device.descriptor)
self._next_endpoint += 1

def process_packets(self):
Expand Down Expand Up @@ -203,6 +194,32 @@ def invoke(self, session, cluster, path, fields, command_ref):

return response

def read_attribute_path(self, path):
attribute_reports = []
if path.Endpoint is None:
endpoints = self._endpoints
else:
endpoints = [path.Endpoint]

# Wildcard so we get it from every endpoint.
for endpoint in endpoints:
if path.Cluster is None:
clusters = self._endpoints[endpoint].values()
else:
if path.Cluster not in self._endpoints[endpoint]:
print(
f"Cluster 0x{path.Cluster:02x} not found on endpoint {endpoint}"
)
continue
clusters = [self._endpoints[endpoint][path.Cluster]]
for cluster in clusters:
# TODO: The path object probably needs to be cloned. Otherwise we'll
# change the endpoint for all uses.
path.Endpoint = endpoint
path.Cluster = cluster.CLUSTER_ID
attribute_reports.extend(self.get_report(cluster, path))
return attribute_reports

def process_packet(self, address, data):
# Print the received data and the address of the sender
# This is section 4.7.2
Expand Down Expand Up @@ -268,11 +285,8 @@ def process_packet(self, address, data):

encoded = response.encode()
exchange.commissioning_hash.update(encoded)
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.PBKDF_PARAM_RESPONSE,
response,
)
print(response)
exchange.send(response)

elif protocol_opcode == SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:
print("Received PBKDF Parameter Response")
Expand All @@ -293,9 +307,7 @@ def process_packet(self, address, data):
)
exchange.cA = cA
exchange.Ke = Ke
exchange.send(
ProtocolId.SECURE_CHANNEL, SecureProtocolOpcode.PASE_PAKE2, pake2
)
exchange.send(pake2)
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE2:
print("Received PASE PAKE2")
raise NotImplementedError("Implement SPAKE2+ prover")
Expand All @@ -316,11 +328,7 @@ def process_packet(self, address, data):
error_status.protocol_code = (
session.SecureChannelProtocolCode.INVALID_PARAMETER
)
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.STATUS_REPORT,
error_status,
)
exchange.send(error_status)
else:
exchange.session.session_timestamp = time.monotonic()
status_ok = session.StatusReport()
Expand All @@ -329,11 +337,7 @@ def process_packet(self, address, data):
status_ok.protocol_code = (
session.SecureChannelProtocolCode.SESSION_ESTABLISHMENT_SUCCESS
)
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.STATUS_REPORT,
status_ok,
)
exchange.send(status_ok)

# Fully initialize the secure session context we'll use going
# forwards.
Expand All @@ -349,16 +353,7 @@ def process_packet(self, address, data):
)
response = self.manager.reply_to_sigma1(exchange, sigma1)

opcode = SecureProtocolOpcode.STATUS_REPORT
if isinstance(response, case.Sigma2Resume):
opcode = SecureProtocolOpcode.CASE_SIGMA2_RESUME
elif isinstance(response, case.Sigma2):
opcode = SecureProtocolOpcode.CASE_SIGMA2
exchange.send(
ProtocolId.SECURE_CHANNEL,
opcode,
response,
)
exchange.send(response)
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2:
print("Received CASE Sigma2")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3:
Expand All @@ -378,17 +373,14 @@ def process_packet(self, address, data):
error_status.general_code = general_code
error_status.protocol_id = ProtocolId.SECURE_CHANNEL
error_status.protocol_code = protocol_code
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.STATUS_REPORT,
error_status,
)
exchange.send(error_status)
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2_RESUME:
print("Received CASE Sigma2 Resume")
elif protocol_opcode == SecureProtocolOpcode.STATUS_REPORT:
print("Received Status Report")
report = session.StatusReport()
report.decode(message.application_payload)
print(report)

# Acknowledge the message because we have no further reply.
if message.exchange_flags & session.ExchangeFlags.R:
Expand All @@ -410,37 +402,25 @@ def process_packet(self, address, data):
)
attribute_reports = []
for path in read_request.AttributeRequests:
if path.Endpoint is None:
# Wildcard so we get it from every endpoint.
for endpoint in self._endpoints:
if path.Cluster in self._endpoints[endpoint]:
cluster = self._endpoints[endpoint][path.Cluster]
# TODO: The path object probably needs to be cloned. Otherwise we'll
# change the endpoint for all uses.
path.Endpoint = endpoint
print(path.Endpoint)
print(path)
attribute_reports.extend(self.get_report(cluster, path))
else:
print(
f"Cluster 0x{path.Cluster:02x} not found on endpoint {endpoint}"
)
else:
if path.Cluster in self._endpoints[path.Endpoint]:
cluster = self._endpoints[path.Endpoint][path.Cluster]
attribute_reports.extend(self.get_report(cluster, path))
else:
print(f"Cluster 0x{path.Cluster:02x} not found at all")
# attribute_reports.append(
# self._build_attribute_error(path, interaction_model.StatusCode.UNSUPPORTED_CLUSTER)
# )
print("read", path)
attribute_reports.extend(self.read_attribute_path(path))
response = interaction_model.ReportDataMessage()
response.AttributeReports = attribute_reports
exchange.send(
ProtocolId.INTERACTION_MODEL,
InteractionModelOpcode.REPORT_DATA,
response,
exchange.send(response)
elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST:
print("Received Write Request")
write_request, _ = interaction_model.WriteRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
)
print(write_request)
write_responses = []
for request in write_request.WriteRequests:
path = request.Path
if path.Cluster in self._endpoints[path.Endpoint]:
cluster = self._endpoints[path.Endpoint][path.Cluster]
print(cluster)
write_responses.append(cluster.set_attribute(request))

elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
invoke_request, _ = interaction_model.InvokeRequestMessage.decode(
Expand Down Expand Up @@ -482,33 +462,33 @@ def process_packet(self, address, data):
response = interaction_model.InvokeResponseMessage()
response.SuppressResponse = False
response.InvokeResponses = invoke_responses
exchange.send(
ProtocolId.INTERACTION_MODEL,
InteractionModelOpcode.INVOKE_RESPONSE,
response,
)
print("sending invoke response", response)
exchange.send(response)
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
print("Received Invoke Response")
elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST:
print("Received Subscribe Request")
subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
)
error_status = session.StatusReport()
error_status.general_code = session.GeneralCode.UNSUPPORTED
error_status.protocol_id = ProtocolId.SECURE_CHANNEL
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.STATUS_REPORT,
error_status,
)
print(subscribe_request)
attribute_reports = []
for path in subscribe_request.AttributeRequests:
attribute_reports.extend(self.read_attribute_path(path))
response = interaction_model.ReportDataMessage()
response.AttributeReports = attribute_reports
exchange.send(response)
final_response = interaction_model.SubscribeResponseMessage()
final_response.SubscriptionId = exchange.exchange_id
final_response.MaxInterval = subscribe_request.MaxIntervalCeiling
exchange.queue(final_response)
elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE:
print("Received Status Response")
print(message)
status_response, _ = interaction_model.StatusResponseMessage.decode(
message.application_payload[0], message.application_payload[1:]
)
print(status_response)
print(
f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}"
)
else:
print(message)
print("application payload", message.application_payload.hex(" "))
Expand Down
4 changes: 2 additions & 2 deletions circuitmatter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import circuitmatter as cm

from circuitmatter.device_types.lighting import extended_color
from circuitmatter.device_types.lighting import on_off


class ReplaySocket:
Expand Down Expand Up @@ -221,7 +221,7 @@ def socket(self, *args, **kwargs):
return RecordingSocket(self.record_file, socket.socket(*args, **kwargs))


class NeoPixel(extended_color.ExtendedColorLight):
class NeoPixel(on_off.OnOffLight):
pass


Expand Down
18 changes: 14 additions & 4 deletions circuitmatter/case.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from . import crypto
from . import protocol
from . import session
from . import tlv


class Sigma1(tlv.Structure):
class CASEMessage(tlv.Structure):
PROTOCOL_ID = protocol.ProtocolId.SECURE_CHANNEL


class Sigma1(CASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.CASE_SIGMA1

initiatorRandom = tlv.OctetStringMember(1, 32)
initiatorSessionId = tlv.IntMember(2, signed=False, octets=2)
destinationId = tlv.OctetStringMember(3, crypto.HASH_LEN_BYTES)
Expand Down Expand Up @@ -31,7 +38,8 @@ class Sigma2TbeData(tlv.Structure):
resumptionID = tlv.OctetStringMember(4, 16)


class Sigma2(tlv.Structure):
class Sigma2(CASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.CASE_SIGMA2
responderRandom = tlv.OctetStringMember(1, 32)
responderSessionId = tlv.IntMember(2, signed=False, octets=2)
responderEphPubKey = tlv.OctetStringMember(3, crypto.PUBLIC_KEY_SIZE_BYTES)
Expand All @@ -54,11 +62,13 @@ class Sigma3TbeData(tlv.Structure):
signature = tlv.OctetStringMember(3, crypto.GROUP_SIZE_BYTES * 2)


class Sigma3(tlv.Structure):
class Sigma3(CASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.CASE_SIGMA3
encrypted3 = tlv.OctetStringMember(1, Sigma3TbeData.max_length())


class Sigma2Resume(tlv.Structure):
class Sigma2Resume(CASEMessage):
PROTOCOL_OPCODE = protocol.SecureProtocolOpcode.CASE_SIGMA2_RESUME
resumptionID = tlv.OctetStringMember(1, 16)
sigma2ResumeMIC = tlv.OctetStringMember(2, 16)
responderSessionID = tlv.IntMember(3, signed=False, octets=2)
Expand Down
Loading

0 comments on commit aa514ac

Please sign in to comment.