Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 575, regression tests #582

Merged
merged 11 commits into from
Jun 13, 2023
1 change: 1 addition & 0 deletions bytestring.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ library

-- DNDEBUG disables asserts in cbits/
cc-options: -std=c11 -DNDEBUG=1
-fno-strict-aliasing

-- No need to link to libgcc on ghc-9.4 and later which uses a clang-based
-- toolchain.
Expand Down
128 changes: 64 additions & 64 deletions cbits/aarch64/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ SUCH DAMAGE.
*/
#pragma GCC push_options
#pragma GCC optimize("-O2")
#include <arm_neon.h>
#include <stdbool.h>
#include <stdint.h>
#include <stddef.h>
#include <arm_neon.h>
#include <stdint.h>

// Fallback (for tails).
static inline int is_valid_utf8_fallback(uint8_t const *const src,
Expand Down Expand Up @@ -102,51 +102,60 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src,
}

static uint8_t const first_len_lookup[16] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
};

static uint8_t const first_range_lookup[16] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
};

static uint8_t const range_min_lookup[16] = {
0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
};

static uint8_t const range_max_lookup[16] = {
0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};

static uint8_t const range_adjust_lookup[32] = {
2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
};

static bool is_ascii (uint8x16_t const * const inputs) {
uint8x16_t const all_80 = vdupq_n_u8(0x80);
// A non-ASCII byte will have its highest-order bit set. Since this is
// preserved by OR, we can OR everything together.
uint8x16_t ored = vorrq_u8(vorrq_u8(inputs[0], inputs[1]),
vorrq_u8(inputs[2], inputs[3]));
// ANDing with 0x80 retains any set high-order bits. We then check for zeroes.
uint64x2_t result = vreinterpretq_u64_u8(vandq_u8(ored, all_80));
static bool is_ascii(uint8x16_t const *const inputs,
uint8x16_t const prev_first_len) {
// Check if we have ASCII, and also that we don't have to treat the prior
// block as special.
// First, verify that we didn't see any non-ASCII bytes in the first half of
// the stride.
uint8x16_t const first_half_clean = vorrq_u8(inputs[0], inputs[1]);
// Then we do the same for the second half of the stride.
uint8x16_t const second_half_clean = vorrq_u8(inputs[2], inputs[3]);
// Check cleanliness of the entire stride.
uint8x16_t const stride_clean = vorrq_u8(first_half_clean, second_half_clean);
// Leave only the high-order set bits.
uint8x16_t const masked = vandq_u8(stride_clean, vdupq_n_u8(0x80));
// Finally, check that we didn't have any leftover marker bytes in the
// previous block: these are indicated by non-zeroes in prev_first_len. In
// order to trigger a failure, we have to have non-zeroes set in the high bit
// of the lane: we do this by doing a greater-than comparison with a block of
// zeroes.
uint8x16_t const no_prior_dirt = vcgtq_u8(prev_first_len, vdupq_n_u8(0x00));
// Check for all-zero.
uint64x2_t const result =
vreinterpretq_u64_u8(vorrq_u8(masked, no_prior_dirt));
return !(vgetq_lane_u64(result, 0) || vgetq_lane_u64(result, 1));
}

static void check_block_neon(uint8x16_t const prev_input,
uint8x16_t const prev_first_len,
uint8x16_t* errors,
uint8x16_t const first_range_tbl,
uint8x16_t const range_min_tbl,
uint8x16_t const range_max_tbl,
uint8x16x2_t const range_adjust_tbl,
uint8x16_t const all_ones,
uint8x16_t const all_twos,
uint8x16_t const all_e0s,
uint8x16_t const input,
uint8x16_t const first_len) {
static void
check_block_neon(uint8x16_t const prev_input, uint8x16_t const prev_first_len,
uint8x16_t *errors, uint8x16_t const first_range_tbl,
uint8x16_t const range_min_tbl, uint8x16_t const range_max_tbl,
uint8x16x2_t const range_adjust_tbl, uint8x16_t const all_ones,
uint8x16_t const all_twos, uint8x16_t const all_e0s,
uint8x16_t const input, uint8x16_t const first_len) {
// Get the high 4-bits of the input.
uint8x16_t const high_nibbles = vshrq_n_u8(input, 4);
// Set range index to 8 for bytes in [C0, FF] by lookup (first byte).
Expand Down Expand Up @@ -182,20 +191,20 @@ static void check_block_neon(uint8x16_t const prev_input,
errors[1] = vorrq_u8(errors[1], vcgtq_u8(input, maxv));
}

int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
if (len == 0) {
return 1;
}
// We step 64 bytes at a time.
size_t const big_strides = len / 64;
size_t const remaining = len % 64;
uint8_t const * ptr = (uint8_t const *)src;
uint8_t const *ptr = (uint8_t const *)src;
// Tracking state
uint8x16_t prev_input = vdupq_n_u8(0);
uint8x16_t prev_first_len = vdupq_n_u8(0);
uint8x16_t errors[2] = {
vdupq_n_u8(0),
vdupq_n_u8(0),
vdupq_n_u8(0),
vdupq_n_u8(0),
};
// Load our lookup tables.
uint8x16_t const first_len_tbl = vld1q_u8(first_len_lookup);
Expand All @@ -209,40 +218,33 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
uint8x16_t const all_e0s = vdupq_n_u8(0xE0);
for (size_t i = 0; i < big_strides; i++) {
// Load 64 bytes
uint8x16_t const inputs[4] = {
vld1q_u8(ptr),
vld1q_u8(ptr + 16),
vld1q_u8(ptr + 32),
vld1q_u8(ptr + 48)
};
uint8x16_t const inputs[4] = {vld1q_u8(ptr), vld1q_u8(ptr + 16),
vld1q_u8(ptr + 32), vld1q_u8(ptr + 48)};
// Check if we have ASCII
if (is_ascii(inputs)) {
if (is_ascii(inputs, prev_first_len)) {
// Prev_first_len cheaply.
prev_first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4));
} else {
uint8x16_t first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4));
check_block_neon(prev_input, prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[0], first_len);
uint8x16_t first_len =
vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4));
check_block_neon(prev_input, prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[0], first_len);
prev_first_len = first_len;
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[1], 4));
check_block_neon(inputs[0], prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[1], first_len);
check_block_neon(inputs[0], prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[1], first_len);
prev_first_len = first_len;
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[2], 4));
check_block_neon(inputs[1], prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[2], first_len);
check_block_neon(inputs[1], prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[2], first_len);
prev_first_len = first_len;
first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4));
check_block_neon(inputs[2], prev_first_len, errors,
first_range_tbl, range_min_tbl, range_max_tbl,
range_adjust_tbl, all_ones, all_twos, all_e0s,
inputs[3], first_len);
check_block_neon(inputs[2], prev_first_len, errors, first_range_tbl,
range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones,
all_twos, all_e0s, inputs[3], first_len);
prev_first_len = first_len;
}
// Set prev_input based on last block.
Expand All @@ -260,19 +262,17 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) {
vst1q_lane_u32(&token, vreinterpretq_u32_u8(prev_input), 3);
// We cast this pointer to avoid a redundant check against < 127, as any such
// value would be negative in signed form.
int8_t const * token_ptr = (int8_t const *)&token;
int8_t const *token_ptr = (int8_t const *)&token;
clyring marked this conversation as resolved.
Show resolved Hide resolved
ptrdiff_t lookahead = 0;
if (token_ptr[3] > (int8_t)0xBF) {
lookahead = 1;
}
else if (token_ptr[2] > (int8_t)0xBF) {
} else if (token_ptr[2] > (int8_t)0xBF) {
lookahead = 2;
}
else if (token_ptr[1] > (int8_t)0xBF) {
} else if (token_ptr[1] > (int8_t)0xBF) {
lookahead = 3;
}
// Finish the job.
uint8_t const * const small_ptr = ptr - lookahead;
uint8_t const *const small_ptr = ptr - lookahead;
size_t const small_len = remaining + lookahead;
return is_valid_utf8_fallback(small_ptr, small_len);
}
Expand Down
91 changes: 65 additions & 26 deletions cbits/is-valid-utf8.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ SUCH DAMAGE.
#include <string.h>

#ifdef __x86_64__
#include <cpuid.h>
#include <emmintrin.h>
#include <immintrin.h>
#include <cpuid.h>
#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || defined(__clang_major__)) && !defined(__STDC_NO_ATOMICS__)
#include <tmmintrin.h>
#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || \
defined(__clang_major__)) && \
!defined(__STDC_NO_ATOMICS__)
#include <stdatomic.h>
#include <tmmintrin.h>
#else
// This is needed to support CentOS 7, which has a very old GCC.
#define CRUFTY_GCC
Expand All @@ -64,7 +66,8 @@ static inline uint64_t read_uint64(const uint64_t *p) {
return r;
}

static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const len) {
static inline int is_valid_utf8_fallback(uint8_t const *const src,
size_t const len) {
uint8_t const *ptr = (uint8_t const *)src;
// This is 'one past the end' to make loop termination and bounds checks
// easier.
Expand All @@ -83,10 +86,11 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const
// Non-ASCII bytes have a set MSB. Thus, if we AND with 0x80 in every
// 'lane', we will get 0 if everything is ASCII, and something else
// otherwise.
uint64_t results[4] = {to_little_endian(read_uint64(big_ptr)) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask};
uint64_t results[4] = {
to_little_endian(read_uint64(big_ptr)) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask,
to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask};
if (results[0] == 0) {
ptr += 8;
if (results[1] == 0) {
Expand All @@ -96,16 +100,16 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const
if (results[3] == 0) {
ptr += 8;
} else {
ptr += (__builtin_ctzl(results[3]) / 8);
ptr += (__builtin_ctzll(results[3]) / 8);
}
} else {
ptr += (__builtin_ctzl(results[2]) / 8);
ptr += (__builtin_ctzll(results[2]) / 8);
}
} else {
ptr += (__builtin_ctzl(results[1]) / 8);
ptr += (__builtin_ctzll(results[1]) / 8);
}
} else {
ptr += (__builtin_ctzl(results[0]) / 8);
ptr += (__builtin_ctzll(results[0]) / 8);
}
}
}
Expand Down Expand Up @@ -203,16 +207,16 @@ static inline int is_valid_utf8_sse2(uint8_t const *const src,
if (result == 0) {
ptr += 16;
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
} else {
ptr += __builtin_ctz(result);
ptr += __builtin_ctzll(result);
}
}
}
Expand Down Expand Up @@ -331,10 +335,26 @@ static int8_t const ef_fe_lookup[16] = {
};

__attribute__((target("ssse3"))) static inline bool
is_ascii_sse2(__m128i const *src) {
is_ascii_sse2(__m128i const *src, __m128i const prev_first_len) {
clyring marked this conversation as resolved.
Show resolved Hide resolved
// Check if we have ASCII, and also that we don't have to treat the prior
// block as special.
// First, verify that we didn't see any non-ASCII bytes in the first half of
// the stride.
__m128i const first_half_clean = _mm_or_si128(src[0], src[1]);
// Then do the same for the second half of the stride.
__m128i const second_half_clean = _mm_or_si128(src[2], src[3]);
// Check cleanliness of the entire stride.
__m128i const stride_clean =
_mm_or_si128(first_half_clean, second_half_clean);
// Finally, check that we didn't have any leftover marker bytes in the
// previous block: these are indicated by non-zeroes in prev_first_len. In
// order to trigger a failure, we have to have non-zeros set the high bit of
// the lane: we do this by doing a greater-than comparison with a block of
// zeroes.
__m128i const no_prior_dirt =
_mm_cmpgt_epi8(prev_first_len, _mm_setzero_si128());
clyring marked this conversation as resolved.
Show resolved Hide resolved
// OR together everything, then check for a high bit anywhere.
__m128i const ored =
_mm_or_si128(_mm_or_si128(src[0], src[1]), _mm_or_si128(src[2], src[3]));
__m128i const ored = _mm_or_si128(stride_clean, no_prior_dirt);
return (_mm_movemask_epi8(ored) == 0);
}

Expand Down Expand Up @@ -415,7 +435,7 @@ is_valid_utf8_ssse3(uint8_t const *const src, size_t const len) {
_mm_loadu_si128(big_ptr), _mm_loadu_si128(big_ptr + 1),
_mm_loadu_si128(big_ptr + 2), _mm_loadu_si128(big_ptr + 3)};
// Check if we have ASCII.
if (is_ascii_sse2(inputs)) {
if (is_ascii_sse2(inputs, prev_first_len)) {
// Prev_first_len cheaply.
prev_first_len =
_mm_shuffle_epi8(first_len_tbl, high_nibbles_of(inputs[3]));
Expand Down Expand Up @@ -598,10 +618,26 @@ is_valid_utf8_avx2(uint8_t const *const src, size_t const len) {
__m256i const inputs[4] = {
_mm256_loadu_si256(big_ptr), _mm256_loadu_si256(big_ptr + 1),
_mm256_loadu_si256(big_ptr + 2), _mm256_loadu_si256(big_ptr + 3)};
// Check if we have ASCII.
bool is_ascii = _mm256_movemask_epi8(_mm256_or_si256(
_mm256_or_si256(inputs[0], inputs[1]),
_mm256_or_si256(inputs[2], inputs[3]))) == 0;
// Check if we have ASCII, and also that we don't have to treat the prior
// block as special.
// First, verify that we didn't see any non-ASCII bytes in the first half of
// the stride.
__m256i const first_half_clean = _mm256_or_si256(inputs[0], inputs[1]);
// Then do the same for the second half of the stride.
__m256i const second_half_clean = _mm256_or_si256(inputs[2], inputs[3]);
// Check cleanliness of the entire stride.
__m256i const stride_clean =
_mm256_or_si256(first_half_clean, second_half_clean);
// Finally, check that we didn't have any leftover marker bytes in the
// previous block: these are indicated by non-zeroes in prev_first_len.
// In order to trigger a failure, we have to have non-zeros set the high bit
// of the lane: we do this by doing a greater-than comparison with a block
// of zeroes.
__m256i const no_prior_dirt =
_mm256_cmpgt_epi8(prev_first_len, _mm256_setzero_si256());
// Combine all checks together, and check if any high bits are set.
bool is_ascii =
clyring marked this conversation as resolved.
Show resolved Hide resolved
_mm256_movemask_epi8(_mm256_or_si256(stride_clean, no_prior_dirt)) == 0;
if (is_ascii) {
// Prev_first_len cheaply
prev_first_len =
Expand Down Expand Up @@ -683,7 +719,7 @@ static inline bool has_avx2() {
}
#endif

typedef int (*is_valid_utf8_t) (uint8_t const *const, size_t const);
typedef int (*is_valid_utf8_t)(uint8_t const *const, size_t const);

int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
if (len == 0) {
Expand All @@ -693,7 +729,10 @@ int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) {
static _Atomic is_valid_utf8_t s_impl = (is_valid_utf8_t)NULL;
is_valid_utf8_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed);
if (!impl) {
impl = has_avx2() ? is_valid_utf8_avx2 : (has_ssse3() ? is_valid_utf8_ssse3 : (has_sse2() ? is_valid_utf8_sse2 : is_valid_utf8_fallback));
impl = has_avx2() ? is_valid_utf8_avx2
: (has_ssse3() ? is_valid_utf8_ssse3
: (has_sse2() ? is_valid_utf8_sse2
: is_valid_utf8_fallback));
atomic_store_explicit(&s_impl, impl, memory_order_relaxed);
}
return (*impl)(src, len);
Expand Down
Loading