Skip to content

Commit

Permalink
Fix utcnow() issue
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenstr committed Nov 26, 2024
1 parent a45fed0 commit de80a19
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 21 deletions.
5 changes: 2 additions & 3 deletions jwthenticator/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import

from typing import Optional, Any, Dict
from datetime import datetime
from urllib.parse import urljoin
from http import HTTPStatus
from uuid import UUID
Expand All @@ -10,7 +9,7 @@
from aiohttp import ClientSession, ClientResponse

from jwthenticator import schemas, exceptions
from jwthenticator.utils import verify_url, fix_url_path
from jwthenticator.utils import verify_url, fix_url_path, utcnow
from jwthenticator.consts import JWT_ALGORITHM

JWT_DECODE_OPTIONS = {"verify_signature": False, "verify_exp": False}
Expand Down Expand Up @@ -82,7 +81,7 @@ def jwt(self, value: str) -> None:
def is_jwt_expired(self) -> bool:
if self._jwt_exp is None:
return True
return datetime.utcnow().timestamp() >= self._jwt_exp
return utcnow().timestamp() >= self._jwt_exp

@property
def refresh_token(self) -> Optional[str]:
Expand Down
6 changes: 3 additions & 3 deletions jwthenticator/keys.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import absolute_import

from typing import Optional
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from hashlib import sha512
from uuid import UUID

from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.utils import create_async_session_factory, utcnow
from jwthenticator.schemas import KeyData
from jwthenticator.models import Base, KeyInfo
from jwthenticator.exceptions import InvalidKeyError
Expand All @@ -32,7 +32,7 @@ async def create_key(self, key: str, identifier: UUID, expires_at: Optional[date
:return: Returns True if successfull, raises exception otherwise.
"""
if expires_at is None:
expires_at = datetime.utcnow() + timedelta(seconds=KEY_EXPIRY)
expires_at = utcnow() + timedelta(seconds=KEY_EXPIRY)
key_hash = sha512(key.encode()).hexdigest()

# If key already exists, update expiry date.
Expand Down
49 changes: 44 additions & 5 deletions jwthenticator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from __future__ import absolute_import

from typing import Any
from datetime import datetime

from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy_utils.types.uuid import UUIDType

from sqlalchemy.ext.hybrid import hybrid_property
from datetime import datetime, timezone
from jwthenticator.consts import DB_URI
from jwthenticator.utils import utcnow

engine = create_engine(DB_URI)
SessionMaker = sessionmaker(bind=engine)
Expand All @@ -20,20 +21,58 @@
class KeyInfo(Base):
__tablename__ = "keys"
id = Column(Integer, primary_key=True, autoincrement=True)
created = Column(DateTime, default=datetime.utcnow())
_created = Column("created", DateTime, default=utcnow())
expires_at = Column(DateTime)
key_hash = Column(String(256), unique=True)
identifier = Column(UUIDType(binary=False), nullable=False) # type: ignore

@hybrid_property
def created(self):
if self._created and self._created.tzinfo is None:
return self._created.replace(tzinfo=timezone.utc)
return self._created

@created.setter
def created(self, created: datetime):
if created and created.tzinfo:
self._created = created.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._created = created


class RefreshTokenInfo(Base):
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
created = Column(DateTime, default=datetime.utcnow())
expires_at = Column(DateTime)
_created = Column("created", DateTime, default=utcnow())
_expires_at = Column("expires_at", DateTime)
token = Column(String(512))
key_id = Column(Integer, ForeignKey("keys.id"))

@hybrid_property
def created(self):
if self._created and self._created.tzinfo is None:
return self._created.replace(tzinfo=timezone.utc)
return self._created

@created.setter
def created(self, created: datetime):
if created and created.tzinfo:
self._created = created.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._created = created

@hybrid_property
def expires_at(self):
if self._expires_at and self._expires_at.tzinfo is None:
return self._expires_at.replace(tzinfo=timezone.utc)
return self._expires_at

@expires_at.setter
def expires_at(self, expires_at: datetime):
if expires_at and expires_at.tzinfo:
self._expires_at = expires_at.astimezone(timezone.utc).replace(tzinfo=None)
else:
self._expires_at = expires_at

# Create database + tables
if not database_exists(DB_URI):
Expand Down
8 changes: 4 additions & 4 deletions jwthenticator/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from marshmallow_dataclass import dataclass, NewType, add_schema

from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY
from jwthenticator.utils import utcnow

# Define the UUID type that uses Marshmallow's UUID + Python's UUID
UUID = NewType("UUID", uuid.UUID, field=fields.UUID)
Expand Down Expand Up @@ -40,7 +41,7 @@ class KeyData:
key: Optional[str] = field(default=None, repr=False, metadata={"load_only": True})

async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()
return self.expires_at > utcnow()


@dataclass
Expand All @@ -53,7 +54,7 @@ class RefreshTokenData:
key_id: int

async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()
return self.expires_at > utcnow()


# Skipping None values on dump since 'aud' is optional and can't be None/empty
Expand All @@ -68,8 +69,7 @@ class JWTPayloadData:
aud: Optional[List[str]] = None # JWT Audience

async def is_valid(self) -> bool:
return self.exp > datetime.utcnow().timestamp()

return self.exp > utcnow().timestamp()

# Request dataclasses
@dataclass
Expand Down
4 changes: 2 additions & 2 deletions jwthenticator/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random
from os import environ
from string import ascii_letters
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from hashlib import sha512
from uuid import uuid4

Expand All @@ -32,7 +32,7 @@ async def hash_key(key: str) -> str:


async def future_datetime(seconds: int = 0) -> datetime:
return datetime.utcnow() + timedelta(seconds=seconds)
return utils.utcnow() + timedelta(seconds=seconds)


def backup_environment(func): # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jwt
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.utils import create_async_session_factory, utcnow
from jwthenticator.models import Base, RefreshTokenInfo
from jwthenticator.schemas import JWTPayloadData, RefreshTokenData
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
Expand Down Expand Up @@ -51,7 +51,7 @@ async def create_access_token(self, identifier: UUID) -> str:
"""
if self.private_key is None:
raise RuntimeError("Private key required for JWT token creation")
utc_now = datetime.utcnow()
utc_now = utcnow()
payload = JWTPayloadData(
token_id=uuid4(),
identifier=identifier,
Expand Down Expand Up @@ -87,8 +87,8 @@ async def create_refresh_token(self, key_id: int, expires_at: Optional[datetime]
:return: The refresh token created.
"""
if expires_at is None:
expires_at = expires_at = datetime.utcnow() + timedelta(seconds=REFRESH_TOKEN_EXPIRY)
if expires_at <= datetime.utcnow():
expires_at = expires_at = utcnow() + timedelta(seconds=REFRESH_TOKEN_EXPIRY)
if expires_at <= utcnow():
raise RuntimeError("Refresh token can't be created in the past")

refresh_token_str = sha512(uuid4().bytes).hexdigest()
Expand Down
4 changes: 4 additions & 0 deletions jwthenticator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, Tuple, Optional
from urllib.parse import urlparse

from datetime import datetime, timezone
from jwt.utils import base64url_encode
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA1
Expand Down Expand Up @@ -119,3 +120,6 @@ def create_async_session_factory(uri: str, base: Optional[Any] = None, **engine_
if base is not None:
asyncio.get_event_loop().run_until_complete(create_base(engine, base))
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)

def utcnow() -> datetime:
return datetime.now(timezone.utc)

0 comments on commit de80a19

Please sign in to comment.