Skip to content

Commit b78ac9a

Browse files
committed
Add support for Elliptic Curve keys
1 parent bc51f5f commit b78ac9a

File tree

2 files changed

+124
-8
lines changed

2 files changed

+124
-8
lines changed

src/josepy/jwk.py

+94-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from cryptography.hazmat.backends import default_backend
1010
from cryptography.hazmat.primitives import hashes # type: ignore
1111
from cryptography.hazmat.primitives import serialization
12-
from cryptography.hazmat.primitives.asymmetric import ec # type: ignore
12+
from cryptography.hazmat.primitives.asymmetric import ec
1313
from cryptography.hazmat.primitives.asymmetric import rsa
1414

1515
from josepy import errors, json_util, util
@@ -121,27 +121,114 @@ def load(cls, data, password=None, backend=None):
121121

122122

123123
@JWK.register
124-
class JWKES(JWK): # pragma: no cover
124+
class JWKEC(JWK): # pragma: no cover
125125
# pylint: disable=abstract-class-not-used
126-
"""ES JWK.
126+
"""EC JWK.
127127
128128
.. warning:: This is not yet implemented!
129129
130130
"""
131-
typ = 'ES'
131+
typ = 'EC'
132+
__slots__ = ('key',)
132133
cryptography_key_types = (
133134
ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
134135
required = ('crv', JWK.type_field_name, 'x', 'y')
135136

137+
def __init__(self, *args, **kwargs):
138+
if 'key' in kwargs and not isinstance(
139+
kwargs['key'], util.ComparableECKey):
140+
kwargs['key'] = util.ComparableECKey(kwargs['key'])
141+
super(JWKEC, self).__init__(*args, **kwargs)
142+
143+
@classmethod
144+
def _encode_param(cls, data):
145+
"""Encode Base64urlUInt.
146+
147+
:type data: long
148+
:rtype: unicode
149+
150+
"""
151+
def _leading_zeros(arg):
152+
if len(arg) % 2:
153+
return '0' + arg
154+
return arg
155+
156+
return json_util.encode_b64jose(binascii.unhexlify(
157+
_leading_zeros(hex(data)[2:].rstrip('L'))))
158+
159+
@classmethod
160+
def _decode_param(cls, data, name, expected_length):
161+
"""Decode Base64urlUInt."""
162+
try:
163+
binary = json_util.decode_b64jose(data)
164+
if len(binary) != expected_length:
165+
raise errors.Error(
166+
'Expected {name} to be {expected_length} bytes after base64-decoding; got {length}',
167+
name=name, expected_length=expected_length, length=len(binary))
168+
return int(binascii.hexlify(binary), 16)
169+
except ValueError: # invalid literal for long() with base 16
170+
raise errors.DeserializationError()
171+
136172
def fields_to_partial_json(self):
137-
raise NotImplementedError()
173+
params = {}
174+
if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey):
175+
public = self.key.public_numbers()
176+
elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey):
177+
private = self.key.private_numbers()
178+
public = self.key.public_key().public_numbers()
179+
params.update({
180+
'd': private.private_value,
181+
})
182+
else: raise AssertionError(
183+
"key was not an EllipticCurvePublicKey or EllipticCurvePrivateKey")
184+
185+
params.update({
186+
'x': public.x,
187+
'y': public.y,
188+
})
189+
params = dict((key, self._encode_param(value))
190+
for key, value in six.iteritems(params))
191+
params['crv'] = self._curve_name_to_crv(public.curve.name)
192+
return params
193+
194+
@classmethod
195+
def _curve_name_to_crv(cls, curve_name):
196+
if curve_name == "secp256r1": return "P-256"
197+
if curve_name == "secp384r1": return "P-384"
198+
if curve_name == "secp521r1": return "P-521"
199+
raise errors.SerializationError()
200+
201+
@classmethod
202+
def _crv_to_curve(cls, crv):
203+
# crv is case-sensitive
204+
if crv == "P-256": return ec.SECP256R1()
205+
if crv == "P-384": return ec.SECP384R1()
206+
if crv == "P-521": return ec.SECP521R1()
207+
raise errors.DeserializationError()
138208

139209
@classmethod
140210
def fields_from_json(cls, jobj):
141-
raise NotImplementedError()
211+
# pylint: disable=invalid-name
212+
curve = cls._crv_to_curve(jobj['crv'])
213+
coord_length = (curve.key_size+7)//8
214+
x, y = (cls._decode_param(jobj[n], n, coord_length) for n in ('x', 'y'))
215+
public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve)
216+
if 'd' not in jobj: # public key
217+
key = public_numbers.public_key(default_backend())
218+
else: # private key
219+
exp_length = (curve.key_size.bit_length()+7)//8
220+
d = cls._decode_param(jobj['d'], 'd', exp_length)
221+
key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key(
222+
default_backend())
223+
return cls(key=key)
142224

143225
def public_key(self):
144-
raise NotImplementedError()
226+
# Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key()
227+
if hasattr(self.key, 'public_key'):
228+
key = self.key.public_key()
229+
else:
230+
key = self.key.public_numbers().public_key(default_backend())
231+
return type(self)(key=key)
145232

146233

147234
@JWK.register

src/josepy/util.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import OpenSSL
55
import six
6-
from cryptography.hazmat.primitives.asymmetric import rsa
6+
from cryptography.hazmat.backends import default_backend
7+
from cryptography.hazmat.primitives.asymmetric import ec, rsa
78

89

910
class abstractclassmethod(classmethod):
@@ -134,6 +135,34 @@ def __hash__(self):
134135
pub = self.public_numbers()
135136
return hash((self.__class__, pub.n, pub.e))
136137

138+
class ComparableECKey(ComparableKey): # pylint: disable=too-few-public-methods
139+
"""Wrapper for ``cryptography`` RSA keys.
140+
141+
Wraps around:
142+
143+
- :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey`
144+
- :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey`
145+
146+
"""
147+
148+
def __hash__(self):
149+
# public_numbers() hasn't got stable hash!
150+
# https://github.com/pyca/cryptography/issues/2143
151+
if isinstance(self._wrapped, ec.EllipticCurvePrivateKeyWithSerialization):
152+
priv = self.private_numbers()
153+
pub = priv.public_numbers
154+
return hash((self.__class__, pub.curve.name, pub.x, pub.y, priv.d))
155+
elif isinstance(self._wrapped, ec.EllipticCurvePublicKeyWithSerialization):
156+
pub = self.public_numbers()
157+
return hash((self.__class__, pub.curve.name, pub.x, pub.y))
158+
def public_key(self):
159+
"""Get wrapped public key."""
160+
# Unlike RSAPrivateKey, EllipticCurvePrivateKey does not have public_key()
161+
if hasattr(self._wrapped, 'public_key'):
162+
key = self._wrapped.public_key()
163+
else:
164+
key = self._wrapped.public_numbers().public_key(default_backend())
165+
return self.__class__(key)
137166

138167
class ImmutableMap(collections.Mapping, collections.Hashable): # type: ignore
139168
# pylint: disable=too-few-public-methods

0 commit comments

Comments
 (0)