Skip to content

Commit

Permalink
WIP: Decrypt first message
Browse files Browse the repository at this point in the history
Gotta figure out how to decode lists and arrays.
  • Loading branch information
tannewt committed Jul 29, 2024
1 parent 70d2e65 commit 7e235cd
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 9 deletions.
139 changes: 131 additions & 8 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from ecdsa.ellipticcurve import AbstractPoint, Point, PointJacobi
from ecdsa.curves import NIST256p

import cryptography
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

from typing import Optional

from . import tlv
Expand Down Expand Up @@ -121,8 +124,22 @@ class SecureProtocolOpcode(enum.IntEnum):
"""The Check-in message notifies a client that the ICD is available for communication."""


class InteractionModelOpcode(enum.IntEnum):
STATUS_RESPONSE = 0x01
READ_REQUEST = 0x02
SUBSCRIBE_REQUEST = 0x03
SUBSCRIBE_RESPONSE = 0x04
REPORT_DATA = 0x05
WRITE_REQUEST = 0x06
WRITE_RESPONSE = 0x07
INVOKE_REQUEST = 0x08
INVOKE_RESPONSE = 0x09
TIMED_REQUEST = 0x0A


PROTOCOL_OPCODES = {
ProtocolId.SECURE_CHANNEL: SecureProtocolOpcode,
ProtocolId.INTERACTION_MODEL: InteractionModelOpcode,
}


Expand Down Expand Up @@ -228,6 +245,56 @@ class PAKE3(tlv.TLVStructure):
cA = tlv.OctetStringMember(1, CRYPTO_HASH_LEN_BYTES)


class AttributePathIB(tlv.TLVList):
"""Section 10.6.2"""

EnableTagCompression = tlv.BoolMember(0)
Node = tlv.IntMember(1, signed=False, octets=8)
Endpoint = tlv.IntMember(2, signed=False, octets=2)
Cluster = tlv.IntMember(3, signed=False, octets=4)
Attribute = tlv.IntMember(4, signed=False, octets=4)
ListIndex = tlv.IntMember(5, signed=False, octets=2, nullable=True)
WildcardPathFlags = tlv.IntMember(6, signed=False, octets=4)


class EventPathIB(tlv.TLVList):
"""Section 10.6.8"""

Node = tlv.IntMember(0, signed=False, octets=8)
Endpoint = tlv.IntMember(1, signed=False, octets=2)
Cluster = tlv.IntMember(2, signed=False, octets=4)
Event = tlv.IntMember(3, signed=False, octets=4)
IsUrgent = tlv.BoolMember(4)


class EventFilterIB(tlv.TLVStructure):
"""Section 10.6.6"""

Node = tlv.IntMember(0, signed=False, octets=8)
EventMinimumInterval = tlv.IntMember(1, signed=False, octets=8)


class ClusterPathIB(tlv.TLVList):
Node = tlv.IntMember(0, signed=False, octets=8)
Endpoint = tlv.IntMember(1, signed=False, octets=2)
Cluster = tlv.IntMember(2, signed=False, octets=4)


class DataVersionFilterIB(tlv.TLVStructure):
Path = tlv.ContainerMember(0, ClusterPathIB)
DataVersion = tlv.IntMember(1, signed=False, octets=4)


class ReadRequestMessage(tlv.TLVStructure):
FabricFiltered = tlv.BoolMember(3)

def __init__(self):
self.AttributeRequests = tlv.ArrayMember(0, AttributePathIB)
self.EventRequests = tlv.ArrayMember(1, EventPathIB)
self.EventFilters = tlv.ArrayMember(2, EventFilterIB)
self.DataVersionFilters = tlv.ArrayMember(4, DataVersionFilterIB)


class MessageReceptionState:
def __init__(self, starting_value, rollover=True, encrypted=False):
"""Implements 4.6.5.1"""
Expand Down Expand Up @@ -404,7 +471,7 @@ class SecureSessionContext:
def __init__(self, local_session_id):
self.session_type = None
"""Records whether the session was established using CASE or PASE."""
self.session_role = None
self.session_role_initiator = False
"""Records whether the node is the session initiator or responder."""
self.local_session_id = local_session_id
"""Individually selected by each participant in secure unicast communication during session establishment and used as a unique identifier to recover encryption keys, authenticate incoming messages and associate them to existing sessions."""
Expand Down Expand Up @@ -435,10 +502,37 @@ def __init__(self, local_session_id):
self.session_active_threshold = None
self.exchanges = {}

self._nonce = bytearray(CRYPTO_AEAD_NONCE_LENGTH_BYTES)

@property
def peer_active(self):
return (time.monotonic() - self.active_timestamp) < self.session_active_interval

def decrypt_and_verify(self, message):
cipher = self.i2r
if self.session_role_initiator:
cipher = self.r2i
try:
source_node_id = 0 # for secure unicast messages
# TODO: Support group messages
struct.pack_into(
"<BIQ",
self._nonce,
0,
message.security_flags,
message.message_counter,
source_node_id,
)
decrypted_payload = cipher.decrypt(
self._nonce, bytes(message.payload), bytes(message.header)
)
except cryptography.exceptions.InvalidTag:
return False

message.decrypted = True
message.payload = decrypted_payload
return True


class Message:
def __init__(self):
Expand All @@ -450,6 +544,7 @@ def clear(self):
self.security_flags: SecurityFlags = SecurityFlags(0)
self.message_counter: Optional[int] = None
self.source_node_id = None
self.destination_node_id = None
self.secure_session: Optional[bool] = None
self.payload = None
self.duplicate: Optional[bool] = None
Expand All @@ -467,6 +562,8 @@ def clear(self):

self.source_ipaddress = None

self.header = None

def parse_protocol_header(self):
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
struct.unpack_from("<BBH", self.payload)
Expand Down Expand Up @@ -513,12 +610,10 @@ def decode(self, buffer):
self.secure_session = not (
not (self.security_flags & SecurityFlags.GROUP) and self.session_id == 0
)
self.decrypted = not self.secure_session

if not self.secure_session:
self.payload = memoryview(buffer)[offset:]
else:
self.payload = None

self.header = memoryview(buffer)[:offset]
self.payload = memoryview(buffer)[offset:]
self.duplicate = None

def encode_into(self, buffer):
Expand Down Expand Up @@ -596,7 +691,9 @@ def destination_node_id(self, value):
self._destination_node_id = value
# Clear the field
self.flags &= ~0x3
if value > 0xFFFF_FFFF_FFFF_0000:
if value is None:
pass
elif value > 0xFFFF_FFFF_FFFF_0000:
self.flags |= 2
elif value > 0:
self.flags |= 1
Expand Down Expand Up @@ -1110,7 +1207,15 @@ def process_packet(self, address, data):
message.source_ipaddress = address
if message.secure_session:
# Decrypt the payload
pass
print("decrypt message", message.session_id)
secure_session_context = self.manager.secure_session_contexts[
message.session_id
]
print(secure_session_context)
print(message)
print(message.payload.hex(" "))
ok = secure_session_context.decrypt_and_verify(message)
print("decrypt ok?", ok)
message.parse_protocol_header()
self.manager.mark_duplicate(message)

Expand Down Expand Up @@ -1254,10 +1359,18 @@ def process_packet(self, address, data):
secure_session_context.i2r_key = keys[
:CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES
]
secure_session_context.i2r = AESCCM(
secure_session_context.i2r_key,
tag_length=CRYPTO_AEAD_MIC_LENGTH_BYTES,
)
secure_session_context.r2i_key = keys[
CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES : 2
* CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES
]
secure_session_context.r2i = AESCCM(
secure_session_context.r2i_key,
tag_length=CRYPTO_AEAD_MIC_LENGTH_BYTES,
)
secure_session_context.attestation_challenge = keys[
2 * CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES :
]
Expand All @@ -1277,6 +1390,16 @@ def process_packet(self, address, data):
print(report)
elif protocol_opcode == SecureProtocolOpcode.ICD_CHECK_IN:
print("Received ICD Check-in")
elif message.protocol_id == ProtocolId.INTERACTION_MODEL:
print(message)
if protocol_opcode == InteractionModelOpcode.READ_REQUEST:
print("Received Read Request")
read_request = ReadRequestMessage(message.application_payload[1:-1])
print(read_request)
if protocol_opcode == InteractionModelOpcode.INVOKE_REQUEST:
print("Received Invoke Request")
elif protocol_opcode == InteractionModelOpcode.INVOKE_RESPONSE:
print("Received Invoke Response")

def __del__(self):
if self.recorded_packets and self.record_to:
Expand Down
29 changes: 29 additions & 0 deletions circuitmatter/tlv.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,32 @@ def encode_value_into(self, value, buffer: bytearray, offset: int) -> int:
offset = value.encode_into(buffer, offset)
buffer[offset] = ElementType.END_OF_CONTAINER
return offset + 1


class ArrayMember(Member[_TLVStruct, _OPT, _NULLABLE]):
def __init__(
self,
tag,
substruct_class: Type[_TLVStruct],
*,
optional: _OPT = False,
nullable: _NULLABLE = False,
**kwargs,
):
self.substruct_class = substruct_class
self.max_value_length = 1280
super().__init__(tag, optional=optional, nullable=nullable, **kwargs)

def decode(self, buffer, length, offset=0):
return self.substruct_class(buffer[offset : offset + length])

def _print(self, value):
return str(value)

def encode_element_type(self, value):
return ElementType.STRUCTURE

def encode_value_into(self, value, buffer: bytearray, offset: int) -> int:
offset = value.encode_into(buffer, offset)
buffer[offset] = ElementType.END_OF_CONTAINER
return offset + 1
2 changes: 1 addition & 1 deletion test_data/recorded_packets.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[["receive", 19144358180711, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 45718, 0, 0], "BAAAAGpkOgQwkwDA2XDoNAUgd1gAABUwASAqDGF6b0TLjYSHpiEP8ULWWJUbuZ1RXNOefpJ4KFje9SUCPPQkAwAoBDUFJQH0ASUCLAElA6APJAQRJAULJgYAAAMBJAcBGBg="], ["receive", 19144364732600, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 45718, 0, 0], "BAAAAGtkOgQwkwDA2XDoNAUid1gAABUwAUEEf8gHR/MR/oQGRHJRxZirzEa42qHq4qH8aIioXnNkXrWr8d2LZRb3nIYMKWt+QDmTvIp+9L79mr68C+OrH3WiYBg="], ["receive", 19144379606119, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 45718, 0, 0], "BAAAAGxkOgQwkwDA2XDoNAUkd1gAABUwASC8iRG3Ehkfyze+iIPOsRNSstOwlc67pZmCj1WLxvY0kBg="], ["receive", 19144379820804, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 45718, 0, 0], "AAEAAJslbgoieWTcgmp1jDKhidqXMqcxlhS1MBWeoTOdBWCmqQMUUKBjJiUx3DB2FWrCLJbwXxHwYmtYsReJ9DNcXrt7d46v984TaDqT+tt+ZjdRxOVJDP/wDlTrHWTZ/5ybvBXRBzZrtSkRC0K47EyiRQ/KXpyzAY1w66SPk/2IGjI="]]
[["receive", 22053600090440, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "BAAAALGoxwBQJ8u6A08zlwUgUU0AABUwASAsETJy1MZI35zvWSjBtjdwFK1FwzEpKJCepZW/hKLqLCUCzAUkAwAoBDUFJQH0ASUCLAElA6APJAQRJAULJgYAAAMBJAcBGBg="], ["receive", 22053606380662, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "BAAAALKoxwBQJ8u6A08zlwUiUU0AABUwAUEELPUQAb7V2HvHRBNMyJiVQrGlHjw9FIz+41h5wxnvYuZhg2wGT+hqfyP7RVOH4QcdAAFlIAdemfr8c3bRVKoIkRg="], ["receive", 22053621334402, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "BAAAALOoxwBQJ8u6A08zlwUkUU0AABUwASDjA9fk82ToT3rKrxDQMOK+7Pp1APSodFRzvO6UuIJzOhg="], ["receive", 22053621532375, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="], ["receive", 22054010148564, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="], ["receive", 22054413609739, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="], ["receive", 22054982238204, ["fd98:bbab:bd61:8040:642:1aff:fe0c:9f2a", 44432, 0, 0], "AAEAAALzMgF457821uGoHhEhDzUhrzWXXEH8HBRfHNlrizMRs0CHkhx7+odFx/bY8o8JFB1FPlM8SFCCdzcEVtwm03ncTQIYw/SYTFXpkkY/wJNjEndAfBVm1eJ1wmE6QkWVlAOsXg4/Ix4aCOR8PE6xx2E17/pSaHstrRh4rw+V73o="]]

0 comments on commit 7e235cd

Please sign in to comment.