Skip to content

Commit

Permalink
Fix #575, regression tests (#582)
Browse files Browse the repository at this point in the history
* Refactor isValidUtf8 tests to enable shrinking

* Test isValidUtf8 on strings which are almost valid, with long ranges of ASCII

* Repair invalid UTF-8 issue, more tests

* Fix NEON ASCII check

* Refactor isValidUtf8 tests to enable shrinking

* Test isValidUtf8 on strings which are almost valid, with long ranges of ASCII

* Use ctzll to ensure 8 bytes get processed

* Set -fno-strict-aliasing

---------

Co-authored-by: Bodigrim <[email protected]>
  • Loading branch information
kozross and Bodigrim authored Jun 13, 2023
1 parent aa79cf8 commit dac5675
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 113 deletions.
1 change: 1 addition & 0 deletions bytestring.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ library

-- DNDEBUG disables asserts in cbits/
cc-options: -std=c11 -DNDEBUG=1
-fno-strict-aliasing

-- No need to link to libgcc on ghc-9.4 and later which uses a clang-based
-- toolchain.
Expand Down
128 changes: 64 additions & 64 deletions cbits/aarch64/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ SUCH DAMAGE.
*/
#pragma GCC push_options
#pragma GCC optimize("-O2")
#include <arm_neon.h>
#include <stdbool.h>
#include <stdint.h>
#include <stddef.h>
#include <arm_neon.h>
#include <stdint.h>

// Fallback (for tails).
static inline int is_valid_utf8_fallback(uint8_t const *const src,
Expand Down Expand Up @@ -102,51 +102,60 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src,
}

static uint8_t const first_len_lookup[16] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
};

static uint8_t const first_range_lookup[16] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
};

static uint8_t const range_min_lookup[16] = {
0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
};

static uint8_t const range_max_lookup[16] = {
0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};

static uint8_t const range_adjust_lookup[32] = {
2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
};

static bool is_ascii (uint8x16_t const * const inputs) {
uint8x16_t const all_80 = vdupq_n_u8(0x80);
// A non-ASCII byte will have its highest-order bit set. Since this is
// preserved by OR, we can OR everything together.
uint8x16_t ored = vorrq_u8(vorrq_u8(inputs[0], inputs[1]),
vorrq_u8(inputs[2], inputs[3]));
// ANDing with 0x80 retains any set high-order bits. We then check for zeroes.
uint64x2_t result = vreinterpretq_u64_u8(vandq_u8(ored, all_80));
static bool is_ascii(uint8x16_t const *const inputs,
uint8x16_t const prev_first_len) {
// Check if we have ASCII, and also that we don't have to treat the prior
// block as special.
// First, verify that we didn't see any non-ASCII bytes in the first half of
// the stride.
uint8x16_t const first_half_clean = vorrq_u8(inputs[0], inputs[1]);
// Then we do the same for the second half of the stride.
uint8x16_t const second_half_clean = vorrq_u8(inputs[2], inputs[3]);
// Check cleanliness of the entire stride.
uint8x16_t const stride_clean = vorrq_u8(first_half_clean, second_half_clean);
// Leave only the high-order set bits.
uint8x16_t const masked = vandq_u8(stride_clean, vdupq_n_u8(0x80));
// Finally, check that we didn't have any leftover marker bytes in the
// previous block: these are indicated by non-zeroes in prev_first_len. In
// order to trigger a failure, we have to have non-zeroes set in the high bit
// of the lane: we do this by doing a greater-than comparison with a block of
// zeroes.
uint8x16_t const no_prior_dirt = vcgtq_u8(prev_first_len, vdupq_n_u8(0x00));
// Check for all-zero.
uint64x2_t const result =
vreinterpretq_u64_u8(vorrq_u8(masked, no_prior_dirt));
return !(vgetq_lane_u64(result, 0) || vgetq_lane_u64(result, 1));
}

static void check_block_neon(uint8x16_t const prev_input,
uint8x16_t const prev_first_len,
uint8x16_t* errors,
uint8x16_t const first_range_tbl,
uint8x16_t const range_min_tbl,
uint8x16_t const range_max_tbl,
uint8x16x2_t const range_adjust_tbl,
uint8x16_t const all_ones,
uint8x16_t const all_twos,
uint8x16_t const all_e0s,
uint8x16_t const input,
uint8x16_t const first_len) {
static void
check_block_neon(uint8x16_t const prev_input, uint8x16_t const prev_first_len,
uint8x16_t *errors, uint8x16_t const first_range_tbl,
uint8x16_t const range_min_tbl, uint8x16_t const range_max_tbl,
uint8x16x2_t const range_adjust_tbl, uint8x16_t const all_ones,
uint8x16_t const all_twos, uint8x16_t const all_e0s,
uint8x16_t const input, uint8x16_t const first_len) {
// Get the high 4-bits of the input.
uint8x16_t const high_nibbles = vshrq_n_u8(input, 4);
// Set range index to 8 for bytes in [C0, FF] by lookup (first byte).
Expand Down Expand Up @@ -182,20 +191,20 @@ static void check_block_neon(uint8x16_t const prev_input,
errors[1] = vorrq_u8(errors[1], vcgtq_u8(input, maxv));
}

int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
if (len == 0) {
return 1;
}
// We step 64 bytes at a time.
size_t const big_strides = len / 64;
size_t const remaining = len % 64;
uint8_t const * ptr = (uint8_t const *)src;
uint8_t const *ptr = (uint8_t const *)src;
// Tracking state
uint8x16_t prev_input = vdupq_n_u8(0);
uint8x16_t prev_first_len = vdupq_n_u8(0);
uint8x16_t errors[2] = {
vdupq_n_u8(0),
vdupq_n_u8(0),
vdupq_n_u8(0),
vdupq_n_u8(0),
};
// Load our lookup tables.
uint8x16_t const first_len_tbl = vld1q_u8(first_len_lookup);
Expand All @@ -209,40 +218,33 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
uint8x16_t const all_e0s = vdupq_n_u8(0xE0);
for (size_t i = 0; i < big_strides; i++) {
// Load 64 bytes
uint8x16_t const inputs[4] = {
vld1q_u8(ptr),
vld1q_u8(ptr + 16),
vld1q_u8(ptr + 32),
vld1q_u8(ptr + 48)
};
uint8x16_t const inputs[4] = {vld1q_u8(ptr), vld1q_u8(ptr + 16),
vld1q_u8(ptr + 32), vld1q_u8(ptr + 48)};
// Check if we have ASCII
if (is_ascii(inputs)) {
if (is_ascii(inputs, prev_first_len)) {
// Prev_first_len cheaply.
prev_first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4));
} else {
uint8x16_t first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4));
check_block_neon(prev_input, prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[0], first_len);
uint8x16_t first_len =
vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4));
check_block_neon(prev_input, prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[0], first_len);
prev_first_len = first_len;
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[1], 4));
check_block_neon(inputs[0], prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[1], first_len);
check_block_neon(inputs[0], prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[1], first_len);
prev_first_len = first_len;
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[2], 4));
check_block_neon(inputs[1], prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[2], first_len);
check_block_neon(inputs[1], prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[2], first_len);
prev_first_len = first_len;
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4));
check_block_neon(inputs[2], prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[3], first_len);
check_block_neon(inputs[2], prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[3], first_len);
prev_first_len = first_len;
}
// Set prev_input based on last block.
Expand All @@ -260,19 +262,17 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
vst1q_lane_u32(&token, vreinterpretq_u32_u8(prev_input), 3);
// We cast this pointer to avoid a redundant check against < 127, as any such
// value would be negative in signed form.
int8_t const * token_ptr = (int8_t const *)&token;
int8_t const *token_ptr = (int8_t const *)&token;
ptrdiff_t lookahead = 0;
if (token_ptr[3] > (int8_t)0xBF) {
lookahead = 1;
}
else if (token_ptr[2] > (int8_t)0xBF) {
} else if (token_ptr[2] > (int8_t)0xBF) {
lookahead = 2;
}
else if (token_ptr[1] > (int8_t)0xBF) {
} else if (token_ptr[1] > (int8_t)0xBF) {
lookahead = 3;
}
// Finish the job.
uint8_t const * const small_ptr = ptr - lookahead;
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
return is_valid_utf8_fallback(small_ptr, small_len);
}
Expand Down
91 changes: 65 additions & 26 deletions cbits/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ SUCH DAMAGE.
#include <string.h>

#ifdef __x86_64__
#include <cpuid.h>
#include <emmintrin.h>
#include <immintrin.h>
#include <cpuid.h>
#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || defined(__clang_major__)) && !defined(__STDC_NO_ATOMICS__)
#include <tmmintrin.h>
#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || \
defined(__clang_major__)) && \
!defined(__STDC_NO_ATOMICS__)
#include <stdatomic.h>
#include <tmmintrin.h>
#else
// This is needed to support CentOS 7, which has a very old GCC.
#define CRUFTY_GCC
Expand All @@ -64,7 +66,8 @@ static inline uint64_t read_uint64(const uint64_t *p) {
return r;
}

static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const len) {
static inline int is_valid_utf8_fallback(uint8_t const *const src,
size_t const len) {
uint8_t const *ptr = (uint8_t const *)src;
// This is 'one past the end' to make loop termination and bounds checks
// easier.
Expand All @@ -83,10 +86,11 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const
// Non-ASCII bytes have a set MSB. Thus, if we AND with 0x80 in every
// 'lane', we will get 0 if everything is ASCII, and something else
// otherwise.
uint64_t results[4] = {to_little_endian(read_uint64(big_ptr)) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask};
uint64_t results[4] = {
to_little_endian(read_uint64(big_ptr)) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask};
if (results[0] == 0) {
ptr += 8;
if (results[1] == 0) {
Expand All @@ -96,16 +100,16 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const
if (results[3] == 0) {
ptr += 8;
} else {
ptr += (__builtin_ctzl(results[3]) / 8);
ptr += (__builtin_ctzll(results[3]) / 8);
}
} else {
ptr += (__builtin_ctzl(results[2]) / 8);
ptr += (__builtin_ctzll(results[2]) / 8);
}
} else {
ptr += (__builtin_ctzl(results[1]) / 8);
ptr += (__builtin_ctzll(results[1]) / 8);
}
} else {
ptr += (__builtin_ctzl(results[0]) / 8);
ptr += (__builtin_ctzll(results[0]) / 8);
}
}
}
Expand Down Expand Up @@ -203,16 +207,16 @@ static inline int is_valid_utf8_sse2(uint8_t const *const src,
if (result == 0) {
ptr += 16;
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
}
}
Expand Down Expand Up @@ -331,10 +335,26 @@ static int8_t const ef_fe_lookup[16] = {
};

__attribute__((target("ssse3"))) static inline bool
is_ascii_sse2(__m128i const *src) {
is_ascii_sse2(__m128i const *src, __m128i const prev_first_len) {
// Check if we have ASCII, and also that we don't have to treat the prior
// block as special.
// First, verify that we didn't see any non-ASCII bytes in the first half of
// the stride.
__m128i const first_half_clean = _mm_or_si128(src[0], src[1]);
// Then do the same for the second half of the stride.
__m128i const second_half_clean = _mm_or_si128(src[2], src[3]);
// Check cleanliness of the entire stride.
__m128i const stride_clean =
_mm_or_si128(first_half_clean, second_half_clean);
// Finally, check that we didn't have any leftover marker bytes in the
// previous block: these are indicated by non-zeroes in prev_first_len. In
// order to trigger a failure, we have to have non-zeros set the high bit of
// the lane: we do this by doing a greater-than comparison with a block of
// zeroes.
__m128i const no_prior_dirt =
_mm_cmpgt_epi8(prev_first_len, _mm_setzero_si128());
// OR together everything, then check for a high bit anywhere.
__m128i const ored =
_mm_or_si128(_mm_or_si128(src[0], src[1]), _mm_or_si128(src[2], src[3]));
__m128i const ored = _mm_or_si128(stride_clean, no_prior_dirt);
return (_mm_movemask_epi8(ored) == 0);
}

Expand Down Expand Up @@ -415,7 +435,7 @@ is_valid_utf8_ssse3(uint8_t const *const src, size_t const len) {
_mm_loadu_si128(big_ptr), _mm_loadu_si128(big_ptr + 1),
_mm_loadu_si128(big_ptr + 2), _mm_loadu_si128(big_ptr + 3)};
// Check if we have ASCII.
if (is_ascii_sse2(inputs)) {
if (is_ascii_sse2(inputs, prev_first_len)) {
// Prev_first_len cheaply.
prev_first_len =
_mm_shuffle_epi8(first_len_tbl, high_nibbles_of(inputs[3]));
Expand Down Expand Up @@ -598,10 +618,26 @@ is_valid_utf8_avx2(uint8_t const *const src, size_t const len) {
__m256i const inputs[4] = {
_mm256_loadu_si256(big_ptr), _mm256_loadu_si256(big_ptr + 1),
_mm256_loadu_si256(big_ptr + 2), _mm256_loadu_si256(big_ptr + 3)};
// Check if we have ASCII.
bool is_ascii = _mm256_movemask_epi8(_mm256_or_si256(
_mm256_or_si256(inputs[0], inputs[1]),
_mm256_or_si256(inputs[2], inputs[3]))) == 0;
// Check if we have ASCII, and also that we don't have to treat the prior
// block as special.
// First, verify that we didn't see any non-ASCII bytes in the first half of
// the stride.
__m256i const first_half_clean = _mm256_or_si256(inputs[0], inputs[1]);
// Then do the same for the second half of the stride.
__m256i const second_half_clean = _mm256_or_si256(inputs[2], inputs[3]);
// Check cleanliness of the entire stride.
__m256i const stride_clean =
_mm256_or_si256(first_half_clean, second_half_clean);
// Finally, check that we didn't have any leftover marker bytes in the
// previous block: these are indicated by non-zeroes in prev_first_len.
// In order to trigger a failure, we have to have non-zeros set the high bit
// of the lane: we do this by doing a greater-than comparison with a block
// of zeroes.
__m256i const no_prior_dirt =
_mm256_cmpgt_epi8(prev_first_len, _mm256_setzero_si256());
// Combine all checks together, and check if any high bits are set.
bool is_ascii =
_mm256_movemask_epi8(_mm256_or_si256(stride_clean, no_prior_dirt)) == 0;
if (is_ascii) {
// Prev_first_len cheaply
prev_first_len =
Expand Down Expand Up @@ -683,7 +719,7 @@ static inline bool has_avx2() {
}
#endif

typedef int (*is_valid_utf8_t) (uint8_t const *const, size_t const);
typedef int (*is_valid_utf8_t)(uint8_t const *const, size_t const);

int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
if (len == 0) {
Expand All @@ -693,7 +729,10 @@ int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
static _Atomic is_valid_utf8_t s_impl = (is_valid_utf8_t)NULL;
is_valid_utf8_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed);
if (!impl) {
impl = has_avx2() ? is_valid_utf8_avx2 : (has_ssse3() ? is_valid_utf8_ssse3 : (has_sse2() ? is_valid_utf8_sse2 : is_valid_utf8_fallback));
impl = has_avx2() ? is_valid_utf8_avx2
: (has_ssse3() ? is_valid_utf8_ssse3
: (has_sse2() ? is_valid_utf8_sse2
: is_valid_utf8_fallback));
atomic_store_explicit(&s_impl, impl, memory_order_relaxed);
}
return (*impl)(src, len);
Expand Down
Loading

0 comments on commit dac5675

Please sign in to comment.