Skip to content

Commit

Permalink
Encrypts and sends but isn't decryptable
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Sep 17, 2024
1 parent ba30cd9 commit 39c27d0
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 91 deletions.
142 changes: 115 additions & 27 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def send(self, message):


class SecureSessionContext:
def __init__(self, local_session_id):
def __init__(self, socket, local_session_id):
self.session_type = None
"""Records whether the session was established using CASE or PASE."""
self.session_role_initiator = False
Expand All @@ -320,13 +320,13 @@ def __init__(self, local_session_id):
"""Encrypts data in messages sent from the session establishment responder to the initiator."""
self.shared_secret = None
"""Computed during the CASE protocol execution and re-used when CASE session resumption is implemented."""
self.local_message_counter = None
self.local_message_counter = MessageCounter()
"""Secure Session Message Counter for outbound messages."""
self.message_reception_state = None
"""Provides tracking for the Secure Session Message Counter of the remote"""
self.local_fabric_index = None
"""Records the local Index for the session’s Fabric, which MAY be used to look up Fabric metadata related to the Fabric for which this session context applies."""
self.peer_node_id = None
self.peer_node_id = 0
"""Records the authenticated node ID of the remote peer, when available."""
self.resumption_id = None
"""The ID used when resuming a session between the local and remote peer."""
Expand All @@ -340,6 +340,8 @@ def __init__(self, local_session_id):
self.exchanges = {}

self._nonce = bytearray(session.CRYPTO_AEAD_NONCE_LENGTH_BYTES)
self.socket = socket
self.node_ipaddress = None

@property
def peer_active(self):
Expand Down Expand Up @@ -370,6 +372,22 @@ def decrypt_and_verify(self, message):
message.payload = decrypted_payload
return True

def send(self, message):
message.session_id = self.peer_session_id
cipher = self.r2i
if self.session_role_initiator:
cipher = self.i2r

self.session_timestamp = time.monotonic()

message.destination_node_id = self.peer_node_id
if message.message_counter is None:
message.message_counter = next(self.local_message_counter)

buf = memoryview(bytearray(1280))
nbytes = message.encode_into(buf, cipher)
self.socket.sendto(buf[:nbytes], self.node_ipaddress)


class Message:
def __init__(self):
Expand All @@ -380,8 +398,8 @@ def clear(self):
self.session_id: int = 0
self.security_flags: SecurityFlags = SecurityFlags(0)
self.message_counter: Optional[int] = None
self.source_node_id = None
self.destination_node_id = None
self.source_node_id = 0
self.destination_node_id = 0
self.secure_session: Optional[bool] = None
self.payload = None
self.duplicate: Optional[bool] = None
Expand Down Expand Up @@ -440,7 +458,7 @@ def decode(self, buffer):
self.source_node_id = struct.unpack_from("<Q", buffer, 8)[0]
offset += 8
else:
self.source_node_id = None
self.source_node_id = 0

if (self.flags >> 4) != 0:
raise RuntimeError("Incorrect version")
Expand All @@ -453,7 +471,7 @@ def decode(self, buffer):
self.payload = memoryview(buffer)[offset:]
self.duplicate = None

def encode_into(self, buffer):
def encode_into(self, buffer, cipher=None):
offset = 0
struct.pack_into(
"<BHBI",
Expand All @@ -464,11 +482,15 @@ def encode_into(self, buffer):
self.security_flags,
self.message_counter,
)
print(self.flags, self.session_id)
nonce_start = 3
nonce_end = nonce_start + 1 + 4
offset += 8
if self.source_node_id is not None:
if self.source_node_id > 0:
struct.pack_into("<Q", buffer, offset, self.source_node_id)
offset += 8
if self.destination_node_id is not None:
nonce_end += 8
if self.destination_node_id > 0:
if self.destination_node_id > 0xFFFF_FFFF_FFFF_0000:
struct.pack_into(
"<H", buffer, offset, self.destination_node_id & 0xFFFF
Expand All @@ -490,21 +512,58 @@ def encode_into(self, buffer):
if self.acknowledged_message_counter is not None:
struct.pack_into("I", buffer, offset, self.acknowledged_message_counter)
offset += 4

if cipher is not None:
unencrypted_buffer = memoryview(bytearray(1280))
unencrypted_offset = 0
else:
unencrypted_buffer = buffer
unencrypted_offset = offset

if self.application_payload is not None:
if isinstance(self.application_payload, tlv.TLVStructure):
# Wrap the structure in an anonymous tag.
buffer[offset] = 0x15
offset += 1
offset = self.application_payload.encode_into(buffer, offset)
buffer[offset] = 0x18
offset += 1
unencrypted_buffer[unencrypted_offset] = 0x15
unencrypted_offset += 1
unencrypted_offset = self.application_payload.encode_into(
unencrypted_buffer, unencrypted_offset
)
unencrypted_buffer[unencrypted_offset] = 0x18
unencrypted_offset += 1
elif isinstance(self.application_payload, StatusReport):
offset = self.application_payload.encode_into(buffer, offset)
else:
buffer[offset : offset + len(self.application_payload)] = (
self.application_payload
unencrypted_offset = self.application_payload.encode_into(
unencrypted_buffer, unencrypted_offset
)
offset += len(self.application_payload)
else:
# Skip a copy operation if we're using a separate unencrypted buffer
if unencrypted_offset == 0:
unencrypted_buffer = self.application_payload
else:
unencrypted_buffer[
unencrypted_offset : unencrypted_offset
+ len(self.application_payload)
] = self.application_payload
unencrypted_offset += len(self.application_payload)

# Encrypt the payload
if cipher is not None:
# The message may not include the source_node_id so we encode the nonce separately.
print(self.message_counter)
nonce = struct.pack(
"<BIQ", self.security_flags, self.message_counter, self.source_node_id
)
print("nonce", nonce_end - nonce_start, nonce.hex(" "))
additional = buffer[:offset]
self.payload = cipher.encrypt(
nonce, bytes(unencrypted_buffer[:unencrypted_offset]), bytes(additional)
)
print("encrypted", len(self.payload), self.payload.hex(" "))
buffer[offset : offset + len(self.payload)] = self.payload
offset += len(self.payload)
else:
offset = unencrypted_offset

print("encoded", buffer[:offset].hex(" "))
return offset

@property
Expand All @@ -514,7 +573,7 @@ def source_node_id(self):
@source_node_id.setter
def source_node_id(self, value):
self._source_node_id = value
if value is not None:
if value > 0:
self.flags |= 1 << 2
else:
self.flags &= ~(1 << 2)
Expand All @@ -528,7 +587,7 @@ def destination_node_id(self, value):
self._destination_node_id = value
# Clear the field
self.flags &= ~0x3
if value is None:
if value == 0:
pass
elif value > 0xFFFF_FFFF_FFFF_0000:
self.flags |= 2
Expand Down Expand Up @@ -708,6 +767,7 @@ def get_session(self, message):
# TODO: Get MRS for source node id and message type
else:
session_context = self.secure_session_contexts[message.session_id]
session_context.node_ipaddress = message.source_ipaddress
else:
if message.source_node_id not in self.unsecured_session_context:
self.unsecured_session_context[message.source_node_id] = (
Expand Down Expand Up @@ -774,7 +834,9 @@ def new_context(self):
self.secure_session_contexts.append(None)
session_id = self.secure_session_contexts.index(None)

self.secure_session_contexts[session_id] = SecureSessionContext(session_id)
self.secure_session_contexts[session_id] = SecureSessionContext(
self.socket, session_id
)
return self.secure_session_contexts[session_id]

def process_exchange(self, message):
Expand Down Expand Up @@ -814,7 +876,15 @@ def process_exchange(self, message):


class CircuitMatter:
def __init__(self, socketpool, mdns_server, random_source, state_filename):
def __init__(
self,
socketpool,
mdns_server,
random_source,
state_filename,
vendor_id=0xFFF1,
product_id=0,
):
self.socketpool = socketpool
self.mdns_server = mdns_server
self.random = random_source
Expand Down Expand Up @@ -851,9 +921,21 @@ def __init__(self, socketpool, mdns_server, random_source, state_filename):
self.start_commissioning()

self._endpoints = {}
self.add_cluster(0, data_model.BasicInformationCluster())
self.add_cluster(0, data_model.NetworkCommissioningCluster())
self.add_cluster(0, data_model.GeneralCommissioningCluster())
basic_info = data_model.BasicInformationCluster()
basic_info.vendor_id = vendor_id
basic_info.product_id = product_id
self.add_cluster(0, basic_info)
network_info = data_model.NetworkCommissioningCluster()
network_info.connect_max_time_seconds = 10
self.add_cluster(0, network_info)
general_commissioning = data_model.GeneralCommissioningCluster()
basic_commissioning_info = (
data_model.GeneralCommissioningCluster.BasicCommissioningInfo()
)
basic_commissioning_info.FailSafeExpiryLengthSeconds = 10
basic_commissioning_info.MaxCumulativeFailsafeSeconds = 900
general_commissioning.basic_commissioning_info = basic_commissioning_info
self.add_cluster(0, general_commissioning)

def start_commissioning(self):
descriminator = self.nonvolatile["descriminator"]
Expand Down Expand Up @@ -903,6 +985,8 @@ def get_report(self, cluster, path):
astatus = interaction_model.AttributeStatusIB()
astatus.Path = path
status = interaction_model.StatusIB()
status.Status = 0
status.ClusterStatus = 0
astatus.Status = status
report.AttributeStatus = astatus
report.AttributeData = cluster.get_attribute_data(path)
Expand Down Expand Up @@ -1093,7 +1177,11 @@ def process_packet(self, address, data):
print(f"Cluster 0x{path.Cluster:02x} not found")
response = interaction_model.ReportDataMessage()
response.AttributeReports = attribute_reports
print(response)
exchange.send(
ProtocolId.INTERACTION_MODEL,
InteractionModelOpcode.REPORT_DATA,
response,
)
if protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
Expand Down
Loading

0 comments on commit 39c27d0

Please sign in to comment.