Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bounded random integer generation. #20

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/chacha.ml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ end = struct
let next_double t = Common.next_double ~nextu64:next_uint64 t


let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t


let advance d t =
(* Split 128bit [d] into [lower, high] 64-bit integers. *)
let d0, d1 = Uint128.(rem d (of_uint64 Uint64.max_int) |> to_uint64,
Expand Down
22 changes: 22 additions & 0 deletions lib/common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ let next_double ~nextu64 t = match nextu64 t with
|> ( *. ) (1.0 /. 9007199254740992.0), t'


(* Generate a random uint64 integer in the range [0, b) where [b] is the upper bound.
This implementation uses Lemire's method. See: https://arxiv.org/abs/1805.10941 *)
let next_bounded_uint64 b ~nextu64 t =
let b' = Uint128.of_uint64 b
and r = Uint64.(rem (neg b) b) in
let rec loop = function
| m, s when Uint128.to_uint64 m >= r -> Uint128.(shift_right m 64 |> to_uint64), s
| _, s -> match nextu64 s with
| x, s' -> loop Uint128.(of_uint64 x * b', s')
in let x, t' = nextu64 t in
match Uint128.(of_uint64 x * b') with
| m when Uint128.(to_uint64 m < b) -> loop (m, t')
| m -> Uint128.(shift_right m 64 |> to_uint64), t'


module type BITGEN = sig
type t
(** [t] is the state of the bitgenerator. *)
Expand All @@ -36,6 +51,13 @@ module type BITGEN = sig
(** [next_uint32 t] Generates a random unsigned 32-bit integer and a state
of the generator advanced forward by one step. *)

val next_bounded_uint64 : uint64 -> t -> uint64 * t
(** [next_bounded_uint64 b t] Generates a random unsigned 64-bit integer
in the interval {m [0, b)}. It returns the integer as well as the state of the
generator advanced forward. To generate an integer in the range {m [a, b)},
one should generate an integer in {m [0, b - a)} using [next_bounded_uint64 (b - a) t]
and then add [a] to the resulting integer to get the output in the desired range. *)

val next_double : t -> float * t
(** [next_double t] Generates a random 64 bit float and a state of the
generator advanced forward by one step. *)
Expand Down
8 changes: 1 addition & 7 deletions lib/pcg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,7 @@ end = struct
{state = (increment + s) * multiplier + increment; increment}


let next_bounded_uint64 bound t =
let rec loop threshold = function
| r, s when r >= threshold -> s, r
| _, s -> loop threshold (next s)
in
let s, r' = loop (Uint64.rem (Uint64.neg bound) bound) (next t.s) in
Uint64.rem r' bound, {t with s}
let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t


let initialize seed =
Expand Down
3 changes: 3 additions & 0 deletions lib/philox.ml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ end = struct
let next_double t = Common.next_double ~nextu64:next_uint64 t


let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t


let jump t =
let c2' = Uint64.(t.ctr.(2) + one) in match Uint64.(c2' = zero) with
| true -> {t with ctr = [|t.ctr.(0); t.ctr.(1); c2'; Uint64.(t.ctr.(3) + one)|]}
Expand Down
3 changes: 3 additions & 0 deletions lib/sfc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ end = struct
let next_double t = Common.next_double ~nextu64:next_uint64 t


let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t


let set_seed (w, x, y) =
let rec loop s = function
| 0 -> s
Expand Down
3 changes: 3 additions & 0 deletions lib/xoshiro.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ end = struct
let next_double t = Common.next_double ~nextu64:next_uint64 t


let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t


let jump = Uint64.(
[| of_int 0x180ec6d33cfd0aba; of_string "0xd5a61266f0c9392c";
of_string "0xa9582618e03fc9aa"; of_int 0x39abdc4529b1661c |])
Expand Down
4 changes: 4 additions & 0 deletions test/test_chacha.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ let test_chacha_datasets _ =
(Sys.getcwd () ^ "/../../../test/data/chacha-testset-2.csv")


let test_bounded_u64 _ = Testconf.test_bounded_u64 (module ChaCha)


let test_full_init _ =
let ss = SeedSequence.initialize [Uint128.of_int 12345] in
let full = ChaCha.(initialize_full ss Uint64.(max_int, zero) 4 |> next_uint64)
Expand Down Expand Up @@ -57,4 +60,5 @@ let tests = [
"test ChaCha full initialization consistency" >:: test_full_init;
"test if odd number of rounds raises exception" >:: test_invalid_round;
"test the correctness of ChaCha's advance function" >:: test_advance;
"test bounded random generation of ChaCha" >:: test_bounded_u64;
]
17 changes: 3 additions & 14 deletions test/test_pcg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,6 @@ let test_advance _ =
~printer:(fun x -> x)


let test_bounded_u64 _ =
let open Stdint in
let rec loop i t b acc n = match i >= n with
| true -> List.fold_left ( && ) true (List.rev acc)
| false ->
let u, t' = PCG64.next_bounded_uint64 b t in
loop (i + 1) t' b ((u < b) :: acc) n
in
let t = SeedSequence.initialize [Uint128.of_int 12345] |> PCG64.initialize in
List.iter
(fun b -> assert_equal true (loop 0 t (Uint64.of_int b) [] 1000))
[1; 1000; 4193609425186963870]


let test_pcg_datasets _ =
Testconf.bitgen_groundtruth
(module PCG64)
Expand All @@ -39,6 +25,9 @@ let test_pcg_datasets _ =
(Sys.getcwd () ^ "/../../../test/data/pcg64-testset-2.csv")


let test_bounded_u64 _ = Testconf.test_bounded_u64 (module PCG64)


let tests = [
"test PCG64 and PCG64DXSM ran against groundtruth datasets" >:: test_pcg_datasets;
"test correctness of PCG's advance function" >:: test_advance;
Expand Down
4 changes: 4 additions & 0 deletions test/test_philox.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ let test_philox_datasets _ =
(Sys.getcwd () ^ "/../../../test/data/philox-testset-2.csv")


let test_bounded_u64 _ = Testconf.test_bounded_u64 (module Philox4x64)


let test_counter_init _ =
let open Stdint in
let ss = SeedSequence.initialize [Uint128.of_int 12345] in
Expand Down Expand Up @@ -54,4 +57,5 @@ let tests = [
"test behaviour when counter is set" >:: test_counter_init;
"test Philox jump function consistency" >:: test_jump;
"test Philox advance function correctness" >:: test_advance;
"test bounded random generation of Philox" >:: test_bounded_u64;
]
4 changes: 4 additions & 0 deletions test/test_sfc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ let test_sfc_datasets _ =
(Sys.getcwd () ^ "/../../../test/data/sfc64-testset-2.csv")


let test_bounded_u64 _ = Testconf.test_bounded_u64 (module SFC64)


let tests = [
"test SFC64's next_uint64 against groundtruth data" >:: test_sfc_datasets;
"test bounded random generation of SFC64" >:: test_bounded_u64;
]
4 changes: 4 additions & 0 deletions test/test_xoshiro.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ let test_xoshiro_datasets _ =
(Sys.getcwd () ^ "/../../../test/data/xoshiro256-testset-2.csv")


let test_bounded_u64 _ = Testconf.test_bounded_u64 (module Xoshiro256)


let test_jump _ =
let ss = SeedSequence.initialize [] in
let t = Xoshiro256.initialize ss |> Xoshiro256.jump in
Expand All @@ -21,4 +24,5 @@ let test_jump _ =
let tests = [
"test Xoshiro256** PRNG against groundtruth datasets" >:: test_xoshiro_datasets;
"test Xoshiro256** jump function consistency" >:: test_jump;
"test bounded random generation of Xoshiro256**" >:: test_bounded_u64;
]
12 changes: 12 additions & 0 deletions test/testconf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@ module type S = sig
val next_uint64 : t -> uint64 * t
val next_uint32 : t -> uint32 * t
val next_double : t -> float * t
val next_bounded_uint64: uint64 -> t -> uint64 * t
val initialize : Bitgen.SeedSequence.t -> t
end


let test_bounded_u64 (module M : S) =
let open Stdint in
let is_less bound t = match M.next_bounded_uint64 bound t with
| u, t' -> Some (u < bound, t') in
let t = Bitgen.SeedSequence.initialize [Uint128.of_int 12345] |> M.initialize in
let all_true b = Seq.(unfold (is_less b) t |> take 100 |> fold_left (&&) true) in
List.iter (fun b -> assert_equal true (all_true b)) Uint64.[of_int 1; of_int 4193609425186963870]


(* This tests the correctness of a bitgenerator's implementation against groundtruth data for a given seed.
This function takes the module representing the bitgenerator as well as the path to the CSV file
containing the groundtruth data. The data is sourced from numpy's random module test suite. *)
Expand Down
Loading