Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sqlalchemy 1.4/2.0 #40

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install poetry
pip install poetry setuptools==75.5.0
poetry config virtualenvs.create false
poetry install
- name: Analysing the code with pylint + mypy
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine poetry poetry-dynamic-versioning
pip install setuptools==75.5.0 wheel twine poetry poetry-dynamic-versioning
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install poetry
pip install poetry setuptools==75.5.0
poetry config virtualenvs.create false
poetry install
- name: Run Pytest
Expand Down
2 changes: 1 addition & 1 deletion jwthenticator/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def check_key_exists(self, key_hash: str) -> bool:
Check if a key exists in DB.
"""
async with self.async_session_factory() as session:
query = select(func.count(KeyInfo.id)).where(KeyInfo.key_hash == key_hash)
query = select(func.count(KeyInfo.id)).where(KeyInfo.key_hash == key_hash) # pylint: disable=not-callable
return (await session.scalar(query)) == 1


Expand Down
5 changes: 2 additions & 3 deletions jwthenticator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from datetime import datetime

from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy_utils.types.uuid import UUIDType

Expand All @@ -24,7 +23,7 @@ class KeyInfo(Base):
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
key_hash = Column(String(256), unique=True)
identifier = Column(UUIDType(binary=False), nullable=False)
identifier = Column(UUIDType(binary=False), nullable=False) # type: ignore


class RefreshTokenInfo(Base):
Expand Down
9 changes: 7 additions & 2 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from hashlib import sha512
from uuid import UUID, uuid4

import logging
import jwt
from sqlalchemy import select, func

Expand Down Expand Up @@ -70,7 +71,11 @@ async def load_access_token(self, token_string: str) -> JWTPayloadData:
"""
if not token_string:
raise MissingJWTError
token_dict: dict = jwt.decode(token_string, self.public_key, algorithms=[self.algorithm], options={"verify_exp": False})
try:
token_dict: dict = jwt.decode(token_string, self.public_key, algorithms=[self.algorithm], options={"verify_exp": False})
except Exception:
logging.error("Exception occured during token decode. Token %s", token_string)
raise
token_data = self.jwt_payload_data_schema.load(token_dict)
return token_data

Expand Down Expand Up @@ -105,7 +110,7 @@ async def check_refresh_token_exists(self, refresh_token: str) -> bool:
Check if a refresh token exists in DB.
"""
async with self.async_session_factory() as session:
query = select(func.count(RefreshTokenInfo.id)).where(RefreshTokenInfo.token == refresh_token)
query = select(func.count(RefreshTokenInfo.id)).where(RefreshTokenInfo.token == refresh_token) # pylint: disable=not-callable
return (await session.scalar(query)) == 1


Expand Down
14 changes: 7 additions & 7 deletions jwthenticator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from jwt.utils import base64url_encode
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA1
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import DeclarativeMeta
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
from sqlalchemy.orm import sessionmaker as async_sessionmaker # type:ignore
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.pool import NullPool

Expand Down Expand Up @@ -102,13 +104,11 @@ def fix_url_path(url: str) -> str:
"""
return url if url.endswith("/") else url + "/"


async def create_base(engine: AsyncEngine, base: DeclarativeMeta) -> None:
async def create_base(engine: AsyncEngine, base: Any) -> None:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)


def create_async_session_factory(uri: str, base: Optional[DeclarativeMeta] = None, **engine_kwargs: Dict[Any, Any]) -> sessionmaker:
def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_kwargs: Dict[Any, Any]) -> Any:
"""
:param uri: Database uniform resource identifier
:param base: Declarative SQLAlchemy class to base off table initialization
Expand All @@ -118,4 +118,4 @@ def create_async_session_factory(uri: str, base: Optional[DeclarativeMeta] = Non
engine = create_async_engine(uri, **engine_kwargs, poolclass=NullPool)
if base is not None:
asyncio.get_event_loop().run_until_complete(create_base(engine, base))
return sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
Loading
Loading