1
1
"""Pure Python implementation of the Matter IOT protocol."""
2
2
3
3
import enum
4
+ import pathlib
5
+ import json
6
+ import struct
7
+ import time
4
8
5
9
from . import tlv
6
10
21
25
22
26
# print(f"Listening on UDP port {UDP_PORT}")
23
27
24
- unsecured_session_context = {}
28
+ # Section 4.11.2
29
+ MSG_COUNTER_WINDOW_SIZE = 32
30
+ MSG_COUNTER_SYNC_REQ_JITTER_MS = 500
31
+ MSG_COUNTER_SYNC_TIMEOUT_MS = 400
25
32
26
33
27
34
class ProtocolId (enum .Enum ):
@@ -36,6 +43,8 @@ class SecurityFlags(enum.Flag):
36
43
P = 1 << 7
37
44
C = 1 << 6
38
45
MX = 1 << 5
46
+ # This is actually 2 bits but the top bit is reserved and always zero.
47
+ GROUP = 1 << 0
39
48
40
49
41
50
class ExchangeFlags (enum .Flag ):
@@ -173,3 +182,264 @@ class PBKDFParamResponse(tlv.TLVStructure):
173
182
responderSessionId = tlv .NumberMember (3 , "<H" )
174
183
pbkdf_parameters = tlv .StructMember (4 , Crypto_PBKDFParameterSet )
175
184
responderSessionParams = tlv .StructMember (5 , SessionParameterStruct , optional = True )
185
+
186
+
187
+ class MessageReceptionState :
188
+ def __init__ (self , starting_value , rollover = True , encrypted = False ):
189
+ """Implements 4.6.5.1"""
190
+ self .message_counter = starting_value
191
+ self .window_bitmap = (1 << MSG_COUNTER_WINDOW_SIZE ) - 1
192
+ self .mask = self .window_bitmap
193
+ self .encrypted = encrypted
194
+ self .rollover = rollover
195
+
196
+ def process_counter (self , counter ) -> bool :
197
+ """Returns True if the counter number is a duplicate"""
198
+ # Process the current window first. Behavior outside the window varies.
199
+ if counter == self .message_counter :
200
+ return True
201
+ if self .message_counter <= MSG_COUNTER_WINDOW_SIZE < counter :
202
+ # Window wraps
203
+ bit_position = 0xFFFFFFFF - counter + self .message_counter
204
+ else :
205
+ bit_position = self .message_counter - counter - 1
206
+ if 0 <= bit_position < MSG_COUNTER_WINDOW_SIZE :
207
+ if self .window_bitmap & (1 << bit_position ) != 0 :
208
+ # This is a duplicate message
209
+ return True
210
+ self .window_bitmap |= 1 << bit_position
211
+ return False
212
+
213
+ new_start = (self .message_counter + 1 ) & self .mask # Inclusive
214
+ new_end = (
215
+ self .message_counter - MSG_COUNTER_WINDOW_SIZE
216
+ ) & self .mask # Exclusive
217
+ if not self .rollover :
218
+ new_end = (1 << MSG_COUNTER_WINDOW_SIZE ) - 1
219
+ elif self .encrypted :
220
+ new_end = (
221
+ self .message_counter + (1 << (MSG_COUNTER_WINDOW_SIZE - 1 ))
222
+ ) & self .mask
223
+
224
+ if new_start <= new_end :
225
+ if not (new_start <= counter < new_end ):
226
+ return True
227
+ else :
228
+ if not (counter < new_end or new_start <= counter ):
229
+ return True
230
+
231
+ # This is a new message
232
+ shift = counter - self .message_counter
233
+ if counter < self .message_counter :
234
+ shift += 0x100000000
235
+ if shift > MSG_COUNTER_WINDOW_SIZE :
236
+ self .window_bitmap = 0
237
+ else :
238
+ new_bitmap = (self .window_bitmap << shift ) & self .mask
239
+ self .window_bitmap = new_bitmap
240
+ if 1 < shift < MSG_COUNTER_WINDOW_SIZE :
241
+ self .window_bitmap |= 1 << (shift - 1 )
242
+ self .message_counter = counter
243
+ return False
244
+
245
+
246
+ class UnsecuredSessionContext :
247
+ def __init__ (self , initiator , ephemeral_initiator_node_id ):
248
+ self .initiator = initiator
249
+ self .ephemeral_initiator_node_id = ephemeral_initiator_node_id
250
+ self .message_reception_state = None
251
+
252
+
253
+ class SecureSessionContext :
254
+ def __init__ (self , local_session_id ):
255
+ self .session_type = None
256
+ """Records whether the session was established using CASE or PASE."""
257
+ self .session_role = None
258
+ """Records whether the node is the session initiator or responder."""
259
+ self .local_session_id = local_session_id
260
+ """Individually selected by each participant in secure unicast communication during session establishment and used as a unique identifier to recover encryption keys, authenticate incoming messages and associate them to existing sessions."""
261
+ self .peer_session_id = None
262
+ """Assigned by the peer during session establishment"""
263
+ self .i2r_key = None
264
+ """Encrypts data in messages sent from the initiator of session establishment to the responder."""
265
+ self .r2i_key = None
266
+ """Encrypts data in messages sent from the session establishment responder to the initiator."""
267
+ self .shared_secret = None
268
+ """Computed during the CASE protocol execution and re-used when CASE session resumption is implemented."""
269
+ self .local_message_counter = None
270
+ """Secure Session Message Counter for outbound messages."""
271
+ self .message_reception_state = None
272
+ """Provides tracking for the Secure Session Message Counter of the remote"""
273
+ self .local_fabric_index = None
274
+ """Records the local Index for the session’s Fabric, which MAY be used to look up Fabric metadata related to the Fabric for which this session context applies."""
275
+ self .peer_node_id = None
276
+ """Records the authenticated node ID of the remote peer, when available."""
277
+ self .resumption_id = None
278
+ """The ID used when resuming a session between the local and remote peer."""
279
+ self .session_timestamp = None
280
+ """A timestamp indicating the time at which the last message was sent or received. This timestamp SHALL be initialized with the time the session was created."""
281
+ self .active_timestamp = None
282
+ """A timestamp indicating the time at which the last message was received. This timestamp SHALL be initialized with the time the session was created."""
283
+ self .session_idle_interval = None
284
+ self .session_active_interval = None
285
+ self .session_active_threshold = None
286
+
287
+ @property
288
+ def peer_active (self ):
289
+ return (time .monotonic () - self .active_timestamp ) < self .session_active_interval
290
+
291
+
292
+ class Message :
293
+ def __init__ (self , buffer ):
294
+ self .buffer = buffer
295
+ self .flags , self .session_id , self .security_flags , self .message_counter = (
296
+ struct .unpack_from ("<BHBI" , buffer )
297
+ )
298
+ offset = 8
299
+ self .source_node_id = None
300
+ if self .flags & (1 << 2 ):
301
+ self .source_node_id = struct .unpack_from ("<Q" , buffer , 8 )[0 ]
302
+ offset += 8
303
+
304
+ if (self .flags >> 4 ) != 0 :
305
+ raise RuntimeError ("Incorrect version" )
306
+ self .secure_session = self .security_flags & 0x3 != 0 or self .session_id != 0
307
+
308
+ if not self .secure_session :
309
+ self .payload = memoryview (buffer )[offset :]
310
+
311
+ context = UnsecuredSessionContext (False , self .source_node_id )
312
+ self .unsecured_session_context [self .source_node_id ] = context
313
+ else :
314
+ self .payload = None
315
+
316
+ def _parse_protocol_header (self ):
317
+ self .exchange_flags , self .protocol_opcode , self .exchange_id = (
318
+ struct .unpack_from ("<BBH" , self .payload )
319
+ )
320
+
321
+ self .exchange_flags = ExchangeFlags (self .exchange_flags )
322
+ decrypted_offset = 4
323
+ self .protocol_vendor_id = 0
324
+ if self .exchange_flags & ExchangeFlags .V :
325
+ self .protocol_vendor_id = struct .unpack_from (
326
+ "<H" , self .payload , decrypted_offset
327
+ )[0 ]
328
+ decrypted_offset += 2
329
+ protocol_id = struct .unpack_from ("<H" , self .payload , decrypted_offset )[0 ]
330
+ decrypted_offset += 2
331
+ self .protocol_id = ProtocolId (protocol_id )
332
+ self .protocol_opcode = PROTOCOL_OPCODES [protocol_id ](self .protocol_opcode )
333
+
334
+ self .acknowledged_message_counter = None
335
+ if self .exchange_flags & ExchangeFlags .A :
336
+ self .acknowledged_message_counter = struct .unpack_from (
337
+ "<I" , self .payload , decrypted_offset
338
+ )[0 ]
339
+ decrypted_offset += 4
340
+
341
+ def reply (self , payload , protocol_id = None , protocol_opcode = None ) -> memoryview :
342
+ reply = bytearray (1280 )
343
+ offset = 0
344
+
345
+ # struct.pack_into(
346
+ # "<BHBI", reply, offset, flags, session_id, security_flags, message_counter
347
+ # )
348
+ # offset += 8
349
+ return memoryview (reply )[:offset ]
350
+
351
+
352
+ class SessionManager :
353
+ def __init__ (self ):
354
+ persist_path = pathlib .Path ("counters.json" )
355
+ if persist_path .exists ():
356
+ self .nonvolatile = json .loads (persist_path .read_text ())
357
+ else :
358
+ self .nonvolatile = {}
359
+ self .nonvolatile ["unencrypted_message_counter" ] = 0
360
+ self .nonvolatile ["group_encrypted_data_message_counter" ] = 0
361
+ self .nonvolatile ["group_encrypted_control_message_counter" ] = 0
362
+ self .unencrypted_message_counter = self .nonvolatile [
363
+ "unencrypted_message_counter"
364
+ ]
365
+ self .group_encrypted_data_message_counter = self .nonvolatile [
366
+ "group_encrypted_data_message_counter"
367
+ ]
368
+ self .group_encrypted_control_message_counter = self .nonvolatile [
369
+ "group_encrypted_control_message_counter"
370
+ ]
371
+ self .check_in_counter = 0
372
+ self .unsecured_session_context = {}
373
+ self .secure_session_contexts = ["reserved" ]
374
+
375
+ def _increment (self , value ):
376
+ return (value + 1 ) % 0xFFFFFFFF
377
+
378
+ def counter_ok (self , message ):
379
+ """Implements 4.6.7"""
380
+ if message .secure_session :
381
+ if message .security_flags & SecurityFlags .GROUP :
382
+ if message .source_node_id is None :
383
+ return False
384
+ # TODO: Get MRS for source node id and message type
385
+ else :
386
+ session_context = self .secure_session_contexts [message .session_id ]
387
+ else :
388
+ if message .source_node_id not in self .unsecured_session_context :
389
+ self .unsecured_session_context [message .source_node_id ] = (
390
+ UnsecuredSessionContext (
391
+ initiator = False ,
392
+ ephemeral_initiator_node_id = message .source_node_id ,
393
+ )
394
+ )
395
+ session_context = self .unsecured_session_context [message .source_node_id ]
396
+
397
+ if session_context .message_reception_state is None :
398
+ session_context .message_reception_state = MessageReceptionState (
399
+ message .message_counter ,
400
+ rollover = False ,
401
+ encrypted = message .secure_session ,
402
+ )
403
+ return True
404
+
405
+ return session_context .message_reception_state .process_counter (
406
+ message .message_counter
407
+ )
408
+
409
+ def next_message_counter (self , message ):
410
+ """Implements 4.6.6"""
411
+ if not message .secure_session :
412
+ value = self .unencrypted_message_counter
413
+ self .unencrypted_message_counter = self ._increment (
414
+ self .unencrypted_message_counter
415
+ )
416
+ return value
417
+ elif message .security_flags & SecurityFlags .GROUP :
418
+ if message .security_flags & SecurityFlags .C :
419
+ value = self .group_encrypted_control_message_counter
420
+ self .group_encrypted_control_message_counter = self ._increment (
421
+ self .group_encrypted_control_message_counter
422
+ )
423
+ return value
424
+ else :
425
+ value = self .group_encrypted_data_message_counter
426
+ self .group_encrypted_data_message_counter = self ._increment (
427
+ self .group_encrypted_data_message_counter
428
+ )
429
+ return value
430
+ session = self .secure_session_contexts [message .session_id ]
431
+ value = session .local_message_counter
432
+ next_value = self ._increment (value )
433
+ session .local_message_counter = next_value
434
+ if next_value == 0 :
435
+ # TODO expire the encryption key
436
+ raise NotImplementedError ("Expire the encryption key 4.6.6" )
437
+ return next_value
438
+
439
+ def new_context (self ):
440
+ if None not in self .secure_session_contexts :
441
+ self .secure_session_contexts .append (None )
442
+ session_id = self .secure_session_contexts .index (None )
443
+
444
+ self .secure_session_contexts [session_id ] = SecureSessionContext (session_id )
445
+ return self .secure_session_contexts [session_id ]
0 commit comments