Skip to content

Commit

Permalink
Fix a number of timeout bugs with QUIC [#954].
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Jul 13, 2023
1 parent 000c37b commit 60253ac
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 23 deletions.
7 changes: 4 additions & 3 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_compute_times,
_have_http2,
_matches_destination,
_remaining,
have_doh,
ssl,
)
Expand Down Expand Up @@ -736,11 +737,11 @@ async def quic(
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
start = time.time()
stream = await the_connection.make_stream()
(start, expiration) = _compute_times(timeout)
stream = await the_connection.make_stream(timeout)
async with stream:
await stream.send(wire, True)
wire = await stream.receive(timeout)
wire = await stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
Expand Down
6 changes: 3 additions & 3 deletions dns/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,10 +1186,10 @@ def quic(
with manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
start = time.time()
with the_connection.make_stream() as stream:
(start, expiration) = _compute_times(timeout)
with the_connection.make_stream(timeout) as stream:
stream.send(wire, True)
wire = stream.receive(timeout)
wire = stream.receive(_remaining(expiration))
finish = time.time()
r = dns.message.from_wire(
wire,
Expand Down
12 changes: 8 additions & 4 deletions dns/quic/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aioquic.quic.events # type: ignore

import dns.asyncbackend
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
Expand Down Expand Up @@ -38,8 +39,8 @@ async def wait_for(self, amount, expiration):
self._expecting = amount
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except Exception:
pass
except TimeoutError:
raise dns.exception.Timeout
self._expecting = 0

async def receive(self, timeout=None):
Expand Down Expand Up @@ -166,8 +167,11 @@ def run(self):
self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender())

async def make_stream(self):
await self._handshake_complete.wait()
async def make_stream(self, timeout=None):
try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
except TimeoutError:
raise dns.exception.Timeout
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
Expand Down
4 changes: 2 additions & 2 deletions dns/quic/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
import struct
import time
from typing import Any
from typing import Any, Optional

import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
Expand Down Expand Up @@ -134,7 +134,7 @@ def _handle_timer(self, expiration):


class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self) -> Any:
async def make_stream(self, timeout: Optional[float] = None) -> Any:
pass


Expand Down
8 changes: 5 additions & 3 deletions dns/quic/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore

import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
Expand Down Expand Up @@ -42,7 +43,7 @@ def wait_for(self, amount, expiration):
self._expecting = amount
with self._wake_up:
if not self._wake_up.wait(timeout):
raise TimeoutError
raise dns.exception.Timeout
self._expecting = 0

def receive(self, timeout=None):
Expand Down Expand Up @@ -171,8 +172,9 @@ def run(self):
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start()

def make_stream(self):
self._handshake_complete.wait()
def make_stream(self, timeout=None):
if not self._handshake_complete.wait(timeout):
raise dns.exception.Timeout
with self._lock:
if self._done:
raise UnexpectedEOF
Expand Down
24 changes: 16 additions & 8 deletions dns/quic/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import aioquic.quic.events # type: ignore
import trio

import dns.exception
import dns.inet
from dns._asyncbackend import NullContext
from dns.quic._common import (
Expand Down Expand Up @@ -45,6 +46,7 @@ async def receive(self, timeout=None):
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
raise dns.exception.Timeout

async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
Expand Down Expand Up @@ -137,14 +139,20 @@ async def run(self):
nursery.start_soon(self._worker)
self._run_done.set()

async def make_stream(self):
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def make_stream(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
raise dns.exception.Timeout

async def close(self):
if not self._closed:
Expand Down

0 comments on commit 60253ac

Please sign in to comment.