Skip to content

Commit 9d2687b

Browse files
committed
Add message reception state for counter verification
1 parent 6f28636 commit 9d2687b

File tree

3 files changed

+423
-67
lines changed

3 files changed

+423
-67
lines changed

circuitmatter/__init__.py

Lines changed: 271 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Pure Python implementation of the Matter IOT protocol."""
22

33
import enum
4+
import pathlib
5+
import json
6+
import struct
7+
import time
48

59
from . import tlv
610

@@ -21,7 +25,10 @@
2125

2226
# print(f"Listening on UDP port {UDP_PORT}")
2327

24-
unsecured_session_context = {}
28+
# Section 4.11.2
29+
MSG_COUNTER_WINDOW_SIZE = 32
30+
MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
31+
MSG_COUNTER_SYNC_TIMEOUT_MS = 400
2532

2633

2734
class ProtocolId(enum.Enum):
@@ -36,6 +43,8 @@ class SecurityFlags(enum.Flag):
3643
P = 1 << 7
3744
C = 1 << 6
3845
MX = 1 << 5
46+
# This is actually 2 bits but the top bit is reserved and always zero.
47+
GROUP = 1 << 0
3948

4049

4150
class ExchangeFlags(enum.Flag):
@@ -173,3 +182,264 @@ class PBKDFParamResponse(tlv.TLVStructure):
173182
responderSessionId = tlv.NumberMember(3, "<H")
174183
pbkdf_parameters = tlv.StructMember(4, Crypto_PBKDFParameterSet)
175184
responderSessionParams = tlv.StructMember(5, SessionParameterStruct, optional=True)
185+
186+
187+
class MessageReceptionState:
188+
def __init__(self, starting_value, rollover=True, encrypted=False):
189+
"""Implements 4.6.5.1"""
190+
self.message_counter = starting_value
191+
self.window_bitmap = (1 << MSG_COUNTER_WINDOW_SIZE) - 1
192+
self.mask = self.window_bitmap
193+
self.encrypted = encrypted
194+
self.rollover = rollover
195+
196+
def process_counter(self, counter) -> bool:
197+
"""Returns True if the counter number is a duplicate"""
198+
# Process the current window first. Behavior outside the window varies.
199+
if counter == self.message_counter:
200+
return True
201+
if self.message_counter <= MSG_COUNTER_WINDOW_SIZE < counter:
202+
# Window wraps
203+
bit_position = 0xFFFFFFFF - counter + self.message_counter
204+
else:
205+
bit_position = self.message_counter - counter - 1
206+
if 0 <= bit_position < MSG_COUNTER_WINDOW_SIZE:
207+
if self.window_bitmap & (1 << bit_position) != 0:
208+
# This is a duplicate message
209+
return True
210+
self.window_bitmap |= 1 << bit_position
211+
return False
212+
213+
new_start = (self.message_counter + 1) & self.mask # Inclusive
214+
new_end = (
215+
self.message_counter - MSG_COUNTER_WINDOW_SIZE
216+
) & self.mask # Exclusive
217+
if not self.rollover:
218+
new_end = (1 << MSG_COUNTER_WINDOW_SIZE) - 1
219+
elif self.encrypted:
220+
new_end = (
221+
self.message_counter + (1 << (MSG_COUNTER_WINDOW_SIZE - 1))
222+
) & self.mask
223+
224+
if new_start <= new_end:
225+
if not (new_start <= counter < new_end):
226+
return True
227+
else:
228+
if not (counter < new_end or new_start <= counter):
229+
return True
230+
231+
# This is a new message
232+
shift = counter - self.message_counter
233+
if counter < self.message_counter:
234+
shift += 0x100000000
235+
if shift > MSG_COUNTER_WINDOW_SIZE:
236+
self.window_bitmap = 0
237+
else:
238+
new_bitmap = (self.window_bitmap << shift) & self.mask
239+
self.window_bitmap = new_bitmap
240+
if 1 < shift < MSG_COUNTER_WINDOW_SIZE:
241+
self.window_bitmap |= 1 << (shift - 1)
242+
self.message_counter = counter
243+
return False
244+
245+
246+
class UnsecuredSessionContext:
247+
def __init__(self, initiator, ephemeral_initiator_node_id):
248+
self.initiator = initiator
249+
self.ephemeral_initiator_node_id = ephemeral_initiator_node_id
250+
self.message_reception_state = None
251+
252+
253+
class SecureSessionContext:
254+
def __init__(self, local_session_id):
255+
self.session_type = None
256+
"""Records whether the session was established using CASE or PASE."""
257+
self.session_role = None
258+
"""Records whether the node is the session initiator or responder."""
259+
self.local_session_id = local_session_id
260+
"""Individually selected by each participant in secure unicast communication during session establishment and used as a unique identifier to recover encryption keys, authenticate incoming messages and associate them to existing sessions."""
261+
self.peer_session_id = None
262+
"""Assigned by the peer during session establishment"""
263+
self.i2r_key = None
264+
"""Encrypts data in messages sent from the initiator of session establishment to the responder."""
265+
self.r2i_key = None
266+
"""Encrypts data in messages sent from the session establishment responder to the initiator."""
267+
self.shared_secret = None
268+
"""Computed during the CASE protocol execution and re-used when CASE session resumption is implemented."""
269+
self.local_message_counter = None
270+
"""Secure Session Message Counter for outbound messages."""
271+
self.message_reception_state = None
272+
"""Provides tracking for the Secure Session Message Counter of the remote"""
273+
self.local_fabric_index = None
274+
"""Records the local Index for the session’s Fabric, which MAY be used to look up Fabric metadata related to the Fabric for which this session context applies."""
275+
self.peer_node_id = None
276+
"""Records the authenticated node ID of the remote peer, when available."""
277+
self.resumption_id = None
278+
"""The ID used when resuming a session between the local and remote peer."""
279+
self.session_timestamp = None
280+
"""A timestamp indicating the time at which the last message was sent or received. This timestamp SHALL be initialized with the time the session was created."""
281+
self.active_timestamp = None
282+
"""A timestamp indicating the time at which the last message was received. This timestamp SHALL be initialized with the time the session was created."""
283+
self.session_idle_interval = None
284+
self.session_active_interval = None
285+
self.session_active_threshold = None
286+
287+
@property
288+
def peer_active(self):
289+
return (time.monotonic() - self.active_timestamp) < self.session_active_interval
290+
291+
292+
class Message:
293+
def __init__(self, buffer):
294+
self.buffer = buffer
295+
self.flags, self.session_id, self.security_flags, self.message_counter = (
296+
struct.unpack_from("<BHBI", buffer)
297+
)
298+
offset = 8
299+
self.source_node_id = None
300+
if self.flags & (1 << 2):
301+
self.source_node_id = struct.unpack_from("<Q", buffer, 8)[0]
302+
offset += 8
303+
304+
if (self.flags >> 4) != 0:
305+
raise RuntimeError("Incorrect version")
306+
self.secure_session = self.security_flags & 0x3 != 0 or self.session_id != 0
307+
308+
if not self.secure_session:
309+
self.payload = memoryview(buffer)[offset:]
310+
311+
context = UnsecuredSessionContext(False, self.source_node_id)
312+
self.unsecured_session_context[self.source_node_id] = context
313+
else:
314+
self.payload = None
315+
316+
def _parse_protocol_header(self):
317+
self.exchange_flags, self.protocol_opcode, self.exchange_id = (
318+
struct.unpack_from("<BBH", self.payload)
319+
)
320+
321+
self.exchange_flags = ExchangeFlags(self.exchange_flags)
322+
decrypted_offset = 4
323+
self.protocol_vendor_id = 0
324+
if self.exchange_flags & ExchangeFlags.V:
325+
self.protocol_vendor_id = struct.unpack_from(
326+
"<H", self.payload, decrypted_offset
327+
)[0]
328+
decrypted_offset += 2
329+
protocol_id = struct.unpack_from("<H", self.payload, decrypted_offset)[0]
330+
decrypted_offset += 2
331+
self.protocol_id = ProtocolId(protocol_id)
332+
self.protocol_opcode = PROTOCOL_OPCODES[protocol_id](self.protocol_opcode)
333+
334+
self.acknowledged_message_counter = None
335+
if self.exchange_flags & ExchangeFlags.A:
336+
self.acknowledged_message_counter = struct.unpack_from(
337+
"<I", self.payload, decrypted_offset
338+
)[0]
339+
decrypted_offset += 4
340+
341+
def reply(self, payload, protocol_id=None, protocol_opcode=None) -> memoryview:
342+
reply = bytearray(1280)
343+
offset = 0
344+
345+
# struct.pack_into(
346+
# "<BHBI", reply, offset, flags, session_id, security_flags, message_counter
347+
# )
348+
# offset += 8
349+
return memoryview(reply)[:offset]
350+
351+
352+
class SessionManager:
353+
def __init__(self):
354+
persist_path = pathlib.Path("counters.json")
355+
if persist_path.exists():
356+
self.nonvolatile = json.loads(persist_path.read_text())
357+
else:
358+
self.nonvolatile = {}
359+
self.nonvolatile["unencrypted_message_counter"] = 0
360+
self.nonvolatile["group_encrypted_data_message_counter"] = 0
361+
self.nonvolatile["group_encrypted_control_message_counter"] = 0
362+
self.unencrypted_message_counter = self.nonvolatile[
363+
"unencrypted_message_counter"
364+
]
365+
self.group_encrypted_data_message_counter = self.nonvolatile[
366+
"group_encrypted_data_message_counter"
367+
]
368+
self.group_encrypted_control_message_counter = self.nonvolatile[
369+
"group_encrypted_control_message_counter"
370+
]
371+
self.check_in_counter = 0
372+
self.unsecured_session_context = {}
373+
self.secure_session_contexts = ["reserved"]
374+
375+
def _increment(self, value):
376+
return (value + 1) % 0xFFFFFFFF
377+
378+
def counter_ok(self, message):
379+
"""Implements 4.6.7"""
380+
if message.secure_session:
381+
if message.security_flags & SecurityFlags.GROUP:
382+
if message.source_node_id is None:
383+
return False
384+
# TODO: Get MRS for source node id and message type
385+
else:
386+
session_context = self.secure_session_contexts[message.session_id]
387+
else:
388+
if message.source_node_id not in self.unsecured_session_context:
389+
self.unsecured_session_context[message.source_node_id] = (
390+
UnsecuredSessionContext(
391+
initiator=False,
392+
ephemeral_initiator_node_id=message.source_node_id,
393+
)
394+
)
395+
session_context = self.unsecured_session_context[message.source_node_id]
396+
397+
if session_context.message_reception_state is None:
398+
session_context.message_reception_state = MessageReceptionState(
399+
message.message_counter,
400+
rollover=False,
401+
encrypted=message.secure_session,
402+
)
403+
return True
404+
405+
return session_context.message_reception_state.process_counter(
406+
message.message_counter
407+
)
408+
409+
def next_message_counter(self, message):
410+
"""Implements 4.6.6"""
411+
if not message.secure_session:
412+
value = self.unencrypted_message_counter
413+
self.unencrypted_message_counter = self._increment(
414+
self.unencrypted_message_counter
415+
)
416+
return value
417+
elif message.security_flags & SecurityFlags.GROUP:
418+
if message.security_flags & SecurityFlags.C:
419+
value = self.group_encrypted_control_message_counter
420+
self.group_encrypted_control_message_counter = self._increment(
421+
self.group_encrypted_control_message_counter
422+
)
423+
return value
424+
else:
425+
value = self.group_encrypted_data_message_counter
426+
self.group_encrypted_data_message_counter = self._increment(
427+
self.group_encrypted_data_message_counter
428+
)
429+
return value
430+
session = self.secure_session_contexts[message.session_id]
431+
value = session.local_message_counter
432+
next_value = self._increment(value)
433+
session.local_message_counter = next_value
434+
if next_value == 0:
435+
# TODO expire the encryption key
436+
raise NotImplementedError("Expire the encryption key 4.6.6")
437+
return next_value
438+
439+
def new_context(self):
440+
if None not in self.secure_session_contexts:
441+
self.secure_session_contexts.append(None)
442+
session_id = self.secure_session_contexts.index(None)
443+
444+
self.secure_session_contexts[session_id] = SecureSessionContext(session_id)
445+
return self.secure_session_contexts[session_id]

0 commit comments

Comments
 (0)