Skip to content

Commit

Permalink
Add a few more size-overflow-related checks (#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
clyring committed Jul 6, 2023
1 parent 88f16dc commit 470b6e3
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 83 deletions.
11 changes: 4 additions & 7 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[0.12.0.0]Unreleased
[0.12.0.0]July 2023

* __Breaking Changes__:
* [`readInt` returns `Nothing`, if the sequence of digits cannot be represented by an `Int`, instead of overflowing silently](https://github.com/haskell/bytestring/pull/309)
Expand All @@ -14,11 +14,8 @@
* [`stimes @StrictByteString`](https://github.com/haskell/bytestring/pull/443)
* [`Data.ByteString.Short.concat`](https://github.com/haskell/bytestring/pull/443)
* [`Data.ByteString.Short.append`](https://github.com/haskell/bytestring/pull/443)
<!-- TODO: Some other `ShortByteString` functions are probably still
susceptible to bad behavior on `Int` overflow in edge cases;
`D.B.Short.Internal.create` does not check for negative size,
unlike its `StrictByteString` counterpart.
-->
* [`Data.ByteString.Short.snoc`](https://github.com/haskell/bytestring/pull/599)
* [`Data.ByteString.Short.cons`](https://github.com/haskell/bytestring/pull/599)
* API additions:
* [New sized and/or unsigned variants of `readInt` and `readInteger`](https://github.com/haskell/bytestring/pull/438)
* [`Data.ByteString.Internal` now provides `SizeOverflowException`, `overflowError`, and `checkedMultiply`](https://github.com/haskell/bytestring/pull/443)
Expand All @@ -34,7 +31,7 @@

[0.12.0.0]: https://github.com/haskell/bytestring/compare/0.11.5.0...0.12.0.0

[0.11.5.0]Unreleased
[0.11.5.0]July 2023

* Bug fixes:
* [Fix multiple bugs with ASCII blocks in the SIMD implementations for `isValidUtf8`](https://github.com/haskell/bytestring/pull/582)
Expand Down
14 changes: 7 additions & 7 deletions Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,16 @@ infixl 5 `snoc`
-- | /O(n)/ 'cons' is analogous to (:) for lists, but of different
-- complexity, as it requires making a copy.
cons :: Word8 -> ByteString -> ByteString
cons c (BS x l) = unsafeCreateFp (l+1) $ \p -> do
cons c (BS x len) = unsafeCreateFp (checkedAdd "cons" len 1) $ \p -> do
pokeFp p c
memcpyFp (p `plusForeignPtr` 1) x l
memcpyFp (p `plusForeignPtr` 1) x len
{-# INLINE cons #-}

-- | /O(n)/ Append a byte to the end of a 'ByteString'
snoc :: ByteString -> Word8 -> ByteString
snoc (BS x l) c = unsafeCreateFp (l+1) $ \p -> do
memcpyFp p x l
pokeFp (p `plusForeignPtr` l) c
snoc (BS x len) c = unsafeCreateFp (checkedAdd "snoc" len 1) $ \p -> do
memcpyFp p x len
pokeFp (p `plusForeignPtr` len) c
{-# INLINE snoc #-}

-- | /O(1)/ Extract the first element of a ByteString, which must be non-empty.
Expand Down Expand Up @@ -773,7 +773,7 @@ scanl
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanl f v = \(BS a len) -> unsafeCreateFp (len+1) $ \q -> do
scanl f v = \(BS a len) -> unsafeCreateFp (checkedAdd "scanl" len 1) $ \q -> do
-- see fold inlining
pokeFp q v
let
Expand Down Expand Up @@ -817,7 +817,7 @@ scanr
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr f v = \(BS a len) -> unsafeCreateFp (len+1) $ \b -> do
scanr f v = \(BS a len) -> unsafeCreateFp (checkedAdd "scanr" len 1) $ \b -> do
-- see fold inlining
pokeFpByteOff b len v
let
Expand Down
60 changes: 33 additions & 27 deletions Data/ByteString/Internal/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -653,30 +653,30 @@ unsafeCreateFpUptoN' l f = unsafeDupablePerformIO (createFpUptoN' l f)

-- | Create ByteString of size @l@ and use action @f@ to fill its contents.
createFp :: Int -> (ForeignPtr Word8 -> IO ()) -> IO ByteString
createFp l action = do
fp <- mallocByteString l
createFp len action = assert (len >= 0) $ do
fp <- mallocByteString len
action fp
mkDeferredByteString fp l
mkDeferredByteString fp len
{-# INLINE createFp #-}

-- | Given a maximum size @l@ and an action @f@ that fills the 'ByteString'
-- starting at the given 'Ptr' and returns the actual utilized length,
-- @`createFpUptoN'` l f@ returns the filled 'ByteString'.
createFpUptoN :: Int -> (ForeignPtr Word8 -> IO Int) -> IO ByteString
createFpUptoN l action = do
fp <- mallocByteString l
l' <- action fp
assert (l' <= l) $ mkDeferredByteString fp l'
createFpUptoN maxLen action = assert (maxLen >= 0) $ do
fp <- mallocByteString maxLen
len <- action fp
assert (0 <= len && len <= maxLen) $ mkDeferredByteString fp len
{-# INLINE createFpUptoN #-}

-- | Like 'createFpUptoN', but also returns an additional value created by the
-- action.
createFpUptoN' :: Int -> (ForeignPtr Word8 -> IO (Int, a)) -> IO (ByteString, a)
createFpUptoN' l action = do
fp <- mallocByteString l
(l', res) <- action fp
bs <- mkDeferredByteString fp l'
assert (l' <= l) $ pure (bs, res)
createFpUptoN' maxLen action = assert (maxLen >= 0) $ do
fp <- mallocByteString maxLen
(len, res) <- action fp
bs <- mkDeferredByteString fp len
assert (0 <= len && len <= maxLen) $ pure (bs, res)
{-# INLINE createFpUptoN' #-}

-- | Given the maximum size needed and a function to make the contents
Expand All @@ -688,22 +688,26 @@ createFpUptoN' l action = do
-- ByteString functions, using Haskell or C functions to fill the space.
--
createFpAndTrim :: Int -> (ForeignPtr Word8 -> IO Int) -> IO ByteString
createFpAndTrim l action = do
fp <- mallocByteString l
l' <- action fp
if assert (0 <= l' && l' <= l) $ l' >= l
then mkDeferredByteString fp l
else createFp l' $ \dest -> memcpyFp dest fp l'
createFpAndTrim maxLen action = assert (maxLen >= 0) $ do
fp <- mallocByteString maxLen
len <- action fp
if assert (0 <= len && len <= maxLen) $ len >= maxLen
then mkDeferredByteString fp maxLen
else createFp len $ \dest -> memcpyFp dest fp len
{-# INLINE createFpAndTrim #-}

createFpAndTrim' :: Int -> (ForeignPtr Word8 -> IO (Int, Int, a)) -> IO (ByteString, a)
createFpAndTrim' l action = do
fp <- mallocByteString l
(off, l', res) <- action fp
bs <- if assert (0 <= l' && l' <= l) $ l' >= l
then mkDeferredByteString fp l -- entire buffer used => offset is zero
else createFp l' $ \dest ->
memcpyFp dest (fp `plusForeignPtr` off) l'
createFpAndTrim' maxLen action = assert (maxLen >= 0) $ do
fp <- mallocByteString maxLen
(off, len, res) <- action fp
assert (
0 <= len && len <= maxLen && -- length OK
(len == 0 || (0 <= off && off <= maxLen - len)) -- offset OK
) $ pure ()
bs <- if len >= maxLen
then mkDeferredByteString fp maxLen -- entire buffer used => offset is zero
else createFp len $ \dest ->
memcpyFp dest (fp `plusForeignPtr` off) len
return (bs, res)
{-# INLINE createFpAndTrim' #-}

Expand Down Expand Up @@ -971,8 +975,10 @@ overflowError fun = throw $ SizeOverflowException msg
checkedAdd :: String -> Int -> Int -> Int
{-# INLINE checkedAdd #-}
checkedAdd fun x y
| r >= 0 = r
| otherwise = overflowError fun
-- checking "r < 0" here matches the condition in mallocPlainForeignPtrBytes,
-- helping the compiler see the latter is redundant in some places
| r < 0 = overflowError fun
| otherwise = r
where r = assert (min x y >= 0) $ x + y

-- | Multiplies two non-negative numbers.
Expand Down
87 changes: 45 additions & 42 deletions Data/ByteString/Short/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ unSBS (ShortByteString (ByteArray ba#)) = ba#

create :: Int -> (forall s. MBA s -> ST s ()) -> ShortByteString
create len fill =
runST $ do
assert (len >= 0) $ runST $ do
mba <- newByteArray len
fill mba
BA# ba# <- unsafeFreezeByteArray mba
Expand All @@ -413,59 +413,60 @@ create len fill =
-- (<= the maximum size) and the result value. The resulting byte array
-- is realloced to this size.
createAndTrim :: Int -> (forall s. MBA s -> ST s (Int, a)) -> (ShortByteString, a)
createAndTrim l fill =
runST $ do
mba <- newByteArray l
(l', res) <- fill mba
if assert (l' <= l) $ l' >= l
createAndTrim maxLen fill =
assert (maxLen >= 0) $ runST $ do
mba <- newByteArray maxLen
(len, res) <- fill mba
if assert (0 <= len && len <= maxLen) $ len >= maxLen
then do
BA# ba# <- unsafeFreezeByteArray mba
return (SBS ba#, res)
else do
mba2 <- newByteArray l'
copyMutableByteArray mba 0 mba2 0 l'
mba2 <- newByteArray len
copyMutableByteArray mba 0 mba2 0 len
BA# ba# <- unsafeFreezeByteArray mba2
return (SBS ba#, res)
{-# INLINE createAndTrim #-}

createAndTrim' :: Int -> (forall s. MBA s -> ST s Int) -> ShortByteString
createAndTrim' l fill =
runST $ do
mba <- newByteArray l
l' <- fill mba
if assert (l' <= l) $ l' >= l
createAndTrim' maxLen fill =
assert (maxLen >= 0) $ runST $ do
mba <- newByteArray maxLen
len <- fill mba
if assert (0 <= len && len <= maxLen) $ len >= maxLen
then do
BA# ba# <- unsafeFreezeByteArray mba
return (SBS ba#)
else do
mba2 <- newByteArray l'
copyMutableByteArray mba 0 mba2 0 l'
mba2 <- newByteArray len
copyMutableByteArray mba 0 mba2 0 len
BA# ba# <- unsafeFreezeByteArray mba2
return (SBS ba#)
{-# INLINE createAndTrim' #-}

createAndTrim'' :: Int -> (forall s. MBA s -> MBA s -> ST s (Int, Int)) -> (ShortByteString, ShortByteString)
createAndTrim'' l fill =
-- | Like createAndTrim, but with two buffers at once
createAndTrim2 :: Int -> Int -> (forall s. MBA s -> MBA s -> ST s (Int, Int)) -> (ShortByteString, ShortByteString)
createAndTrim2 maxLen1 maxLen2 fill =
runST $ do
mba1 <- newByteArray l
mba2 <- newByteArray l
(l1, l2) <- fill mba1 mba2
sbs1 <- freeze' l1 mba1
sbs2 <- freeze' l2 mba2
mba1 <- newByteArray maxLen1
mba2 <- newByteArray maxLen2
(len1, len2) <- fill mba1 mba2
sbs1 <- freeze' len1 maxLen1 mba1
sbs2 <- freeze' len2 maxLen2 mba2
pure (sbs1, sbs2)
where
freeze' :: Int -> MBA s -> ST s ShortByteString
freeze' l' mba =
if assert (l' <= l) $ l' >= l
freeze' :: Int -> Int -> MBA s -> ST s ShortByteString
freeze' len maxLen mba =
if assert (0 <= len && len <= maxLen) $ len >= maxLen
then do
BA# ba# <- unsafeFreezeByteArray mba
return (SBS ba#)
else do
mba2 <- newByteArray l'
copyMutableByteArray mba 0 mba2 0 l'
mba2 <- newByteArray len
copyMutableByteArray mba 0 mba2 0 len
BA# ba# <- unsafeFreezeByteArray mba2
return (SBS ba#)
{-# INLINE createAndTrim'' #-}
{-# INLINE createAndTrim2 #-}

isPinned :: ByteArray# -> Bool
#if MIN_VERSION_base(4,10,0)
Expand Down Expand Up @@ -676,23 +677,23 @@ infixl 5 `snoc`
--
-- @since 0.11.3.0
snoc :: ShortByteString -> Word8 -> ShortByteString
snoc = \sbs c -> let l = length sbs
nl = l + 1
in create nl $ \mba -> do
copyByteArray (asBA sbs) 0 mba 0 l
writeWord8Array mba l c
snoc = \sbs c -> let len = length sbs
newLen = checkedAdd "Short.snoc" len 1
in create newLen $ \mba -> do
copyByteArray (asBA sbs) 0 mba 0 len
writeWord8Array mba len c

-- | /O(n)/ 'cons' is analogous to (:) for lists.
--
-- Note: copies the entire byte array
--
-- @since 0.11.3.0
cons :: Word8 -> ShortByteString -> ShortByteString
cons c = \sbs -> let l = length sbs
nl = l + 1
in create nl $ \mba -> do
cons c = \sbs -> let len = length sbs
newLen = checkedAdd "Short.cons" len 1
in create newLen $ \mba -> do
writeWord8Array mba 0 c
copyByteArray (asBA sbs) 0 mba 1 l
copyByteArray (asBA sbs) 0 mba 1 len

-- | /O(1)/ Extract the last element of a ShortByteString, which must be finite and non-empty.
-- An exception will be thrown in the case of an empty ShortByteString.
Expand Down Expand Up @@ -1484,9 +1485,9 @@ find f = \sbs -> case findIndex f sbs of
--
-- @since 0.11.3.0
partition :: (Word8 -> Bool) -> ShortByteString -> (ShortByteString, ShortByteString)
partition k = \sbs -> let l = length sbs
in if | l <= 0 -> (sbs, sbs)
| otherwise -> createAndTrim'' l $ \mba1 mba2 -> go mba1 mba2 (asBA sbs) l
partition k = \sbs -> let len = length sbs
in if | len <= 0 -> (sbs, sbs)
| otherwise -> createAndTrim2 len len $ \mba1 mba2 -> go mba1 mba2 (asBA sbs) len
where
go :: forall s.
MBA s -- mutable output bytestring1
Expand Down Expand Up @@ -1614,12 +1615,14 @@ indexWord8ArrayAsWord64 (BA# ba#) (I# i#) = W64# (indexWord8ArrayAsWord64# ba# i
#endif

newByteArray :: Int -> ST s (MBA s)
newByteArray (I# len#) =
newByteArray len@(I# len#) =
assert (len >= 0) $
ST $ \s -> case newByteArray# len# s of
(# s', mba# #) -> (# s', MBA# mba# #)

newPinnedByteArray :: Int -> ST s (MBA s)
newPinnedByteArray (I# len#) =
newPinnedByteArray len@(I# len#) =
assert (len >= 0) $
ST $ \s -> case newPinnedByteArray# len# s of
(# s', mba# #) -> (# s', MBA# mba# #)

Expand Down

0 comments on commit 470b6e3

Please sign in to comment.