From 9b45fbd426faa5847eb2cbd99ff580a7da71bde2 Mon Sep 17 00:00:00 2001 From: kobim Date: Wed, 26 Jan 2022 10:40:04 +0200 Subject: [PATCH 1/4] - add lock on writing, since this is mostly used on sqlite, (remove the lock incase of non sqlite) - use the same loop for lock and server --- jwthenticator/loop_management.py | 4 ++++ jwthenticator/server.py | 3 ++- jwthenticator/tokens.py | 21 +++++++++++---------- 3 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 jwthenticator/loop_management.py diff --git a/jwthenticator/loop_management.py b/jwthenticator/loop_management.py new file mode 100644 index 0000000..779a8e7 --- /dev/null +++ b/jwthenticator/loop_management.py @@ -0,0 +1,4 @@ +import asyncio + +loop = asyncio.new_event_loop() +db_lock = asyncio.Lock(loop=loop) \ No newline at end of file diff --git a/jwthenticator/server.py b/jwthenticator/server.py index 990913a..0ae5d5e 100644 --- a/jwthenticator/server.py +++ b/jwthenticator/server.py @@ -16,6 +16,7 @@ from jwthenticator.utils import get_rsa_key_pair from jwthenticator.server_utils import extract_jwt from jwthenticator.consts import PORT, URL_PREFIX, DISABLE_EXTERNAL_API, DISABLE_INTERNAL_API +from loop_management import loop class Server: @@ -59,7 +60,7 @@ def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(), ]) if start_server: - web.run_app(self.app, port=port) + web.run_app(self.app, port=port, loop=loop) async def authenticate(self, request: web.Request) -> web.Response: diff --git a/jwthenticator/tokens.py b/jwthenticator/tokens.py index d8433eb..ec773bb 100644 --- a/jwthenticator/tokens.py +++ b/jwthenticator/tokens.py @@ -12,7 +12,7 @@ from jwthenticator.schemas import JWTPayloadData, RefreshTokenData from jwthenticator.exceptions import InvalidTokenError, MissingJWTError from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI - +from loop_management import db_lock class TokenManager: """ @@ -88,15 +88,16 @@ async def create_refresh_token(self, key_id: int, expires_at: Optional[datetime] raise Exception("Refresh token can't be created in the past") refresh_token_str = sha512(uuid4().bytes).hexdigest() - async with self.session_factory() as session: - refresh_token_info_obj = RefreshTokenInfo( - expires_at=expires_at, - token=refresh_token_str, - key_id=key_id - ) - await session.add(refresh_token_info_obj) - await session.flush() - return refresh_token_str + async with db_lock: + async with self.session_factory() as session: + refresh_token_info_obj = RefreshTokenInfo( + expires_at=expires_at, + token=refresh_token_str, + key_id=key_id + ) + await session.add(refresh_token_info_obj) + await session.flush() + return refresh_token_str async def check_refresh_token_exists(self, refresh_token: str) -> bool: From 587ee7cbc6a751afabbd130e2ecc314555ec4424 Mon Sep 17 00:00:00 2001 From: kobim Date: Wed, 26 Jan 2022 16:32:04 +0200 Subject: [PATCH 2/4] enable the db lock only when sqlite is used --- jwthenticator/api.py | 10 +++++++++- jwthenticator/loop_management.py | 6 +++++- jwthenticator/server.py | 2 +- jwthenticator/tokens.py | 21 ++++++++++----------- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/jwthenticator/api.py b/jwthenticator/api.py index acaaa93..2e3fd72 100644 --- a/jwthenticator/api.py +++ b/jwthenticator/api.py @@ -13,6 +13,8 @@ from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME, JWT_AUDIENCE from jwthenticator.utils import get_rsa_key_pair, calculate_key_signature from jwthenticator.exceptions import ExpiredError +from jwthenticator.loop_management import is_using_sqlite, db_lock + class JWThenticatorAPI: """ @@ -49,7 +51,13 @@ async def authenticate(self, request: schemas.AuthRequest) -> schemas.TokenRespo raise ExpiredError("Key is expired.") jwt_token = await self.token_manager.create_access_token(request.identifier) - refresh_token = await self.token_manager.create_refresh_token(key_obj.id) + + if is_using_sqlite(): + # this lock is designed to block multiple request to sqlite causing DB lock + async with db_lock: + refresh_token = await self.token_manager.create_refresh_token(key_obj.id) + else: + refresh_token = await self.token_manager.create_refresh_token(key_obj.id) return schemas.TokenResponse( jwt=jwt_token, diff --git a/jwthenticator/loop_management.py b/jwthenticator/loop_management.py index 779a8e7..41eaeb3 100644 --- a/jwthenticator/loop_management.py +++ b/jwthenticator/loop_management.py @@ -1,4 +1,8 @@ import asyncio +from jwthenticator.consts import DB_URI loop = asyncio.new_event_loop() -db_lock = asyncio.Lock(loop=loop) \ No newline at end of file +db_lock = asyncio.Lock(loop=loop) + +def is_using_sqlite()->bool: + return "sqlite" in DB_URI \ No newline at end of file diff --git a/jwthenticator/server.py b/jwthenticator/server.py index 0ae5d5e..1e319b6 100644 --- a/jwthenticator/server.py +++ b/jwthenticator/server.py @@ -16,7 +16,7 @@ from jwthenticator.utils import get_rsa_key_pair from jwthenticator.server_utils import extract_jwt from jwthenticator.consts import PORT, URL_PREFIX, DISABLE_EXTERNAL_API, DISABLE_INTERNAL_API -from loop_management import loop +from jwthenticator.loop_management import loop class Server: diff --git a/jwthenticator/tokens.py b/jwthenticator/tokens.py index ec773bb..83ac62c 100644 --- a/jwthenticator/tokens.py +++ b/jwthenticator/tokens.py @@ -12,7 +12,7 @@ from jwthenticator.schemas import JWTPayloadData, RefreshTokenData from jwthenticator.exceptions import InvalidTokenError, MissingJWTError from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI -from loop_management import db_lock +from jwthenticator.loop_management import db_lock class TokenManager: """ @@ -88,16 +88,15 @@ async def create_refresh_token(self, key_id: int, expires_at: Optional[datetime] raise Exception("Refresh token can't be created in the past") refresh_token_str = sha512(uuid4().bytes).hexdigest() - async with db_lock: - async with self.session_factory() as session: - refresh_token_info_obj = RefreshTokenInfo( - expires_at=expires_at, - token=refresh_token_str, - key_id=key_id - ) - await session.add(refresh_token_info_obj) - await session.flush() - return refresh_token_str + async with self.session_factory() as session: + refresh_token_info_obj = RefreshTokenInfo( + expires_at=expires_at, + token=refresh_token_str, + key_id=key_id + ) + await session.add(refresh_token_info_obj) + await session.flush() + return refresh_token_str async def check_refresh_token_exists(self, refresh_token: str) -> bool: From 400fb3885b15399c98529c57650fc0a7dbb3cb7a Mon Sep 17 00:00:00 2001 From: kobim Date: Sun, 30 Jan 2022 16:27:14 +0200 Subject: [PATCH 3/4] - rename loop to a more readable name - remove unused code --- jwthenticator/loop_management.py | 6 +++--- jwthenticator/server.py | 4 ++-- jwthenticator/tokens.py | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/jwthenticator/loop_management.py b/jwthenticator/loop_management.py index 41eaeb3..fc42664 100644 --- a/jwthenticator/loop_management.py +++ b/jwthenticator/loop_management.py @@ -1,8 +1,8 @@ import asyncio from jwthenticator.consts import DB_URI -loop = asyncio.new_event_loop() -db_lock = asyncio.Lock(loop=loop) +main_event_loop = asyncio.new_event_loop() +db_lock = asyncio.Lock(loop=main_event_loop) def is_using_sqlite()->bool: - return "sqlite" in DB_URI \ No newline at end of file + return "sqlite://" in DB_URI diff --git a/jwthenticator/server.py b/jwthenticator/server.py index 1e319b6..aaca26f 100644 --- a/jwthenticator/server.py +++ b/jwthenticator/server.py @@ -16,7 +16,7 @@ from jwthenticator.utils import get_rsa_key_pair from jwthenticator.server_utils import extract_jwt from jwthenticator.consts import PORT, URL_PREFIX, DISABLE_EXTERNAL_API, DISABLE_INTERNAL_API -from jwthenticator.loop_management import loop +from jwthenticator.loop_management import main_event_loop class Server: @@ -60,7 +60,7 @@ def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(), ]) if start_server: - web.run_app(self.app, port=port, loop=loop) + web.run_app(self.app, port=port, loop=main_event_loop) async def authenticate(self, request: web.Request) -> web.Response: diff --git a/jwthenticator/tokens.py b/jwthenticator/tokens.py index 83ac62c..124b02f 100644 --- a/jwthenticator/tokens.py +++ b/jwthenticator/tokens.py @@ -12,7 +12,6 @@ from jwthenticator.schemas import JWTPayloadData, RefreshTokenData from jwthenticator.exceptions import InvalidTokenError, MissingJWTError from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI -from jwthenticator.loop_management import db_lock class TokenManager: """ From 80c7348f74f5626f33ed05583a20b12e5d1046ee Mon Sep 17 00:00:00 2001 From: kobim Date: Sun, 30 Jan 2022 19:55:27 +0200 Subject: [PATCH 4/4] fixed integration test --- jwthenticator/tests/test_integration.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/jwthenticator/tests/test_integration.py b/jwthenticator/tests/test_integration.py index 618d76f..7ee9a2b 100644 --- a/jwthenticator/tests/test_integration.py +++ b/jwthenticator/tests/test_integration.py @@ -53,18 +53,17 @@ class TestIntegration(AioHTTPTestCase): """ async def get_application(self) -> web.Application: - self.app = web.Application() - self.app.add_routes([web.get("/", secure_endpoint)]) - return self.app - + server = Server(start_server=False) + server.app.add_routes([web.get("/dummy", secure_endpoint)]) + return server.app async def setUpAsync(self) -> None: - server = Server(start_server=False) - runner = web.AppRunner(server.app) - await runner.setup() - site = web.TCPSite(runner, SERVER_HOST, SERVER_PORT) - self.loop.create_task(site.start()) + self.app = await self.get_application() + self.server = await self.get_server(self.app) + self.server.port = SERVER_PORT + self.client = await self.get_client(self.server) + await self.client.start_server() @unittest_run_loop async def test_client_and_authenticated_server(self) -> None: @@ -82,11 +81,11 @@ async def test_client_and_authenticated_server(self) -> None: await client.refresh() # Test get_with_auth with patch("aiohttp.ClientSession.__aenter__", ContextAwareClient(self.client)): - response = await client.get_with_auth("/") + response = await client.get_with_auth("/dummy") assert response.status == HTTPStatus.OK # Test client with JWT and see that doesn't try to refresh (will fail if tries) client2 = Client(SERVER_URL, uuid_identifier, jwt=client.jwt) with patch("aiohttp.ClientSession.__aenter__", ContextAwareClient(self.client)): - response = await client2.get_with_auth("/") + response = await client2.get_with_auth("/dummy") assert response.status == HTTPStatus.OK