diff --git a/Cargo.toml b/Cargo.toml index d1e7c492..6557b9c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ members = [ "crates/clmul", "crates/mpz-ole-core", "crates/mpz-ole", + "crates/mpz-zk-core", + "crates/mpz-zk", ] resolver = "2" @@ -43,6 +45,8 @@ mpz-ole = { path = "crates/mpz-ole" } mpz-ole-core = { path = "crates/mpz-ole-core" } clmul = { path = "crates/clmul" } matrix-transpose = { path = "crates/matrix-transpose" } +mpz-zk-core = { path = "crates/mpz-zk-core" } +mpz-zk = { path = "crates/mpz-zk" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 804472ef..7fcb1628 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -18,7 +18,7 @@ struct Buffer { } /// The ideal functionality from the perspective of Alice. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Alice { f: Arc>, buffer: Arc>, @@ -79,7 +79,7 @@ impl Alice { } /// The ideal functionality from the perspective of Bob. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Bob { f: Arc>, buffer: Arc>, diff --git a/crates/mpz-core/src/ggm_tree.rs b/crates/mpz-core/src/ggm_tree.rs index 913fffb6..840efcc6 100644 --- a/crates/mpz-core/src/ggm_tree.rs +++ b/crates/mpz-core/src/ggm_tree.rs @@ -32,33 +32,35 @@ impl GgmTree { assert_eq!(k0.len(), self.depth); assert_eq!(k1.len(), self.depth); let mut buf = [Block::ZERO; 8]; - self.tkprp.expand_1to2(tree, seed); - k0[0] = tree[0]; - k1[0] = tree[1]; + if self.depth > 1 { + self.tkprp.expand_1to2(tree, seed); + k0[0] = tree[0]; + k1[0] = tree[1]; - self.tkprp.expand_2to4(&mut buf, tree); - k0[1] = buf[0] ^ buf[2]; - k1[1] = buf[1] ^ buf[3]; - tree[0..4].copy_from_slice(&buf[0..4]); - - for h in 2..self.depth { - k0[h] = Block::ZERO; - k1[h] = Block::ZERO; - - // How many nodes there are in this layer - let sz = 1 << h; - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - k0[h] ^= buf[0]; - k0[h] ^= buf[2]; - k0[h] ^= buf[4]; - k0[h] ^= buf[6]; - k1[h] ^= buf[1]; - k1[h] ^= buf[3]; - k1[h] ^= buf[5]; - k1[h] ^= buf[7]; + self.tkprp.expand_2to4(&mut buf, tree); + k0[1] = buf[0] ^ buf[2]; + k1[1] = buf[1] ^ buf[3]; + tree[0..4].copy_from_slice(&buf[0..4]); - tree[2 * i..2 * i + 8].copy_from_slice(&buf); + for h in 2..self.depth { + k0[h] = Block::ZERO; + k1[h] = Block::ZERO; + + // How many nodes there are in this layer + let sz = 1 << h; + for i in (0..=sz - 4).rev().step_by(4) { + self.tkprp.expand_4to8(&mut buf, &tree[i..]); + k0[h] ^= buf[0]; + k0[h] ^= buf[2]; + k0[h] ^= buf[4]; + k0[h] ^= buf[6]; + k1[h] ^= buf[1]; + k1[h] ^= buf[3]; + k1[h] ^= buf[5]; + k1[h] ^= buf[7]; + + tree[2 * i..2 * i + 8].copy_from_slice(&buf); + } } } } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 3ad7701e..bbbf264a 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -36,11 +36,12 @@ pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { }; /// The type of Lpn parameters. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, Default)] pub enum LpnType { /// Uniform error distribution. Uniform, /// Regular error distribution. + #[default] Regular, } @@ -48,7 +49,6 @@ pub enum LpnType { mod tests { use super::*; - use msgs::LpnMatrixSeed; use receiver::Receiver; use sender::Sender; @@ -56,7 +56,6 @@ mod tests { use crate::test::assert_cot; use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; use mpz_core::{lpn::LpnParameters, prg::Prg}; - use rand::SeedableRng; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { n: 9600, @@ -66,7 +65,7 @@ mod tests { #[test] fn ferret_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_cot = IdealCOT::default(); let mut ideal_mpcot = IdealMpcot::default(); @@ -101,18 +100,8 @@ mod tests { ) .unwrap(); - let LpnMatrixSeed { - seed: lpn_matrix_seed, - } = seed; - let mut sender = sender - .setup( - delta, - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &v, - ) + .setup(delta, LPN_PARAMETERS_TEST, LpnType::Regular, seed, &v) .unwrap(); // extend once diff --git a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs index e74dc38a..047780d4 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs @@ -16,11 +16,10 @@ mod tests { use crate::ideal::spcot::IdealSpcot; use crate::{SPCOTReceiverOutput, SPCOTSenderOutput}; use mpz_core::prg::Prg; - use rand::SeedableRng; #[test] fn mpcot_general_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); @@ -96,7 +95,7 @@ mod tests { #[test] fn mpcot_regular_test() { - let mut prg = Prg::from_seed([2u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs index 0f8613af..e4d362da 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs @@ -32,11 +32,11 @@ impl Receiver { /// # Argument /// /// * `hash_seed` - Random seed to generate hashes, will be sent to the sender. - pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { + pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); let recv = Receiver { - state: state::PreExtension { + state: state::Extension { counter: 0, hashes: Arc::new(hashes), }, @@ -48,7 +48,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -63,7 +63,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { if alphas.len() as u32 > n { return Err(ReceiverError::InvalidInput( "length of alphas should not exceed n".to_string(), @@ -104,7 +104,7 @@ impl Receiver { } let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, m, n, @@ -117,7 +117,7 @@ impl Receiver { Ok((receiver, p)) } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -128,7 +128,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt.len() != self.state.m { return Err(ReceiverError::InvalidInput( "the length rt should be m".to_string(), @@ -165,7 +165,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, hashes: self.state.hashes, }, @@ -182,8 +182,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -200,20 +200,20 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, /// The hashes to generate Cuckoo hash table. pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state of extension. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter pub(super) counter: usize, /// Current length of Cuckoo hash table, will possibly be changed in each extension. @@ -228,7 +228,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs index 2b226108..e1e7edfe 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs @@ -19,13 +19,13 @@ impl Receiver { } /// Completes the setup phase of the protocol. - pub fn setup(self) -> Receiver { + pub fn setup(self) -> Receiver { Receiver { - state: state::PreExtension { counter: 0 }, + state: state::Extension { counter: 0 }, } } } -impl Receiver { +impl Receiver { /// Performs the prepare procedure in MPCOT extension. /// Outputs the indices for SPCOT. /// @@ -38,7 +38,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { let t = alphas.len() as u32; if t > n { return Err(ReceiverError::InvalidInput( @@ -91,7 +91,7 @@ impl Receiver { .collect(); let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, n, queries_length, @@ -103,7 +103,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// # Arguments. @@ -112,7 +112,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt .iter() .zip(self.state.queries_depth.iter()) @@ -130,7 +130,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, }, }; @@ -145,8 +145,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -162,19 +162,19 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state after the setup phase. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter #[allow(dead_code)] pub(super) counter: usize, @@ -186,7 +186,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs index f1e49105..ad025574 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs @@ -31,12 +31,12 @@ impl Sender { /// /// * `delta` - The sender's global secret. /// * `hash_seed` - The seed for Cuckoo hash sent by the receiver. - pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { + pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { let HashSeed { seed: hash_seed } = hash_seed; let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); Sender { - state: state::PreExtension { + state: state::Extension { delta, counter: 0, hashes: Arc::new(hashes), @@ -45,7 +45,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -59,7 +59,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -86,7 +86,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, m, @@ -101,7 +101,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -112,7 +112,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st.len() != self.state.m { return Err(SenderError::InvalidInput( "the length st should be m".to_string(), @@ -147,7 +147,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, hashes: self.state.hashes, @@ -166,8 +166,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -184,7 +184,7 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -193,13 +193,13 @@ pub mod state { pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state of extension. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -217,7 +217,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs index db0646b6..7afa5106 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs @@ -23,14 +23,14 @@ impl Sender { /// # Argument. /// /// * `delta` - The sender's global secret. - pub fn setup(self, delta: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { - state: state::PreExtension { delta, counter: 0 }, + state: state::Extension { delta, counter: 0 }, } } } -impl Sender { +impl Sender { /// Performs the prepare procedure in MPCOT extension. /// Outputs the information for SPCOT. /// @@ -42,7 +42,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -78,7 +78,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, n, @@ -91,7 +91,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// # Arguments. @@ -100,7 +100,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st .iter() .zip(self.state.queries_depth.iter()) @@ -117,7 +117,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, }, @@ -135,8 +135,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -153,20 +153,20 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state after the setup phase. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -179,7 +179,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 4d08c69b..e5939c60 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -4,7 +4,10 @@ use mpz_core::{ Block, }; -use crate::ferret::{error::ReceiverError, LpnType}; +use crate::{ + ferret::{error::ReceiverError, LpnType}, + TransferId, +}; use super::msgs::LpnMatrixSeed; @@ -59,6 +62,7 @@ impl Receiver { u: u.to_vec(), w: w.to_vec(), e: Vec::default(), + id: TransferId::default(), }, }, LpnMatrixSeed { seed }, @@ -69,10 +73,6 @@ impl Receiver { impl Receiver { /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN parameters. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { match self.state.lpn_type { LpnType::Uniform => { @@ -105,6 +105,8 @@ impl Receiver { return Err(ReceiverError("the length of r should be n".to_string())); } + self.state.id.next(); + // Compute z = A * w + r. let mut z = r.to_vec(); self.state.lpn_encoder.compute(&mut z, &self.state.w); @@ -133,6 +135,11 @@ impl Receiver { Ok((x_, z_)) } + + /// Returns id + pub fn id(&self) -> TransferId { + self.state.id + } } /// The receiver's state. @@ -176,6 +183,9 @@ pub mod state { /// Receiver's lpn error vector. pub(super) e: Vec, + + /// TransferID + pub(super) id: TransferId, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 9e8db180..2af3e4ae 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -4,7 +4,12 @@ use mpz_core::{ Block, }; -use crate::ferret::{error::SenderError, LpnType}; +use crate::{ + ferret::{error::SenderError, LpnType}, + TransferId, +}; + +use super::msgs::LpnMatrixSeed; /// Ferret sender. #[derive(Debug, Default)] @@ -36,7 +41,7 @@ impl Sender { delta: Block, lpn_parameters: LpnParameters, lpn_type: LpnType, - seed: Block, + seed: LpnMatrixSeed, v: &[Block], ) -> Result, SenderError> { if v.len() != lpn_parameters.k { @@ -44,6 +49,7 @@ impl Sender { "the length of v should be equal to k".to_string(), )); } + let LpnMatrixSeed { seed } = seed; let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); Ok(Sender { @@ -54,6 +60,7 @@ impl Sender { lpn_type, lpn_encoder, v: v.to_vec(), + id: TransferId::default(), }, }) } @@ -63,6 +70,7 @@ impl Sender { /// Outputs the information for MPCOT. /// /// See step 3 and 4. + #[inline] pub fn get_mpcot_query(&self) -> (u32, u32) { ( self.state.lpn_parameters.t as u32, @@ -83,6 +91,8 @@ impl Sender { return Err(SenderError("the length of s should be n".to_string())); } + self.state.id.next(); + // Compute y = A * v + s let mut y = s.to_vec(); self.state.lpn_encoder.compute(&mut y, &self.state.v); @@ -97,10 +107,17 @@ impl Sender { Ok(y_) } + + /// Returns id + pub fn id(&self) -> TransferId { + self.state.id + } } /// The sender's state. pub mod state { + use crate::TransferId; + use super::*; mod sealed { @@ -141,6 +158,9 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, + + /// TransferID. + pub(crate) id: TransferId, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/spcot/mod.rs b/crates/mpz-ot-core/src/ferret/spcot/mod.rs index 802efb66..63ebea15 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/mod.rs @@ -7,8 +7,6 @@ pub mod sender; #[cfg(test)] mod tests { - use mpz_core::prg::Prg; - use super::{receiver::Receiver as SpcotReceiver, sender::Sender as SpcotSender}; use crate::{ferret::CSP, ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; @@ -18,49 +16,82 @@ mod tests { let sender = SpcotSender::new(); let receiver = SpcotReceiver::new(); - let mut prg = Prg::new(); - let sender_seed = prg.random_block(); let delta = ideal_cot.delta(); - let mut sender = sender.setup(delta, sender_seed); + let mut sender = sender.setup(delta); let mut receiver = receiver.setup(); - let h1 = 8; - let alpha1 = 3; + let hs = [8, 4, 10]; + let alphas = [3, 2, 4]; - // Extend once - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h1); + let h_sum = hs.iter().sum(); + // batch extension + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h1, alpha1, &rs).unwrap(); - let msg_from_sender = sender.extend(h1, &qs, maskbits).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); + + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); + + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); + + // Check + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); + + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = msg_for_receiver; + + let RCOTSenderOutput { msgs: y_star, .. } = msg_for_sender; + + let check_from_receiver = receiver.check_pre(&x_star).unwrap(); - receiver.extend(h1, alpha1, &ts, msg_from_sender).unwrap(); + let (mut output_sender, check) = sender.check(&y_star, check_from_receiver).unwrap(); - // Extend twice - let h2 = 4; - let alpha2 = 2; + let output_receiver = receiver.check(&z_star, check).unwrap(); - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h2); + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + let h_sum = hs.iter().sum(); + + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h2, alpha2, &rs).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); - let msg_from_sender = sender.extend(h2, &qs, maskbits).unwrap(); + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); - receiver.extend(h2, alpha2, &ts, msg_from_sender).unwrap(); + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); // Check let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs index 5e860f31..baf10ae2 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs @@ -6,6 +6,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -43,71 +47,101 @@ impl Receiver { } impl Receiver { - /// Performs the mask bit step in extension. + /// Performs the mask bit step in batch in extension. /// /// See step 4 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `rs` - The message from COT ideal functionality for the receiver. Only the random bits are used. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `rss` - The message from COT ideal functionality for the receiver for all the tress. Only the random bits are used. pub fn extend_mask_bits( &mut self, - h: usize, - alpha: u32, - rs: &[bool], - ) -> Result { + hs: &[usize], + alphas: &[u32], + rss: &[bool], + ) -> Result, ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( "extension is not allowed".to_string(), )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - if rs.len() != h { + let h_sum = hs.iter().sum(); + + if rss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of r should be h".to_string(), + "the length of r should be the sum of h".to_string(), )); } - // Step 4 in Figure 6 + let mut rs_s = vec![Vec::::new(); hs.len()]; + let mut rss_vec = rss.to_vec(); + for (index, h) in hs.iter().enumerate() { + rs_s[index] = rss_vec.drain(0..*h).collect(); + } - let bs: Vec = alpha - .iter_msb0() - .skip(32 - h) - // Computes alpha_i XOR r_i XOR 1. - .zip(rs.iter()) - .map(|(alpha, &r)| alpha == r) - .collect(); + // Step 4 in Figure 6 + let mut bss = vec![Vec::::new(); hs.len()]; + + let iter = bss + .iter_mut() + .zip(alphas.iter()) + .zip(hs.iter()) + .zip(rs_s.iter()) + .map(|(((bs, alpha), h), rs)| (bs, alpha, h, rs)); + + for (bs, alpha, h, rs) in iter { + *bs = alpha + .iter_msb0() + .skip(32 - h) + // Computes alpha_i XOR r_i XOR 1. + .zip(rs.iter()) + .map(|(alpha, &r)| alpha == r) + .collect(); + } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); + + let res: Vec = bss.into_iter().map(|bs| MaskBits { bs }).collect(); - Ok(MaskBits { bs }) + Ok(res) } - /// Performs the GGM reconstruction step in extension. This function can be called multiple times before checking. + /// Performs the GGM reconstruction step in batch in extension. This function can be called multiple times before checking. /// /// See step 5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `ts` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `extendfs` - The message sent by the sender. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `tss` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. + /// * `extendfss` - The vector of messages sent by the sender. pub fn extend( &mut self, - h: usize, - alpha: u32, - ts: &[Block], - extendfs: ExtendFromSender, + hs: &[usize], + alphas: &[u32], + tss: &[Block], + extendfss: &[ExtendFromSender], ) -> Result<(), ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( @@ -115,61 +149,122 @@ impl Receiver { )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - let ExtendFromSender { ms, sum } = extendfs; - if ts.len() != h { + let h_sum = hs.iter().sum(); + + if tss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of t should be h".to_string(), + "the length of tss should be the sum of h".to_string(), )); } - if ms.len() != h { + let mut ts_s = vec![Vec::::new(); hs.len()]; + let mut tss_vec = tss.to_vec(); + for (index, h) in hs.iter().enumerate() { + ts_s[index] = tss_vec.drain(0..*h).collect(); + } + + if extendfss.len() != hs.len() { return Err(ReceiverError::InvalidLength( - "the length of M should be h".to_string(), + "the length of extendfss should be the length of hs".to_string(), )); } - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); - - // Step 5 in Figure 6. - let k: Vec = ms - .into_iter() - .zip(ts) - .zip(alpha_bar_vec.iter()) - .enumerate() - .map(|(i, (([m0, m1], &t), &b))| { - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - if !b { - // H(t, i|ell) ^ M0 - FIXED_KEY_AES.tccr(tweak, t) ^ m0 - } else { - // H(t, i|ell) ^ M1 - FIXED_KEY_AES.tccr(tweak, t) ^ m1 - } - }) - .collect(); + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; - // Reconstructs GGM tree except `ws[alpha]`. - let ggm_tree = GgmTree::new(h); - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec); + for (index, extendfs) in extendfss.iter().enumerate() { + ms_s[index].clone_from(&extendfs.ms); + sum_s[index] = extendfs.sum; + } + + if ms_s.iter().zip(hs.iter()).any(|(ms, h)| ms.len() != *h) { + return Err(ReceiverError::InvalidLength( + "the length of ms should be h".to_string(), + )); + } + // Updates hasher + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); + + let mut trees = vec![Vec::::new(); hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = alphas + .par_iter() + .zip(ms_s.par_iter()) + .zip(sum_s.par_iter()) + .zip(hs.par_iter()) + .zip(ts_s.par_iter()) + .zip(trees.par_iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + }else{ + let iter = alphas + .iter() + .zip(ms_s.iter()) + .zip(sum_s.iter()) + .zip(hs.iter()) + .zip(ts_s.iter()) + .zip(trees.iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + } + } - // Sets `tree[alpha]`, which is `ws[alpha]`. - tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x); + iter.for_each(|(alpha, ms, sum, h, ts, tree)| { + let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); + + // Step 5 in Figure 6. + let k: Vec = ms + .iter() + .zip(ts) + .zip(alpha_bar_vec.iter()) + .enumerate() + .map(|(i, (([m0, m1], &t), &b))| { + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + if !b { + // H(t, i|ell) ^ M0 + FIXED_KEY_AES.tccr(tweak, t) ^ *m0 + } else { + // H(t, i|ell) ^ M1 + FIXED_KEY_AES.tccr(tweak, t) ^ *m1 + } + }) + .collect(); + + // Reconstructs GGM tree except `ws[alpha]`. + let ggm_tree = GgmTree::new(*h); + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.reconstruct(tree, &k, &alpha_bar_vec); + + // Sets `tree[alpha]`, which is `ws[alpha]`. + tree[(*alpha) as usize] = tree.iter().fold(*sum, |acc, &x| acc ^ x); + }); + + for tree in trees { + self.state.unchecked_ws.extend_from_slice(&tree); + } - self.state.unchecked_ws.extend_from_slice(&tree); - self.state.alphas_and_length.push((alpha, 1 << h)); + for (alpha, h) in alphas.iter().zip(hs.iter()) { + self.state.alphas_and_length.push((*alpha, 1 << h)); + } - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); Ok(()) } @@ -248,7 +343,6 @@ impl Receiver { } self.state.cot_counter += self.state.unchecked_ws.len(); - self.state.extended = true; let mut res = Vec::new(); for (alpha, n) in &self.state.alphas_and_length { @@ -256,8 +350,19 @@ impl Receiver { res.push((tmp, *alpha)); } + self.state.hasher = blake3::Hasher::new(); + self.state.alphas_and_length.clear(); + self.state.chis.clear(); + self.state.unchecked_ws.clear(); + Ok(res) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The receiver's state. diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs index fef1327e..a62ad3bb 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/sender.rs @@ -5,6 +5,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -29,8 +33,7 @@ impl Sender { /// # Arguments /// /// * `delta` - The sender's global secret. - /// * `seed` - The random seed to generate PRG. - pub fn setup(self, delta: Block, seed: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { state: state::Extension { delta, @@ -39,7 +42,6 @@ impl Sender { cot_counter: 0, exec_counter: 0, extended: false, - prg: Prg::from_seed(seed), hasher: blake3::Hasher::new(), }, } @@ -47,85 +49,137 @@ impl Sender { } impl Sender { - /// Performs the SPCOT extension. + /// Performs batch SPCOT extension. /// /// See Step 1-5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `qs`- The blocks received by calling the COT functionality. - /// * `mask`- The mask bits sent by the receiver. + /// * `hs` - The depths of the GGM trees. + /// * `qss`- The blocks received by calling the COT functionality for hs trees. + /// * `masks`- The vector of mask bits sent by the receiver. pub fn extend( &mut self, - h: usize, - qs: &[Block], - mask: MaskBits, - ) -> Result { + hs: &[usize], + qss: &[Block], + masks: &[MaskBits], + ) -> Result, SenderError> { if self.state.extended { return Err(SenderError::InvalidState( "extension is not allowed".to_string(), )); } - if qs.len() != h { + let h_sum = hs.iter().sum(); + + if qss.len() != h_sum { return Err(SenderError::InvalidLength( - "the length of q should be h".to_string(), + "the length of qss should be the sum of h".to_string(), )); } - let MaskBits { bs } = mask; + let mut qs_s = vec![Vec::::new(); hs.len()]; + let mut qss_vec = qss.to_vec(); + for (index, h) in hs.iter().enumerate() { + qs_s[index] = qss_vec.drain(0..*h).collect(); + } - if bs.len() != h { + if masks.len() != hs.len() { + return Err(SenderError::InvalidLength( + "the length of masks should be the length of hs".to_string(), + )); + } + + let bss: Vec> = masks.iter().map(|m| m.clone().bs).collect(); + + if bss.iter().zip(hs.iter()).any(|(b, h)| b.len() != *h) { return Err(SenderError::InvalidLength( "the length of b should be h".to_string(), )); } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); // Step 3-4, Figure 6. // Generates a GGM tree with depth h and seed s. - let s = self.state.prg.random_block(); - let ggm_tree = GgmTree::new(h); - let mut k0 = vec![Block::ZERO; h]; - let mut k1 = vec![Block::ZERO; h]; - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.gen(s, &mut tree, &mut k0, &mut k1); + let mut trees = vec![Vec::::new(); hs.len()]; + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = trees + .par_iter_mut().zip(hs.par_iter()) + .zip(qs_s.par_iter()) + .zip(bss.par_iter()) + .zip(ms_s.par_iter_mut()) + .zip(sum_s.par_iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + }else{ + let iter = trees + .iter_mut() + .zip(hs.iter()) + .zip(qs_s.iter()) + .zip(bss.iter()) + .zip(ms_s.iter_mut()) + .zip(sum_s.iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + } + } + + iter.for_each(|(tree, h, qs, bs, ms, sum)| { + let s = Prg::new().random_block(); + let ggm_tree = GgmTree::new(*h); + let mut k0 = vec![Block::ZERO; *h]; + let mut k1 = vec![Block::ZERO; *h]; + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.gen(s, tree, &mut k0, &mut k1); + + // Computes the sum of the leaves and delta. + *sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); + + // Computes M0 and M1. + for (((i, &q), b), (k0, k1)) in + qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) + { + let mut m = if *b { + [q ^ self.state.delta, q] + } else { + [q, q ^ self.state.delta] + }; + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); + m[0] ^= k0; + m[1] ^= k1; + ms.push(m); + } + }); // Stores the tree, i.e., the possible output of sender. - self.state.unchecked_vs.extend_from_slice(&tree); + for tree in trees { + self.state.unchecked_vs.extend_from_slice(&tree); + } // Stores the length of this extension. - self.state.vs_length.push(1 << h); - - // Computes the sum of the leaves and delta. - let sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); - - // Computes M0 and M1. - let mut ms: Vec<[Block; 2]> = Vec::with_capacity(qs.len()); - for (((i, &q), b), (k0, k1)) in qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) { - let mut m = if b { - [q ^ self.state.delta, q] - } else { - [q, q ^ self.state.delta] - }; - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); - m[0] ^= k0; - m[1] ^= k1; - ms.push(m); + for h in hs { + self.state.vs_length.push(1 << h); } // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); + + let res: Vec = ms_s + .into_iter() + .zip(sum_s.iter()) + .map(|(ms, &sum)| ExtendFromSender { ms, sum }) + .collect(); - Ok(ExtendFromSender { ms, sum }) + Ok(res) } /// Performs the consistency check for the resulting COTs. @@ -193,10 +247,18 @@ impl Sender { res.push(tmp); } - self.state.extended = true; + self.state.hasher = blake3::Hasher::new(); + self.state.unchecked_vs.clear(); + self.state.vs_length.clear(); Ok((res, CheckFromSender { hashed_v })) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The sender's state. @@ -239,8 +301,6 @@ pub mod state { /// This is to prevent the receiver from extending twice pub(super) extended: bool, - /// A PRG to generate random strings. - pub(super) prg: Prg, /// A hasher to generate chi seed. pub(super) hasher: blake3::Hasher, } diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs new file mode 100644 index 00000000..6952f0ec --- /dev/null +++ b/crates/mpz-ot/src/ferret/error.rs @@ -0,0 +1,67 @@ +use crate::OTError; + +/// A Ferret sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::error::SenderError), + #[error(transparent)] + MPCOTSenderError(#[from] crate::ferret::mpcot::SenderError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), + #[error("{0}")] + MPCOTSenderTypeError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A Ferret receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::error::ReceiverError), + #[error(transparent)] + MPCOTReceiverError(#[from] crate::ferret::mpcot::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), + #[error("{0}")] + MPCOTReceiverTypeError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs new file mode 100644 index 00000000..2b2047b9 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -0,0 +1,175 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. +mod error; +mod mpcot; +mod receiver; +mod sender; +mod spcot; + +pub use error::{ReceiverError, SenderError}; +pub use receiver::Receiver; +pub use sender::Sender; + +use mpz_core::lpn::LpnParameters; +use mpz_ot_core::ferret::LpnType; + +/// Configuration of Ferret. +#[derive(Debug)] +pub struct FerretConfig { + rcot: RandomCOT, + setup_rcot: SetupRandomCOT, + lpn_parameters: LpnParameters, + lpn_type: LpnType, +} + +impl FerretConfig { + /// Create a new instance. + /// + /// # Arguments. + /// + /// * `rcot` - The rcot for MPCOT. + /// * `setup_rcot` - The rcot for setup. + /// * `lpn_parameters` - The parameters of LPN. + /// * `lpn_type` - The type of LPN. + pub fn new( + rcot: RandomCOT, + setup_rcot: SetupRandomCOT, + lpn_parameters: LpnParameters, + lpn_type: LpnType, + ) -> Self { + Self { + rcot, + setup_rcot, + lpn_parameters, + lpn_type, + } + } + + /// Get rcot + pub fn rcot(&self) -> RandomCOT { + self.rcot.clone() + } + + /// Get the setup rcot + pub fn setup_rcot(&mut self) -> &mut SetupRandomCOT { + &mut self.setup_rcot + } + + /// Get the lpn type + pub fn lpn_type(&self) -> LpnType { + self.lpn_type + } + + /// Get the lpn parameters + pub fn lpn_parameters(&self) -> LpnParameters { + self.lpn_parameters + } +} + +#[cfg(test)] +mod tests { + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::{lpn::LpnParameters, Block}; + use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, RandomCOTReceiver, RandomCOTSender, + }; + + use super::*; + + // l = n - k = 8380 + const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { + n: 9600, + k: 1220, + t: 600, + }; + + fn setup() -> ( + Sender, + Receiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let sender_config = FerretConfig::new( + rcot_sender.clone(), + rcot_sender.clone(), + LPN_PARAMETERS_TEST, + LpnType::Regular, + ); + + let receiver_config = FerretConfig::new( + rcot_receiver.clone(), + rcot_receiver, + LPN_PARAMETERS_TEST, + LpnType::Regular, + ); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(sender_config); + + let receiver = Receiver::new(receiver_config); + + (sender, receiver, delta) + } + + #[tokio::test] + async fn test_ferret() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, delta) = setup(); + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta) + .map_err(OTError::from), + receiver.setup(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + // extend once. + let count = 8000; + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &b, &u, &w); + + // extend twice + let count = 9000; + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(delta, &b, &u, &w); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/error.rs b/crates/mpz-ot/src/ferret/mpcot/error.rs new file mode 100644 index 00000000..e300bf0d --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/error.rs @@ -0,0 +1,59 @@ +use crate::OTError; + +/// A MPCOT sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs, clippy::enum_variant_names)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::mpcot::error::SenderError), + #[error(transparent)] + SPCOTSenderError(#[from] crate::ferret::spcot::SenderError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::mpcot::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A MPCOT receiver error +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs, clippy::enum_variant_names)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::mpcot::error::ReceiverError), + #[error(transparent)] + SpcotReceiverError(#[from] crate::ferret::spcot::ReceiverError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::mpcot::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/mod.rs b/crates/mpz-ot/src/ferret/mpcot/mod.rs new file mode 100644 index 00000000..598b5734 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/mod.rs @@ -0,0 +1,165 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod error; +mod receiver; +mod sender; + +pub(crate) use error::{ReceiverError, SenderError}; +pub(crate) use receiver::Receiver; +pub(crate) use sender::Sender; + +#[cfg(test)] +mod tests { + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + use mpz_ot_core::ferret::LpnType; + + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, + }; + + use receiver::Receiver; + use sender::Sender; + + use super::*; + + fn setup( + lpn_type: LpnType, + ) -> ( + Sender, + Receiver, + IdealCOTSender, + IdealCOTReceiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(lpn_type); + + let receiver = Receiver::new(lpn_type); + + (sender, receiver, rcot_sender, rcot_receiver, delta) + } + + #[tokio::test] + async fn test_mpcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Uniform); + + let alphas = [0, 1, 3, 4, 2]; + let t = alphas.len(); + let n = 10; + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta, rcot_sender) + .map_err(OTError::from), + receiver + .setup(&mut ctx_receiver, rcot_receiver) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + // extend twice + let alphas = [5, 1, 7, 2]; + let t = alphas.len(); + let n = 16; + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(LpnType::Regular); + + // extend once. + let alphas = [0, 3, 4, 7, 9]; + let t = alphas.len(); + let n = 10; + + tokio::try_join!( + sender + .setup_with_delta(&mut ctx_sender, delta, rcot_sender) + .map_err(OTError::from), + receiver + .setup(&mut ctx_receiver, rcot_receiver) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + // extend twice. + let alphas = [0, 3, 7, 9, 14, 15]; + let t = alphas.len(); + let n = 16; + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender + .extend(&mut ctx_sender, t as u32, n) + .map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, n) + .map_err(OTError::from) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/receiver.rs b/crates/mpz-ot/src/ferret/mpcot/receiver.rs new file mode 100644 index 00000000..e2553efd --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/receiver.rs @@ -0,0 +1,192 @@ +use crate::{ + ferret::{mpcot::error::ReceiverError, spcot::Receiver as SpcotReceiver}, + RandomCOTReceiver, +}; +use enum_try_as_inner::EnumTryAsInner; + +use mpz_common::Context; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + receiver::{state as uniform_state, Receiver as UniformReceiverCore}, + receiver_regular::{state as regular_state, Receiver as RegularReceiverCore}, + }, + LpnType, +}; +use serio::SinkExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + UniformInitialized(UniformReceiverCore), + UniformExtension(UniformReceiverCore), + RegularInitialized(RegularReceiverCore), + RegularExtension(RegularReceiverCore), + Complete, + Error, +} + +/// MPCOT receiver. +#[derive(Debug)] +pub(crate) struct Receiver { + state: State, + spcot: SpcotReceiver, + lpn_type: LpnType, +} + +impl Receiver { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `lpn_type` - The type of LPN. + pub(crate) fn new(lpn_type: LpnType) -> Self { + match lpn_type { + LpnType::Uniform => Self { + state: State::UniformInitialized(UniformReceiverCore::new()), + spcot: crate::ferret::spcot::Receiver::new(), + lpn_type, + }, + LpnType::Regular => Self { + state: State::RegularInitialized(RegularReceiverCore::new()), + spcot: crate::ferret::spcot::Receiver::new(), + lpn_type, + }, + } + } + + /// Performs setup for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `rcot` - The random COT used by Receiver. + pub(crate) async fn setup( + &mut self, + ctx: &mut Ctx, + rcot: RandomCOT, + ) -> Result<(), ReceiverError> { + match self.lpn_type { + LpnType::Uniform => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_initialized()?; + + let hash_seed = Prg::new().random_block(); + + let (ext_receiver, hash_seed) = ext_receiver.setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + self.state = State::UniformExtension(ext_receiver); + } + LpnType::Regular => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_initialized()?; + + let ext_receiver = ext_receiver.setup(); + + self.state = State::RegularExtension(ext_receiver); + } + } + + self.spcot.setup(rcot)?; + + Ok(()) + } + + /// Performs MPCOT extension. + /// + /// + /// # Arguments + /// + /// * `ctx` - The context, + /// * `alphas` - The queried indices. + /// * `n` - The total number of indices. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + alphas: &[u32], + n: u32, + ) -> Result, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let alphas_vec = alphas.to_vec(); + + match self.lpn_type { + LpnType::Uniform => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_extension()?; + + let (ext_receiver, h_and_pos) = + Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + self.spcot.extend(ctx, &pos, &hs).await?; + + let rt = self.spcot.check(ctx).await?; + + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (ext_receiver, output) = + Backend::spawn(move || ext_receiver.extend(&rt)).await?; + + self.state = State::UniformExtension(ext_receiver); + + Ok(output) + } + + LpnType::Regular => { + let ext_receiver = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_extension()?; + + let (ext_receiver, h_and_pos) = + Backend::spawn(move || ext_receiver.pre_extend(&alphas_vec, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + self.spcot.extend(ctx, &pos, &hs).await?; + + let rt = self.spcot.check(ctx).await?; + + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (ext_receiver, output) = + Backend::spawn(move || ext_receiver.extend(&rt)).await?; + + self.state = State::RegularExtension(ext_receiver); + + Ok(output) + } + } + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { + match self.lpn_type { + LpnType::Uniform => { + std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; + } + LpnType::Regular => { + std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; + } + } + + self.spcot.finalize()?; + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot/sender.rs b/crates/mpz-ot/src/ferret/mpcot/sender.rs new file mode 100644 index 00000000..a0256276 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot/sender.rs @@ -0,0 +1,166 @@ +use crate::{ + ferret::{mpcot::error::SenderError, spcot::Sender as SpcotSender}, + RandomCOTSender, +}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, + sender::{state as uniform_state, Sender as UniformSenderCore}, + sender_regular::{state as regular_state, Sender as RegularSenderCore}, + }, + LpnType, +}; +use serio::stream::IoStreamExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + UniformInitialized(UniformSenderCore), + UniformExtension(UniformSenderCore), + RegularInitialized(RegularSenderCore), + RegularExtension(RegularSenderCore), + Complete, + Error, +} + +/// MPCOT sender. +#[derive(Debug)] +pub(crate) struct Sender { + state: State, + spcot: SpcotSender, + lpn_type: LpnType, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Arguments. + /// + /// * `lpn_type` - The type of LPN. + pub(crate) fn new(lpn_type: LpnType) -> Self { + match lpn_type { + LpnType::Uniform => Self { + state: State::UniformInitialized(UniformSenderCore::new()), + spcot: crate::ferret::spcot::Sender::new(), + lpn_type, + }, + LpnType::Regular => Self { + state: State::RegularInitialized(RegularSenderCore::new()), + spcot: crate::ferret::spcot::Sender::new(), + lpn_type, + }, + } + } + + /// Performs setup with provided delta. + /// + /// # Arguments + /// + /// * `ctx` - The channel. + /// * `delta` - The delta value to use for OT extension. + /// * `rcot` - The random COT used by Sender. + pub(crate) async fn setup_with_delta( + &mut self, + ctx: &mut Ctx, + delta: Block, + rcot: RandomCOT, + ) -> Result<(), SenderError> { + match self.lpn_type { + LpnType::Uniform => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_initialized()?; + + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let ext_sender = ext_sender.setup(delta, hash_seed); + + self.state = State::UniformExtension(ext_sender); + } + + LpnType::Regular => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::RegularExtension(ext_sender); + } + } + + self.spcot.setup_with_delta(delta, rcot)?; + + Ok(()) + } + + /// Performs MPCOT extension. + /// + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `t` - The number of queried indices. + /// * `n` - The total number of indices. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + t: u32, + n: u32, + ) -> Result, SenderError> + where + RandomCOT: RandomCOTSender, + { + match self.lpn_type { + LpnType::Uniform => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_uniform_extension()?; + + let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; + + self.spcot.extend(ctx, &hs).await?; + + let st = self.spcot.check(ctx).await?; + + let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; + + self.state = State::UniformExtension(ext_sender); + Ok(output) + } + LpnType::Regular => { + let ext_sender = std::mem::replace(&mut self.state, State::Error) + .try_into_regular_extension()?; + + let (ext_sender, hs) = Backend::spawn(move || ext_sender.pre_extend(t, n)).await?; + + self.spcot.extend(ctx, &hs).await?; + + let st = self.spcot.check(ctx).await?; + + let (ext_sender, output) = Backend::spawn(move || ext_sender.extend(&st)).await?; + + self.state = State::RegularExtension(ext_sender); + Ok(output) + } + } + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { + match self.lpn_type { + LpnType::Uniform => { + std::mem::replace(&mut self.state, State::Error).try_into_uniform_extension()?; + } + LpnType::Regular => { + std::mem::replace(&mut self.state, State::Error).try_into_regular_extension()?; + } + } + + self.spcot.finalize()?; + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs new file mode 100644 index 00000000..520506e8 --- /dev/null +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -0,0 +1,192 @@ +use crate::{ + ferret::{mpcot::Receiver as MpcotReceiver, ReceiverError}, + RandomCOTReceiver, +}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::{ + ferret::receiver::{state, Receiver as ReceiverCore}, + RCOTReceiverOutput, +}; +use serio::SinkExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +use super::FerretConfig; +use crate::{async_trait, OTError}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(ReceiverCore), + Extension(ReceiverCore), + Complete, + Error, +} + +/// Ferret Receiver. +#[derive(Debug)] +pub struct Receiver { + state: State, + mpcot: MpcotReceiver, + config: FerretConfig, +} + +impl Receiver +where + RandomCOT: Send + Default + Clone, + SetupRandomCOT: Send, +{ + /// Creates a new Receiver. + /// + /// # Arguments. + /// + /// * `config` - Ferret configuration. + pub fn new(config: FerretConfig) -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + mpcot: MpcotReceiver::new(config.lpn_type()), + config, + } + } + + /// Setup for receiver. + /// + /// # Arguments. + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> + where + Ctx: Context, + SetupRandomCOT: RandomCOTReceiver, + { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let rcot = self.config.rcot(); + self.mpcot.setup(ctx, rcot).await?; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Get random blocks from ideal Random COT. + + let RCOTReceiverOutput { + choices: u, + msgs: w, + .. + } = self + .config + .setup_rcot() + .receive_random_correlated(ctx, params.k) + .await?; + + let seed = Prg::new().random_block(); + + let (ext_receiver, seed) = ext_receiver.setup(params, lpn_type, seed, &u, &w)?; + + ctx.io_mut().send(seed).await?; + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `ctx` - The channel context. + async fn extend(&mut self, ctx: &mut Ctx) -> Result<(Vec, Vec), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let (alphas, n) = ext_receiver.get_mpcot_query(); + + let r = self.mpcot.extend(ctx, &alphas, n as u32).await?; + + let (ext_receiver, choices, msgs) = Backend::spawn(move || { + ext_receiver + .extend(&r) + .map(|(choices, msgs)| (ext_receiver, choices, msgs)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok((choices, msgs)) + } + + /// Complete extension + pub fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + self.state = State::Complete; + self.mpcot.finalize()?; + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTReceiver + for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send + Clone + Default + 'static, + SetupRandomCOT: Send + 'static, +{ + async fn receive_random_correlated( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let (mut choices_buffer, mut msgs_buffer) = self.extend(ctx).await?; + + assert_eq!(choices_buffer.len(), msgs_buffer.len()); + + let l = choices_buffer.len(); + + let id = self + .state + .try_as_extension() + .map_err(ReceiverError::from)? + .id(); + + if count <= l { + let choices_res = choices_buffer.drain(..count).collect(); + + let msgs_res = msgs_buffer.drain(..count).collect(); + + return Ok(RCOTReceiverOutput { + id, + choices: choices_res, + msgs: msgs_res, + }); + } else { + let mut choices_res = choices_buffer; + let mut msgs_res = msgs_buffer; + + for _ in 0..count / l - 1 { + (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + + choices_res.extend_from_slice(&choices_buffer); + msgs_res.extend_from_slice(&msgs_buffer); + } + + (choices_buffer, msgs_buffer) = self.extend(ctx).await?; + + choices_res.extend_from_slice(&choices_buffer[0..count % l]); + msgs_res.extend_from_slice(&msgs_buffer[0..count % l]); + + return Ok(RCOTReceiverOutput { + id, + choices: choices_res, + msgs: msgs_res, + }); + } + } +} diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs new file mode 100644 index 00000000..709ff8e2 --- /dev/null +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -0,0 +1,160 @@ +use crate::{ferret::mpcot::Sender as MpcotSender, RandomCOTSender}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::sender::{state, Sender as SenderCore}, + RCOTSenderOutput, +}; +use serio::stream::IoStreamExt; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +use super::{FerretConfig, SenderError}; +use crate::{async_trait, OTError}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(SenderCore), + Complete, + Error, +} + +/// Ferret Sender. +#[derive(Debug)] +pub struct Sender { + state: State, + mpcot: MpcotSender, + config: FerretConfig, +} + +impl Sender +where + RandomCOT: Send + Default + Clone, + SetupRandomCOT: Send, +{ + /// Creates a new Sender. + pub fn new(config: FerretConfig) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + mpcot: MpcotSender::new(config.lpn_type()), + config, + } + } + + /// Setup with provided delta. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + /// * `delta` - The provided delta used for sender. + pub async fn setup_with_delta( + &mut self, + ctx: &mut Ctx, + delta: Block, + ) -> Result<(), SenderError> + where + Ctx: Context, + SetupRandomCOT: RandomCOTSender, + { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let rcot = self.config.rcot(); + + self.mpcot.setup_with_delta(ctx, delta, rcot).await?; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Get random blocks from ideal Random COT. + let RCOTSenderOutput { msgs: v, .. } = self + .config + .setup_rcot() + .send_random_correlated(ctx, params.k) + .await?; + + // Get seed for LPN matrix from receiver. + let seed = ctx.io_mut().expect_next().await?; + + // Ferret core setup. + let ext_sender = ext_sender.setup(delta, params, lpn_type, seed, &v)?; + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs extension. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + async fn extend(&mut self, ctx: &mut Ctx) -> Result, SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let (t, n) = ext_sender.get_mpcot_query(); + + let s = self.mpcot.extend(ctx, t, n).await?; + + let (ext_sender, output) = + Backend::spawn(move || ext_sender.extend(&s).map(|output| (ext_sender, output))) + .await?; + self.state = State::Extension(ext_sender); + + Ok(output) + } + + /// Complete extension + pub fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + self.state = State::Complete; + self.mpcot.finalize()?; + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTSender + for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send + Default + Clone + 'static, + SetupRandomCOT: Send + 'static, +{ + async fn send_random_correlated( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let mut buffer = self.extend(ctx).await?; + let l = buffer.len(); + + let id = self + .state + .try_as_extension() + .map_err(SenderError::from)? + .id(); + + if count <= l { + let res = buffer.drain(..count).collect(); + return Ok(RCOTSenderOutput { id, msgs: res }); + } else { + let mut res = buffer; + for _ in 0..count / l - 1 { + buffer = self.extend(ctx).await?; + res.extend_from_slice(&buffer); + } + + buffer = self.extend(ctx).await?; + res.extend_from_slice(&buffer[0..count % l]); + + return Ok(RCOTSenderOutput { id, msgs: res }); + } + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/error.rs b/crates/mpz-ot/src/ferret/spcot/error.rs new file mode 100644 index 00000000..5f23f466 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/error.rs @@ -0,0 +1,59 @@ +use crate::OTError; + +/// A SPCOT sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs, clippy::enum_variant_names)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::spcot::error::SenderError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => OTError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::ferret::spcot::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +/// A SPCOT receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs, clippy::enum_variant_names)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_ot_core::ferret::spcot::error::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] OTError), + #[error("{0}")] + StateError(String), +} + +impl From for OTError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => OTError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::ferret::spcot::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/mod.rs b/crates/mpz-ot/src/ferret/spcot/mod.rs new file mode 100644 index 00000000..6e53fd28 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/mod.rs @@ -0,0 +1,103 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +mod error; +mod receiver; +mod sender; + +pub(crate) use error::{ReceiverError, SenderError}; +pub(crate) use receiver::Receiver; +pub(crate) use sender::Sender; + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}, + OTError, + }; + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + + fn setup() -> ( + Sender, + Receiver, + IdealCOTSender, + IdealCOTReceiver, + Block, + ) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(); + let receiver = Receiver::new(); + + (sender, receiver, rcot_sender, rcot_receiver, delta) + } + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, rcot_sender, rcot_receiver, delta) = setup(); + + // shold set the same delta as in RCOT. + sender.setup_with_delta(delta, rcot_sender).unwrap(); + receiver.setup(rcot_receiver).unwrap(); + + let hs = [8, 4]; + let alphas = [4, 2]; + + tokio::try_join!( + sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, &hs) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender.check(&mut ctx_sender).map_err(OTError::from), + receiver.check(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice. + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + tokio::try_join!( + sender.extend(&mut ctx_sender, &hs).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, &alphas, &hs) + .map_err(OTError::from) + ) + .unwrap(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + sender.check(&mut ctx_sender).map_err(OTError::from), + receiver.check(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + sender.finalize().unwrap(); + receiver.finalize().unwrap(); + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/receiver.rs b/crates/mpz-ot/src/ferret/spcot/receiver.rs new file mode 100644 index 00000000..3c48bfad --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/receiver.rs @@ -0,0 +1,164 @@ +use crate::{ferret::spcot::error::ReceiverError, RandomCOTReceiver}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::ExtendFromSender, + receiver::{state, Receiver as ReceiverCore}, + }, + CSP, + }, + RCOTReceiverOutput, +}; +use serio::{stream::IoStreamExt, SinkExt}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(ReceiverCore), + Extension(Box>), + Complete, + Error, +} + +/// SPCOT Receiver. +#[derive(Debug)] +pub(crate) struct Receiver { + state: State, + rcot: RandomCOT, +} + +impl Receiver { + /// Creates a new Receiver. + pub(crate) fn new() -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + rcot: Default::default(), + } + } + + /// Performs setup for receiver. + /// + /// # Arguments. + /// + /// * `rcot` - The random COT used by the receiver. + pub(crate) fn setup(&mut self, rcot: RandomCOT) -> Result<(), ReceiverError> { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_receiver = ext_receiver.setup(); + self.state = State::Extension(Box::new(ext_receiver)); + self.rcot = rcot; + Ok(()) + } + + /// Performs spcot extension for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `alphas`` - The vector of chosen positions. + /// * `h` - The depth of GGM tree. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + alphas: &[u32], + hs: &[usize], + ) -> Result<(), ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = self.rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut ext_receiver, masks) = Backend::spawn(move || { + ext_receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (ext_receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let ext_receiver = Backend::spawn(move || { + ext_receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| ext_receiver) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs batch check for SPCOT extension. + /// + /// # Arguments + /// + /// * `ctx` - The context. + pub(crate) async fn check( + &mut self, + ctx: &mut Ctx, + ) -> Result, u32)>, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = self.rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut ext_receiver, checkfr) = Backend::spawn(move || { + ext_receiver + .check_pre(&x_star) + .map(|checkfr| (ext_receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let (ext_receiver, output) = Backend::spawn(move || { + ext_receiver + .check(&z_star, check) + .map(|output| (ext_receiver, output)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(output) + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot/sender.rs b/crates/mpz-ot/src/ferret/spcot/sender.rs new file mode 100644 index 00000000..9178b787 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot/sender.rs @@ -0,0 +1,144 @@ +use crate::{ferret::spcot::error::SenderError, RandomCOTSender}; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::MaskBits, + sender::{state, Sender as SenderCore}, + }, + CSP, + }, + RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt, SinkExt}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(Box>), + Complete, + Error, +} + +/// SPCOT sender. +#[derive(Debug)] +pub(crate) struct Sender { + state: State, + rcot: RandomCOT, +} + +impl Sender { + /// Creates a new Sender. + pub(crate) fn new() -> Self { + Self { + state: State::Initialized(SenderCore::new()), + rcot: Default::default(), + } + } + + /// Performs setup with the provided delta. + /// + /// # Arguments + /// + /// * `delta` - The delta value to use for OT extension. + /// * `rcot` - The random COT used by the sender. + pub(crate) fn setup_with_delta( + &mut self, + delta: Block, + rcot: RandomCOT, + ) -> Result<(), SenderError> { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::Extension(Box::new(ext_sender)); + self.rcot = rcot; + Ok(()) + } + + /// Performs spcot extension for sender. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `hs` - The depths of GGM trees. + pub(crate) async fn extend( + &mut self, + ctx: &mut Ctx, + hs: &[usize], + ) -> Result<(), SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = self.rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (ext_sender, extend_msg) = Backend::spawn(move || { + ext_sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (ext_sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs batch check for SPCOT extension. + /// + /// # Arguments + /// + /// * `ctx` - The context. + pub(crate) async fn check( + &mut self, + ctx: &mut Ctx, + ) -> Result>, SenderError> + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = + self.rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (ext_sender, output, check_msg) = Backend::spawn(move || { + ext_sender + .check(&y_star, checkfr) + .map(|(output, check_msg)| (ext_sender, output, check_msg)) + }) + .await?; + + ctx.io_mut().send(check_msg).await?; + + self.state = State::Extension(ext_sender); + + Ok(output) + } + + /// Complete extension. + pub(crate) fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index b0084957..18233dfe 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -46,9 +46,16 @@ pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { } /// Ideal COT sender. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTSender(Alice); +impl IdealCOTSender { + /// Returns Alice. + pub fn alice(&mut self) -> &mut Alice { + &mut self.0 + } +} + #[async_trait] impl OTSetup for IdealCOTSender where @@ -98,7 +105,7 @@ impl RandomCOTSender for IdealCOTSender { } /// Ideal COT receiver. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTReceiver(Bob); #[async_trait] diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index b9871eab..d53e322b 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -10,7 +10,7 @@ )] pub mod chou_orlandi; -#[cfg(any(test, feature = "ideal"))] +pub mod ferret; pub mod ideal; pub mod kos; diff --git a/crates/mpz-zk-core/Cargo.toml b/crates/mpz-zk-core/Cargo.toml new file mode 100644 index 00000000..e390db57 --- /dev/null +++ b/crates/mpz-zk-core/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "mpz-zk-core" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lints] +workspace = true + +[lib] +name = "mpz_zk_core" + +[features] +default = ["rayon", "test-utils"] +rayon = ["dep:rayon", "itybity/rayon"] +test-utils = [] + +[dependencies] +mpz-core.workspace = true +mpz-ot-core.workspace = true +clmul.workspace = true +matrix-transpose.workspace = true + +tlsn-utils.workspace = true + +rayon = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } +thiserror.workspace = true +derive_builder.workspace = true +itybity.workspace = true +opaque-debug.workspace = true +cfg-if.workspace = true +bytemuck = { workspace = true, features = ["derive"] } +enum-try-as-inner.workspace = true diff --git a/crates/mpz-zk-core/src/lib.rs b/crates/mpz-zk-core/src/lib.rs new file mode 100644 index 00000000..92ca8ef1 --- /dev/null +++ b/crates/mpz-zk-core/src/lib.rs @@ -0,0 +1,23 @@ +//! Low-level crate containing core functionalities for zero-knowledge protocols. +//! +//! This crate is not intended to be used directly. Instead, use the higher-level APIs provided by +//! the `mpz-zk` crate. +//! +//! # ⚠️ Warning ⚠️ +//! +//! Some implementations make assumptions about invariants which may not be checked if using these +//! low-level APIs naively. Failing to uphold these invariants may result in security vulnerabilities. +//! +//! USE AT YOUR OWN RISK. + +#![deny( + unsafe_code, + missing_docs, + unused_imports, + unused_must_use, + unreachable_pub, + clippy::all +)] + +pub mod test; +pub mod vope; diff --git a/crates/mpz-zk-core/src/test.rs b/crates/mpz-zk-core/src/test.rs new file mode 100644 index 00000000..b22369d3 --- /dev/null +++ b/crates/mpz-zk-core/src/test.rs @@ -0,0 +1,11 @@ +//! test functions. + +use mpz_core::Block; + +/// Check polynomial relation. +pub fn poly_check(a: &[Block], b: Block, delta: Block) -> bool { + b == a + .iter() + .rev() + .fold(Block::ZERO, |acc, &x| x ^ (delta.gfmul(acc))) +} diff --git a/crates/mpz-zk-core/src/vope/error.rs b/crates/mpz-zk-core/src/vope/error.rs new file mode 100644 index 00000000..101a11b3 --- /dev/null +++ b/crates/mpz-zk-core/src/vope/error.rs @@ -0,0 +1,21 @@ +//! Errors that can occur when using VOPE. + +/// Errors that can occur when using VOPE sender (verifier). +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error("invalid input: expected {0}")] + InvalidInput(String), + #[error("invalid length: expected {0}")] + InvalidLength(String), +} + +/// Errors that can occur when using VOPE receiver (prover). +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error("invalid input: expected {0}")] + InvalidInput(String), + #[error("invalid length: expected {0}")] + InvalidLength(String), +} diff --git a/crates/mpz-zk-core/src/vope/mod.rs b/crates/mpz-zk-core/src/vope/mod.rs new file mode 100644 index 00000000..3da511be --- /dev/null +++ b/crates/mpz-zk-core/src/vope/mod.rs @@ -0,0 +1,65 @@ +//! This is the implementation of vector oblivious polynomial evaluation (VOPE) based on Figure 4 in https://eprint.iacr.org/2021/076.pdf + +pub mod error; +pub mod receiver; +pub mod sender; + +/// Security parameter +pub const CSP: usize = 128; + +#[cfg(test)] +mod tests { + use mpz_core::prg::Prg; + use mpz_ot_core::{ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; + + use crate::test::poly_check; + + use super::{receiver::Receiver, sender::Sender, CSP}; + + #[test] + fn vope_test() { + let mut prg = Prg::new(); + let delta = prg.random_block(); + + let mut ideal_cot = IdealCOT::default(); + ideal_cot.set_delta(delta); + + let sender = Sender::new(); + let receiver = Receiver::new(); + + let mut sender = sender.setup(delta); + let mut receiver = receiver.setup(); + + let d = 1; + + let (sender_cot, receiver_cot) = ideal_cot.random_correlated((2 * d - 1) * CSP); + + let RCOTSenderOutput { msgs: ks, .. } = sender_cot; + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = receiver_cot; + + let sender_out = sender.extend(&ks, d).unwrap(); + let receiver_out = receiver.extend(&ms, &us, d).unwrap(); + + assert!(poly_check(&receiver_out, sender_out, delta)); + + let d = 5; + + let (sender_cot, receiver_cot) = ideal_cot.random_correlated((2 * d - 1) * CSP); + + let RCOTSenderOutput { msgs: ks, .. } = sender_cot; + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = receiver_cot; + + let sender_out = sender.extend(&ks, d).unwrap(); + let receiver_out = receiver.extend(&ms, &us, d).unwrap(); + + assert!(poly_check(&receiver_out, sender_out, delta)); + } +} diff --git a/crates/mpz-zk-core/src/vope/receiver.rs b/crates/mpz-zk-core/src/vope/receiver.rs new file mode 100644 index 00000000..9a0e4fce --- /dev/null +++ b/crates/mpz-zk-core/src/vope/receiver.rs @@ -0,0 +1,162 @@ +//! VOPE receiver. +use mpz_core::Block; + +use crate::vope::CSP; + +use super::error::ReceiverError; + +/// VOPE receiver +/// This is the prover in Figure 4. +#[derive(Debug, Default)] +pub struct Receiver { + state: T, +} + +impl Receiver { + /// Create a new receiver. + pub fn new() -> Self { + Receiver { + state: state::Initialized::default(), + } + } + + /// Completes the setup phase of the protocol. + /// + /// See Initialize in Figure 4. + pub fn setup(self) -> Receiver { + Receiver { + state: state::Extension { + vope_counter: 0, + exec_counter: 0, + }, + } + } +} + +impl Receiver { + /// Performs VOPE extension. + /// + /// See step 1-3 in Figure 4. + /// + /// # Arguments + /// + /// * `ms` - The blocks received by calling the COT ideal functionality. + /// * `us` - The bits received by calling the COT ideal functionality. + /// * `d` - The degree of the polynomial. + /// + /// Note that this functionality is only suitable for small d. + pub fn extend( + &mut self, + ms: &[Block], + us: &[bool], + d: usize, + ) -> Result, ReceiverError> { + if d == 0 { + return Err(ReceiverError::InvalidInput( + "the degree d should not be 0".to_string(), + )); + } + + if ms.len() != us.len() { + return Err(ReceiverError::InvalidLength( + "the length of ms and us should be equal".to_string(), + )); + } + + if ms.len() != (2 * d - 1) * CSP { + return Err(ReceiverError::InvalidLength( + "the length of ms and us should be (2 * d -1) * CSP".to_string(), + )); + } + + let mut h_ms = ms.to_vec(); + let mut h_us = us.to_vec(); + + let mut mi = vec![Block::ZERO; 2 * d - 1]; + let mut ui = vec![Block::ZERO; 2 * d - 1]; + + let base: Vec = (0..CSP) + .map(|x| bytemuck::cast((1_u128) << (CSP - 1 - x))) + .collect(); + + for i in 0..(2 * d - 1) { + let m = h_ms.split_off(CSP); + let u = h_us.split_off(CSP); + + mi[i] = Block::inn_prdt_red(&h_ms, &base); + + ui[i] = + h_us.iter().zip(base.iter()).fold( + Block::ZERO, + |acc, (b, base)| { + if *b { + acc ^ *base + } else { + acc + } + }, + ); + h_ms = m; + h_us = u; + } + + let mut gi = vec![Block::ZERO; d + 1]; + gi[0] = mi[0]; + gi[1] = ui[0]; + + for i in 0..d - 1 { + poly_update(&mut gi, mi[i + 1], ui[i + 1], i + 2); + gi[0] ^= mi[d + i]; + gi[1] ^= ui[d + i]; + } + + self.state.exec_counter += 1; + self.state.vope_counter += 1; + + Ok(gi) + } +} + +fn poly_update(g: &mut [Block], m: Block, u: Block, length: usize) { + let mut buffer = vec![Block::ZERO; length + 1]; + for i in 0..length { + buffer[i + 1] = g[i].gfmul(u); + g[i] = g[i].gfmul(m); + + g[i] ^= buffer[i]; + } + g[length] = buffer[length]; +} + +/// The receiver's state. +pub mod state { + mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + /// The receiver's state. + pub trait State: sealed::Sealed {} + + /// The receiver's initial state. + #[derive(Default)] + pub struct Initialized {} + + impl State for Initialized {} + opaque_debug::implement!(Initialized); + + /// The receiver's state after the setup phase. + /// + /// In this state the sender performs VOPE extension. + pub struct Extension { + /// Current VOPE counter + pub(super) vope_counter: usize, + /// Current execution counter + pub(super) exec_counter: usize, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} diff --git a/crates/mpz-zk-core/src/vope/sender.rs b/crates/mpz-zk-core/src/vope/sender.rs new file mode 100644 index 00000000..2bca3fd6 --- /dev/null +++ b/crates/mpz-zk-core/src/vope/sender.rs @@ -0,0 +1,128 @@ +//! VOPE sender. +use mpz_core::Block; + +use crate::vope::CSP; + +use super::error::SenderError; + +/// VOPE sender +/// This is the verifier in Figure 4. +#[derive(Debug, Default)] +pub struct Sender { + state: T, +} + +impl Sender { + /// Creates a new sender. + pub fn new() -> Self { + Sender { + state: state::Initialized::default(), + } + } + + /// Completes the setup phase of the protocol. + /// + /// See Initialize in Figure 4. + /// + /// # Arguments. + /// + /// * `delta` - The sender's global secret. + pub fn setup(self, delta: Block) -> Sender { + Sender { + state: state::Extension { + delta, + vope_counter: 0, + exec_counter: 0, + }, + } + } +} + +impl Sender { + /// Performs VOPE extension. + /// + /// See step 1-3 in Figure 4. + /// + /// # Arguments + /// + /// * `ks` - The blocks received by calling the COT ideal functionality. + /// * `d` - The degree of the polynomial. + /// + /// Note that this functionality is only suitable for small d. + pub fn extend(&mut self, ks: &[Block], d: usize) -> Result { + if d == 0 { + return Err(SenderError::InvalidInput( + "the degree d should not be 0".to_string(), + )); + } + + if ks.len() != (2 * d - 1) * CSP { + return Err(SenderError::InvalidLength( + "the length of ks should be (2 * d -1) * CSP".to_string(), + )); + } + + let mut ki = vec![Block::ZERO; 2 * d - 1]; + + let base: Vec = (0..CSP) + .map(|x| bytemuck::cast((1_u128) << (CSP - 1 - x))) + .collect(); + + let mut h_ks = ks.to_vec(); + + for k in ki.iter_mut().take(2 * d - 1) { + let buf = h_ks.split_off(CSP); + *k = Block::inn_prdt_red(&h_ks, &base); + h_ks = buf; + } + + let mut b = ki[0]; + + for i in 0..d - 1 { + b = b.gfmul(ki[i + 1]) ^ ki[d + i] + } + + self.state.exec_counter += 1; + self.state.vope_counter += 1; + + Ok(b) + } +} +/// The sender's state. +pub mod state { + use super::*; + + mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + /// The sender's state. + pub trait State: sealed::Sealed {} + + /// The sender's initial state. + #[derive(Default)] + pub struct Initialized {} + + impl State for Initialized {} + opaque_debug::implement!(Initialized); + + /// The sender's state after the setup phase. + /// + /// In this state the sender performs VOPE extension. + pub struct Extension { + /// Sender's global secret. + #[allow(dead_code)] + pub(crate) delta: Block, + + /// Current VOPE counter + pub(super) vope_counter: usize, + /// Current execution counter + pub(super) exec_counter: usize, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} diff --git a/crates/mpz-zk/Cargo.toml b/crates/mpz-zk/Cargo.toml new file mode 100644 index 00000000..54b42ced --- /dev/null +++ b/crates/mpz-zk/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "mpz-zk" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lints] +workspace = true + +[lib] +name = "mpz_zk" + +[features] +default = ["rayon"] +rayon = ["mpz-ot-core/rayon"] +ideal = ["mpz-common/ideal"] + +[dependencies] +mpz-core.workspace = true +mpz-zk-core.workspace = true +mpz-common.workspace = true +mpz-cointoss.workspace = true +mpz-ot-core.workspace = true +mpz-ot.workspace = true + +tlsn-utils-aio.workspace = true + +async-trait.workspace = true +futures.workspace = true +rand.workspace = true +rand_core.workspace = true +rand_chacha.workspace = true +thiserror.workspace = true +rayon = { workspace = true } +itybity.workspace = true +enum-try-as-inner.workspace = true +opaque-debug.workspace = true +serde = { workspace = true, optional = true } +serio.workspace = true +cfg-if.workspace = true + +[dev-dependencies] +mpz-common = { workspace = true, features = ["test-utils", "ideal"] } +mpz-ot-core = { workspace = true, features = ["test-utils"] } +rstest = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } +tokio = { workspace = true, features = [ + "net", + "macros", + "rt", + "rt-multi-thread", +] } diff --git a/crates/mpz-zk/src/lib.rs b/crates/mpz-zk/src/lib.rs new file mode 100644 index 00000000..4ae2e10f --- /dev/null +++ b/crates/mpz-zk/src/lib.rs @@ -0,0 +1,24 @@ +//! Implementations of zero-knowledge protocols. + +#![deny( + unsafe_code, + missing_docs, + unused_imports, + unused_must_use, + unreachable_pub, + clippy::all +)] + +pub mod vope; + +/// An oblivious transfer error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum VOPEError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("sender error: {0}")] + SenderError(Box), + #[error("receiver error: {0}")] + ReceiverError(Box), +} diff --git a/crates/mpz-zk/src/vope/error.rs b/crates/mpz-zk/src/vope/error.rs new file mode 100644 index 00000000..912efda1 --- /dev/null +++ b/crates/mpz-zk/src/vope/error.rs @@ -0,0 +1,61 @@ +//! Errors in VOPE + +use crate::VOPEError; + +/// A VOPE Sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::vope::error::SenderError), + #[error(transparent)] + RandomCOTError(#[from] mpz_ot::OTError), + #[error("{0}")] + StateError(String), +} + +/// A VOPE Receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::vope::error::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] mpz_ot::OTError), + #[error("{0}")] + StateError(String), +} + +impl From for VOPEError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => VOPEError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::vope::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +impl From for VOPEError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => VOPEError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::vope::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-zk/src/vope/mod.rs b/crates/mpz-zk/src/vope/mod.rs new file mode 100644 index 00000000..1b34799d --- /dev/null +++ b/crates/mpz-zk/src/vope/mod.rs @@ -0,0 +1,63 @@ +//! This is the implementation of vector oblivious polynomial evaluation (VOPE) based on Figure 4 in https://eprint.iacr.org/2021/076.pdf + +pub mod error; +pub mod receiver; +pub mod sender; + +#[cfg(test)] +mod tests { + use crate::{ + vope::{receiver::Receiver, sender::Sender}, + VOPEError, + }; + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_core::Block; + use mpz_ot::ideal::cot::{ideal_rcot, IdealCOTReceiver, IdealCOTSender}; + use mpz_zk_core::test::poly_check; + + fn setup() -> (Sender, Receiver, Block) { + let (mut rcot_sender, rcot_receiver) = ideal_rcot(); + + let delta = rcot_sender.alice().get_mut().delta(); + + let sender = Sender::new(rcot_sender); + let receiver = Receiver::new(rcot_receiver); + + (sender, receiver, delta) + } + + #[tokio::test] + async fn test_vope() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut sender, mut receiver, delta) = setup(); + + sender.setup_with_delta(delta).unwrap(); + receiver.setup().unwrap(); + + let d = 1; + + let (output_sender, output_receiver) = tokio::try_join!( + sender.extend(&mut ctx_sender, d).map_err(VOPEError::from), + receiver + .extend(&mut ctx_receiver, d) + .map_err(VOPEError::from) + ) + .unwrap(); + + assert!(poly_check(&output_receiver, output_sender, delta)); + + let d = 5; + + let (output_sender, output_receiver) = tokio::try_join!( + sender.extend(&mut ctx_sender, d).map_err(VOPEError::from), + receiver + .extend(&mut ctx_receiver, d) + .map_err(VOPEError::from) + ) + .unwrap(); + + assert!(poly_check(&output_receiver, output_sender, delta)); + } +} diff --git a/crates/mpz-zk/src/vope/receiver.rs b/crates/mpz-zk/src/vope/receiver.rs new file mode 100644 index 00000000..d1fb3cdd --- /dev/null +++ b/crates/mpz-zk/src/vope/receiver.rs @@ -0,0 +1,105 @@ +//! Implementation of VOPE receiver. + +use crate::vope::error::ReceiverError; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver}; +use mpz_zk_core::vope::{ + receiver::{state, Receiver as ReceiverCore}, + CSP, +}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum State { + Initialized(ReceiverCore), + Extension(ReceiverCore), + Complete, + Error, +} + +/// VOPE receiver (prover) +#[derive(Debug)] +pub struct Receiver { + state: State, + rcot: RandomCOT, +} + +impl Receiver { + /// Creates a new receiver. + /// + /// # Arguments + /// + /// * `rcot` - The random COT used by the receiver. + pub fn new(rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + rcot, + } + } + + /// Performs setup for receiver. + pub fn setup(&mut self) -> Result<(), ReceiverError> { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_receiver = ext_receiver.setup(); + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs VOPE extension for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `d` - The polynomial degree. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + d: usize, + ) -> Result, ReceiverError> + where + RandomCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + assert!(d > 0); + + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = self + .rcot + .receive_random_correlated(ctx, (2 * d - 1) * CSP) + .await?; + + // extend + let (ext_receiver, res) = Backend::spawn(move || { + ext_receiver + .extend(&ms, &us, d) + .map(|res| (ext_receiver, res)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(res) + } + + /// Complete extension. + pub fn finalize(&mut self) -> Result<(), ReceiverError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +} diff --git a/crates/mpz-zk/src/vope/sender.rs b/crates/mpz-zk/src/vope/sender.rs new file mode 100644 index 00000000..447dcb6b --- /dev/null +++ b/crates/mpz-zk/src/vope/sender.rs @@ -0,0 +1,99 @@ +//! Implementation of VOPE sender + +use crate::vope::error::SenderError; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot::{RCOTSenderOutput, RandomCOTSender}; +use mpz_zk_core::vope::{ + sender::{state, Sender as SenderCore}, + CSP, +}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum State { + Initialized(SenderCore), + Extension(SenderCore), + Complete, + Error, +} + +/// VOPE sender (verifier) +#[derive(Debug)] +pub struct Sender { + state: State, + rcot: RandomCOT, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Arguments + /// + /// * `rcot` - The random COT used by the sender. + pub fn new(rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + rcot, + } + } + + /// Performs setup with the provided delta. + /// + /// # Arguments + /// + /// * `delta` - The delta value to use for VOPE extension. + pub fn setup_with_delta(&mut self, delta: Block) -> Result<(), SenderError> { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs VOPE extension for sender. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `d` - The polynomial degree. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + d: usize, + ) -> Result + where + RandomCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + assert!(d > 0); + + let RCOTSenderOutput { msgs: ks, .. } = self + .rcot + .send_random_correlated(ctx, (2 * d - 1) * CSP) + .await?; + + let (ext_sender, res) = + Backend::spawn(move || ext_sender.extend(&ks, d).map(|res| (ext_sender, res))).await?; + + self.state = State::Extension(ext_sender); + + Ok(res) + } + + /// Complete extension. + pub fn finalize(&mut self) -> Result<(), SenderError> { + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + self.state = State::Complete; + + Ok(()) + } +}