Skip to content

Commit

Permalink
Merge pull request #23 from kobimic/sqlite_db_lock
Browse files Browse the repository at this point in the history
- add lock on writing, since this is mostly used on sqlite
  • Loading branch information
kobimic authored Jan 31, 2022
2 parents 9d77054 + 80c7348 commit 027ed6f
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 14 deletions.
10 changes: 9 additions & 1 deletion jwthenticator/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME, JWT_AUDIENCE
from jwthenticator.utils import get_rsa_key_pair, calculate_key_signature
from jwthenticator.exceptions import ExpiredError
from jwthenticator.loop_management import is_using_sqlite, db_lock


class JWThenticatorAPI:
"""
Expand Down Expand Up @@ -49,7 +51,13 @@ async def authenticate(self, request: schemas.AuthRequest) -> schemas.TokenRespo
raise ExpiredError("Key is expired.")

jwt_token = await self.token_manager.create_access_token(request.identifier)
refresh_token = await self.token_manager.create_refresh_token(key_obj.id)

if is_using_sqlite():
# this lock is designed to block multiple request to sqlite causing DB lock
async with db_lock:
refresh_token = await self.token_manager.create_refresh_token(key_obj.id)
else:
refresh_token = await self.token_manager.create_refresh_token(key_obj.id)

return schemas.TokenResponse(
jwt=jwt_token,
Expand Down
8 changes: 8 additions & 0 deletions jwthenticator/loop_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import asyncio
from jwthenticator.consts import DB_URI

main_event_loop = asyncio.new_event_loop()
db_lock = asyncio.Lock(loop=main_event_loop)

def is_using_sqlite()->bool:
return "sqlite://" in DB_URI
3 changes: 2 additions & 1 deletion jwthenticator/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from jwthenticator.utils import get_rsa_key_pair
from jwthenticator.server_utils import extract_jwt
from jwthenticator.consts import PORT, URL_PREFIX, DISABLE_EXTERNAL_API, DISABLE_INTERNAL_API
from jwthenticator.loop_management import main_event_loop


class Server:
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(),
])

if start_server:
web.run_app(self.app, port=port)
web.run_app(self.app, port=port, loop=main_event_loop)


async def authenticate(self, request: web.Request) -> web.Response:
Expand Down
21 changes: 10 additions & 11 deletions jwthenticator/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,17 @@ class TestIntegration(AioHTTPTestCase):
"""

async def get_application(self) -> web.Application:
self.app = web.Application()
self.app.add_routes([web.get("/", secure_endpoint)])
return self.app

server = Server(start_server=False)
server.app.add_routes([web.get("/dummy", secure_endpoint)])
return server.app

async def setUpAsync(self) -> None:
server = Server(start_server=False)
runner = web.AppRunner(server.app)
await runner.setup()
site = web.TCPSite(runner, SERVER_HOST, SERVER_PORT)
self.loop.create_task(site.start())
self.app = await self.get_application()
self.server = await self.get_server(self.app)
self.server.port = SERVER_PORT
self.client = await self.get_client(self.server)

await self.client.start_server()

@unittest_run_loop
async def test_client_and_authenticated_server(self) -> None:
Expand All @@ -82,11 +81,11 @@ async def test_client_and_authenticated_server(self) -> None:
await client.refresh()
# Test get_with_auth
with patch("aiohttp.ClientSession.__aenter__", ContextAwareClient(self.client)):
response = await client.get_with_auth("/")
response = await client.get_with_auth("/dummy")
assert response.status == HTTPStatus.OK

# Test client with JWT and see that doesn't try to refresh (will fail if tries)
client2 = Client(SERVER_URL, uuid_identifier, jwt=client.jwt)
with patch("aiohttp.ClientSession.__aenter__", ContextAwareClient(self.client)):
response = await client2.get_with_auth("/")
response = await client2.get_with_auth("/dummy")
assert response.status == HTTPStatus.OK
1 change: 0 additions & 1 deletion jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI


class TokenManager:
"""
Class responsible for the creation and loading of tokens
Expand Down

0 comments on commit 027ed6f

Please sign in to comment.