From 3e8efe73dcb3abfdfcd4a1ac2d10d86c1222c4a3 Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Sun, 6 Jun 2021 13:05:28 +0900 Subject: [PATCH] anyio port WIP --- .travis.yml | 4 ++-- Makefile | 2 +- requirements-dev.txt | 11 ++++++--- setup.py | 7 +++++- trio_websocket/_impl.py | 53 ++++++++++++++++++++++------------------- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5095bc3..bf29769 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,12 +13,12 @@ jobs: - python: pypy3 - name: "latest deps" python: 3.9 - env: UPGRADE="pip install --upgrade trio wsproto" + env: UPGRADE="pip install --upgrade anyio trio wsproto" install: - pip install -r requirements-dev.txt - $UPGRADE - - pip install -e . + - pip install -e .[trio] script: - make test diff --git a/Makefile b/Makefile index b496fff..458bc32 100644 --- a/Makefile +++ b/Makefile @@ -32,5 +32,5 @@ publish: # make -W requirements-dev.{in,txt} PIP_COMPILE_ARGS="-P foo" ifneq ($(PIP_COMPILE_ARGS),) requirements-dev.txt: setup.py requirements-dev.in - pip-compile -q $(PIP_COMPILE_ARGS) --output-file $@ $^ + pip-compile -q $(PIP_COMPILE_ARGS) --extra trio --output-file $@ $^ endif diff --git a/requirements-dev.txt b/requirements-dev.txt index 1114142..4e2009d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,10 +2,12 @@ # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --output-file=requirements-dev.txt requirements-dev.in setup.py +# pip-compile --extra=trio --output-file=requirements-dev.txt requirements-dev.in setup.py # alabaster==0.7.12 # via sphinx +anyio[trio]==3.1.0 + # via trio-websocket (setup.py) astroid==2.4.2 # via pylint async-generator==1.10 @@ -51,6 +53,7 @@ h11==0.11.0 # via wsproto idna==2.10 # via + # anyio # requests # trio # trustme @@ -130,7 +133,9 @@ six==1.15.0 # packaging # readme-renderer sniffio==1.2.0 - # via trio + # via + # anyio + # trio snowballstemmer==2.0.0 # via sphinx sortedcontainers==2.3.0 @@ -166,8 +171,8 @@ tqdm==4.51.0 trio==0.17.0 # via # -r requirements-dev.in + # anyio # pytest-trio - # trio-websocket (setup.py) trustme==0.6.0 # via -r requirements-dev.in twine==3.2.0 diff --git a/setup.py b/setup.py index ab84d41..0e8d01e 100644 --- a/setup.py +++ b/setup.py @@ -39,10 +39,15 @@ keywords='websocket client server trio', packages=find_packages(exclude=['docs', 'examples', 'tests']), install_requires=[ + 'anyio ~= 3.0', 'async_generator>=1.10', - 'trio>=0.11', 'wsproto>=0.14', ], + extras_require={ + 'trio': [ + 'anyio[trio] ~= 3.0', + ] + }, project_urls={ 'Bug Reports': 'https://github.com/HyperionGray/trio-websocket/issues', 'Source': 'https://github.com/HyperionGray/trio-websocket', diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 440d0e2..88658b8 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -9,6 +9,7 @@ import struct import urllib.parse +import anyio from async_generator import asynccontextmanager import trio import trio.abc @@ -101,15 +102,15 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. ''' - async with trio.open_nursery() as new_nursery: + async with anyio.create_task_group() as new_nursery: try: - with trio.fail_after(connect_timeout): + with anyio.fail_after(connect_timeout): connection = await connect_websocket(new_nursery, host, port, resource, use_ssl=use_ssl, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, max_message_size=max_message_size) - except trio.TooSlowError: + except TimeoutError: raise ConnectionTimeout from None except OSError as e: raise HandshakeError from e @@ -117,9 +118,9 @@ async def open_websocket(host, port, resource, *, use_ssl, subprotocols=None, yield connection finally: try: - with trio.fail_after(disconnect_timeout): + with anyio.fail_after(disconnect_timeout): await connection.aclose() - except trio.TooSlowError: + except TimeoutError: raise DisconnectionTimeout from None @@ -368,7 +369,7 @@ async def wrap_server_stream(nursery, stream, async def serve_websocket(handler, host, port, ssl_context, *, handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT, - disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED): + disconnect_timeout=CONN_TIMEOUT, task_status=anyio.TASK_STATUS_IGNORED): ''' Serve a WebSocket over TCP. @@ -524,7 +525,7 @@ class Future: def __init__(self): ''' Constructor. ''' self._value = None - self._value_event = trio.Event() + self._value_event = anyio.Event() def set_value(self, value): ''' @@ -723,7 +724,7 @@ def __init__(self, stream, wsproto, *, host=None, path=None, self._reject_status = None self._reject_headers = None self._reject_body = b'' - self._send_channel, self._recv_channel = trio.open_memory_channel( + self._send_channel, self._recv_channel = anyio.create_memory_object_stream( message_queue_size) self._pings = OrderedDict() # Set when the server has received a connection request event. This @@ -731,13 +732,13 @@ def __init__(self, stream, wsproto, *, host=None, path=None, self._connection_proposal = Future() # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. - self._open_handshake = trio.Event() + self._open_handshake = anyio.Event() # Set once a WebSocket closed handshake takes place, i.e after a close # frame has been sent and a close frame has been received. - self._close_handshake = trio.Event() + self._close_handshake = anyio.Event() # Set immediately upon receiving closed event from peer. Used to # test close race conditions between client and server. - self._for_testing_peer_closed_connection = trio.Event() + self._for_testing_peer_closed_connection = anyio.Event() @property def closed(self): @@ -868,7 +869,7 @@ async def get_message(self): ''' try: message = await self._recv_channel.receive() - except (trio.ClosedResourceError, trio.EndOfChannel): + except (anyio.ClosedResourceError, anyio.EndOfStream): raise ConnectionClosed(self._close_reason) from None return message @@ -899,7 +900,7 @@ async def ping(self, payload=None): format(payload)) if payload is None: payload = struct.pack('!I', random.getrandbits(32)) - event = trio.Event() + event = anyio.Event() self._pings[payload] = event await self._send(Ping(payload=payload)) await event.wait() @@ -1003,7 +1004,7 @@ async def _close_stream(self): try: with _preserve_current_exception(): await self._stream.aclose() - except trio.BrokenResourceError: + except (trio.BrokenResourceError, anyio.BrokenResourceError): # This means the TCP connection is already dead. pass @@ -1088,7 +1089,7 @@ async def _handle_close_connection_event(self, event): :param wsproto.events.CloseConnection event: ''' self._for_testing_peer_closed_connection.set() - await trio.sleep(0) + await anyio.sleep(0) if self._wsproto.state == ConnectionState.REMOTE_CLOSING: await self._send(event.response()) await self._close_web_socket(event.code, event.reason or None) @@ -1125,7 +1126,7 @@ async def _handle_message_event(self, event): self._message_parts = [] try: await self._send_channel.send(msg) - except (trio.ClosedResourceError, trio.BrokenResourceError): + except (trio.ClosedResourceError, trio.BrokenResourceError, anyio.BrokenResourceError): # The receive channel is closed, probably because somebody # called ``aclose()``. We don't want to abort the reader task, # and there's no useful cleanup that we can do here. @@ -1212,7 +1213,7 @@ async def _reader_task(self): # Get network data. try: data = await self._stream.receive_some(RECEIVE_BYTES) - except (trio.BrokenResourceError, trio.ClosedResourceError): + except (trio.BrokenResourceError, anyio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() break if len(data) == 0: @@ -1250,7 +1251,7 @@ async def _send(self, event): logger.debug('%s sending %d bytes', self, len(data)) try: await self._stream.send_all(data) - except (trio.BrokenResourceError, trio.ClosedResourceError): + except (trio.BrokenResourceError, anyio.BrokenResourceError, trio.ClosedResourceError): await self._abort_web_socket() raise ConnectionClosed(self._close_reason) from None @@ -1377,7 +1378,7 @@ def listeners(self): listeners.append(repr(listener)) return listeners - async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): + async def run(self, *, task_status=anyio.TASK_STATUS_IGNORED): ''' Start serving incoming connections requests. @@ -1388,7 +1389,7 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): :param task_status: Part of the Trio nursery start protocol. :returns: This method never returns unless cancelled. ''' - async with trio.open_nursery() as nursery: + async with anyio.create_task_group() as nursery: serve_listeners = partial(trio.serve_listeners, self._handle_connection, self._listeners, handler_nursery=self._handler_nursery) @@ -1396,7 +1397,7 @@ async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): logger.debug('Listening on %s', ','.join([str(l) for l in self.listeners])) task_status.started(self) - await trio.sleep_forever() + await anyio.sleep_forever() async def _handle_connection(self, stream): ''' @@ -1406,22 +1407,24 @@ async def _handle_connection(self, stream): :param stream: :type stream: trio.abc.Stream ''' - async with trio.open_nursery() as nursery: + async with anyio.create_task_group() as nursery: wsproto = WSConnection(ConnectionType.SERVER) connection = WebSocketConnection(stream, wsproto, message_queue_size=self._message_queue_size, max_message_size=self._max_message_size) nursery.start_soon(connection._reader_task) - with trio.move_on_after(self._connect_timeout) as connect_scope: + have_request = False + with anyio.move_on_after(self._connect_timeout) as connect_scope: request = await connection._get_request() - if connect_scope.cancelled_caught: + have_request = True + if not have_request: nursery.cancel_scope.cancel() await stream.aclose() return try: await self._handler(request) finally: - with trio.move_on_after(self._disconnect_timeout): + with anyio.move_on_after(self._disconnect_timeout): # aclose() will shut down the reader task even if it's # cancelled: await connection.aclose()