Skip to content

Commit

Permalink
Merge pull request #45 from claroty/bugfix/rnew-26573
Browse files Browse the repository at this point in the history
Fix utcnow() issue and bump python to 3.10
  • Loading branch information
mixmind authored Dec 9, 2024
2 parents a45fed0 + e08da8d commit 5a7b81f
Show file tree
Hide file tree
Showing 14 changed files with 493 additions and 440 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v1
with:
python-version: 3.9
python-version: 3.10.15
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v1
with:
python-version: 3.9
python-version: 3.10.15
- name: Install PostgreSQL
uses: harmon758/postgresql-action@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.9-slim
FROM python:3.10-slim

# Install python requirements
WORKDIR /opt
Expand Down
5 changes: 2 additions & 3 deletions jwthenticator/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import

from typing import Optional, Any, Dict
from datetime import datetime
from urllib.parse import urljoin
from http import HTTPStatus
from uuid import UUID
Expand All @@ -10,7 +9,7 @@
from aiohttp import ClientSession, ClientResponse

from jwthenticator import schemas, exceptions
from jwthenticator.utils import verify_url, fix_url_path
from jwthenticator.utils import verify_url, fix_url_path, utcnow
from jwthenticator.consts import JWT_ALGORITHM

JWT_DECODE_OPTIONS = {"verify_signature": False, "verify_exp": False}
Expand Down Expand Up @@ -82,7 +81,7 @@ def jwt(self, value: str) -> None:
def is_jwt_expired(self) -> bool:
if self._jwt_exp is None:
return True
return datetime.utcnow().timestamp() >= self._jwt_exp
return utcnow().timestamp() >= self._jwt_exp

@property
def refresh_token(self) -> Optional[str]:
Expand Down
4 changes: 2 additions & 2 deletions jwthenticator/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.utils import create_async_session_factory, utcnow
from jwthenticator.schemas import KeyData
from jwthenticator.models import Base, KeyInfo
from jwthenticator.exceptions import InvalidKeyError
Expand All @@ -32,7 +32,7 @@ async def create_key(self, key: str, identifier: UUID, expires_at: Optional[date
:return: Returns True if successfull, raises exception otherwise.
"""
if expires_at is None:
expires_at = datetime.utcnow() + timedelta(seconds=KEY_EXPIRY)
expires_at = utcnow() + timedelta(seconds=KEY_EXPIRY)
key_hash = sha512(key.encode()).hexdigest()

# If key already exists, update expiry date.
Expand Down
2 changes: 1 addition & 1 deletion jwthenticator/loop_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from jwthenticator.consts import DB_URI

main_event_loop = asyncio.new_event_loop()
db_lock = asyncio.Lock(loop=main_event_loop)
db_lock = asyncio.Lock()

def is_using_sqlite()->bool:
return "sqlite://" in DB_URI
59 changes: 42 additions & 17 deletions jwthenticator/models.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
# pylint: disable=too-few-public-methods
from __future__ import absolute_import

from typing import Any
from datetime import datetime

from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, declarative_base
from datetime import datetime, timezone
from sqlalchemy import create_engine, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, DeclarativeBase, mapped_column
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy_utils.types.uuid import UUIDType

from jwthenticator.consts import DB_URI
from jwthenticator.utils import utcnow

engine = create_engine(DB_URI)
SessionMaker = sessionmaker(bind=engine)

Base = declarative_base() # type: Any # pylint: disable=invalid-name

class DateTimeMixin:
_created = mapped_column("created", DateTime, default=utcnow().replace(tzinfo=None))
_expires_at = mapped_column("expires_at", DateTime)

@property
def created(self)-> datetime:
if self._created and self._created.tzinfo is None:
return self._created.replace(tzinfo=timezone.utc)
return self._created

@created.setter
def created(self, created: datetime)-> None:
if created and created.tzinfo:
self._created = created.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._created = created

@property
def expires_at(self)-> datetime:
if self._expires_at and self._expires_at.tzinfo is None:
return self._expires_at.replace(tzinfo=timezone.utc)
return self._expires_at

@expires_at.setter
def expires_at(self, expires_at: datetime)-> None:
if expires_at and expires_at.tzinfo:
self._expires_at = expires_at.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._expires_at = expires_at

class Base(DeclarativeBase, DateTimeMixin):
pass

class KeyInfo(Base):
__tablename__ = "keys"
id = Column(Integer, primary_key=True, autoincrement=True)
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
key_hash = Column(String(256), unique=True)
identifier = Column(UUIDType(binary=False), nullable=False) # type: ignore
id = mapped_column(Integer, primary_key=True, autoincrement=True)
key_hash = mapped_column(String(256), unique=True)
identifier = mapped_column(UUIDType(binary=False), nullable=False)


class RefreshTokenInfo(Base):
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
token = Column(String(512))
key_id = Column(Integer, ForeignKey("keys.id"))
id = mapped_column(Integer, primary_key=True, autoincrement=True)
token = mapped_column(String(512))
key_id = mapped_column(Integer, ForeignKey("keys.id"))


# Create database + tables
Expand Down
8 changes: 4 additions & 4 deletions jwthenticator/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from marshmallow_dataclass import dataclass, NewType, add_schema

from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY
from jwthenticator.utils import utcnow

# Define the UUID type that uses Marshmallow's UUID + Python's UUID
UUID = NewType("UUID", uuid.UUID, field=fields.UUID)
Expand Down Expand Up @@ -40,7 +41,7 @@ class KeyData:
key: Optional[str] = field(default=None, repr=False, metadata={"load_only": True})

async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()
return self.expires_at > utcnow()


@dataclass
Expand All @@ -53,7 +54,7 @@ class RefreshTokenData:
key_id: int

async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()
return self.expires_at > utcnow()


# Skipping None values on dump since 'aud' is optional and can't be None/empty
Expand All @@ -68,8 +69,7 @@ class JWTPayloadData:
aud: Optional[List[str]] = None # JWT Audience

async def is_valid(self) -> bool:
return self.exp > datetime.utcnow().timestamp()

return self.exp > utcnow().timestamp()

# Request dataclasses
@dataclass
Expand Down
29 changes: 26 additions & 3 deletions jwthenticator/tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
from __future__ import absolute_import

import time
import os
from datetime import timedelta
from collections.abc import Callable, Generator
from uuid import uuid4

import pytest

from jwthenticator.tokens import TokenManager
from jwthenticator.keys import KeyManager
from jwthenticator.utils import create_rsa_key_pair
from jwthenticator.utils import create_rsa_key_pair, utcnow
from jwthenticator.schemas import RefreshTokenData
from jwthenticator.tests.utils import random_key, hash_key


@pytest.fixture
def set_timezone()-> Generator[Callable[[str], None], None, None]:
original_tz = os.environ.get('TZ')
def change_timezone(time_zone: str) -> None:
os.environ['TZ'] = time_zone
time.tzset() # Update the timezone for the process
yield change_timezone
# Restore the original timezone
if original_tz is not None:
os.environ['TZ'] = original_tz
else:
del os.environ['TZ']
time.tzset()

class TestTokens:

def setup_class(self) -> None:
Expand Down Expand Up @@ -42,12 +60,17 @@ async def _create_refresh_token(self) -> str:

# Create access token tests
@pytest.mark.asyncio
async def test_create_access_token(self) -> None:
async def test_create_access_token(self, set_timezone: Callable[[str], None]) -> None: # pylint: disable=redefined-outer-name
uuid = uuid4()
time_now = utcnow()
time_now_timestamp = int(time_now.timestamp())
time_plus_some_time = time_now + timedelta(seconds=10)
time_plus_some_time_timestamp = int(time_plus_some_time.timestamp())
set_timezone("America/Los_Angeles")
token = await self.token_manager.create_access_token(uuid)
token_data = await self.token_manager.load_access_token(token)
assert token_data.identifier == uuid

assert time_now_timestamp <= token_data.iat <= time_plus_some_time_timestamp

# Create refresh token tests
@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion jwthenticator/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def hash_key(key: str) -> str:


async def future_datetime(seconds: int = 0) -> datetime:
return datetime.utcnow() + timedelta(seconds=seconds)
return utils.utcnow() + timedelta(seconds=seconds)


def backup_environment(func): # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jwt
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.utils import create_async_session_factory, utcnow
from jwthenticator.models import Base, RefreshTokenInfo
from jwthenticator.schemas import JWTPayloadData, RefreshTokenData
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
Expand Down Expand Up @@ -51,7 +51,7 @@ async def create_access_token(self, identifier: UUID) -> str:
"""
if self.private_key is None:
raise RuntimeError("Private key required for JWT token creation")
utc_now = datetime.utcnow()
utc_now = utcnow()
payload = JWTPayloadData(
token_id=uuid4(),
identifier=identifier,
Expand Down Expand Up @@ -87,8 +87,8 @@ async def create_refresh_token(self, key_id: int, expires_at: Optional[datetime]
:return: The refresh token created.
"""
if expires_at is None:
expires_at = expires_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_EXPIRY)
if expires_at <= datetime.utcnow():
expires_at = expires_at = utcnow() + timedelta(seconds=REFRESH_TOKEN_EXPIRY)
if expires_at <= utcnow():
raise RuntimeError("Refresh token can't be created in the past")

refresh_token_str = sha512(uuid4().bytes).hexdigest()
Expand Down
16 changes: 9 additions & 7 deletions jwthenticator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import asyncio
from os.path import isfile
from typing import Any, Dict, Tuple, Optional
from typing import Any, Dict, Tuple, Optional, Type
from urllib.parse import urlparse

from datetime import datetime, timezone
from jwt.utils import base64url_encode
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA1
try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
from sqlalchemy.orm import sessionmaker as async_sessionmaker # type:ignore
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.pool import NullPool

from jwthenticator.consts import RSA_KEY_STRENGTH, RSA_PUBLIC_KEY, RSA_PRIVATE_KEY, RSA_PUBLIC_KEY_PATH, RSA_PRIVATE_KEY_PATH
Expand Down Expand Up @@ -104,11 +103,11 @@ def fix_url_path(url: str) -> str:
"""
return url if url.endswith("/") else url + "/"

async def create_base(engine: AsyncEngine, base: Any) -> None:
async def create_base(engine: AsyncEngine, base: Type[DeclarativeBase]) -> None:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)

def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_kwargs: Dict[Any, Any]) -> Any:
def create_async_session_factory(uri: str, base: Optional[Type[DeclarativeBase]] = None, **engine_kwargs: Dict[Any, Any]) -> async_sessionmaker[AsyncSession]:
"""
:param uri: Database uniform resource identifier
:param base: Declarative SQLAlchemy class to base off table initialization
Expand All @@ -119,3 +118,6 @@ def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_
if base is not None:
asyncio.get_event_loop().run_until_complete(create_base(engine, base))
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

def utcnow() -> datetime:
return datetime.now(timezone.utc)
Loading

0 comments on commit 5a7b81f

Please sign in to comment.