Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MONGOCRYPT-769 Implement changes from OST-v12 and v13 #952

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 34 additions & 32 deletions src/mc-text-search-str-encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,26 @@

// 16MiB - maximum length in bytes of a string to be encoded.
#define MAX_ENCODE_BYTE_LEN 16777216
// Number of bytes which are added to the base string before encryption.
#define OVERHEAD_BYTES 5

static mc_affix_set_t *generate_prefix_or_suffix_tree(const mc_utf8_string_with_bad_char_t *base_str,
uint32_t unfolded_codepoint_len,
uint32_t unfolded_byte_len,
uint32_t lb,
uint32_t ub,
bool is_prefix) {
BSON_ASSERT_PARAM(base_str);
// 16 * ceil(unfolded codepoint len / 16)
uint32_t cbclen = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
if (cbclen < lb) {
// We encrypt (unfolded string + 5 bytes of extra BSON info) with a 16-byte block cipher.
uint32_t encrypted_len = 16 * (uint32_t)((unfolded_byte_len + OVERHEAD_BYTES + 15) / 16);
// Max len of a string that has this encrypted len.
uint32_t padded_len = encrypted_len - OVERHEAD_BYTES;
if (padded_len < lb) {
// No valid substrings, return empty tree
return NULL;
}

// Total number of substrings
uint32_t msize = BSON_MIN(cbclen, ub) - lb + 1;
uint32_t msize = BSON_MIN(padded_len, ub) - lb + 1;
uint32_t folded_codepoint_len = base_str->codepoint_len - 1; // remove one codepoint for 0xFF
uint32_t real_max_len = BSON_MIN(folded_codepoint_len, ub);
// Number of actual substrings, excluding padding
Expand Down Expand Up @@ -67,19 +71,19 @@ static mc_affix_set_t *generate_prefix_or_suffix_tree(const mc_utf8_string_with_
}

static mc_affix_set_t *generate_suffix_tree(const mc_utf8_string_with_bad_char_t *base_str,
uint32_t unfolded_codepoint_len,
uint32_t unfolded_byte_len,
const mc_FLE2SuffixInsertSpec_t *spec) {
BSON_ASSERT_PARAM(base_str);
BSON_ASSERT_PARAM(spec);
return generate_prefix_or_suffix_tree(base_str, unfolded_codepoint_len, spec->lb, spec->ub, false);
return generate_prefix_or_suffix_tree(base_str, unfolded_byte_len, spec->lb, spec->ub, false);
}

static mc_affix_set_t *generate_prefix_tree(const mc_utf8_string_with_bad_char_t *base_str,
uint32_t unfolded_codepoint_len,
uint32_t unfolded_byte_len,
const mc_FLE2PrefixInsertSpec_t *spec) {
BSON_ASSERT_PARAM(base_str);
BSON_ASSERT_PARAM(spec);
return generate_prefix_or_suffix_tree(base_str, unfolded_codepoint_len, spec->lb, spec->ub, true);
return generate_prefix_or_suffix_tree(base_str, unfolded_byte_len, spec->lb, spec->ub, true);
}

static uint32_t calc_number_of_substrings(uint32_t strlen, uint32_t lb, uint32_t ub) {
Expand All @@ -97,13 +101,15 @@ static uint32_t calc_number_of_substrings(uint32_t strlen, uint32_t lb, uint32_t
}

static mc_substring_set_t *generate_substring_tree(const mc_utf8_string_with_bad_char_t *base_str,
uint32_t unfolded_codepoint_len,
uint32_t unfolded_byte_len,
const mc_FLE2SubstringInsertSpec_t *spec) {
BSON_ASSERT_PARAM(base_str);
BSON_ASSERT_PARAM(spec);
// 16 * ceil(unfolded len / 16)
uint32_t cbclen = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
if (unfolded_codepoint_len > spec->mlen || cbclen < spec->lb) {
// We encrypt (unfolded string + 5 bytes of extra BSON info) with a 16-byte block cipher.
uint32_t encrypted_len = 16 * (uint32_t)((unfolded_byte_len + OVERHEAD_BYTES + 15) / 16);
// Max len of a string that has this encrypted len.
uint32_t padded_len = encrypted_len - OVERHEAD_BYTES;
if (padded_len < spec->lb) {
// No valid substrings, return empty tree
return NULL;
}
Expand All @@ -112,30 +118,30 @@ static mc_substring_set_t *generate_substring_tree(const mc_utf8_string_with_bad
// justifies why that calculation and this calculation are equivalent.
// At this point, it is established that:
// beta <= mlen
// lb <= cbclen
// lb <= padded_len
// lb <= ub <= mlen
//
// So, the following formula for msize in the OST paper:
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1))
// maxkgram_2 = sum_(j=lb, min(ub, cbclen), (cbclen - j + 1))
// maxkgram_2 = sum_(j=lb, min(ub, padded_len), (padded_len - j + 1))
// msize = min(maxkgram_1, maxkgram_2)
// can be simplified to:
// msize = sum_(j=lb, min(ub, cbclen), (min(mlen, cbclen) - j + 1))
// msize = sum_(j=lb, min(ub, padded_len), (min(mlen, padded_len) - j + 1))
//
// because if cbclen <= ub, then it follows that cbclen <= ub <= mlen, and so
// because if padded_len <= ub, then it follows that padded_len <= ub <= mlen, and so
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1)) # as above
// maxkgram_2 = sum_(j=lb, cbclen, (cbclen - j + 1)) # less or equal to maxkgram_1
// maxkgram_2 = sum_(j=lb, padded_len, (padded_len - j + 1)) # less or equal to maxkgram_1
// msize = maxkgram_2
// and if cbclen > ub, then it follows that:
// and if padded_len > ub, then it follows that:
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1)) # as above
// maxkgram_2 = sum_(j=lb, ub, (cbclen - j + 1)) # same sum bounds as maxkgram_1
// msize = sum_(j=lb, ub, (min(mlen, cbclen) - j + 1))
// maxkgram_2 = sum_(j=lb, ub, (padded_len - j + 1)) # same sum bounds as maxkgram_1
// msize = sum_(j=lb, ub, (min(mlen, padded_len) - j + 1))
// in both cases, msize can be rewritten as:
// msize = sum_(j=lb, min(ub, cbclen), (min(mlen, cbclen) - j + 1))
// msize = sum_(j=lb, min(ub, padded_len), (min(mlen, padded_len) - j + 1))

uint32_t folded_codepoint_len = base_str->codepoint_len - 1;
// If mlen < cbclen, we only need to pad to mlen
uint32_t padded_len = BSON_MIN(spec->mlen, cbclen);
// If mlen < padded_len, we only need to pad to mlen
padded_len = BSON_MIN(spec->mlen, padded_len);
// Total number of substrings -- i.e. the number of valid substrings IF the string spanned the full padded length
uint32_t msize = calc_number_of_substrings(padded_len, spec->lb, spec->ub);
uint32_t n_real_substrings = 0;
Expand Down Expand Up @@ -185,11 +191,6 @@ mc_str_encode_sets_t *mc_text_search_str_encode(const mc_FLE2TextSearchInsertSpe
CLIENT_ERR("StrEncode: String passed in was not valid UTF-8");
return NULL;
}
uint32_t unfolded_codepoint_len = mc_get_utf8_codepoint_length(spec->v, spec->len);
if (unfolded_codepoint_len == 0) {
// Empty string: We set unfolded length to 1 so that we generate fake tokens.
unfolded_codepoint_len = 1;
}

mc_utf8_string_with_bad_char_t *base_string;
if (spec->casef || spec->diacf) {
Expand All @@ -213,12 +214,13 @@ mc_str_encode_sets_t *mc_text_search_str_encode(const mc_FLE2TextSearchInsertSpe
// Base string is the folded string plus the 0xFF character
sets->base_string = base_string;
if (spec->suffix.set) {
sets->suffix_set = generate_suffix_tree(sets->base_string, unfolded_codepoint_len, &spec->suffix.value);
sets->suffix_set = generate_suffix_tree(sets->base_string, spec->len, &spec->suffix.value);
}
if (spec->prefix.set) {
sets->prefix_set = generate_prefix_tree(sets->base_string, unfolded_codepoint_len, &spec->prefix.value);
sets->prefix_set = generate_prefix_tree(sets->base_string, spec->len, &spec->prefix.value);
}
if (spec->substr.set) {
uint32_t unfolded_codepoint_len = mc_get_utf8_codepoint_length(spec->v, spec->len);
if (unfolded_codepoint_len > spec->substr.value.mlen) {
CLIENT_ERR("StrEncode: String passed in was longer than the maximum length for substring indexing -- "
"String len: %u, max len: %u",
Expand All @@ -227,7 +229,7 @@ mc_str_encode_sets_t *mc_text_search_str_encode(const mc_FLE2TextSearchInsertSpe
mc_str_encode_sets_destroy(sets);
return NULL;
}
sets->substring_set = generate_substring_tree(sets->base_string, unfolded_codepoint_len, &spec->substr.value);
sets->substring_set = generate_substring_tree(sets->base_string, spec->len, &spec->substr.value);
}
// Exact string is always equal to the base string up until the bad character
_mongocrypt_buffer_from_data(&sets->exact, sets->base_string->buf.data, (uint32_t)sets->base_string->buf.len - 1);
Expand Down
50 changes: 37 additions & 13 deletions test/test-mc-text-search-str-encode.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
uint32_t byte_len = (uint32_t)strlen(str);
uint32_t unfolded_codepoint_len = byte_len == 0 ? 1 : get_utf8_codepoint_length(str, byte_len);
uint32_t folded_codepoint_len = byte_len == 0 ? 0 : unfolded_codepoint_len - foldable_codepoints;
uint32_t max_padded_len = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
uint32_t padded_len = 16 * (uint32_t)((byte_len + 5 + 15) / 16) - 5;
uint32_t max_affix_len = BSON_MIN(ub, folded_codepoint_len);
uint32_t n_real_affixes = max_affix_len >= lb ? max_affix_len - lb + 1 : 0;
uint32_t n_affixes = BSON_MIN(ub, max_padded_len) - lb + 1;
uint32_t n_affixes = BSON_MIN(ub, padded_len) - lb + 1;
uint32_t n_padding = n_affixes - n_real_affixes;

mc_str_encode_sets_t *sets;
Expand Down Expand Up @@ -86,7 +86,7 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
ASSERT_CMPUINT32(sets->exact.len, ==, sets->base_string->buf.len - 1);
ASSERT_CMPINT(0, ==, memcmp(sets->exact.data, sets->base_string->buf.data, sets->exact.len));

if (lb > max_padded_len) {
if (lb > padded_len) {
ASSERT(sets->suffix_set == NULL);
ASSERT(sets->prefix_set == NULL);
goto CONTINUE;
Expand Down Expand Up @@ -230,8 +230,8 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
uint32_t byte_len = (uint32_t)strlen(str);
uint32_t unfolded_codepoint_len = byte_len == 0 ? 1 : get_utf8_codepoint_length(str, byte_len);
uint32_t folded_codepoint_len = byte_len == 0 ? 0 : unfolded_codepoint_len - foldable_codepoints;
uint32_t max_padded_len = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
uint32_t n_substrings = calc_number_of_substrings(BSON_MIN(max_padded_len, mlen), lb, ub);
uint32_t padded_len = 16 * (uint32_t)((byte_len + 5 + 15) / 16) - 5;
uint32_t n_substrings = calc_number_of_substrings(BSON_MIN(padded_len, mlen), lb, ub);

mongocrypt_status_t *status = mongocrypt_status_new();
mc_str_encode_sets_t *sets;
Expand Down Expand Up @@ -260,7 +260,7 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
ASSERT_CMPUINT32(sets->exact.len, ==, sets->base_string->buf.len - 1);
ASSERT_CMPINT(0, ==, memcmp(sets->exact.data, sets->base_string->buf.data, sets->base_string->buf.len - 1));

if (lb > max_padded_len) {
if (lb > padded_len) {
ASSERT(sets->substring_set == NULL);
goto cleanup;
} else {
Expand Down Expand Up @@ -325,17 +325,39 @@ static void test_nofold_substring_case_multiple_mlen(_mongocrypt_tester_t *teste
bool casef,
bool diacf,
int foldable_codepoints) {
// mlen < unfolded_codepoint_len
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len - 1, casef, diacf, foldable_codepoints);
if (unfolded_codepoint_len > 1) {
// mlen < unfolded_codepoint_len
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len - 1, casef, diacf, foldable_codepoints);
}
// mlen = unfolded_codepoint_len
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len, casef, diacf, foldable_codepoints);
// mlen > unfolded_codepoint_len
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len + 1, casef, diacf, foldable_codepoints);
// mlen >> unfolded_codepoint_len
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len + 64, casef, diacf, foldable_codepoints);
// mlen = cbclen
uint32_t max_padded_len = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
test_nofold_substring_case(tester, str, lb, ub, max_padded_len, casef, diacf, foldable_codepoints);

uint32_t byte_len = (uint32_t)strlen(str);
if (byte_len > 1) {
// mlen < byte_len
test_nofold_substring_case(tester, str, lb, ub, byte_len - 1, casef, diacf, foldable_codepoints);
}
if (byte_len > 0) {
// mlen = byte_len
test_nofold_substring_case(tester, str, lb, ub, byte_len, casef, diacf, foldable_codepoints);
}
// mlen > byte_len
test_nofold_substring_case(tester, str, lb, ub, byte_len + 1, casef, diacf, foldable_codepoints);
// mlen = padded_len
test_nofold_substring_case(tester,
str,
lb,
ub,
16 * (uint32_t)((byte_len + 5 + 15) / 16) - 5,
casef,
diacf,
foldable_codepoints);
// mlen >> byte_len
test_nofold_substring_case(tester, str, lb, ub, byte_len + 64, casef, diacf, foldable_codepoints);
}

const char *normal_ascii_strings[] = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f",
Expand All @@ -353,6 +375,8 @@ const char *unicode_diacritics[] = {"̀", "́", "̂", "̃", "̄", "̅", "̆",

// Build a random string which has unfolded_len codepoints, but folds to folded_len codepoints after diacritic folding.
char *build_random_string_to_fold(uint32_t folded_len, uint32_t unfolded_len) {
// 1/3 to generate all unicode, 1/3 to be half and half, 1/3 to be all ascii.
int ascii_ratio = rand() % 3;
ASSERT_CMPUINT32(unfolded_len, >=, folded_len);
// Max size in bytes is # unicode characters * 4 bytes for each character + 1 null terminator.
char *str = malloc(unfolded_len * 4 + 1);
Expand All @@ -366,7 +390,7 @@ char *build_random_string_to_fold(uint32_t folded_len, uint32_t unfolded_len) {
bool must_add_normal = n_codepoints - folded_size == diacritics;
if (must_add_diacritic || (!must_add_normal && (rand() % 1000 < dia_prob))) {
// Add diacritic.
if (rand() % 2) {
if (rand() % 2 < ascii_ratio) {
int i = rand() % (sizeof(ascii_diacritics) / sizeof(char *));
src_ptr = ascii_diacritics[i];
} else {
Expand All @@ -375,7 +399,7 @@ char *build_random_string_to_fold(uint32_t folded_len, uint32_t unfolded_len) {
}
} else {
// Add normal character.
if (rand() % 2) {
if (rand() % 2 < ascii_ratio) {
int i = rand() % (sizeof(normal_ascii_strings) / sizeof(char *));
src_ptr = normal_ascii_strings[i];
} else {
Expand Down
30 changes: 30 additions & 0 deletions test/test-mongocrypt-crypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <mongocrypt-crypto-private.h>
#include <mongocrypt.h>

#include "test-mongocrypt-assert.h"
#include "test-mongocrypt.h"

typedef struct {
Expand Down Expand Up @@ -432,9 +433,38 @@ static void _test_random_int64(_mongocrypt_tester_t *tester) {
mongocrypt_destroy(crypt);
}

static void _test_aes_256_aead_steps_consistent(_mongocrypt_tester_t *tester) {
mongocrypt_status_t *status = mongocrypt_status_new();
// Tests a key assumption we make that if 16k <= a <= b <= 16k + 15 (a, b, k integers), a plaintext of length a and
// a plaintext of length b produce a ciphertext of the same length, and a plaintext of length 16k produces a
// ciphertext 16 bytes longer than one of length 16(k-1). This is very important for the leakage profile of QE text
// search.
const _mongocrypt_value_encryption_algorithm_t *alg = _mcFLE2v2AEADAlgorithm();
size_t ciphertext_len = 0;
for (int i = 0; i <= 16; i++) {
size_t new_ct_len = alg->get_ciphertext_len(i * 16, status);
if (new_ct_len == 0) {
TEST_ERROR("get_ciphertext_len failed");
}
if (i != 0) {
ASSERT_CMPSIZE_T(new_ct_len, ==, ciphertext_len + 16);
}
ciphertext_len = new_ct_len;
for (int j = 1; j < 16; j++) {
size_t ct_len = alg->get_ciphertext_len(i * 16 + j, status);
if (ct_len == 0) {
TEST_ERROR("get_ciphertext_len failed");
}
ASSERT_CMPSIZE_T(ct_len, ==, ciphertext_len);
}
}
mongocrypt_status_destroy(status);
}

void _mongocrypt_tester_install_crypto(_mongocrypt_tester_t *tester) {
INSTALL_TEST(_test_roundtrip);
INSTALL_TEST(_test_native_crypto_hmac_sha_256);
INSTALL_TEST_CRYPTO(_test_mongocrypt_hmac_sha_256_hook, CRYPTO_OPTIONAL);
INSTALL_TEST(_test_random_int64);
INSTALL_TEST(_test_aes_256_aead_steps_consistent);
}
Loading