Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user based secret key #51

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ Some of Simple JWT's behavior can be customized through settings variables in
'BLACKLIST_AFTER_ROTATION': True,

'ALGORITHM': 'HS256',
'SIGNING_KEY': settings.SECRET_KEY,
'JWT_SECRET_KEY': settings.SECRET_KEY,
'SIGNING_KEY': None,
'VERIFYING_KEY': None,
'GET_USER_SECRET_KEY': None

'AUTH_HEADER_TYPES': ('Bearer',),
'USER_ID_FIELD': 'id',
Expand Down Expand Up @@ -202,24 +204,34 @@ ALGORITHM
key. Likewise, the ``VERIFYING_KEY`` setting must be set to a string which
contains an RSA public key.

SIGNING_KEY
The signing key which is used to sign the content of generated tokens. For
HMAC signing, this should be a random string with at least as many bits of
data as is required by the signing protocol. For RSA signing, this
should be a string which contains an RSA private key which is 2048 bits or
longer. Since Simple JWT defaults to using 256-bit HMAC signing, the
``SIGNING_KEY`` setting defaults to the value of the ``SECRET_KEY`` setting
JWT_SECRET_KEY
The signing key which is used to sign the content of generated tokens when HMAC algorithm
is chosen, this should be a random string with at least as many bits of
data as is required by the signing protocol.
Since Simple JWT defaults to using 256-bit HMAC signing, the
``JWT_SECRET_KEY`` setting defaults to the value of the ``SECRET_KEY`` setting
for your django project. Although this is the most reasonable default that
Simple JWT can provide, it is recommended that developers change this setting
to a value which is independent from the django project secret key. This
will make changing the signing key used for tokens easier in the event that
to a value which is independent from the django project secret key.
This will make changing the signing key used for tokens easier in the event that
it is compromised.

GET_USER_SECRET_KEY
A function that will be called with the value of ``USER_ID_FIELD`` when decoding a token.
It should return a secret key that is user dependant. This lets you specify a secret key for
a user and for instance make the key change once the user changes his password in order to
invalidate token created before the password change.

SIGNING_KEY
The signing key which is used to sign the content of generated tokens when RSA algorithm
is chosen, this should be a string which contains an RSA private key which is 2048 bits or
longer.

VERIFYING_KEY
The verifying key which is used to verify the content of generated tokens.
If an HMAC algorithm has been specified by the ``ALGORITHM`` setting, the
``VERIFYING_KEY`` setting will be ignored and the value of the
``SIGNING_KEY`` setting will be used. If an RSA algorithm has been specified
``JWT_SECRET_KEY`` setting will be used. If an RSA algorithm has been specified
by the ``ALGORITHM`` setting, the ``VERIFYING_KEY`` setting must be set to a
string which contains an RSA public key.

Expand Down
30 changes: 22 additions & 8 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import unicode_literals

import jwt
from django.utils.translation import ugettext_lazy as _
from jwt import InvalidTokenError
import jwt

from .exceptions import TokenBackendError
from .settings import api_settings
from .utils import format_lazy

ALLOWED_ALGORITHMS = (
Expand All @@ -18,22 +19,33 @@


class TokenBackend(object):
def __init__(self, algorithm, signing_key=None, verifying_key=None):
def __init__(self, algorithm, signing_key=None, verifying_key=None, secret_key=None, get_user_secret_key=None):
if algorithm not in ALLOWED_ALGORITHMS:
raise TokenBackendError(format_lazy(_("Unrecognized algorithm type '{}'"), algorithm))

self.algorithm = algorithm
self.signing_key = signing_key
if algorithm.startswith('HS'):
self.verifying_key = signing_key
else:
self.verifying_key = verifying_key
self.verifying_key = verifying_key
self.secret_key = secret_key
self.get_user_secret_key = get_user_secret_key

def get_secret_key(self, payload):
if self.get_user_secret_key:
return self.get_user_secret_key(payload[api_settings.USER_ID_FIELD])
return self.secret_key

def get_signing_key(self, payload):
return self.signing_key or self.get_secret_key(payload)

def get_verifying_key(self, payload):
return self.verifying_key or self.get_secret_key(payload)

def encode(self, payload):
"""
Returns an encoded token for the given payload dictionary.
"""
token = jwt.encode(payload, self.signing_key, algorithm=self.algorithm)
signing_key = self.get_signing_key(payload)
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
return token.decode('utf-8')

def decode(self, token, verify=True):
Expand All @@ -45,6 +57,8 @@ def decode(self, token, verify=True):
signature check fails, or if its 'exp' claim indicates it has expired.
"""
try:
return jwt.decode(token, self.verifying_key, algorithms=[self.algorithm], verify=verify)
unverified_payload = jwt.decode(token, None, False)
verifying_key = self.get_verifying_key(unverified_payload)
return jwt.decode(token, verifying_key, algorithms=[self.algorithm], verify=verify)
except InvalidTokenError:
raise TokenBackendError(_('Token is invalid or expired'))
4 changes: 3 additions & 1 deletion rest_framework_simplejwt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
'BLACKLIST_AFTER_ROTATION': True,

'ALGORITHM': 'HS256',
'SIGNING_KEY': settings.SECRET_KEY,
'JWT_SECRET_KEY': settings.SECRET_KEY,
'GET_USER_SECRET_KEY': None,
'SIGNING_KEY': None,
'VERIFYING_KEY': None,

'AUTH_HEADER_TYPES': ('Bearer',),
Expand Down
9 changes: 7 additions & 2 deletions rest_framework_simplejwt/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,10 @@
from .settings import api_settings

User = get_user_model()
token_backend = TokenBackend(api_settings.ALGORITHM, api_settings.SIGNING_KEY,
api_settings.VERIFYING_KEY)
token_backend = TokenBackend(
api_settings.ALGORITHM,
signing_key=api_settings.SIGNING_KEY,
verifying_key=api_settings.VERIFYING_KEY,
secret_key=api_settings.JWT_SECRET_KEY,
get_user_secret_key=api_settings.GET_USER_SECRET_KEY
)
35 changes: 31 additions & 4 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import unicode_literals

from datetime import datetime, timedelta
from random import randint

import jwt
from django.test import TestCase
Expand Down Expand Up @@ -55,8 +56,9 @@

class TestTokenBackend(TestCase):
def setUp(self):
self.hmac_token_backend = TokenBackend('HS256', SECRET)
self.rsa_token_backend = TokenBackend('RS256', PRIVATE_KEY, PUBLIC_KEY)
self.hmac_token_backend = TokenBackend('HS256', secret_key=SECRET)
self.hmac_custom_user_key_token_backend = TokenBackend('HS256', get_user_secret_key=lambda a: a)
self.rsa_token_backend = TokenBackend('RS256', signing_key=PRIVATE_KEY, verifying_key=PUBLIC_KEY)
self.payload = {'foo': 'bar'}

def test_init(self):
Expand Down Expand Up @@ -96,6 +98,31 @@ def test_encode_rsa(self):
),
)

def test_encode_hmac_custom(self):
payload = {
'id': '1234'
}
hmac_token = self.hmac_custom_user_key_token_backend.encode(payload)
self.assertIn(
hmac_token,
(
'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6IjEyMzQifQ.RMZuO9SRBYS0pLh8DVhvknBSs80OfFvzxbl-y9b5pnc'
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjEyMzQifQ.1weakZ8seqK0UKjpLrpfJaCEr9B7rJiYLGa1TR64QF4'
)
)

def test_decode_hmac_custom_changed_user(self):
token_backend = TokenBackend('HS256', get_user_secret_key=lambda a: str(a) + str(randint(0, 1000000000)))
payload = {
'exp': make_utc(datetime(year=2000, month=1, day=1)),
'id': '1234'
}

token = token_backend.encode(payload)

with self.assertRaises(TokenBackendError):
self.hmac_token_backend.decode(token)

def test_decode_hmac_with_no_expiry(self):
no_exp_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

Expand Down Expand Up @@ -162,7 +189,7 @@ def test_decode_rsa_with_no_expiry_no_verify(self):
no_exp_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')

self.assertEqual(
self.hmac_token_backend.decode(no_exp_token, verify=False),
self.rsa_token_backend.decode(no_exp_token, verify=False),
self.payload,
)

Expand Down Expand Up @@ -198,7 +225,7 @@ def test_decode_rsa_with_invalid_sig_no_verify(self):
invalid_token = token_2_payload + '.' + token_1_sig

self.assertEqual(
self.hmac_token_backend.decode(invalid_token, verify=False),
self.rsa_token_backend.decode(invalid_token, verify=False),
self.payload,
)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def test_init_bad_sig_token_given(self):
# Test backend rejects encoded token (expired or bad signature)
payload = {'foo': 'bar'}
payload['exp'] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_1 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')
payload['foo'] = 'baz'
token_2 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_2 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')

token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
Expand All @@ -112,9 +112,9 @@ def test_init_bad_sig_token_given_no_verify(self):
# Test backend rejects encoded token (expired or bad signature)
payload = {'foo': 'bar'}
payload['exp'] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_1 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')
payload['foo'] = 'baz'
token_2 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_2 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')

token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
Expand Down