diff --git a/bytestring.cabal b/bytestring.cabal index 510f12771..dc6c3e5e7 100644 --- a/bytestring.cabal +++ b/bytestring.cabal @@ -133,6 +133,7 @@ library -- DNDEBUG disables asserts in cbits/ cc-options: -std=c11 -DNDEBUG=1 + -fno-strict-aliasing -- No need to link to libgcc on ghc-9.4 and later which uses a clang-based -- toolchain. diff --git a/cbits/aarch64/is-valid-utf8.c b/cbits/aarch64/is-valid-utf8.c index e79c8b94b..afb446b12 100644 --- a/cbits/aarch64/is-valid-utf8.c +++ b/cbits/aarch64/is-valid-utf8.c @@ -29,10 +29,10 @@ SUCH DAMAGE. */ #pragma GCC push_options #pragma GCC optimize("-O2") +#include #include -#include #include -#include +#include // Fallback (for tails). static inline int is_valid_utf8_fallback(uint8_t const *const src, @@ -102,51 +102,60 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, } static uint8_t const first_len_lookup[16] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3, }; static uint8_t const first_range_lookup[16] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, }; static uint8_t const range_min_lookup[16] = { - 0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80, - 0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80, + 0xC2, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, }; static uint8_t const range_max_lookup[16] = { - 0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F, - 0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F, + 0xF4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }; static uint8_t const range_adjust_lookup[32] = { - 2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, + 2, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, }; -static bool is_ascii (uint8x16_t const * const inputs) { - uint8x16_t const all_80 = vdupq_n_u8(0x80); - // A non-ASCII byte will have its highest-order bit set. Since this is - // preserved by OR, we can OR everything together. - uint8x16_t ored = vorrq_u8(vorrq_u8(inputs[0], inputs[1]), - vorrq_u8(inputs[2], inputs[3])); - // ANDing with 0x80 retains any set high-order bits. We then check for zeroes. - uint64x2_t result = vreinterpretq_u64_u8(vandq_u8(ored, all_80)); +static bool is_ascii(uint8x16_t const *const inputs, + uint8x16_t const prev_first_len) { + // Check if we have ASCII, and also that we don't have to treat the prior + // block as special. + // First, verify that we didn't see any non-ASCII bytes in the first half of + // the stride. + uint8x16_t const first_half_clean = vorrq_u8(inputs[0], inputs[1]); + // Then we do the same for the second half of the stride. + uint8x16_t const second_half_clean = vorrq_u8(inputs[2], inputs[3]); + // Check cleanliness of the entire stride. + uint8x16_t const stride_clean = vorrq_u8(first_half_clean, second_half_clean); + // Leave only the high-order set bits. + uint8x16_t const masked = vandq_u8(stride_clean, vdupq_n_u8(0x80)); + // Finally, check that we didn't have any leftover marker bytes in the + // previous block: these are indicated by non-zeroes in prev_first_len. In + // order to trigger a failure, we have to have non-zeroes set in the high bit + // of the lane: we do this by doing a greater-than comparison with a block of + // zeroes. + uint8x16_t const no_prior_dirt = vcgtq_u8(prev_first_len, vdupq_n_u8(0x00)); + // Check for all-zero. + uint64x2_t const result = + vreinterpretq_u64_u8(vorrq_u8(masked, no_prior_dirt)); return !(vgetq_lane_u64(result, 0) || vgetq_lane_u64(result, 1)); } -static void check_block_neon(uint8x16_t const prev_input, - uint8x16_t const prev_first_len, - uint8x16_t* errors, - uint8x16_t const first_range_tbl, - uint8x16_t const range_min_tbl, - uint8x16_t const range_max_tbl, - uint8x16x2_t const range_adjust_tbl, - uint8x16_t const all_ones, - uint8x16_t const all_twos, - uint8x16_t const all_e0s, - uint8x16_t const input, - uint8x16_t const first_len) { +static void +check_block_neon(uint8x16_t const prev_input, uint8x16_t const prev_first_len, + uint8x16_t *errors, uint8x16_t const first_range_tbl, + uint8x16_t const range_min_tbl, uint8x16_t const range_max_tbl, + uint8x16x2_t const range_adjust_tbl, uint8x16_t const all_ones, + uint8x16_t const all_twos, uint8x16_t const all_e0s, + uint8x16_t const input, uint8x16_t const first_len) { // Get the high 4-bits of the input. uint8x16_t const high_nibbles = vshrq_n_u8(input, 4); // Set range index to 8 for bytes in [C0, FF] by lookup (first byte). @@ -182,20 +191,20 @@ static void check_block_neon(uint8x16_t const prev_input, errors[1] = vorrq_u8(errors[1], vcgtq_u8(input, maxv)); } -int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) { +int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) { if (len == 0) { return 1; } // We step 64 bytes at a time. size_t const big_strides = len / 64; size_t const remaining = len % 64; - uint8_t const * ptr = (uint8_t const *)src; + uint8_t const *ptr = (uint8_t const *)src; // Tracking state uint8x16_t prev_input = vdupq_n_u8(0); uint8x16_t prev_first_len = vdupq_n_u8(0); uint8x16_t errors[2] = { - vdupq_n_u8(0), - vdupq_n_u8(0), + vdupq_n_u8(0), + vdupq_n_u8(0), }; // Load our lookup tables. uint8x16_t const first_len_tbl = vld1q_u8(first_len_lookup); @@ -209,40 +218,33 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) { uint8x16_t const all_e0s = vdupq_n_u8(0xE0); for (size_t i = 0; i < big_strides; i++) { // Load 64 bytes - uint8x16_t const inputs[4] = { - vld1q_u8(ptr), - vld1q_u8(ptr + 16), - vld1q_u8(ptr + 32), - vld1q_u8(ptr + 48) - }; + uint8x16_t const inputs[4] = {vld1q_u8(ptr), vld1q_u8(ptr + 16), + vld1q_u8(ptr + 32), vld1q_u8(ptr + 48)}; // Check if we have ASCII - if (is_ascii(inputs)) { + if (is_ascii(inputs, prev_first_len)) { // Prev_first_len cheaply. prev_first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4)); } else { - uint8x16_t first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4)); - check_block_neon(prev_input, prev_first_len, errors, - first_range_tbl, range_min_tbl, range_max_tbl, - range_adjust_tbl, all_ones, all_twos, all_e0s, - inputs[0], first_len); + uint8x16_t first_len = + vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[0], 4)); + check_block_neon(prev_input, prev_first_len, errors, first_range_tbl, + range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones, + all_twos, all_e0s, inputs[0], first_len); prev_first_len = first_len; first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[1], 4)); - check_block_neon(inputs[0], prev_first_len, errors, - first_range_tbl, range_min_tbl, range_max_tbl, - range_adjust_tbl, all_ones, all_twos, all_e0s, - inputs[1], first_len); + check_block_neon(inputs[0], prev_first_len, errors, first_range_tbl, + range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones, + all_twos, all_e0s, inputs[1], first_len); prev_first_len = first_len; first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[2], 4)); - check_block_neon(inputs[1], prev_first_len, errors, - first_range_tbl, range_min_tbl, range_max_tbl, - range_adjust_tbl, all_ones, all_twos, all_e0s, - inputs[2], first_len); + check_block_neon(inputs[1], prev_first_len, errors, first_range_tbl, + range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones, + all_twos, all_e0s, inputs[2], first_len); prev_first_len = first_len; first_len = vqtbl1q_u8(first_len_tbl, vshrq_n_u8(inputs[3], 4)); - check_block_neon(inputs[2], prev_first_len, errors, - first_range_tbl, range_min_tbl, range_max_tbl, - range_adjust_tbl, all_ones, all_twos, all_e0s, - inputs[3], first_len); + check_block_neon(inputs[2], prev_first_len, errors, first_range_tbl, + range_min_tbl, range_max_tbl, range_adjust_tbl, all_ones, + all_twos, all_e0s, inputs[3], first_len); prev_first_len = first_len; } // Set prev_input based on last block. @@ -260,19 +262,17 @@ int bytestring_is_valid_utf8(uint8_t const * const src, size_t const len) { vst1q_lane_u32(&token, vreinterpretq_u32_u8(prev_input), 3); // We cast this pointer to avoid a redundant check against < 127, as any such // value would be negative in signed form. - int8_t const * token_ptr = (int8_t const *)&token; + int8_t const *token_ptr = (int8_t const *)&token; ptrdiff_t lookahead = 0; if (token_ptr[3] > (int8_t)0xBF) { lookahead = 1; - } - else if (token_ptr[2] > (int8_t)0xBF) { + } else if (token_ptr[2] > (int8_t)0xBF) { lookahead = 2; - } - else if (token_ptr[1] > (int8_t)0xBF) { + } else if (token_ptr[1] > (int8_t)0xBF) { lookahead = 3; } // Finish the job. - uint8_t const * const small_ptr = ptr - lookahead; + uint8_t const *const small_ptr = ptr - lookahead; size_t const small_len = remaining + lookahead; return is_valid_utf8_fallback(small_ptr, small_len); } diff --git a/cbits/is-valid-utf8.c b/cbits/is-valid-utf8.c index f8862bf3c..a9cc0fe8a 100644 --- a/cbits/is-valid-utf8.c +++ b/cbits/is-valid-utf8.c @@ -35,12 +35,14 @@ SUCH DAMAGE. #include #ifdef __x86_64__ +#include #include #include -#include -#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || defined(__clang_major__)) && !defined(__STDC_NO_ATOMICS__) -#include +#if (__GNUC__ >= 7 || __GNUC__ == 6 && __GNUC_MINOR__ >= 3 || \ + defined(__clang_major__)) && \ + !defined(__STDC_NO_ATOMICS__) #include +#include #else // This is needed to support CentOS 7, which has a very old GCC. #define CRUFTY_GCC @@ -64,7 +66,8 @@ static inline uint64_t read_uint64(const uint64_t *p) { return r; } -static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const len) { +static inline int is_valid_utf8_fallback(uint8_t const *const src, + size_t const len) { uint8_t const *ptr = (uint8_t const *)src; // This is 'one past the end' to make loop termination and bounds checks // easier. @@ -83,10 +86,11 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const // Non-ASCII bytes have a set MSB. Thus, if we AND with 0x80 in every // 'lane', we will get 0 if everything is ASCII, and something else // otherwise. - uint64_t results[4] = {to_little_endian(read_uint64(big_ptr)) & high_bits_mask, - to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask, - to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask, - to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask}; + uint64_t results[4] = { + to_little_endian(read_uint64(big_ptr)) & high_bits_mask, + to_little_endian(read_uint64((big_ptr + 1))) & high_bits_mask, + to_little_endian(read_uint64((big_ptr + 2))) & high_bits_mask, + to_little_endian(read_uint64((big_ptr + 3))) & high_bits_mask}; if (results[0] == 0) { ptr += 8; if (results[1] == 0) { @@ -96,16 +100,16 @@ static inline int is_valid_utf8_fallback(uint8_t const *const src, size_t const if (results[3] == 0) { ptr += 8; } else { - ptr += (__builtin_ctzl(results[3]) / 8); + ptr += (__builtin_ctzll(results[3]) / 8); } } else { - ptr += (__builtin_ctzl(results[2]) / 8); + ptr += (__builtin_ctzll(results[2]) / 8); } } else { - ptr += (__builtin_ctzl(results[1]) / 8); + ptr += (__builtin_ctzll(results[1]) / 8); } } else { - ptr += (__builtin_ctzl(results[0]) / 8); + ptr += (__builtin_ctzll(results[0]) / 8); } } } @@ -203,16 +207,16 @@ static inline int is_valid_utf8_sse2(uint8_t const *const src, if (result == 0) { ptr += 16; } else { - ptr += __builtin_ctz(result); + ptr += __builtin_ctzll(result); } } else { - ptr += __builtin_ctz(result); + ptr += __builtin_ctzll(result); } } else { - ptr += __builtin_ctz(result); + ptr += __builtin_ctzll(result); } } else { - ptr += __builtin_ctz(result); + ptr += __builtin_ctzll(result); } } } @@ -331,10 +335,26 @@ static int8_t const ef_fe_lookup[16] = { }; __attribute__((target("ssse3"))) static inline bool -is_ascii_sse2(__m128i const *src) { +is_ascii_sse2(__m128i const *src, __m128i const prev_first_len) { + // Check if we have ASCII, and also that we don't have to treat the prior + // block as special. + // First, verify that we didn't see any non-ASCII bytes in the first half of + // the stride. + __m128i const first_half_clean = _mm_or_si128(src[0], src[1]); + // Then do the same for the second half of the stride. + __m128i const second_half_clean = _mm_or_si128(src[2], src[3]); + // Check cleanliness of the entire stride. + __m128i const stride_clean = + _mm_or_si128(first_half_clean, second_half_clean); + // Finally, check that we didn't have any leftover marker bytes in the + // previous block: these are indicated by non-zeroes in prev_first_len. In + // order to trigger a failure, we have to have non-zeros set the high bit of + // the lane: we do this by doing a greater-than comparison with a block of + // zeroes. + __m128i const no_prior_dirt = + _mm_cmpgt_epi8(prev_first_len, _mm_setzero_si128()); // OR together everything, then check for a high bit anywhere. - __m128i const ored = - _mm_or_si128(_mm_or_si128(src[0], src[1]), _mm_or_si128(src[2], src[3])); + __m128i const ored = _mm_or_si128(stride_clean, no_prior_dirt); return (_mm_movemask_epi8(ored) == 0); } @@ -415,7 +435,7 @@ is_valid_utf8_ssse3(uint8_t const *const src, size_t const len) { _mm_loadu_si128(big_ptr), _mm_loadu_si128(big_ptr + 1), _mm_loadu_si128(big_ptr + 2), _mm_loadu_si128(big_ptr + 3)}; // Check if we have ASCII. - if (is_ascii_sse2(inputs)) { + if (is_ascii_sse2(inputs, prev_first_len)) { // Prev_first_len cheaply. prev_first_len = _mm_shuffle_epi8(first_len_tbl, high_nibbles_of(inputs[3])); @@ -598,10 +618,26 @@ is_valid_utf8_avx2(uint8_t const *const src, size_t const len) { __m256i const inputs[4] = { _mm256_loadu_si256(big_ptr), _mm256_loadu_si256(big_ptr + 1), _mm256_loadu_si256(big_ptr + 2), _mm256_loadu_si256(big_ptr + 3)}; - // Check if we have ASCII. - bool is_ascii = _mm256_movemask_epi8(_mm256_or_si256( - _mm256_or_si256(inputs[0], inputs[1]), - _mm256_or_si256(inputs[2], inputs[3]))) == 0; + // Check if we have ASCII, and also that we don't have to treat the prior + // block as special. + // First, verify that we didn't see any non-ASCII bytes in the first half of + // the stride. + __m256i const first_half_clean = _mm256_or_si256(inputs[0], inputs[1]); + // Then do the same for the second half of the stride. + __m256i const second_half_clean = _mm256_or_si256(inputs[2], inputs[3]); + // Check cleanliness of the entire stride. + __m256i const stride_clean = + _mm256_or_si256(first_half_clean, second_half_clean); + // Finally, check that we didn't have any leftover marker bytes in the + // previous block: these are indicated by non-zeroes in prev_first_len. + // In order to trigger a failure, we have to have non-zeros set the high bit + // of the lane: we do this by doing a greater-than comparison with a block + // of zeroes. + __m256i const no_prior_dirt = + _mm256_cmpgt_epi8(prev_first_len, _mm256_setzero_si256()); + // Combine all checks together, and check if any high bits are set. + bool is_ascii = + _mm256_movemask_epi8(_mm256_or_si256(stride_clean, no_prior_dirt)) == 0; if (is_ascii) { // Prev_first_len cheaply prev_first_len = @@ -683,7 +719,7 @@ static inline bool has_avx2() { } #endif -typedef int (*is_valid_utf8_t) (uint8_t const *const, size_t const); +typedef int (*is_valid_utf8_t)(uint8_t const *const, size_t const); int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) { if (len == 0) { @@ -693,7 +729,10 @@ int bytestring_is_valid_utf8(uint8_t const *const src, size_t const len) { static _Atomic is_valid_utf8_t s_impl = (is_valid_utf8_t)NULL; is_valid_utf8_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed); if (!impl) { - impl = has_avx2() ? is_valid_utf8_avx2 : (has_ssse3() ? is_valid_utf8_ssse3 : (has_sse2() ? is_valid_utf8_sse2 : is_valid_utf8_fallback)); + impl = has_avx2() ? is_valid_utf8_avx2 + : (has_ssse3() ? is_valid_utf8_ssse3 + : (has_sse2() ? is_valid_utf8_sse2 + : is_valid_utf8_fallback)); atomic_store_explicit(&s_impl, impl, memory_order_relaxed); } return (*impl)(src, len); diff --git a/tests/IsValidUtf8.hs b/tests/IsValidUtf8.hs index a09f587dc..7901a2d58 100644 --- a/tests/IsValidUtf8.hs +++ b/tests/IsValidUtf8.hs @@ -8,16 +8,18 @@ import qualified Data.ByteString.Short as SBS import qualified Data.ByteString as B import Data.Char (chr, ord) import Data.Word (Word8) -import GHC.Exts (fromList) -import Test.QuickCheck (Property, forAll, (===)) +import Control.Monad (guard) +import Numeric (showHex) +import GHC.Exts (fromList, fromListN, toList) +import Test.QuickCheck (Property, forAll, (===), forAllShrinkShow) import Test.QuickCheck.Arbitrary (Arbitrary (arbitrary, shrink)) import Test.QuickCheck.Gen (oneof, Gen, choose, vectorOf, listOf1, sized, resize, - elements) + elements, choose) import Test.Tasty (testGroup, adjustOption, TestTree) import Test.Tasty.QuickCheck (testProperty, QuickCheckTests) testSuite :: TestTree -testSuite = testGroup "UTF-8 validation" $ [ +testSuite = testGroup "UTF-8 validation" [ adjustOption (max testCount) . testProperty "Valid UTF-8 ByteString" $ goValidBS, adjustOption (max testCount) . testProperty "Invalid UTF-8 ByteString" $ goInvalidBS, adjustOption (max testCount) . testProperty "Valid UTF-8 ShortByteString" $ goValidSBS, @@ -25,24 +27,30 @@ testSuite = testGroup "UTF-8 validation" $ [ testGroup "Regressions" checkRegressions ] where - goValidBS :: Property - goValidBS = forAll arbitrary $ - \(ValidUtf8 ss) -> (B.isValidUtf8 . foldMap sequenceToBS $ ss) === True - goInvalidBS :: Property - goInvalidBS = forAll arbitrary $ - \inv -> (B.isValidUtf8 . toByteString $ inv) === False - goValidSBS :: Property - goValidSBS = forAll arbitrary $ - \(ValidUtf8 ss) -> (SBS.isValidUtf8 . SBS.toShort . foldMap sequenceToBS $ ss) === True - goInvalidSBS :: Property - goInvalidSBS = forAll arbitrary $ - \inv -> (SBS.isValidUtf8 . SBS.toShort . toByteString $ inv) === False + goValidBS :: ValidUtf8 -> Bool + goValidBS = B.isValidUtf8 . foldMap sequenceToBS . unValidUtf8 + goInvalidBS :: InvalidUtf8 -> Bool + goInvalidBS = not . B.isValidUtf8 . toByteString + goValidSBS :: ValidUtf8 -> Bool + goValidSBS = SBS.isValidUtf8 . SBS.toShort . foldMap sequenceToBS . unValidUtf8 + goInvalidSBS :: InvalidUtf8 -> Bool + goInvalidSBS = not . SBS.isValidUtf8 . SBS.toShort . toByteString testCount :: QuickCheckTests testCount = 1000 checkRegressions :: [TestTree] checkRegressions = [ - testProperty "Too high code point" $ not $ B.isValidUtf8 tooHigh + testProperty "Too high code point" $ + not $ B.isValidUtf8 tooHigh, + testProperty "Invalid byte at end of ASCII block" badBlockEnd, + testProperty "Invalid byte between spaces" $ + not $ B.isValidUtf8 byteBetweenSpaces, + testProperty "Two invalid bytes between spaces" $ + not $ B.isValidUtf8 twoBytesBetweenSpaces, + testProperty "Three invalid bytes between spaces" $ + not $ B.isValidUtf8 threeBytesBetweenSpaces, + testProperty "ASCII stride and invalid multibyte sequence" $ + not $ B.isValidUtf8 asciiAndInvalidMultiByte ] where tooHigh :: ByteString @@ -50,8 +58,46 @@ checkRegressions = [ [244, 176, 181, 139] ++ -- our invalid sequence too high to be valid (take 68 . cycle $ [194, 162]) -- 68 cent symbols + byteBetweenSpaces :: ByteString + byteBetweenSpaces = fromList $ replicate 127 32 ++ [216] ++ replicate 128 32 + + twoBytesBetweenSpaces :: ByteString + twoBytesBetweenSpaces = fromList $ replicate 126 32 ++ [235, 167] ++ replicate 128 32 + + threeBytesBetweenSpaces :: ByteString + threeBytesBetweenSpaces = fromList $ replicate 125 32 ++ [242, 134, 159] ++ replicate 128 32 + + badBlockEnd :: Property + badBlockEnd = + forAllShrinkShow genBadBlock shrinkBadBlock showBadBlock $ \(BadBlock bs) -> + not . B.isValidUtf8 $ bs + + asciiAndInvalidMultiByte :: ByteString + asciiAndInvalidMultiByte = fromList $ replicate 32 48 ++ [235, 185] + -- Helpers +-- A 128-byte sequence with a single bad byte at the end, with the rest being +-- ASCII +newtype BadBlock = BadBlock ByteString + +genBadBlock :: Gen BadBlock +genBadBlock = do + asciiBytes <- vectorOf 127 $ choose (0, 127) + pure . BadBlock . fromListN 128 $ asciiBytes ++ [216] + +shrinkBadBlock :: BadBlock -> [BadBlock] +shrinkBadBlock (BadBlock bs) = BadBlock <$> do + let asList = init . toList $ bs + init' <- fromList <$> traverse shrink asList + guard (B.length init' == 127) + pure . B.append init' . B.singleton $ 216 + +-- Display as hex instead of ASCII-ish +showBadBlock :: BadBlock -> String +showBadBlock (BadBlock bs) = let asList = toList bs in + foldr showHex "" asList + data Utf8Sequence = One Word8 | Two Word8 Word8 | @@ -155,7 +201,7 @@ sequenceToBS = B.pack . \case Three w1 w2 w3 -> [w1, w2, w3] Four w1 w2 w3 w4 -> [w1, w2, w3, w4] -newtype ValidUtf8 = ValidUtf8 [Utf8Sequence] +newtype ValidUtf8 = ValidUtf8 { unValidUtf8 :: [Utf8Sequence] } deriving (Eq) instance Show ValidUtf8 where @@ -188,8 +234,8 @@ instance Arbitrary InvalidUtf8 where , InvalidUtf8 <$> genValidUtf8 <*> genInvalidUtf8 <*> genValidUtf8 ] shrink (InvalidUtf8 p i s) = - (InvalidUtf8 p i <$> shrinkBS s) ++ - ((\p' -> InvalidUtf8 p' i s) <$> shrinkBS p) + (InvalidUtf8 p i <$> shrinkValidBS s) ++ + ((\p' -> InvalidUtf8 p' i s) <$> shrinkValidBS p) toByteString :: InvalidUtf8 -> ByteString toByteString (InvalidUtf8 p i s) = p `B.append` i `B.append` s @@ -240,7 +286,8 @@ genValidUtf8 = sized $ \size -> B.append <$> genAscii <*> resize (size `div` 2) genValidUtf8, B.append <$> gen2Byte <*> resize (size `div` 2) genValidUtf8, B.append <$> gen3Byte <*> resize (size `div` 2) genValidUtf8, - B.append <$> gen4Byte <*> resize (size `div` 2) genValidUtf8 + B.append <$> gen4Byte <*> resize (size `div` 2) genValidUtf8, + B.replicate <$> resize (size * 16) arbitrary <*> elements [0x00 .. 0x7F] ] where genAscii :: Gen ByteString @@ -270,8 +317,8 @@ genValidUtf8 = sized $ \size -> b4 <- elements [0x80 .. 0xBF] pure . B.pack $ [b1, b2, b3, b4] -shrinkBS :: ByteString -> [ByteString] -shrinkBS bs = B.pack <$> (shrink . B.unpack $ bs) +shrinkValidBS :: ByteString -> [ByteString] +shrinkValidBS bs = filter B.isValidUtf8 (map B.pack (shrink (B.unpack bs))) ord2 :: Char -> (Word8, Word8) ord2 c = (x, y)