Skip to content

Commit

Permalink
Work on sending a message
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Jul 19, 2024
1 parent 6078dbe commit fecabb3
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 62 deletions.
232 changes: 171 additions & 61 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import struct
import time

from typing import Optional

from . import tlv

__version__ = "0.0.0"
Expand All @@ -17,32 +19,51 @@
MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
MSG_COUNTER_SYNC_TIMEOUT_MS = 400

# Section 4.12.8
MRP_MAX_TRANSMISSIONS = 5
"""The maximum number of transmission attempts for a given reliable message. The sender MAY choose this value as it sees fit."""

MRP_BACKOFF_BASE = 1.6
"""The base number for the exponential backoff equation."""

MRP_BACKOFF_JITTER = 0.25
"""The scaler for random jitter in the backoff equation."""

MRP_BACKOFF_MARGIN = 1.1
"""The scaler margin increase to backoff over the peer idle interval."""

MRP_BACKOFF_THRESHOLD = 1
"""The number of retransmissions before transitioning from linear to exponential backoff."""

class ProtocolId(enum.Enum):
MRP_STANDALONE_ACK_TIMEOUT_MS = 200
"""Amount of time to wait for an opportunity to piggyback an acknowledgement on an outbound message before falling back to sending a standalone acknowledgement."""


class ProtocolId(enum.IntEnum):
SECURE_CHANNEL = 0
INTERACTION_MODEL = 1
BDX = 2
USER_DIRECTED_COMMISSIONING = 3
FOR_TESTING = 4


class SecurityFlags(enum.Flag):
class SecurityFlags(enum.IntFlag):
P = 1 << 7
C = 1 << 6
MX = 1 << 5
# This is actually 2 bits but the top bit is reserved and always zero.
GROUP = 1 << 0


class ExchangeFlags(enum.Flag):
class ExchangeFlags(enum.IntFlag):
V = 1 << 4
SX = 1 << 3
R = 1 << 2
A = 1 << 1
I = 1 << 0 # noqa: E741


class SecureProtocolOpcode(enum.Enum):
class SecureProtocolOpcode(enum.IntEnum):
MSG_COUNTER_SYNC_REQ = 0x00
"""The Message Counter Synchronization Request message queries the current message counter from a peer to bootstrap replay protection."""

Expand Down Expand Up @@ -230,55 +251,108 @@ def process_counter(self, counter) -> bool:
return False


class MessageCounter:
def __init__(self, starting_value=None):
if starting_value is None:
starting_value = os.urandom(4)
starting_value = struct.unpack("<I", starting_value)[0]
starting_value >>= 4
starting_value += 1
self.value = starting_value

def __next__(self):
self.value = (self.value + 1) % 0xFFFFFFFF
return self.value


class Exchange:
def __init__(self, initiator: bool, exchange_id: int, protocols):
def __init__(self, session, initiator: bool, exchange_id: int, protocols):
self.initiator = initiator
self.exchange_id = exchange_id
self.protocols = protocols
self.session = session

self.pending_acknowledgement = None
"""Message number that is waiting for an ack from us"""
self.send_standalone_time = None

self.next_retransmission_time = None
"""When to next resend the message that hasn't been acked"""
self.pending_retransmission = None

def send(self, message):
pass
"""Message that we've attempted to send but hasn't been acked"""

def send(self, protocol_id, protocol_opcode, application_payload=None):
message = Message()
message.exchange_flags = ExchangeFlags(0)
if self.initiator:
message.exchange_flags |= ExchangeFlags.I
if self.pending_acknowledgement is not None:
message.exchange_flags |= ExchangeFlags.A
self.send_standalone_time = None
self.pending_acknowledgement = None
message.protocol_id = protocol_id
message.protocol_opcode = protocol_opcode
message.exchange_id = self.exchange_id
message.application_payload = application_payload
self.session.send(message)

def send_standalone(self):
self.send(
ProtocolId.SECURE_CHANNEL, SecureProtocolOpcode.MRP_STANDALONE_ACK, None
)

def receive(self, message) -> bool:
"""Process the message and return if the packet should be dropped."""
if message.protocol_id not in self.protocols:
# Drop messages that don't match the protocols we're waiting for.
return True

# Section 4.10.5.2.1
# Section 4.12.5.2.1
if message.exchange_flags & ExchangeFlags.A:
if message.acknowledged_message_counter is None:
# Drop messages that are missing an acknowledgement counter.
return True
if self.pending_acknowledgement is None:
# Drop messages that are not waiting for an acknowledgement.
return True
if message.acknowledged_message_counter != self.pending_acknowledgement:
# Drop messages that have the wrong acknowledgement counter.
return True
self.pending_acknowledgement = None
self.pending_retransmission = None
self.next_retransmission_time = None

# Section 4.10.5.2.2
# if message.exchange_flags & ExchangeFlags.R:
# if message
# Section 4.12.5.2.2
# Incoming packets that are marked Reliable.
if message.exchange_flags & ExchangeFlags.R:
if message.duplicate:
# Send a standalone acknowledgement.
return True
if self.pending_acknowledgement is not None:
# Send a standalone acknowledgement with the message counter we're about to overwrite.
pass
self.pending_acknowledgement = message.message_counter
self.send_standalone_time = (
time.monotonic() + MRP_STANDALONE_ACK_TIMEOUT_MS / 1000
)

if message.duplicate:
return True
return False


class UnsecuredSessionContext:
def __init__(self, initiator, ephemeral_initiator_node_id):
def __init__(self, message_counter, initiator, ephemeral_initiator_node_id):
self.initiator = initiator
self.ephemeral_initiator_node_id = ephemeral_initiator_node_id
self.message_reception_state = None
self.message_counter = message_counter
self.exchanges = {}

def send(self, message):
message.destination_node_id = self.ephemeral_initiator_node_id
if message.message_counter is None:
message.message_counter = next(self.message_counter)
buf = memoryview(bytearray(1280))
nbytes = message.encode_into(buf)
print(nbytes, buf[:nbytes].hex(" "))


class SecureSessionContext:
def __init__(self, local_session_id):
Expand Down Expand Up @@ -321,30 +395,29 @@ def peer_active(self):


class Message:
def __init__(self, buffer):
self.buffer = buffer
self.flags, self.session_id, self.security_flags, self.message_counter = (
struct.unpack_from("<BHBI", buffer)
)
self.security_flags = SecurityFlags(self.security_flags)
offset = 8
def __init__(self):
self.clear()

def clear(self):
self.flags: int = 0
self.session_id: int = 0
self.security_flags: SecurityFlags = SecurityFlags(0)
self.message_counter: int = 0
self.source_node_id = None
if self.flags & (1 << 2):
self.source_node_id = struct.unpack_from("<Q", buffer, 8)[0]
offset += 8
self.secure_session: Optional[bool] = None
self.payload = None
self.duplicate: Optional[bool] = None

if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
self.secure_session = not (
not (self.security_flags & SecurityFlags.GROUP) and self.session_id == 0
)
# Filled in after the message payload is decrypted.
self.exchange_flags: ExchangeFlags = ExchangeFlags(0)
self.exchange_id: Optional[int] = None

if not self.secure_session:
self.payload = memoryview(buffer)[offset:]
else:
self.payload = None
self.protocol_vendor_id = 0
self.protocol_id = ProtocolId(0)
self.protocol_opcode: Optional[int] = None

self.duplicate = None
self.acknowledged_message_counter = None
self.application_payload = None

def parse_protocol_header(self):
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
Expand Down Expand Up @@ -373,15 +446,48 @@ def parse_protocol_header(self):

self.application_payload = self.payload[decrypted_offset:]

def reply(self, payload, protocol_id=None, protocol_opcode=None) -> memoryview:
reply = bytearray(1280)
offset = 0
def decode(self, buffer):
self.clear()
self.buffer = buffer
self.flags, self.session_id, self.security_flags, self.message_counter = (
struct.unpack_from("<BHBI", buffer)
)
self.security_flags = SecurityFlags(self.security_flags)
offset = 8
self.source_node_id = None
if self.flags & (1 << 2):
self.source_node_id = struct.unpack_from("<Q", buffer, 8)[0]
offset += 8

# struct.pack_into(
# "<BHBI", reply, offset, flags, session_id, security_flags, message_counter
# )
# offset += 8
return memoryview(reply)[:offset]
if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
self.secure_session = not (
not (self.security_flags & SecurityFlags.GROUP) and self.session_id == 0
)

if not self.secure_session:
self.payload = memoryview(buffer)[offset:]
else:
self.payload = None

self.duplicate = None

def encode_into(self, buffer):
offset = 0
struct.pack_into(
"BBHH",
buffer,
offset,
self.exchange_flags,
self.protocol_opcode,
self.exchange_id,
self.protocol_id,
)
offset += 6
if self.acknowledged_message_counter is not None:
struct.pack_into("I", buffer, offset, self.acknowledged_message_counter)
offset += struct.calcsize(4)
return offset


class SessionManager:
Expand All @@ -394,15 +500,15 @@ def __init__(self):
self.nonvolatile["unencrypted_message_counter"] = 0
self.nonvolatile["group_encrypted_data_message_counter"] = 0
self.nonvolatile["group_encrypted_control_message_counter"] = 0
self.unencrypted_message_counter = self.nonvolatile[
"unencrypted_message_counter"
]
self.group_encrypted_data_message_counter = self.nonvolatile[
"group_encrypted_data_message_counter"
]
self.group_encrypted_control_message_counter = self.nonvolatile[
"group_encrypted_control_message_counter"
]
self.unencrypted_message_counter = MessageCounter(
self.nonvolatile["unencrypted_message_counter"]
)
self.group_encrypted_data_message_counter = MessageCounter(
self.nonvolatile["group_encrypted_data_message_counter"]
)
self.group_encrypted_control_message_counter = MessageCounter(
self.nonvolatile["group_encrypted_control_message_counter"]
)
self.check_in_counter = 0
self.unsecured_session_context = {}
self.secure_session_contexts = ["reserved"]
Expand All @@ -422,6 +528,7 @@ def get_session(self, message):
if message.source_node_id not in self.unsecured_session_context:
self.unsecured_session_context[message.source_node_id] = (
UnsecuredSessionContext(
self.unencrypted_message_counter,
initiator=False,
ephemeral_initiator_node_id=message.source_node_id,
)
Expand Down Expand Up @@ -501,7 +608,7 @@ def process_exchange(self, message):
initiator = message.exchange_flags & ExchangeFlags.I
if initiator and not message.duplicate:
session.exchanges[message.exchange_id] = Exchange(
not initiator, message.exchange_id, [message.protocol_id]
session, not initiator, message.exchange_id, [message.protocol_id]
)
# Drop because the message isn't from an initiator.
elif message.exchange_flags & ExchangeFlags.R:
Expand All @@ -524,7 +631,6 @@ class CircuitMatter:
def __init__(self, socketpool, mdns_server, state_filename, record_to=None):
self.socketpool = socketpool
self.mdns_server = mdns_server
self.avahi = None
self.record_to = record_to
if self.record_to:
self.recorded_packets = []
Expand Down Expand Up @@ -573,12 +679,13 @@ def start_commissioning(self):
"T": "1",
"VP": "65521+32769",
}
instance_name = os.urandom(8).hex().upper()
self.mdns_server.advertise_service(
"_matterc",
"_udp",
self.UDP_PORT,
txt_records=txt_records,
instance_name="FA93546B21F5FB54",
instance_name=instance_name,
subtypes=[
f"_L{descriminator}._sub._matterc._udp",
"_CM._sub._matterc._udp",
Expand Down Expand Up @@ -612,7 +719,8 @@ def process_packets(self):
def process_packet(self, address, data):
# Print the received data and the address of the sender
# This is section 4.7.2
message = Message(data)
message = Message()
message.decode(data)
if message.secure_session:
# Decrypt the payload
pass
Expand Down Expand Up @@ -658,7 +766,11 @@ def process_packet(self, address, data):
params.iterations = self.nonvolatile["iteration-count"]
params.salt = binascii.a2b_base64(self.nonvolatile["salt"])
response.pbkdf_parameters = params
print(response)
exchange.send(
ProtocolId.SECURE_CHANNEL,
SecureProtocolOpcode.PBKDF_PARAM_RESPONSE,
response.encode(),
)

elif protocol_opcode == SecureProtocolOpcode.PBKDF_PARAM_RESPONSE:
print("Received PBKDF Parameter Response")
Expand All @@ -682,8 +794,6 @@ def process_packet(self, address, data):
print("Received ICD Check-in")

def __del__(self):
if self.avahi:
self.avahi.kill()
if self.recorded_packets and self.record_to:
with open(self.record_to, "w") as record_file:
json.dump(self.recorded_packets, record_file)
Loading

0 comments on commit fecabb3

Please sign in to comment.