Skip to content

Commit 3813945

Browse files
authored
Websockets handler and functional test (#8)
* Rename `WebSocketMessageHandler` to `WebSocketMessagesHandler` * WebsocketsHandler and functional tests
1 parent 32646f7 commit 3813945

15 files changed

+393
-31
lines changed

src/entrypoint.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import asyncio
2+
import signal
3+
4+
import websockets
5+
6+
from app import conf
7+
from handlers.websockets_handler import WebSocketsHandler
8+
from storage.subscription_storage import SubscriptionStorage
9+
10+
11+
def create_stop_signal() -> asyncio.Future[None]:
12+
loop = asyncio.get_running_loop()
13+
stop_signal = loop.create_future()
14+
loop.add_signal_handler(signal.SIGTERM, stop_signal.set_result, None)
15+
return stop_signal
16+
17+
18+
async def app_runner(settings: conf.Settings, websockets_handler: WebSocketsHandler) -> None:
19+
async with websockets.serve(
20+
ws_handler=websockets_handler.websockets_handler,
21+
host=settings.WEBSOCKETS_HOST,
22+
port=settings.WEBSOCKETS_PORT,
23+
):
24+
await asyncio.Future()
25+
26+
27+
async def main() -> None:
28+
settings = conf.get_app_settings()
29+
storage = SubscriptionStorage()
30+
websockets_handler = WebSocketsHandler(storage=storage)
31+
32+
await app_runner(
33+
settings=settings,
34+
websockets_handler=websockets_handler,
35+
)
36+
37+
38+
if __name__ == "__main__":
39+
asyncio.run(main())

src/handlers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from handlers.message_handler import WebSocketMessageHandler
1+
from handlers.websockets_handler import WebSocketsHandler
22

33
__all__ = [
4-
"WebSocketMessageHandler",
4+
"WebSocketsHandler",
55
]

src/handlers/dto.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Literal
22

3-
from pydantic import Field
3+
from pydantic_core import ErrorDetails
4+
from pydantic import SecretStr
45

56
from pydantic import BaseModel
67
from app.types import Event
@@ -10,13 +11,13 @@
1011

1112

1213
class AuthMessageParams(BaseModel):
13-
token: str
14+
token: SecretStr
1415

1516

1617
class AuthMessage(BaseModel):
1718
message_id: messageId
1819
message_type: Literal["Authenticate"]
19-
params: AuthMessageParams = Field(exclude=True)
20+
params: AuthMessageParams
2021

2122

2223
class SubscribeParams(BaseModel):
@@ -45,5 +46,5 @@ class SuccessResponseMessage(BaseModel):
4546

4647
class ErrorResponseMessage(BaseModel):
4748
message_type: Literal["ErrorResponse"] = "ErrorResponse"
48-
error_detail: str
49+
errors: list[ErrorDetails | str]
4950
incoming_message: IncomingMessage | None # may be null if incoming message was not valid

src/handlers/exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
class WebsocketMessageException(Exception):
66
"""Raise if error occurred during message handling."""
77

8-
def __init__(self, error_detail: str, incoming_message: IncomingMessage | None = None) -> None:
9-
self.error_detail = error_detail
8+
def __init__(self, error_detail: str, incoming_message: IncomingMessage) -> None:
9+
self.errors = [error_detail]
1010
self.incoming_message = incoming_message
1111

1212
def as_error_message(self) -> ErrorResponseMessage:
13-
return ErrorResponseMessage.model_construct(error_detail=self.error_detail, incoming_message=self.incoming_message)
13+
return ErrorResponseMessage.model_construct(errors=self.errors, incoming_message=self.incoming_message)

src/handlers/message_handler.py renamed to src/handlers/messages_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@dataclass
25-
class WebSocketMessageHandler:
25+
class WebSocketMessagesHandler:
2626
storage: SubscriptionStorage
2727

2828
def __post_init__(self) -> None:
@@ -40,7 +40,7 @@ async def handle_message(self, websocket: WebSocketServerProtocol, message: Inco
4040

4141
async def handle_auth_message(self, websocket: WebSocketServerProtocol, message: AuthMessage) -> SuccessResponseMessage:
4242
try:
43-
validated_token = await self.jwk_client.decode(message.params.token)
43+
validated_token = await self.jwk_client.decode(message.params.token.get_secret_value())
4444
StorageWebSocketRegister(storage=self.storage, websocket=websocket, validated_token=validated_token)()
4545
except (AsyncJWKClientException, StorageOperationException) as exc:
4646
raise WebsocketMessageException(str(exc), message)

src/handlers/tests/message_handler/conftest.py renamed to src/handlers/tests/messages_handler/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from handlers import WebSocketMessageHandler
3+
from handlers.messages_handler import WebSocketMessagesHandler
44
from handlers.dto import AuthMessage, SubscribeMessage, UnsubscribeMessage
55

66

@@ -12,13 +12,13 @@ def settings(settings):
1212

1313

1414
@pytest.fixture
15-
def force_token_on_validation(mocker, valid_token):
15+
def force_token_validation(mocker, valid_token):
1616
return mocker.patch("a12n.jwk_client.AsyncJWKClient.decode", return_value=valid_token)
1717

1818

1919
@pytest.fixture
2020
def message_handler(storage):
21-
return WebSocketMessageHandler(storage=storage)
21+
return WebSocketMessagesHandler(storage=storage)
2222

2323

2424
@pytest.fixture

src/handlers/tests/message_handler/tests_auth_message_handler.py renamed to src/handlers/tests/messages_handler/tests_auth_message_handler.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from app.types import DecodedValidToken
55
from handlers.dto import SuccessResponseMessage
66
from handlers.exceptions import WebsocketMessageException
7-
from handlers import WebSocketMessageHandler
7+
from handlers.messages_handler import WebSocketMessagesHandler
88
from storage.storage_updaters import StorageWebSocketRegister
99

1010
pytestmark = [
11-
pytest.mark.usefixtures("force_token_on_validation"),
11+
pytest.mark.usefixtures("force_token_validation"),
1212
]
1313

1414

@@ -18,7 +18,7 @@ def ya_user_decoded_valid_token():
1818

1919

2020
@pytest.fixture
21-
def auth_handler(message_handler: WebSocketMessageHandler, ws):
21+
def auth_handler(message_handler: WebSocketMessagesHandler, ws):
2222
return lambda auth_message: message_handler.handle_auth_message(ws, auth_message)
2323

2424

@@ -50,18 +50,18 @@ async def test_auth_handler_raise_if_user_send_token_for_different_user(auth_han
5050
await auth_handler(auth_message) # send valid user1 token while connection registered with ya_user
5151

5252
raised_exception = exc_info.value
53-
assert raised_exception.error_detail == "The user has different public id"
53+
assert raised_exception.errors == ["The user has different public id"]
5454
assert raised_exception.incoming_message == auth_message
5555
assert storage.is_websocket_registered(ws) is True, "The existed connection should not be touched"
5656

5757

58-
async def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(auth_handler, ws, auth_message, force_token_on_validation, storage):
59-
force_token_on_validation.side_effect = AsyncJWKClientException("The token is expired")
58+
async def test_auth_handler_raise_if_user_try_to_auth_with_expired_token(auth_handler, ws, auth_message, force_token_validation, storage):
59+
force_token_validation.side_effect = AsyncJWKClientException("The token is expired")
6060

6161
with pytest.raises(WebsocketMessageException) as exc_info:
6262
await auth_handler(auth_message)
6363

6464
raised_exception = exc_info.value
65-
assert raised_exception.error_detail == "The token is expired"
65+
assert raised_exception.errors == ["The token is expired"]
6666
assert raised_exception.incoming_message == auth_message
6767
assert storage.is_websocket_registered(ws) is False, "The ws should not be added to registered websockets"

src/handlers/tests/message_handler/tests_message_handler_common.py renamed to src/handlers/tests/messages_handler/tests_message_handler_common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,36 @@
11
import pytest
2-
from handlers.message_handler import WebSocketMessageHandler
2+
from handlers.messages_handler import WebSocketMessagesHandler
33

44

55
@pytest.fixture
66
def get_message_handler(storage):
7-
return lambda: WebSocketMessageHandler(storage)
7+
return lambda: WebSocketMessagesHandler(storage)
88

99

1010
def test_message_handler_jwk_client_settings(message_handler):
1111
assert message_handler.jwk_client.jwks_url == "https://auth.clowns.com/auth/realms/clowns-realm/protocol/openid-connect/certs"
1212
assert message_handler.jwk_client.supported_signing_algorithms == ["RS256"]
1313

1414

15-
@pytest.mark.usefixtures("force_token_on_validation")
15+
@pytest.mark.usefixtures("force_token_validation")
1616
async def test_message_handler_call_auth_handler_on_auth_message(get_message_handler, auth_message, mocker, ws):
17-
spy_auth_handler = mocker.spy(WebSocketMessageHandler, "handle_auth_message")
17+
spy_auth_handler = mocker.spy(WebSocketMessagesHandler, "handle_auth_message")
1818

1919
await get_message_handler().handle_message(ws, auth_message)
2020

2121
spy_auth_handler.assert_awaited_once()
2222

2323

2424
async def test_message_handler_call_subscribe_handler_on_subscribe_message(get_message_handler, subscribe_message, mocker, ws_registered):
25-
spy_subscribe_handler = mocker.spy(WebSocketMessageHandler, "handle_subscribe_message")
25+
spy_subscribe_handler = mocker.spy(WebSocketMessagesHandler, "handle_subscribe_message")
2626

2727
await get_message_handler().handle_message(ws_registered, subscribe_message)
2828

2929
spy_subscribe_handler.assert_awaited_once()
3030

3131

3232
async def test_message_handler_call_unsubscribe_handler_on_unsubscribe_message(get_message_handler, unsubscribe_message, mocker, ws_subscribed):
33-
spy_unsubscribe_handler = mocker.spy(WebSocketMessageHandler, "handle_unsubscribe_message")
33+
spy_unsubscribe_handler = mocker.spy(WebSocketMessagesHandler, "handle_unsubscribe_message")
3434

3535
await get_message_handler().handle_message(ws_subscribed, unsubscribe_message)
3636

src/handlers/tests/message_handler/tests_subscirbe_message_handler.py renamed to src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pytest
2-
from handlers.message_handler import WebSocketMessageHandler
2+
from handlers.messages_handler import WebSocketMessagesHandler
33
from storage.storage_updaters import StorageUserSubscriber
44

55

66
@pytest.fixture
7-
def subscribe_handler(message_handler: WebSocketMessageHandler, ws_registered):
7+
def subscribe_handler(message_handler: WebSocketMessagesHandler, ws_registered):
88
return lambda subscribe_message: message_handler.handle_subscribe_message(ws_registered, subscribe_message)
99

1010

src/handlers/tests/message_handler/tests_unsubscirbe_message_handler.py renamed to src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import pytest
2-
from handlers.message_handler import WebSocketMessageHandler
2+
from handlers.messages_handler import WebSocketMessagesHandler
33
from storage.storage_updaters import StorageUserUnsubscriber
44

55

66
@pytest.fixture
7-
def unsubscribe_handler(message_handler: WebSocketMessageHandler, ws_subscribed):
7+
def unsubscribe_handler(message_handler: WebSocketMessagesHandler, ws_subscribed):
88
return lambda unsubscribe_message: message_handler.handle_unsubscribe_message(ws_subscribed, unsubscribe_message)
99

1010

src/handlers/websockets_handler.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from dataclasses import dataclass
2+
import logging
3+
from typing import Annotated
4+
5+
import pydantic
6+
from pydantic import Field
7+
from pydantic import TypeAdapter
8+
from websockets import WebSocketServerProtocol
9+
from websockets.exceptions import ConnectionClosedError
10+
11+
from app import conf
12+
from handlers.dto import AuthMessage
13+
from handlers.dto import ErrorResponseMessage
14+
from handlers.dto import IncomingMessage
15+
from handlers.dto import SuccessResponseMessage
16+
from handlers.exceptions import WebsocketMessageException
17+
from storage.storage_updaters import StorageWebSocketRemover
18+
from storage.subscription_storage import SubscriptionStorage
19+
from handlers.messages_handler import WebSocketMessagesHandler
20+
21+
logger = logging.getLogger(__name__)
22+
23+
IncomingMessageAdapter = TypeAdapter(Annotated[IncomingMessage, Field(discriminator="message_type")])
24+
AuthMessageAdapter = TypeAdapter(Annotated[AuthMessage, Field(discriminator="message_type")])
25+
26+
27+
@dataclass
28+
class WebSocketsHandler:
29+
storage: SubscriptionStorage
30+
31+
def __post_init__(self) -> None:
32+
settings = conf.get_app_settings()
33+
self.websockets_path = settings.WEBSOCKETS_PATH
34+
35+
self.messages_handler = WebSocketMessagesHandler(storage=self.storage)
36+
37+
async def websockets_handler(self, websocket: WebSocketServerProtocol) -> None:
38+
if websocket.path != self.websockets_path:
39+
return
40+
41+
try:
42+
async for message in websocket:
43+
response_message = await self.process_message(websocket=websocket, raw_message=message)
44+
await websocket.send(response_message.model_dump_json(exclude_none=True))
45+
except ConnectionClosedError:
46+
logger.warning("Trying to send message to closed connection. Connection id: '%s'", websocket.id)
47+
finally:
48+
StorageWebSocketRemover(storage=self.storage, websocket=websocket)()
49+
50+
async def process_message(self, websocket: WebSocketServerProtocol, raw_message: str | bytes) -> SuccessResponseMessage | ErrorResponseMessage:
51+
try:
52+
message = self.parse_raw_message(websocket, raw_message)
53+
except pydantic.ValidationError as exc:
54+
return ErrorResponseMessage.model_construct(errors=exc.errors(include_url=False, include_context=False), incoming_message=None)
55+
56+
try:
57+
success_response = await self.messages_handler.handle_message(websocket, message)
58+
except WebsocketMessageException as exc:
59+
return exc.as_error_message()
60+
61+
return success_response
62+
63+
def parse_raw_message(self, websocket: WebSocketServerProtocol, raw_message: str | bytes) -> IncomingMessage:
64+
adapter = self.get_message_adapter(websocket)
65+
return adapter.validate_json(raw_message)
66+
67+
def get_message_adapter(self, websocket: WebSocketServerProtocol) -> TypeAdapter:
68+
"""Only registered websockets can send all messages. Unregistered websockets can only send Auth messages."""
69+
if self.storage.is_websocket_registered(websocket):
70+
return IncomingMessageAdapter
71+
72+
return AuthMessageAdapter

0 commit comments

Comments
 (0)