diff --git a/lib/pcg.ml b/lib/pcg.ml index 589d088..4593197 100644 --- a/lib/pcg.ml +++ b/lib/pcg.ml @@ -5,6 +5,35 @@ SPDX-License-Identifier: BSD-3-Clause *) open Stdint +module U128 = struct + type t = { high : uint64; low : uint64 } + + let of_u64 high low = {high; low} + + let one = Uint64.{high = zero; low = one} + + let zero = Uint64.{high = zero; low = zero} + + let ( + ) a b = + match Uint64.{high = a.high + b.high; low = a.low + b.low} with + | x when x.low < b.low -> {x with high = Uint64.(x.high + one)} + | x -> x + + let max32 = Uint32.(max_int |> to_uint64) + let mult64 x y = + let open Uint64 in + let x0 = logand max32 x and y0 = logand max32 y + and x1 = shift_right x 32 and y1 = shift_right y 32 in + let t = shift_right (x0 * y0) 32 + x1 * y0 in + {high = shift_right (logand max32 t + x0 * y1) 32 + (shift_right t 32) + x1 * y1; low = x * y} + + let ( * ) a b = match mult64 a.low b.low with + | {high;low} -> {high = Uint64.(high + a.high * b.low + a.low * b.high); low} + + (* let ( ** ) a b = match mult64 a.low b with + | x -> {x with high = Uint64.(x.high + a.high * b)} *) +end + module PCG64 : sig (** PCG-64 is a 128-bit implementation of O'Neill's permutation congruential @@ -20,7 +49,7 @@ module PCG64 : sig include Common.BITGEN - val advance : int128 -> t -> t + val advance : uint64 * uint64 -> t -> t (** [advance delta] Advances the underlying RNG as if [delta] draws have been made. The returned state is that of the generator [delta] steps forward. *) @@ -29,55 +58,60 @@ module PCG64 : sig (0, bound) as well as the state of the generator advanced one step forward. *) end = struct type t = {s : setseq; ustore : uint32 option} - and setseq = {state : uint128; increment : uint128} + and setseq = {state : U128.t; increment : U128.t} + + + let sixtythree = Uint32.of_int32 63l + let multiplier = U128.of_u64 (Uint64.of_int64 2549297995355413924L) + (Uint64.of_int64 4865540595714422341L) - let multiplier = Uint128.of_string "0x2360ed051fc65da44385df649fccf645" - let sixtythree = Uint32.of_int 63 (* Uses the XSL-RR output function *) - let output state = - let v = Uint128.(shift_right state 64 |> logxor state |> to_uint64) - and r = Uint128.(shift_right state 122 |> to_int) in - let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in - Uint64.(logor (shift_left v nr) (shift_right v r)) - + let output U128.{high; low} = + let v = Uint64.(logxor high low) in + let r = Uint64.(shift_right high 58 |> to_int) in + let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in + Uint64.(logor (shift_left v nr) (shift_right v r)) + let next {state; increment} = - let state' = Uint128.(state * multiplier + increment) in + let state' = U128.(state * multiplier + increment) in output state', {state = state'; increment} let next_uint64 t = match next t.s with | u, s -> u, {t with s} - + let next_uint32 t = match Common.next_uint32 ~next:next t.s t.ustore with - | u, s, ustore -> u, {s; ustore} + | u, s, ustore -> u, {s; ustore} - let next_double t = Common.next_double ~nextu64:next_uint64 t + let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t - let advance delta {s = {state; increment}; _} = - let open Uint128 in - let rec lcg d am ap cm cp = (* advance state using LCG method *) - match d = zero, logand d one = one with - | true, _ -> am * state + ap - | false, true -> lcg (shift_right d 1) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one)) - | false, false -> lcg (shift_right d 1) am ap (cm * cm) (cp * (cm + one)) - in {s = {state = lcg (Uint128.of_int128 delta) one zero multiplier increment; increment}; ustore = None} + let next_double t = Common.next_double ~nextu64:next_uint64 t let set_seed seed = - let open Uint128 in - let s = logor (shift_left (of_uint64 seed.(0)) 64) (of_uint64 seed.(1)) - and i = logor (shift_left (of_uint64 seed.(2)) 64) (of_uint64 seed.(3)) in - let increment = logor (shift_left i 1) one in - {state = (increment + s) * multiplier + increment; increment} - - - let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t + let s2 = Uint64.(logor (shift_left seed.(2) 1) (shift_right seed.(3) 63)) in + let s3 = Uint64.(logor (shift_left seed.(3) 1) one) in + let increment = U128.of_u64 s2 s3 in + let state = U128.(zero * multiplier + increment) in + {state = U128.((of_u64 seed.(0) seed.(1) + state) * multiplier + increment); increment} + + + let advance (d1, d0) {s = {state; increment}; _} = + let open U128 in + let half x = U128.{low = Uint64.(logor (shift_right x.low 1) (shift_left x.high 63)); + high = Uint64.(shift_right x.high 1)} in + let rec lcg d am ap cm cp = + match Uint64.(d.high <= zero && d.low <= zero, logand d.low one = one) with + | true, _ -> am * state + ap + | false, true -> lcg (half d) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one)) + | false, false -> lcg (half d) am ap (cm * cm) (cp * (cm + one)) + in {s = {state = lcg (of_u64 d1 d0) one zero multiplier increment; increment}; ustore = None} let initialize seed = diff --git a/test/test_pcg.ml b/test/test_pcg.ml index 2ae8b38..c8cdfab 100644 --- a/test/test_pcg.ml +++ b/test/test_pcg.ml @@ -7,7 +7,7 @@ let test_advance _ = let t = SeedSequence.initialize [Uint128.of_int 12345] |> PCG64.initialize in let advance n = Seq.(iterate (fun s -> PCG64.next_uint64 s |> snd) t |> drop n |> uncons |> Option.get |> fst) in assert_equal - (PCG64.advance (Int128.of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string) + (PCG64.advance Uint64.(of_int 0, of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string) (advance 100 |> PCG64.next_uint64 |> fst |> Uint64.to_string) ~printer:(fun x -> x)