Skip to content

Commit 267c757

Browse files
Fix: Signed char handling in sz_find_byte_serial (#308)
Closes #306 Co-Authored-By: David Mollitor <dmollitor@apache.org> Co-Authored-By: David Mollitor <12578579+belugabehr@users.noreply.github.com> Co-authored-by: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com>
1 parent 7d72c96 commit 267c757

2 files changed

Lines changed: 95 additions & 13 deletions

File tree

include/stringzilla/find.h

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -377,66 +377,73 @@ SZ_PUBLIC sz_cptr_t sz_rfind_byteset_serial(sz_cptr_t text, sz_size_t length, sz
377377
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
378378
* Identical to `memchr(haystack, needle[0], haystack_length)`.
379379
*/
380-
SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
380+
SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h_chars, sz_size_t h_length, sz_cptr_t n_chars) {
381381

382382
if (!h_length) return SZ_NULL_CHAR;
383-
sz_cptr_t const h_end = h + h_length;
383+
// Reinterpret as unsigned bytes so the SWAR broadcast below cannot sign-extend
384+
// on platforms where `char` is signed (e.g. `-fsigned-char`). See issue #306.
385+
sz_u8_t const *h = (sz_u8_t const *)h_chars;
386+
sz_u8_t const *const n = (sz_u8_t const *)n_chars;
387+
sz_u8_t const *const h_end = h + h_length;
384388

385389
#if !SZ_IS_BIG_ENDIAN_ // Use SWAR only on little-endian platforms for brevity.
386390
#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads.
387391
for (; ((sz_size_t)h & 7ull) && h < h_end; ++h)
388-
if (*h == *n) return h;
392+
if (*h == *n) return (sz_cptr_t)h;
389393
#endif
390394

391395
// Broadcast the n into every byte of a 64-bit integer to use SWAR
392396
// techniques and process eight characters at a time.
393397
sz_u64_vec_t h_vec, n_vec, match_vec;
394398
match_vec.u64 = 0;
395-
n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull;
399+
n_vec.u64 = (sz_u64_t)*n * 0x0101010101010101ull;
396400
for (; h + 8 <= h_end; h += 8) {
397401
h_vec.u64 = *(sz_u64_t const *)h;
398402
match_vec = sz_u64_each_byte_equal_(h_vec, n_vec);
399-
if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8;
403+
if (match_vec.u64) return (sz_cptr_t)(h + sz_u64_ctz(match_vec.u64) / 8);
400404
}
401405
#endif
402406

403407
// Handle the misaligned tail.
404408
for (; h < h_end; ++h)
405-
if (*h == *n) return h;
409+
if (*h == *n) return (sz_cptr_t)h;
406410
return SZ_NULL_CHAR;
407411
}
408412

409413
/* Find the last occurrence of a @b single-character needle in an arbitrary length haystack.
410414
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
411415
* Identical to `memrchr(haystack, needle[0], haystack_length)`.
412416
*/
413-
sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
417+
sz_cptr_t sz_rfind_byte_serial(sz_cptr_t h_chars, sz_size_t h_length, sz_cptr_t n_chars) {
414418

415419
if (!h_length) return SZ_NULL_CHAR;
416-
sz_cptr_t const h_start = h;
420+
// Reinterpret as unsigned bytes so the SWAR broadcast below cannot sign-extend
421+
// on platforms where `char` is signed (e.g. `-fsigned-char`). See issue #306.
422+
sz_u8_t const *const h_start = (sz_u8_t const *)h_chars;
423+
sz_u8_t const *const n = (sz_u8_t const *)n_chars;
417424

418425
// Reposition the `h` pointer to the end, as we will be walking backwards.
419-
h = h + h_length - 1;
426+
sz_u8_t const *h = h_start + h_length - 1;
420427

421428
#if !SZ_IS_BIG_ENDIAN_ // Use SWAR only on little-endian platforms for brevity.
422429
#if !SZ_USE_MISALIGNED_LOADS // Process the misaligned head, to void UB on unaligned 64-bit loads.
423430
for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h)
424-
if (*h == *n) return h;
431+
if (*h == *n) return (sz_cptr_t)h;
425432
#endif
426433

427434
// Broadcast the n into every byte of a 64-bit integer to use SWAR
428435
// techniques and process eight characters at a time.
429436
sz_u64_vec_t h_vec, n_vec, match_vec;
430-
n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull;
437+
n_vec.u64 = (sz_u64_t)*n * 0x0101010101010101ull;
431438
for (; h >= h_start + 7; h -= 8) {
432439
h_vec.u64 = *(sz_u64_t const *)(h - 7);
433440
match_vec = sz_u64_each_byte_equal_(h_vec, n_vec);
434-
if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8;
441+
if (match_vec.u64) return (sz_cptr_t)(h - sz_u64_clz(match_vec.u64) / 8);
435442
}
436443
#endif
437444

438445
for (; h >= h_start; --h)
439-
if (*h == *n) return h;
446+
if (*h == *n) return (sz_cptr_t)h;
440447
return SZ_NULL_CHAR;
441448
}
442449

scripts/test_stringzilla.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,6 +2256,79 @@ void test_comparisons() {
22562256
assert("a\0"_sv == "a\0"_sv);
22572257
}
22582258

2259+
/**
2260+
* @brief Regression test for signed-char bug in `sz_find_byte_serial` and `sz_rfind_byte_serial`.
2261+
* When compiled with `-fsigned-char`, bytes > 0x7F would be sign-extended during the SWAR
2262+
* broadcast multiplication, producing incorrect results. This test verifies that single-byte
2263+
* search works correctly for all byte values 0x00-0xFF across various haystack lengths.
2264+
* @see https://github.com/ashvardanian/StringZilla/issues/306
2265+
*/
2266+
void test_find_byte_serial_high_bytes() {
2267+
2268+
// Test every byte value in a haystack long enough to exercise the SWAR loop (>=8 bytes)
2269+
// and the scalar tail. We place the target byte at different positions to cover both paths.
2270+
sz_u8_t haystack_bytes[64];
2271+
2272+
// Fill the haystack with a neutral byte that won't match our needle
2273+
std::memset(haystack_bytes, 0x00, sizeof(haystack_bytes));
2274+
2275+
for (unsigned needle_byte = 0x80; needle_byte <= 0xFF; ++needle_byte) {
2276+
sz_u8_t needle_u8 = (sz_u8_t)needle_byte;
2277+
char const *needle = (char const *)&needle_u8;
2278+
2279+
// Test 1: needle in the SWAR-processed region (position 5, within first 8-byte block)
2280+
haystack_bytes[5] = needle_u8;
2281+
{
2282+
sz_cptr_t result = sz_find_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2283+
assert(result != SZ_NULL_CHAR && "sz_find_byte_serial must find high bytes in SWAR region");
2284+
assert((sz_size_t)(result - (sz_cptr_t)haystack_bytes) == 5);
2285+
}
2286+
{
2287+
sz_cptr_t result = sz_rfind_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2288+
assert(result != SZ_NULL_CHAR && "sz_rfind_byte_serial must find high bytes in SWAR region");
2289+
assert((sz_size_t)(result - (sz_cptr_t)haystack_bytes) == 5);
2290+
}
2291+
haystack_bytes[5] = 0x00;
2292+
2293+
// Test 2: needle in the scalar tail region (position 61, within last few bytes)
2294+
haystack_bytes[61] = needle_u8;
2295+
{
2296+
sz_cptr_t result = sz_find_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2297+
assert(result != SZ_NULL_CHAR && "sz_find_byte_serial must find high bytes in scalar tail");
2298+
assert((sz_size_t)(result - (sz_cptr_t)haystack_bytes) == 61);
2299+
}
2300+
{
2301+
sz_cptr_t result = sz_rfind_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2302+
assert(result != SZ_NULL_CHAR && "sz_rfind_byte_serial must find high bytes in scalar tail");
2303+
assert((sz_size_t)(result - (sz_cptr_t)haystack_bytes) == 61);
2304+
}
2305+
haystack_bytes[61] = 0x00;
2306+
2307+
// Test 3: needle not present - must return NULL
2308+
{
2309+
sz_cptr_t result = sz_find_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2310+
assert(result == SZ_NULL_CHAR && "sz_find_byte_serial must return NULL when byte is absent");
2311+
}
2312+
{
2313+
sz_cptr_t result = sz_rfind_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2314+
assert(result == SZ_NULL_CHAR && "sz_rfind_byte_serial must return NULL when byte is absent");
2315+
}
2316+
}
2317+
2318+
// Test 4: multiple occurrences - find returns first, rfind returns last
2319+
std::memset(haystack_bytes, 0x00, sizeof(haystack_bytes));
2320+
haystack_bytes[3] = 0xBE;
2321+
haystack_bytes[40] = 0xBE;
2322+
{
2323+
sz_u8_t n = 0xBE;
2324+
char const *needle = (char const *)&n;
2325+
sz_cptr_t first = sz_find_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2326+
sz_cptr_t last = sz_rfind_byte_serial((sz_cptr_t)haystack_bytes, sizeof(haystack_bytes), needle);
2327+
assert(first != SZ_NULL_CHAR && (sz_size_t)(first - (sz_cptr_t)haystack_bytes) == 3);
2328+
assert(last != SZ_NULL_CHAR && (sz_size_t)(last - (sz_cptr_t)haystack_bytes) == 40);
2329+
}
2330+
}
2331+
22592332
/**
22602333
* @brief Tests the correctness of the string class search methods, such as `find` and `find_first_of`.
22612334
* This covers haystacks and needles of different lengths, as well as character-sets.
@@ -4567,6 +4640,8 @@ int main(int argc, char const **argv) {
45674640
test_updates();
45684641

45694642
std::printf("\n=== Search and Comparison ===\n");
4643+
std::printf("- test_find_byte_serial_high_bytes...\n");
4644+
test_find_byte_serial_high_bytes();
45704645
std::printf("- test_comparisons...\n");
45714646
test_comparisons();
45724647
std::printf("- test_search...\n");

0 commit comments

Comments
 (0)