From a77c0073fc802cf75171cebe6ae536191e0e78a5 Mon Sep 17 00:00:00 2001 From: Pass Automated Testing Suite Date: Sun, 21 Apr 2024 21:52:35 +0200 Subject: [PATCH] Make ChaCha's `initialize_full` type-safe. --- lib/chacha.ml | 25 ++++++++++++------------- test/test_chacha.ml | 13 ++++++------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/lib/chacha.ml b/lib/chacha.ml index 2a489fb..1f27c81 100644 --- a/lib/chacha.ml +++ b/lib/chacha.ml @@ -31,11 +31,12 @@ module ChaCha128Counter : sig include Common.BITGEN - val initialize_full : key:uint64 array -> counter:uint64 * uint64 -> rounds:int -> t - (** [initialize_full key counter rounds] initializes the state of the ChaCha - bitgenerator; where [key] is a 4-element array, [counter] is a 2-tuple, + val initialize_full : Seed.SeedSequence.t -> uint64 * uint64 -> int -> t + (** [initialize_full seedseq counter rounds] initializes the state of the ChaCha + bitgenerator; where [seedseq] is a {!SeedSequence.t} used to initialize the + PRNG's key array, [counter] is a 2-tuple used to initialize the 128-bit counter, and [rounds] is the number of rounds to use. [rounds] must be non-negative, even - and greater than 2, else an exception is raised. *) + and greater than 2, else an [Invalid_argument] exception is raised. *) val advance : uint128 -> t -> t (** [advance n] Advances the generator forward as if [n] calls to {!ChaCha.next_uint32} @@ -144,15 +145,13 @@ end = struct {block = generate_block ctr' keysetup rounds; ctr; keysetup; rounds} - let initialize_full ~key ~counter ~rounds = - let rounds' = match rounds, rounds mod 2 with - | r, m when r <= 2 || m <> 0 -> failwith "`rounds` must be a positive, even and > 2" - | _ -> rounds - in - set_seed (Array.sub key 0 2) (Array.sub key 2 2) counter rounds' + let initialize_full seed counter = function + | r when r <= 2 || r mod 2 <> 0 -> + raise (Invalid_argument "`rounds` must be a positive, even and > 2") + | r -> + let key = Seed.SeedSequence.generate_64bit_state 4 seed in + set_seed (Array.sub key 0 2) (Array.sub key 2 2) counter r - let initialize seed = - let istate = Seed.SeedSequence.generate_64bit_state 4 seed in - initialize_full ~key:istate ~counter:Uint64.(zero, zero) ~rounds:4 + let initialize seed = initialize_full seed Uint64.(zero, zero) 4 end diff --git a/test/test_chacha.ml b/test/test_chacha.ml index 551a70b..a5290a4 100644 --- a/test/test_chacha.ml +++ b/test/test_chacha.ml @@ -14,20 +14,19 @@ let test_chacha_datasets _ = let test_full_init _ = let ss = SeedSequence.initialize [Uint128.of_int 12345] in - let key = SeedSequence.generate_64bit_state 4 ss in - let full = ChaCha.(initialize_full ~key:key ~counter:Uint64.(max_int, zero) ~rounds:4 - |> next_uint64) |> fst |> Uint64.to_string in + let full = ChaCha.(initialize_full ss Uint64.(max_int, zero) 4 |> next_uint64) + |> fst |> Uint64.to_string in let default = ChaCha.(initialize ss |> next_uint64) |> fst |> Uint64.to_string in assert_bool "stream should not be the same for different init counters" (full <> default) let test_invalid_round _ = - let msg = "`rounds` must be a positive, even and > 2" in - let key = SeedSequence.(initialize [Uint128.of_int 12345] |> generate_64bit_state 4) in let ctr = Uint64.(zero, zero) in - assert_raises (Failure msg) (fun _ -> ChaCha.initialize_full ~key:key ~counter:ctr ~rounds:3); + let msg = "`rounds` must be a positive, even and > 2" in + let ss = SeedSequence.initialize [Uint128.of_int 12345] in + assert_raises (Invalid_argument msg) (fun _ -> ChaCha.initialize_full ss ctr 3); (* test non-positive even rounds *) - assert_raises (Failure msg) (fun _ -> ChaCha.initialize_full ~key:key ~counter:ctr ~rounds:(-10)) + assert_raises (Invalid_argument msg) (fun _ -> ChaCha.initialize_full ss ctr (-10)) let test_advance _ =