Skip to content

Commit 60997d4

Browse files
authored
Websockets access guardian (#9)
* WebSocketAccess guardian * WebSocketsAccessGuardian and test fot it
1 parent 3813945 commit 60997d4

File tree

6 files changed

+180
-6
lines changed

6 files changed

+180
-6
lines changed

src/entrypoint.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import websockets
55

66
from app import conf
7-
from handlers.websockets_handler import WebSocketsHandler
7+
from handlers import WebSocketsHandler
8+
from handlers import WebSocketsAccessGuardian
89
from storage.subscription_storage import SubscriptionStorage
910

1011

@@ -15,23 +16,27 @@ def create_stop_signal() -> asyncio.Future[None]:
1516
return stop_signal
1617

1718

18-
async def app_runner(settings: conf.Settings, websockets_handler: WebSocketsHandler) -> None:
19+
async def app_runner(settings: conf.Settings, websockets_handler: WebSocketsHandler, access_guardian: WebSocketsAccessGuardian) -> None:
1920
async with websockets.serve(
2021
ws_handler=websockets_handler.websockets_handler,
2122
host=settings.WEBSOCKETS_HOST,
2223
port=settings.WEBSOCKETS_PORT,
2324
):
24-
await asyncio.Future()
25+
await access_guardian.run()
2526

2627

2728
async def main() -> None:
2829
settings = conf.get_app_settings()
30+
stop_signal = create_stop_signal()
31+
2932
storage = SubscriptionStorage()
3033
websockets_handler = WebSocketsHandler(storage=storage)
34+
access_guardian = WebSocketsAccessGuardian(storage=storage, check_interval=60.0, stop_signal=stop_signal)
3135

3236
await app_runner(
3337
settings=settings,
3438
websockets_handler=websockets_handler,
39+
access_guardian=access_guardian,
3540
)
3641

3742

src/handlers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from handlers.websockets_access_guardian import WebSocketsAccessGuardian
12
from handlers.websockets_handler import WebSocketsHandler
23

34
__all__ = [
5+
"WebSocketsAccessGuardian",
46
"WebSocketsHandler",
57
]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
from datetime import datetime
3+
import json
4+
import pytest
5+
6+
from app.types import DecodedValidToken
7+
from handlers.websockets_access_guardian import WebSocketsAccessGuardian
8+
from storage.storage_updaters import StorageWebSocketRegister
9+
from storage.storage_updaters import StorageUserSubscriber
10+
11+
pytestmark = [
12+
pytest.mark.slow,
13+
]
14+
15+
16+
@pytest.fixture(autouse=True)
17+
def mock_broadcast(mocker):
18+
return mocker.patch("websockets.broadcast")
19+
20+
21+
@pytest.fixture
22+
def get_guardian_enough_time_to_do_its_job():
23+
return lambda: asyncio.sleep(0.4)
24+
25+
26+
@pytest.fixture
27+
async def guardian(storage):
28+
stop_signal = asyncio.Future() # it's better to use loop.create_future() but ok for tests
29+
30+
yield WebSocketsAccessGuardian(storage=storage, check_interval=0.1, stop_signal=stop_signal)
31+
32+
stop_signal.set_result(None)
33+
34+
35+
@pytest.fixture(autouse=True)
36+
def guardian_as_task(guardian, event_loop):
37+
return event_loop.create_task(guardian.run())
38+
39+
40+
@pytest.fixture(autouse=True)
41+
def ws_subscribed(ws, storage, event):
42+
token_expiration_timestamp = int(datetime.fromisoformat("2023-01-01 12:23:00Z").timestamp())
43+
44+
valid_token = DecodedValidToken(sub="user1", exp=token_expiration_timestamp)
45+
46+
StorageWebSocketRegister(storage, ws, valid_token)()
47+
StorageUserSubscriber(storage, ws, event)()
48+
49+
return ws
50+
51+
52+
def test_guardian_monitor_and_manage_access_remove_expired_websockets(guardian, storage, ws_subscribed):
53+
guardian.monitor_and_manage_access()
54+
55+
registered_websockets = storage.get_registered_websockets()
56+
assert registered_websockets == []
57+
58+
59+
async def test_remove_expired_websockets_from_storage(get_guardian_enough_time_to_do_its_job, storage, mock_broadcast, ws_subscribed, mocker):
60+
await get_guardian_enough_time_to_do_its_job()
61+
62+
registered_websockets = storage.get_registered_websockets()
63+
assert registered_websockets == []
64+
assert ws_subscribed.closed is False, "Do not close connection when token expired, just remove from storage"
65+
mock_broadcast.assert_called_once_with(websockets=[ws_subscribed], message=mocker.ANY)
66+
67+
68+
async def test_broadcasted_message(get_guardian_enough_time_to_do_its_job, mock_broadcast):
69+
await get_guardian_enough_time_to_do_its_job()
70+
71+
broadcasted_message_as_json = json.loads(mock_broadcast.call_args.kwargs["message"])
72+
assert len(broadcasted_message_as_json) == 2
73+
assert broadcasted_message_as_json["message_type"] == "ErrorResponse"
74+
assert broadcasted_message_as_json["errors"] == ["Token expired, user subscriptions disabled or removed"]
75+
76+
77+
@pytest.mark.freeze_time("2023-01-01 12:22:55Z", tick=True) # 5 seconds before expiration
78+
async def test_do_not_remove_not_expired_connections(get_guardian_enough_time_to_do_its_job, storage, ws_subscribed, mock_broadcast):
79+
await get_guardian_enough_time_to_do_its_job()
80+
81+
registered_websockets = storage.get_registered_websockets()
82+
assert registered_websockets == [ws_subscribed]
83+
mock_broadcast.assert_not_called()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import asyncio
2+
from dataclasses import dataclass, field
3+
import logging
4+
5+
import websockets
6+
7+
from handlers.dto import ErrorResponseMessage
8+
from storage.storage_updaters import StorageWebSocketRemover
9+
from storage.subscription_storage import SubscriptionStorage
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
@dataclass
15+
class WebSocketsAccessGuardian:
16+
storage: SubscriptionStorage
17+
check_interval: float = 60.0 # in seconds
18+
stop_signal: asyncio.Future[None] = field(default_factory=asyncio.Future) # default feature will run forever
19+
20+
async def run(self) -> None:
21+
while True:
22+
await asyncio.sleep(self.check_interval)
23+
24+
self.monitor_and_manage_access()
25+
26+
def monitor_and_manage_access(self) -> None:
27+
expired_websockets = self.storage.get_expired_websockets()
28+
29+
if expired_websockets:
30+
logger.warning("Notify and remove from storage for expired websockets: '%s'", (", ").join([str(websocket.id) for websocket in expired_websockets]))
31+
32+
self.broadcast_authentication_expired(expired_websockets)
33+
self.remove_expired_websockets(expired_websockets)
34+
35+
e = self.storage.get_expired_websockets()
36+
assert e is not None
37+
38+
def broadcast_authentication_expired(self, expired_websockets: list[websockets.WebSocketServerProtocol]) -> None:
39+
error_message = ErrorResponseMessage(errors=["Token expired, user subscriptions disabled or removed"], incoming_message=None)
40+
websockets.broadcast(websockets=expired_websockets, message=error_message.model_dump_json(exclude_none=True))
41+
42+
def remove_expired_websockets(self, expired_websockets: list[websockets.WebSocketServerProtocol]) -> None:
43+
for websocket in expired_websockets:
44+
StorageWebSocketRemover(storage=self.storage, websocket=websocket)()

src/tests/functional/conftest.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from websockets import client
66

77
from entrypoint import app_runner
8-
from handlers.websockets_handler import WebSocketsHandler
8+
from handlers import WebSocketsHandler
9+
from handlers import WebSocketsAccessGuardian
910

1011

1112
@pytest.fixture
@@ -24,18 +25,31 @@ def websockets_handler(storage):
2425
return WebSocketsHandler(storage=storage)
2526

2627

28+
@pytest.fixture
29+
async def stop_signal():
30+
return asyncio.Future()
31+
32+
33+
@pytest.fixture
34+
async def access_guardian(storage, stop_signal):
35+
return WebSocketsAccessGuardian(storage=storage, check_interval=0.5, stop_signal=stop_signal)
36+
37+
2738
@pytest.fixture(autouse=True)
28-
async def serve_app_runner(settings, websockets_handler):
39+
async def serve_app_runner(settings, websockets_handler, access_guardian, stop_signal):
2940
serve_task = asyncio.get_running_loop().create_task(
3041
app_runner(
3142
settings=settings,
3243
websockets_handler=websockets_handler,
44+
access_guardian=access_guardian,
3345
),
3446
)
3547

3648
await asyncio.sleep(0.1) # give enough time to start the server
3749
assert serve_task.done() is False # be sure server is running
38-
return serve_task
50+
yield serve_task
51+
52+
stop_signal.set_result(None)
3953

4054

4155
@pytest.fixture
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import asyncio
2+
import pytest
3+
4+
pytestmark = [
5+
pytest.mark.slow,
6+
]
7+
8+
9+
@pytest.fixture
10+
def set_storage_connections_expired(storage):
11+
def set_expired():
12+
for _, websocket_meta in storage.registered_websockets.items():
13+
websocket_meta.expiration_timestamp = 1000 # far in the past
14+
15+
return set_expired
16+
17+
18+
async def test_expired_connections_removed_from_active_connections(ws_client_authenticated, ws_client_recv_decoded, set_storage_connections_expired):
19+
set_storage_connections_expired()
20+
await asyncio.sleep(1.1) # give enough time to validator to do its job
21+
22+
received = await ws_client_recv_decoded(ws_client_authenticated)
23+
24+
assert len(received) == 2
25+
assert received["message_type"] == "ErrorResponse"
26+
assert received["errors"] == ["Token expired, user subscriptions disabled or removed"]

0 commit comments

Comments
 (0)