Skip to content

Commit

Permalink
Implement stimes for Builder and ShortByteString (#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
clyring authored Sep 8, 2023
1 parent 39f4011 commit 09dc954
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 31 deletions.
20 changes: 17 additions & 3 deletions Data/ByteString/Builder/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ module Data.ByteString.Builder.Internal (

import Control.Arrow (second)

#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup (Semigroup((<>)))
#endif
import Data.Semigroup (Semigroup(..))

import qualified Data.ByteString as S
import qualified Data.ByteString.Internal.Type as S
Expand Down Expand Up @@ -382,9 +380,25 @@ empty = Builder ($)
append :: Builder -> Builder -> Builder
append (Builder b1) (Builder b2) = Builder $ b1 . b2

stimesBuilder :: Integral t => t -> Builder -> Builder
{-# INLINABLE stimesBuilder #-}
stimesBuilder n b
| n >= 0 = go n
| otherwise = stimesNegativeErr
where go 0 = empty
go k = b `append` go (k - 1)

stimesNegativeErr :: Builder
-- See Note [Float error calls out of INLINABLE things]
-- in Data.ByteString.Internal.Type
stimesNegativeErr
= errorWithoutStackTrace "stimes @Builder: non-negative multiplier expected"

instance Semigroup Builder where
{-# INLINE (<>) #-}
(<>) = append
{-# INLINE stimes #-}
stimes = stimesBuilder

instance Monoid Builder where
{-# INLINE mempty #-}
Expand Down
29 changes: 26 additions & 3 deletions Data/ByteString/Internal/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -885,13 +885,36 @@ stimesPolymorphic nRaw !bs = case checkedIntegerToInt n of
-- and the likelihood of potentially dangerous mistakes minimized.


{-
Note [Float error calls out of INLINABLE things]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If a function is marked INLINE or INLINABLE, then when ghc inlines or
specializes it, it duplicates the function body exactly as written.
This feature is useful for systems of rewrite rules, but sometimes
comes at a code-size cost. One situation where this cost generally
comes with no compensating up-side is when the function in question
calls `error` or something similar.
Such an `error` call is not meaningfully improved by the extra context
inlining or specialization provides, and if inlining or specialization
happens in a different module from where the function was originally
defined, CSE will not be able to de-duplicate the error call floated
out of the inlined RHS and the error call floated out of the original
RHS. See also https://gitlab.haskell.org/ghc/ghc/-/issues/23823
To mitigate this, we manually float the error calls out of INLINABLE
functions when it is possible to do so.
-}

stimesNegativeErr :: ByteString
-- See Note [Float error calls out of INLINABLE things]
stimesNegativeErr
= error "stimes @ByteString: non-negative multiplier expected"
= errorWithoutStackTrace "stimes @ByteString: non-negative multiplier expected"

stimesOverflowErr :: ByteString
-- Although this only appears once, it is extracted here to prevent it
-- from being duplicated in specializations of 'stimesPolymorphic'
-- See Note [Float error calls out of INLINABLE things]
stimesOverflowErr = overflowError "stimes"

-- | Repeats the given ByteString n times.
Expand Down
3 changes: 2 additions & 1 deletion Data/ByteString/Short/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ import Data.Data
import Data.Monoid
( Monoid(..) )
import Data.Semigroup
( Semigroup((<>)) )
( Semigroup(..), stimesMonoid )
import Data.String
( IsString(..) )
import Control.Applicative
Expand Down Expand Up @@ -313,6 +313,7 @@ instance Ord ShortByteString where

instance Semigroup ShortByteString where
(<>) = append
stimes = stimesMonoid

instance Monoid ShortByteString where
mempty = empty
Expand Down
4 changes: 1 addition & 3 deletions tests/Properties/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,8 @@ tests =
\x y -> B.unpack (mappend x y) === B.unpack x `mappend` B.unpack y
, testProperty "<>" $
\x y -> B.unpack (x <> y) === B.unpack x <> B.unpack y
#ifndef BYTESTRING_SHORT
, testProperty "stimes" $
\(Sqrt (NonNegative n)) (Sqrt x) -> stimes (n :: Int) (x :: BYTESTRING_TYPE) === mtimesDefault n x
#endif
\(Sqrt (NonNegative n)) (Sqrt x) -> stimes (n :: Int) (x :: BYTESTRING_TYPE) === stimesMonoid n x

, testProperty "break" $
\f x -> (B.unpack *** B.unpack) (B.break f x) === break f (B.unpack x)
Expand Down
33 changes: 12 additions & 21 deletions tests/builder/Data/ByteString/Builder/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ import Foreign (minusPtr)
import Data.Char (chr)
import Data.Bits ((.|.), shiftL)
import Data.Foldable
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup
#endif
import Data.Semigroup (Semigroup(..))
import Data.Word

import qualified Data.ByteString as S
Expand All @@ -55,8 +53,11 @@ import System.Posix.Internals (c_unlink)
import Test.Tasty (TestTree, TestName, testGroup)
import Test.Tasty.QuickCheck
( Arbitrary(..), oneof, choose, listOf, elements
, counterexample, ioProperty, UnicodeString(..), Property, testProperty
, (===), (.&&.), conjoin )
, counterexample, ioProperty, Property, testProperty
, (===), (.&&.), conjoin
, UnicodeString(..), NonNegative(..)
)
import QuickCheckUtils


tests :: [TestTree]
Expand All @@ -67,6 +68,7 @@ tests =
, testPut
, testRunBuilder
, testWriteFile
, testStimes
] ++
testsEncodingToBuilder ++
testsBinary ++
Expand Down Expand Up @@ -199,6 +201,11 @@ testWriteFile =
unless success (error msg)
return success

testStimes :: TestTree
testStimes = testProperty "stimes" $
\(Sqrt (NonNegative n)) (Sqrt x) ->
stimes (n :: Int) x === toLazyByteString (stimes n (lazyByteString x))

removeFile :: String -> IO ()
removeFile fn = void $ withCString fn c_unlink

Expand Down Expand Up @@ -319,22 +326,6 @@ recipeComponents (Recipe how firstSize otherSize cont as) =
-- 'Arbitary' instances
-----------------------

instance Arbitrary L.ByteString where
arbitrary = L.fromChunks <$> listOf arbitrary
shrink lbs
| L.null lbs = []
| otherwise = pure $ L.take (L.length lbs `div` 2) lbs

instance Arbitrary S.ByteString where
arbitrary =
trim S.drop =<< trim S.take =<< S.pack <$> listOf arbitrary
where
trim f bs = oneof [pure bs, f <$> choose (0, S.length bs) <*> pure bs]

shrink bs
| S.null bs = []
| otherwise = pure $ S.take (S.length bs `div` 2) bs

instance Arbitrary Mode where
arbitrary = oneof
[Threshold <$> arbitrary, pure Smart, pure Insert, pure Copy, pure Hex]
Expand Down

0 comments on commit 09dc954

Please sign in to comment.