|
9 | 9 | from cryptography.hazmat.backends import default_backend
|
10 | 10 | from cryptography.hazmat.primitives import hashes # type: ignore
|
11 | 11 | from cryptography.hazmat.primitives import serialization
|
12 |
| -from cryptography.hazmat.primitives.asymmetric import ec # type: ignore |
| 12 | +from cryptography.hazmat.primitives.asymmetric import ec |
13 | 13 | from cryptography.hazmat.primitives.asymmetric import rsa
|
14 | 14 |
|
15 | 15 | from josepy import errors, json_util, util
|
@@ -121,27 +121,114 @@ def load(cls, data, password=None, backend=None):
|
121 | 121 |
|
122 | 122 |
|
123 | 123 | @JWK.register
|
124 |
| -class JWKES(JWK): # pragma: no cover |
| 124 | +class JWKEC(JWK): # pragma: no cover |
125 | 125 | # pylint: disable=abstract-class-not-used
|
126 |
| - """ES JWK. |
| 126 | + """EC JWK. |
127 | 127 |
|
128 | 128 | .. warning:: This is not yet implemented!
|
129 | 129 |
|
130 | 130 | """
|
131 |
| - typ = 'ES' |
| 131 | + typ = 'EC' |
| 132 | + __slots__ = ('key',) |
132 | 133 | cryptography_key_types = (
|
133 | 134 | ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
|
134 | 135 | required = ('crv', JWK.type_field_name, 'x', 'y')
|
135 | 136 |
|
| 137 | + def __init__(self, *args, **kwargs): |
| 138 | + if 'key' in kwargs and not isinstance( |
| 139 | + kwargs['key'], util.ComparableECKey): |
| 140 | + kwargs['key'] = util.ComparableECKey(kwargs['key']) |
| 141 | + super(JWKEC, self).__init__(*args, **kwargs) |
| 142 | + |
| 143 | + @classmethod |
| 144 | + def _encode_param(cls, data): |
| 145 | + """Encode Base64urlUInt. |
| 146 | +
|
| 147 | + :type data: long |
| 148 | + :rtype: unicode |
| 149 | +
|
| 150 | + """ |
| 151 | + def _leading_zeros(arg): |
| 152 | + if len(arg) % 2: |
| 153 | + return '0' + arg |
| 154 | + return arg |
| 155 | + |
| 156 | + return json_util.encode_b64jose(binascii.unhexlify( |
| 157 | + _leading_zeros(hex(data)[2:].rstrip('L')))) |
| 158 | + |
| 159 | + @classmethod |
| 160 | + def _decode_param(cls, data, name, expected_length): |
| 161 | + """Decode Base64urlUInt.""" |
| 162 | + try: |
| 163 | + binary = json_util.decode_b64jose(data) |
| 164 | + if len(binary) != expected_length: |
| 165 | + raise errors.Error( |
| 166 | + 'Expected {name} to be {expected_length} bytes after base64-decoding; got {length}', |
| 167 | + name=name, expected_length=expected_length, length=len(binary)) |
| 168 | + return int(binascii.hexlify(binary), 16) |
| 169 | + except ValueError: # invalid literal for long() with base 16 |
| 170 | + raise errors.DeserializationError() |
| 171 | + |
136 | 172 | def fields_to_partial_json(self):
|
137 |
| - raise NotImplementedError() |
| 173 | + params = {} |
| 174 | + if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey): |
| 175 | + public = self.key.public_numbers() |
| 176 | + elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey): |
| 177 | + private = self.key.private_numbers() |
| 178 | + public = self.key.public_key().public_numbers() |
| 179 | + params.update({ |
| 180 | + 'd': private.private_value, |
| 181 | + }) |
| 182 | + else: raise AssertionError( |
| 183 | + "key was not an EllipticCurvePublicKey or EllipticCurvePrivateKey") |
| 184 | + |
| 185 | + params.update({ |
| 186 | + 'x': public.x, |
| 187 | + 'y': public.y, |
| 188 | + }) |
| 189 | + params = dict((key, self._encode_param(value)) |
| 190 | + for key, value in six.iteritems(params)) |
| 191 | + params['crv'] = self._curve_name_to_crv(public.curve.name) |
| 192 | + return params |
| 193 | + |
| 194 | + @classmethod |
| 195 | + def _curve_name_to_crv(cls, curve_name): |
| 196 | + if curve_name == "secp256r1": return "P-256" |
| 197 | + if curve_name == "secp384r1": return "P-384" |
| 198 | + if curve_name == "secp521r1": return "P-521" |
| 199 | + raise errors.SerializationError() |
| 200 | + |
| 201 | + @classmethod |
| 202 | + def _crv_to_curve(cls, crv): |
| 203 | + # crv is case-sensitive |
| 204 | + if crv == "P-256": return ec.SECP256R1() |
| 205 | + if crv == "P-384": return ec.SECP384R1() |
| 206 | + if crv == "P-521": return ec.SECP521R1() |
| 207 | + raise errors.DeserializationError() |
138 | 208 |
|
139 | 209 | @classmethod
|
140 | 210 | def fields_from_json(cls, jobj):
|
141 |
| - raise NotImplementedError() |
| 211 | + # pylint: disable=invalid-name |
| 212 | + curve = cls._crv_to_curve(jobj['crv']) |
| 213 | + coord_length = (curve.key_size+7)//8 |
| 214 | + x, y = (cls._decode_param(jobj[n], n, coord_length) for n in ('x', 'y')) |
| 215 | + public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve) |
| 216 | + if 'd' not in jobj: # public key |
| 217 | + key = public_numbers.public_key(default_backend()) |
| 218 | + else: # private key |
| 219 | + exp_length = (curve.key_size.bit_length()+7)//8 |
| 220 | + d = cls._decode_param(jobj['d'], 'd', exp_length) |
| 221 | + key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key( |
| 222 | + default_backend()) |
| 223 | + return cls(key=key) |
142 | 224 |
|
143 | 225 | def public_key(self):
|
144 |
| - raise NotImplementedError() |
| 226 | + # Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key() |
| 227 | + if hasattr(self.key, 'public_key'): |
| 228 | + key = self.key.public_key() |
| 229 | + else: |
| 230 | + key = self.key.public_numbers().public_key(default_backend()) |
| 231 | + return type(self)(key=key) |
145 | 232 |
|
146 | 233 |
|
147 | 234 | @JWK.register
|
|
0 commit comments