From da2fb25334d9e5c847e1339208cfafb44fab5484 Mon Sep 17 00:00:00 2001 From: Pass Automated Testing Suite Date: Fri, 26 Apr 2024 23:50:39 +0200 Subject: [PATCH] Add support for bounded random integer generation. This implementation uses Daniel Lemire's method, as shown [here](https://arxiv.org/abs/1805.10941). The functionality is added to the main bitgenerator signature and thus supported by all implemented PRNG bitgenerators. --- lib/chacha.ml | 3 +++ lib/common.ml | 22 ++++++++++++++++++++++ lib/pcg.ml | 8 +------- lib/philox.ml | 3 +++ lib/sfc.ml | 3 +++ lib/xoshiro.ml | 3 +++ test/test_chacha.ml | 4 ++++ test/test_pcg.ml | 17 +++-------------- test/test_philox.ml | 4 ++++ test/test_sfc.ml | 4 ++++ test/test_xoshiro.ml | 4 ++++ test/testconf.ml | 16 ++++++++++++++++ 12 files changed, 70 insertions(+), 21 deletions(-) diff --git a/lib/chacha.ml b/lib/chacha.ml index 93de140..c38ce83 100644 --- a/lib/chacha.ml +++ b/lib/chacha.ml @@ -113,6 +113,9 @@ end = struct let next_double t = Common.next_double ~nextu64:next_uint64 t + let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t + + let advance d t = (* Split 128bit [d] into [lower, high] 64-bit integers. *) let d0, d1 = Uint128.(rem d (of_uint64 Uint64.max_int) |> to_uint64, diff --git a/lib/common.ml b/lib/common.ml index c83bcd3..023aa1d 100644 --- a/lib/common.ml +++ b/lib/common.ml @@ -29,6 +29,21 @@ let next_double ~nextu64 t = match nextu64 t with |> ( *. ) (1.0 /. 9007199254740992.0), t' +(* Generate a random uint64 integer in the range [0, b) where [b] is the upper bound. + This implementation uses Lemire's method. See: https://arxiv.org/abs/1805.10941 *) +let next_bounded_uint64 b ~nextu64 t = + let b' = Uint128.of_uint64 b in + let rec loop threshold = function + | m, s when Uint128.to_uint64 m >= threshold -> Uint128.(shift_right m 64 |> to_uint64), s + | _, s -> + let x, s' = nextu64 s in + loop threshold Uint128.(of_uint64 x * b', s') in + let x, t' = nextu64 t in + let m = Uint128.(of_uint64 x * b') in + if Uint128.(to_uint64 m < b) then loop Uint64.(rem (neg b) b) (m, t') + else Uint128.(shift_right m 64 |> to_uint64), t' + + module type BITGEN = sig type t (** [t] is the state of the bitgenerator. *) @@ -41,6 +56,13 @@ module type BITGEN = sig (** [next_uint32 t] Generates a random unsigned 32-bit integer and a state of the generator advanced forward by one step. *) + val next_bounded_uint64 : uint64 -> t -> uint64 * t + (** [next_bounded_uint64 b t] Generates a random unsigned 64-bit integer + in the interval {m [0, b)}. It returns the integer as well as the state of the + generator advanced forward. To generate an integer in the range {m [a, b)}, + one should generate an integer in {m [0, b - a)} using [next_bounded_uint64 (b - a) t] + and then add [a] to the resulting integer to get the output in the desired range. *) + val next_double : t -> float * t (** [next_double t] Generates a random 64 bit float and a state of the generator advanced forward by one step. *) diff --git a/lib/pcg.ml b/lib/pcg.ml index 50a2bb5..38525c0 100644 --- a/lib/pcg.ml +++ b/lib/pcg.ml @@ -77,13 +77,7 @@ end = struct {state = (increment + s) * multiplier + increment; increment} - let next_bounded_uint64 bound t = - let rec loop threshold = function - | r, s when r >= threshold -> s, r - | _, s -> loop threshold (next s) - in - let s, r' = loop (Uint64.rem (Uint64.neg bound) bound) (next t.s) in - Uint64.rem r' bound, {t with s} + let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t let initialize seed = diff --git a/lib/philox.ml b/lib/philox.ml index 97d68c4..96739a2 100644 --- a/lib/philox.ml +++ b/lib/philox.ml @@ -93,6 +93,9 @@ end = struct let next_double t = Common.next_double ~nextu64:next_uint64 t + let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t + + let jump t = let c2' = Uint64.(t.ctr.(2) + one) in match Uint64.(c2' = zero) with | true -> {t with ctr = [|t.ctr.(0); t.ctr.(1); c2'; Uint64.(t.ctr.(3) + one)|]} diff --git a/lib/sfc.ml b/lib/sfc.ml index cf70baf..19b88f3 100644 --- a/lib/sfc.ml +++ b/lib/sfc.ml @@ -40,6 +40,9 @@ end = struct let next_double t = Common.next_double ~nextu64:next_uint64 t + let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t + + let set_seed (w, x, y) = let rec loop s = function | 0 -> s diff --git a/lib/xoshiro.ml b/lib/xoshiro.ml index 983c249..39f70d0 100644 --- a/lib/xoshiro.ml +++ b/lib/xoshiro.ml @@ -55,6 +55,9 @@ end = struct let next_double t = Common.next_double ~nextu64:next_uint64 t + let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t + + let jump = Uint64.( [| of_int 0x180ec6d33cfd0aba; of_string "0xd5a61266f0c9392c"; of_string "0xa9582618e03fc9aa"; of_int 0x39abdc4529b1661c |]) diff --git a/test/test_chacha.ml b/test/test_chacha.ml index 2021919..e79282b 100644 --- a/test/test_chacha.ml +++ b/test/test_chacha.ml @@ -12,6 +12,9 @@ let test_chacha_datasets _ = (Sys.getcwd () ^ "/../../../test/data/chacha-testset-2.csv") +let test_bounded_u64 _ = Testconf.test_bounded_u64 (module ChaCha) + + let test_full_init _ = let ss = SeedSequence.initialize [Uint128.of_int 12345] in let full = ChaCha.(initialize_full ss Uint64.(max_int, zero) 4 |> next_uint64) @@ -57,4 +60,5 @@ let tests = [ "test ChaCha full initialization consistency" >:: test_full_init; "test if odd number of rounds raises exception" >:: test_invalid_round; "test the correctness of ChaCha's advance function" >:: test_advance; + "test bounded random generation of ChaCha" >:: test_bounded_u64; ] diff --git a/test/test_pcg.ml b/test/test_pcg.ml index 97ab2ec..114cb64 100644 --- a/test/test_pcg.ml +++ b/test/test_pcg.ml @@ -16,20 +16,6 @@ let test_advance _ = ~printer:(fun x -> x) -let test_bounded_u64 _ = - let open Stdint in - let rec loop i t b acc n = match i >= n with - | true -> List.fold_left ( && ) true (List.rev acc) - | false -> - let u, t' = PCG64.next_bounded_uint64 b t in - loop (i + 1) t' b ((u < b) :: acc) n - in - let t = SeedSequence.initialize [Uint128.of_int 12345] |> PCG64.initialize in - List.iter - (fun b -> assert_equal true (loop 0 t (Uint64.of_int b) [] 1000)) - [1; 1000; 4193609425186963870] - - let test_pcg_datasets _ = Testconf.bitgen_groundtruth (module PCG64) @@ -39,6 +25,9 @@ let test_pcg_datasets _ = (Sys.getcwd () ^ "/../../../test/data/pcg64-testset-2.csv") +let test_bounded_u64 _ = Testconf.test_bounded_u64 (module PCG64) + + let tests = [ "test PCG64 and PCG64DXSM ran against groundtruth datasets" >:: test_pcg_datasets; "test correctness of PCG's advance function" >:: test_advance; diff --git a/test/test_philox.ml b/test/test_philox.ml index 623e7ae..cc4fb04 100644 --- a/test/test_philox.ml +++ b/test/test_philox.ml @@ -11,6 +11,9 @@ let test_philox_datasets _ = (Sys.getcwd () ^ "/../../../test/data/philox-testset-2.csv") +let test_bounded_u64 _ = Testconf.test_bounded_u64 (module Philox64) + + let test_counter_init _ = let open Stdint in let ss = SeedSequence.initialize [Uint128.of_int 12345] in @@ -37,4 +40,5 @@ let tests = [ "test Philox PNRG against groundtruth data" >:: test_philox_datasets; "test behaviour when counter is set" >:: test_counter_init; "test Philox jump function consistency" >:: test_jump; + "test bounded random generation of Philox" >:: test_bounded_u64; ] diff --git a/test/test_sfc.ml b/test/test_sfc.ml index a67baf0..fbb3f95 100644 --- a/test/test_sfc.ml +++ b/test/test_sfc.ml @@ -11,6 +11,10 @@ let test_sfc_datasets _ = (Sys.getcwd () ^ "/../../../test/data/sfc64-testset-2.csv") +let test_bounded_u64 _ = Testconf.test_bounded_u64 (module SFC64) + + let tests = [ "test SFC64's next_uint64 against groundtruth data" >:: test_sfc_datasets; + "test bounded random generation of SFC64" >:: test_bounded_u64; ] diff --git a/test/test_xoshiro.ml b/test/test_xoshiro.ml index 2b498b7..277c7a0 100644 --- a/test/test_xoshiro.ml +++ b/test/test_xoshiro.ml @@ -11,6 +11,9 @@ let test_xoshiro_datasets _ = (Sys.getcwd () ^ "/../../../test/data/xoshiro256-testset-2.csv") +let test_bounded_u64 _ = Testconf.test_bounded_u64 (module Xoshiro256) + + let test_jump _ = let ss = SeedSequence.initialize [] in let t = Xoshiro256.initialize ss |> Xoshiro256.jump in @@ -21,4 +24,5 @@ let test_jump _ = let tests = [ "test Xoshiro256** PRNG against groundtruth datasets" >:: test_xoshiro_datasets; "test Xoshiro256** jump function consistency" >:: test_jump; + "test bounded random generation of Xoshiro256**" >:: test_bounded_u64; ] diff --git a/test/testconf.ml b/test/testconf.ml index d87f279..062fc9d 100644 --- a/test/testconf.ml +++ b/test/testconf.ml @@ -17,8 +17,24 @@ module type S = sig val next_uint64 : t -> uint64 * t val next_uint32 : t -> uint32 * t val next_double : t -> float * t + val next_bounded_uint64: uint64 -> t -> uint64 * t val initialize : Bitgen.SeedSequence.t -> t end + + +let test_bounded_u64 (module M : S) = + let open Stdint in + let rec loop t b acc = function + | 0 -> List.fold_left ( && ) true acc + | n -> + let u, t' = M.next_bounded_uint64 b t in + loop t' b ((u < b) :: acc) (n - 1) in + let t = Bitgen.SeedSequence.initialize [Uint128.of_int 12345] |> M.initialize in + List.iter + (fun b -> assert_equal true (loop t (Uint64.of_int b) [] 1000)) + [1; 1000; 4193609425186963870] + + (* This tests the correctness of a bitgenerator's implementation against groundtruth data for a given seed. This function takes the module representing the bitgenerator as well as the path to the CSV file containing the groundtruth data. The data is sourced from numpy's random module test suite. *)