From 96292f9aacc6aed1b716f9de951b69586d7512bb Mon Sep 17 00:00:00 2001 From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com> Date: Thu, 13 Jun 2024 10:04:19 +0300 Subject: [PATCH] Poseidon252 channel (#655) --- Cargo.lock | 322 +++++++++++++++++- crates/prover/Cargo.toml | 2 + .../core/{channel.rs => channel/blake2s.rs} | 59 +--- crates/prover/src/core/channel/mod.rs | 47 +++ crates/prover/src/core/channel/poseidon252.rs | 238 +++++++++++++ 5 files changed, 610 insertions(+), 58 deletions(-) rename crates/prover/src/core/{channel.rs => channel/blake2s.rs} (84%) create mode 100644 crates/prover/src/core/channel/mod.rs create mode 100644 crates/prover/src/core/channel/poseidon252.rs diff --git a/Cargo.lock b/Cargo.lock index 5c4dd53a5..f9b33e10d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,70 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" +[[package]] +name = "ark-ff" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "derivative", + "digest", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rustc_version", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-std", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "arrayref" version = "0.3.7" @@ -59,6 +123,18 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +[[package]] +name = "bigdecimal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + [[package]] name = "blake2" version = "0.10.6" @@ -113,7 +189,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -192,6 +268,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -259,6 +344,17 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -269,6 +365,17 @@ dependencies = [ "typenum", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -289,7 +396,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -315,7 +422,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -347,6 +454,19 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + [[package]] name = "half" version = "2.4.1" @@ -369,6 +489,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "is-terminal" version = "0.4.12" @@ -421,9 +550,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.153" +version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "log" @@ -456,6 +585,25 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -483,6 +631,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -517,6 +671,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.81" @@ -541,6 +701,17 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", "rand_core", ] @@ -614,6 +785,25 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "ryu" version = "1.0.17" @@ -629,6 +819,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.198" @@ -646,7 +842,7 @@ checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -660,6 +856,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -681,6 +888,60 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "starknet-crypto" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e2c30c01e8eb0fc913c4ee3cf676389fffc1d1182bfe5bb9670e4e72e968064" +dependencies = [ + "crypto-bigint", + "hex", + "hmac", + "num-bigint", + "num-integer", + "num-traits", + "rfc6979", + "sha2", + "starknet-crypto-codegen", + "starknet-curve", + "starknet-ff", + "zeroize", +] + +[[package]] +name = "starknet-crypto-codegen" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbc159a1934c7be9761c237333a57febe060ace2bc9e3b337a59a37af206d19f" +dependencies = [ + "starknet-curve", + "starknet-ff", + "syn 2.0.60", +] + +[[package]] +name = "starknet-curve" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1c383518bb312751e4be80f53e8644034aa99a0afb29d7ac41b89a997db875b" +dependencies = [ + "starknet-ff", +] + +[[package]] +name = "starknet-ff" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abf1b44ec5b18d87c1ae5f54590ca9d0699ef4dd5b2ffa66fc97f24613ec585" +dependencies = [ + "ark-ff", + "bigdecimal", + "crypto-bigint", + "getrandom", + "hex", + "serde", +] + [[package]] name = "stwo-prover" version = "0.1.1" @@ -696,6 +957,8 @@ dependencies = [ "itertools 0.12.1", "num-traits", "rand", + "starknet-crypto", + "starknet-ff", "test-log", "thiserror", "tracing", @@ -708,6 +971,17 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.60" @@ -738,7 +1012,7 @@ checksum = "c8f546451eaa38373f549093fe9fd05e7d2bade739e2ddf834b9968621d60107" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -758,7 +1032,7 @@ checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -800,7 +1074,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", ] [[package]] @@ -876,6 +1150,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasm-bindgen" version = "0.2.92" @@ -897,7 +1177,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 2.0.60", "wasm-bindgen-shared", ] @@ -919,7 +1199,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.60", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1043,3 +1323,23 @@ name = "windows_x86_64_msvc" version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index c056faa00..a0b97bbbc 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -15,6 +15,8 @@ hex.workspace = true itertools.workspace = true num-traits.workspace = true rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } +starknet-crypto = "0.6.2" +starknet-ff = "0.3.7" thiserror.workspace = true tracing.workspace = true diff --git a/crates/prover/src/core/channel.rs b/crates/prover/src/core/channel/blake2s.rs similarity index 84% rename from crates/prover/src/core/channel.rs rename to crates/prover/src/core/channel/blake2s.rs index 68071d144..7c67c2494 100644 --- a/crates/prover/src/core/channel.rs +++ b/crates/prover/src/core/channel/blake2s.rs @@ -1,53 +1,15 @@ use std::iter; -use super::fields::m31::{BaseField, N_BYTES_FELT, P}; -use super::fields::qm31::SecureField; -use super::fields::secure_column::SECURE_EXTENSION_DEGREE; -use super::fields::IntoSlice; +use super::{Channel, ChannelTime}; +use crate::core::fields::m31::{BaseField, N_BYTES_FELT, P}; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::fields::IntoSlice; use crate::core::vcs::blake2_hash::{Blake2sHash, Blake2sHasher}; use crate::core::vcs::hasher::Hasher; pub const BLAKE_BYTES_PER_HASH: usize = 32; pub const FELTS_PER_HASH: usize = 8; -pub const EXTENSION_FELTS_PER_HASH: usize = 2; - -#[derive(Default)] -pub struct ChannelTime { - n_challenges: usize, - n_sent: usize, -} - -impl ChannelTime { - fn inc_sent(&mut self) { - self.n_sent += 1; - } - - fn inc_challenges(&mut self) { - self.n_challenges += 1; - self.n_sent = 0; - } -} - -pub trait Channel { - type Digest; - - const BYTES_PER_HASH: usize; - - fn new(digest: Self::Digest) -> Self; - fn get_digest(&self) -> Self::Digest; - - // Mix functions. - fn mix_digest(&mut self, digest: Self::Digest); - fn mix_felts(&mut self, felts: &[SecureField]); - fn mix_nonce(&mut self, nonce: u64); - - // Draw functions. - fn draw_felt(&mut self) -> SecureField; - /// Generates a uniform random vector of SecureField elements. - fn draw_felts(&mut self, n_felts: usize) -> Vec; - /// Returns a vector of random bytes of length `BYTES_PER_HASH`. - fn draw_random_bytes(&mut self) -> Vec; -} /// A channel that can be used to draw random elements from a [Blake2sHash] digest. pub struct Blake2sChannel { @@ -61,7 +23,7 @@ impl Blake2sChannel { // Repeats hashing with an increasing counter until getting a good result. // Retry probability for each round is ~ 2^(-28). loop { - let random_bytes: [u32; FELTS_PER_HASH] = self + let u32s: [u32; FELTS_PER_HASH] = self .draw_random_bytes() .chunks_exact(N_BYTES_FELT) // 4 bytes per u32. .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) @@ -70,8 +32,8 @@ impl Blake2sChannel { .unwrap(); // Retry if not all the u32 are in the range [0, 2P). - if random_bytes.iter().all(|x| *x < 2 * P) { - return random_bytes + if u32s.iter().all(|x| *x < 2 * P) { + return u32s .into_iter() .map(|x| BaseField::reduce(x as u64)) .collect::>() @@ -149,6 +111,8 @@ impl Channel for Blake2sChannel { hash_input.extend_from_slice(&padded_counter); + // TODO(spapini): Are we worried about this drawing hash colliding with mix_digest? + self.channel_time.inc_sent(); Blake2sHasher::hash(&hash_input).into() } @@ -158,7 +122,8 @@ impl Channel for Blake2sChannel { mod tests { use std::collections::BTreeSet; - use crate::core::channel::{Blake2sChannel, Channel}; + use crate::core::channel::blake2s::Blake2sChannel; + use crate::core::channel::Channel; use crate::core::fields::qm31::SecureField; use crate::core::vcs::blake2_hash::Blake2sHash; use crate::m31; diff --git a/crates/prover/src/core/channel/mod.rs b/crates/prover/src/core/channel/mod.rs new file mode 100644 index 000000000..576001f47 --- /dev/null +++ b/crates/prover/src/core/channel/mod.rs @@ -0,0 +1,47 @@ +use super::fields::qm31::SecureField; + +mod blake2s; +#[cfg(not(target_arch = "wasm32"))] +mod poseidon252; + +pub use blake2s::Blake2sChannel; + +pub const EXTENSION_FELTS_PER_HASH: usize = 2; + +#[derive(Default)] +pub struct ChannelTime { + n_challenges: usize, + n_sent: usize, +} + +impl ChannelTime { + fn inc_sent(&mut self) { + self.n_sent += 1; + } + + fn inc_challenges(&mut self) { + self.n_challenges += 1; + self.n_sent = 0; + } +} + +pub trait Channel { + type Digest; + + const BYTES_PER_HASH: usize; + + fn new(digest: Self::Digest) -> Self; + fn get_digest(&self) -> Self::Digest; + + // Mix functions. + fn mix_digest(&mut self, digest: Self::Digest); + fn mix_felts(&mut self, felts: &[SecureField]); + fn mix_nonce(&mut self, nonce: u64); + + // Draw functions. + fn draw_felt(&mut self) -> SecureField; + /// Generates a uniform random vector of SecureField elements. + fn draw_felts(&mut self, n_felts: usize) -> Vec; + /// Returns a vector of random bytes of length `BYTES_PER_HASH`. + fn draw_random_bytes(&mut self) -> Vec; +} diff --git a/crates/prover/src/core/channel/poseidon252.rs b/crates/prover/src/core/channel/poseidon252.rs new file mode 100644 index 000000000..b65a89e1d --- /dev/null +++ b/crates/prover/src/core/channel/poseidon252.rs @@ -0,0 +1,238 @@ +use std::iter; + +use starknet_crypto::poseidon_hash; +use starknet_ff::FieldElement as FieldElement252; + +use super::{Channel, ChannelTime}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; + +pub const BYTES_PER_FELT252: usize = 31; +pub const FELTS_PER_HASH: usize = 8; + +/// A channel that can be used to draw random elements from a Poseidon252 hash. +pub struct Poseidon252Channel { + digest: FieldElement252, + channel_time: ChannelTime, +} + +impl Poseidon252Channel { + fn draw_felt252(&mut self) -> FieldElement252 { + let res = poseidon_hash(self.digest, self.channel_time.n_sent.into()); + self.channel_time.inc_sent(); + res + } + + // TODO(spapini): Understand if we really need uniformity here. + /// Generates a close-to uniform random vector of BaseField elements. + fn draw_base_felts(&mut self) -> [BaseField; 8] { + let shift = (1u64 << 31).into(); + + let mut cur = self.draw_felt252(); + let u32s: [u32; 8] = std::array::from_fn(|_| { + let next = cur.floor_div(shift); + let res = cur - next * shift; + cur = next; + res.try_into().unwrap() + }); + + u32s.into_iter() + .map(|x| BaseField::reduce(x as u64)) + .collect::>() + .try_into() + .unwrap() + } +} + +impl Channel for Poseidon252Channel { + type Digest = FieldElement252; + const BYTES_PER_HASH: usize = BYTES_PER_FELT252; + + fn new(digest: Self::Digest) -> Self { + Poseidon252Channel { + digest, + channel_time: ChannelTime::default(), + } + } + + fn get_digest(&self) -> Self::Digest { + self.digest + } + + fn mix_digest(&mut self, digest: Self::Digest) { + self.digest = poseidon_hash(self.digest, digest); + self.channel_time.inc_challenges(); + } + + fn mix_felts(&mut self, felts: &[SecureField]) { + let shift = (1u64 << 31).into(); + let mut cur = FieldElement252::default(); + let mut in_chunk = 0; + for x in felts { + for y in x.to_m31_array() { + cur = cur * shift + y.0.into(); + } + in_chunk += 1; + if in_chunk == 2 { + self.digest = poseidon_hash(self.digest, cur); + cur = FieldElement252::default(); + in_chunk = 0; + } + } + if in_chunk > 0 { + self.digest = poseidon_hash(self.digest, cur); + } + + // TODO(spapini): do we need length padding? + self.channel_time.inc_challenges(); + } + + fn mix_nonce(&mut self, nonce: u64) { + self.mix_digest(nonce.into()) + } + + fn draw_felt(&mut self) -> SecureField { + let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); + SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) + } + + fn draw_felts(&mut self, n_felts: usize) -> Vec { + let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); + let secure_felts = iter::from_fn(|| { + Some(SecureField::from_m31_array([ + felts.next()?, + felts.next()?, + felts.next()?, + felts.next()?, + ])) + }); + secure_felts.take(n_felts).collect() + } + + fn draw_random_bytes(&mut self) -> Vec { + let shift = (1u64 << 8).into(); + let mut cur = self.draw_felt252(); + let bytes: [u8; 31] = std::array::from_fn(|_| { + let next = cur.floor_div(shift); + let res = cur - next * shift; + cur = next; + res.try_into().unwrap() + }); + bytes.to_vec() + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use starknet_ff::FieldElement as FieldElement252; + + use crate::core::channel::poseidon252::Poseidon252Channel; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::m31; + + #[test] + fn test_initialize_channel() { + let initial_digest = FieldElement252::default(); + let channel = Poseidon252Channel::new(initial_digest); + + // Assert that the channel is initialized correctly. + assert_eq!(channel.digest, initial_digest); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 0); + } + + #[test] + fn test_channel_time() { + let initial_digest = FieldElement252::default(); + let mut channel = Poseidon252Channel::new(initial_digest); + + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 0); + + channel.draw_random_bytes(); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 1); + + channel.draw_felts(9); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 6); + + channel.mix_digest(FieldElement252::default()); + assert_eq!(channel.channel_time.n_challenges, 1); + assert_eq!(channel.channel_time.n_sent, 0); + + channel.draw_felt(); + assert_eq!(channel.channel_time.n_challenges, 1); + assert_eq!(channel.channel_time.n_sent, 1); + assert_ne!(channel.digest, initial_digest); + } + + #[test] + fn test_draw_random_bytes() { + let initial_digest = FieldElement252::default(); + let mut channel = Poseidon252Channel::new(initial_digest); + + let first_random_bytes = channel.draw_random_bytes(); + + // Assert that next random bytes are different. + assert_ne!(first_random_bytes, channel.draw_random_bytes()); + } + + #[test] + pub fn test_draw_felt() { + let initial_digest = FieldElement252::default(); + let mut channel = Poseidon252Channel::new(initial_digest); + + let first_random_felt = channel.draw_felt(); + + // Assert that next random felt is different. + assert_ne!(first_random_felt, channel.draw_felt()); + } + + #[test] + pub fn test_draw_felts() { + let initial_digest = FieldElement252::default(); + let mut channel = Poseidon252Channel::new(initial_digest); + + let mut random_felts = channel.draw_felts(5); + random_felts.extend(channel.draw_felts(4)); + + // Assert that all the random felts are unique. + assert_eq!( + random_felts.len(), + random_felts.iter().collect::>().len() + ); + } + + #[test] + pub fn test_mix_digest() { + let initial_digest = FieldElement252::default(); + let mut channel = Poseidon252Channel::new(initial_digest); + + for _ in 0..10 { + channel.draw_random_bytes(); + channel.draw_felt(); + } + + // Reseed channel and check the digest was changed. + channel.mix_digest(FieldElement252::default()); + assert_ne!(initial_digest, channel.digest); + } + + #[test] + pub fn test_mix_felts() { + let initial_digest = FieldElement252::default(); + let mut channel = Poseidon252Channel::new(initial_digest); + let felts: Vec = (0..2) + .map(|i| SecureField::from(m31!(i + 1923782))) + .collect(); + + channel.mix_felts(felts.as_slice()); + + assert_ne!(initial_digest, channel.digest); + } +}