Skip to content

Commit

Permalink
Apple commissioning almost works
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Oct 3, 2024
1 parent 7b05d28 commit 4a0c4b6
Show file tree
Hide file tree
Showing 14 changed files with 443 additions and 107 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ You do not need to pay anything or be a member organization.

CircuitMatter is currently developed in CPython 3.12, the de facto implementation written in C. It is designed with minimal dependencies so that it can also be used on CircuitPython on microcontrollers.

After cloning the repo, pip install `ecdsa` and `cryptography`.
After cloning the repo, pip install `ecdsa`, `cryptography` and `qrcode`.

### Running a CircuitMatter replay

Expand Down Expand Up @@ -73,7 +73,9 @@ Logs can be added into the chip sources to understand what is happening on the c

### Apple Home

The Apple Home app can also discover and (attempt to) commission the device. Tap Add Accessory and the CircuitMatter device will show up as a nearby Matter Accessory. Tap it and then enter the setup code `67202583`. This will start the commissioning process from Apple Home.
The Apple Home app can also discover and (attempt to) commission the device. Tap Add Accessory.
* By default this will pull up the camera to scan a QR Code. CircuitMatter will print the qrcode to the console to scan.
* You can also use the passcode by clicking "More options" and the CircuitMatter device will show up as a nearby Matter Accessory. Tap it and then enter the setup code `67202583`. This will start the commissioning process from Apple Home.

## Generate a certificate declaration

Expand Down
41 changes: 34 additions & 7 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
with open(state_filename, "r") as state_file:
self.nonvolatile = json.load(state_file)

for key in ["descriminator", "salt", "iteration-count", "verifier"]:
for key in ["discriminator", "salt", "iteration-count", "verifier"]:
if key not in self.nonvolatile:
raise RuntimeError(f"Missing key {key} in state file")

Expand All @@ -58,6 +58,7 @@ def __init__(
self._next_endpoint = 0
self._descriptor = data_model.DescriptorCluster()
self._descriptor.PartsList = []
self._descriptor.ServerList = []
self.add_cluster(0, self._descriptor)
basic_info = data_model.BasicInformationCluster()
basic_info.vendor_id = vendor_id
Expand All @@ -66,6 +67,11 @@ def __init__(
group_keys = core.GroupKeyManagementCluster()
self.add_cluster(0, group_keys)
network_info = data_model.NetworkCommissioningCluster()

ethernet = data_model.NetworkCommissioningCluster.NetworkInfoStruct()
ethernet.NetworkID = "enp13s0".encode("utf-8")
ethernet.Connected = True
network_info.networks = [ethernet]
network_info.connect_max_time_seconds = 10
self.add_cluster(0, network_info)
general_commissioning = core.GeneralCommissioningCluster()
Expand All @@ -75,6 +81,9 @@ def __init__(
)
self.add_cluster(0, noc)

self.vendor_id = vendor_id
self.product_id = product_id

self.manager = session.SessionManager(self.random, self.socket, noc)

print(f"Listening on UDP port {self.UDP_PORT}")
Expand All @@ -83,17 +92,21 @@ def __init__(
self.start_commissioning()

def start_commissioning(self):
descriminator = self.nonvolatile["descriminator"]
discriminator = self.nonvolatile["discriminator"]
passcode = self.nonvolatile["passcode"]
txt_records = {
"PI": "",
"PH": "33",
"CM": "1",
"D": str(descriminator),
"D": str(discriminator),
"CRI": "3000",
"CRA": "4000",
"T": "1",
"VP": "65521+32769",
"VP": f"{self.vendor_id}+{self.product_id}",
}
from . import pase

pase.show_qr_code(self.vendor_id, self.product_id, discriminator, passcode)
instance_name = self.random.urandom(8).hex().upper()
self.mdns_server.advertise_service(
"_matterc",
Expand All @@ -102,21 +115,25 @@ def start_commissioning(self):
txt_records=txt_records,
instance_name=instance_name,
subtypes=[
f"_L{descriminator}._sub._matterc._udp",
f"_L{discriminator}._sub._matterc._udp",
"_CM._sub._matterc._udp",
],
)

def add_cluster(self, endpoint, cluster):
if endpoint not in self._endpoints:
self._endpoints[endpoint] = {}
self._descriptor.PartsList.append(endpoint)
if endpoint > 0:
self._descriptor.PartsList.append(endpoint)
self._next_endpoint = max(self._next_endpoint, endpoint + 1)
if endpoint == 0:
self._descriptor.ServerList.append(cluster.CLUSTER_ID)
self._endpoints[endpoint][cluster.CLUSTER_ID] = cluster

def add_device(self, device):
self._endpoints[self._next_endpoint] = {}
self._descriptor.PartsList.append(self._next_endpoint)
if self._next_endpoint > 0:
self._descriptor.PartsList.append(self._next_endpoint)
self._next_endpoint += 1

def process_packets(self):
Expand Down Expand Up @@ -358,8 +375,16 @@ def process_packet(self, address, data):
report = session.StatusReport()
report.decode(message.application_payload)
print(report)

# Acknowledge the message because we have no further reply.
if message.exchange_flags & session.ExchangeFlags.R:
exchange.send_standalone()
elif protocol_opcode == SecureProtocolOpcode.ICD_CHECK_IN:
print("Received ICD Check-in")
elif protocol_opcode == SecureProtocolOpcode.MRP_STANDALONE_ACK:
print("Received MRP Standalone Ack")
else:
print("Unhandled secure channel opcode", protocol_opcode)
elif message.protocol_id == ProtocolId.INTERACTION_MODEL:
secure_session_context = self.manager.secure_session_contexts[
message.session_id
Expand Down Expand Up @@ -456,4 +481,6 @@ def process_packet(self, address, data):
else:
print(message)
print("application payload", message.application_payload.hex(" "))
else:
print("Unknown protocol", message.protocol_id, message.protocol_opcode)
print()
11 changes: 11 additions & 0 deletions circuitmatter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def advertise_service(
class MDNSServer(DummyMDNS):
def __init__(self):
self.active_services = {}
self.publish_address = None

def advertise_service(
self,
Expand Down Expand Up @@ -130,10 +131,20 @@ def advertise_service(
]
print("running avahi", command)
self.active_services[service_type] = subprocess.Popen(command)
if self.publish_address is None:
command = [
"avahi-publish-address",
f"{instance_name}.local",
"fe80::642:1aff:fe0c:9f2a",
]
print("run", command)
self.publish_address = subprocess.Popen(command)

def __del__(self):
for active_service in self.active_services.values():
active_service.kill()
if self.publish_address is not None:
self.publish_address.kill()


class RecordingRandom:
Expand Down
7 changes: 4 additions & 3 deletions circuitmatter/clusters/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ def add_noc(
noc, _ = crypto.MatterCertificate.decode(
args.NOCValue[0], memoryview(args.NOCValue)[1:]
)
icac, _ = crypto.MatterCertificate.decode(
args.ICACValue[0], memoryview(args.ICACValue)[1:]
)
if args.ICACValue:
icac, _ = crypto.MatterCertificate.decode(
args.ICACValue[0], memoryview(args.ICACValue)[1:]
)

response = data_model.NodeOperationalCredentialsCluster.NOCResponse()

Expand Down
74 changes: 61 additions & 13 deletions circuitmatter/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ class Enum16(enum.IntEnum):
pass


class Uint16(tlv.IntMember):
def __init__(self, _id=None, minimum=0):
super().__init__(_id, signed=False, octets=2, minimum=minimum)


class GroupId(Uint16):
pass


class ClusterId(Uint16):
pass


class EndpointNumber(Uint16):
def __init__(self, _id=None):
super().__init__(_id, minimum=1)


# Data model "lists" are encoded as tlv arrays. 🙄
class List(tlv.ArrayMember):
pass


class Attribute:
def __init__(self, _id, default=None):
self.id = _id
Expand Down Expand Up @@ -86,7 +109,12 @@ def __init__(self, _id, enum_type, default=None):


class ListAttribute(Attribute):
pass
def __init__(self, _id, element_type):
self.tlv_type = tlv.ArrayMember(None, element_type)
super().__init__(_id)

def encode(self, value) -> bytes:
return self.tlv_type.encode(value)


class BoolAttribute(Attribute):
Expand Down Expand Up @@ -218,11 +246,10 @@ class DeviceTypeStruct(tlv.Structure):
devtype_id = tlv.IntMember(0, signed=False, octets=4)
revision = tlv.IntMember(1, signed=False, octets=2, minimum=1)

DeviceTypeList = ListAttribute(0x0000)
ServerList = ListAttribute(0x0001)
ClientList = ListAttribute(0x0002)
PartsList = ListAttribute(0x0003)
TagList = ListAttribute(0x0004)
DeviceTypeList = ListAttribute(0x0000, DeviceTypeStruct)
ServerList = ListAttribute(0x0001, ClusterId())
ClientList = ListAttribute(0x0002, ClusterId())
PartsList = ListAttribute(0x0003, EndpointNumber())


class ProductFinish(enum.IntEnum):
Expand Down Expand Up @@ -323,11 +350,20 @@ class GroupKeySetStruct(tlv.Structure):
class GroupKeyManagementCluster(Cluster):
CLUSTER_ID = 0x3F

class GroupKeyMapStruct(tlv.Structure):
GroupId = GroupId(1)
GroupKeySetID = tlv.IntMember(2, signed=False, octets=2, minimum=1)

class GroupInfoMapStruct(tlv.Structure):
GroupId = GroupId(1)
Endpoints = List(2, EndpointNumber())
GroupName = tlv.UTF8StringMember(3, max_length=16)

class KeySetWrite(tlv.Structure):
GroupKeySet = tlv.StructMember(0, GroupKeySetStruct)

group_key_map = ListAttribute(0)
group_table = ListAttribute(1)
group_key_map = ListAttribute(0, GroupKeyMapStruct)
group_table = ListAttribute(1, GroupInfoMapStruct)
max_groups_per_fabric = NumberAttribute(2, signed=False, bits=16, default=0)
max_group_keys_per_fabric = NumberAttribute(3, signed=False, bits=16, default=1)

Expand Down Expand Up @@ -403,6 +439,14 @@ class FeatureBitmap(enum.IntFlag):
THREAD_NETWORK_INTERFACE = 0b010
ETHERNET_NETWORK_INTERFACE = 0b100

class WifiBandEnum(Enum8):
BAND_2G4 = 0
BAND_3G65 = 1
BAND_5G = 2
BAND_6G = 3
BAND_60G = 4
BAND_1G = 5

class NetworkCommissioningStatus(Enum8):
SUCCESS = 0
"""Ok, no error"""
Expand Down Expand Up @@ -443,15 +487,19 @@ class NetworkCommissioningStatus(Enum8):
UNKNOWN_ERROR = 12
"""Unknown error"""

class NetworkInfoStruct(tlv.Structure):
NetworkID = tlv.OctetStringMember(0, min_length=1, max_length=32)
Connected = tlv.BoolMember(1)

max_networks = NumberAttribute(0, signed=False, bits=8)
networks = ListAttribute(1)
networks = ListAttribute(1, NetworkInfoStruct)
scan_max_time_seconds = NumberAttribute(2, signed=False, bits=8)
connect_max_time_seconds = NumberAttribute(3, signed=False, bits=8)
interface_enabled = BoolAttribute(4)
last_network_status = EnumAttribute(5, NetworkCommissioningStatus)
last_network_id = OctetStringAttribute(6, min_length=1, max_length=32)
last_connect_error_value = NumberAttribute(7, signed=True, bits=32)
supported_wifi_bands = ListAttribute(8)
supported_wifi_bands = ListAttribute(8, WifiBandEnum)
supported_thread_features = BitmapAttribute(9)
thread_version = NumberAttribute(10, signed=False, bits=16)

Expand Down Expand Up @@ -547,11 +595,11 @@ class RemoveFabric(tlv.Structure):
class AddTrustedRootCertificate(tlv.Structure):
RootCACertificate = tlv.OctetStringMember(0, 400)

nocs = ListAttribute(0)
fabrics = ListAttribute(1)
nocs = ListAttribute(0, NOCStruct)
fabrics = ListAttribute(1, FabricDescriptorStruct)
supported_fabrics = NumberAttribute(2, signed=False, bits=8)
commissioned_fabrics = NumberAttribute(3, signed=False, bits=8)
trusted_root_certificates = ListAttribute(4)
trusted_root_certificates = ListAttribute(4, tlv.OctetStringMember(None, 400))
current_fabric_index = NumberAttribute(5, signed=False, bits=8, default=0)

attestation_request = Command(0x00, AttestationRequest, 0x01, AttestationResponse)
Expand Down
33 changes: 25 additions & 8 deletions circuitmatter/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(self, session, initiator: bool, exchange_id: int, protocols):
self.pending_retransmission = None
"""Message that we've attempted to send but hasn't been acked"""

def send(self, protocol_id, protocol_opcode, application_payload=None):
def send(
self, protocol_id, protocol_opcode, application_payload=None, reliable=True
):
message = Message()
message.exchange_flags = ExchangeFlags(0)
if self.initiator:
Expand All @@ -49,6 +51,9 @@ def send(self, protocol_id, protocol_opcode, application_payload=None):
self.send_standalone_time = None
message.acknowledged_message_counter = self.pending_acknowledgement
self.pending_acknowledgement = None
if reliable:
message.exchange_flags |= ExchangeFlags.R
self.pending_retransmission = message
message.source_node_id = self.session.local_node_id
message.protocol_id = protocol_id
message.protocol_opcode = protocol_opcode
Expand All @@ -57,36 +62,48 @@ def send(self, protocol_id, protocol_opcode, application_payload=None):
self.session.send(message)

def send_standalone(self):
if self.pending_retransmission is not None:
print("resending", self.pending_retransmission)
self.session.send(self.pending_retransmission)
return
self.send(
ProtocolId.SECURE_CHANNEL, SecureProtocolOpcode.MRP_STANDALONE_ACK, None
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.MRP_STANDALONE_ACK,
None,
reliable=False,
)

def receive(self, message) -> bool:
"""Process the message and return if the packet should be dropped."""
if message.protocol_id not in self.protocols:
# Drop messages that don't match the protocols we're waiting for.
return True

# Section 4.12.5.2.1
if message.exchange_flags & ExchangeFlags.A:
if message.acknowledged_message_counter is None:
# Drop messages that are missing an acknowledgement counter.
return True
if message.acknowledged_message_counter != self.pending_acknowledgement:
if (
message.acknowledged_message_counter
!= self.pending_retransmission.message_counter
):
# Drop messages that have the wrong acknowledgement counter.
return True
self.pending_retransmission = None
self.next_retransmission_time = None

if message.protocol_id not in self.protocols:
print("protocol mismatch")
# Drop messages that don't match the protocols we're waiting for.
return True

# Section 4.12.5.2.2
# Incoming packets that are marked Reliable.
if message.exchange_flags & ExchangeFlags.R:
if message.duplicate:
# Send a standalone acknowledgement.
self.send_standalone()
return True
if self.pending_acknowledgement is not None:
# Send a standalone acknowledgement with the message counter we're about to overwrite.
pass
self.send_standalone()
self.pending_acknowledgement = message.message_counter
self.send_standalone_time = (
time.monotonic() + MRP_STANDALONE_ACK_TIMEOUT_MS / 1000
Expand Down
Loading

0 comments on commit 4a0c4b6

Please sign in to comment.