Skip to content

Produce multiple random numbers efficiently #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 118 additions & 26 deletions System/Random/MWC.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE BangPatterns, CPP, DeriveDataTypeable, FlexibleContexts,
MagicHash, Rank2Types, ScopedTypeVariables, TypeFamilies, UnboxedTuples,
ForeignFunctionInterface #-}

-- |
-- Module : System.Random.MWC
-- Copyright : (c) 2009-2012 Bryan O'Sullivan
Expand Down Expand Up @@ -90,6 +91,9 @@ module System.Random.MWC
, save
, restore

-- * Fold
, foldMUniforms

-- * References
-- $references
) where
Expand Down Expand Up @@ -119,7 +123,7 @@ import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Marshal.Array (peekArray)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as I
import qualified Data.Vector.Unboxed.Mutable as M
import Data.Primitive.ByteArray
import System.CPUTime (cpuTimePrecision, getCPUTime)
import System.IO (IOMode(..), hGetBuf, hPutStrLn, stderr, withBinaryFile)
import System.IO.Unsafe (unsafePerformIO)
Expand Down Expand Up @@ -156,7 +160,7 @@ class Variate a where
-- 2**(-33). To do the same with 'Double' variates, subtract
-- 2**(-53).
uniform :: (PrimMonad m) => Gen (PrimState m) -> m a
-- | Generate single uniformly distributed random variable in a
-- | Generate a single uniformly distributed random variable in a
-- given range.
--
-- * For integral types inclusive range is used.
Expand Down Expand Up @@ -313,7 +317,7 @@ wordsToDouble x y = (fromIntegral u * m_inv_32 + (0.5 + m_inv_53) +
-- | State of the pseudo-random number generator. It uses mutable
-- state so same generator shouldn't be used from the different
-- threads simultaneously.
newtype Gen s = Gen (M.MVector s Word32)
newtype Gen s = Gen (MutableByteArray s)

-- | A shorter name for PRNG state in the 'IO' monad.
type GenIO = Gen (PrimState IO)
Expand Down Expand Up @@ -362,19 +366,19 @@ create = initialize defaultSeed
initialize :: (PrimMonad m, Vector v Word32) =>
v Word32 -> m (Gen (PrimState m))
initialize seed = do
q <- M.unsafeNew 258
q <- mkAlignedByteArray
fill q
if fini == 258
then do
M.unsafeWrite q ioff $ G.unsafeIndex seed ioff .&. 255
M.unsafeWrite q coff $ G.unsafeIndex seed coff
writeByteArray q ioff $ G.unsafeIndex seed ioff .&. 255
writeByteArray q coff $ G.unsafeIndex seed coff
else do
M.unsafeWrite q ioff 255
M.unsafeWrite q coff 362436
writeByteArray q ioff (255 :: Word32)
writeByteArray q coff (362436 :: Word32)
return (Gen q)
where fill q = go 0 where
go i | i == 256 = return ()
| otherwise = M.unsafeWrite q i s >> go (i+1)
| otherwise = writeByteArray q i s >> go (i+1)
where s | i >= fini = if fini == 0
then G.unsafeIndex defaultSeed i
else G.unsafeIndex defaultSeed i `xor`
Expand All @@ -396,16 +400,43 @@ newtype Seed = Seed {
--
-- > restore (toSeed v) = initialize v
toSeed :: (Vector v Word32) => v Word32 -> Seed
toSeed v = Seed $ I.create $ do { Gen q <- initialize v; return q }
toSeed v =
Seed $ I.create $ do
Gen q <- initialize v
unsafeFreezeByteArray q >>= I.unsafeThaw . byteArrayToVector

byteArrayToVector :: (Vector v Word32) => ByteArray -> v Word32
byteArrayToVector q = G.fromList $
let nWord32 = quot (sizeofByteArray q) SIZEOF_WORD32
in map (indexByteArray q) [0..nWord32-1]

vectorToByteArray :: (Vector v Word32, PrimMonad m) => v Word32 -> m (MutableByteArray (PrimState m))
vectorToByteArray v = do
b <- mkAlignedByteArray
mapM_ (uncurry $ writeByteArray b) $ zip [0..] $ G.toList v
return b

mkAlignedByteArray :: PrimMonad m => m (MutableByteArray (PrimState m))
mkAlignedByteArray =
-- The indexes ioff and coff (256,257) are read and written to an order of magnitude more
-- than other indexes, and always consecutively. Hence, it's important that the
-- corresponding memory sits on the same cache line. We also want the overall array to
-- use the least count of cache lines.
--
-- Assuming 64 bytes cache lines, a 64 bytes alignment meets the aforementionned
-- requirements.
newAlignedPinnedByteArray (258 * SIZEOF_WORD32) 64

-- | Save the state of a 'Gen', for later use by 'restore'.
save :: PrimMonad m => Gen (PrimState m) -> m Seed
save (Gen q) = Seed `liftM` G.freeze q
-- its' ok to unsafeFreezeByteArray here because byteArrayToVector will not return
-- any of its memory
save (Gen q) = Seed . byteArrayToVector <$> unsafeFreezeByteArray q
{-# INLINE save #-}

-- | Create a new 'Gen' that mirrors the state of a saved 'Seed'.
restore :: PrimMonad m => Seed -> m (Gen (PrimState m))
restore (Seed s) = Gen `liftM` G.thaw s
restore (Seed s) = Gen <$> vectorToByteArray s
{-# INLINE restore #-}


Expand Down Expand Up @@ -520,19 +551,25 @@ aa :: Word64
aa = 1540315826
{-# INLINE aa #-}

{-# INLINE read32 #-}
read32 :: PrimMonad m => MutableByteArray (PrimState m) -> Int -> m Word32
read32 b i =
readByteArray b i


uniformWord32 :: PrimMonad m => Gen (PrimState m) -> m Word32
uniformWord32 (Gen q) = do
i <- nextIndex `liftM` M.unsafeRead q ioff
c <- fromIntegral `liftM` M.unsafeRead q coff
qi <- fromIntegral `liftM` M.unsafeRead q i
i <- nextIndex `liftM` read32 q ioff
c <- fromIntegral `liftM` read32 q coff
qi <- fromIntegral `liftM` read32 q i
let t = aa * qi + c
c' = fromIntegral (t `shiftR` 32)
x = fromIntegral t + c'
(# x', c'' #) | x < c' = (# x + 1, c' + 1 #)
| otherwise = (# x, c' #)
M.unsafeWrite q i x'
M.unsafeWrite q ioff (fromIntegral i)
M.unsafeWrite q coff (fromIntegral c'')
writeByteArray q i x'
writeByteArray q ioff (fromIntegral i :: Word32)
writeByteArray q coff c''
return x'
{-# INLINE uniformWord32 #-}

Expand All @@ -544,11 +581,11 @@ uniform1 f gen = do

uniform2 :: PrimMonad m => (Word32 -> Word32 -> a) -> Gen (PrimState m) -> m a
uniform2 f (Gen q) = do
i <- nextIndex `liftM` M.unsafeRead q ioff
i <- nextIndex `liftM` read32 q ioff
let j = nextIndex i
c <- fromIntegral `liftM` M.unsafeRead q coff
qi <- fromIntegral `liftM` M.unsafeRead q i
qj <- fromIntegral `liftM` M.unsafeRead q j
c <- fromIntegral `liftM` read32 q coff
qi <- fromIntegral `liftM` read32 q i
qj <- fromIntegral `liftM` read32 q j
let t = aa * qi + c
c' = fromIntegral (t `shiftR` 32)
x = fromIntegral t + c'
Expand All @@ -559,13 +596,68 @@ uniform2 f (Gen q) = do
y = fromIntegral u + d'
(# y', d'' #) | y < d' = (# y + 1, d' + 1 #)
| otherwise = (# y, d' #)
M.unsafeWrite q i x'
M.unsafeWrite q j y'
M.unsafeWrite q ioff (fromIntegral j)
M.unsafeWrite q coff (fromIntegral d'')
writeByteArray q i x'
writeByteArray q j y'
writeByteArray q ioff (fromIntegral j :: Word32)
writeByteArray q coff d''
return $! f x' y'
{-# INLINE uniform2 #-}

data AccumWithUniforms a = AWM {
_coeff :: {-# UNPACK #-} !Word32
, _index :: {-# UNPACK #-} !Int
, _accumulator :: !a
}

-- | Fold-like function allowing to consume random numbers efficiently produced
-- with a minimal number of reads and writes to the state vector.
--
-- To generate @n@ numbers, this function does @n + 2@ reads and @n + 2@ writes.
foldMUniforms :: PrimMonad m
=> Int
-- ^ How many 'Word32' should be generated
-> (a -> Word32 -> m a)
-- ^ The accumulating function
-> a
-- ^ The accumulator's initial value
-> Gen (PrimState m)
-- ^ The RNG
-> m a
foldMUniforms n f acc0 (Gen q) = do
i0 <- fromIntegral <$> read32 q ioff
c0 <- fromIntegral <$> read32 q coff

let accum (AWM cPrev iPrev accPrev) = do
let i = nextIndex iPrev
qi <- fromIntegral <$> read32 q i
let t = aa * qi + fromIntegral cPrev
c' = fromIntegral (t `shiftR` 32)
x = fromIntegral t + c'
(# x', c'' #) | x < c' = (# x + 1, c' + 1 #)
| otherwise = (# x, c' #)
writeByteArray q i x'
AWM c'' i <$> f accPrev x'

(AWM cF iF accF) <- iterateNM accum (AWM c0 i0 acc0) n

writeByteArray q ioff (fromIntegral iF :: Word32)
writeByteArray q coff cF

return accF

{-# INLINE foldMUniforms #-}

-- Equivalent to @foldM (\_ -> f) a [0..n-1]@.
iterateNM :: Monad m => (a -> m a) -> a -> Int -> m a
iterateNM f a0 n0 =
go n0 a0
where
go 0 !a = return a
go n a = f a >>= go (n-1)

{-# INLINE iterateNM #-}


-- Type family for fixed size integrals. For signed data types it's
-- its unsigned couterpart with same size and for unsigned data types
-- it's same type
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/mwc-random-benchmarks.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ executable bm
criterion,
mersenne-random,
mwc-random,
random
random,
vector