diff --git a/Makefile b/Makefile index f1441bf..794fb36 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ dev-deps: deps fmt: ruff format $(SOURCES) - ruff check $(SOURCES) --fix --unsafe-fixes + ruff check --select I --fix $(SOURCES) lint: dotenv-linter env.example diff --git a/pyproject.toml b/pyproject.toml index 4de5024..601a7aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,16 +34,56 @@ dev = [ [tool.ruff] line-length = 160 +src = ["src"] -[tool.roof.lint] +[tool.ruff.lint] select = ["ALL"] +ignore = [ + "ANN101", # Missing type annotation for self in method + "ANN102", # missing type annotation for `cls` in classmethod + "ANN401", # dynamically typed expressions (typing.Any) are disallowed in `{}` + "COM812", # Trailing comma missing + "D100", # missing docstring in public module + "D101", # missing docstring in public class + "D102", # missing docstring in public method + "D103", # missing docstring in public function + "D104", # missing docstring in public package + "D105", # missing docstring in magic method + "D106", # missing docstring in public nested class + "D107", # missing docstring in `__init__` + "D203", # one blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "EM101", # exception must not use a string literal, assign to variable first + "EM102", # expection must not use an f-string literal, assign to variable first + "INP001", # file `%filename%` is part of an implicit namespace package. Add an `__init__.py` + "ISC001", # implicitly concatenated string literals on one line + "N818", # exception name `{}` should be named with an Error suffix + "PT001", # use `@pytest.fixture()` over `@pytest.fixture` + "TRY003", # avoid specifying long messages outside the exception class +] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" +[tool.ruff.lint.per-file-ignores] +"**/tests/*" = [ + "ANN", # flake8-annotations + "ARG001", # Unused function argument + "PLR0913", # Too many arguments in function definition + "PLR2004", # Magic value used in comparison, consider replacing `%value%` with a constant variable + "S101", # Use of `assert` detected +] +"**/fixtures.py" = [ + "ANN", # flake8-annotations +] + +[tool.ruff.lint.isort] +extra-standard-library = ["pytest"] + + [tool.pytest.ini_options] pythonpath = ["src"] testpaths = ["src"] diff --git a/src/a12n/jwk_client.py b/src/a12n/jwk_client.py index 29e0574..dbefd32 100644 --- a/src/a12n/jwk_client.py +++ b/src/a12n/jwk_client.py @@ -1,20 +1,17 @@ import asyncio -from dataclasses import dataclass import json import logging +from dataclasses import dataclass import httpx import jwt -from jwt.api_jwk import PyJWK -from jwt.api_jwk import PyJWKSet +from jwt.api_jwk import PyJWK, PyJWKSet from jwt.api_jwt import decode_complete as decode_token from jwt.exceptions import PyJWKSetError from jwt.jwk_set_cache import JWKSetCache - from app.types import DecodedValidToken - logger = logging.getLogger(__name__) @@ -24,7 +21,8 @@ class AsyncJWKClientException(Exception): @dataclass class AsyncJWKClient: - """ + """Async JW Keys client. + Inspired and partially copy-pasted from 'jwt.jwks_client.PyJWKClient'. The purpose is the same but querying the JWKS endpoint is async. """ @@ -68,7 +66,7 @@ async def fetch_data(self) -> PyJWKSet: except PyJWKSetError as exc: raise AsyncJWKClientException(exc) from exc - async def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: + async def get_jwk_set(self, *, refresh: bool = False) -> PyJWKSet: jwk_set: PyJWKSet | None = None while self.fetch_data_lock.locked(): @@ -82,8 +80,8 @@ async def get_jwk_set(self, refresh: bool = False) -> PyJWKSet: return jwk_set - async def get_signing_keys(self, refresh: bool = False) -> list[PyJWK]: - jwk_set = await self.get_jwk_set(refresh) + async def get_signing_keys(self, *, refresh: bool = False) -> list[PyJWK]: + jwk_set = await self.get_jwk_set(refresh=refresh) signing_keys = [ jwk_set_key @@ -102,7 +100,7 @@ async def get_signing_keys(self, refresh: bool = False) -> list[PyJWK]: return signing_keys async def get_signing_key(self, kid: str) -> PyJWK: - signing_keys = await self.get_signing_keys() + signing_keys = await self.get_signing_keys(refresh=False) signing_key = self.match_kid(signing_keys, kid) if not signing_key: diff --git a/src/a12n/tests/async_jwk_client/conftest.py b/src/a12n/tests/async_jwk_client/conftest.py index 01a41a3..0046681 100644 --- a/src/a12n/tests/async_jwk_client/conftest.py +++ b/src/a12n/tests/async_jwk_client/conftest.py @@ -1,7 +1,6 @@ import pytest -from respx import MockRouter -from respx import Route +from respx import MockRouter, Route from a12n.jwk_client import AsyncJWKClient @@ -10,13 +9,13 @@ @pytest.fixture def expired_token(): - return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjNMcjhuTjh1R29wUElMZlFvUGpfRCJ9.eyJpc3MiOiJodHRwczovL2Rldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbS8iLCJhdWQiOiIxV1NiR1hGUnlhS0NsTHJvWHZteTlXdndrZUtHb1JvayIsImlhdCI6MTY5ODUyODU1OSwiZXhwIjoxNjk4NTI4ODU5LCJzdWIiOiJhdXRoMHw2NTNjMzI2MGEzMDQ0OGM1OTRhNjllMTIiLCJzaWQiOiI5MEZ3WFNDSFUtd0N3QmY0Y1YyQ3NZTnpBMldieDNUcSIsIm5vbmNlIjoiZTIxZWVhNTljNGY1MDg0N2Q3YzFhOGUzZjQ0NjVjYTcifQ.FO_xoMA9RGI7uAVauv00-zdORgkvCwyWfeAPd7lmU_nKzGp5avPa2MN66S0fjLKOxb8tgzrfpXYLUhDl1nqUvtj1A54-PfNW0n0ctdn2zk_CCOxsAjKyImlIgq7Y4DIuil0wikj7FdoWkB-bCBrKs7JaOoWkSHws9uQxRyvZzBwPHExW0myHWvB3G0x8g23PfSv2oALbvXBp0OAniGwru2Br9e2iXCVyGAUMTCpQmjPDAyfeYXGxF9BhxuX3e-GL80oyngBQK0kTxw-2Xz8LDSC-MI2jTs1gUo9qdVrg_1fzQtvAW9LGaWg5L_CJe92ZH3l1fBPfSh7Gc6uBtwF-YA" + return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjNMcjhuTjh1R29wUElMZlFvUGpfRCJ9.eyJpc3MiOiJodHRwczovL2Rldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbS8iLCJhdWQiOiIxV1NiR1hGUnlhS0NsTHJvWHZteTlXdndrZUtHb1JvayIsImlhdCI6MTY5ODUyODU1OSwiZXhwIjoxNjk4NTI4ODU5LCJzdWIiOiJhdXRoMHw2NTNjMzI2MGEzMDQ0OGM1OTRhNjllMTIiLCJzaWQiOiI5MEZ3WFNDSFUtd0N3QmY0Y1YyQ3NZTnpBMldieDNUcSIsIm5vbmNlIjoiZTIxZWVhNTljNGY1MDg0N2Q3YzFhOGUzZjQ0NjVjYTcifQ.FO_xoMA9RGI7uAVauv00-zdORgkvCwyWfeAPd7lmU_nKzGp5avPa2MN66S0fjLKOxb8tgzrfpXYLUhDl1nqUvtj1A54-PfNW0n0ctdn2zk_CCOxsAjKyImlIgq7Y4DIuil0wikj7FdoWkB-bCBrKs7JaOoWkSHws9uQxRyvZzBwPHExW0myHWvB3G0x8g23PfSv2oALbvXBp0OAniGwru2Br9e2iXCVyGAUMTCpQmjPDAyfeYXGxF9BhxuX3e-GL80oyngBQK0kTxw-2Xz8LDSC-MI2jTs1gUo9qdVrg_1fzQtvAW9LGaWg5L_CJe92ZH3l1fBPfSh7Gc6uBtwF-YA" # noqa: E501 @pytest.fixture def token(): # The token won't expire in ~100 years (expiration date 2123-10-05, it's more than enough to rely on it in test) - return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjNMcjhuTjh1R29wUElMZlFvUGpfRCJ9.eyJpc3MiOiJodHRwczovL2Rldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbS8iLCJhdWQiOiIxV1NiR1hGUnlhS0NsTHJvWHZteTlXdndrZUtHb1JvayIsImlhdCI6MTY5ODUyODE3MCwiZXhwIjo0ODUyMTI4MTcwLCJzdWIiOiJhdXRoMHw2NTNjMzI2MGEzMDQ0OGM1OTRhNjllMTIiLCJzaWQiOiI5MEZ3WFNDSFUtd0N3QmY0Y1YyQ3NZTnpBMldieDNUcSIsIm5vbmNlIjoiMTVhNWI2M2Y3MzI5MDcwMmU3MGViZmJlMDc5ODgxYmIifQ.FQYBaTnjKJHcskRl1WsB4kKQmyvXRcG8RDWlB2woSbzukZx7SnWghC1qRhYeqOLBUBpe3Iu_EzxgF26YDZJ28bKKNgL4fVmYak3jOg2nRP2lulrkF8USmkqT9Vx85hlIEVCisYOS6DJE0bHJL5WbHjCmDjQ6RGRyVZ3s6UPFXIwe2CMC_egAdWrsLYrgA1mqozQhwLJN2zSuObkDffkpHbX9XXB225v3-ryY-Rr0rPh9AOfKtEeMUEmNG0gsGyIbi0DoPDjAxlxCDx7ULVSChIKhUv4DKICqrqzHyopA7oE8LlpDbPTshQsL6L4u1EwUT7maP9VTcEQUTnp3Cu5msw" + return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjNMcjhuTjh1R29wUElMZlFvUGpfRCJ9.eyJpc3MiOiJodHRwczovL2Rldi1wcm50ZG1vMTYzc2NsczR4LnVzLmF1dGgwLmNvbS8iLCJhdWQiOiIxV1NiR1hGUnlhS0NsTHJvWHZteTlXdndrZUtHb1JvayIsImlhdCI6MTY5ODUyODE3MCwiZXhwIjo0ODUyMTI4MTcwLCJzdWIiOiJhdXRoMHw2NTNjMzI2MGEzMDQ0OGM1OTRhNjllMTIiLCJzaWQiOiI5MEZ3WFNDSFUtd0N3QmY0Y1YyQ3NZTnpBMldieDNUcSIsIm5vbmNlIjoiMTVhNWI2M2Y3MzI5MDcwMmU3MGViZmJlMDc5ODgxYmIifQ.FQYBaTnjKJHcskRl1WsB4kKQmyvXRcG8RDWlB2woSbzukZx7SnWghC1qRhYeqOLBUBpe3Iu_EzxgF26YDZJ28bKKNgL4fVmYak3jOg2nRP2lulrkF8USmkqT9Vx85hlIEVCisYOS6DJE0bHJL5WbHjCmDjQ6RGRyVZ3s6UPFXIwe2CMC_egAdWrsLYrgA1mqozQhwLJN2zSuObkDffkpHbX9XXB225v3-ryY-Rr0rPh9AOfKtEeMUEmNG0gsGyIbi0DoPDjAxlxCDx7ULVSChIKhUv4DKICqrqzHyopA7oE8LlpDbPTshQsL6L4u1EwUT7maP9VTcEQUTnp3Cu5msw" # noqa: E501 @pytest.fixture @@ -26,7 +25,7 @@ def matching_kid_data(): { "kty": "RSA", "use": "sig", - "n": "oIQkRCY4X-_ItMUPt65wVIGewOJfjMhlu6HG_rHik5-dTK0o6oyUne2Gevetn2Vrn8NSIaARobLZ8expuJBYDS121w_RloC6MCuzlc-j_nHj-BcBOCqGWPVwKX4un0HueD3aW3buqzYcmX_9LhdSE8ARyN0S9O6RbYWDCTKFhrRXtIP4wzP8vdPGXGurtGIiBbhVCK1LHG2lO5Gt8IIQ_DAcX6swnXCfbHwR1OXc9Do06o8c7ZsZdjMty5b4Fpv8rAKA-HTP_One4yhKtqCMYs3_gcTeQdHi-0w634VnpdzC_0f_MMzNIgvXC8VdJgkGpa6jLBp3mTqaFUdkAXFYlw", + "n": "oIQkRCY4X-_ItMUPt65wVIGewOJfjMhlu6HG_rHik5-dTK0o6oyUne2Gevetn2Vrn8NSIaARobLZ8expuJBYDS121w_RloC6MCuzlc-j_nHj-BcBOCqGWPVwKX4un0HueD3aW3buqzYcmX_9LhdSE8ARyN0S9O6RbYWDCTKFhrRXtIP4wzP8vdPGXGurtGIiBbhVCK1LHG2lO5Gt8IIQ_DAcX6swnXCfbHwR1OXc9Do06o8c7ZsZdjMty5b4Fpv8rAKA-HTP_One4yhKtqCMYs3_gcTeQdHi-0w634VnpdzC_0f_MMzNIgvXC8VdJgkGpa6jLBp3mTqaFUdkAXFYlw", # noqa: E501 "e": "AQAB", "kid": "3Lr8nN8uGopPILfQoPj_D", "x5t": "f93zLhSTsgVJiS9JA0x8sHkaLMg", @@ -46,7 +45,7 @@ def not_matching_kid_data(): { "kty": "RSA", "use": "sig", - "n": "zB0xsH539lpLVejR6Hq1bHN3EzDt_0tJyr5JVHz3GSnNYAaZzkqL7HyLlhwttl7_bRyZJeZ8X6aasBxVK2JCDc9U-0KMJXmSoJs1oWYRo79DqdzCXK3ZYXcgkvI9OWF1qVx76vbZVwiRv5qUzpINdLnsX2CXChyd0LFkg14bYrSfdN-eMmG1PXtHZufeKG6HW17PFXS7OwesMQIfQ9kFfSvgFkJgkNM0o6NaeB-ZPDvzfKmmpBXjtGcze0A56NdQ7Z42DRDURROS82sPISrX-iAt93tZ1F0IW_U4niIYc6NFcWPPXpQpiVDDwdrz-L1H63mSUDSDFsWVcv2xWry6kQ", + "n": "zB0xsH539lpLVejR6Hq1bHN3EzDt_0tJyr5JVHz3GSnNYAaZzkqL7HyLlhwttl7_bRyZJeZ8X6aasBxVK2JCDc9U-0KMJXmSoJs1oWYRo79DqdzCXK3ZYXcgkvI9OWF1qVx76vbZVwiRv5qUzpINdLnsX2CXChyd0LFkg14bYrSfdN-eMmG1PXtHZufeKG6HW17PFXS7OwesMQIfQ9kFfSvgFkJgkNM0o6NaeB-ZPDvzfKmmpBXjtGcze0A56NdQ7Z42DRDURROS82sPISrX-iAt93tZ1F0IW_U4niIYc6NFcWPPXpQpiVDDwdrz-L1H63mSUDSDFsWVcv2xWry6kQ", # noqa: E501 "e": "AQAB", "kid": "ICOpsXGmpNaDPiljjRjiE", "x5t": "1GDK6kGV6HvZ1m_-VdSKIFNEtEU", diff --git a/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py b/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py index 20ff755..dfe4afe 100644 --- a/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py +++ b/src/a12n/tests/async_jwk_client/tests_async_jwk_decode.py @@ -1,6 +1,7 @@ +import pytest from contextlib import nullcontext as does_not_raise + from a12n.jwk_client import AsyncJWKClientException -import pytest pytestmark = [ pytest.mark.usefixtures("mock_success_response"), diff --git a/src/app/conf/__init__.py b/src/app/conf/__init__.py index 8c67b13..e051d76 100644 --- a/src/app/conf/__init__.py +++ b/src/app/conf/__init__.py @@ -1,5 +1,4 @@ -from app.conf.settings import get_app_settings -from app.conf.settings import Settings +from app.conf.settings import Settings, get_app_settings __all__ = [ "Settings", diff --git a/src/app/conf/settings.py b/src/app/conf/settings.py index 1afceb4..d4dcf4f 100644 --- a/src/app/conf/settings.py +++ b/src/app/conf/settings.py @@ -2,8 +2,7 @@ from typing import Literal from pydantic import AmqpDsn -from pydantic_settings import BaseSettings -from pydantic_settings import SettingsConfigDict +from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): @@ -33,4 +32,4 @@ class Settings(BaseSettings): @lru_cache def get_app_settings() -> Settings: - return Settings() # type: ignore + return Settings() # type: ignore[call-arg] diff --git a/src/app/fixtures.py b/src/app/fixtures.py index 581197b..b2b1e45 100644 --- a/src/app/fixtures.py +++ b/src/app/fixtures.py @@ -1,18 +1,19 @@ import pytest +from collections.abc import Callable from app.testing import MockedWebSocketServerProtocol @pytest.fixture -def create_ws(): +def create_ws() -> Callable[[], MockedWebSocketServerProtocol]: return lambda: MockedWebSocketServerProtocol() @pytest.fixture -def ws(create_ws): +def ws(create_ws) -> MockedWebSocketServerProtocol: return create_ws() @pytest.fixture -def ya_ws(create_ws): +def ya_ws(create_ws) -> MockedWebSocketServerProtocol: return create_ws() diff --git a/src/app/services.py b/src/app/services.py index a7b732f..7734c8d 100644 --- a/src/app/services.py +++ b/src/app/services.py @@ -1,10 +1,11 @@ -from abc import ABC -from abc import abstractmethod -from typing import Any, Callable +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any class BaseService(ABC): - """This is a template of a a base service. + """Template of a a base service. + All services in the app should follow this rules: * Input variables should be done at the __init__ phase * Service should implement a single entrypoint without arguments diff --git a/src/app/testing.py b/src/app/testing.py index 3868afc..ef6543b 100644 --- a/src/app/testing.py +++ b/src/app/testing.py @@ -33,7 +33,7 @@ async def send(self, message: str) -> None: # type: ignore[override] await asyncio.sleep(0) await self.send_queue.put(message) - async def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: + async def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # noqa: ARG002 self.state = State.CLOSED async def wait_messages_to_be_sent(self) -> None: @@ -46,10 +46,8 @@ async def count_sent_to_client(self) -> int: def client_send(self, message: dict) -> None: self.recv_queue.put_nowait(json.dumps(message)) - async def client_recv(self, skip_count_first_messages=0) -> dict | None: - """Convenient for testing. - Receive one message at time. First messages could be discarded with 'skip_count_first_messages' parameter. - """ + async def client_recv(self, skip_count_first_messages: int = 0) -> dict | None: + """Skip 'skip_count_first_messages' messages and return the next one. Convenient for testing.""" await self.wait_messages_to_be_sent() if self.send_queue.empty(): diff --git a/src/app/types.py b/src/app/types.py index f3bd73e..b240d16 100644 --- a/src/app/types.py +++ b/src/app/types.py @@ -1,4 +1,4 @@ -from typing import NewType, NamedTuple +from typing import NamedTuple, NewType UserId = NewType("UserId", str) Event = NewType("Event", str) diff --git a/src/conftest.py b/src/conftest.py index 6927997..042509a 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -1,6 +1,6 @@ import pytest -from app.conf import get_app_settings +from app.conf import Settings, get_app_settings pytest_plugins = [ "app.fixtures", @@ -9,5 +9,5 @@ @pytest.fixture -def settings(): +def settings() -> Settings: return get_app_settings() diff --git a/src/consumer/consumer.py b/src/consumer/consumer.py index 9f108e6..fc3f64e 100644 --- a/src/consumer/consumer.py +++ b/src/consumer/consumer.py @@ -1,17 +1,15 @@ import asyncio -from dataclasses import dataclass import logging +from dataclasses import dataclass from typing import Protocol import aio_pika +import websockets +from pydantic import ValidationError from app import conf -from consumer.dto import ConsumedMessage -from consumer.dto import OutgoingMessage +from consumer.dto import ConsumedMessage, OutgoingMessage from storage.subscription_storage import SubscriptionStorage -from pydantic import ValidationError -import websockets - logger = logging.getLogger(__name__) @@ -61,7 +59,7 @@ def parse_message(raw_message: aio_pika.abc.AbstractIncomingMessage) -> Consumed try: return ConsumedMessage.model_validate_json(raw_message.body) except ValidationError as exc: - logger.error("Consumed message not in expected format. Errors: %s", exc.errors()) + logger.error("Consumed message not in expected format. Errors: %s", exc.errors()) # noqa: TRY400 return None @staticmethod diff --git a/src/consumer/dto.py b/src/consumer/dto.py index fd1b87a..1d415b4 100644 --- a/src/consumer/dto.py +++ b/src/consumer/dto.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel -from pydantic import ConfigDict from typing import Literal + +from pydantic import BaseModel, ConfigDict + from app.types import Event diff --git a/src/consumer/tests/conftest.py b/src/consumer/tests/conftest.py index ab64f35..6c1e211 100644 --- a/src/consumer/tests/conftest.py +++ b/src/consumer/tests/conftest.py @@ -1,9 +1,9 @@ import pytest +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from dataclasses import dataclass from consumer.consumer import Consumer -from dataclasses import dataclass -from contextlib import asynccontextmanager -from typing import AsyncGenerator @pytest.fixture(autouse=True) diff --git a/src/consumer/tests/tests_consumer_on_message.py b/src/consumer/tests/tests_consumer_on_message.py index 74a527d..a09b14d 100644 --- a/src/consumer/tests/tests_consumer_on_message.py +++ b/src/consumer/tests/tests_consumer_on_message.py @@ -1,6 +1,6 @@ -from contextlib import nullcontext as does_not_raise import json import pytest +from contextlib import nullcontext as does_not_raise from consumer.tests.conftest import MockedIncomingMessage diff --git a/src/entrypoint.py b/src/entrypoint.py index 460a05d..0828ec3 100644 --- a/src/entrypoint.py +++ b/src/entrypoint.py @@ -1,14 +1,13 @@ import asyncio +import logging import signal import websockets -import logging from app import conf -from handlers import WebSocketsHandler -from handlers import WebSocketsAccessGuardian -from storage.subscription_storage import SubscriptionStorage from consumer import Consumer +from handlers import WebSocketsAccessGuardian, WebSocketsHandler +from storage.subscription_storage import SubscriptionStorage logging.basicConfig(level=logging.INFO) @@ -27,14 +26,16 @@ async def app_runner( consumer: Consumer, stop_signal: asyncio.Future, ) -> None: - async with websockets.serve( - ws_handler=websockets_handler.websockets_handler, - host=settings.WEBSOCKETS_HOST, - port=settings.WEBSOCKETS_PORT, + async with ( + websockets.serve( + ws_handler=websockets_handler.websockets_handler, + host=settings.WEBSOCKETS_HOST, + port=settings.WEBSOCKETS_PORT, + ), + asyncio.TaskGroup() as task_group, ): - async with asyncio.TaskGroup() as task_group: - task_group.create_task(access_guardian.run(stop_signal)) - task_group.create_task(consumer.consume(stop_signal)) + task_group.create_task(access_guardian.run(stop_signal)) + task_group.create_task(consumer.consume(stop_signal)) async def main() -> None: diff --git a/src/handlers/dto.py b/src/handlers/dto.py index 25e9834..a479a90 100644 --- a/src/handlers/dto.py +++ b/src/handlers/dto.py @@ -1,13 +1,11 @@ from typing import Literal +from pydantic import BaseModel, SecretStr from pydantic_core import ErrorDetails -from pydantic import SecretStr -from pydantic import BaseModel from app.types import Event - -messageId = int | str +MessageId = int | str class AuthMessageParams(BaseModel): @@ -15,7 +13,7 @@ class AuthMessageParams(BaseModel): class AuthMessage(BaseModel): - message_id: messageId + message_id: MessageId message_type: Literal["Authenticate"] params: AuthMessageParams @@ -25,13 +23,13 @@ class SubscribeParams(BaseModel): class SubscribeMessage(BaseModel): - message_id: messageId + message_id: MessageId message_type: Literal["Subscribe"] params: SubscribeParams class UnsubscribeMessage(BaseModel): - message_id: messageId + message_id: MessageId message_type: Literal["Unsubscribe"] params: SubscribeParams diff --git a/src/handlers/exceptions.py b/src/handlers/exceptions.py index 32ec7d4..a731f0b 100644 --- a/src/handlers/exceptions.py +++ b/src/handlers/exceptions.py @@ -1,5 +1,4 @@ -from handlers.dto import IncomingMessage -from handlers.dto import ErrorResponseMessage +from handlers.dto import ErrorResponseMessage, IncomingMessage class WebsocketMessageException(Exception): diff --git a/src/handlers/messages_handler.py b/src/handlers/messages_handler.py index 70ea6ba..0b04c50 100644 --- a/src/handlers/messages_handler.py +++ b/src/handlers/messages_handler.py @@ -1,22 +1,16 @@ +from collections.abc import Callable, Coroutine from dataclasses import dataclass -from typing import Coroutine, Any, Callable +from typing import Any from websockets import WebSocketServerProtocol -from a12n.jwk_client import AsyncJWKClient -from a12n.jwk_client import AsyncJWKClientException +from a12n.jwk_client import AsyncJWKClient, AsyncJWKClientException from app import conf -from handlers.dto import AuthMessage -from handlers.dto import SubscribeMessage -from handlers.dto import UnsubscribeMessage -from handlers.dto import SuccessResponseMessage +from handlers.dto import AuthMessage, IncomingMessage, SubscribeMessage, SuccessResponseMessage, UnsubscribeMessage from handlers.exceptions import WebsocketMessageException -from storage.exceptions import StorageOperationException -from storage.storage_updaters import StorageWebSocketRegister from storage import SubscriptionStorage -from storage.storage_updaters import StorageUserSubscriber -from storage.storage_updaters import StorageUserUnsubscriber -from handlers.dto import IncomingMessage +from storage.exceptions import StorageOperationException +from storage.storage_updaters import StorageUserSubscriber, StorageUserUnsubscriber, StorageWebSocketRegister AsyncMessageHandler = Callable[[WebSocketServerProtocol, Any], Coroutine[Any, Any, SuccessResponseMessage]] @@ -43,7 +37,7 @@ async def handle_auth_message(self, websocket: WebSocketServerProtocol, message: validated_token = await self.jwk_client.decode(message.params.token.get_secret_value()) StorageWebSocketRegister(storage=self.storage, websocket=websocket, validated_token=validated_token)() except (AsyncJWKClientException, StorageOperationException) as exc: - raise WebsocketMessageException(str(exc), message) + raise WebsocketMessageException(str(exc), message) from exc return SuccessResponseMessage.model_construct(incoming_message=message) diff --git a/src/handlers/tests/messages_handler/conftest.py b/src/handlers/tests/messages_handler/conftest.py index e827e24..f91052b 100644 --- a/src/handlers/tests/messages_handler/conftest.py +++ b/src/handlers/tests/messages_handler/conftest.py @@ -1,7 +1,7 @@ import pytest -from handlers.messages_handler import WebSocketMessagesHandler from handlers.dto import AuthMessage, SubscribeMessage, UnsubscribeMessage +from handlers.messages_handler import WebSocketMessagesHandler @pytest.fixture(autouse=True) diff --git a/src/handlers/tests/messages_handler/tests_message_handler_common.py b/src/handlers/tests/messages_handler/tests_message_handler_common.py index 5a811bb..bc0ecd0 100644 --- a/src/handlers/tests/messages_handler/tests_message_handler_common.py +++ b/src/handlers/tests/messages_handler/tests_message_handler_common.py @@ -1,4 +1,5 @@ import pytest + from handlers.messages_handler import WebSocketMessagesHandler diff --git a/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py b/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py index f266259..3794220 100644 --- a/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py +++ b/src/handlers/tests/messages_handler/tests_subscirbe_message_handler.py @@ -1,4 +1,5 @@ import pytest + from handlers.messages_handler import WebSocketMessagesHandler from storage.storage_updaters import StorageUserSubscriber diff --git a/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py b/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py index 61e5c55..546db33 100644 --- a/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py +++ b/src/handlers/tests/messages_handler/tests_unsubscirbe_message_handler.py @@ -1,4 +1,5 @@ import pytest + from handlers.messages_handler import WebSocketMessagesHandler from storage.storage_updaters import StorageUserUnsubscriber diff --git a/src/handlers/tests/tests_websockets_access_guardian.py b/src/handlers/tests/tests_websockets_access_guardian.py index 27916cc..3ed4c6c 100644 --- a/src/handlers/tests/tests_websockets_access_guardian.py +++ b/src/handlers/tests/tests_websockets_access_guardian.py @@ -1,12 +1,11 @@ import asyncio -from datetime import datetime import json import pytest +from datetime import datetime from app.types import DecodedValidToken from handlers.websockets_access_guardian import WebSocketsAccessGuardian -from storage.storage_updaters import StorageWebSocketRegister -from storage.storage_updaters import StorageUserSubscriber +from storage.storage_updaters import StorageUserSubscriber, StorageWebSocketRegister pytestmark = [ pytest.mark.slow, @@ -30,8 +29,6 @@ def guardian(storage): @pytest.fixture(autouse=True) async def guardian_as_task(guardian): - # event_loop = asyncio.get_event_loop() - stop_signal = asyncio.get_event_loop().create_future() runner_task = asyncio.create_task(guardian.run(stop_signal)) diff --git a/src/handlers/websockets_access_guardian.py b/src/handlers/websockets_access_guardian.py index 8c5cc75..fa063a2 100644 --- a/src/handlers/websockets_access_guardian.py +++ b/src/handlers/websockets_access_guardian.py @@ -1,6 +1,6 @@ import asyncio -from dataclasses import dataclass import logging +from dataclasses import dataclass import websockets @@ -37,9 +37,6 @@ def monitor_and_manage_access(self) -> None: self.broadcast_authentication_expired(expired_websockets) self.remove_expired_websockets(expired_websockets) - e = self.storage.get_expired_websockets() - assert e is not None - def broadcast_authentication_expired(self, expired_websockets: list[websockets.WebSocketServerProtocol]) -> None: error_message = ErrorResponseMessage(errors=["Token expired, user subscriptions disabled or removed"], incoming_message=None) websockets.broadcast(websockets=expired_websockets, message=error_message.model_dump_json(exclude_none=True)) diff --git a/src/handlers/websockets_handler.py b/src/handlers/websockets_handler.py index 9c28749..379dcc3 100644 --- a/src/handlers/websockets_handler.py +++ b/src/handlers/websockets_handler.py @@ -1,22 +1,18 @@ -from dataclasses import dataclass import logging +from dataclasses import dataclass from typing import Annotated import pydantic -from pydantic import Field -from pydantic import TypeAdapter +from pydantic import Field, TypeAdapter from websockets import WebSocketServerProtocol from websockets.exceptions import ConnectionClosedError from app import conf -from handlers.dto import AuthMessage -from handlers.dto import ErrorResponseMessage -from handlers.dto import IncomingMessage -from handlers.dto import SuccessResponseMessage +from handlers.dto import AuthMessage, ErrorResponseMessage, IncomingMessage, SuccessResponseMessage from handlers.exceptions import WebsocketMessageException +from handlers.messages_handler import WebSocketMessagesHandler from storage.storage_updaters import StorageWebSocketRemover from storage.subscription_storage import SubscriptionStorage -from handlers.messages_handler import WebSocketMessagesHandler logger = logging.getLogger(__name__) diff --git a/src/storage/fixtures.py b/src/storage/fixtures.py index 7742bd4..4184d84 100644 --- a/src/storage/fixtures.py +++ b/src/storage/fixtures.py @@ -1,9 +1,9 @@ import pytest from app.types import DecodedValidToken -from storage.storage_updaters.storage_websocket_register import StorageWebSocketRegister -from storage.storage_updaters.storage_user_subscriber import StorageUserSubscriber from storage import SubscriptionStorage +from storage.storage_updaters.storage_user_subscriber import StorageUserSubscriber +from storage.storage_updaters.storage_websocket_register import StorageWebSocketRegister @pytest.fixture diff --git a/src/storage/storage_updaters/__init__.py b/src/storage/storage_updaters/__init__.py index 201c3dc..6b86e7e 100644 --- a/src/storage/storage_updaters/__init__.py +++ b/src/storage/storage_updaters/__init__.py @@ -1,7 +1,7 @@ -from storage.storage_updaters.storage_websocket_register import StorageWebSocketRegister -from storage.storage_updaters.storage_websocket_remover import StorageWebSocketRemover from storage.storage_updaters.storage_user_subscriber import StorageUserSubscriber from storage.storage_updaters.storage_user_unsubscriber import StorageUserUnsubscriber +from storage.storage_updaters.storage_websocket_register import StorageWebSocketRegister +from storage.storage_updaters.storage_websocket_remover import StorageWebSocketRemover __all__ = [ "StorageWebSocketRegister", diff --git a/src/storage/storage_updaters/storage_user_subscriber.py b/src/storage/storage_updaters/storage_user_subscriber.py index 7de0136..ac55953 100644 --- a/src/storage/storage_updaters/storage_user_subscriber.py +++ b/src/storage/storage_updaters/storage_user_subscriber.py @@ -1,12 +1,11 @@ +import logging from dataclasses import dataclass from functools import cached_property -import logging from websockets.server import WebSocketServerProtocol from app.services import BaseService -from app.types import UserId -from app.types import Event +from app.types import Event, UserId from storage.exceptions import StorageOperationException from storage.subscription_storage import SubscriptionStorage diff --git a/src/storage/storage_updaters/storage_user_unsubscriber.py b/src/storage/storage_updaters/storage_user_unsubscriber.py index eb0be8f..7e8fac5 100644 --- a/src/storage/storage_updaters/storage_user_unsubscriber.py +++ b/src/storage/storage_updaters/storage_user_unsubscriber.py @@ -1,12 +1,11 @@ +import logging from dataclasses import dataclass from functools import cached_property -import logging from websockets.server import WebSocketServerProtocol from app.services import BaseService -from app.types import UserId -from app.types import Event +from app.types import Event, UserId from storage.exceptions import StorageOperationException from storage.subscription_storage import SubscriptionStorage diff --git a/src/storage/storage_updaters/storage_websocket_register.py b/src/storage/storage_updaters/storage_websocket_register.py index 9c7a728..9ae41ba 100644 --- a/src/storage/storage_updaters/storage_websocket_register.py +++ b/src/storage/storage_updaters/storage_websocket_register.py @@ -1,26 +1,23 @@ +import logging from dataclasses import dataclass from functools import cached_property -import logging from websockets.server import WebSocketServerProtocol from app.services import BaseService -from app.types import DecodedValidToken -from app.types import UserId +from app.types import DecodedValidToken, UserId from storage.exceptions import StorageOperationException from storage.subscription_storage import SubscriptionStorage -from storage.types import ConnectedUserMeta -from storage.types import WebSocketMeta +from storage.types import ConnectedUserMeta, WebSocketMeta logger = logging.getLogger(__name__) @dataclass class StorageWebSocketRegister(BaseService): - """Add or update websocket in storage + """Add or update websocket in storage. If websocket not registered: just register it. - If websocket is registered already: - if token's 'user_id' is the same then update websocket's expiration timestamp - if token's 'user_id' is different then do not change existed registered websocket and raise `StorageOperationException` @@ -47,7 +44,7 @@ def act(self) -> None: if self.existed_websocket_meta: self.update_registered_websocket_meta(self.existed_websocket_meta) - return None + return self.register_new_websocket() diff --git a/src/storage/storage_updaters/storage_websocket_remover.py b/src/storage/storage_updaters/storage_websocket_remover.py index 6ed11fa..fbec4bf 100644 --- a/src/storage/storage_updaters/storage_websocket_remover.py +++ b/src/storage/storage_updaters/storage_websocket_remover.py @@ -1,9 +1,9 @@ -from dataclasses import dataclass import logging +from dataclasses import dataclass +from functools import cached_property from typing import cast from websockets.server import WebSocketServerProtocol -from functools import cached_property from app.services import BaseService from app.types import UserId @@ -15,7 +15,7 @@ @dataclass class StorageWebSocketRemover(BaseService): - """ "Remove connection from storage. + """Remove connection from storage. If websocket is not registered then nothing to do. If websocket is registered and it's last connection then unsubscribe user from all subscriptions. @@ -31,7 +31,7 @@ def websocket_user_id(self) -> UserId: def act(self) -> None: if not self.storage.is_websocket_registered(self.websocket): - return None + return if self.is_last_user_websocket(): self.remove_user_subscriptions() diff --git a/src/storage/subscription_storage.py b/src/storage/subscription_storage.py index 5d08477..82b6d0e 100644 --- a/src/storage/subscription_storage.py +++ b/src/storage/subscription_storage.py @@ -1,13 +1,10 @@ -from dataclasses import dataclass -from dataclasses import field import time +from dataclasses import dataclass, field from websockets import WebSocketServerProtocol -from app.types import Event -from app.types import UserId -from storage.types import ConnectedUserMeta -from storage.types import WebSocketMeta +from app.types import Event, UserId +from storage.types import ConnectedUserMeta, WebSocketMeta @dataclass diff --git a/src/storage/tests/storage_updaters/tests_storage_websocket_register.py b/src/storage/tests/storage_updaters/tests_storage_websocket_register.py index 47b7990..3207391 100644 --- a/src/storage/tests/storage_updaters/tests_storage_websocket_register.py +++ b/src/storage/tests/storage_updaters/tests_storage_websocket_register.py @@ -3,8 +3,7 @@ from storage.exceptions import StorageOperationException from storage.storage_updaters import StorageWebSocketRegister from storage.subscription_storage import SubscriptionStorage -from storage.types import ConnectedUserMeta -from storage.types import WebSocketMeta +from storage.types import ConnectedUserMeta, WebSocketMeta @pytest.fixture diff --git a/src/storage/tests/storage_updaters/tests_storage_websocket_remover.py b/src/storage/tests/storage_updaters/tests_storage_websocket_remover.py index 109c8e3..8b316f4 100644 --- a/src/storage/tests/storage_updaters/tests_storage_websocket_remover.py +++ b/src/storage/tests/storage_updaters/tests_storage_websocket_remover.py @@ -1,5 +1,5 @@ -from contextlib import nullcontext as does_not_raise import pytest +from contextlib import nullcontext as does_not_raise from storage.storage_updaters import StorageWebSocketRemover diff --git a/src/storage/types.py b/src/storage/types.py index 464d993..95af629 100644 --- a/src/storage/types.py +++ b/src/storage/types.py @@ -3,8 +3,7 @@ from websockets.server import WebSocketServerProtocol -from app.types import Event -from app.types import UserId +from app.types import Event, UserId @dataclass diff --git a/src/tests/functional/conftest.py b/src/tests/functional/conftest.py index 18458b9..5424b2b 100644 --- a/src/tests/functional/conftest.py +++ b/src/tests/functional/conftest.py @@ -4,10 +4,9 @@ from websockets import client -from entrypoint import app_runner -from handlers import WebSocketsHandler -from handlers import WebSocketsAccessGuardian from consumer import Consumer +from entrypoint import app_runner +from handlers import WebSocketsAccessGuardian, WebSocketsHandler @pytest.fixture @@ -18,7 +17,7 @@ def force_token_validation(mocker, valid_token): @pytest.fixture(autouse=True) def _adjust_settings(settings, unused_tcp_port): settings.BROKER_QUEUE = None # force consumer to create a queue with a random name - settings.WEBSOCKETS_HOST = "0.0.0.0" + settings.WEBSOCKETS_HOST = "0.0.0.0" # noqa: S104 settings.WEBSOCKETS_PORT = unused_tcp_port diff --git a/src/tests/functional/tests_remove_registered_on_token_expiration.py b/src/tests/functional/tests_remove_registered_on_token_expiration.py index cc42b2a..e2293fa 100644 --- a/src/tests/functional/tests_remove_registered_on_token_expiration.py +++ b/src/tests/functional/tests_remove_registered_on_token_expiration.py @@ -9,7 +9,7 @@ @pytest.fixture def set_storage_connections_expired(storage): def set_expired(): - for _, websocket_meta in storage.registered_websockets.items(): + for websocket_meta in storage.registered_websockets.values(): websocket_meta.expiration_timestamp = 1000 # far in the past return set_expired