2222import secrets
2323import time
2424import typing as t
25- from contextlib import contextmanager
2625from enum import Enum , IntEnum , auto
2726
2827import typing_extensions as tx
@@ -111,6 +110,30 @@ def __init__(self) -> None:
111110 super ().__init__ (self .__doc__ )
112111
113112
113+ class _ThpInteractiveContext :
114+
115+ def __init__ (self , channel : Channel , force_flush : bool = False ) -> None :
116+ self .channel = channel
117+ self .force_flush = force_flush
118+
119+ def __enter__ (self ) -> tx .Self :
120+ self .channel ._active_contexts .append (self )
121+ return self
122+
123+ def __exit__ (
124+ self ,
125+ exc_type : type [BaseException ] | None ,
126+ exc_value : BaseException | None ,
127+ traceback : t .Any ,
128+ ) -> None :
129+ assert self .channel ._active_contexts .pop () is self
130+ if not self .channel .is_ack_piggybacking_allowed :
131+ return
132+
133+ if not self .channel ._active_contexts or self .force_flush :
134+ self .channel ._flush_ack ()
135+
136+
114137class Channel :
115138 CHUNK_SIZE : t .ClassVar [int | None ] = None
116139
@@ -141,7 +164,7 @@ def __init__(
141164 self ._noise : NoiseConnection | None = None
142165 self .state = channel_state
143166 self .trezor_public_keys : TrezorPublicKeys | None = None
144- self ._active_workflow : object | None = None
167+ self ._active_contexts = []
145168
146169 @functools .cached_property
147170 def is_ack_piggybacking_allowed (self ) -> bool :
@@ -300,8 +323,7 @@ def _read_handshake_init_response(self) -> None:
300323 if e .code == exceptions .ThpErrorCode .DEVICE_LOCKED :
301324 raise DeviceLockedError from e
302325 raise
303- if not self .is_ack_piggybacking_allowed :
304- self ._send_ack (message )
326+ self ._send_ack (message )
305327 if not message .is_handshake_init_response ():
306328 raise ProtocolError (f"Not a valid handshake init response: { message } " )
307329
@@ -409,16 +431,19 @@ def should_back_off() -> bool:
409431 continue
410432 raise
411433
412- def _send_ack (self , acked_message : Message | None ) -> None :
413- if self .is_ack_piggybacking_allowed and self ._active_workflow is not None :
434+ def _send_ack (self , acked_message : Message ) -> None :
435+ if self .is_ack_piggybacking_allowed and self ._active_contexts :
436+ # Skip explicit THP ACK during workflow - the next request will piggyback the correct ACK bit
414437 return
415438
416- if acked_message is not None :
417- ack = control_byte .make_ack_for (acked_message .ctrl_byte )
418- ack_message = Message (ack , acked_message .cid , b"" )
419- else :
420- ack = control_byte .make_ack (not self .sync_bit_receive )
421- ack_message = Message (ack , self .channel_id , b"" )
439+ ack = control_byte .make_ack_for (acked_message .ctrl_byte )
440+ ack_message = Message (ack , acked_message .cid , b"" )
441+
442+ thp_io .write_payload_to_wire (self .transport , ack_message )
443+
444+ def _flush_ack (self ) -> None :
445+ ack = control_byte .make_ack (not self .sync_bit_receive )
446+ ack_message = Message (ack , self .channel_id , b"" )
422447
423448 thp_io .write_payload_to_wire (self .transport , ack_message )
424449
@@ -441,23 +466,6 @@ def _read_ack(self, message: Message) -> None:
441466 f"Failed to read ACK in { retries } retries for message: { message } "
442467 )
443468
444- @contextmanager
445- def piggyback_acks (self , marker : object ) -> t .Generator [None , None , None ]:
446- # Make sure the previous workflow is over.
447- assert self ._active_workflow is None
448- self ._active_workflow = marker
449- # Skip explicit ACKs during this workflow
450- try :
451- yield
452- finally :
453- active = self ._active_workflow
454- self ._active_workflow = None
455- assert active is marker
456- if self .is_ack_piggybacking_allowed :
457- # Explicitly ACK the latest received message. The device may restart
458- # the event loop, so the next request will be sent in a separate message.
459- self ._send_ack (None )
460-
461469 def write_chunk (self , data : bytes , / ) -> None :
462470 self ._assert_handshake_done ()
463471 encrypted_data = self .noise .encrypt (data )
0 commit comments