Skip to content

Commit

Permalink
Add support for user based secret key
Browse files Browse the repository at this point in the history
In order to address this issue :
#17

I propose the following solution:
- Add a JWT_SECRET_KEY setting that should be used for hmac algorithm
- The SIGNING_KEY and VERIFYING_KEY should be used only for asymmetric
algorithm
- Add a GET_USER_SECRET_KEY setting which expect a function that will
be called with the user id as defined by the USER_ID_FIELD and return a
key that can change for instance when the user changed his password
  • Loading branch information
lucas-foodles committed Oct 29, 2018
1 parent c6bad8f commit af3785c
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 29 deletions.
34 changes: 23 additions & 11 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ Some of Simple JWT's behavior can be customized through settings variables in
'BLACKLIST_AFTER_ROTATION': True,
'ALGORITHM': 'HS256',
'SIGNING_KEY': settings.SECRET_KEY,
'JWT_SECRET_KEY': settings.SECRET_KEY,
'SIGNING_KEY': None,
'VERIFYING_KEY': None,
'GET_USER_SECRET_KEY': None
'AUTH_HEADER_TYPES': ('Bearer',),
'USER_ID_FIELD': 'id',
Expand Down Expand Up @@ -202,24 +204,34 @@ ALGORITHM
key. Likewise, the ``VERIFYING_KEY`` setting must be set to a string which
contains an RSA public key.

SIGNING_KEY
The signing key which is used to sign the content of generated tokens. For
HMAC signing, this should be a random string with at least as many bits of
data as is required by the signing protocol. For RSA signing, this
should be a string which contains an RSA private key which is 2048 bits or
longer. Since Simple JWT defaults to using 256-bit HMAC signing, the
``SIGNING_KEY`` setting defaults to the value of the ``SECRET_KEY`` setting
JWT_SECRET_KEY
The signing key which is used to sign the content of generated tokens when HMAC algorithm
is chosen, this should be a random string with at least as many bits of
data as is required by the signing protocol.
Since Simple JWT defaults to using 256-bit HMAC signing, the
``JWT_SECRET_KEY`` setting defaults to the value of the ``SECRET_KEY`` setting
for your django project. Although this is the most reasonable default that
Simple JWT can provide, it is recommended that developers change this setting
to a value which is independent from the django project secret key. This
will make changing the signing key used for tokens easier in the event that
to a value which is independent from the django project secret key.
This will make changing the signing key used for tokens easier in the event that
it is compromised.

GET_USER_SECRET_KEY
A function that will be called with the value of ``USER_ID_FIELD`` when decoding a token.
It should return a secret key that is user dependant. This lets you specify a secret key for
a user and for instance make the key change once the user changes his password in order to
invalidate token created before the password change.

SIGNING_KEY
The signing key which is used to sign the content of generated tokens when RSA algorithm
is chosen, this should be a string which contains an RSA private key which is 2048 bits or
longer.

VERIFYING_KEY
The verifying key which is used to verify the content of generated tokens.
If an HMAC algorithm has been specified by the ``ALGORITHM`` setting, the
``VERIFYING_KEY`` setting will be ignored and the value of the
``SIGNING_KEY`` setting will be used. If an RSA algorithm has been specified
``JWT_SECRET_KEY`` setting will be used. If an RSA algorithm has been specified
by the ``ALGORITHM`` setting, the ``VERIFYING_KEY`` setting must be set to a
string which contains an RSA public key.

Expand Down
28 changes: 21 additions & 7 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .exceptions import TokenBackendError
from .utils import format_lazy
from .settings import api_settings

ALLOWED_ALGORITHMS = (
'HS256',
Expand All @@ -18,22 +19,33 @@


class TokenBackend(object):
def __init__(self, algorithm, signing_key=None, verifying_key=None):
def __init__(self, algorithm, signing_key=None, verifying_key=None, secret_key=None, get_user_secret_key=None):
if algorithm not in ALLOWED_ALGORITHMS:
raise TokenBackendError(format_lazy(_("Unrecognized algorithm type '{}'"), algorithm))

self.algorithm = algorithm
self.signing_key = signing_key
if algorithm.startswith('HS'):
self.verifying_key = signing_key
else:
self.verifying_key = verifying_key
self.verifying_key = verifying_key
self.secret_key = secret_key
self.get_user_secret_key = get_user_secret_key

def get_secret_key(self, payload):
if self.get_user_secret_key:
return self.get_user_secret_key(payload[api_settings.USER_ID_FIELD])
return self.secret_key

def get_signing_key(self, payload):
return self.signing_key or self.get_secret_key(payload)

def get_verifying_key(self, payload):
return self.verifying_key or self.get_secret_key(payload)

def encode(self, payload):
"""
Returns an encoded token for the given payload dictionary.
"""
token = jwt.encode(payload, self.signing_key, algorithm=self.algorithm)
signing_key = self.get_signing_key(payload)
token = jwt.encode(payload, signing_key, algorithm=self.algorithm)
return token.decode('utf-8')

def decode(self, token, verify=True):
Expand All @@ -45,6 +57,8 @@ def decode(self, token, verify=True):
signature check fails, or if its 'exp' claim indicates it has expired.
"""
try:
return jwt.decode(token, self.verifying_key, algorithms=[self.algorithm], verify=verify)
unverified_payload = jwt.decode(token, None, False)
verifying_key = self.get_verifying_key(unverified_payload)
return jwt.decode(token, verifying_key, algorithms=[self.algorithm], verify=verify)
except InvalidTokenError:
raise TokenBackendError(_('Token is invalid or expired'))
4 changes: 3 additions & 1 deletion rest_framework_simplejwt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
'BLACKLIST_AFTER_ROTATION': True,

'ALGORITHM': 'HS256',
'SIGNING_KEY': settings.SECRET_KEY,
'JWT_SECRET_KEY': settings.SECRET_KEY,
'GET_USER_SECRET_KEY': None,
'SIGNING_KEY': None,
'VERIFYING_KEY': None,

'AUTH_HEADER_TYPES': ('Bearer',),
Expand Down
9 changes: 7 additions & 2 deletions rest_framework_simplejwt/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,10 @@
from .settings import api_settings

User = get_user_model()
token_backend = TokenBackend(api_settings.ALGORITHM, api_settings.SIGNING_KEY,
api_settings.VERIFYING_KEY)
token_backend = TokenBackend(
api_settings.ALGORITHM,
signing_key=api_settings.SIGNING_KEY,
verifying_key=api_settings.VERIFYING_KEY,
secret_key=api_settings.JWT_SECRET_KEY,
get_user_secret_key=api_settings.GET_USER_SECRET_KEY
)
35 changes: 31 additions & 4 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
from rest_framework_simplejwt.utils import aware_utcnow, make_utc
from random import randint

SECRET = 'not_secret'

Expand Down Expand Up @@ -55,8 +56,9 @@

class TestTokenBackend(TestCase):
def setUp(self):
self.hmac_token_backend = TokenBackend('HS256', SECRET)
self.rsa_token_backend = TokenBackend('RS256', PRIVATE_KEY, PUBLIC_KEY)
self.hmac_token_backend = TokenBackend('HS256', secret_key=SECRET)
self.hmac_custom_user_key_token_backend = TokenBackend('HS256', get_user_secret_key= lambda a: a)
self.rsa_token_backend = TokenBackend('RS256', signing_key=PRIVATE_KEY, verifying_key=PUBLIC_KEY)
self.payload = {'foo': 'bar'}

def test_init(self):
Expand Down Expand Up @@ -96,6 +98,31 @@ def test_encode_rsa(self):
),
)

def test_encode_hmac_custom(self):
payload = {
'id': '1234'
}
hmac_token = self.hmac_custom_user_key_token_backend.encode(payload)
self.assertIn(
hmac_token,
(
'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6IjEyMzQifQ.RMZuO9SRBYS0pLh8DVhvknBSs80OfFvzxbl-y9b5pnc'
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6IjEyMzQifQ.1weakZ8seqK0UKjpLrpfJaCEr9B7rJiYLGa1TR64QF4'
)
)

def test_decode_hmac_custom_changed_user(self):
token_backend = TokenBackend('HS256', get_user_secret_key= lambda a: str(a) + str(randint(0, 1000000000)))
payload = {
'exp': make_utc(datetime(year=2000, month=1, day=1)),
'id': '1234'
}

token = token_backend.encode(payload)

with self.assertRaises(TokenBackendError):
self.hmac_token_backend.decode(token)

def test_decode_hmac_with_no_expiry(self):
no_exp_token = jwt.encode(self.payload, SECRET, algorithm='HS256')

Expand Down Expand Up @@ -162,7 +189,7 @@ def test_decode_rsa_with_no_expiry_no_verify(self):
no_exp_token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')

self.assertEqual(
self.hmac_token_backend.decode(no_exp_token, verify=False),
self.rsa_token_backend.decode(no_exp_token, verify=False),
self.payload,
)

Expand Down Expand Up @@ -198,7 +225,7 @@ def test_decode_rsa_with_invalid_sig_no_verify(self):
invalid_token = token_2_payload + '.' + token_1_sig

self.assertEqual(
self.hmac_token_backend.decode(invalid_token, verify=False),
self.rsa_token_backend.decode(invalid_token, verify=False),
self.payload,
)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def test_init_bad_sig_token_given(self):
# Test backend rejects encoded token (expired or bad signature)
payload = {'foo': 'bar'}
payload['exp'] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_1 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')
payload['foo'] = 'baz'
token_2 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_2 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')

token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
Expand All @@ -112,9 +112,9 @@ def test_init_bad_sig_token_given_no_verify(self):
# Test backend rejects encoded token (expired or bad signature)
payload = {'foo': 'bar'}
payload['exp'] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_1 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')
payload['foo'] = 'baz'
token_2 = jwt.encode(payload, api_settings.SIGNING_KEY, algorithm='HS256')
token_2 = jwt.encode(payload, api_settings.JWT_SECRET_KEY, algorithm='HS256')

token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
Expand Down

0 comments on commit af3785c

Please sign in to comment.