Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not Merge] Implement PCG64 without using the uint128 type. #36

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 64 additions & 30 deletions lib/pcg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,35 @@
SPDX-License-Identifier: BSD-3-Clause *)
open Stdint

module U128 = struct
type t = { high : uint64; low : uint64 }

let of_u64 high low = {high; low}

let one = Uint64.{high = zero; low = one}

let zero = Uint64.{high = zero; low = zero}

let ( + ) a b =
match Uint64.{high = a.high + b.high; low = a.low + b.low} with
| x when x.low < b.low -> {x with high = Uint64.(x.high + one)}
| x -> x

let max32 = Uint32.(max_int |> to_uint64)
let mult64 x y =
let open Uint64 in
let x0 = logand max32 x and y0 = logand max32 y
and x1 = shift_right x 32 and y1 = shift_right y 32 in
let t = shift_right (x0 * y0) 32 + x1 * y0 in
{high = shift_right (logand max32 t + x0 * y1) 32 + (shift_right t 32) + x1 * y1; low = x * y}

let ( * ) a b = match mult64 a.low b.low with
| {high;low} -> {high = Uint64.(high + a.high * b.low + a.low * b.high); low}

(* let ( ** ) a b = match mult64 a.low b with
| x -> {x with high = Uint64.(x.high + a.high * b)} *)
end


module PCG64 : sig
(** PCG-64 is a 128-bit implementation of O'Neill's permutation congruential
Expand All @@ -20,7 +49,7 @@ module PCG64 : sig

include Common.BITGEN

val advance : int128 -> t -> t
val advance : uint64 * uint64 -> t -> t
(** [advance delta] Advances the underlying RNG as if [delta] draws have been made.
The returned state is that of the generator [delta] steps forward. *)

Expand All @@ -29,55 +58,60 @@ module PCG64 : sig
(0, bound) as well as the state of the generator advanced one step forward. *)
end = struct
type t = {s : setseq; ustore : uint32 option}
and setseq = {state : uint128; increment : uint128}
and setseq = {state : U128.t; increment : U128.t}


let sixtythree = Uint32.of_int32 63l
let multiplier = U128.of_u64 (Uint64.of_int64 2549297995355413924L)
(Uint64.of_int64 4865540595714422341L)

let multiplier = Uint128.of_string "0x2360ed051fc65da44385df649fccf645"
let sixtythree = Uint32.of_int 63

(* Uses the XSL-RR output function *)
let output state =
let v = Uint128.(shift_right state 64 |> logxor state |> to_uint64)
and r = Uint128.(shift_right state 122 |> to_int) in
let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in
Uint64.(logor (shift_left v nr) (shift_right v r))
let output U128.{high; low} =
let v = Uint64.(logxor high low) in
let r = Uint64.(shift_right high 58 |> to_int) in
let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in
Uint64.(logor (shift_left v nr) (shift_right v r))


let next {state; increment} =
let state' = Uint128.(state * multiplier + increment) in
let state' = U128.(state * multiplier + increment) in
output state', {state = state'; increment}


let next_uint64 t = match next t.s with
| u, s -> u, {t with s}


let next_uint32 t =
match Common.next_uint32 ~next:next t.s t.ustore with
| u, s, ustore -> u, {s; ustore}
| u, s, ustore -> u, {s; ustore}


let next_double t = Common.next_double ~nextu64:next_uint64 t
let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t


let advance delta {s = {state; increment}; _} =
let open Uint128 in
let rec lcg d am ap cm cp = (* advance state using LCG method *)
match d = zero, logand d one = one with
| true, _ -> am * state + ap
| false, true -> lcg (shift_right d 1) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one))
| false, false -> lcg (shift_right d 1) am ap (cm * cm) (cp * (cm + one))
in {s = {state = lcg (Uint128.of_int128 delta) one zero multiplier increment; increment}; ustore = None}
let next_double t = Common.next_double ~nextu64:next_uint64 t


let set_seed seed =
let open Uint128 in
let s = logor (shift_left (of_uint64 seed.(0)) 64) (of_uint64 seed.(1))
and i = logor (shift_left (of_uint64 seed.(2)) 64) (of_uint64 seed.(3)) in
let increment = logor (shift_left i 1) one in
{state = (increment + s) * multiplier + increment; increment}


let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t
let s2 = Uint64.(logor (shift_left seed.(2) 1) (shift_right seed.(3) 63)) in
let s3 = Uint64.(logor (shift_left seed.(3) 1) one) in
let increment = U128.of_u64 s2 s3 in
let state = U128.(zero * multiplier + increment) in
{state = U128.((of_u64 seed.(0) seed.(1) + state) * multiplier + increment); increment}


let advance (d1, d0) {s = {state; increment}; _} =
let open U128 in
let half x = U128.{low = Uint64.(logor (shift_right x.low 1) (shift_left x.high 63));
high = Uint64.(shift_right x.high 1)} in
let rec lcg d am ap cm cp =
match Uint64.(d.high <= zero && d.low <= zero, logand d.low one = one) with
| true, _ -> am * state + ap
| false, true -> lcg (half d) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one))
| false, false -> lcg (half d) am ap (cm * cm) (cp * (cm + one))
in {s = {state = lcg (of_u64 d1 d0) one zero multiplier increment; increment}; ustore = None}


let initialize seed =
Expand Down
2 changes: 1 addition & 1 deletion test/test_pcg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ let test_advance _ =
let t = SeedSequence.initialize [Uint128.of_int 12345] |> PCG64.initialize in
let advance n = Seq.(iterate (fun s -> PCG64.next_uint64 s |> snd) t |> drop n |> uncons |> Option.get |> fst) in
assert_equal
(PCG64.advance (Int128.of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string)
(PCG64.advance Uint64.(of_int 0, of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string)
(advance 100 |> PCG64.next_uint64 |> fst |> Uint64.to_string)
~printer:(fun x -> x)

Expand Down
Loading