Skip to content

Commit

Permalink
Connect back up to real networking, record packets and enable replay
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Jul 18, 2024
1 parent e590447 commit 6078dbe
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 160 deletions.
296 changes: 270 additions & 26 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,17 @@
"""Pure Python implementation of the Matter IOT protocol."""

import binascii
import enum
import pathlib
import json
import os
import struct
import time

from . import tlv

__version__ = "0.0.0"

# descriminator = 3840
# avahi = subprocess.Popen(["avahi-publish-service", "-v", f"--subtype=_L{descriminator}._sub._matterc._udp", "--subtype=_CM._sub._matterc._udp", "FA93546B21F5FB54", "_matterc._udp", "5540", "PI=", "PH=33", "CM=1", f"D={descriminator}", "CRI=3000", "CRA=4000", "T=1", "VP=65521+32769"])

# # Define the UDP IP address and port
# UDP_IP = "::" # Listen on all available network interfaces
# UDP_PORT = 5540

# # Create the UDP socket
# sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)

# # Bind the socket to the IP and port
# sock.bind((UDP_IP, UDP_PORT))

# print(f"Listening on UDP port {UDP_PORT}")

# Section 4.11.2
MSG_COUNTER_WINDOW_SIZE = 32
MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
Expand Down Expand Up @@ -243,11 +230,54 @@ def process_counter(self, counter) -> bool:
return False


class Exchange:
def __init__(self, initiator: bool, exchange_id: int, protocols):
self.initiator = initiator
self.exchange_id = exchange_id
self.protocols = protocols

self.pending_acknowledgement = None
self.next_retransmission_time = None
self.pending_retransmission = None

def send(self, message):
pass

def receive(self, message) -> bool:
"""Process the message and return if the packet should be dropped."""
if message.protocol_id not in self.protocols:
# Drop messages that don't match the protocols we're waiting for.
return True

# Section 4.10.5.2.1
if message.exchange_flags & ExchangeFlags.A:
if message.acknowledged_message_counter is None:
# Drop messages that are missing an acknowledgement counter.
return True
if self.pending_acknowledgement is None:
# Drop messages that are not waiting for an acknowledgement.
return True
if message.acknowledged_message_counter != self.pending_acknowledgement:
# Drop messages that have the wrong acknowledgement counter.
return True
self.pending_acknowledgement = None
self.pending_retransmission = None
self.next_retransmission_time = None

# Section 4.10.5.2.2
# if message.exchange_flags & ExchangeFlags.R:
# if message
if message.duplicate:
return True
return False


class UnsecuredSessionContext:
def __init__(self, initiator, ephemeral_initiator_node_id):
self.initiator = initiator
self.ephemeral_initiator_node_id = ephemeral_initiator_node_id
self.message_reception_state = None
self.exchanges = {}


class SecureSessionContext:
Expand Down Expand Up @@ -283,6 +313,7 @@ def __init__(self, local_session_id):
self.session_idle_interval = None
self.session_active_interval = None
self.session_active_threshold = None
self.exchanges = {}

@property
def peer_active(self):
Expand All @@ -295,6 +326,7 @@ def __init__(self, buffer):
self.flags, self.session_id, self.security_flags, self.message_counter = (
struct.unpack_from("<BHBI", buffer)
)
self.security_flags = SecurityFlags(self.security_flags)
offset = 8
self.source_node_id = None
if self.flags & (1 << 2):
Expand All @@ -303,17 +335,18 @@ def __init__(self, buffer):

if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
self.secure_session = self.security_flags & 0x3 != 0 or self.session_id != 0
self.secure_session = not (
not (self.security_flags & SecurityFlags.GROUP) and self.session_id == 0
)

if not self.secure_session:
self.payload = memoryview(buffer)[offset:]

context = UnsecuredSessionContext(False, self.source_node_id)
self.unsecured_session_context[self.source_node_id] = context
else:
self.payload = None

def _parse_protocol_header(self):
self.duplicate = None

def parse_protocol_header(self):
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
struct.unpack_from("<BBH", self.payload)
)
Expand All @@ -329,7 +362,7 @@ def _parse_protocol_header(self):
protocol_id = struct.unpack_from("<H", self.payload, decrypted_offset)[0]
decrypted_offset += 2
self.protocol_id = ProtocolId(protocol_id)
self.protocol_opcode = PROTOCOL_OPCODES[protocol_id](self.protocol_opcode)
self.protocol_opcode = PROTOCOL_OPCODES[self.protocol_id](self.protocol_opcode)

self.acknowledged_message_counter = None
if self.exchange_flags & ExchangeFlags.A:
Expand All @@ -338,6 +371,8 @@ def _parse_protocol_header(self):
)[0]
decrypted_offset += 4

self.application_payload = self.payload[decrypted_offset:]

def reply(self, payload, protocol_id=None, protocol_opcode=None) -> memoryview:
reply = bytearray(1280)
offset = 0
Expand Down Expand Up @@ -375,12 +410,11 @@ def __init__(self):
def _increment(self, value):
return (value + 1) % 0xFFFFFFFF

def counter_ok(self, message):
"""Implements 4.6.7"""
def get_session(self, message):
if message.secure_session:
if message.security_flags & SecurityFlags.GROUP:
if message.source_node_id is None:
return False
return None
# TODO: Get MRS for source node id and message type
else:
session_context = self.secure_session_contexts[message.session_id]
Expand All @@ -393,16 +427,22 @@ def counter_ok(self, message):
)
)
session_context = self.unsecured_session_context[message.source_node_id]
return session_context

def mark_duplicate(self, message):
"""Implements 4.6.7"""
session_context = self.get_session(message)

if session_context.message_reception_state is None:
session_context.message_reception_state = MessageReceptionState(
message.message_counter,
rollover=False,
encrypted=message.secure_session,
)
return True
message.duplicate = False
return

return session_context.message_reception_state.process_counter(
message.duplicate = session_context.message_reception_state.process_counter(
message.message_counter
)

Expand Down Expand Up @@ -443,3 +483,207 @@ def new_context(self):

self.secure_session_contexts[session_id] = SecureSessionContext(session_id)
return self.secure_session_contexts[session_id]

def process_exchange(self, message):
session = self.get_session(message)
if session is None:
return None
# Step 1 of 4.12.5.2
if (
message.exchange_flags & (ExchangeFlags.R | ExchangeFlags.A)
and not message.security_flags & SecurityFlags.C
and message.security_flags & SecurityFlags.GROUP
):
# Drop illegal combination of flags.
return None
if message.exchange_id not in session.exchanges:
# Section 4.10.5.2
initiator = message.exchange_flags & ExchangeFlags.I
if initiator and not message.duplicate:
session.exchanges[message.exchange_id] = Exchange(
not initiator, message.exchange_id, [message.protocol_id]
)
# Drop because the message isn't from an initiator.
elif message.exchange_flags & ExchangeFlags.R:
# Send a bare acknowledgement back.
raise NotImplementedError("Send a bare acknowledgement back")
return None
else:
# Just drop it.
return None

exchange = session.exchanges[message.exchange_id]
if exchange.receive(message):
# If we want to drop the message, then return None.
return None

return exchange


class CircuitMatter:
def __init__(self, socketpool, mdns_server, state_filename, record_to=None):
self.socketpool = socketpool
self.mdns_server = mdns_server
self.avahi = None
self.record_to = record_to
if self.record_to:
self.recorded_packets = []
else:
self.recorded_packets = None
self.manager = SessionManager()

with open(state_filename, "r") as state_file:
self.nonvolatile = json.load(state_file)

for key in ["descriminator", "salt", "iteration-count"]:
if key not in self.nonvolatile:
raise RuntimeError(f"Missing key {key} in state file")

commission = "fabrics" not in self.nonvolatile

self.packet_buffer = memoryview(bytearray(1280))

# Define the UDP IP address and port
UDP_IP = "::" # Listen on all available network interfaces
self.UDP_PORT = 5540

# Create the UDP socket
self.socket = self.socketpool.socket(
self.socketpool.AF_INET6, self.socketpool.SOCK_DGRAM
)

# Bind the socket to the IP and port
self.socket.bind((UDP_IP, self.UDP_PORT))
self.socket.setblocking(False)

print(f"Listening on UDP port {self.UDP_PORT}")

if commission:
self.start_commissioning()

def start_commissioning(self):
descriminator = self.nonvolatile["descriminator"]
txt_records = {
"PI": "",
"PH": "33",
"CM": "1",
"D": str(descriminator),
"CRI": "3000",
"CRA": "4000",
"T": "1",
"VP": "65521+32769",
}
self.mdns_server.advertise_service(
"_matterc",
"_udp",
self.UDP_PORT,
txt_records=txt_records,
instance_name="FA93546B21F5FB54",
subtypes=[
f"_L{descriminator}._sub._matterc._udp",
"_CM._sub._matterc._udp",
],
)

def process_packets(self):
while True:
try:
nbytes, addr = self.socket.recvfrom_into(
self.packet_buffer, len(self.packet_buffer)
)
except BlockingIOError:
break
if nbytes == 0:
break
if self.recorded_packets is not None:
self.recorded_packets.append(
(
"receive",
time.monotonic_ns(),
addr,
binascii.b2a_base64(
self.packet_buffer[:nbytes], newline=False
).decode("utf-8"),
)
)

self.process_packet(addr, self.packet_buffer[:nbytes])

def process_packet(self, address, data):
# Print the received data and the address of the sender
# This is section 4.7.2
message = Message(data)
if message.secure_session:
# Decrypt the payload
pass
message.parse_protocol_header()
self.manager.mark_duplicate(message)

exchange = self.manager.process_exchange(message)
if exchange is None:
print(f"Dropping message {message.message_counter}")
return

print(f"Received packet from {address}:")
print(f"{data.hex(' ')}")
print(f"Message counter {message.message_counter}")
protocol_id = message.protocol_id
protocol_opcode = message.protocol_opcode

if protocol_id == ProtocolId.SECURE_CHANNEL:
if protocol_opcode == SecureProtocolOpcode.MSG_COUNTER_SYNC_REQ:
print("Received Message Counter Synchronization Request")
elif protocol_opcode == SecureProtocolOpcode.MSG_COUNTER_SYNC_RSP:
print("Received Message Counter Synchronization Response")
elif protocol_opcode == SecureProtocolOpcode.PBKDF_PARAM_REQUEST:
print("Received PBKDF Parameter Request")
# This is Section 4.14.1.2
request = PBKDFParamRequest(message.application_payload[1:-1])
if request.passcodeId == 0:
pass
# Send back failure
# response = StatusReport()
# response.GeneralCode
print(request)
response = PBKDFParamResponse()
response.initiatorRandom = request.initiatorRandom

# Generate a random number
response.responderRandom = os.urandom(32)
session_context = self.manager.new_context()
response.responderSessionId = session_context.local_session_id
session_context.peer_session_id = request.initiatorSessionId
if not request.hasPBKDFParameters:
params = Crypto_PBKDFParameterSet()
params.iterations = self.nonvolatile["iteration-count"]
params.salt = binascii.a2b_base64(self.nonvolatile["salt"])
response.pbkdf_parameters = params
print(response)

elif protocol_opcode == SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:
print("Received PBKDF Parameter Response")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE1:
print("Received PASE PAKE1")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE2:
print("Received PASE PAKE2")
elif protocol_opcode == SecureProtocolOpcode.PASE_PAKE3:
print("Received PASE PAKE3")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA1:
print("Received CASE Sigma1")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2:
print("Received CASE Sigma2")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA3:
print("Received CASE Sigma3")
elif protocol_opcode == SecureProtocolOpcode.CASE_SIGMA2_RESUME:
print("Received CASE Sigma2 Resume")
elif protocol_opcode == SecureProtocolOpcode.STATUS_REPORT:
print("Received Status Report")
elif protocol_opcode == SecureProtocolOpcode.ICD_CHECK_IN:
print("Received ICD Check-in")

def __del__(self):
if self.avahi:
self.avahi.kill()
if self.recorded_packets and self.record_to:
with open(self.record_to, "w") as record_file:
json.dump(self.recorded_packets, record_file)
Loading

0 comments on commit 6078dbe

Please sign in to comment.