Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix utcnow() issue and bump python to 3.10 #45

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions jwthenticator/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import

from typing import Optional, Any, Dict
from datetime import datetime
from urllib.parse import urljoin
from http import HTTPStatus
from uuid import UUID
Expand All @@ -10,7 +9,7 @@
from aiohttp import ClientSession, ClientResponse

from jwthenticator import schemas, exceptions
from jwthenticator.utils import verify_url, fix_url_path
from jwthenticator.utils import verify_url, fix_url_path, utcnow
from jwthenticator.consts import JWT_ALGORITHM

JWT_DECODE_OPTIONS = {"verify_signature": False, "verify_exp": False}
Expand Down Expand Up @@ -82,7 +81,7 @@ def jwt(self, value: str) -> None:
def is_jwt_expired(self) -> bool:
if self._jwt_exp is None:
return True
return datetime.utcnow().timestamp() >= self._jwt_exp
return utcnow().timestamp() >= self._jwt_exp

@property
def refresh_token(self) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions jwthenticator/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.utils import create_async_session_factory, utcnow
from jwthenticator.schemas import KeyData
from jwthenticator.models import Base, KeyInfo
from jwthenticator.exceptions import InvalidKeyError
Expand All @@ -32,7 +32,7 @@ async def create_key(self, key: str, identifier: UUID, expires_at: Optional[date
:return: Returns True if successfull, raises exception otherwise.
"""
if expires_at is None:
expires_at = datetime.utcnow() + timedelta(seconds=KEY_EXPIRY)
expires_at = utcnow() + timedelta(seconds=KEY_EXPIRY)
key_hash = sha512(key.encode()).hexdigest()

# If key already exists, update expiry date.
Expand Down
59 changes: 42 additions & 17 deletions jwthenticator/models.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
# pylint: disable=too-few-public-methods
from __future__ import absolute_import

from typing import Any
from datetime import datetime

from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, declarative_base
from datetime import datetime, timezone
from sqlalchemy import create_engine, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, DeclarativeBase, mapped_column
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy_utils.types.uuid import UUIDType

from jwthenticator.consts import DB_URI
from jwthenticator.utils import utcnow

engine = create_engine(DB_URI)
SessionMaker = sessionmaker(bind=engine)

Base = declarative_base() # type: Any # pylint: disable=invalid-name

class DateTimeMixin:
_created = mapped_column("created", DateTime, default=utcnow().replace(tzinfo=None))
_expires_at = mapped_column("expires_at", DateTime)

@property
def created(self)-> datetime:
if self._created and self._created.tzinfo is None:
return self._created.replace(tzinfo=timezone.utc)
return self._created

@created.setter
def created(self, created: datetime)-> None:
if created and created.tzinfo:
self._created = created.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._created = created

@property
def expires_at(self)-> datetime:
if self._expires_at and self._expires_at.tzinfo is None:
return self._expires_at.replace(tzinfo=timezone.utc)
return self._expires_at

@expires_at.setter
def expires_at(self, expires_at: datetime)-> None:
if expires_at and expires_at.tzinfo:
self._expires_at = expires_at.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._expires_at = expires_at

class Base(DeclarativeBase, DateTimeMixin):
pass

class KeyInfo(Base):
__tablename__ = "keys"
id = Column(Integer, primary_key=True, autoincrement=True)
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
key_hash = Column(String(256), unique=True)
identifier = Column(UUIDType(binary=False), nullable=False) # type: ignore
id = mapped_column(Integer, primary_key=True, autoincrement=True)
key_hash = mapped_column(String(256), unique=True)
identifier = mapped_column(UUIDType(binary=False), nullable=False)


class RefreshTokenInfo(Base):
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
token = Column(String(512))
key_id = Column(Integer, ForeignKey("keys.id"))
id = mapped_column(Integer, primary_key=True, autoincrement=True)
token = mapped_column(String(512))
key_id = mapped_column(Integer, ForeignKey("keys.id"))


# Create database + tables
Expand Down
8 changes: 4 additions & 4 deletions jwthenticator/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from marshmallow_dataclass import dataclass, NewType, add_schema

from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY
from jwthenticator.utils import utcnow

# Define the UUID type that uses Marshmallow's UUID + Python's UUID
UUID = NewType("UUID", uuid.UUID, field=fields.UUID)
Expand Down Expand Up @@ -40,7 +41,7 @@ class KeyData:
key: Optional[str] = field(default=None, repr=False, metadata={"load_only": True})

async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()
return self.expires_at > utcnow()


@dataclass
Expand All @@ -53,7 +54,7 @@ class RefreshTokenData:
key_id: int

async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()
return self.expires_at > utcnow()


# Skipping None values on dump since 'aud' is optional and can't be None/empty
Expand All @@ -68,8 +69,7 @@ class JWTPayloadData:
aud: Optional[List[str]] = None # JWT Audience

async def is_valid(self) -> bool:
return self.exp > datetime.utcnow().timestamp()

return self.exp > utcnow().timestamp()

# Request dataclasses
@dataclass
Expand Down
2 changes: 1 addition & 1 deletion jwthenticator/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def hash_key(key: str) -> str:


async def future_datetime(seconds: int = 0) -> datetime:
return datetime.utcnow() + timedelta(seconds=seconds)
return utils.utcnow() + timedelta(seconds=seconds)


def backup_environment(func): # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jwt
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.utils import create_async_session_factory, utcnow
from jwthenticator.models import Base, RefreshTokenInfo
from jwthenticator.schemas import JWTPayloadData, RefreshTokenData
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
Expand Down Expand Up @@ -51,7 +51,7 @@ async def create_access_token(self, identifier: UUID) -> str:
"""
if self.private_key is None:
raise RuntimeError("Private key required for JWT token creation")
utc_now = datetime.utcnow()
utc_now = utcnow()
payload = JWTPayloadData(
token_id=uuid4(),
identifier=identifier,
Expand Down Expand Up @@ -87,8 +87,8 @@ async def create_refresh_token(self, key_id: int, expires_at: Optional[datetime]
:return: The refresh token created.
"""
if expires_at is None:
expires_at = expires_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_EXPIRY)
if expires_at <= datetime.utcnow():
expires_at = expires_at = utcnow() + timedelta(seconds=REFRESH_TOKEN_EXPIRY)
if expires_at <= utcnow():
raise RuntimeError("Refresh token can't be created in the past")

refresh_token_str = sha512(uuid4().bytes).hexdigest()
Expand Down
16 changes: 9 additions & 7 deletions jwthenticator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import asyncio
from os.path import isfile
from typing import Any, Dict, Tuple, Optional
from typing import Any, Dict, Tuple, Optional, Type
from urllib.parse import urlparse

from datetime import datetime, timezone
from jwt.utils import base64url_encode
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA1
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
from sqlalchemy.orm import sessionmaker as async_sessionmaker # type:ignore
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.pool import NullPool

from jwthenticator.consts import RSA_KEY_STRENGTH, RSA_PUBLIC_KEY, RSA_PRIVATE_KEY, RSA_PUBLIC_KEY_PATH, RSA_PRIVATE_KEY_PATH
Expand Down Expand Up @@ -104,11 +103,11 @@ def fix_url_path(url: str) -> str:
"""
return url if url.endswith("/") else url + "/"

async def create_base(engine: AsyncEngine, base: Any) -> None:
async def create_base(engine: AsyncEngine, base: Type[DeclarativeBase]) -> None:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)

def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_kwargs: Dict[Any, Any]) -> Any:
def create_async_session_factory(uri: str, base: Optional[Type[DeclarativeBase]] = None, **engine_kwargs: Dict[Any, Any]) -> async_sessionmaker[AsyncSession]:
"""
:param uri: Database uniform resource identifier
:param base: Declarative SQLAlchemy class to base off table initialization
Expand All @@ -119,3 +118,6 @@ def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_
if base is not None:
asyncio.get_event_loop().run_until_complete(create_base(engine, base))
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

def utcnow() -> datetime:
return datetime.now(timezone.utc)
65 changes: 30 additions & 35 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ exclude = ["jwthenticator/tests"]

[tool.poetry.dependencies]
python = "^3.9"
sqlalchemy = ">=1.3.24, < 2.1.0"
sqlalchemy = ">=2.0, < 2.1.0"
sqlalchemy-utils = ">=0.33.0, < 1.0.0"
pg8000 = "1.16.6" # Constant due to - https://github.com/tlocke/pg8000/issues/53
aiohttp = "^3.9.5"
pyjwt = ">= 1.7, < 3.0.0"
cryptography = ">=41.0.3, <= 42.0.8" # Required for pyjwt
cryptography = ">=41.0.3, <= 43.0.4" # Required for pyjwt
marshmallow = "^3.9"
marshmallow-dataclass = "^8.3"
pycryptodomex = "^3.9"
Expand Down
Loading