Skip to content

Commit

Permalink
Merge pull request #10 from guyzyl/main
Browse files Browse the repository at this point in the history
Added support for JWT payload audience ('aud')
  • Loading branch information
guyzyl authored Feb 25, 2021
2 parents 44a03bf + 444773c commit 2c9fb05
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 13 deletions.
10 changes: 5 additions & 5 deletions jwthenticator/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jwthenticator import schemas
from jwthenticator.tokens import TokenManager
from jwthenticator.keys import KeyManager
from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME
from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME, JWT_AUDIENCE
from jwthenticator.utils import get_rsa_key_pair
from jwthenticator.exceptions import ExpiredError

Expand All @@ -23,8 +23,9 @@ class JWThenticatorAPI:
"""

# pylint: disable=too-many-arguments
def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(), jwt_lease_time: int = JWT_LEASE_TIME,
jwt_algorithm: str = JWT_ALGORITHM, jwt_algorithm_family: str = JWT_ALGORITHM_FAMILY):
def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(),
jwt_lease_time: int = JWT_LEASE_TIME, jwt_algorithm: str = JWT_ALGORITHM,
jwt_algorithm_family: str = JWT_ALGORITHM_FAMILY, jwt_audience: Optional[str] = JWT_AUDIENCE):
"""
Class can be initiated without giving any parameter, will generate RSA key pair by itself.
:param rsa_key_pair: (public_key, private_key) RSA key pair. Will generate keys if not given
Expand All @@ -33,9 +34,8 @@ def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(),
self.public_key, self._private_key = rsa_key_pair
self.jwt_algorithm = jwt_algorithm
self.jwt_algorithm_family = jwt_algorithm_family
self.jwt_lease_time = jwt_lease_time

self.token_manager = TokenManager(self.public_key, self._private_key, self.jwt_algorithm, self.jwt_lease_time)
self.token_manager = TokenManager(self.public_key, self._private_key, self.jwt_algorithm, jwt_lease_time, jwt_audience)
self.key_manager = KeyManager()


Expand Down
1 change: 1 addition & 0 deletions jwthenticator/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
JWT_ALGORITHM_FAMILY = environ.get("JWT_ALGORITHM_FAMILY", "RSA")
JWT_LEASE_TIME = int(environ.get("JWT_LEASE_TIME", 30 * 60)) # In seconds - 30 minutes
RSA_KEY_STRENGTH = int(environ.get("RSA_KEY_STRENGTH", 2048))
JWT_AUDIENCE = environ.get("JWT_AUDIENCE", None)

# Token consts
KEY_EXPIRY = int(environ.get("KEY_EXPIRY", DAYS_TO_SECONDS(120))) # In seconds
Expand Down
22 changes: 19 additions & 3 deletions jwthenticator/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,31 @@

import uuid
from dataclasses import field
from typing import Optional, List, ClassVar, Type
from typing import Optional, List, ClassVar, Type, Dict, Any
from datetime import datetime

from marshmallow import Schema, fields
from marshmallow_dataclass import dataclass, NewType
from marshmallow import Schema, fields, post_dump
from marshmallow_dataclass import dataclass, NewType, add_schema

from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY

# Define the UUID type that uses Marshmallow's UUID + Python's UUID
UUID = NewType("UUID", uuid.UUID, field=fields.UUID)


# Marshmallow base schema for skipping None values on dump
class BaseSchema(Schema):
SKIP_VALUES = {None}

@post_dump
# pylint: disable=unused-argument
def remove_skip_values(self, data: Any, many: bool) -> Dict[Any, Any]:
return {
key: value for key, value in data.items()
if value not in self.SKIP_VALUES
}


# Data dataclasses (that match the sqlalchemy models)
@dataclass # pylint: disable=used-before-assignment
class KeyData:
Expand Down Expand Up @@ -42,13 +55,16 @@ async def is_valid(self) -> bool:
return self.expires_at > datetime.utcnow()


# Skipping None values on dump since aud is optional and can't be None/empty
@add_schema(base_schema=BaseSchema)
@dataclass
class JWTPayloadData:
Schema: ClassVar[Type[Schema]] = Schema
token_id: UUID # JWT token identifier
identifier: UUID # Machine the JWT was issued to identifier
iat: int # Issued at timestamp
exp: int # Expires at timestamp
aud: Optional[str] = None # JWT Audience

async def is_valid(self) -> bool:
return self.exp > datetime.utcnow().timestamp()
Expand Down
5 changes: 3 additions & 2 deletions jwthenticator/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import absolute_import

import inspect
from os.path import basename
from uuid import uuid4
from http import HTTPStatus
from typing import Union
Expand Down Expand Up @@ -38,8 +39,8 @@ def __init__(self, test_client: TestClient):

async def __call__(self) -> Union[TestClient, ClientSessionType]:
context = inspect.stack()
caller_file = context[1].filename
if any([i in caller_file for i in CLIENT_PATCH_FILES]):
caller_file = basename(context[1].filename)
if caller_file in CLIENT_PATCH_FILES:
return self.test_client
return ClientSession()

Expand Down
10 changes: 7 additions & 3 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
from jwthenticator.models import Base, RefreshTokenInfo
from jwthenticator.schemas import JWTPayloadData, RefreshTokenData
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, DB_URI
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI


class TokenManager:
"""
Class responsible for the creation and loading of tokens
"""

def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm: str = JWT_ALGORITHM, jwt_lease_time: int = JWT_LEASE_TIME):
# pylint: disable=too-many-arguments,too-many-instance-attributes
def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm: str = JWT_ALGORITHM,
jwt_lease_time: int = JWT_LEASE_TIME, jwt_audience: Optional[str] = JWT_AUDIENCE):
"""
Accepts public + private key pairs.
If only public key is given tokens can be loaded but not created.
Expand All @@ -32,6 +34,7 @@ def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm
self.private_key = private_key
self.algorithm = algorithm
self.jwt_lease_time = jwt_lease_time
self.jwt_audience = jwt_audience

self.refresh_token_schema = RefreshTokenData.Schema()
self.jwt_payload_data_schema = JWTPayloadData.Schema()
Expand All @@ -52,7 +55,8 @@ async def create_access_token(self, identifier: UUID) -> str:
token_id=uuid4(),
identifier=identifier,
iat=int(now.timestamp()),
exp=int((now + timedelta(seconds=self.jwt_lease_time)).timestamp())
exp=int((now + timedelta(seconds=self.jwt_lease_time)).timestamp()),
aud=self.jwt_audience
)
encoded_payload = self.jwt_payload_data_schema.dump(payload)
token_string = jwt.encode(encoded_payload, self.private_key, self.algorithm)
Expand Down

0 comments on commit 2c9fb05

Please sign in to comment.