Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lnpeer: async maybe_send_commitment #7836

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 41 additions & 27 deletions electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def ping_if_required(self):
self.pong_event.clear()
await self.pong_event.wait()

def process_message(self, message):
async def process_message(self, message: bytes) -> None:
try:
message_type, payload = decode_msg(message)
except UnknownOptionalMsgType as e:
Expand Down Expand Up @@ -226,9 +226,19 @@ def process_message(self, message):
# raw message is needed to check signature
if message_type in ['node_announcement', 'channel_announcement', 'channel_update']:
payload['raw'] = message
execution_result = f(*args)
# note: the message handler might be async or non-async. In either case, by default,
# we wait for it to complete before we return, i.e. before the next message is processed.
if asyncio.iscoroutinefunction(f):
asyncio.ensure_future(self.taskgroup.spawn(execution_result))
await f(*args)
else:
f(*args)

def runs_in_taskgroup(func):
assert asyncio.iscoroutinefunction(func), 'func needs to be a coroutine'
@functools.wraps(func)
async def wrapper(self: 'Peer', *args, **kwargs):
return await self.taskgroup.spawn(func(self, *args, **kwargs))
return wrapper

def on_warning(self, payload):
chan_id = payload.get("channel_id")
Expand Down Expand Up @@ -576,7 +586,7 @@ async def _message_loop(self):
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
raise GracefulDisconnect(f'initialize failed: {repr(e)}') from e
async for msg in self.transport.read_messages():
self.process_message(msg)
await self.process_message(msg)
if self.DELAY_INC_MSG_PROCESSING_SLEEP:
# rate-limit message-processing a bit, to make it harder
# for a single peer to bog down the event loop / cpu:
Expand Down Expand Up @@ -899,6 +909,7 @@ def create_channel_storage(self, channel_id, outpoint, local_config, remote_conf
}
return StoredDict(chan_dict, self.lnworker.db if self.lnworker else None, [])

@runs_in_taskgroup
async def on_open_channel(self, payload):
"""Implements the channel acceptance flow.

Expand Down Expand Up @@ -1392,14 +1403,15 @@ def send_announcement_signatures(self, chan: Channel):
)
return msg_hash, node_signature, bitcoin_signature

def on_update_fail_htlc(self, chan: Channel, payload):
async def on_update_fail_htlc(self, chan: Channel, payload):
htlc_id = payload["id"]
reason = payload["reason"]
self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)

def maybe_send_commitment(self, chan: Channel) -> bool:
async def maybe_send_commitment(self, chan: Channel) -> bool:
await self.ping_if_required()
# REMOTE should revoke first before we can sign a new ctx
if chan.hm.is_revack_pending(REMOTE):
return False
Expand All @@ -1411,7 +1423,8 @@ def maybe_send_commitment(self, chan: Channel) -> bool:
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
return True

def pay(self, *,
async def pay(
self, *,
route: 'LNPaymentRoute',
chan: Channel,
amount_msat: int,
Expand Down Expand Up @@ -1470,20 +1483,20 @@ def pay(self, *,
amount_msat=htlc.amount_msat,
payment_hash=htlc.payment_hash,
onion_routing_packet=onion.to_bytes())
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)
return htlc

def send_revoke_and_ack(self, chan: Channel):
async def send_revoke_and_ack(self, chan: Channel):
self.logger.info(f'send_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(LOCAL)}')
rev = chan.revoke_current_commitment()
self.lnworker.save_channel(chan)
self.send_message("revoke_and_ack",
channel_id=chan.channel_id,
per_commitment_secret=rev.per_commitment_secret,
next_per_commitment_point=rev.next_per_commitment_point)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)

def on_commitment_signed(self, chan: Channel, payload):
async def on_commitment_signed(self, chan: Channel, payload):
if chan.peer_state == PeerState.BAD:
return
self.logger.info(f'on_commitment_signed. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(LOCAL)}.')
Expand All @@ -1499,20 +1512,20 @@ def on_commitment_signed(self, chan: Channel, payload):
data = payload["htlc_signature"]
htlc_sigs = list(chunks(data, 64))
chan.receive_new_commitment(payload["signature"], htlc_sigs)
self.send_revoke_and_ack(chan)
await self.send_revoke_and_ack(chan)
self.received_commitsig_event.set()
self.received_commitsig_event.clear()

def on_update_fulfill_htlc(self, chan: Channel, payload):
async def on_update_fulfill_htlc(self, chan: Channel, payload):
preimage = payload["payment_preimage"]
payment_hash = sha256(preimage)
htlc_id = payload["id"]
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.lnworker.save_preimage(payment_hash, preimage)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)

def on_update_fail_malformed_htlc(self, chan: Channel, payload):
async def on_update_fail_malformed_htlc(self, chan: Channel, payload):
htlc_id = payload["id"]
failure_code = payload["failure_code"]
self.logger.info(f"on_update_fail_malformed_htlc. chan {chan.get_id_for_log()}. "
Expand All @@ -1522,7 +1535,7 @@ def on_update_fail_malformed_htlc(self, chan: Channel, payload):
raise RemoteMisbehaving(f"received update_fail_malformed_htlc with unexpected failure code: {failure_code}")
reason = OnionRoutingFailure(code=failure_code, data=payload["sha256_of_onion"])
chan.receive_fail_htlc(htlc_id, error_bytes=None, reason=reason)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)

def on_update_add_htlc(self, chan: Channel, payload):
payment_hash = payload["payment_hash"]
Expand All @@ -1546,7 +1559,7 @@ def on_update_add_htlc(self, chan: Channel, payload):
chan.receive_htlc(htlc, onion_packet)
util.trigger_callback('htlc_added', chan, htlc, RECEIVED)

def maybe_forward_htlc(
async def maybe_forward_htlc(
self, *,
htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket) -> Tuple[bytes, int]:
Expand Down Expand Up @@ -1631,7 +1644,7 @@ def maybe_forward_htlc(
except BaseException as e:
self.logger.info(f"failed to forward htlc: error sending message. {e}")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
next_peer.maybe_send_commitment(next_chan)
await next_peer.maybe_send_commitment(next_chan)
return next_chan_scid, next_htlc.htlc_id

def maybe_forward_trampoline(
Expand Down Expand Up @@ -1841,14 +1854,14 @@ def fail_malformed_htlc(self, *, chan: Channel, htlc_id: int, reason: OnionRouti
sha256_of_onion=reason.data,
failure_code=reason.code)

def on_revoke_and_ack(self, chan: Channel, payload):
async def on_revoke_and_ack(self, chan: Channel, payload):
if chan.peer_state == PeerState.BAD:
return
self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}')
rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])
chan.receive_revocation(rev)
self.lnworker.save_channel(chan)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)
self._received_revack_event.set()
self._received_revack_event.clear()

Expand Down Expand Up @@ -1894,7 +1907,7 @@ async def maybe_update_fee(self, chan: Channel):
"update_fee",
channel_id=chan.channel_id,
feerate_per_kw=feerate_per_kw)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)

@log_exceptions
async def close_channel(self, chan_id: bytes):
Expand All @@ -1912,6 +1925,7 @@ async def close_channel(self, chan_id: bytes):
raise Exception('The remote peer did not send their final signature. The channel may not have been be closed')
return txid

@runs_in_taskgroup
async def on_shutdown(self, chan: Channel, payload):
# TODO: A receiving node: if it hasn't received a funding_signed (if it is a
# funder) or a funding_created (if it is a fundee):
Expand Down Expand Up @@ -2195,7 +2209,7 @@ async def htlc_switch(self):
for chan_id, chan in self.channels.items():
if not chan.can_send_ctx_updates():
continue
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)
done = set()
unfulfilled = chan.unfulfilled_htlcs
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items():
Expand All @@ -2216,7 +2230,7 @@ async def htlc_switch(self):
error_reason = e
else:
try:
preimage, fw_info, error_bytes = self.process_unfulfilled_htlc(
preimage, fw_info, error_bytes = await self.process_unfulfilled_htlc(
chan=chan,
htlc=htlc,
forwarding_info=forwarding_info,
Expand Down Expand Up @@ -2248,7 +2262,7 @@ async def htlc_switch(self):
local_ctn, remote_ctn, onion_packet_hex, forwarding_info = unfulfilled.pop(htlc_id)
if forwarding_info:
self.lnworker.downstream_htlc_to_upstream_peer_map.pop(forwarding_info, None)
self.maybe_send_commitment(chan)
await self.maybe_send_commitment(chan)

def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
done = set()
Expand All @@ -2273,7 +2287,7 @@ async def htlc_switch_iteration():
await group.spawn(htlc_switch_iteration())
await group.spawn(self.got_disconnected.wait())

def process_unfulfilled_htlc(
async def process_unfulfilled_htlc(
self, *,
chan: Channel,
htlc: UpdateAddHtlc,
Expand Down Expand Up @@ -2332,7 +2346,7 @@ def process_unfulfilled_htlc(
# HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
return None, None, None
next_chan_id, next_htlc_id = self.maybe_forward_htlc(
next_chan_id, next_htlc_id = await self.maybe_forward_htlc(
htlc=htlc,
processed_onion=processed_onion)
fw_info = (next_chan_id.hex(), next_htlc_id)
Expand Down
2 changes: 1 addition & 1 deletion electrum/lnworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ async def pay_to_route(
if not peer:
raise PaymentFailure('Dropped peer')
await peer.initialized
htlc = peer.pay(
htlc = await peer.pay(
route=route,
chan=chan,
amount_msat=amount_msat,
Expand Down
16 changes: 8 additions & 8 deletions electrum/tests/test_lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ async def f():
self._send_fake_htlc(p2, chan_BA)
self._send_fake_htlc(p1, chan_AB)
p2.transport.queue.put_nowait(asyncio.Event()) # break Bob's incoming pipe
self.assertTrue(p2.maybe_send_commitment(chan_BA))
self.assertTrue(await p2.maybe_send_commitment(chan_BA))
await p1.received_commitsig_event.wait()
await group.cancel_remaining()
# simulating disconnection. recreate transports.
Expand Down Expand Up @@ -666,8 +666,8 @@ async def f():
self._send_fake_htlc(p2, chan_BA)
self._send_fake_htlc(p1, chan_AB)
p2.transport.queue.put_nowait(asyncio.Event()) # break Bob's incoming pipe
self.assertTrue(p1.maybe_send_commitment(chan_AB))
self.assertTrue(p2.maybe_send_commitment(chan_BA))
self.assertTrue(await p1.maybe_send_commitment(chan_AB))
self.assertTrue(await p2.maybe_send_commitment(chan_BA))
await p1.received_commitsig_event.wait()
await group.cancel_remaining()
# simulating disconnection. recreate transports.
Expand Down Expand Up @@ -737,7 +737,7 @@ async def pay():
q1 = w1.sent_htlcs[lnaddr2.paymenthash]
q2 = w2.sent_htlcs[lnaddr1.paymenthash]
# alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None
p1.maybe_send_commitment = lambda x: asyncio.sleep(0)
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0]
amount_msat = lnaddr2.get_amount_msat()
await w1.pay_to_route(
Expand All @@ -752,7 +752,7 @@ async def pay():
)
p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None
p2.maybe_send_commitment = lambda x: asyncio.sleep(0)
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0]
amount_msat = lnaddr1.get_amount_msat()
await w2.pay_to_route(
Expand All @@ -769,8 +769,8 @@ async def pay():
# sleep a bit so that they both receive msgs sent so far
await asyncio.sleep(0.2)
# now they both send COMMITMENT_SIGNED
p1.maybe_send_commitment(alice_channel)
p2.maybe_send_commitment(bob_channel)
await p1.maybe_send_commitment(alice_channel)
await p2.maybe_send_commitment(bob_channel)

htlc_log1 = await q1.get()
assert htlc_log1.success
Expand Down Expand Up @@ -1241,7 +1241,7 @@ async def pay():
await asyncio.wait_for(p2.initialized, 1)
# alice sends htlc
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
p1.pay(route=route,
await p1.pay(route=route,
chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(),
total_msat=lnaddr.get_amount_msat(),
Expand Down