Skip to content

Commit 2c9fb05

Browse files
authored
Merge pull request #10 from guyzyl/main
Added support for JWT payload audience ('aud')
2 parents 44a03bf + 444773c commit 2c9fb05

File tree

5 files changed

+35
-13
lines changed

5 files changed

+35
-13
lines changed

jwthenticator/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from jwthenticator import schemas
1010
from jwthenticator.tokens import TokenManager
1111
from jwthenticator.keys import KeyManager
12-
from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME
12+
from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY, JWT_LEASE_TIME, JWT_AUDIENCE
1313
from jwthenticator.utils import get_rsa_key_pair
1414
from jwthenticator.exceptions import ExpiredError
1515

@@ -23,8 +23,9 @@ class JWThenticatorAPI:
2323
"""
2424

2525
# pylint: disable=too-many-arguments
26-
def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(), jwt_lease_time: int = JWT_LEASE_TIME,
27-
jwt_algorithm: str = JWT_ALGORITHM, jwt_algorithm_family: str = JWT_ALGORITHM_FAMILY):
26+
def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(),
27+
jwt_lease_time: int = JWT_LEASE_TIME, jwt_algorithm: str = JWT_ALGORITHM,
28+
jwt_algorithm_family: str = JWT_ALGORITHM_FAMILY, jwt_audience: Optional[str] = JWT_AUDIENCE):
2829
"""
2930
Class can be initiated without giving any parameter, will generate RSA key pair by itself.
3031
:param rsa_key_pair: (public_key, private_key) RSA key pair. Will generate keys if not given
@@ -33,9 +34,8 @@ def __init__(self, rsa_key_pair: Tuple[str, Optional[str]] = get_rsa_key_pair(),
3334
self.public_key, self._private_key = rsa_key_pair
3435
self.jwt_algorithm = jwt_algorithm
3536
self.jwt_algorithm_family = jwt_algorithm_family
36-
self.jwt_lease_time = jwt_lease_time
3737

38-
self.token_manager = TokenManager(self.public_key, self._private_key, self.jwt_algorithm, self.jwt_lease_time)
38+
self.token_manager = TokenManager(self.public_key, self._private_key, self.jwt_algorithm, jwt_lease_time, jwt_audience)
3939
self.key_manager = KeyManager()
4040

4141

jwthenticator/consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
JWT_ALGORITHM_FAMILY = environ.get("JWT_ALGORITHM_FAMILY", "RSA")
1717
JWT_LEASE_TIME = int(environ.get("JWT_LEASE_TIME", 30 * 60)) # In seconds - 30 minutes
1818
RSA_KEY_STRENGTH = int(environ.get("RSA_KEY_STRENGTH", 2048))
19+
JWT_AUDIENCE = environ.get("JWT_AUDIENCE", None)
1920

2021
# Token consts
2122
KEY_EXPIRY = int(environ.get("KEY_EXPIRY", DAYS_TO_SECONDS(120))) # In seconds

jwthenticator/schemas.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,31 @@
33

44
import uuid
55
from dataclasses import field
6-
from typing import Optional, List, ClassVar, Type
6+
from typing import Optional, List, ClassVar, Type, Dict, Any
77
from datetime import datetime
88

9-
from marshmallow import Schema, fields
10-
from marshmallow_dataclass import dataclass, NewType
9+
from marshmallow import Schema, fields, post_dump
10+
from marshmallow_dataclass import dataclass, NewType, add_schema
1111

1212
from jwthenticator.consts import JWT_ALGORITHM, JWT_ALGORITHM_FAMILY
1313

1414
# Define the UUID type that uses Marshmallow's UUID + Python's UUID
1515
UUID = NewType("UUID", uuid.UUID, field=fields.UUID)
1616

1717

18+
# Marshmallow base schema for skipping None values on dump
19+
class BaseSchema(Schema):
20+
SKIP_VALUES = {None}
21+
22+
@post_dump
23+
# pylint: disable=unused-argument
24+
def remove_skip_values(self, data: Any, many: bool) -> Dict[Any, Any]:
25+
return {
26+
key: value for key, value in data.items()
27+
if value not in self.SKIP_VALUES
28+
}
29+
30+
1831
# Data dataclasses (that match the sqlalchemy models)
1932
@dataclass # pylint: disable=used-before-assignment
2033
class KeyData:
@@ -42,13 +55,16 @@ async def is_valid(self) -> bool:
4255
return self.expires_at > datetime.utcnow()
4356

4457

58+
# Skipping None values on dump since aud is optional and can't be None/empty
59+
@add_schema(base_schema=BaseSchema)
4560
@dataclass
4661
class JWTPayloadData:
4762
Schema: ClassVar[Type[Schema]] = Schema
4863
token_id: UUID # JWT token identifier
4964
identifier: UUID # Machine the JWT was issued to identifier
5065
iat: int # Issued at timestamp
5166
exp: int # Expires at timestamp
67+
aud: Optional[str] = None # JWT Audience
5268

5369
async def is_valid(self) -> bool:
5470
return self.exp > datetime.utcnow().timestamp()

jwthenticator/tests/test_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import absolute_import
33

44
import inspect
5+
from os.path import basename
56
from uuid import uuid4
67
from http import HTTPStatus
78
from typing import Union
@@ -38,8 +39,8 @@ def __init__(self, test_client: TestClient):
3839

3940
async def __call__(self) -> Union[TestClient, ClientSessionType]:
4041
context = inspect.stack()
41-
caller_file = context[1].filename
42-
if any([i in caller_file for i in CLIENT_PATCH_FILES]):
42+
caller_file = basename(context[1].filename)
43+
if caller_file in CLIENT_PATCH_FILES:
4344
return self.test_client
4445
return ClientSession()
4546

jwthenticator/tokens.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
from jwthenticator.models import Base, RefreshTokenInfo
1212
from jwthenticator.schemas import JWTPayloadData, RefreshTokenData
1313
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
14-
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, DB_URI
14+
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI
1515

1616

1717
class TokenManager:
1818
"""
1919
Class responsible for the creation and loading of tokens
2020
"""
2121

22-
def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm: str = JWT_ALGORITHM, jwt_lease_time: int = JWT_LEASE_TIME):
22+
# pylint: disable=too-many-arguments,too-many-instance-attributes
23+
def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm: str = JWT_ALGORITHM,
24+
jwt_lease_time: int = JWT_LEASE_TIME, jwt_audience: Optional[str] = JWT_AUDIENCE):
2325
"""
2426
Accepts public + private key pairs.
2527
If only public key is given tokens can be loaded but not created.
@@ -32,6 +34,7 @@ def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm
3234
self.private_key = private_key
3335
self.algorithm = algorithm
3436
self.jwt_lease_time = jwt_lease_time
37+
self.jwt_audience = jwt_audience
3538

3639
self.refresh_token_schema = RefreshTokenData.Schema()
3740
self.jwt_payload_data_schema = JWTPayloadData.Schema()
@@ -52,7 +55,8 @@ async def create_access_token(self, identifier: UUID) -> str:
5255
token_id=uuid4(),
5356
identifier=identifier,
5457
iat=int(now.timestamp()),
55-
exp=int((now + timedelta(seconds=self.jwt_lease_time)).timestamp())
58+
exp=int((now + timedelta(seconds=self.jwt_lease_time)).timestamp()),
59+
aud=self.jwt_audience
5660
)
5761
encoded_payload = self.jwt_payload_data_schema.dump(payload)
5862
token_string = jwt.encode(encoded_payload, self.private_key, self.algorithm)

0 commit comments

Comments
 (0)