From c4571e7415758655bd390f913c6f113a5e1b1ce1 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Tue, 26 Mar 2024 19:08:42 -0700 Subject: [PATCH] Overhaul nanonameserver, adding DoT, DoH, DoH3, and DoQ support. Co-authored-by: bwelling@xbill.org --- pyproject.toml | 2 + tests/doh.py | 65 ++++++++ tests/doq.py | 357 ++++++++++++++++++++++++++++++++++++++++ tests/nanonameserver.py | 307 ++++++++++++++++++++++++---------- tests/nanoquic.py | 137 --------------- tests/test_doq.py | 31 ++-- tests/tls/ca.crt | 36 ++-- tests/tls/private.pem | 30 +++- tests/tls/public.crt | 52 +++--- 9 files changed, 722 insertions(+), 295 deletions(-) create mode 100644 tests/doh.py create mode 100644 tests/doq.py delete mode 100644 tests/nanoquic.py diff --git a/pyproject.toml b/pyproject.toml index 749f2a834..7ebefccc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,13 @@ dynamic = ["version"] dev = [ "black>=23.1.0", "coverage>=7.0", + "hypercorn>=0.16.0", "flake8>=7", "mypy>=1.8", "pylint>=3", "pytest>=7.4", "pytest-cov>=4.1.0", + "quart-trio>=0.11.0", "sphinx>=7.2.0", "twine>=4.0.0", "wheel>=0.42.0", diff --git a/tests/doh.py b/tests/doh.py new file mode 100644 index 000000000..97cad6139 --- /dev/null +++ b/tests/doh.py @@ -0,0 +1,65 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import base64 +import functools +import socket + +import hypercorn.config +import hypercorn.trio +import quart +import quart_trio + + +def setup(server, connection_type): + name = f"{__name__}-{connection_type.name}" + app = quart_trio.QuartTrio(name) + app.logger.handlers = [] + + @app.route("/dns-query", methods=["GET", "POST"]) + async def dns_query(): + if quart.request.method == "POST": + wire = await quart.request.body + else: + encoded = quart.request.args["dns"] + remainder = len(encoded) % 4 + if remainder != 0: + encoded += "=" * (4 - remainder) + wire = base64.urlsafe_b64decode(encoded) + for body in server.handle_wire( + wire, + quart.request.remote_addr, + quart.request.server, + connection_type, + ): + if body is not None: + return quart.Response(body, mimetype="application/dns-message") + else: + return quart.Response(status=500) + + return app + + +def make_server(server, sock, connection_type, tls_chain, tls_key): + doh_app = setup(server, connection_type) + hconfig = hypercorn.config.Config() + fd = sock.fileno() + if sock.type == socket.SOCK_STREAM: + # We put http/1.1 in the ALPN as we don't mind, but DoH is + # supposed to be H2 officially. + hconfig.alpn_protocols = ["h2", "http/1.1"] + hconfig.bind = [f"fd://{fd}"] + hconfig.quic_bind = [] + else: + hconfig.alpn_protocols = ["h3"] + # We should be able to pass bind=[], but that triggers a bug in + # hypercorn. So, create a dummy socket and bind to it. + tmp_sock = socket.create_server(("127.0.0.1", 0)) + hconfig.bind = [f"fd://{tmp_sock.fileno()}"] + tmp_sock.detach() + hconfig.quic_bind = [f"fd://{fd}"] + sock.detach() + hconfig.certfile = tls_chain + hconfig.keyfile = tls_key + hconfig.accesslog = None + hconfig.errorlog = None + return functools.partial(hypercorn.trio.serve, doh_app, hconfig) diff --git a/tests/doq.py b/tests/doq.py new file mode 100644 index 000000000..5e052bfa5 --- /dev/null +++ b/tests/doq.py @@ -0,0 +1,357 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Implement DNS-over-QUIC + +import secrets +import struct +import time +from typing import Optional + +import aioquic +import aioquic.buffer +import aioquic.quic.configuration +import aioquic.quic.connection +import aioquic.quic.events +import aioquic.quic.logger +import aioquic.quic.packet +import aioquic.quic.retry +import trio + +import dns.exception +from dns._asyncbackend import NullContext +from dns.quic._common import Buffer + +MAX_SAVED_SESSIONS = 100 + + +class Stream: + def __init__(self, connection, stream_id): + self.connection = connection + self.stream_id = stream_id + self.buffer = Buffer() + self.expecting = 0 + self.wake_up = trio.Condition() + self.headers = None + self.trailers = None + + async def wait_for(self, amount: int): + while True: + if self.buffer.have(amount): + return + self.expecting = amount + async with self.wake_up: + await self.wake_up.wait() + self.expecting = 0 + + async def receive(self, timeout: Optional[float] = None): + context: trio.CancelScope | NullContext + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self.wait_for(2) + (size,) = struct.unpack("!H", self.buffer.get(2)) + await self.wait_for(size) + return self.buffer.get(size) + raise dns.exception.Timeout + + async def send(self, datagram: bytes, is_end=False): + l = len(datagram) + data = struct.pack("!H", l) + datagram + await self.connection.write(self.stream_id, data, is_end) + + async def add_input(self, data: bytes, is_end: bool): + self.buffer.put(data, is_end) + # Note it is important that we wake up if we're ending! + if (self.expecting > 0 and self.buffer.have(self.expecting)) or is_end: + async with self.wake_up: + self.wake_up.notify() + + def seen_end(self) -> bool: + return self.buffer.seen_end() + + async def run(self): + try: + wire = await self.receive() + is_get = False + path: Optional[bytes] + for wire in self.connection.listener.server.handle_wire( + wire, + self.connection.peer, + self.connection.listener.socket.getsockname(), + self.connection.listener.connection_type, + ): + break + await self.send(wire, True) + except Exception: + if not self.seen_end(): + self.connection.reset(self.stream_id) + finally: + self.connection.stream_done(self) + + +class Connection: + def __init__(self, listener, cid, peer, retry_cid=None): + self.original_cid: bytes = cid + self.listener = listener + self.cids: set[bytes] = set() + self.cids.add(cid) + self.listener.connections[cid] = self + self.peer = peer + self.quic_connection = aioquic.quic.connection.QuicConnection( + configuration=listener.quic_config, + original_destination_connection_id=cid, + retry_source_connection_id=retry_cid, + session_ticket_fetcher=self.listener.pop_session_ticket, + session_ticket_handler=self.listener.store_session_ticket, + ) + self.cids.add(self.quic_connection.host_cid) + self.listener.connections[self.quic_connection.host_cid] = self + self.send_channel: trio.MemorySendChannel + self.receive_channel: trio.MemoryReceiveChannel + self.send_channel, self.receive_channel = trio.open_memory_channel(100) + self.send_pending = False + self.done = False + self.worker_scope = None + self.streams = {} + + def get_timer_values(self, now: float) -> tuple[float, float]: + expiration = self.quic_connection.get_timer() + if expiration is None: + expiration = now + 3600 # arbitrary "big" value + interval = max(expiration - now, 0) + return (expiration, interval) + + async def close_open_streams(self): + # We copy the list here as awaiting might let the dictionary change + # due to the stream finishing. + for stream in list(self.streams.values()): + if not stream.seen_end(): + await stream.add_input(b"", True) + + def create_stream(self, nursery: trio.Nursery, stream_id: int) -> Stream: + stream = Stream(self, stream_id) + self.streams[stream_id] = stream + nursery.start_soon(stream.run) + return stream + + async def handle_events(self, nursery: trio.Nursery): + count = 0 + while not self.done: + event = self.quic_connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self.streams.get(event.stream_id) + if stream is None: + stream = self.create_stream(nursery, event.stream_id) + await stream.add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + await self.close_open_streams() + self.done = True + elif isinstance(event, aioquic.quic.events.ConnectionIdIssued): + cid = event.connection_id + if cid not in self.cids: + self.cids.add(cid) + self.listener.connections[cid] = self + else: + self.done = True + elif isinstance(event, aioquic.quic.events.ConnectionIdRetired): + cid = event.connection_id + if cid in self.cids: + # These should not fail but we eat them just in case so we + # don't crash the whole connection. + self.cids.remove(cid) + del self.listener.connections[cid] + else: + self.done = True + count += 1 + if count > 10: + # yield + count = 0 + await trio.sleep(0) + + async def run(self): + try: + async with trio.open_nursery() as nursery: + while not self.done: + now = time.time() + (expiration, interval) = self.get_timer_values(now) + # Note it must be trio.current_time() and not now due to how + # trio time works! + if self.send_pending: + interval = 0 + self.send_pending = False + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self.worker_scope: + (datagram, peer) = await self.receive_channel.receive() + self.quic_connection.receive_datagram(datagram, peer, now) + self.worker_scope = None + now = time.time() + if expiration <= now: + self.quic_connection.handle_timer(now) + await self.handle_events(nursery) + datagrams = self.quic_connection.datagrams_to_send(now) + for datagram, _ in datagrams: + await self.listener.socket.sendto(datagram, self.peer) + finally: + await self.close_open_streams() + for cid in self.cids: + try: + del self.listener.connections[cid] + except KeyError: + pass + + def maybe_wake_up(self): + self.send_pending = True + if self.worker_scope is not None: + self.worker_scope.cancel() + + async def write(self, stream: int, data: bytes, is_end=False): + if not self.done: + self.quic_connection.send_stream_data(stream, data, is_end) + self.maybe_wake_up() + + def reset(self, stream: int, error=0): + if not self.done: + self.quic_connection.reset_stream(stream, error) + self.maybe_wake_up() + + def stream_done(self, stream: Stream): + try: + del self.streams[stream.stream_id] + except KeyError: + pass + + +class Listener: + def __init__( + self, + server, + socket, + connection_type, + tls_chain, + tls_key, + quic_log_directory=None, + quic_retry=False, + ): + self.server = server + self.socket = socket # note this is a trio socket + self.connection_type = connection_type + self.connections = {} + self.session_tickets = {} + self.done = False + alpn_protocols = ["doq"] + self.quic_config = aioquic.quic.configuration.QuicConfiguration( + is_client=False, alpn_protocols=alpn_protocols + ) + if quic_log_directory is not None: + self.quic_config.quic_logger = aioquic.quic.logger.QuicFileLogger( + quic_log_directory + ) + self.quic_config.load_cert_chain(tls_chain, tls_key) + self.retry: Optional[aioquic.quic.retry.QuicRetryTokenHandler] + if quic_retry: + self.retry = aioquic.quic.retry.QuicRetryTokenHandler() + else: + self.retry = None + + def pop_session_ticket(self, key): + try: + return self.session_tickets.pop(key) + except KeyError: + return None + + def store_session_ticket(self, session_ticket): + self.session_tickets[session_ticket.ticket] = session_ticket + while len(self.session_tickets) > MAX_SAVED_SESSIONS: + # Grab the first key + key = next(iter(self.session_tickets.keys())) + del self.session_tickets[key] + + async def run(self): + async with trio.open_nursery() as nursery: + while True: + data = None + peer = None + try: + (data, peer) = await self.socket.recvfrom(65535) + except Exception: + continue + buffer = aioquic.buffer.Buffer(data=data) + try: + header = aioquic.quic.packet.pull_quic_header( + buffer, self.quic_config.connection_id_length + ) + except Exception: + continue + cid = header.destination_cid + connection = self.connections.get(cid) + if ( + connection is None + and header.version is not None + and len(data) >= 1200 + and header.version not in self.quic_config.supported_versions + ): + wire = aioquic.quic.packet.encode_quic_version_negotiation( + source_cid=cid, + destination_cid=header.source_cid, + supported_versions=self.quic_config.supported_versions, + ) + await self.socket.sendto(wire, peer) + continue + + if ( + connection is None + and len(data) >= 1200 + and header.packet_type == aioquic.quic.packet.PACKET_TYPE_INITIAL + ): + retry_cid = None + if self.retry is not None: + if not header.token: + if header.version is None: + continue + source_cid = secrets.token_bytes(8) + wire = aioquic.quic.packet.encode_quic_retry( + version=header.version, + source_cid=source_cid, + destination_cid=header.source_cid, + original_destination_cid=header.destination_cid, + retry_token=self.retry.create_token( + peer, header.destination_cid, source_cid + ), + ) + await self.socket.sendto(wire, peer) + continue + else: + try: + (cid, retry_cid) = self.retry.validate_token( + peer, header.token + ) + # We need to recheck the cid here in case of duplicates, + # as we don't want to kick off another connection! + connection = self.connections.get(cid) + if connection is not None: + # duplicate! + continue + except ValueError: + continue + + connection = Connection(self, cid, peer, retry_cid) + nursery.start_soon(connection.run) + + if connection is not None: + try: + connection.send_channel.send_nowait((data, peer)) + except trio.WouldBlock: + pass + + # Listeners are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py index bc655abc2..2727e42a8 100644 --- a/tests/nanonameserver.py +++ b/tests/nanonameserver.py @@ -4,14 +4,43 @@ import enum import errno import functools +import logging +import logging.config import socket +import ssl import struct import threading + import trio import dns.asyncquery +import dns.inet import dns.message import dns.rcode +from tests.util import here + +try: + import tests.doq + + have_doq = True +except Exception: + have_doq = False + +try: + import tests.doh + + have_doh = True +except Exception as e: + have_doh = False + + +class ConnectionType(enum.IntEnum): + UDP = 1 + TCP = 2 + DOT = 3 + DOH = 4 + DOQ = 5 + DOH3 = 6 async def read_exactly(stream, count): @@ -28,11 +57,6 @@ async def read_exactly(stream, count): return s -class ConnectionType(enum.IntEnum): - UDP = 1 - TCP = 2 - - class Request: def __init__(self, message, wire, peer, local, connection_type): self.message = message @@ -59,7 +83,6 @@ def qtype(self): class Server(threading.Thread): - """The nanoserver is a nameserver skeleton suitable for faking a DNS server for various testing purposes. It executes with a trio run loop in a dedicated thread, and is a context manager. Exiting the @@ -81,65 +104,108 @@ class Server(threading.Thread): def __init__( self, + *, address="127.0.0.1", port=0, - enable_udp=True, - enable_tcp=True, + dot_port=0, + doh_port=0, + protocols=tuple(p for p in ConnectionType), use_thread=True, origin=None, keyring=None, + tls_chain=here("tls/public.crt"), + tls_key=here("tls/private.pem"), ): super().__init__() self.address = address self.port = port - self.enable_udp = enable_udp - self.enable_tcp = enable_tcp + self.dot_port = dot_port + self.doh_port = doh_port + self.protocols = set(protocols) + if not have_doq: + self.protocols.discard(ConnectionType.DOQ) + if not have_doh: + self.protocols.discard(ConnectionType.DOH) + self.protocols.discard(ConnectionType.DOH3) self.use_thread = use_thread self.origin = origin self.keyring = keyring self.left = None self.right = None - self.udp = None - self.udp_address = None - self.tcp = None - self.tcp_address = None + self.sockets = {} + self.addresses = {} + self.tls_chain = tls_chain + self.tls_key = tls_key + + def get_address(self, connection_type): + return self.addresses[connection_type] + + # For backwards compatibility + @property + def udp_address(self): + return self.addresses[ConnectionType.UDP] + + @property + def tcp_address(self): + return self.addresses[ConnectionType.TCP] + + @property + def doq_address(self): + return self.addresses[ConnectionType.DOQ] + + def caught(self, who, e): + print(who, "caught", type(e), e) + + def open_sockets(self, port, udp_type, tcp_type): + want_udp = udp_type in self.protocols + want_tcp = tcp_type in self.protocols + udp = None + tcp = None + af = dns.inet.af_for_address(self.address) + + if port != 0 or not (want_udp and want_tcp): + if want_udp: + udp = socket.socket(af, socket.SOCK_DGRAM, 0) + udp.bind((self.address, port)) + self.sockets[udp_type] = udp + if want_tcp: + tcp = socket.create_server((self.address, port), family=af) + self.sockets[tcp_type] = tcp + return - def __enter__(self): - (self.left, self.right) = socket.socketpair() - # We're making the sockets now so they can be sent to by the - # caller immediately (i.e. no race with the listener starting - # in the thread). open_udp_sockets = [] try: while True: - if self.enable_udp: - self.udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) - self.udp.bind((self.address, self.port)) - self.udp_address = self.udp.getsockname() - if self.enable_tcp: - self.tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - self.tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if self.port == 0 and self.enable_udp: - try: - self.tcp.bind((self.address, self.udp_address[1])) - except OSError: - # We can get EADDRINUSE and other errors like EPERM, so - # we just remember to close the UDP socket later, try again, - # and hope we get a better port. You'd think the OS would - # know better... - if len(open_udp_sockets) < 100: - open_udp_sockets.append(self.udp) - continue - # 100 tries to find a port is enough! Give up! - raise - else: - self.tcp.bind((self.address, self.port)) - self.tcp.listen() - self.tcp_address = self.tcp.getsockname() - break + udp = socket.socket(af, socket.SOCK_DGRAM, 0) + udp.bind((self.address, port)) + try: + udp_port = udp.getsockname()[1] + tcp = socket.create_server((self.address, udp_port), family=af) + self.sockets[udp_type] = udp + self.sockets[tcp_type] = tcp + return + except OSError: + # We failed to open the corresponding TCP port. Keep the UDP socket + # open, try again, and hope we get a better port. + if len(open_udp_sockets) < 100: + open_udp_sockets.append(udp) + continue + # 100 tries to find a port is enough! Give up! + raise finally: for udp_socket in open_udp_sockets: udp_socket.close() + + def __enter__(self): + (self.left, self.right) = socket.socketpair() + # We're making the sockets now so they can be sent to by the + # caller immediately (i.e. no race with the listener starting + # in the thread). + self.open_sockets(self.port, ConnectionType.UDP, ConnectionType.TCP) + self.open_sockets(self.dot_port, ConnectionType.DOQ, ConnectionType.DOT) + self.open_sockets(self.doh_port, ConnectionType.DOH3, ConnectionType.DOH) + for proto, sock in self.sockets.items(): + self.addresses[proto] = sock.getsockname() if self.use_thread: self.start() return self @@ -151,10 +217,8 @@ def __exit__(self, ex_ty, ex_va, ex_tr): self.join() if self.right: self.right.close() - if self.udp: - self.udp.close() - if self.tcp: - self.tcp.close() + for sock in self.sockets.values(): + sock.close() async def wait_for_input_or_eof(self): # @@ -204,12 +268,6 @@ def handle_wire(self, wire, peer, local, connection_type): # It also handles any exceptions from handle() # # Returns a (possibly empty) list of wire format message to send. - # - # XXXRTH It might be nice to have a "debug mode" in the server - # where we'd print something in all the places we're eating - # exceptions. That way bugs in handle() would be easier to - # find. - # items = [] r = None try: @@ -236,8 +294,10 @@ def handle_wire(self, wire, peer, local, connection_type): if not items: request = Request(q, wire, peer, local, connection_type) items = self.maybe_listify(self.handle(request)) - except Exception: - # Exceptions from handle get a SERVFAIL response. + except Exception as e: + # Exceptions from handle get a SERVFAIL response, and a print because + # they are usually bugs in the the test! + self.caught("handle", e) r = dns.message.make_response(q) r.set_rcode(dns.rcode.SERVFAIL) items = [r] @@ -252,57 +312,100 @@ def handle_wire(self, wire, peer, local, connection_type): elif thing is not None: yield thing - async def serve_udp(self): - with trio.socket.from_stdlib_socket(self.udp) as sock: - self.udp = None # we own cleanup - local = self.udp_address + async def serve_udp(self, connection_type): + with trio.socket.from_stdlib_socket(self.sockets[connection_type]) as sock: + self.sockets.pop(connection_type) # we own cleanup + local = self.addresses[connection_type] while True: try: (wire, peer) = await sock.recvfrom(65535) - for wire in self.handle_wire(wire, peer, local, ConnectionType.UDP): + for wire in self.handle_wire(wire, peer, local, connection_type): await sock.sendto(wire, peer) - except Exception: - pass + except Exception as e: + self.caught("serve_udp", e) - async def serve_tcp(self, stream): + async def serve_tcp(self, connection_type, stream): try: - peer = stream.socket.getpeername() - local = stream.socket.getsockname() + if connection_type == ConnectionType.DOT: + peer = stream.transport_stream.socket.getpeername() + local = stream.transport_stream.socket.getsockname() + else: + assert connection_type == ConnectionType.TCP + peer = stream.socket.getpeername() + local = stream.socket.getsockname() while True: ldata = await read_exactly(stream, 2) (l,) = struct.unpack("!H", ldata) wire = await read_exactly(stream, l) - for wire in self.handle_wire(wire, peer, local, ConnectionType.TCP): + for wire in self.handle_wire(wire, peer, local, connection_type): l = len(wire) stream_message = struct.pack("!H", l) + wire await stream.send_all(stream_message) - except Exception: - pass + except Exception as e: + self.caught("serve_tcp", e) - async def orchestrate_tcp(self): - with trio.socket.from_stdlib_socket(self.tcp) as sock: - self.tcp = None # we own cleanup + async def orchestrate_tcp(self, connection_type): + with trio.socket.from_stdlib_socket(self.sockets[connection_type]) as sock: + self.sockets.pop(connection_type) # we own cleanup listener = trio.SocketListener(sock) + if connection_type == ConnectionType.DOT: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + ssl_context.load_cert_chain(self.tls_chain, self.tls_key) + listener = trio.SSLListener(listener, ssl_context) + serve = functools.partial(self.serve_tcp, connection_type) async with trio.open_nursery() as nursery: serve = functools.partial( trio.serve_listeners, - self.serve_tcp, + serve, [listener], handler_nursery=nursery, ) nursery.start_soon(serve) + async def serve_doq(self, connection_type) -> None: + with trio.socket.from_stdlib_socket(self.sockets[connection_type]) as sock: + self.sockets.pop(connection_type) # we own cleanup + async with tests.doq.Listener( + self, sock, connection_type, self.tls_chain, self.tls_key + ) as listener: + try: + await listener.run() + except Exception as e: + self.caught("serve_doq", e) + + async def serve_doh(self, connection_type) -> None: + server = tests.doh.make_server( + self, + self.sockets[connection_type], + connection_type, + self.tls_chain, + self.tls_key, + ) + try: + await server() + except Exception as e: + self.caught("serve_doh", e) + async def main(self): + handlers = { + ConnectionType.UDP: self.serve_udp, + ConnectionType.TCP: self.orchestrate_tcp, + ConnectionType.DOT: self.orchestrate_tcp, + ConnectionType.DOH: self.serve_doh, + ConnectionType.DOH3: self.serve_doh, + ConnectionType.DOQ: self.serve_doq, + } + try: async with trio.open_nursery() as nursery: if self.use_thread: nursery.start_soon(self.wait_for_input_or_eof) - if self.enable_udp: - nursery.start_soon(self.serve_udp) - if self.enable_tcp: - nursery.start_soon(self.orchestrate_tcp) - except Exception: - pass + for connection_type in self.protocols: + nursery.start_soon(handlers[connection_type], connection_type) + + except Exception as e: + self.caught("nanoserver main", e) def run(self): if not self.use_thread: @@ -314,24 +417,46 @@ def run(self): import sys import time + logger = logging.getLogger(__name__) + format = "%(asctime)s %(levelname)s: %(message)s" + logging.basicConfig(format=format, level=logging.INFO) + logging.config.dictConfig( + { + "version": 1, + "incremental": True, + "loggers": { + "quart.app": { + "level": "INFO", + }, + "quart.serving": { + "propagate": False, + "level": "ERROR", + }, + "quic": { + "level": "CRITICAL", + }, + }, + } + ) + async def trio_main(): try: - with Server(port=5354, use_thread=False) as server: - print( - f"Trio mode: listening on UDP: {server.udp_address}, " - + f"TCP: {server.tcp_address}" - ) + with Server( + port=5354, dot_port=5355, doh_port=5356, use_thread=False + ) as server: + print("Trio mode") + for proto, address in server.addresses.items(): + print(f" listening on {proto.name}: {address}") async with trio.open_nursery() as nursery: nursery.start_soon(server.main) - except Exception: - pass + except Exception as e: + print("trio_main caught", type(e), e) def threaded_main(): - with Server(port=5354) as server: - print( - f"Thread Mode: listening on UDP: {server.udp_address}, " - + f"TCP: {server.tcp_address}" - ) + with Server(port=5354, dot_port=5355, doh_port=5356) as server: + print("Thread mode") + for proto, address in server.addresses.items(): + print(f" listening on {proto.name}: {address}") time.sleep(300) if len(sys.argv) > 1 and sys.argv[1] == "trio": diff --git a/tests/nanoquic.py b/tests/nanoquic.py deleted file mode 100644 index 47c10431b..000000000 --- a/tests/nanoquic.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -try: - import asyncio - import socket - import struct - import threading - - import aioquic.asyncio - import aioquic.asyncio.server - import aioquic.quic.configuration - import aioquic.quic.events - - import dns.asyncquery - import dns.message - import dns.rcode - from tests.util import here - - have_quic = True - - class Request: - def __init__(self, message, wire): - self.message = message - self.wire = wire - - @property - def question(self): - return self.message.question[0] - - @property - def qname(self): - return self.question.name - - @property - def qclass(self): - return self.question.rdclass - - @property - def qtype(self): - return self.question.rdtype - - class NanoQuic(aioquic.asyncio.QuicConnectionProtocol): - def quic_event_received(self, event): - # This is a bit hackish and not fully general, but this is a test server! - if isinstance(event, aioquic.quic.events.StreamDataReceived): - data = bytes(event.data) - (wire_len,) = struct.unpack("!H", data[:2]) - wire = self.handle_wire(data[2 : 2 + wire_len]) - if wire is not None: - self._quic.send_stream_data(event.stream_id, wire, end_stream=True) - - def handle(self, request): - r = dns.message.make_response(request.message) - r.set_rcode(dns.rcode.REFUSED) - return r - - def handle_wire(self, wire): - response = None - try: - q = dns.message.from_wire(wire) - except dns.message.ShortHeader: - return - except Exception as e: - try: - q = dns.message.from_wire(wire, question_only=True) - response = dns.message.make_response(q) - response.set_rcode(dns.rcode.FORMERR) - except Exception: - return - if response is None: - try: - request = Request(q, wire) - response = self.handle(request) - except Exception: - response = dns.message.make_response(q) - response.set_rcode(dns.rcode.SERVFAIL) - wire = response.to_wire() - return struct.pack("!H", len(wire)) + wire - - class Server(threading.Thread): - def __init__(self, address="127.0.0.1"): - super().__init__() - self.address = address - self.transport = None - self.protocol = None - self.left = None - self.right = None - self.ready = threading.Event() - - def __enter__(self): - self.left, self.right = socket.socketpair() - self.start() - self.ready.wait(4) - return self - - def __exit__(self, ex_ty, ex_va, ex_tr): - if self.protocol is not None: - self.protocol.close() - if self.transport is not None: - self.transport.close() - if self.left: - self.left.close() - if self.is_alive(): - self.join() - if self.right: - self.right.close() - - async def arun(self): - reader, _ = await asyncio.open_connection(sock=self.right) - conf = aioquic.quic.configuration.QuicConfiguration( - alpn_protocols=["doq"], - is_client=False, - ) - conf.load_cert_chain(here("tls/public.crt"), here("tls/private.pem")) - loop = asyncio.get_event_loop() - (self.transport, self.protocol) = await loop.create_datagram_endpoint( - lambda: aioquic.asyncio.server.QuicServer( - configuration=conf, create_protocol=NanoQuic - ), - local_addr=(self.address, 0), - ) - info = self.transport.get_extra_info("sockname") - self.port = info[1] - self.ready.set() - try: - await reader.read(1) - except Exception: - pass - - def run(self): - asyncio.run(self.arun()) - -except ImportError: - have_quic = False - - class NanoQuic: - pass diff --git a/tests/test_doq.py b/tests/test_doq.py index 76cc3b64b..d749d8954 100644 --- a/tests/test_doq.py +++ b/tests/test_doq.py @@ -13,13 +13,15 @@ from .util import have_ipv4, have_ipv6, here +have_quic = False try: - from .nanoquic import Server + from .nanonameserver import Server - _nanoquic_available = True + have_quic = True except ImportError: - _nanoquic_available = False + pass +if not have_quic: class Server(object): pass @@ -31,15 +33,16 @@ class Server(object): addresses.append("::1") if len(addresses) == 0: # no networking - _nanoquic_available = False + have_quic = False -@pytest.mark.skipif(not _nanoquic_available, reason="requires aioquic") +@pytest.mark.skipif(not have_quic, reason="requires aioquic") def test_basic_sync(): q = dns.message.make_query("www.example.", "A") for address in addresses: - with Server(address) as server: - r = dns.query.quic(q, address, port=server.port, verify=here("tls/ca.crt")) + with Server(address=address) as server: + port = server.doq_address[1] + r = dns.query.quic(q, address, port=port, verify=here("tls/ca.crt")) assert r.rcode() == dns.rcode.REFUSED @@ -49,23 +52,25 @@ async def amain(address, port): assert r.rcode() == dns.rcode.REFUSED -@pytest.mark.skipif(not _nanoquic_available, reason="requires aioquic") +@pytest.mark.skipif(not have_quic, reason="requires aioquic") def test_basic_asyncio(): dns.asyncbackend.set_default_backend("asyncio") for address in addresses: - with Server(address) as server: - asyncio.run(amain(address, server.port)) + with Server(address=address) as server: + port = server.doq_address[1] + asyncio.run(amain(address, port)) try: import trio - @pytest.mark.skipif(not _nanoquic_available, reason="requires aioquic") + @pytest.mark.skipif(not have_quic, reason="requires aioquic") def test_basic_trio(): dns.asyncbackend.set_default_backend("trio") for address in addresses: - with Server(address) as server: - trio.run(amain, address, server.port) + with Server(address=address) as server: + port = server.doq_address[1] + trio.run(amain, address, port) except ImportError: pass diff --git a/tests/tls/ca.crt b/tests/tls/ca.crt index 81c768253..96a1297dc 100644 --- a/tests/tls/ca.crt +++ b/tests/tls/ca.crt @@ -1,20 +1,20 @@ -----BEGIN CERTIFICATE----- -MIIDTDCCAjSgAwIBAgIUUCWxpsMnzETqwNKJ38le9z7oFEEwDQYJKoZIhvcNAQEL -BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMDcwOTIyMjQw -N1oXDTMyMDcwNjIyMjQzN1owHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3Jn -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0AMlXDsx/7Kis4lUhAML -yaL4wtvhPGnqz20Gnhd/b2uAjZbtLtKDG2aRC0QtHL6N0vfBhj+KUV/unT60Mf7G -Pm2Z8fOxiwh/UJ8oxoJe59izklrwM0PL2iR21OMCCsiYcjiOOx75RUZ/6KEGMTgd -3wvqwEV320yd3WInkdO72n9jlQTN3VtwLwkIkSbINiuUCKgP9hy28K7HjMHvEIlf -QZfh9wIHhbqs/JP3dirRL7MKWFAv3MlmMffb/6NBBFb6FaRjS6WjojD8qaSTr14/ -tyqrK7zL32npKm/TbzxC8hFwYdwd3HURgpWInA6CRIcyZM/k4y7dHQlI4ID7hmcC -1QIDAQABo4GDMIGAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0G -A1UdDgQWBBQrNPKeL6rBhPV+Eb1RnvIkeax5sDAfBgNVHSMEGDAWgBQrNPKeL6rB -hPV+Eb1RnvIkeax5sDAdBgNVHREEFjAUghJxdWljLmRuc3B5dGhvbi5vcmcwDQYJ -KoZIhvcNAQELBQADggEBAADpAtDvceOrhn5FReYip9DlTW7KKrRDDFCo0SNdhvN3 -6mU/Hn3jNXYu9Ym3NDVL8q9UWzLRcSNLUo1qjkK3aOlgwcO6PuGKXukF7Zdd8wVb -pPdUqooBmj6akqmNvmloZyDmQ+aXcYhR83hcEHFOK+C7pGLqSFChN1mgDT1/mgBk -pODOZkcLtZI8YJyQ2sn3WhUJS52D6xfmPigliUcYqi6i+w1vxD45QilWbvqCwnN/ -6qmb3JQsMf+3MCtogVcSZjE9cf4CwlmKqgMxsBKz+/Qk9YPMpDuecEbd76L+Htdl -HWuDlemBzyhd5qO5y/UGarqmuh3MgkOdFVQWAUygcCM= +MIIDWDCCAkCgAwIBAgIUMNvX24hfDzdebzBu5Jfp8N5Y/W4wDQYJKoZIhvcNAQEL +BQAwITEfMB0GA1UEAxMWdGVzdC5wa2kuZG5zcHl0aG9uLm9yZzAeFw0yNDAzMjYy +MDU0NTZaFw0zNDAzMjQyMDU1MjZaMCExHzAdBgNVBAMTFnRlc3QucGtpLmRuc3B5 +dGhvbi5vcmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC3s6+DW75g +Z15vN9zMaWdIzMSJF09yZSzCy4btkqs/WK1TGZkUYmCxXCrQFuuin+TaXiv9TpOO +4FVrNzLpf5GIgqBwW5T4gf32CfEkeNoxJK+7bctP9i8hfsrOO5CPLogtSWGTpxE+ +S2B3L2riIMF7Fp77/wP29+t88PrFn/P7/NwNbrqwzJ7vmsDcCvSnRrzOpKq2T4B/ +CoV8SeUrj4PGOjQivh2Lda1dx4J3Nlg6aXEyqu+80QQBH0ya8ezqpVuHtaoYXiuY +U6efngSrBQsjuVyQumaZKc5G0GxR0WlNKrFMVSAFSgcZvEFm92MX/HX6Yh9EFOeA +rq4aRvKEB2BZAgMBAAGjgYcwgYQwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQF +MAMBAf8wHQYDVR0OBBYEFKPPasZqlmsMgS9hXhu/w2kOWtUMMB8GA1UdIwQYMBaA +FKPPasZqlmsMgS9hXhu/w2kOWtUMMCEGA1UdEQQaMBiCFnRlc3QucGtpLmRuc3B5 +dGhvbi5vcmcwDQYJKoZIhvcNAQELBQADggEBAAHNKgXtZP82TDbwc0qA3iJVxOcf +eAV8/S+o+ku4/f5dk3+kJJo4sfUPOG0M+gJXt21J7597bkVdTYubqgaZ42Z30tqS +mjxuM+KE4pg13CvVPsH/bvHatbKDGpGcICS+isHDe+0w+eHRp+AuVyl7/KDnBoTy +qsFf0kjR/qtWZ0qHrAWP/3pdgIo3G+jUJUiDxXj0N57HfPgYDswh2hY9rrtuy0m7 +m6v5W5aWH0ebp2o/FR+j6Z4vM8ibmqIevBd9VbhnDE2VOzTDR6r43q1OuWBRae+4 +j0BgirT00eD4QckjHYVCMNevAS6EKM4yA1C413YbNd18iSaDvyLs3B/fLRk= -----END CERTIFICATE----- diff --git a/tests/tls/private.pem b/tests/tls/private.pem index 06a01fadb..171b1b326 100644 --- a/tests/tls/private.pem +++ b/tests/tls/private.pem @@ -1,3 +1,27 @@ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEIL2OxuOo+awfhPvvm82EBZ4VA6ULQHlebxGCamZ/H5Rt ------END PRIVATE KEY----- +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA3erHsqChOB2LIP+I8MwreTxXIU6Vmqsbs0bXT7sETkD2lAqq +TZoKNK4ZiGi7tMmXE3bfE0PaPacSRj3+ocsYt0r/OqtnzVKE325cryZIPjGLNn0R +iWFrCLTa4HU3B88XuvJGOgdGGf82RlhqbI11Tjtl1DGYyhGHghGUnWbS0tVtVyoU +QCQDvxNIRkeKC8ygQxwTdeFFXPBOv9YnzTOhaZzz1zGLAtF6Cr5f52/PKFUWsOps +Xa0e/TqOjq1avwqNZ9ud1hTAdzAiJStDCxgPYWp3c8RN7me13zGj8Y1xpvtsbns6 +59+WAbKnB8tz2sNB60nsdChUKP1eDwQbtVnoJwIDAQABAoIBABLKJ0Bzo1LqBXa0 +zDZ/QwsP1dzHF5mx9TV3wRFKJ3Isw/QC9yp86XJOb3ECVSpYi0cloHu0Gg1wUPbL +lvzCAoS6f+PK0Y4r934njQdzeVhyv3PMoSh9rB9fmMy6f/9URJEapGGTbhcTadgI +8nHghFcsZlHSJTquw0d7D5dINn+NVEdNUEHG9CyhBu3+miwRqPSiRoJBEo9Tloh1 +2i6lvuqdszntZsdjhEt6Iyw6AS7SmmYeaxxg7hAqByFHeEDfK5xsl1czwe4czDfH +RTYX1xsDOOloTHUlqN6PORA8QC3BgHLCVrdNFrphg4/4I9fWsFh/DzABBy+8Kiwp +IHxTHsECgYEA6j7TnkdSZDhdPdldC8by2KihwG9lOyZkYbQdhBchcd4BabRBLuQ7 +Fmk3RhxZgH2ijKqbadXvKLk1nsG+SJk77cbKDHUBxUmEqT310m1W+YhkTxqiXW/S +TJKL3+M3oHBeIs1QHC2DIF3xt13Jc1BrrZSfmzctzQHYoc8WnGOBCaECgYEA8obZ +OeaWBQqN1PRjZ2bEaxcfVzL6OeXN82j8wBtitd9EWk/v6XwuO3K6O96vo6eGJcmO +2W8O6FQpgfyEVl1hK56pCcDkDmwl/XeKd1XIOh2/zltiKjAmUZ9Q8SeCmDysOcsd +Tc09OaSqSWte+L6iTwk7p8iSTSw/izGcu3VL7McCgYByHlmKasTA/pSuZQ7nhe0Z +kE39KkfvIS0WTGF00LACgV2+2YpIBfijWm8LQRR5fLuMPDGqxgbVmCV/SnQhekWv ++YDFwNsz+jUfHoh8E7ijqMb1oswnKSsTEvICCPg4uYWi/tNgZuvTAPGZm59hBnTv +A9EeFSvDDHs1mWYymmdrAQKBgAapPX4hny00RQD8VV6Zq/tk/y9d7xF4BlgRIiAE +oIluQGpal7RJ/NsVI5hRXXGZQE35YzsFmds3tIwla10T439XND1YVusufTyg8+Sj +LoSqHIKGcAPInsTPI2H8O9ICmJhdw8hHQs86fpLVqB4c3khdcI4DLEGCXZxtGGjt +p9AxAoGBANduq+U2AMvN6iIqw/j/TC2YHoEdVIMPEE9L6nZBFPVPGseVp5kGCCx3 +v4aErmFOZI/2yXuZ0BVEne0Fjm5TdRN7rEqDGx9DHpsupTnfArhQPeDRw+lRDyNc +JtkAZKoJCfAsx5DK5fUrpW/2g2h9xYLVjihbMUKvhNQ+IBwqINH3 +-----END RSA PRIVATE KEY----- diff --git a/tests/tls/public.crt b/tests/tls/public.crt index 96129a1b1..7587ce577 100644 --- a/tests/tls/public.crt +++ b/tests/tls/public.crt @@ -1,35 +1,21 @@ -----BEGIN CERTIFICATE----- -MIICZjCCAU6gAwIBAgIUBTlEzhtkXYQvZl5CYRNBxOG4GpEwDQYJKoZIhvcNAQEL -BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMTAwOTE2MjYw -OFoXDTMwMTIyNjE2MjYzOFowFDESMBAGA1UEAxMJbG9jYWxob3N0MCowBQYDK2Vw -AyEAKpQbO2JXhCGnQs2MrWmGBK5LcmJMWPXCzM2PfWbo1TyjgaAwgZ0wDgYDVR0P -AQH/BAQDAgOoMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4E -FgQUM2pZy8pH78CvP+FnuF190KEJkjUwHwYDVR0jBBgwFoAUKzTyni+qwYT1fhG9 -UZ7yJHmsebAwLAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAA -AAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQA0JlNLrLz3ajCzSVfOQsUdd3a3wR7Q -Dr28mYoDHSY9mhnJ9IQeInmGvPMLA4dgiRPFqxWsKh+lxzZObkbMjf1IAIVykfh6 -LynePm58/lnRrhdvf8vFfccuTyeb2aD0ZBA/RyhZam79J6JjRRovkSj9TyIqKfif -6T6QWXOXwAF89rH8YHAKnRSl32pqZuDhOnM0Ien+Sa6KpCvgIDogHQxIVbe1egZl -2Ec0LVQUaXhoICd1c6xoRoAa5UzDFJ7ujeu1XNGWKIiXESlcIo7SZjzusL2p5vv/ -frM+r43khtZ4s+F70A+B3AndcVSeKTQ5KlftN9CBuiQoYzhY29NmL93X ------END CERTIFICATE----- ------BEGIN CERTIFICATE----- -MIIDTDCCAjSgAwIBAgIUUCWxpsMnzETqwNKJ38le9z7oFEEwDQYJKoZIhvcNAQEL -BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMDcwOTIyMjQw -N1oXDTMyMDcwNjIyMjQzN1owHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3Jn -MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0AMlXDsx/7Kis4lUhAML -yaL4wtvhPGnqz20Gnhd/b2uAjZbtLtKDG2aRC0QtHL6N0vfBhj+KUV/unT60Mf7G -Pm2Z8fOxiwh/UJ8oxoJe59izklrwM0PL2iR21OMCCsiYcjiOOx75RUZ/6KEGMTgd -3wvqwEV320yd3WInkdO72n9jlQTN3VtwLwkIkSbINiuUCKgP9hy28K7HjMHvEIlf -QZfh9wIHhbqs/JP3dirRL7MKWFAv3MlmMffb/6NBBFb6FaRjS6WjojD8qaSTr14/ -tyqrK7zL32npKm/TbzxC8hFwYdwd3HURgpWInA6CRIcyZM/k4y7dHQlI4ID7hmcC -1QIDAQABo4GDMIGAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0G -A1UdDgQWBBQrNPKeL6rBhPV+Eb1RnvIkeax5sDAfBgNVHSMEGDAWgBQrNPKeL6rB -hPV+Eb1RnvIkeax5sDAdBgNVHREEFjAUghJxdWljLmRuc3B5dGhvbi5vcmcwDQYJ -KoZIhvcNAQELBQADggEBAADpAtDvceOrhn5FReYip9DlTW7KKrRDDFCo0SNdhvN3 -6mU/Hn3jNXYu9Ym3NDVL8q9UWzLRcSNLUo1qjkK3aOlgwcO6PuGKXukF7Zdd8wVb -pPdUqooBmj6akqmNvmloZyDmQ+aXcYhR83hcEHFOK+C7pGLqSFChN1mgDT1/mgBk -pODOZkcLtZI8YJyQ2sn3WhUJS52D6xfmPigliUcYqi6i+w1vxD45QilWbvqCwnN/ -6qmb3JQsMf+3MCtogVcSZjE9cf4CwlmKqgMxsBKz+/Qk9YPMpDuecEbd76L+Htdl -HWuDlemBzyhd5qO5y/UGarqmuh3MgkOdFVQWAUygcCM= +MIIDZDCCAkygAwIBAgIUE4MNnbLX3YqTfCPwoka5eoDnrd4wDQYJKoZIhvcNAQEL +BQAwITEfMB0GA1UEAxMWdGVzdC5wa2kuZG5zcHl0aG9uLm9yZzAeFw0yNDAzMjYy +MDU2MjBaFw0yNDA0MjcyMDU2NTBaMBQxEjAQBgNVBAMTCWxvY2FsaG9zdDCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAN3qx7KgoTgdiyD/iPDMK3k8VyFO +lZqrG7NG10+7BE5A9pQKqk2aCjSuGYhou7TJlxN23xND2j2nEkY9/qHLGLdK/zqr +Z81ShN9uXK8mSD4xizZ9EYlhawi02uB1NwfPF7ryRjoHRhn/NkZYamyNdU47ZdQx +mMoRh4IRlJ1m0tLVbVcqFEAkA78TSEZHigvMoEMcE3XhRVzwTr/WJ80zoWmc89cx +iwLRegq+X+dvzyhVFrDqbF2tHv06jo6tWr8KjWfbndYUwHcwIiUrQwsYD2Fqd3PE +Te5ntd8xo/GNcab7bG57OufflgGypwfLc9rDQetJ7HQoVCj9Xg8EG7VZ6CcCAwEA +AaOBoDCBnTAOBgNVHQ8BAf8EBAMCA6gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsG +AQUFBwMCMB0GA1UdDgQWBBQc2JmwrD0UY3gaKENUGLbh0CNxLjAfBgNVHSMEGDAW +gBSjz2rGapZrDIEvYV4bv8NpDlrVDDAsBgNVHREEJTAjgglsb2NhbGhvc3SHBH8A +AAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZIhvcNAQELBQADggEBAH3CiiXBzmFe +nEoj9JiyE+yStaVdyA0wG0jsHDu5yCbwMVqSdNbRTGeWCoQ5j0zmf+cIci5uSoRJ +U7SaNnHzx8yk24k7RKi12iUt2sNL101dLy1Fk6F5kF3DKXo57W31I4jE0v9CSDfg +CcbEPl1KFFJTJIEC0C2H+XuHbGkaOp0LxMdRTpnlH06abusU39OsMDs2gixjw1Xw +z+PWbRqkbXbhBLznAgb3MfhTSrKvS3bUQLPCe5RGCAlwH8QHZkKMxKnFmZaiQwyI +uZIYXHXUbWaT031cD+hwVF76rJ9GcDXC63k9rmWTZNUurftfHytFf4yxNhV6/1ra +a7rsc/ziASQ= -----END CERTIFICATE-----