diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index 069387f4..01547e8d 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -97,6 +97,8 @@ def __init__(self, connection, address, port, source, source_port, manager=None) self._wake_timer = asyncio.Condition() self._receiver_task = None self._sender_task = None + self._send_pending = False + self._check_for_events = False async def _receiver(self): try: @@ -117,7 +119,10 @@ async def _receiver(self): continue self._connection.receive_datagram(datagram, address, time.time()) # Wake up the timer in case the sender is sleeping, as there may be - # stuff to send now. + # stuff to send now. We need to set a flag as well as wake up the + # timer to avoid a race where we get a datagram and generate an + # event right before the sender is going to sleep. + self._check_for_events = True async with self._wake_timer: self._wake_timer.notify_all() except Exception: @@ -135,16 +140,19 @@ async def _wait_for_wake_timer(self): async def _sender(self): await self._socket_created.wait() while not self._done: + self._send_pending = False datagrams = self._connection.datagrams_to_send(time.time()) for datagram, address in datagrams: assert address == self._peer await self._socket.sendto(datagram, self._peer, None) (expiration, interval) = self._get_timer_values() - try: - await asyncio.wait_for(self._wait_for_wake_timer(), interval) - except Exception: - pass + if not (self._check_for_events or self._send_pending): + try: + await asyncio.wait_for(self._wait_for_wake_timer(), interval) + except Exception: + pass self._handle_timer(expiration) + self._check_for_events = False await self._handle_events() async def _handle_events(self): @@ -194,6 +202,7 @@ async def _handle_events(self): async def write(self, stream, data, is_end=False): self._connection.send_stream_data(stream, data, is_end) + self._send_pending = True async with self._wake_timer: self._wake_timer.notify_all()