Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/.changelog.d/6202.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support receive-side THP ACK piggybacking.
3 changes: 1 addition & 2 deletions core/src/trezor/wire/thp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties:
internal_model=utils.INTERNAL_MODEL,
model_variant=model_variant,
protocol_version_major=2,
# TODO: re-enable THP ACK piggybacking after #6506 is fixed
protocol_version_minor=0,
protocol_version_minor=1,
)


Expand Down
9 changes: 4 additions & 5 deletions core/src/trezor/wire/thp/received_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,10 @@ async def _handle_state_handshake(
def _handshake_callback(ctrl_byte: int) -> bool:
success = control_byte.is_handshake_init_req(ctrl_byte)

# TODO: re-enable THP ACK piggybacking after #6506 is fixed
# if success and control_byte.get_ack_bit(ctrl_byte) == 1:
# # Newer Suite versions will send `handshake_init_req` with a non-zero ACK bit.
# # The device should not use ACK piggybacking with older Suite versions.
# ABP.allow_ack_piggybacking(ctx.channel_cache)
if success and control_byte.get_ack_bit(ctrl_byte) == 1:
# Newer Suite versions will send `handshake_init_req` with a non-zero ACK bit.
# The device should not use ACK piggybacking with older Suite versions.
ABP.allow_ack_piggybacking(ctx.channel_cache)

if __debug__:
ctx._log(
Expand Down
171 changes: 86 additions & 85 deletions python/src/trezorlib/btc.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,96 +326,97 @@ def sign_tx(
elif preauthorized:
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)

res = session.call(signtx, expect=messages.TxRequest)

# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
serialized_tx = b""

def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
tx_copy = copy(tx)
# clear fields
tx_copy.inputs_cnt = len(tx.inputs)
tx_copy.inputs = []
tx_copy.outputs_cnt = len(tx.bin_outputs or tx.outputs)
tx_copy.outputs = []
tx_copy.bin_outputs = []
tx_copy.extra_data_len = len(tx.extra_data or b"")
tx_copy.extra_data = None
return tx_copy

this_tx = messages.TransactionType(
inputs=inputs,
outputs=outputs,
inputs_cnt=len(inputs),
outputs_cnt=len(outputs),
# pick either kw-provided or default value from the SignTx request
version=signtx.version,
)
with session.interact() as ctx:
res = ctx.call(signtx, expect=messages.TxRequest)

# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
serialized_tx = b""

def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
tx_copy = copy(tx)
# clear fields
tx_copy.inputs_cnt = len(tx.inputs)
tx_copy.inputs = []
tx_copy.outputs_cnt = len(tx.bin_outputs or tx.outputs)
tx_copy.outputs = []
tx_copy.bin_outputs = []
tx_copy.extra_data_len = len(tx.extra_data or b"")
tx_copy.extra_data = None
return tx_copy

this_tx = messages.TransactionType(
inputs=inputs,
outputs=outputs,
inputs_cnt=len(inputs),
outputs_cnt=len(outputs),
# pick either kw-provided or default value from the SignTx request
version=signtx.version,
)

R = messages.RequestType
while True:
# If there's some part of signed transaction, let's add it
if res.serialized:
if res.serialized.serialized_tx:
serialized_tx += res.serialized.serialized_tx

if res.serialized.signature_index is not None:
idx = res.serialized.signature_index
sig = res.serialized.signature
if signatures[idx] is not None:
raise ValueError(f"Signature for index {idx} already filled")
signatures[idx] = sig

if res.request_type == R.TXFINISHED:
break

assert res.details is not None, "device did not provide details"

# Device asked for one more information, let's process it.
if res.details.tx_hash is not None:
if res.details.tx_hash not in prev_txes:
raise ValueError(
f"Previous transaction {res.details.tx_hash.hex()} not available"
)
current_tx = prev_txes[res.details.tx_hash]
else:
current_tx = this_tx
R = messages.RequestType
while True:
# If there's some part of signed transaction, let's add it
if res.serialized:
if res.serialized.serialized_tx:
serialized_tx += res.serialized.serialized_tx

if res.serialized.signature_index is not None:
idx = res.serialized.signature_index
sig = res.serialized.signature
if signatures[idx] is not None:
raise ValueError(f"Signature for index {idx} already filled")
signatures[idx] = sig

if res.request_type == R.TXFINISHED:
break

assert res.details is not None, "device did not provide details"

# Device asked for one more information, let's process it.
if res.details.tx_hash is not None:
if res.details.tx_hash not in prev_txes:
raise ValueError(
f"Previous transaction {res.details.tx_hash.hex()} not available"
)
current_tx = prev_txes[res.details.tx_hash]
else:
current_tx = this_tx

if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index]
res = session.call(msg, expect=messages.TxRequest)
else:
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
msg = copy_tx_meta(current_tx)
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
assert res.details.request_index is not None
msg.inputs = [current_tx.inputs[res.details.request_index]]
elif res.request_type == R.TXOUTPUT:
if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
if res.details.tx_hash:
msg.bin_outputs = [
current_tx.bin_outputs[res.details.request_index]
]
else:
msg.outputs = [current_tx.outputs[res.details.request_index]]
elif res.request_type == R.TXORIGOUTPUT:
assert res.details.request_index is not None
msg.outputs = [current_tx.outputs[res.details.request_index]]
elif res.request_type == R.TXEXTRADATA:
assert res.details.extra_data_offset is not None
assert res.details.extra_data_len is not None
assert current_tx.extra_data is not None
o, l = res.details.extra_data_offset, res.details.extra_data_len
msg.extra_data = current_tx.extra_data[o : o + l]
msg = payment_reqs[res.details.request_index]
res = ctx.call(msg, expect=messages.TxRequest)
else:
raise exceptions.TrezorException(
f"Unknown request type - {res.request_type}."
)
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
msg = copy_tx_meta(current_tx)
elif res.request_type in (R.TXINPUT, R.TXORIGINPUT):
assert res.details.request_index is not None
msg.inputs = [current_tx.inputs[res.details.request_index]]
elif res.request_type == R.TXOUTPUT:
assert res.details.request_index is not None
if res.details.tx_hash:
msg.bin_outputs = [
current_tx.bin_outputs[res.details.request_index]
]
else:
msg.outputs = [current_tx.outputs[res.details.request_index]]
elif res.request_type == R.TXORIGOUTPUT:
assert res.details.request_index is not None
msg.outputs = [current_tx.outputs[res.details.request_index]]
elif res.request_type == R.TXEXTRADATA:
assert res.details.extra_data_offset is not None
assert res.details.extra_data_len is not None
assert current_tx.extra_data is not None
o, l = res.details.extra_data_offset, res.details.extra_data_len
msg.extra_data = current_tx.extra_data[o : o + l]
else:
raise exceptions.TrezorException(
f"Unknown request type - {res.request_type}."
)

res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
res = ctx.call(messages.TxAck(tx=msg), expect=messages.TxRequest)

for i, sig in zip(inputs, signatures):
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
Expand Down
45 changes: 42 additions & 3 deletions python/src/trezorlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ class PassphraseSetting(enum.Enum):
)


@dataclass
class InteractionContext(t.Generic[ClientType, SessionType]):
client: ClientType
session: SessionType

def __enter__(self) -> tx.Self:
self.client.__enter__()
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: t.Any,
) -> None:
return self.client.__exit__(exc_type, exc_value, traceback)

def call(
self,
msg: MessageType,
*,
expect: type[MT] = MessageType,
timeout: float | None = None,
) -> MT:
return self.client._call(self.session, msg, expect=expect, timeout=timeout)


class Session(t.Generic[ClientType, SessionIdType]):
def __init__(
self,
Expand Down Expand Up @@ -107,11 +134,19 @@ def call(
expect: type[MT] = MessageType,
timeout: float | None = None,
) -> MT:
"""Call a method on this session, process and return the response."""
"""
Call a method on this session, process and return the response.

Use `self.interact()` for consecutive calls (to allow THP ACK piggybacking).
"""
with self.interact() as ctx:
return ctx.call(msg, expect=expect, timeout=timeout)

def interact(self) -> InteractionContext:
"""Use the returned context manager to call methods on this session."""
if self.is_invalid:
raise exceptions.InvalidSessionError(self.id)
with self:
return self.client._call(self, msg, expect=expect, timeout=timeout)
return self.client.interact(self)

def call_raw(self, msg: MessageType, timeout: float | None = None) -> MessageType:
"""Invoke a single call-response round-trip to the device.
Expand Down Expand Up @@ -299,6 +334,10 @@ def __exit__(
) -> None:
self.transport.__exit__(exc_type, exc_value, traceback)

def interact(self, session: SessionType) -> InteractionContext:
"""Use the returned context manager to call methods on this session."""
return InteractionContext(self, session)

def connect(self) -> None:
"""Establish a connection to the device.

Expand Down
24 changes: 13 additions & 11 deletions python/src/trezorlib/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .tools import Address, parse_path, workflow

if TYPE_CHECKING:
from .client import Session
from .client import InteractionContext, Session


RECOVERY_BACK = "\x08" # backspace character, sent literally
Expand Down Expand Up @@ -77,16 +77,17 @@ def apply_settings(

if homescreen and session.version >= HOMESCREEN_STREAMING_MIN_VERSION:
settings.homescreen_length = len(homescreen)
response = session.call(settings, expect=messages.DataChunkRequest)
_send_chunked_data(session, response, homescreen)
with session.interact() as ctx:
response = ctx.call(settings, expect=messages.DataChunkRequest)
_send_chunked_data(ctx, response, homescreen)
else:
settings.homescreen = homescreen
session.call(settings, expect=messages.Success)
session.refresh_features()


def _send_chunked_data(
session: "Session",
ctx: "InteractionContext",
request: "messages.DataChunkRequest",
language_data: bytes,
) -> None:
Expand All @@ -96,7 +97,7 @@ def _send_chunked_data(
data_length = response.data_length
data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length]
response = session.call(messages.DataChunkAck(data_chunk=chunk))
response = ctx.call(messages.DataChunkAck(data_chunk=chunk))


@workflow()
Expand All @@ -108,12 +109,13 @@ def change_language(
data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)

response = session.call(msg)
if data_length > 0:
response = messages.DataChunkRequest.ensure_isinstance(response)
_send_chunked_data(session, response, language_data)
else:
messages.Success.ensure_isinstance(response)
with session.interact() as ctx:
response = ctx.call(msg)
if data_length > 0:
response = messages.DataChunkRequest.ensure_isinstance(response)
_send_chunked_data(ctx, response, language_data)
else:
messages.Success.ensure_isinstance(response)
session.refresh_features() # changing the language in features


Expand Down
Loading