Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(websocket): add support for reason in WebSocket.close() #2056

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion falcon/asgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .request import Request
from .response import Response
from .structures import SSEvent
from .ws import check_support_reason
from .ws import WebSocket
from .ws import WebSocketOptions

Expand Down Expand Up @@ -974,7 +975,11 @@ async def _handle_websocket(self, ver, scope, receive, send):
# we don't support, so bail out. This also fulfills the ASGI
# spec requirement to only process the request after
# receiving and verifying the first event.
await send({'type': EventType.WS_CLOSE, 'code': WSCloseCode.SERVER_ERROR})
response = {'type': EventType.WS_CLOSE, 'code': WSCloseCode.SERVER_ERROR}
if check_support_reason(ver):
response['reason'] = 'Internal Server Error'

await send(response)
return

req = self._request_type(scope, receive, options=self.req_options)
Expand All @@ -986,6 +991,7 @@ async def _handle_websocket(self, ver, scope, receive, send):
send,
self.ws_options.media_handlers,
self.ws_options.max_receive_queue,
self.ws_options.default_close_reasons,
)

on_websocket = None
Expand Down
48 changes: 41 additions & 7 deletions falcon/asgi/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class WebSocket:
'_asgi_send',
'_buffered_receiver',
'_close_code',
'_close_reasons',
'_supports_accept_headers',
'_supports_reason',
'_mh_bin_deserialize',
'_mh_bin_serialize',
'_mh_text_deserialize',
Expand All @@ -69,8 +71,10 @@ def __init__(
Union[media.BinaryBaseHandlerWS, media.TextBaseHandlerWS],
],
max_receive_queue: int,
default_close_reasons: Dict[Optional[int], str],
):
self._supports_accept_headers = ver != '2.0'
self._supports_reason = check_support_reason(ver)

# NOTE(kgriffs): Normalize the iterable to a stable tuple; note that
# ordering is significant, and so we preserve it here.
Expand All @@ -94,6 +98,7 @@ def __init__(
self._mh_bin_serialize = mh_bin.serialize
self._mh_bin_deserialize = mh_bin.deserialize

self._close_reasons = default_close_reasons
self._state = _WebSocketState.HANDSHAKE
self._close_code = None # type: Optional[int]

Expand Down Expand Up @@ -257,12 +262,15 @@ async def close(self, code: Optional[int] = None) -> None:
if self.closed:
return

await self._asgi_send(
{
'type': EventType.WS_CLOSE,
'code': code,
}
)
response = {'type': EventType.WS_CLOSE, 'code': code}

if self._supports_reason:
if code in self._close_reasons:
response['reason'] = self._close_reasons[code]
elif 3100 <= code <= 3999:
response['reason'] = falcon.util.code_to_http_status(code - 3000)

await self._asgi_send(response)

self._state = _WebSocketState.CLOSED
self._close_code = code
Expand Down Expand Up @@ -512,6 +520,10 @@ class WebSocketOptions:
unhandled error is raised while handling a WebSocket connection
(default ``1011``). For a list of valid close codes and ranges,
see also: https://tools.ietf.org/html/rfc6455#section-7.4
default_close_reasons (dict): A default mapping between the Websocket
close code and the reason why the connection is close. Close codes
corresponding to HTTPerrors are not included as they will be rendered
automatically using HTTP status.
media_handlers (dict): A dict-like object for configuring media handlers
according to the WebSocket payload type (TEXT vs. BINARY) of a
given message. See also: :ref:`ws_media_handlers`.
Expand All @@ -528,7 +540,12 @@ class WebSocketOptions:

"""

__slots__ = ['error_close_code', 'max_receive_queue', 'media_handlers']
__slots__ = [
'error_close_code',
'default_close_reasons',
'max_receive_queue',
'media_handlers',
]

def __init__(self):
try:
Expand Down Expand Up @@ -557,6 +574,12 @@ def __init__(self):
#
self.error_close_code: int = WSCloseCode.SERVER_ERROR

self.default_close_reasons: Dict[int, str] = {
1000: 'Normal Closure',
1011: 'Internal Server Error',
3011: 'Internal Server Error',
}

# NOTE(kgriffs): The websockets library itself will buffer, so we keep
# this value fairly small by default to mitigate buffer bloat. But in
# the case that we have a large spillover from the websocket server's
Expand Down Expand Up @@ -701,3 +724,14 @@ async def _pump(self):
if self._pop_message_waiter is not None:
self._pop_message_waiter.set_result(None)
self._pop_message_waiter = None


def check_support_reason(asgi_ver):
target_ver = [2, 3]
current_ver = asgi_ver.split('.')

for i in range(2):
if int(current_ver[i]) < target_ver[i]:
return False

return True
26 changes: 24 additions & 2 deletions falcon/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(self):
self._state = _WebSocketState.CONNECT
self._disconnect_emitted = False
self._close_code = None
self._close_reason = None
self._accepted_subprotocol = None
self._accepted_headers = None
self._collected_server_events = deque()
Expand All @@ -427,6 +428,10 @@ def closed(self) -> bool:
def close_code(self) -> int:
return self._close_code

@property
def close_reason(self) -> str:
return self._close_reason
CaselIT marked this conversation as resolved.
Show resolved Hide resolved

@property
def subprotocol(self) -> str:
return self._accepted_subprotocol
Expand Down Expand Up @@ -464,12 +469,14 @@ async def wait_ready(self, timeout: Optional[int] = 5):
# NOTE(kgriffs): This is a coroutine just in case we need it to be
# in a future code revision. It also makes it more consistent
# with the other methods.
async def close(self, code: Optional[int] = None):
async def close(self, code: Optional[int] = None, reason: Optional[str] = None):
"""Close the simulated connection.

Keyword Args:
code (int): The WebSocket close code to send to the application
per the WebSocket spec (default: ``1000``).
reason (str): The WebSocket close reason to send to the application
per the WebSocket spec (default: empty string).
"""

# NOTE(kgriffs): Give our collector a chance in case the
Expand All @@ -488,8 +495,12 @@ async def close(self, code: Optional[int] = None):
if code is None:
code = WSCloseCode.NORMAL

if reason is None:
reason = ''

self._state = _WebSocketState.CLOSED
self._close_code = code
self._close_reason = reason

async def send_text(self, payload: str):
"""Send a message to the app with a Unicode string payload.
Expand Down Expand Up @@ -727,6 +738,7 @@ async def _collect(self, event: Dict[str, Any]):
self._state = _WebSocketState.DENIED

desired_code = event.get('code', WSCloseCode.NORMAL)
reason = event.get('reason', '')
CaselIT marked this conversation as resolved.
Show resolved Hide resolved
if desired_code == WSCloseCode.SERVER_ERROR or (
3000 <= desired_code < 4000
):
Expand All @@ -735,12 +747,16 @@ async def _collect(self, event: Dict[str, Any]):
# different raised error types or to pass through a
# raised HTTPError status code.
self._close_code = desired_code
self._close_reason = reason
else:
# NOTE(kgriffs): Force the close code to this since it is
# similar to what happens with a real web server (the HTTP
# connection is closed with a 403 and there is no websocket
# close code).
self._close_code = WSCloseCode.FORBIDDEN
self._close_reason = falcon.util.code_to_http_status(
WSCloseCode.FORBIDDEN - 3000
)

self._event_handshake_complete.set()

Expand All @@ -755,6 +771,7 @@ async def _collect(self, event: Dict[str, Any]):
if event_type == EventType.WS_CLOSE:
self._state = _WebSocketState.CLOSED
self._close_code = event.get('code', WSCloseCode.NORMAL)
self._close_reason = event.get('reason', '')
CaselIT marked this conversation as resolved.
Show resolved Hide resolved
else:
assert event_type == EventType.WS_SEND
self._collected_server_events.append(event)
Expand All @@ -780,7 +797,12 @@ def _create_checked_disconnect(self) -> Dict[str, Any]:
)

self._disconnect_emitted = True
return {'type': EventType.WS_DISCONNECT, 'code': self._close_code}
response = {'type': EventType.WS_DISCONNECT, 'code': self._close_code}

if self._close_reason:
response['reason'] = self._close_reason

return response


# get_encoding_from_headers() is Copyright 2016 Kenneth Reitz, and is
Expand Down
81 changes: 78 additions & 3 deletions tests/asgi/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,11 +886,11 @@ async def test_bad_http_version(version, conductor):


@pytest.mark.asyncio
async def test_bad_first_event():
@pytest.mark.parametrize('version', ['2.1', '2.3', '2.10.3'])
async def test_bad_first_event(version):
app = App()

scope = testing.create_scope_ws()
del scope['asgi']['spec_version']
scope = testing.create_scope_ws(spec_version=version)

ws = testing.ASGIWebSocketSimulator()
wrapped_emit = ws._emit
Expand All @@ -910,6 +910,10 @@ async def _emit():

assert ws.closed
assert ws.close_code == CloseCode.SERVER_ERROR
if version != '2.1':
assert ws.close_reason == 'Internal Server Error'
else:
assert ws.close_reason == ''


@pytest.mark.asyncio
Expand Down Expand Up @@ -1092,3 +1096,74 @@ def test_msgpack_missing():

with pytest.raises(RuntimeError):
handler.deserialize(b'{}')


@pytest.mark.asyncio
@pytest.mark.parametrize('reason', ['Client closing connection', '', None])
async def test_client_close_with_reason(reason, conductor):
class Resource:
def __init__(self):
pass

async def on_websocket(self, req, ws):
await ws.accept()
while True:
try:
await ws.receive_data()

except falcon.WebSocketDisconnected:
break

resource = Resource()
conductor.app.add_route('/', resource)

async with conductor as c:
async with c.simulate_ws('/', spec_version='2.3') as ws:
await ws.close(4099, reason)

assert ws.close_code == 4099
if reason:
assert ws.close_reason == reason
else:
assert ws.close_reason == ''


@pytest.mark.asyncio
@pytest.mark.parametrize('no_default', [True, False])
@pytest.mark.parametrize('code', [None, 1011, 4099, 4042, 3405])
async def test_no_reason_mapping(no_default, code, conductor):
class Resource:
def __init__(self):
pass

async def on_websocket(self, req, ws):
await ws.accept()
await ws.close(code)

resource = Resource()
conductor.app.add_route('/', resource)
if no_default:
conductor.app.ws_options.default_close_reasons = {}
else:
conductor.app.ws_options.default_close_reasons[4099] = '4099 reason'

async with conductor as c:
with pytest.raises(falcon.WebSocketDisconnected):
async with c.simulate_ws('/', spec_version='2.10.3') as ws:
await ws.receive_data()

if code:
assert ws.close_code == code
else:
assert ws.close_code == CloseCode.NORMAL

if 3100 <= ws.close_code <= 3999:
assert ws.close_reason == falcon.util.code_to_http_status(ws.close_code - 3000)
elif (
no_default
or ws.close_code not in conductor.app.ws_options.default_close_reasons
):
assert ws.close_reason == ''
else:
reason = conductor.app.ws_options.default_close_reasons[ws.close_code]
assert ws.close_reason == reason