diff --git a/jwthenticator/api.py b/jwthenticator/api.py index acaaa93..2e3fd72 100644 --- a/jwthenticator/api.py +++ b/jwthenticator/api.py @@ -13,6 +13,8 @@ from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME, JWT_AUDIENCE from jwthenticator.utils import get_rsa_key_pair, calculate_key_signature from jwthenticator.exceptions import ExpiredError +from jwthenticator.loop_management import is_using_sqlite, db_lock + class JWThenticatorAPI: """ @@ -49,7 +51,13 @@ async def authenticate(self, request: schemas.AuthRequest) -> schemas.TokenRespo raise ExpiredError("Key is expired.") jwt_token = await self.token_manager.create_access_token(request.identifier) - refresh_token = await self.token_manager.create_refresh_token(key_obj.id) + + if is_using_sqlite(): + # this lock is designed to block multiple request to sqlite causing DB lock + async with db_lock: + refresh_token = await self.token_manager.create_refresh_token(key_obj.id) + else: + refresh_token = await self.token_manager.create_refresh_token(key_obj.id) return schemas.TokenResponse( jwt=jwt_token, diff --git a/jwthenticator/loop_management.py b/jwthenticator/loop_management.py new file mode 100644 index 0000000..fc42664 --- /dev/null +++ b/jwthenticator/loop_management.py @@ -0,0 +1,8 @@ +import asyncio +from jwthenticator.consts import DB_URI + +main_event_loop = asyncio.new_event_loop() +db_lock = asyncio.Lock(loop=main_event_loop) + +def is_using_sqlite()->bool: + return "sqlite://" in DB_URI diff --git a/jwthenticator/server.py b/jwthenticator/server.py index 990913a..aaca26f 100644 --- a/jwthenticator/server.py +++ b/jwthenticator/server.py @@ -16,6 +16,7 @@ from jwthenticator.utils import get_rsa_key_pair from jwthenticator.server_utils import extract_jwt from jwthenticator.consts import PORT, URL_PREFIX, DISABLE_EXTERNAL_API, DISABLE_INTERNAL_API +from jwthenticator.loop_management import main_event_loop class Server: @@ -59,7 +60,7 @@ def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(), ]) if start_server: - web.run_app(self.app, port=port) + web.run_app(self.app, port=port, loop=main_event_loop) async def authenticate(self, request: web.Request) -> web.Response: diff --git a/jwthenticator/tests/test_integration.py b/jwthenticator/tests/test_integration.py index 618d76f..7ee9a2b 100644 --- a/jwthenticator/tests/test_integration.py +++ b/jwthenticator/tests/test_integration.py @@ -53,18 +53,17 @@ class TestIntegration(AioHTTPTestCase): """ async def get_application(self) -> web.Application: - self.app = web.Application() - self.app.add_routes([web.get("/", secure_endpoint)]) - return self.app - + server = Server(start_server=False) + server.app.add_routes([web.get("/dummy", secure_endpoint)]) + return server.app async def setUpAsync(self) -> None: - server = Server(start_server=False) - runner = web.AppRunner(server.app) - await runner.setup() - site = web.TCPSite(runner, SERVER_HOST, SERVER_PORT) - self.loop.create_task(site.start()) + self.app = await self.get_application() + self.server = await self.get_server(self.app) + self.server.port = SERVER_PORT + self.client = await self.get_client(self.server) + await self.client.start_server() @unittest_run_loop async def test_client_and_authenticated_server(self) -> None: @@ -82,11 +81,11 @@ async def test_client_and_authenticated_server(self) -> None: await client.refresh() # Test get_with_auth with patch("aiohttp.ClientSession.__aenter__", ContextAwareClient(self.client)): - response = await client.get_with_auth("/") + response = await client.get_with_auth("/dummy") assert response.status == HTTPStatus.OK # Test client with JWT and see that doesn't try to refresh (will fail if tries) client2 = Client(SERVER_URL, uuid_identifier, jwt=client.jwt) with patch("aiohttp.ClientSession.__aenter__", ContextAwareClient(self.client)): - response = await client2.get_with_auth("/") + response = await client2.get_with_auth("/dummy") assert response.status == HTTPStatus.OK diff --git a/jwthenticator/tokens.py b/jwthenticator/tokens.py index d8433eb..124b02f 100644 --- a/jwthenticator/tokens.py +++ b/jwthenticator/tokens.py @@ -13,7 +13,6 @@ from jwthenticator.exceptions import InvalidTokenError, MissingJWTError from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI - class TokenManager: """ Class responsible for the creation and loading of tokens