Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add advance function to ChaCha bitgenerator. #16

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions lib/chacha.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ module ChaCha128Counter : sig

include Common.BITGEN

val initialize_full : key:uint64 array -> counter:uint64 * uint64 -> rounds:int -> t
(** [initialize_full key counter rounds] initializes the state of the ChaCha
bitgenerator; where [key] is a 4-element array, [counter] is a 2-tuple,
val initialize_full : Seed.SeedSequence.t -> uint64 * uint64 -> int -> t
(** [initialize_full seedseq counter rounds] initializes the state of the ChaCha
bitgenerator; where [seedseq] is a {!SeedSequence.t} used to initialize the
PRNG's key array, [counter] is a 2-tuple used to initialize the 128-bit counter,
and [rounds] is the number of rounds to use. [rounds] must be non-negative, even
and greater than 2, else an exception is raised. *)
and greater than 2, else an [Invalid_argument] exception is raised. *)

val advance : uint128 -> t -> t
(** [advance n] Advances the generator forward as if [n] calls to {!ChaCha.next_uint32}
have been made, and returns the new advanced state. *)

end = struct
type t = {rounds: int; block : uint32 array; keysetup : uint32 array; ctr : uint64 * uint64}

Expand Down Expand Up @@ -113,6 +119,22 @@ end = struct
let next_double t = Common.next_double ~nextu64:next_uint64 t


let advance d t =
(* Split 128bit [d] into [lower, high] 64-bit integers. *)
let d0, d1 = Uint128.(rem d (of_uint64 Uint64.max_int) |> to_uint64,
shift_right d 64 |> to_uint64) in
let open Uint64 in
let ctr0, ctr1 = t.ctr in
let idx = rem ctr0 sixteen64 in
let ctr' = match ctr0 + d0 with
| v when v < ctr0 -> v, ctr1 + d1 + one
| v -> v, ctr1 + d1
in
match (idx + d0 >= sixteen64 || d1 > zero) && (rem (fst ctr') sixteen64 > zero) with
| true -> {t with block = generate_block ctr' t.keysetup t.rounds; ctr = ctr'}
| false -> {t with ctr = ctr'}


let set_seed seed stream ctr rounds =
let open Uint64 in
let f x = logand x mask |> to_uint32
Expand All @@ -123,15 +145,13 @@ end = struct
{block = generate_block ctr' keysetup rounds; ctr; keysetup; rounds}


let initialize_full ~key ~counter ~rounds =
let rounds' = match rounds, rounds mod 2 with
| r, m when r <= 2 || m <> 0 -> failwith "`rounds` must be a positive, even and > 2"
| _ -> rounds
in
set_seed (Array.sub key 0 2) (Array.sub key 2 2) counter rounds'
let initialize_full seed counter = function
| r when r <= 2 || r mod 2 <> 0 ->
raise (Invalid_argument "`rounds` must be a positive, even and > 2")
| r ->
let key = Seed.SeedSequence.generate_64bit_state 4 seed in
set_seed (Array.sub key 0 2) (Array.sub key 2 2) counter r


let initialize seed =
let istate = Seed.SeedSequence.generate_64bit_state 4 seed in
initialize_full ~key:istate ~counter:Uint64.(zero, zero) ~rounds:4
let initialize seed = initialize_full seed Uint64.(zero, zero) 4
end
37 changes: 30 additions & 7 deletions test/test_chacha.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,47 @@ let test_chacha_datasets _ =

let test_full_init _ =
let ss = SeedSequence.initialize [Uint128.of_int 12345] in
let key = SeedSequence.generate_64bit_state 4 ss in
let full = ChaCha.(initialize_full ~key:key ~counter:Uint64.(max_int, zero) ~rounds:4
|> next_uint64) |> fst |> Uint64.to_string in
let full = ChaCha.(initialize_full ss Uint64.(max_int, zero) 4 |> next_uint64)
|> fst |> Uint64.to_string in
let default = ChaCha.(initialize ss |> next_uint64) |> fst |> Uint64.to_string in
assert_bool "stream should not be the same for different init counters" (full <> default)


let test_invalid_round _ =
let msg = "`rounds` must be a positive, even and > 2" in
let key = SeedSequence.(initialize [Uint128.of_int 12345] |> generate_64bit_state 4) in
let ctr = Uint64.(zero, zero) in
assert_raises (Failure msg) (fun _ -> ChaCha.initialize_full ~key:key ~counter:ctr ~rounds:3);
let msg = "`rounds` must be a positive, even and > 2" in
let ss = SeedSequence.initialize [Uint128.of_int 12345] in
assert_raises (Invalid_argument msg) (fun _ -> ChaCha.initialize_full ss ctr 3);
(* test non-positive even rounds *)
assert_raises (Failure msg) (fun _ -> ChaCha.initialize_full ~key:key ~counter:ctr ~rounds:(-10))
assert_raises (Invalid_argument msg) (fun _ -> ChaCha.initialize_full ss ctr (-10))


let test_advance _ =
let open Stdint in
(* manually advance ChaCha n times. Since Chacha generates 32bit ints,
we advance manually using next_uint32, else if we used next_uint64 we
would need to use n/2 steps.*)
let rec advance_n t = function
| 0 -> t
| i -> advance_n (ChaCha.next_uint32 t |> snd) (i - 1)
in
let t = ChaCha.initialize_full (SeedSequence.initialize [Uint128.of_int 12345]) Uint64.(max_int, zero) 4 in
assert_equal
(ChaCha.advance (Uint128.of_int 1000) t |> ChaCha.next_uint32 |> fst |> Uint32.to_string)
(advance_n t 1000 |> ChaCha.next_uint32 |> fst |> Uint32.to_string)
~printer:(fun x -> x);
(* Test zero advancing *)
assert_equal
(ChaCha.advance Uint128.zero t |> ChaCha.next_uint32 |> fst |> Uint32.to_string)
(advance_n t 0 |> ChaCha.next_uint32 |> fst |> Uint32.to_string)
~printer:(fun x -> x);
(* Advancing with the largest 128bit integer should not fail *)
ignore (ChaCha.advance Uint128.max_int t)


let tests = [
"test ChaCha PRNG against groundtruth data" >:: test_chacha_datasets;
"test ChaCha full initialization consistency" >:: test_full_init;
"test if odd number of rounds raises exception" >:: test_invalid_round;
"test the correctness of ChaCha's advance function" >:: test_advance;
]
Loading