Skip to content

Commit

Permalink
Properly fix asyncio QUIC shutdown races [#1069].
Browse files Browse the repository at this point in the history
There were two basic issues:

1) We did not wake up the sender thread to do work in some cases, and could
   sleep for a long time.
2) asyncio.wait_for() does not instantly run the function, it just schedules
   it, and our guards against lost wakeups were thus in the wrong place.
  • Loading branch information
rthalley committed Mar 25, 2024
1 parent 0aa713d commit 7fb6e92
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions dns/quic/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def __init__(self, connection, address, port, source, source_port, manager=None)
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._send_pending = False
self._check_for_events = False
self._wake_pending = False

async def _receiver(self):
try:
Expand All @@ -119,40 +118,39 @@ async def _receiver(self):
continue
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now. We need to set a flag as well as wake up the
# timer to avoid a race where we get a datagram and generate an
# event right before the sender is going to sleep.
self._check_for_events = True
async with self._wake_timer:
self._wake_timer.notify_all()
# stuff to send now.
await self._wakeup()
except Exception:
pass
finally:
self._done = True
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()
self._handshake_complete.set()

async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()

async def _wait_for_wake_timer(self):
async with self._wake_timer:
await self._wake_timer.wait()
if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False

async def _sender(self):
await self._socket_created.wait()
while not self._done:
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams:
assert address == self._peer
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
if not (self._check_for_events or self._send_pending):
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
self._check_for_events = False
await self._handle_events()

async def _handle_events(self):
Expand Down Expand Up @@ -202,9 +200,7 @@ async def _handle_events(self):

async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()

def run(self):
if self._closed:
Expand All @@ -231,8 +227,7 @@ async def close(self):
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
async with self._wake_timer:
self._wake_timer.notify_all()
await self._wakeup()
try:
await self._receiver_task
except asyncio.CancelledError:
Expand Down

0 comments on commit 7fb6e92

Please sign in to comment.