Skip to content

Commit

Permalink
Make: Remove mbedtls (#97)
Browse files Browse the repository at this point in the history
Co-authored-by: Ash Vardanian <[email protected]>
  • Loading branch information
MarkReedZ and ashvardanian authored May 12, 2024
1 parent 3737d94 commit 611921c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 160 deletions.
20 changes: 3 additions & 17 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ include_directories(${cxxopts_SOURCE_DIR}/include)
# export CPATH=/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk/usr/include/
FetchContent_Declare(
picohttpparser
GIT_REPOSITORY https://github.com/unum-cloud/picohttpparser.git
#GIT_REPOSITORY https://github.com/unum-cloud/picohttpparser.git
GIT_REPOSITORY https://github.com/MarkReedZ/picohttpparser.git
GIT_SHALLOW 1
)
FetchContent_MakeAvailable(picohttpparser)
Expand All @@ -123,21 +124,6 @@ FetchContent_MakeAvailable(tb64)
include_directories(${tb64_SOURCE_DIR})
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

FetchContent_Declare(
mbedtls
GIT_REPOSITORY https://github.com/Mbed-TLS/mbedtls/
GIT_TAG v3.4.0
CMAKE_ARGS
-DENABLE_PROGRAMS=OFF
-DENABLE_TESTING=OFF
-DUSE_SHARED_MBEDTLS_LIBRARY=OFF
-DUSE_STATIC_MBEDTLS_LIBRARY=ON
)

FetchContent_MakeAvailable(mbedtls)
include_directories(${mbedtls_SOURCE_DIR}/include)
set(mbedtls_LIBS mbedtls mbedcrypto mbedx509)

# LibUring
if(LINUX)
set(URING_DIR ${CMAKE_BINARY_DIR}/_deps/liburing-ep)
Expand Down Expand Up @@ -168,7 +154,7 @@ find_package(Threads REQUIRED)
include_directories(include/ src/)

add_library(ucall_server_posix src/engine_posix.cpp)
target_link_libraries(ucall_server_posix simdjson::simdjson Threads::Threads ${mbedtls_LIBS})
target_link_libraries(ucall_server_posix simdjson::simdjson Threads::Threads )
set(PYTHON_BACKEND ucall_server_posix)

add_executable(ucall_example_login_posix examples/login/ucall_server.cpp)
Expand Down
150 changes: 7 additions & 143 deletions src/engine_posix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@
#include <charconv> // `std::to_chars`
#include <chrono> // `std::chrono`

#include "mbedtls/config.h"
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h>
#include <mbedtls/net_sockets.h>
#include <mbedtls/ssl.h>
#include <mbedtls/ssl_cache.h>

#include "ucall/ucall.h"

#include "helpers/log.hpp"
Expand All @@ -59,80 +52,11 @@ using time_point_t = std::chrono::time_point<time_clock_t>;

static constexpr std::size_t initial_buffer_size_k = ram_page_size_k * 4;

struct ucall_ssl_context_t {

~ucall_ssl_context_t() noexcept {
mbedtls_x509_crt_free(&srvcert);
mbedtls_pk_free(&pkey);
mbedtls_ssl_free(&ssl);
mbedtls_ssl_config_free(&conf);
mbedtls_ssl_cache_free(&cache);
mbedtls_ctr_drbg_free(&ctr_drbg);
mbedtls_entropy_free(&entropy);
}

int init(const char* pk_path, const char** crts_path, size_t crts_cnt) {
mbedtls_ssl_init(&ssl);
mbedtls_ssl_config_init(&conf);
mbedtls_ssl_cache_init(&cache);
mbedtls_x509_crt_init(&srvcert);
mbedtls_pk_init(&pkey);
mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_init(&ctr_drbg);
int ret = 0;

// Seed the RNG
if ((ret = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0)
// TODO Use personalization string. Required or Optional ?
return ret;

// Load Private Key
if ((ret = mbedtls_pk_parse_keyfile(&pkey, pk_path, NULL, NULL, &ctr_drbg)) != 0)
// TODO Use Password. Required or Optional ?
return ret;

// Load Certificates
for (size_t i = 0; i < crts_cnt; ++i)
if ((ret = mbedtls_x509_crt_parse_file(&srvcert, crts_path[i])) != 0)
// TODO Notify which certificate was invalid ?
return ret;

if ((ret = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
return ret;

mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg);

mbedtls_ssl_conf_session_cache(&conf, &cache, mbedtls_ssl_cache_get, mbedtls_ssl_cache_set);
mbedtls_ssl_conf_renegotiation(&conf, MBEDTLS_SSL_RENEGOTIATION_DISABLED);

mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL);
if ((ret = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
return ret;

if ((ret = mbedtls_ssl_setup(&ssl, &conf)) != 0)
return ret;

return 0;
}

mbedtls_ssl_context ssl{};
mbedtls_ssl_config conf{};
mbedtls_pk_context pkey{};
mbedtls_x509_crt srvcert{};
mbedtls_entropy_context entropy{};
mbedtls_ssl_cache_context cache{};
mbedtls_ctr_drbg_context ctr_drbg{};
};

struct engine_t {
~engine_t() noexcept { delete ssl_ctx; }
~engine_t() noexcept { }

descriptor_t socket{};

/// @brief Establishes an SSL connection if SSL is enabled, otherwise the `ssl_ctx` is unused and uninitialized.
ucall_ssl_context_t* ssl_ctx = nullptr;

/// @brief The file descriptor of the stateful connection over TCP.
descriptor_t connection{};
/// @brief A small memory buffer to store small requests.
Expand Down Expand Up @@ -169,13 +93,8 @@ void send_message(engine_t& engine, array_gt<char> const& message) noexcept {
long idx = 0;
long res = 0;

if (engine.ssl_ctx)
while (idx < len && (res = mbedtls_ssl_write(&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t const*>(buf + idx),
(len - idx))) > 0)
idx += res;
else
while (idx < len && (res = send(engine.connection, buf + idx, len - idx, 0)) > 0)
idx += res;
while (idx < len && (res = send(engine.connection, buf + idx, len - idx, 0)) > 0)
idx += res;

if (res < 0) {
if (errno == EMSGSIZE)
Expand Down Expand Up @@ -264,29 +183,12 @@ void forward_packet(engine_t& engine) noexcept {
return forward_call_or_calls(engine);
}

int ssl_send(void* ctx, const unsigned char* buf, size_t len) {
mbedtls_net_context* conn = reinterpret_cast<mbedtls_net_context*>(ctx);
ssize_t ret = send(conn->fd, reinterpret_cast<char const*>(buf), len, 0);
return ret;
}

int ssl_recv(void* ctx, unsigned char* buf, size_t len) {
mbedtls_net_context* conn = reinterpret_cast<mbedtls_net_context*>(ctx);
ssize_t ret = recv(conn->fd, reinterpret_cast<char*>(buf), len, 0);
return ret;
}

int recv_all(engine_t& engine, char* buf, size_t len) {
size_t idx = 0;
int res = 0;

if (engine.ssl_ctx)
while (idx < len &&
(res = mbedtls_ssl_read(&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t*>(buf + idx), (len - idx))) > 0)
idx += res;
else
while (idx < len && (res = recv(engine.connection, buf + idx, len - idx, 0)) > 0)
idx += res;
while (idx < len && (res = recv(engine.connection, buf + idx, len - idx, 0)) > 0)
idx += res;

return idx;
}
Expand Down Expand Up @@ -320,32 +222,14 @@ void ucall_take_call(ucall_server_t server, uint16_t) {
return;
}

mbedtls_net_context client_ctx;

if (engine.ssl_ctx) {
client_ctx.fd = connection_fd;
mbedtls_ssl_set_bio(&engine.ssl_ctx->ssl, &client_ctx, ssl_send, ssl_recv, NULL);
int ret = 0;
while ((ret = mbedtls_ssl_handshake(&engine.ssl_ctx->ssl)) != 0)
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
mbedtls_net_free(&client_ctx);
mbedtls_ssl_session_reset(&engine.ssl_ctx->ssl);
return;
}
}

// Wait until we have input.
engine.connection = descriptor_t{connection_fd};
engine.stats.added_connections++;
engine.stats.closed_connections++;
char* buffer_ptr = &engine.packet_buffer[0];

size_t bytes_received = 0, bytes_expected = 0;
if (engine.ssl_ctx)
bytes_received =
mbedtls_ssl_read(&engine.ssl_ctx->ssl, reinterpret_cast<uint8_t*>(buffer_ptr), http_head_max_size_k);
else
bytes_received = recv(engine.connection, buffer_ptr, http_head_max_size_k, 0);
bytes_received = recv(engine.connection, buffer_ptr, http_head_max_size_k, 0);

auto json_or_error = split_body_headers(std::string_view(buffer_ptr, bytes_received));
if (auto error_ptr = std::get_if<default_error_t>(&json_or_error); error_ptr)
Expand Down Expand Up @@ -401,14 +285,6 @@ void ucall_take_call(ucall_server_t server, uint16_t) {
buffer_ptr = nullptr;
}

if (engine.ssl_ctx) {
int ret = 0;
while ((ret = mbedtls_ssl_close_notify(&engine.ssl_ctx->ssl)) < 0)
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
break;

mbedtls_ssl_session_reset(&engine.ssl_ctx->ssl);
}
shutdown(connection_fd, SHUT_WR);
// If later on some UB is detected for client not recieving full data,
// then it may be required to put a `recv` with timeout between `shutdown` and `close`
Expand All @@ -431,9 +307,6 @@ void ucall_init(ucall_config_t* config_inout, ucall_server_t* server_out) {
config.max_callbacks = 128u;
if (!config.hostname)
config.hostname = "0.0.0.0";
if (config.use_ssl &&
!(config.ssl_private_key_path || config.ssl_certificates_paths || config.ssl_certificates_count))
return;

// Some limitations are hard-coded for this non-concurrent implementation
config.max_threads = 1u;
Expand All @@ -447,7 +320,6 @@ void ucall_init(ucall_config_t* config_inout, ucall_server_t* server_out) {
engine_t* server_ptr = nullptr;
array_gt<char> buffer;
array_gt<named_callback_t> embedded_callbacks;
ucall_ssl_context_t* ssl_context = nullptr;
sjd::parser parser;

// By default, let's open TCP port for IPv4.
Expand Down Expand Up @@ -475,12 +347,6 @@ void ucall_init(ucall_config_t* config_inout, ucall_server_t* server_out) {
goto cleanup;
if (listen(socket_descriptor, config.queue_depth) < 0)
goto cleanup;
if (config.use_ssl) {
ssl_context = new ucall_ssl_context_t();
if (ssl_context->init(config.ssl_private_key_path, config.ssl_certificates_paths,
config.ssl_certificates_count) != 0)
goto cleanup;
}
if (parser.allocate(ram_page_size_k, ram_page_size_k / 2) != sj::SUCCESS)
goto cleanup;

Expand All @@ -493,7 +359,6 @@ void ucall_init(ucall_config_t* config_inout, ucall_server_t* server_out) {
server_ptr->logs_file_descriptor = config.logs_file_descriptor;
server_ptr->logs_format = config.logs_format ? std::string_view(config.logs_format) : std::string_view();
server_ptr->log_last_time = time_clock_t::now();
server_ptr->ssl_ctx = ssl_context;
*server_out = (ucall_server_t)server_ptr;
return;

Expand All @@ -503,7 +368,6 @@ void ucall_init(ucall_config_t* config_inout, ucall_server_t* server_out) {
close(socket_descriptor);
std::free(server_ptr);
*server_out = nullptr;
delete ssl_context;
}

void ucall_add_procedure(ucall_server_t server, ucall_str_t name, ucall_callback_t callback,
Expand Down Expand Up @@ -705,4 +569,4 @@ bool ucall_param_positional_str(ucall_call_t call, size_t position, ucall_str_t*
return true;
} else
return false;
}
}

0 comments on commit 611921c

Please sign in to comment.