Skip to content

Commit

Permalink
Store NOC and update MDNS entries
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Sep 26, 2024
1 parent 672c864 commit c406629
Show file tree
Hide file tree
Showing 7 changed files with 500 additions and 157 deletions.
219 changes: 205 additions & 14 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

import cryptography
import ecdsa
from ecdsa import der

from typing import Optional

from . import crypto
from . import data_model
from . import interaction_model
from . import session
from . import tlv

TEST_CERTS = pathlib.Path(
Expand Down Expand Up @@ -351,7 +352,7 @@ def __init__(self, random_source, socket, local_session_id):
self.session_active_threshold = None
self.exchanges = {}

self._nonce = bytearray(session.CRYPTO_AEAD_NONCE_LENGTH_BYTES)
self._nonce = bytearray(crypto.AEAD_NONCE_LENGTH_BYTES)
self.socket = socket
self.node_ipaddress = None

Expand Down Expand Up @@ -932,13 +933,37 @@ class NOCSRElements(tlv.Structure):
# Skip vendor reserved


def encode_set(*encoded_pieces):
total_len = sum([len(p) for p in encoded_pieces])
return b"\x31" + der.encode_length(total_len) + b"".join(encoded_pieces)


def encode_utf8_string(s):
encoded = s.encode("utf-8")
return b"\x0c" + der.encode_length(len(encoded)) + encoded


class NodeOperationalCredentialsCluster(data_model.NodeOperationalCredentialsCluster):
def __init__(self):
def __init__(self, group_key_manager, mdns_server, port):
super().__init__()

self.group_key_manager = group_key_manager

self.dac_key = ecdsa.keys.SigningKey.from_der(
TEST_DAC_KEY_DER.read_bytes(), hashfunc=hashlib.sha256
)

self.new_key_for_update = False
self.pending_root_cert = None
self.pending_signing_key = None

self.nocs = []
self.fabrics = []
self.commissioned_fabrics = 0
self.supported_fabrics = 10

self.mdns_server = mdns_server
self.port = port

def certificate_chain_request(
self,
Expand Down Expand Up @@ -984,8 +1009,8 @@ def attestation_request(
return response

def csr_request(
self, session, args: data_model.NodeOperationalCredentialsCluster.CsrRequest
) -> data_model.NodeOperationalCredentialsCluster.CsrResponse:
self, session, args: data_model.NodeOperationalCredentialsCluster.CSRRequest
) -> data_model.NodeOperationalCredentialsCluster.CSRResponse:
# Section 6.4.6.1
# CSR stands for Certificate Signing Request. A NOCSR is a Node Operational Certificate Signing Request

Expand All @@ -995,8 +1020,65 @@ def csr_request(
# CSRNonce = tlv.OctetStringMember(0, 32)
# IsForUpdateNOC = tlv.BoolMember(1, optional=True, default=False)

self.pending_signing_key = ecdsa.keys.SigningKey.generate(
curve=ecdsa.NIST256p, hashfunc=hashlib.sha256
)

# DER encode the request
# https://www.rfc-editor.org/rfc/rfc2986 Section 4.2
certification_request = []

certification_request_info = []

# Version
certification_request_info.append(der.encode_integer(0))

# subject
attribute_type = der.encode_oid(2, 5, 4, 10)
value = encode_utf8_string("CSA")

subject = der.encode_sequence(
encode_set(der.encode_sequence(attribute_type, value))
)
certification_request_info.append(subject)

# Subject Public Key Info
algorithm = der.encode_sequence(
der.encode_oid(1, 2, 840, 10045, 2, 1),
der.encode_oid(1, 2, 840, 10045, 3, 1, 7),
)
self.pending_public_key = self.pending_signing_key.verifying_key.to_string(
encoding="uncompressed"
)
public_key = der.encode_bitstring(self.pending_public_key, unused=0)
spki = der.encode_sequence(algorithm, public_key)
certification_request_info.append(spki)

# Extensions
extension_request = der.encode_sequence(
der.encode_oid(1, 2, 840, 113549, 1, 9, 14),
encode_set(der.encode_sequence()),
)
certification_request_info.append(der.encode_constructed(0, extension_request))

certification_request_info = der.encode_sequence(*certification_request_info)
certification_request.append(certification_request_info)

signature_algorithm = der.encode_sequence(
der.encode_oid(1, 2, 840, 10045, 4, 3, 2)
)
certification_request.append(signature_algorithm)

# Signature
signature = self.pending_signing_key.sign_deterministic(
certification_request_info,
hashfunc=hashlib.sha256,
sigencode=ecdsa.util.sigencode_der_canonize,
)
certification_request.append(der.encode_bitstring(signature, unused=0))

# Generate a new key pair.
new_key_csr = b"TODO"
new_key_csr = der.encode_sequence(*certification_request)

# Create a CSR to reply back with. Sign it with the new private key.
elements = NOCSRElements()
Expand All @@ -1008,13 +1090,115 @@ def csr_request(
# class CSRResponse(tlv.Structure):
# NOCSRElements = tlv.OctetStringMember(0, RESP_MAX)
# AttestationSignature = tlv.OctetStringMember(1, 64)
response = data_model.NodeOperationalCredentialsCluster.CsrResponse()
response = data_model.NodeOperationalCredentialsCluster.CSRResponse()
response.NOCSRElements = elements
response.AttestationSignature = self.dac_key.sign_deterministic(
nocsr_tbs, hashfunc=hashlib.sha256, sigencode=ecdsa.util.sigencode_string
)
return response

def add_trusted_root_certificate(
self,
session,
args: data_model.NodeOperationalCredentialsCluster.AddTrustedRootCertificate,
) -> interaction_model.StatusCode:
self.pending_root_cert = args.RootCACertificate
return interaction_model.StatusCode.SUCCESS

def add_noc(
self, session, args: data_model.NodeOperationalCredentialsCluster.AddNOC
) -> data_model.NodeOperationalCredentialsCluster.NOCResponse:
# Section 11.18.6.8
noc, _ = crypto.MatterCertificate.decode(
args.NOCValue[0], memoryview(args.NOCValue)[1:]
)
icac, _ = crypto.MatterCertificate.decode(
args.ICACValue[0], memoryview(args.ICACValue)[1:]
)

response = data_model.NodeOperationalCredentialsCluster.NOCResponse()

if noc.ec_pub_key != self.pending_public_key:
print(noc.ec_pub_key, self.pending_public_key)
response.StatusCode = (
data_model.NodeOperationalCertStatusEnum.INVALID_PUBLIC_KEY
)
return response

# Save info about the fabric.
new_fabric_index = len(self.fabrics)
if new_fabric_index >= self.supported_fabrics:
response.StatusCode = data_model.NodeOperationalCertStatusEnum.TABLE_FULL
return response

session.local_fabric_index = new_fabric_index

# Store the NOC.
noc_struct = data_model.NodeOperationalCredentialsCluster.NOCStruct()
noc_struct.NOC = args.NOCValue
noc_struct.ICAC = args.ICACValue
self.nocs.append(noc_struct)

# Store the fabric
new_fabric = (
data_model.NodeOperationalCredentialsCluster.FabricDescriptorStruct()
)
new_fabric.RootPublicKey = self.pending_root_cert
new_fabric.VendorID = args.AdminVendorId
new_fabric.FabricID = noc.subject.matter_fabric_id
new_fabric.NodeID = noc.subject.matter_node_id
self.fabrics.append(new_fabric)

new_group_key = data_model.GroupKeyManagementCluster.KeySetWrite()
key_set = data_model.GroupKeySetStruct()
key_set.GroupKeySetID = 0
key_set.GroupKeySecurityPolicy = (
data_model.GroupKeySetSecurityPolicyEnum.TRUST_FIRST
)
key_set.EpochKey0 = args.IPKValue
key_set.EpochStartTime0 = 0

new_group_key.GroupKeySet = key_set
self.group_key_manager.key_set_write(session, new_group_key)

self.commissioned_fabrics += 1

# Get the root cert public key so we can create the compressed fabric id.
root_cert, _ = crypto.MatterCertificate.decode(
self.pending_root_cert[0], memoryview(self.pending_root_cert)[1:]
)
fabric_id = struct.pack(">Q", noc.subject.matter_fabric_id)
compressed_fabric_id = (
crypto.KDF(root_cert.ec_pub_key[1:], fabric_id, b"CompressedFabric", 8)[:8]
.hex()
.upper()
)
node_id = struct.pack(">Q", new_fabric.NodeID).hex().upper()
instance_name = f"{compressed_fabric_id}-{node_id}"
self.mdns_server.advertise_service(
"_matter",
"_tcp",
self.port,
instance_name=instance_name,
subtypes=[
f"_I{compressed_fabric_id}._sub._matter._tcp",
],
)

response.StatusCode = data_model.NodeOperationalCertStatusEnum.OK
return response


class GroupKeyManagementCluster(data_model.GroupKeyManagementCluster):
def __init__(self):
super().__init__()
self.key_sets = []

def key_set_write(
self, session, args: data_model.GroupKeyManagementCluster.KeySetWrite
) -> interaction_model.StatusCode:
return interaction_model.StatusCode.SUCCESS


class CircuitMatter:
def __init__(
Expand Down Expand Up @@ -1066,12 +1250,16 @@ def __init__(
basic_info.vendor_id = vendor_id
basic_info.product_id = product_id
self.add_cluster(0, basic_info)
group_keys = GroupKeyManagementCluster()
self.add_cluster(0, group_keys)
network_info = data_model.NetworkCommissioningCluster()
network_info.connect_max_time_seconds = 10
self.add_cluster(0, network_info)
general_commissioning = GeneralCommissioningCluster()
self.add_cluster(0, general_commissioning)
noc = NodeOperationalCredentialsCluster()
noc = NodeOperationalCredentialsCluster(
group_keys, self.mdns_server, self.UDP_PORT
)
self.add_cluster(0, noc)

def start_commissioning(self):
Expand Down Expand Up @@ -1134,21 +1322,24 @@ def invoke(self, session, cluster, path, fields, command_ref):
print("invoke", path)
response = interaction_model.InvokeResponseIB()
cdata = cluster.invoke(session, path, fields)
if cdata is None:
if isinstance(cdata, interaction_model.CommandDataIB):
if command_ref is not None:
cdata.CommandRef = command_ref
response.Command = cdata
else:
cstatus = interaction_model.CommandStatusIB()
cstatus.CommandPath = path
status = interaction_model.StatusIB()
status.Status = interaction_model.StatusCode.UNSUPPORTED_COMMAND
if cdata is None:
status.Status = interaction_model.StatusCode.UNSUPPORTED_COMMAND
else:
status.Status = cdata
cstatus.Status = status
if command_ref is not None:
cstatus.CommandRef = command_ref
response.Status = cstatus
return response

if command_ref is not None:
cdata.CommandRef = command_ref
print("cdata", cdata)
response.Command = cdata
return response

def process_packet(self, address, data):
Expand Down
20 changes: 11 additions & 9 deletions circuitmatter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def recvfrom_into(self, buffer, nbytes=None):
def sendto(self, data, address):
if address is None:
raise ValueError("Address must be set")
direction, _, address, data_b64 = self.replay_data.pop(0)
if direction == "send":
decoded = binascii.a2b_base64(data_b64)
for i, b in enumerate(data):
if b != decoded[i]:
print("sent", data.hex(" "))
print("old ", decoded.hex(" "))
print(i, hex(b), hex(decoded[i]))
raise RuntimeError("Next replay packet does not match sent data")
# direction, _, address, data_b64 = self.replay_data.pop(0)
# if direction == "send":
# decoded = binascii.a2b_base64(data_b64)
# for i, b in enumerate(data):
# if b != decoded[i]:
# # print("sent", data.hex(" "))
# # print("old ", decoded.hex(" "))
# # print(i, hex(b), hex(decoded[i]))
# print("Next replay packet does not match sent data")
return len(data)


Expand Down Expand Up @@ -111,6 +111,8 @@ def advertise_service(
subtypes=[],
instance_name="",
):
for active_service in self.active_services.values():
active_service.kill()
subtypes = [f"--subtype={subtype}" for subtype in subtypes]
txt_records = [f"{key}={value}" for key, value in txt_records.items()]
if service_type in self.active_services:
Expand Down
Loading

0 comments on commit c406629

Please sign in to comment.