Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Data.ByteString.Short.Internal.MBA #617

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 36 additions & 38 deletions Data/ByteString/Short/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ import Data.ByteString.Internal.Type
)

import Data.Array.Byte
( ByteArray(..) )
( ByteArray(..), MutableByteArray(..) )
import Data.Bits
( FiniteBits (finiteBitSize)
, shiftL
Expand Down Expand Up @@ -392,7 +392,7 @@ asBA (ShortByteString ba) = ba
unSBS :: ShortByteString -> ByteArray#
unSBS (ShortByteString (ByteArray ba#)) = ba#

create :: Int -> (forall s. MBA s -> ST s ()) -> ShortByteString
create :: Int -> (forall s. MutableByteArray s -> ST s ()) -> ShortByteString
create len fill =
assert (len >= 0) $ runST $ do
mba <- newByteArray len
Expand All @@ -405,7 +405,7 @@ create len fill =
-- The generating function is required to return the actual final size
-- (<= the maximum size) and the result value. The resulting byte array
-- is realloced to this size.
createAndTrim :: Int -> (forall s. MBA s -> ST s (Int, a)) -> (ShortByteString, a)
createAndTrim :: Int -> (forall s. MutableByteArray s -> ST s (Int, a)) -> (ShortByteString, a)
createAndTrim maxLen fill =
assert (maxLen >= 0) $ runST $ do
mba <- newByteArray maxLen
Expand All @@ -421,7 +421,7 @@ createAndTrim maxLen fill =
return (ShortByteString ba, res)
{-# INLINE createAndTrim #-}

createAndTrim' :: Int -> (forall s. MBA s -> ST s Int) -> ShortByteString
createAndTrim' :: Int -> (forall s. MutableByteArray s -> ST s Int) -> ShortByteString
createAndTrim' maxLen fill =
assert (maxLen >= 0) $ runST $ do
mba <- newByteArray maxLen
Expand All @@ -436,7 +436,7 @@ createAndTrim' maxLen fill =
{-# INLINE createAndTrim' #-}

-- | Like createAndTrim, but with two buffers at once
createAndTrim2 :: Int -> Int -> (forall s. MBA s -> MBA s -> ST s (Int, Int)) -> (ShortByteString, ShortByteString)
createAndTrim2 :: Int -> Int -> (forall s. MutableByteArray s -> MutableByteArray s -> ST s (Int, Int)) -> (ShortByteString, ShortByteString)
createAndTrim2 maxLen1 maxLen2 fill =
runST $ do
mba1 <- newByteArray maxLen1
Expand All @@ -446,7 +446,7 @@ createAndTrim2 maxLen1 maxLen2 fill =
sbs2 <- freeze' len2 maxLen2 mba2
pure (sbs1, sbs2)
where
freeze' :: Int -> Int -> MBA s -> ST s ShortByteString
freeze' :: Int -> Int -> MutableByteArray s -> ST s ShortByteString
freeze' len maxLen mba =
if assert (0 <= len && len <= maxLen) $ len >= maxLen
then do
Expand Down Expand Up @@ -496,7 +496,7 @@ fromShort !sbs = unsafeDupablePerformIO (fromShortIO sbs)
fromShortIO :: ShortByteString -> IO ByteString
fromShortIO sbs = do
let len = length sbs
mba@(MBA# mba#) <- stToIO (newPinnedByteArray len)
mba@(MutableByteArray mba#) <- stToIO (newPinnedByteArray len)
stToIO (copyByteArray (asBA sbs) 0 mba 0 len)
let fp = ForeignPtr (byteArrayContents# (unsafeCoerce# mba#))
(PlainPtr mba#)
Expand Down Expand Up @@ -542,7 +542,7 @@ packLenBytes :: Int -> [Word8] -> ShortByteString
packLenBytes len ws0 =
create len (\mba -> go mba 0 ws0)
where
go :: MBA s -> Int -> [Word8] -> ST s ()
go :: MutableByteArray s -> Int -> [Word8] -> ST s ()
go !_ !_ [] = return ()
go !mba !i (w:ws) = do
writeWord8Array mba i w
Expand Down Expand Up @@ -646,7 +646,7 @@ concat = \sbss ->
totalLen !acc (curr : rest)
= totalLen (checkedAdd "Short.concat" acc $ length curr) rest

copy :: MBA s -> Int -> [ShortByteString] -> ST s ()
copy :: MutableByteArray s -> Int -> [ShortByteString] -> ST s ()
copy !_ !_ [] = return ()
copy !dst !off (src : sbss) = do
let !len = length src
Expand Down Expand Up @@ -777,7 +777,7 @@ map f = \sbs ->
ba = asBA sbs
in create l (\mba -> go ba mba 0 l)
where
go :: ByteArray -> MBA s -> Int -> Int -> ST s ()
go :: ByteArray -> MutableByteArray s -> Int -> Int -> ST s ()
go !ba !mba !i !l
| i >= l = return ()
| otherwise = do
Expand All @@ -796,7 +796,7 @@ reverse = \sbs ->
#if HS_UNALIGNED_ByteArray_OPS_OK
in create l (\mba -> go ba mba l)
where
go :: forall s. ByteArray -> MBA s -> Int -> ST s ()
go :: forall s. ByteArray -> MutableByteArray s -> Int -> ST s ()
go !ba !mba !l = do
-- this is equivalent to: (q, r) = l `quotRem` 8
let q = l `shiftR` 3
Expand Down Expand Up @@ -829,7 +829,7 @@ reverse = \sbs ->
#else
in create l (\mba -> go ba mba 0 l)
where
go :: ByteArray -> MBA s -> Int -> Int -> ST s ()
go :: ByteArray -> MutableByteArray s -> Int -> Int -> ST s ()
go !ba !mba !i !l
| i >= l = return ()
| otherwise = do
Expand All @@ -856,7 +856,7 @@ intercalate sep = \case
ba = asBA sep
lba = length sep

go :: MBA s -> Int -> [ShortByteString] -> ST s ()
go :: MutableByteArray s -> Int -> [ShortByteString] -> ST s ()
go _ _ [] = pure ()
go mba !off (chunk:chunks) = do
let lc = length chunk
Expand Down Expand Up @@ -1278,7 +1278,7 @@ unfoldrN i f = \x0 ->
| otherwise -> createAndTrim i $ \mba -> go mba x0 0

where
go :: forall s. MBA s -> a -> Int -> ST s (Int, Maybe a)
go :: forall s. MutableByteArray s -> a -> Int -> ST s (Int, Maybe a)
go !mba !x !n = go' x n
where
go' :: a -> Int -> ST s (Int, Maybe a)
Expand Down Expand Up @@ -1430,7 +1430,7 @@ filter k = \sbs -> let l = length sbs
in if | l <= 0 -> sbs
| otherwise -> createAndTrim' l $ \mba -> go mba (asBA sbs) l
where
go :: forall s. MBA s -- mutable output bytestring
go :: forall s. MutableByteArray s -- mutable output bytestring
-> ByteArray -- input bytestring
-> Int -- length of input bytestring
-> ST s Int
Expand Down Expand Up @@ -1477,8 +1477,8 @@ partition k = \sbs -> let len = length sbs
| otherwise -> createAndTrim2 len len $ \mba1 mba2 -> go mba1 mba2 (asBA sbs) len
where
go :: forall s.
MBA s -- mutable output bytestring1
-> MBA s -- mutable output bytestring2
MutableByteArray s -- mutable output bytestring1
-> MutableByteArray s -- mutable output bytestring2
-> ByteArray -- input bytestring
-> Int -- length of input bytestring
-> ST s (Int, Int) -- (length mba1, length mba2)
Expand Down Expand Up @@ -1586,8 +1586,6 @@ createFromPtr !ptr len =
------------------------------------------------------------------------
-- Primop wrappers

data MBA s = MBA# (MutableByteArray# s)

indexCharArray :: ByteArray -> Int -> Char
indexCharArray (ByteArray ba#) (I# i#) = C# (indexCharArray# ba# i#)

Expand All @@ -1599,37 +1597,37 @@ indexWord8ArrayAsWord64 :: ByteArray -> Int -> Word64
indexWord8ArrayAsWord64 (ByteArray ba#) (I# i#) = W64# (indexWord8ArrayAsWord64# ba# i#)
#endif

newByteArray :: Int -> ST s (MBA s)
newByteArray :: Int -> ST s (MutableByteArray s)
newByteArray len@(I# len#) =
assert (len >= 0) $
ST $ \s -> case newByteArray# len# s of
(# s', mba# #) -> (# s', MBA# mba# #)
(# s', mba# #) -> (# s', MutableByteArray mba# #)

newPinnedByteArray :: Int -> ST s (MBA s)
newPinnedByteArray :: Int -> ST s (MutableByteArray s)
newPinnedByteArray len@(I# len#) =
assert (len >= 0) $
ST $ \s -> case newPinnedByteArray# len# s of
(# s', mba# #) -> (# s', MBA# mba# #)
(# s', mba# #) -> (# s', MutableByteArray mba# #)

unsafeFreezeByteArray :: MBA s -> ST s ByteArray
unsafeFreezeByteArray (MBA# mba#) =
unsafeFreezeByteArray :: MutableByteArray s -> ST s ByteArray
unsafeFreezeByteArray (MutableByteArray mba#) =
ST $ \s -> case unsafeFreezeByteArray# mba# s of
(# s', ba# #) -> (# s', ByteArray ba# #)

writeWord8Array :: MBA s -> Int -> Word8 -> ST s ()
writeWord8Array (MBA# mba#) (I# i#) (W8# w#) =
writeWord8Array :: MutableByteArray s -> Int -> Word8 -> ST s ()
writeWord8Array (MutableByteArray mba#) (I# i#) (W8# w#) =
ST $ \s -> case writeWord8Array# mba# i# w# s of
s' -> (# s', () #)

#if HS_UNALIGNED_ByteArray_OPS_OK
writeWord64Array :: MBA s -> Int -> Word64 -> ST s ()
writeWord64Array (MBA# mba#) (I# i#) (W64# w#) =
writeWord64Array :: MutableByteArray s -> Int -> Word64 -> ST s ()
writeWord64Array (MutableByteArray mba#) (I# i#) (W64# w#) =
ST $ \s -> case writeWord64Array# mba# i# w# s of
s' -> (# s', () #)
#endif

copyAddrToByteArray :: Ptr a -> MBA RealWorld -> Int -> Int -> ST RealWorld ()
copyAddrToByteArray (Ptr src#) (MBA# dst#) (I# dst_off#) (I# len#) =
copyAddrToByteArray :: Ptr a -> MutableByteArray RealWorld -> Int -> Int -> ST RealWorld ()
copyAddrToByteArray (Ptr src#) (MutableByteArray dst#) (I# dst_off#) (I# len#) =
ST $ \s -> case copyAddrToByteArray# src# dst# dst_off# len# s of
s' -> (# s', () #)

Expand All @@ -1638,18 +1636,18 @@ copyByteArrayToAddr (ByteArray src#) (I# src_off#) (Ptr dst#) (I# len#) =
ST $ \s -> case copyByteArrayToAddr# src# src_off# dst# len# s of
s' -> (# s', () #)

copyByteArray :: ByteArray -> Int -> MBA s -> Int -> Int -> ST s ()
copyByteArray (ByteArray src#) (I# src_off#) (MBA# dst#) (I# dst_off#) (I# len#) =
copyByteArray :: ByteArray -> Int -> MutableByteArray s -> Int -> Int -> ST s ()
copyByteArray (ByteArray src#) (I# src_off#) (MutableByteArray dst#) (I# dst_off#) (I# len#) =
ST $ \s -> case copyByteArray# src# src_off# dst# dst_off# len# s of
s' -> (# s', () #)

setByteArray :: MBA s -> Int -> Int -> Int -> ST s ()
setByteArray (MBA# dst#) (I# off#) (I# len#) (I# c#) =
setByteArray :: MutableByteArray s -> Int -> Int -> Int -> ST s ()
setByteArray (MutableByteArray dst#) (I# off#) (I# len#) (I# c#) =
ST $ \s -> case setByteArray# dst# off# len# c# s of
s' -> (# s', () #)

copyMutableByteArray :: MBA s -> Int -> MBA s -> Int -> Int -> ST s ()
copyMutableByteArray (MBA# src#) (I# src_off#) (MBA# dst#) (I# dst_off#) (I# len#) =
copyMutableByteArray :: MutableByteArray s -> Int -> MutableByteArray s -> Int -> Int -> ST s ()
copyMutableByteArray (MutableByteArray src#) (I# src_off#) (MutableByteArray dst#) (I# dst_off#) (I# len#) =
ST $ \s -> case copyMutableByteArray# src# src_off# dst# dst_off# len# s of
s' -> (# s', () #)

Expand Down Expand Up @@ -1834,7 +1832,7 @@ packLenBytesRev :: Int -> [Word8] -> ShortByteString
packLenBytesRev len ws0 =
create len (\mba -> go mba len ws0)
where
go :: MBA s -> Int -> [Word8] -> ST s ()
go :: MutableByteArray s -> Int -> [Word8] -> ST s ()
go !_ !_ [] = return ()
go !mba !i (w:ws) = do
writeWord8Array mba (i - 1) w
Expand Down
Loading