Skip to content

Commit

Permalink
move b64 out of helper
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitriyMusatkin committed Nov 19, 2024
1 parent ffbfe51 commit 1b36706
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 48 deletions.
4 changes: 2 additions & 2 deletions awscrt/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def new_public_key_from_pem_data(pem_data: Union[str, bytes, bytearray, memoryvi
return RSA(binding=_awscrt.rsa_public_key_from_pem_data(pem_data))

@staticmethod
def new_private_key_from_der_data(der_data: Union[str, bytes, bytearray, memoryview]) -> 'RSA':
def new_private_key_from_der_data(der_data: Union[bytes, bytearray, memoryview]) -> 'RSA':
"""
Creates a new instance of private RSA key pair from der data.
Expects key in PKCS1 format.
Expand All @@ -133,7 +133,7 @@ def new_private_key_from_der_data(der_data: Union[str, bytes, bytearray, memoryv
return RSA(binding=_awscrt.rsa_private_key_from_der_data(der_data))

@staticmethod
def new_public_key_from_der_data(der_data: Union[str, bytes, bytearray, memoryview]) -> 'RSA':
def new_public_key_from_der_data(der_data: Union[bytes, bytearray, memoryview]) -> 'RSA':
"""
Creates a new instance of public RSA key pair from der data.
Expects key in PKCS1 format.
Expand Down
46 changes: 4 additions & 42 deletions source/crypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -355,32 +355,14 @@ PyObject *aws_py_rsa_private_key_from_der_data(PyObject *self, PyObject *args) {
(void)self;

struct aws_byte_cursor der_data_cur;
if (!PyArg_ParseTuple(args, "s#", &der_data_cur.ptr, &der_data_cur.len)) {
if (!PyArg_ParseTuple(args, "y#", &der_data_cur.ptr, &der_data_cur.len)) {
return NULL;
}

PyObject *capsule = NULL;
struct aws_allocator *allocator = aws_py_get_allocator();

struct aws_byte_buf decoded_buffer;
AWS_ZERO_STRUCT(decoded_buffer);

size_t decoded_len = 0;
if (aws_base64_compute_decoded_len(&der_data_cur, &decoded_len)) {
PyErr_AwsLastError();
goto on_done;
}

aws_byte_buf_init(&decoded_buffer, allocator, decoded_len);

if (aws_base64_decode(&der_data_cur, &decoded_buffer)) {
PyErr_AwsLastError();
goto on_done;
}

struct aws_byte_cursor raw_der = aws_byte_cursor_from_buf(&decoded_buffer);

struct aws_rsa_key_pair *key_pair = aws_rsa_key_pair_new_from_private_key_pkcs1(allocator, raw_der);
struct aws_rsa_key_pair *key_pair = aws_rsa_key_pair_new_from_private_key_pkcs1(allocator, der_data_cur);

if (key_pair == NULL) {
PyErr_AwsLastError();
Expand All @@ -394,40 +376,21 @@ PyObject *aws_py_rsa_private_key_from_der_data(PyObject *self, PyObject *args) {
}

on_done:
aws_byte_buf_clean_up_secure(&decoded_buffer);
return capsule;
}

PyObject *aws_py_rsa_public_key_from_der_data(PyObject *self, PyObject *args) {
(void)self;

struct aws_byte_cursor der_data_cur;
if (!PyArg_ParseTuple(args, "s#", &der_data_cur.ptr, &der_data_cur.len)) {
if (!PyArg_ParseTuple(args, "y#", &der_data_cur.ptr, &der_data_cur.len)) {
return NULL;
}

PyObject *capsule = NULL;
struct aws_allocator *allocator = aws_py_get_allocator();

struct aws_byte_buf decoded_buffer;
AWS_ZERO_STRUCT(decoded_buffer);

size_t decoded_len = 0;
if (aws_base64_compute_decoded_len(&der_data_cur, &decoded_len)) {
PyErr_AwsLastError();
goto on_done;
}

aws_byte_buf_init(&decoded_buffer, allocator, decoded_len);

if (aws_base64_decode(&der_data_cur, &decoded_buffer)) {
PyErr_AwsLastError();
goto on_done;
}

struct aws_byte_cursor raw_der = aws_byte_cursor_from_buf(&decoded_buffer);

struct aws_rsa_key_pair *key_pair = aws_rsa_key_pair_new_from_public_key_pkcs1(allocator, raw_der);
struct aws_rsa_key_pair *key_pair = aws_rsa_key_pair_new_from_public_key_pkcs1(allocator, der_data_cur);

if (key_pair == NULL) {
PyErr_AwsLastError();
Expand All @@ -441,7 +404,6 @@ PyObject *aws_py_rsa_public_key_from_der_data(PyObject *self, PyObject *args) {
}

on_done:
aws_byte_buf_clean_up_secure(&decoded_buffer);
return capsule;
}

Expand Down
13 changes: 9 additions & 4 deletions test/test_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from test import NativeResourceTest
from awscrt.crypto import Hash, RSA, RSAEncryptionAlgorithm, RSASignatureAlgorithm
import base64
import unittest

RSA_PRIVATE_KEY_PEM = """
Expand Down Expand Up @@ -177,12 +178,14 @@ def test_rsa_encryption_roundtrip_der(self):
for p in param_list:
with self.subTest(msg="RSA Encryption Roundtrip using algo p", p=p):
test_pt = b'totally original test string'
rsa = RSA.new_private_key_from_der_data(RSA_PRIVATE_KEY_DER)
decoded_private_key = base64.b64decode(RSA_PRIVATE_KEY_DER)
rsa = RSA.new_private_key_from_der_data(decoded_private_key)
ct = rsa.encrypt(p, test_pt)
pt = rsa.decrypt(p, ct)
self.assertEqual(test_pt, pt)

rsa_pub = RSA.new_public_key_from_der_data(RSA_PUBLIC_KEY_DER)
decoded_public_key = base64.b64decode(RSA_PUBLIC_KEY_DER)
rsa_pub = RSA.new_public_key_from_der_data(decoded_public_key)
ct_pub = rsa_pub.encrypt(p, test_pt)
pt_pub = rsa.decrypt(p, ct_pub)
self.assertEqual(test_pt, pt_pub)
Expand Down Expand Up @@ -222,11 +225,13 @@ def test_rsa_signing_roundtrip_der(self):
h.update(b'totally original test string')
digest = h.digest()

rsa = RSA.new_private_key_from_der_data(RSA_PRIVATE_KEY_DER)
decoded_private_key = base64.b64decode(RSA_PRIVATE_KEY_DER)
rsa = RSA.new_private_key_from_der_data(decoded_private_key)
signature = rsa.sign(p, digest)
self.assertTrue(rsa.verify(p, digest, signature))

rsa_pub = RSA.new_public_key_from_der_data(RSA_PUBLIC_KEY_DER)
decoded_private_key = base64.b64decode(RSA_PUBLIC_KEY_DER)
rsa_pub = RSA.new_public_key_from_der_data(decoded_private_key)
self.assertTrue(rsa_pub.verify(p, digest, signature))

def test_rsa_load_error(self):
Expand Down

0 comments on commit 1b36706

Please sign in to comment.