Skip to content

Commit

Permalink
Merge pull request #104 from mstksg/bitvec-fix
Browse files Browse the repository at this point in the history
[BUGFIX] Fix shiftL/shiftR bug (#99)
  • Loading branch information
expipiplus1 authored Dec 13, 2020
2 parents bd3fbe5 + f375780 commit 8f0d2af
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/Data/Vector/Generic/Sized.hs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ module Data.Vector.Generic.Sized

import Data.Vector.Generic.Sized.Internal
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector as Boxed
import qualified Data.Vector.Storable as Storable
import qualified Data.Vector.Unboxed as Unboxed
Expand All @@ -260,6 +261,7 @@ import Data.Coerce
import Data.Finite hiding (shift)
import Data.Finite.Internal
import Data.Proxy
import qualified Control.Exception as Exception
import Control.Monad (mzero)
import Control.Monad.Primitive
import Control.Monad.ST
Expand Down Expand Up @@ -1950,11 +1952,6 @@ instance (VG.Vector v a, Bits (v a), Bits a, KnownNat n) => Bits (Vector v n a)
(.|.) = coerce ((.|.) @(v a))
xor = coerce (xor @(v a))
complement = coerce (complement @(v a))
shiftL = coerce (shiftL @(v a))
unsafeShiftL = coerce (unsafeShiftL @(v a))
shiftR = coerce (shiftR @(v a))
unsafeShiftR = coerce (unsafeShiftR @(v a))
shift = coerce (shift @(v a))
rotate = coerce (rotate @(v a))
rotateL = coerce (rotateL @(v a))
rotateR = coerce (rotateR @(v a))
Expand All @@ -1974,6 +1971,33 @@ instance (VG.Vector v a, Bits (v a), Bits a, KnownNat n) => Bits (Vector v n a)
SVGM.write v i (complement zeroBits)
freeze v
zeroBits = replicate zeroBits
-- need to do special stuff because they have to preserve vector size
shiftL xs i
| i < 0 = Exception.throw Exception.Overflow
| otherwise = unsafeShiftL xs i
unsafeShiftL (Vector xs) i
| i == 0 = Vector xs
| i' == n = replicate zeroBits
| otherwise = Vector $ runST $ do
u <- VGM.replicate n zeroBits
VG.unsafeCopy (VGM.unsafeDrop i' u) (VG.unsafeTake (n - i') xs)
VG.unsafeFreeze u
where
n = fromInteger $ natVal (Proxy @n)
i' = min n i
shiftR xs i
| i < 0 = Exception.throw Exception.Overflow
| otherwise = unsafeShiftR xs i
unsafeShiftR (Vector xs) i
| i == 0 = Vector xs
| i' == n = replicate zeroBits
| otherwise = Vector $ runST $ do
u <- VGM.replicate n zeroBits
VG.unsafeCopy (VGM.unsafeTake (n - i') u) (VG.unsafeDrop i' xs)
VG.unsafeFreeze u
where
n = fromInteger $ natVal (Proxy @n)
i' = min n i

-- | Treats a bit vector as n times the size of the stored bits, reflecting
-- the 'Bits' instance; does not necessarily reflect exact in-memory
Expand Down

0 comments on commit 8f0d2af

Please sign in to comment.