Skip to content

Commit

Permalink
Fixed CR
Browse files Browse the repository at this point in the history
  • Loading branch information
omerabuddi committed Aug 17, 2023
1 parent f51dc1a commit a2e28fc
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 26 deletions.
11 changes: 5 additions & 6 deletions jwthenticator/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hashlib import sha512
from uuid import UUID

from sqlalchemy import select
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.schemas import KeyData
Expand Down Expand Up @@ -55,9 +55,8 @@ async def check_key_exists(self, key_hash: str) -> bool:
Check if a key exists in DB.
"""
async with self.async_session_factory() as session:
query = select(KeyInfo.id).where(KeyInfo.key_hash == key_hash)
result = await session.execute(query)
return len(result.all()) == 1
query = select(func.count(KeyInfo.id)).where(KeyInfo.key_hash == key_hash)
return (await session.scalar(query)) == 1


async def update_key_expiry(self, key_hash: str, expires_at: datetime) -> bool:
Expand All @@ -68,7 +67,7 @@ async def update_key_expiry(self, key_hash: str, expires_at: datetime) -> bool:
raise InvalidKeyError("Invalid key")
async with self.async_session_factory() as session:
query = select(KeyInfo).where(KeyInfo.key_hash == key_hash)
key_info_obj = (await session.execute(query)).scalars().first()
key_info_obj = await session.scalar(query)
key_info_obj.expires_at = expires_at
return True

Expand All @@ -81,6 +80,6 @@ async def get_key(self, key_hash: str) -> KeyData:
raise InvalidKeyError("Invalid key")
async with self.async_session_factory() as session:
query = select(KeyInfo).where(KeyInfo.key_hash == key_hash)
key_info_obj = (await session.execute(query)).scalars().first()
key_info_obj = await session.scalar(query)
key_data_obj = self.key_schema.load((self.key_schema.dump(key_info_obj)))
return key_data_obj
7 changes: 3 additions & 4 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from uuid import UUID, uuid4

import jwt
from sqlalchemy import select
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.models import Base, RefreshTokenInfo
Expand Down Expand Up @@ -105,9 +105,8 @@ async def check_refresh_token_exists(self, refresh_token: str) -> bool:
Check if a refresh token exists in DB.
"""
async with self.async_session_factory() as session:
query = select(RefreshTokenInfo).where(RefreshTokenInfo.token == refresh_token)
result = (await session.execute(query))
return len(result.all()) == 1
query = select(func.count(RefreshTokenInfo.id)).where(RefreshTokenInfo.token == refresh_token)
return (await session.scalar(query)) == 1


async def load_refresh_token(self, refresh_token: str) -> RefreshTokenData:
Expand Down
15 changes: 1 addition & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ marshmallow-dataclass = "^8.3"
pycryptodomex = "^3.9"
environs = "^9.3.1"
asyncpg = "^0.28.0"
asyncio = "^3.4.3"
nest-asyncio = "^1.5.7"


[tool.poetry.dev-dependencies]
Expand All @@ -42,6 +40,9 @@ aiofiles = "^0.7.0"
typing-inspect = "0.7.1" # https://github.com/lovasoa/marshmallow_dataclass/issues/206


[tool.poetry.group.dev.dependencies]
nest-asyncio = "^1.5.7"

[tool.pylint.message_control]
disable = [
"missing-class-docstring",
Expand Down

0 comments on commit a2e28fc

Please sign in to comment.