Skip to content

Commit

Permalink
Merge pull request #40 from claroty/fix/reuvens/sqlalchemy-14-fixes
Browse files Browse the repository at this point in the history
Fix sqalchemy issues and bump packages
  • Loading branch information
mixmind authored Nov 25, 2024
2 parents 245c31c + 200bf8b commit a45fed0
Show file tree
Hide file tree
Showing 9 changed files with 875 additions and 669 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install poetry
pip install poetry setuptools==75.5.0
poetry config virtualenvs.create false
poetry install
- name: Analysing the code with pylint + mypy
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine poetry poetry-dynamic-versioning
pip install setuptools==75.5.0 wheel twine poetry poetry-dynamic-versioning
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install poetry
pip install poetry setuptools==75.5.0
poetry config virtualenvs.create false
poetry install
- name: Run Pytest
Expand Down
2 changes: 1 addition & 1 deletion jwthenticator/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def check_key_exists(self, key_hash: str) -> bool:
Check if a key exists in DB.
"""
async with self.async_session_factory() as session:
query = select(func.count(KeyInfo.id)).where(KeyInfo.key_hash == key_hash)
query = select(func.count(KeyInfo.id)).where(KeyInfo.key_hash == key_hash) # pylint: disable=not-callable
return (await session.scalar(query)) == 1


Expand Down
5 changes: 2 additions & 3 deletions jwthenticator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from datetime import datetime

from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy_utils.types.uuid import UUIDType

Expand All @@ -24,7 +23,7 @@ class KeyInfo(Base):
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
key_hash = Column(String(256), unique=True)
identifier = Column(UUIDType(binary=False), nullable=False)
identifier = Column(UUIDType(binary=False), nullable=False) # type: ignore


class RefreshTokenInfo(Base):
Expand Down
9 changes: 7 additions & 2 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hashlib import sha512
from uuid import UUID, uuid4

import logging
import jwt
from sqlalchemy import select, func

Expand Down Expand Up @@ -70,7 +71,11 @@ async def load_access_token(self, token_string: str) -> JWTPayloadData:
"""
if not token_string:
raise MissingJWTError
token_dict: dict = jwt.decode(token_string, self.public_key, algorithms=[self.algorithm], options={"verify_exp": False})
try:
token_dict: dict = jwt.decode(token_string, self.public_key, algorithms=[self.algorithm], options={"verify_exp": False})
except Exception:
logging.error("Exception occured during token decode. Token %s", token_string)
raise
token_data = self.jwt_payload_data_schema.load(token_dict)
return token_data

Expand Down Expand Up @@ -105,7 +110,7 @@ async def check_refresh_token_exists(self, refresh_token: str) -> bool:
Check if a refresh token exists in DB.
"""
async with self.async_session_factory() as session:
query = select(func.count(RefreshTokenInfo.id)).where(RefreshTokenInfo.token == refresh_token)
query = select(func.count(RefreshTokenInfo.id)).where(RefreshTokenInfo.token == refresh_token) # pylint: disable=not-callable
return (await session.scalar(query)) == 1


Expand Down
14 changes: 7 additions & 7 deletions jwthenticator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from jwt.utils import base64url_encode
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA1
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import DeclarativeMeta
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
from sqlalchemy.orm import sessionmaker as async_sessionmaker # type:ignore
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.pool import NullPool

Expand Down Expand Up @@ -102,13 +104,11 @@ def fix_url_path(url: str) -> str:
"""
return url if url.endswith("/") else url + "/"


async def create_base(engine: AsyncEngine, base: DeclarativeMeta) -> None:
async def create_base(engine: AsyncEngine, base: Any) -> None:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)


def create_async_session_factory(uri: str, base: Optional[DeclarativeMeta] = None, **engine_kwargs: Dict[Any, Any]) -> sessionmaker:
def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_kwargs: Dict[Any, Any]) -> Any:
"""
:param uri: Database uniform resource identifier
:param base: Declarative SQLAlchemy class to base off table initialization
Expand All @@ -118,4 +118,4 @@ def create_async_session_factory(uri: str, base: Optional[DeclarativeMeta] = Non
engine = create_async_engine(uri, **engine_kwargs, poolclass=NullPool)
if base is not None:
asyncio.get_event_loop().run_until_complete(create_base(engine, base))
return sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
Loading

0 comments on commit a45fed0

Please sign in to comment.