Skip to content

Commit 51d3a31

Browse files
committed
feat(python): enable per-workflow THP ACK piggybacking
Following #6563 (comment). [no changelog]
1 parent 5865acd commit 51d3a31

8 files changed

Lines changed: 192 additions & 186 deletions

File tree

python/src/trezorlib/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import unicodedata
2525
import warnings
2626
from abc import ABCMeta, abstractmethod
27+
from contextlib import AbstractContextManager, nullcontext
2728
from dataclasses import dataclass
2829

2930
import typing_extensions as tx
@@ -245,6 +246,7 @@ def __init__(
245246
self._mapping = mapping
246247
self._features = None
247248
self.pairing = pairing
249+
self._interact_ctx = self._interact()
248250

249251
# ===== Internal methods for overriding in subclasses =====
250252

@@ -284,11 +286,15 @@ def _get_session(
284286
"""
285287
raise NotImplementedError
286288

289+
def _interact(self, *, force_flush: bool = False) -> AbstractContextManager:
290+
return nullcontext()
291+
287292
# ===== Common implementations =====
288293

289294
def __enter__(self) -> tx.Self:
290295
"""(Re)Open a connection to the device."""
291296
self.transport.__enter__()
297+
self._interact_ctx.__enter__()
292298
return self
293299

294300
def __exit__(
@@ -297,7 +303,10 @@ def __exit__(
297303
exc_value: BaseException | None,
298304
traceback: t.Any,
299305
) -> None:
300-
self.transport.__exit__(exc_type, exc_value, traceback)
306+
try:
307+
self._interact_ctx.__exit__(exc_type, exc_value, traceback)
308+
finally:
309+
self.transport.__exit__(exc_type, exc_value, traceback)
301310

302311
def connect(self) -> None:
303312
"""Establish a connection to the device.

python/src/trezorlib/thp/channel.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import secrets
2323
import time
2424
import typing as t
25-
from contextlib import contextmanager
2625
from enum import Enum, IntEnum, auto
2726

2827
import typing_extensions as tx
@@ -111,6 +110,30 @@ def __init__(self) -> None:
111110
super().__init__(self.__doc__)
112111

113112

113+
class _ThpInteractiveContext:
114+
115+
def __init__(self, channel: Channel, force_flush: bool = False) -> None:
116+
self.channel = channel
117+
self.force_flush = force_flush
118+
119+
def __enter__(self) -> tx.Self:
120+
self.channel._active_contexts.append(self)
121+
return self
122+
123+
def __exit__(
124+
self,
125+
exc_type: type[BaseException] | None,
126+
exc_value: BaseException | None,
127+
traceback: t.Any,
128+
) -> None:
129+
assert self.channel._active_contexts.pop() is self
130+
if not self.channel.is_ack_piggybacking_allowed:
131+
return
132+
133+
if not self.channel._active_contexts or self.force_flush:
134+
self.channel._flush_ack()
135+
136+
114137
class Channel:
115138
CHUNK_SIZE: t.ClassVar[int | None] = None
116139

@@ -141,7 +164,7 @@ def __init__(
141164
self._noise: NoiseConnection | None = None
142165
self.state = channel_state
143166
self.trezor_public_keys: TrezorPublicKeys | None = None
144-
self._active_workflow: object | None = None
167+
self._active_contexts = []
145168

146169
@functools.cached_property
147170
def is_ack_piggybacking_allowed(self) -> bool:
@@ -300,8 +323,7 @@ def _read_handshake_init_response(self) -> None:
300323
if e.code == exceptions.ThpErrorCode.DEVICE_LOCKED:
301324
raise DeviceLockedError from e
302325
raise
303-
if not self.is_ack_piggybacking_allowed:
304-
self._send_ack(message)
326+
self._send_ack(message)
305327
if not message.is_handshake_init_response():
306328
raise ProtocolError(f"Not a valid handshake init response: {message}")
307329

@@ -409,16 +431,19 @@ def should_back_off() -> bool:
409431
continue
410432
raise
411433

412-
def _send_ack(self, acked_message: Message | None) -> None:
413-
if self.is_ack_piggybacking_allowed and self._active_workflow is not None:
434+
def _send_ack(self, acked_message: Message) -> None:
435+
if self.is_ack_piggybacking_allowed and self._active_contexts:
436+
# Skip explicit THP ACK during workflow - the next request will piggyback the correct ACK bit
414437
return
415438

416-
if acked_message is not None:
417-
ack = control_byte.make_ack_for(acked_message.ctrl_byte)
418-
ack_message = Message(ack, acked_message.cid, b"")
419-
else:
420-
ack = control_byte.make_ack(not self.sync_bit_receive)
421-
ack_message = Message(ack, self.channel_id, b"")
439+
ack = control_byte.make_ack_for(acked_message.ctrl_byte)
440+
ack_message = Message(ack, acked_message.cid, b"")
441+
442+
thp_io.write_payload_to_wire(self.transport, ack_message)
443+
444+
def _flush_ack(self) -> None:
445+
ack = control_byte.make_ack(not self.sync_bit_receive)
446+
ack_message = Message(ack, self.channel_id, b"")
422447

423448
thp_io.write_payload_to_wire(self.transport, ack_message)
424449

@@ -441,23 +466,6 @@ def _read_ack(self, message: Message) -> None:
441466
f"Failed to read ACK in {retries} retries for message: {message}"
442467
)
443468

444-
@contextmanager
445-
def piggyback_acks(self, marker: object) -> t.Generator[None, None, None]:
446-
# Make sure the previous workflow is over.
447-
assert self._active_workflow is None
448-
self._active_workflow = marker
449-
# Skip explicit ACKs during this workflow
450-
try:
451-
yield
452-
finally:
453-
active = self._active_workflow
454-
self._active_workflow = None
455-
assert active is marker
456-
if self.is_ack_piggybacking_allowed:
457-
# Explicitly ACK the latest received message. The device may restart
458-
# the event loop, so the next request will be sent in a separate message.
459-
self._send_ack(None)
460-
461469
def write_chunk(self, data: bytes, /) -> None:
462470
self._assert_handshake_done()
463471
encrypted_data = self.noise.encrypt(data)

python/src/trezorlib/thp/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from .. import client, exceptions, messages, models, protobuf
2525
from ..log import DUMP_BYTES
26-
from .channel import Channel
26+
from .channel import Channel, _ThpInteractiveContext
2727
from .pairing import PairingController
2828

2929
if t.TYPE_CHECKING:
@@ -67,6 +67,7 @@ def __init__(
6767
mapping: ProtobufMapping | None,
6868
model: models.TrezorModel | None,
6969
) -> None:
70+
# used to override channel creation logic in tests
7071
channel = Channel.allocate(transport)
7172
try:
7273
# try to open the channel
@@ -135,6 +136,14 @@ def _get_session(
135136
session.derive(passphrase, derive_cardano)
136137
return session
137138

139+
def _interact(self, *, force_flush: bool = False) -> _ThpInteractiveContext:
140+
"""
141+
Used internally by `TrezorClient` to create an THP ACK piggybacking context.
142+
143+
Can be also used by tests for unconditionally sending THP ACKs.
144+
"""
145+
return _ThpInteractiveContext(channel=self.channel, force_flush=force_flush)
146+
138147
def _invalidate(self) -> None:
139148
super()._invalidate()
140149
# Close the channel. The client cannot be used until a channel is

python/src/trezorlib/thp/pairing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def setup(self) -> None:
262262
messages.ThpCodeEntryChallenge(challenge=challenge),
263263
expect=messages.ThpCodeEntryCpaceTrezor,
264264
)
265+
self.controller.channel._flush_ack()
265266
self.code_entry_state = CodeEntryState(
266267
challenge=challenge,
267268
commitment=commitment_msg.commitment,

tests/device_tests/evolu/common.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,12 @@ def __init__(self, session, credential):
4747
self.credential: ThpCredentialResponse = credential
4848

4949

50-
def pair_and_get_credential(client: Client) -> ThpPairingResult:
51-
from ..thp.connect import nfc_pairing, prepare_channel_for_pairing
50+
def pair_and_get_credential(test_ctx: Client) -> ThpPairingResult:
51+
from ..thp.connect import new_thp_client
5252

53-
prepare_channel_for_pairing(
54-
client, host_static_privkey=TEST_host_static_private_key
53+
client = new_thp_client(
54+
test_ctx, host_static_privkey=TEST_host_static_private_key, nfc_pairing=True
5555
)
56-
nfc_pairing(client)
5756
credential = client.pairing.request_credential(autoconnect=False)
5857
client.pairing.finish()
5958

@@ -78,7 +77,7 @@ def get_delegated_identity_key(client: Client) -> bytes:
7877
if client.is_thp():
7978
pairing_data = pair_and_get_credential(client)
8079
return evolu.get_delegated_identity_key(
81-
client.get_session(),
80+
session=pairing_data.session,
8281
thp_credential=pairing_data.credential.credential,
8382
)
8483
elif client.is_protocol_v1():

tests/device_tests/test_msg_backup_device.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,13 @@ def _try_to_cancel(
6060
next(gen)
6161
while True:
6262
br = yield
63-
# Try to cancel the backup flow on Core
64-
resp = session.call_raw(messages.Cancel())
65-
# Following #6483, backup is not cancellable
66-
assert resp == BACKUP_IN_PROGRESS
63+
# Entering session's context will send an explicit THP ACK after `BACKUP_IN_PROGRESS` is received.
64+
with session.client._interact(force_flush=True):
65+
# Try to cancel the backup flow on Core
66+
with pytest.raises(TrezorFailure) as exc_info:
67+
session.call(messages.Cancel(), expect=messages.Failure)
68+
# Following #6483, backup is not cancellable
69+
assert exc_info.value.failure == BACKUP_IN_PROGRESS
6770
try:
6871
gen.send(br)
6972
except StopIteration:

tests/device_tests/thp/connect.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,63 +39,62 @@ def prepare_channel_for_handshake(test_ctx: TrezorTestContext) -> None:
3939
test_ctx.channel._init_noise()
4040

4141

42-
def prepare_channel_for_pairing(
42+
def new_thp_client(
4343
test_ctx: TrezorTestContext,
44+
*,
4445
credential: Credential | None = None,
4546
host_static_privkey: bytes | None = None,
4647
fixed_entropy: bool = False,
47-
) -> None:
48-
"""Create a fresh channel, perform the handshake using the provided fixed entropy
49-
and credentials, and leave it in the pairing phase.
50-
"""
48+
nfc_pairing: bool = False,
49+
) -> TrezorClientThp:
5150
# set up a fresh channel
52-
prepare_channel_for_handshake(test_ctx)
51+
assert isinstance(test_ctx.client, TrezorClientThp)
52+
test_ctx.channel = Channel.allocate(test_ctx.transport)
53+
5354
assert host_static_privkey is None or not fixed_entropy, "don't use both"
5455
if host_static_privkey is not None:
5556
test_ctx.channel._init_noise(static_privkey=host_static_privkey)
5657
elif fixed_entropy:
5758
test_ctx.channel._init_noise(
5859
static_privkey=b"\x12" * 32, ephemeral_privkey=b"\x24" * 32
5960
)
61+
else:
62+
test_ctx.channel._init_noise()
63+
6064
credentials = []
6165
if credential is not None:
6266
credentials.append(credential)
6367

6468
# run the handshake
6569
test_ctx.channel.open(credentials)
66-
assert isinstance(test_ctx.client, TrezorClientThp)
70+
71+
test_ctx.client._interact_ctx = test_ctx.client._interact()
6772
test_ctx.client.pairing = test_ctx.pairing = PairingController(test_ctx.client)
6873

74+
if nfc_pairing:
75+
method = Nfc(test_ctx.pairing)
76+
# NFC screen shown
77+
78+
# Read `nfc_secret` and `handshake_hash` from Trezor using debuglink
79+
pairing_info = test_ctx.debug.pairing_info(
80+
thp_channel_id=test_ctx.channel.channel_id.to_bytes(2, "big"),
81+
handshake_hash=test_ctx.channel.handshake_hash,
82+
nfc_secret_host=method.nfc_host_secret,
83+
)
84+
assert pairing_info.handshake_hash is not None
85+
assert pairing_info.nfc_secret_trezor is not None
86+
assert pairing_info.handshake_hash[:16] == test_ctx.channel.handshake_hash[:16]
6987

70-
def get_encrypted_transport_protocol(test_ctx: TrezorTestContext) -> None:
71-
prepare_channel_for_pairing(test_ctx)
72-
test_ctx.pairing.skip()
88+
method.send_nfc_tag(pairing_info.nfc_secret_trezor)
7389

90+
return test_ctx.client
7491

75-
def break_channel(test_ctx: TrezorTestContext) -> None:
76-
cse = test_ctx.channel._noise.noise_protocol.cipher_state_encrypt
92+
93+
def break_channel(client: TrezorClientThp) -> None:
94+
cse = client.channel._noise.noise_protocol.cipher_state_encrypt
7795
cse.n = cse.n + 1
7896

79-
session = test_ctx.client._get_any_session()
97+
session = client._get_any_session()
8098
session.write(messages.ButtonAck())
8199
with pytest.raises(ThpError):
82-
session.read(1)
83-
84-
85-
def nfc_pairing(test_ctx: TrezorTestContext) -> None:
86-
assert isinstance(test_ctx.client, TrezorClientThp)
87-
method = Nfc(test_ctx.client.pairing)
88-
89-
# NFC screen shown
90-
91-
# Read `nfc_secret` and `handshake_hash` from Trezor using debuglink
92-
pairing_info = test_ctx.debug.pairing_info(
93-
thp_channel_id=test_ctx.channel.channel_id.to_bytes(2, "big"),
94-
handshake_hash=test_ctx.channel.handshake_hash,
95-
nfc_secret_host=method.nfc_host_secret,
96-
)
97-
assert pairing_info.handshake_hash is not None
98-
assert pairing_info.nfc_secret_trezor is not None
99-
assert pairing_info.handshake_hash[:16] == test_ctx.channel.handshake_hash[:16]
100-
101-
method.send_nfc_tag(pairing_info.nfc_secret_trezor)
100+
session.read(timeout=1)

0 commit comments

Comments
 (0)