Skip to content

Commit

Permalink
Added support for user models other than Django's AUTH_USER_MODEL.
Browse files Browse the repository at this point in the history
  • Loading branch information
greyhare committed Mar 23, 2019
1 parent 2b6de9d commit 3561ba5
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 24 deletions.
8 changes: 5 additions & 3 deletions rest_framework_simplejwt/authentication.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from django.contrib.auth import get_user_model
from django.utils.translation import ugettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, authentication

from .exceptions import AuthenticationFailed, InvalidToken, TokenError
from .models import TokenUser
from .settings import api_settings
from .state import User

AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES

Expand All @@ -24,6 +24,8 @@ class JWTAuthentication(authentication.BaseAuthentication):
"""
www_authenticate_realm = 'api'

user_class = get_user_model()

def authenticate(self, request):
header = self.get_header(request)
if header is None:
Expand Down Expand Up @@ -108,8 +110,8 @@ def get_user(self, validated_token):
raise InvalidToken(_('Token contained no recognizable user identification'))

try:
user = User.objects.get(**{api_settings.USER_ID_FIELD: user_id})
except User.DoesNotExist:
user = self.user_class.objects.get(**{api_settings.USER_ID_FIELD: user_id})
except self.user_class.DoesNotExist:
raise AuthenticationFailed(_('User not found'), code='user_not_found')

if not user.is_active:
Expand Down
5 changes: 2 additions & 3 deletions rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from django.contrib.auth import authenticate
from django.contrib.auth import authenticate, get_user_model
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers

from .settings import api_settings
from .state import User
from .tokens import RefreshToken, SlidingToken, UntypedToken


Expand All @@ -18,7 +17,7 @@ def __init__(self, *args, **kwargs):


class TokenObtainSerializer(serializers.Serializer):
username_field = User.USERNAME_FIELD
username_field = get_user_model().USERNAME_FIELD

default_error_messages = {
'no_active_account': _('No active account found with the given credentials')
Expand Down
8 changes: 0 additions & 8 deletions rest_framework_simplejwt/state.py

This file was deleted.

13 changes: 7 additions & 6 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.conf import settings
from django.utils.translation import ugettext_lazy as _

from .backends import TokenBackend
from .exceptions import TokenBackendError, TokenError
from .settings import api_settings
from .token_blacklist.models import BlacklistedToken, OutstandingToken
Expand All @@ -17,6 +18,10 @@ class Token:
A class which validates and wraps an existing JWT or can be used to build a
new JWT.
"""
token_backend = TokenBackend(
api_settings.ALGORITHM, api_settings.SIGNING_KEY,
api_settings.VERIFYING_KEY
)
token_type = None
lifetime = None

Expand All @@ -35,11 +40,9 @@ def __init__(self, token=None, verify=True):
# Set up token
if token is not None:
# An encoded token was provided
from .state import token_backend

# Decode token
try:
self.payload = token_backend.decode(token, verify=verify)
self.payload = self.token_backend.decode(token, verify=verify)
except TokenBackendError:
raise TokenError(_('Token is invalid or expired'))

Expand Down Expand Up @@ -77,9 +80,7 @@ def __str__(self):
"""
Signs and returns a token as a base64 encoded string.
"""
from .state import token_backend

return token_backend.encode(self.payload)
return self.token_backend.encode(self.payload)

def verify(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import timedelta

from django.contrib.auth import get_user_model

from rest_framework_simplejwt.compat import reverse
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import User
from rest_framework_simplejwt.tokens import AccessToken

from .utils import APIViewTestCase, override_api_settings

User = get_user_model()

class TestTestView(APIViewTestCase):
view_name = 'test_view'
Expand Down
4 changes: 3 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from unittest.mock import MagicMock, patch

from django.contrib.auth import get_user_model
from django.test import TestCase

from rest_framework_simplejwt.exceptions import TokenError
Expand All @@ -10,7 +11,6 @@
TokenRefreshSlidingSerializer, TokenVerifySerializer,
)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import User
from rest_framework_simplejwt.token_blacklist.models import (
BlacklistedToken, OutstandingToken,
)
Expand All @@ -23,6 +23,8 @@

from .utils import override_api_settings

User = get_user_model()


class TestTokenObtainSerializer(TestCase):
def setUp(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from datetime import datetime, timedelta
from unittest.mock import patch

from django.contrib.auth import get_user_model
from django.test import TestCase
from jose import jwt

from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import User
from rest_framework_simplejwt.tokens import (
AccessToken, RefreshToken, SlidingToken, Token, UntypedToken,
)
Expand All @@ -16,6 +16,8 @@

from .utils import override_api_settings

User = get_user_model()


class MyToken(Token):
token_type = 'test'
Expand Down
5 changes: 4 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import timedelta
from unittest.mock import patch

from django.contrib.auth import get_user_model

from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import User
from rest_framework_simplejwt.tokens import (
AccessToken, RefreshToken, SlidingToken,
)
Expand All @@ -12,6 +13,8 @@

from .utils import APIViewTestCase

User = get_user_model()


class TestTokenObtainPairView(APIViewTestCase):
view_name = 'token_obtain_pair'
Expand Down

0 comments on commit 3561ba5

Please sign in to comment.