Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent incorrect usage of Token.for_user #804

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def validate(self, attrs: Dict[str, Any]) -> Dict[Any, Any]:

@classmethod
def get_token(cls, user: AuthUser) -> Token:
return cls.token_class.for_user(user) # type: ignore
return cls.token_class.for_validated_user(user) # type: ignore


class TokenObtainPairSerializer(TokenObtainSerializer):
Expand Down
12 changes: 10 additions & 2 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ def for_user(cls: Type[T], user: AuthUser) -> T:
Returns an authorization token for the given user that will be provided
after authenticating the user's credentials.
"""
# Prevent incorrect usage of Token.for_user (when creating tokens manually)
# see https://github.com/jazzband/djangorestframework-simplejwt/issues/779
if not api_settings.USER_AUTHENTICATION_RULE(user):
raise TokenError(_("Token is invalid or expired"))
return cls.for_validated_user(user)

@classmethod
def for_validated_user(cls: Type[T], user: AuthUser) -> T:
user_id = getattr(user, api_settings.USER_ID_FIELD)
pchiquet marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(user_id, int):
user_id = str(user_id)
Expand Down Expand Up @@ -278,11 +286,11 @@ def blacklist(self) -> BlacklistedToken:
return BlacklistedToken.objects.get_or_create(token=token)

@classmethod
def for_user(cls: Type[T], user: AuthUser) -> T:
def for_validated_user(cls: Type[T], user: AuthUser) -> T:
"""
Adds this token to the outstanding token list.
"""
token = super().for_user(user) # type: ignore
token = super().for_validated_user(user) # type: ignore

jti = token[api_settings.JTI_CLAIM]
exp = token["exp"]
Expand Down
9 changes: 9 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,15 @@ def test_for_user_with_username(self):
token = MyToken.for_user(self.user)
self.assertEqual(token[api_settings.USER_ID_CLAIM], self.username)

def test_for_user_fails_if_is_active_false(self):
# works with is_active=True
token = MyToken.for_user(self.user)

# fails with is_active=False
self.user.is_active = False
with self.assertRaises(TokenError):
token = MyToken.for_user(self.user)

@override_api_settings(CHECK_REVOKE_TOKEN=True)
def test_revoke_token_claim_included_in_authorization_token(self):
token = MyToken.for_user(self.user)
Expand Down