Skip to content

Commit

Permalink
Store nonvolatile state in json and restore. Improve decode too
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Oct 16, 2024
1 parent cdaa606 commit 6b9090e
Show file tree
Hide file tree
Showing 13 changed files with 368 additions and 127 deletions.
63 changes: 28 additions & 35 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import binascii
import hashlib
import json
import time

from . import case
from . import interaction_model
from . import nonvolatile
from .message import Message
from .protocol import InteractionModelOpcode, ProtocolId, SecureProtocolOpcode
from . import session
Expand All @@ -29,15 +29,12 @@ def __init__(
self.mdns_server = mdns_server
self.random = random_source

with open(state_filename, "r") as state_file:
self.nonvolatile = json.load(state_file)
self.nonvolatile = nonvolatile.PersistentDictionary(state_filename)

for key in ["discriminator", "salt", "iteration-count", "verifier"]:
if key not in self.nonvolatile:
raise RuntimeError(f"Missing key {key} in state file")

commission = "fabrics" not in self.nonvolatile

self.packet_buffer = memoryview(bytearray(1280))

# Define the UDP IP address and port
Expand All @@ -51,6 +48,7 @@ def __init__(

# Bind the socket to the IP and port
self.socket.bind((UDP_IP, self.UDP_PORT))
print(f"Listening on UDP port {self.UDP_PORT}")
self.socket.setblocking(False)

self._endpoints = {}
Expand All @@ -62,14 +60,11 @@ def __init__(

self.vendor_id = vendor_id
self.product_id = product_id

self.manager = session.SessionManager(
self.random, self.socket, self.root_node.noc
)

print(f"Listening on UDP port {self.UDP_PORT}")

if commission:
if self.root_node.fabric_count == 0:
self.start_commissioning()

def start_commissioning(self):
Expand Down Expand Up @@ -118,6 +113,12 @@ def add_device(self, device):
device.descriptor.ServerList.append(server.CLUSTER_ID)
self.add_cluster(self._next_endpoint, server)
self.add_cluster(self._next_endpoint, device.descriptor)

if "devices" not in self.nonvolatile:
self.nonvolatile["devices"] = {}
if device.name not in self.nonvolatile["devices"]:
self.nonvolatile["devices"][device.name] = {}
device.restore(self.nonvolatile["devices"][device.name])
self._next_endpoint += 1

def process_packets(self):
Expand Down Expand Up @@ -249,9 +250,7 @@ def process_packet(self, address, data):
from . import pase

# This is Section 4.14.1.2
request, _ = pase.PBKDFParamRequest.decode(
message.application_payload[0], message.application_payload[1:]
)
request = pase.PBKDFParamRequest.decode(message.application_payload)
exchange.commissioning_hash = hashlib.sha256(
b"CHIP PAKE V1 Commissioning"
)
Expand Down Expand Up @@ -287,9 +286,7 @@ def process_packet(self, address, data):
from . import pase

print("Received PASE PAKE1")
pake1, _ = pase.PAKE1.decode(
message.application_payload[0], message.application_payload[1:]
)
pake1 = pase.PAKE1.decode(message.application_payload)
pake2 = pase.PAKE2()
verifier = binascii.a2b_base64(self.nonvolatile["verifier"])
context = exchange.commissioning_hash.digest()
Expand All @@ -308,9 +305,7 @@ def process_packet(self, address, data):
from . import pase

print("Received PASE PAKE3")
pake3, _ = pase.PAKE3.decode(
message.application_payload[0], message.application_payload[1:]
)
pake3 = pase.PAKE3.decode(message.application_payload)
if pake3.cA != exchange.cA:
del exchange.cA
del exchange.Ke
Expand Down Expand Up @@ -341,19 +336,15 @@ def process_packet(self, address, data):
print("PASE succeeded")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA1:
print("Received CASE Sigma1")
sigma1, _ = case.Sigma1.decode(
message.application_payload[0], message.application_payload[1:]
)
sigma1 = case.Sigma1.decode(message.application_payload)
response = self.manager.reply_to_sigma1(exchange, sigma1)

exchange.send(response)
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2:
print("Received CASE Sigma2")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3:
print("Received CASE Sigma3")
sigma3, _ = case.Sigma3.decode(
message.application_payload[0], message.application_payload[1:]
)
sigma3 = case.Sigma3.decode(message.application_payload)
protocol_code = self.manager.reply_to_sigma3(exchange, sigma3)

error_status = session.StatusReport()
Expand Down Expand Up @@ -390,8 +381,8 @@ def process_packet(self, address, data):
message.session_id
]
if protocol_opcode == InteractionModelOpcode.READ_REQUEST:
read_request, _ = interaction_model.ReadRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
read_request = interaction_model.ReadRequestMessage.decode(
message.application_payload
)
attribute_reports = []
for path in read_request.AttributeRequests:
Expand All @@ -404,8 +395,8 @@ def process_packet(self, address, data):
exchange.send(response)
elif protocol_opcode == InteractionModelOpcode.WRITE_REQUEST:
print("Received Write Request")
write_request, _ = interaction_model.WriteRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
write_request = interaction_model.WriteRequestMessage.decode(
message.application_payload
)
write_responses = []
for request in write_request.WriteRequests:
Expand All @@ -421,8 +412,8 @@ def process_packet(self, address, data):

elif protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
invoke_request, _ = interaction_model.InvokeRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
invoke_request = interaction_model.InvokeRequestMessage.decode(
message.application_payload
)
for invoke in invoke_request.InvokeRequests:
path = invoke.CommandPath
Expand Down Expand Up @@ -460,14 +451,13 @@ def process_packet(self, address, data):
response = interaction_model.InvokeResponseMessage()
response.SuppressResponse = False
response.InvokeResponses = invoke_responses
print("sending invoke response", response)
exchange.send(response)
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
print("Received Invoke Response")
elif protocol_opcode == InteractionModelOpcode.SUBSCRIBE_REQUEST:
print("Received Subscribe Request")
subscribe_request, _ = interaction_model.SubscribeRequestMessage.decode(
message.application_payload[0], message.application_payload[1:]
subscribe_request = interaction_model.SubscribeRequestMessage.decode(
message.application_payload
)
print(subscribe_request)
attribute_reports = []
Expand All @@ -484,8 +474,8 @@ def process_packet(self, address, data):
final_response.MaxInterval = subscribe_request.MaxIntervalCeiling
exchange.queue(final_response)
elif protocol_opcode == InteractionModelOpcode.STATUS_RESPONSE:
status_response, _ = interaction_model.StatusResponseMessage.decode(
message.application_payload[0], message.application_payload[1:]
status_response = interaction_model.StatusResponseMessage.decode(
message.application_payload
)
print(
f"Received Status Response on {message.session_id}/{message.exchange_id} ack {message.acknowledged_message_counter}: {status_response.Status!r}"
Expand All @@ -502,3 +492,6 @@ def process_packet(self, address, data):
else:
print("Unknown protocol", message.protocol_id, message.protocol_opcode)
print()

self.nonvolatile.commit()
# TODO: Rollback on error?
28 changes: 17 additions & 11 deletions circuitmatter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import binascii
import json
import os
import pathlib
import secrets
import socket
import subprocess
Expand Down Expand Up @@ -114,13 +115,8 @@ def advertise_service(
subtypes=[],
instance_name="",
):
for active_service in self.active_services.values():
active_service.kill()
subtypes = [f"--subtype={subtype}" for subtype in subtypes]
txt_records = [f"{key}={value}" for key, value in txt_records.items()]
if service_type in self.active_services:
self.active_services[service_type].kill()
del self.active_services[service_type]
command = [
"avahi-publish-service",
*subtypes,
Expand All @@ -130,7 +126,7 @@ def advertise_service(
*txt_records,
]
print("running avahi", command)
self.active_services[service_type] = subprocess.Popen(command)
self.active_services[service_type + instance_name] = subprocess.Popen(command)
if self.publish_address is None:
command = [
"avahi-publish-address",
Expand Down Expand Up @@ -226,23 +222,33 @@ class NeoPixel(on_off.OnOffLight):


def run(replay_file=None):
device_state = pathlib.Path("test_data/device_state.json")
replay_device_state = pathlib.Path("test_data/replay_device_state.json")
if replay_file:
replay_lines = []
with open(replay_file, "r") as f:
device_state_fn = f.readline().strip()
for line in f:
replay_lines.append(json.loads(line))
socketpool = ReplaySocketPool(replay_lines)
mdns_server = DummyMDNS()
random_source = ReplayRandom(replay_lines)
# Reset device state to before the captured run
device_state.write_text(pathlib.Path(device_state_fn).read_text())
else:
record_file = open("test_data/recorded_packets.jsonl", "w")
timestamp = time.strftime("%Y%m%d-%H%M%S")
record_file = open(f"test_data/recorded_packets-{timestamp}.jsonl", "w")
device_state_fn = f"test_data/device_state-{timestamp}.json"
record_file.write(f"{device_state_fn}\n")
socketpool = RecordingSocketPool(record_file)
mdns_server = MDNSServer()
random_source = RecordingRandom(record_file)
matter = cm.CircuitMatter(
socketpool, mdns_server, random_source, "test_data/device_state.json"
)
led = NeoPixel()
# Save device state before we run so replays can use it.
replay_device_state = pathlib.Path(device_state_fn)
replay_device_state.write_text(device_state.read_text())

matter = cm.CircuitMatter(socketpool, mdns_server, random_source, device_state)
led = NeoPixel("neopixel1")
matter.add_device(led)
while True:
matter.process_packets()
Expand Down
18 changes: 10 additions & 8 deletions circuitmatter/clusters/device_management/group_key_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ class GroupKeyMulticastPolicyEnum(Enum8):
class GroupKeySetStruct(tlv.Structure):
GroupKeySetID = tlv.IntMember(0, signed=False, octets=2)
GroupKeySecurityPolicy = tlv.EnumMember(1, GroupKeySetSecurityPolicyEnum)
EpochKey0 = tlv.OctetStringMember(2, 16)
EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8)
EpochKey1 = tlv.OctetStringMember(4, 16)
EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8)
EpochKey2 = tlv.OctetStringMember(6, 16)
EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8)
GroupKeyMulticastPolicy = tlv.EnumMember(8, GroupKeyMulticastPolicyEnum)
EpochKey0 = tlv.OctetStringMember(2, 16, nullable=True)
EpochStartTime0 = tlv.IntMember(3, signed=False, octets=8, nullable=True)
EpochKey1 = tlv.OctetStringMember(4, 16, nullable=True)
EpochStartTime1 = tlv.IntMember(5, signed=False, octets=8, nullable=True)
EpochKey2 = tlv.OctetStringMember(6, 16, nullable=True)
EpochStartTime2 = tlv.IntMember(7, signed=False, octets=8, nullable=True)
GroupKeyMulticastPolicy = tlv.EnumMember(
8, GroupKeyMulticastPolicyEnum, nullable=True
)


class GroupKeyManagementCluster(Cluster):
Expand All @@ -48,7 +50,7 @@ class GroupInfoMapStruct(tlv.Structure):
class KeySetWrite(tlv.Structure):
GroupKeySet = tlv.StructMember(0, GroupKeySetStruct)

group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[])
group_key_map = ListAttribute(0, GroupKeyMapStruct, default=[], N_nonvolatile=True)
group_table = ListAttribute(1, GroupInfoMapStruct, default=[])
max_groups_per_fabric = NumberAttribute(2, signed=False, bits=16, default=0)
max_group_keys_per_fabric = NumberAttribute(3, signed=False, bits=16, default=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,20 @@ class RemoveFabric(tlv.Structure):
class AddTrustedRootCertificate(tlv.Structure):
RootCACertificate = tlv.OctetStringMember(0, 400)

nocs = ListAttribute(0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True)
fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True)
nocs = ListAttribute(
0, NOCStruct, N_nonvolatile=True, C_changes_omitted=True, default=[]
)
fabrics = ListAttribute(1, FabricDescriptorStruct, N_nonvolatile=True, default=[])
supported_fabrics = NumberAttribute(2, signed=False, bits=8, F_fixed=True)
commissioned_fabrics = NumberAttribute(3, signed=False, bits=8, N_nonvolatile=True)
commissioned_fabrics = NumberAttribute(
3, signed=False, bits=8, N_nonvolatile=True, default=0
)
trusted_root_certificates = ListAttribute(
4, tlv.OctetStringMember(None, 400), N_nonvolatile=True, C_changes_omitted=True
4,
tlv.OctetStringMember(None, 400),
N_nonvolatile=True,
C_changes_omitted=True,
default=[],
)
# This attribute is weird because it is fabric sensitive but not marked as such.
# Cluster sets current_fabric_index for use in fabric sensitive attributes and
Expand Down
Loading

0 comments on commit 6b9090e

Please sign in to comment.