diff --git a/tfhe/src/integer/server_key/radix/add.rs b/tfhe/src/integer/server_key/radix/add.rs index 329df0377c..43985b16ae 100644 --- a/tfhe/src/integer/server_key/radix/add.rs +++ b/tfhe/src/integer/server_key/radix/add.rs @@ -1,5 +1,5 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; +use crate::integer::server_key::radix_parallel::ComputationFlags; use crate::integer::server_key::CheckError; use crate::integer::{BooleanBlock, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel}; @@ -267,6 +267,16 @@ impl ServerKey { lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext, ) -> (SignedRadixCiphertext, BooleanBlock) { - self.unchecked_signed_overflowing_add_or_sub(lhs, rhs, SignedOperation::Addition) + let mut result = lhs.clone(); + let overflowed = self + .advanced_add_assign_with_carry_sequential( + &mut result.blocks, + &rhs.blocks, + None, + ComputationFlags::from_signedness(true), + ) + .expect("overflow flat was reuested"); + + (result, overflowed) } } diff --git a/tfhe/src/integer/server_key/radix/scalar_sub.rs b/tfhe/src/integer/server_key/radix/scalar_sub.rs index c844941ecf..0023bca31c 100644 --- a/tfhe/src/integer/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix/scalar_sub.rs @@ -65,7 +65,7 @@ impl ServerKey { // - `None` if scalar is zero // - `Some` if scalar is non-zero // - fn create_negated_block_decomposer( + pub(crate) fn create_negated_block_decomposer( &self, scalar: Scalar, ) -> Option> diff --git a/tfhe/src/integer/server_key/radix/sub.rs b/tfhe/src/integer/server_key/radix/sub.rs index 0235bb3a9d..f2df889b5c 100644 --- a/tfhe/src/integer/server_key/radix/sub.rs +++ b/tfhe/src/integer/server_key/radix/sub.rs @@ -1,9 +1,8 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; +use crate::integer::server_key::radix_parallel::ComputationFlags; use crate::integer::server_key::CheckError; use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel}; -use crate::shortint::Ciphertext; impl ServerKey { /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. @@ -420,124 +419,23 @@ impl ServerKey { (result, overflowed) } - pub(crate) fn unchecked_signed_overflowing_add_or_sub( - &self, - lhs: &SignedRadixCiphertext, - rhs: &SignedRadixCiphertext, - signed_operation: SignedOperation, - ) -> (SignedRadixCiphertext, BooleanBlock) { - let mut result = lhs.clone(); - - let num_blocks = result.blocks.len(); - if num_blocks == 0 { - return (result, self.create_trivial_boolean_block(false)); - } - - fn block_add_assign_returning_carry( - sks: &ServerKey, - lhs: &mut Ciphertext, - rhs: &Ciphertext, - ) -> Ciphertext { - sks.key.unchecked_add_assign(lhs, rhs); - let (carry, message) = rayon::join( - || sks.key.carry_extract(lhs), - || sks.key.message_extract(lhs), - ); - - *lhs = message; - - carry - } - - // 2_2, 3_3, 4_4 - // If we have at least 2 bits and at least as much carries - if self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= self.key.message_modulus.0 - { - if signed_operation == SignedOperation::Subtraction { - self.unchecked_sub_assign(&mut result, rhs); - } else { - self.unchecked_add_assign(&mut result, rhs); - } - - let mut input_carry = self.key.create_trivial(0); - - // For the first block do the first step of overflow computation in parallel - let (_, last_block_inner_propagation) = rayon::join( - || { - input_carry = - block_add_assign_returning_carry(self, &mut result.blocks[0], &input_carry); - }, - || { - self.generate_last_block_inner_propagation( - &lhs.blocks[num_blocks - 1], - &rhs.blocks[num_blocks - 1], - signed_operation, - ) - }, - ); - - for block in result.blocks[1..num_blocks - 1].iter_mut() { - input_carry = block_add_assign_returning_carry(self, block, &input_carry); - } - - // Treat the last block separately to handle last step of overflow behavior - let output_carry = block_add_assign_returning_carry( - self, - &mut result.blocks[num_blocks - 1], - &input_carry, - ); - let overflowed = self.resolve_signed_overflow( - last_block_inner_propagation, - &BooleanBlock::new_unchecked(input_carry), - &BooleanBlock::new_unchecked(output_carry), - ); - - return (result, overflowed); - } - - // 1_X parameters - // - // Same idea as other algorithms, however since we have 1 bit per block - // we do not have to resolve any inner propagation but it adds one more - // sequential PBS - if self.key.message_modulus.0 == 2 { - if signed_operation == SignedOperation::Subtraction { - self.unchecked_sub_assign(&mut result, rhs); - } else { - self.unchecked_add_assign(&mut result, rhs); - } - - let mut input_carry = self.key.create_trivial(0); - for block in result.blocks[..num_blocks - 1].iter_mut() { - input_carry = block_add_assign_returning_carry(self, block, &input_carry); - } - - let output_carry = block_add_assign_returning_carry( - self, - &mut result.blocks[num_blocks - 1], - &input_carry, - ); - - // Encode the rule - // "Overflow occurred if the carry into the last bit is different than the carry out - // of the last bit" - let overflowed = self.key.not_equal(&output_carry, &input_carry); - return (result, BooleanBlock::new_unchecked(overflowed)); - } - - panic!( - "Invalid combo of message modulus ({}) and carry modulus ({}) \n\ - This function requires the message modulus >= 2 and carry modulus >= message_modulus \n\ - I.e. PARAM_MESSAGE_X_CARRY_Y where X >= 1 and Y >= X.", - self.key.message_modulus.0, self.key.carry_modulus.0 - ); - } pub fn unchecked_signed_overflowing_sub( &self, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext, ) -> (SignedRadixCiphertext, BooleanBlock) { - self.unchecked_signed_overflowing_add_or_sub(lhs, rhs, SignedOperation::Subtraction) + let flipped_rhs = self.bitnot(rhs); + let carry = self.create_trivial_boolean_block(true); + let mut result = lhs.clone(); + let overflowed = self + .advanced_add_assign_with_carry_sequential( + &mut result.blocks, + &flipped_rhs.blocks, + Some(&carry), + ComputationFlags::from_signedness(true), + ) + .expect("overflow flat was requested"); + (result, overflowed) } pub fn signed_overflowing_sub( diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index bf76103fc5..7815bf5dae 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -1,41 +1,24 @@ use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::server_key::radix_parallel::sub::SignedOperation; use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::Degree; use crate::shortint::Ciphertext; use rayon::prelude::*; -#[repr(u64)] -#[derive(PartialEq, Eq)] -pub(crate) enum OutputCarry { - /// The block does not generate nor propagate a carry - None = 0, - /// The block generates a carry - Generated = 1, - /// The block will propagate a carry if it ever - /// receives one - Propagated = 2, +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub(crate) enum ComputationFlags { + None, + Overflow, + Carry, } -/// Function to create the LUT used in parallel prefix sum -/// to compute carry propagation -/// -/// If msb propagates it take the value of lsb, -/// this means: -/// - if lsb propagates, msb will propagate (but we don't know yet if there will actually be a carry -/// to propagate), -/// - if lsb generates a carry, as msb propagates it, lsb will generate a carry. Note that this lsb -/// generates might be due to x propagating ('resolved' by an earlier iteration of the loop) -/// - if lsb does not output a carry, msb will have nothing to propagate -/// -/// Otherwise, msb either does not generate, or it does generate, -/// but it means it won't propagate -fn prefix_sum_carry_propagation(msb: u64, lsb: u64) -> u64 { - if msb == OutputCarry::Propagated as u64 { - lsb - } else { - msb +impl ComputationFlags { + pub(crate) const fn from_signedness(is_signed: bool) -> Self { + if is_signed { + Self::Overflow + } else { + Self::Carry + } } } @@ -231,14 +214,10 @@ impl ServerKey { } }; - if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, rhs); - } else { - self.unchecked_add_assign(lhs, rhs); - self.full_propagate_parallelized(lhs); - } + self.add_assign_with_carry(lhs, rhs, None); } - /// Computes the addition of two unsigned ciphertexts and returns the overflow flag + + /// Computes the addition of two ciphertexts and returns the overflow flag /// /// # Example /// @@ -265,23 +244,25 @@ impl ServerKey { /// assert_eq!(dec_result, expected_result); /// assert_eq!(dec_overflowed, expected_overflow); /// ``` - pub fn unsigned_overflowing_add_parallelized( - &self, - ct_left: &RadixCiphertext, - ct_right: &RadixCiphertext, - ) -> (RadixCiphertext, BooleanBlock) { + pub fn overflowing_add_parallelized(&self, ct_left: &T, ct_right: &T) -> (T, BooleanBlock) + where + T: IntegerRadixCiphertext, + { let mut ct_res = ct_left.clone(); - let overflowed = self.unsigned_overflowing_add_assign_parallelized(&mut ct_res, ct_right); + let overflowed = self.overflowing_add_assign_parallelized(&mut ct_res, ct_right); (ct_res, overflowed) } - pub fn unsigned_overflowing_add_assign_parallelized( + pub fn overflowing_add_assign_parallelized( &self, - ct_left: &mut RadixCiphertext, - ct_right: &RadixCiphertext, - ) -> BooleanBlock { - let mut tmp_rhs: RadixCiphertext; - if ct_left.blocks.is_empty() || ct_right.blocks.is_empty() { + ct_left: &mut T, + ct_right: &T, + ) -> BooleanBlock + where + T: IntegerRadixCiphertext, + { + let mut tmp_rhs: T; + if ct_left.blocks().is_empty() || ct_right.blocks().is_empty() { return self.create_trivial_boolean_block(false); } @@ -309,30 +290,50 @@ impl ServerKey { } }; - self.unchecked_add_assign_parallelized(lhs, rhs); - self.unsigned_overflowing_propagate_addition_carry(lhs) + self.overflowing_add_assign_with_carry(lhs, rhs, None) } - /// This function takes a ciphertext resulting from an addition of 2 clean ciphertexts + /// Computes the addition of two unsigned ciphertexts and returns the overflow flag /// - /// It propagates the carries in-place, making the ciphertext clean and returns - /// the boolean indicating overflow - pub(in crate::integer) fn unsigned_overflowing_propagate_addition_carry( + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg1 = u8::MAX; + /// let msg2 = 1; + /// + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// let (ct_res, overflowed) = sks.unsigned_overflowing_add_parallelized(&ct1, &ct2); + /// + /// // Decrypt: + /// let dec_result: u8 = cks.decrypt(&ct_res); + /// let dec_overflowed = cks.decrypt_bool(&overflowed); + /// let (expected_result, expected_overflow) = msg1.overflowing_add(msg2); + /// assert_eq!(dec_result, expected_result); + /// assert_eq!(dec_overflowed, expected_overflow); + /// ``` + pub fn unsigned_overflowing_add_parallelized( + &self, + ct_left: &RadixCiphertext, + ct_right: &RadixCiphertext, + ) -> (RadixCiphertext, BooleanBlock) { + self.overflowing_add_parallelized(ct_left, ct_right) + } + + pub fn unsigned_overflowing_add_assign_parallelized( &self, - ct: &mut RadixCiphertext, + ct_left: &mut RadixCiphertext, + ct_right: &RadixCiphertext, ) -> BooleanBlock { - if self.is_eligible_for_parallel_single_carry_propagation(ct) { - let carry = self.propagate_single_carry_parallelized_low_latency(&mut ct.blocks); - BooleanBlock::new_unchecked(carry) - } else { - let len = ct.blocks.len(); - for i in 0..len - 1 { - let _ = self.propagate_parallelized(ct, i); - } - let mut carry = self.propagate_parallelized(ct, len - 1); - carry.degree = Degree::new(1); - BooleanBlock::new_unchecked(carry) - } + self.overflowing_add_assign_parallelized(ct_left, ct_right) } pub fn signed_overflowing_add_parallelized( @@ -340,36 +341,7 @@ impl ServerKey { ct_left: &SignedRadixCiphertext, ct_right: &SignedRadixCiphertext, ) -> (SignedRadixCiphertext, BooleanBlock) { - let mut tmp_lhs: SignedRadixCiphertext; - let mut tmp_rhs: SignedRadixCiphertext; - - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.clone(); - self.full_propagate_parallelized(&mut tmp_rhs); - (ct_left, &tmp_rhs) - } - (false, true) => { - tmp_lhs = ct_left.clone(); - self.full_propagate_parallelized(&mut tmp_lhs); - (&tmp_lhs, ct_right) - } - (false, false) => { - tmp_lhs = ct_left.clone(); - tmp_rhs = ct_right.clone(); - rayon::join( - || self.full_propagate_parallelized(&mut tmp_lhs), - || self.full_propagate_parallelized(&mut tmp_rhs), - ); - (&tmp_lhs, &tmp_rhs) - } - }; - - self.unchecked_signed_overflowing_add_parallelized(lhs, rhs) + self.overflowing_add_parallelized(ct_left, ct_right) } pub fn unchecked_signed_overflowing_add_parallelized( @@ -386,61 +358,9 @@ impl ServerKey { ); assert!(!ct_left.blocks.is_empty(), "inputs cannot be empty"); - if self.is_eligible_for_parallel_single_carry_propagation(ct_left) { - self.unchecked_signed_overflowing_add_or_sub_parallelized_impl( - ct_left, - ct_right, - SignedOperation::Addition, - ) - } else { - self.unchecked_signed_overflowing_add_or_sub( - ct_left, - ct_right, - SignedOperation::Addition, - ) - } - } - - pub fn add_parallelized_work_efficient(&self, ct_left: &T, ct_right: &T) -> T - where - T: IntegerRadixCiphertext, - { - let mut ct_res = ct_left.clone(); - self.add_assign_parallelized_work_efficient(&mut ct_res, ct_right); - ct_res - } - - pub fn add_assign_parallelized_work_efficient(&self, ct_left: &mut T, ct_right: &T) - where - T: IntegerRadixCiphertext, - { - let mut tmp_rhs: T; - - let (lhs, rhs) = match ( - ct_left.block_carries_are_empty(), - ct_right.block_carries_are_empty(), - ) { - (true, true) => (ct_left, ct_right), - (true, false) => { - tmp_rhs = ct_right.clone(); - self.full_propagate_parallelized(&mut tmp_rhs); - (ct_left, &tmp_rhs) - } - (false, true) => { - self.full_propagate_parallelized(ct_left); - (ct_left, ct_right) - } - (false, false) => { - tmp_rhs = ct_right.clone(); - rayon::join( - || self.full_propagate_parallelized(ct_left), - || self.full_propagate_parallelized(&mut tmp_rhs), - ); - (ct_left, &tmp_rhs) - } - }; - - self.unchecked_add_assign_parallelized_work_efficient(lhs, rhs); + let mut result = ct_left.clone(); + let overflowed = self.overflowing_add_assign_with_carry(&mut result, ct_right, None); + (result, overflowed) } pub(crate) fn is_eligible_for_parallel_single_carry_propagation(&self, ct: &T) -> bool @@ -459,115 +379,870 @@ impl ServerKey { should_hillis_steele_propagation_be_faster(ct.blocks().len(), rayon::current_num_threads()) } - /// This add_assign two numbers - /// - /// It uses the Hillis and Steele algorithm to do - /// prefix sum / cumulative sum in parallel. - /// - /// It it not "work efficient" as in, it adds a lot - /// of work compared to the single threaded approach, - /// however it is highly parallelized and so is the fastest - /// assuming enough threads are available. - /// - /// At most num_block - 1 threads are used - /// - /// Returns the output carry that can be used to check for unsigned addition - /// overflow. - /// - /// # Requirements - /// - /// - The parameters have 4 bits in total - /// - Adding rhs to lhs must not consume more than one carry - /// - /// # Output - /// - /// - lhs will have its carries empty - pub(crate) fn unchecked_add_assign_parallelized_low_latency( + /// Does lhs += (rhs + carry) + pub fn add_assign_with_carry(&self, lhs: &mut T, rhs: &T, input_carry: Option<&BooleanBlock>) + where + T: IntegerRadixCiphertext, + { + self.advanced_add_assign_with_carry( + lhs.blocks_mut(), + rhs.blocks(), + input_carry, + ComputationFlags::None, + ); + } + + /// Does lhs += (rhs + carry) + pub fn overflowing_add_assign_with_carry( &self, lhs: &mut T, rhs: &T, - ) -> Ciphertext + input_carry: Option<&BooleanBlock>, + ) -> BooleanBlock where T: IntegerRadixCiphertext, { - let degree_after_add_does_not_go_beyond_first_carry = lhs - .blocks() - .iter() - .zip(rhs.blocks().iter()) - .all(|(bl, br)| { - let degree_after_add = bl.degree.get() + br.degree.get(); - degree_after_add < (self.key.message_modulus.0 * 2) - }); - assert!(degree_after_add_does_not_go_beyond_first_carry); + self.advanced_add_assign_with_carry( + lhs.blocks_mut(), + rhs.blocks(), + input_carry, + ComputationFlags::from_signedness(T::IS_SIGNED), + ) + .expect("internal error, overflow computation was not returned as was requested") + } - self.unchecked_add_assign_parallelized(lhs, rhs); - self.propagate_single_carry_parallelized_low_latency(lhs.blocks_mut()) + pub(crate) fn propagate_single_carry_parallelized(&self, radix: &mut [Ciphertext]) { + self.advanced_add_assign_with_carry_at_least_4_bits( + radix, + &[], + None, + ComputationFlags::None, + ); } - /// This function takes an input slice of shortint ciphertext (aka blocks) - /// for which at most one bit of carry is consumed in each block, and - /// it does the carry propagation in place. - /// - /// It returns the output carry of the last block - /// - /// Used in (among other) 'default' addition: - /// - first unchecked_add - /// - at this point at most on bit of carry is taken - /// - use this function to propagate them in parallel - pub(crate) fn propagate_single_carry_parallelized_low_latency( + pub(crate) fn advanced_add_assign_with_carry( &self, - blocks: &mut [Ciphertext], - ) -> Ciphertext { - let generates_or_propagates = self.generate_init_carry_array(blocks); - let (input_carries, output_carry) = - self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates); + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + input_carry: Option<&BooleanBlock>, + requested_flag: ComputationFlags, + ) -> Option { + // TODO: estimate by thread count + if self.message_modulus().0 * self.carry_modulus().0 >= 16 { + self.advanced_add_assign_with_carry_at_least_4_bits( + lhs, + rhs, + input_carry, + requested_flag, + ) + } else { + self.advanced_add_assign_with_carry_sequential(lhs, rhs, input_carry, requested_flag) + } + } - blocks - .par_iter_mut() - .zip(input_carries.par_iter()) - .for_each(|(block, input_carry)| { - self.key.unchecked_add_assign(block, input_carry); - self.key.message_extract_assign(block); + pub(crate) fn advanced_add_assign_with_carry_sequential( + &self, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + input_carry: Option<&BooleanBlock>, + requested_flag: ComputationFlags, + ) -> Option { + assert_eq!( + lhs.len(), + rhs.len(), + "Both operands must have the same number of blocks" + ); + + if lhs.is_empty() { + return if requested_flag == ComputationFlags::None { + None + } else { + Some(self.create_trivial_boolean_block(false)) + }; + } + + let mut carry = input_carry.map_or_else( + || self.key.create_trivial(0), + |boolean_block| boolean_block.0.clone(), + ); + + // 2_2, 3_3, 4_4 + // If we have at least 2 bits and at least as much carries + if self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= self.key.message_modulus.0 + { + let mut overflow_flag = if requested_flag == ComputationFlags::Overflow { + let mut block = self.key.unchecked_scalar_mul( + lhs.last().as_ref().unwrap(), + self.message_modulus().0 as u8, + ); + self.key + .unchecked_add_assign(&mut block, rhs.last().as_ref().unwrap()); + Some(block) + } else { + None + }; + // Handle the first block + self.key.unchecked_add_assign(&mut lhs[0], &rhs[0]); + self.key.unchecked_add_assign(&mut lhs[0], &carry); + + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[0]); + rayon::scope(|s| { + s.spawn(|_| { + self.key.message_extract_assign(&mut lhs[0]); + }); + + s.spawn(|_| { + self.key.carry_extract_assign(&mut carry); + }); + + if requested_flag == ComputationFlags::Overflow { + s.spawn(|_| { + // Computing the overflow flag requires and extra step for the first block + + let overflow_flag = overflow_flag.as_mut().unwrap(); + let num_bits_in_message = self.message_modulus().0.ilog2() as u64; + let lut = self.key.generate_lookup_table(|lhs_rhs| { + let lhs = lhs_rhs / self.message_modulus().0 as u64; + let rhs = lhs_rhs % self.message_modulus().0 as u64; + let mask = (1 << (num_bits_in_message - 1)) - 1; + let lhs_except_last_bit = lhs & mask; + let rhs_except_last_bit = rhs & mask; + + let overflows_with_given_input_carry = |input_carry| { + let output_carry = + ((lhs + rhs + input_carry) >> num_bits_in_message) & 1; + + let input_carry_to_last_bit = + ((lhs_except_last_bit + rhs_except_last_bit + input_carry) + >> (num_bits_in_message - 1)) + & 1; + + u64::from(input_carry_to_last_bit != output_carry) + }; + + (overflows_with_given_input_carry(1) << 3) + | (overflows_with_given_input_carry(0) << 2) + }); + self.key.apply_lookup_table_assign(overflow_flag, &lut); + }); + } }); - output_carry + + let num_blocks = lhs.len(); + for (lhs_b, rhs_b) in lhs[1..num_blocks - 1] + .iter_mut() + .zip(rhs[1..num_blocks - 1].iter()) + { + self.key.unchecked_add_assign(lhs_b, rhs_b); + self.key.unchecked_add_assign(lhs_b, &carry); + + carry.clone_from(lhs_b); + rayon::join( + || self.key.message_extract_assign(lhs_b), + || self.key.carry_extract_assign(&mut carry), + ); + } + + // Handle the last block + self.key.unchecked_add_assign(&mut lhs[0], &rhs[0]); + self.key.unchecked_add_assign(&mut lhs[0], &carry); + + if let Some(block) = overflow_flag.as_mut() { + self.key.unchecked_add_assign(block, &carry); + } + + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[0]); + + rayon::scope(|s| { + s.spawn(|_| { + self.key.message_extract_assign(&mut lhs[0]); + }); + + s.spawn(|_| { + self.key.carry_extract_assign(&mut carry); + }); + + if requested_flag == ComputationFlags::Overflow { + s.spawn(|_| { + let overflow_flag_block = overflow_flag.as_mut().unwrap(); + //let shifted_carry = self.key.unchecked_scalar_mul(&carry, 2); + // Computing the overflow flag requires and extra step for the first block + let overflow_flag_lut = self.key.generate_lookup_table(|block| { + let input_carry = block & 1; + if input_carry == 1 { + (block >> 3) & 1 + } else { + (block >> 2) & 1 + } + }); + + self.key + .apply_lookup_table_assign(overflow_flag_block, &overflow_flag_lut); + }); + } + }); + + return match requested_flag { + ComputationFlags::None => None, + ComputationFlags::Overflow => { + assert!( + overflow_flag.is_some(), + "internal error, overflow_flag should exist" + ); + overflow_flag.map(BooleanBlock::new_unchecked) + } + ComputationFlags::Carry => { + carry.degree = Degree::new(1); + Some(BooleanBlock::new_unchecked(carry)) + } + }; + } + + // 1_X parameters + // + // Same idea as other algorithms, however since we have 1 bit per block + // we do not have to resolve any inner propagation but it adds one more + // sequential PBS + if self.key.message_modulus.0 == 2 { + fn block_add_assign_returning_carry( + sks: &ServerKey, + lhs: &mut Ciphertext, + rhs: &Ciphertext, + carry: &Ciphertext, + ) -> Ciphertext { + sks.key.unchecked_add_assign(lhs, rhs); + sks.key.unchecked_add_assign(lhs, carry); + let (carry, message) = rayon::join( + || sks.key.carry_extract(lhs), + || sks.key.message_extract(lhs), + ); + + *lhs = message; + + carry + } + let num_blocks = lhs.len(); + for (lhs_b, rhs_b) in lhs[..num_blocks - 1] + .iter_mut() + .zip(rhs[..num_blocks - 1].iter()) + { + carry = block_add_assign_returning_carry(self, lhs_b, rhs_b, &carry); + } + + let mut output_carry = block_add_assign_returning_carry( + self, + &mut lhs[num_blocks - 1], + &rhs[num_blocks - 1], + &carry, + ); + + return match requested_flag { + ComputationFlags::None => None, + ComputationFlags::Overflow => { + let overflowed = self.key.not_equal(&output_carry, &carry); + Some(BooleanBlock::new_unchecked(overflowed)) + } + ComputationFlags::Carry => { + output_carry.degree = Degree::new(1); + Some(BooleanBlock::new_unchecked(output_carry)) + } + }; + } + + panic!( + "Invalid combo of message modulus ({}) and carry modulus ({}) \n\ + This function requires the message modulus >= 2 and carry modulus >= message_modulus \n\ + I.e. PARAM_MESSAGE_X_CARRY_Y where X >= 1 and Y >= X.", + self.key.message_modulus.0, self.key.carry_modulus.0 + ); } - /// Backbone algorithm of parallel carry (only one bit) propagation - /// - /// Uses the Hillis and Steele prefix scan - /// - /// Requires the blocks to have at least 4 bits - pub(crate) fn compute_carry_propagation_parallelized_low_latency( + /// Does lhs += (rhs + carry) + /// acts like the ADC assemby op, expect, the flags have to be explicitely requested + /// as they incur additional PBS + fn advanced_add_assign_with_carry_at_least_4_bits( &self, - generates_or_propagates: Vec, - ) -> (Vec, Ciphertext) { - if generates_or_propagates.is_empty() { - return (vec![], self.key.create_trivial(0)); + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + input_carry: Option<&BooleanBlock>, + requested_flag: ComputationFlags, + ) -> Option { + // Empty rhs is a specially allowed 'weird' case to have + // act like a 'propagate single carry' function + if rhs.is_empty() { + // Techinically, CarryFlag is computable, but OverflowFlag is not + assert_eq!(requested_flag, ComputationFlags::None); + } else { + assert_eq!( + lhs.len(), + rhs.len(), + "Both operands must have the same number of blocks" + ); + } + + if lhs.is_empty() { + // Then both are empty + if requested_flag == ComputationFlags::None { + return None; + } + return Some(self.create_trivial_boolean_block(false)); + } + + let saved_last_blocks = if requested_flag == ComputationFlags::Overflow { + Some((lhs.last().cloned().unwrap(), rhs.last().cloned().unwrap())) + } else { + None + }; + + // Perform the block additions + for (lhs_b, rhs_b) in lhs.iter_mut().zip(rhs.iter()) { + self.key.unchecked_add_assign(lhs_b, rhs_b); + } + if let Some(carry) = input_carry { + self.key.unchecked_add_assign(&mut lhs[0], &carry.0); + } + + let blocks = lhs; + let num_blocks = blocks.len(); + + let message_modulus = self.message_modulus().0 as u64; + let num_bits_in_message = message_modulus.ilog2() as u64; + + let block_modulus = self.message_modulus().0 * self.carry_modulus().0; + let num_bits_in_block = block_modulus.ilog2(); + + let grouping_size = num_bits_in_block as usize; + + let num_groupings = num_blocks.div_ceil(grouping_size); + assert!(self.key.max_noise_level.get() >= grouping_size); + + let num_carry_to_resolve = num_groupings - 1; + + let sequential_depth = (num_carry_to_resolve as u32 - 1) / (grouping_size as u32 - 1); + let hillis_steel_depth = if num_carry_to_resolve == 0 { + 0 + } else { + num_carry_to_resolve.ceil_ilog2() + }; + + let shift_grouping_pgn = sequential_depth <= hillis_steel_depth; + + let mut output_flag = None; + + // First step + let (shifted_blocks, block_states) = match requested_flag { + ComputationFlags::None => { + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_states(blocks); + let _ = block_states.pop().unwrap(); + (shifted_blocks, block_states) + } + ComputationFlags::Overflow => { + let (block, (shifted_blocks, block_states)) = rayon::join( + || { + let lut = self.key.generate_lookup_table_bivariate(|lhs, rhs| { + let mask = (1 << (num_bits_in_message - 1)) - 1; + let lhs_except_last_bit = lhs & mask; + let rhs_except_last_bit = rhs & mask; + + let overflows_with_given_input_carry = |input_carry| { + let output_carry = + ((lhs + rhs + input_carry) >> num_bits_in_message) & 1; + + let input_carry_to_last_bit = + ((lhs_except_last_bit + rhs_except_last_bit + input_carry) + >> (num_bits_in_message - 1)) + & 1; + + u64::from(input_carry_to_last_bit != output_carry) + }; + + (overflows_with_given_input_carry(1) << 3) + | (overflows_with_given_input_carry(0) << 2) + }); + let (last_lhs_block, last_rhs_block) = saved_last_blocks.as_ref().unwrap(); + self.key.unchecked_apply_lookup_table_bivariate( + last_lhs_block, + last_rhs_block, + &lut, + ) + }, + || { + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_states(blocks); + let _ = block_states.pop().unwrap(); + (shifted_blocks, block_states) + }, + ); + + output_flag = Some(block); + (shifted_blocks, block_states) + } + ComputationFlags::Carry => { + let (shifted_blocks, mut block_states) = + self.compute_shifted_blocks_and_block_states(blocks); + let last_block_state = block_states.pop().unwrap(); + output_flag = Some(last_block_state); + (shifted_blocks, block_states) + } + }; + + // Second step + let (mut prepared_blocks, mut groupings_pgns) = { + // This stores, the LUTs that given a cum sum block in the first grouping + // tells if a carry is generated or not + let first_grouping_inner_propagation_luts = (0..grouping_size - 1) + .map(|index| { + self.key.generate_lookup_table(|propa_cum_sum_block| { + let carry = propa_cum_sum_block & (1 << index); + if carry != 0 { + 2 // Generates + } else { + 0 // Nothing + } + }) + }) + .collect::>(); + + // This stores, the LUTs that given a cum sum in non first grouping + // tells if a carry is generated or propagated or neither of these + let other_groupings_inner_propagation_luts = (0..grouping_size) + .map(|index| { + self.key.generate_lookup_table(|propa_cum_sum_block| { + let mask = (2 << index) - 1; + if propa_cum_sum_block >= (2 << index) { + 2 // Generates + } else if (propa_cum_sum_block & mask) == mask { + 1 // Propagate + } else { + 0 + } + }) + }) + .collect::>(); + + // This stores the LUT that outputs the propagation result of the first grouping + let first_grouping_outer_propagation_lut = self.key.generate_lookup_table(|block| { + // Check if the last bit of the block is set + (block >> (num_bits_in_block - 1)) & 1 + }); + + // This stores the LUTs that output the propagation result of the other groupings + let grouping_chunk_pgn_luts = if shift_grouping_pgn { + // When using the sequential algorithm for the propagation of one grouping to the + // other we need to shift the PGN state to the correct position, so we later, when + // using them only lwe_add is needed and so noise management is easy + // + // Also, these LUTs are 'negacylic', they are made to exploit the padding bit + // resulting blocks from these LUTs must be added the constant `1 << index`. + (0..grouping_size - 1) + .map(|i| { + self.key.generate_lookup_table(|block| { + // All bits set to 1 (e.g. 0b1111), means propagate + if block == (block_modulus - 1) as u64 { + 0 + } else { + // u64::MAX is -1 in tow's complement + // We apply the modulus including the padding bit + (u64::MAX << i) % (1 << (num_bits_in_block + 1)) + } + }) + }) + .collect::>() + } else { + // This LUT is for when we are using Hillis-Steele prefix-scan to propagate carries + // between groupings. When using this propagation, the encoding of the states + // are a bit different. + // + // Also, these LUTs are 'negacylic', they are made to exploit the padding bit + // resulting blocks from these LUTs must be added the constant `1`. + vec![self.key.generate_lookup_table(|block| { + if block == (block_modulus - 1) as u64 { + // All bits set to 1 (e.g. 0b1111), means propagate + 2 + } else { + // u64::MAX is -1 in tow's complement + // We apply the modulus including the padding bit + u64::MAX % (1 << (block_modulus + 1)) + } + })] + }; + + let mut propagation_cum_sums = Vec::with_capacity(num_blocks); + block_states.chunks(grouping_size).for_each(|grouping| { + propagation_cum_sums.push(grouping[0].clone()); + for other in &grouping[1..] { + let mut result = other.clone(); + self.key + .unchecked_add_assign(&mut result, propagation_cum_sums.last().unwrap()); + + propagation_cum_sums.push(result); + } + }); + + let len = propagation_cum_sums.len(); + propagation_cum_sums + .par_iter_mut() + .enumerate() + .for_each(|(i, cum_sum_block)| { + let grouping_index = i / grouping_size; + let is_in_first_grouping = grouping_index == 0; + let index_in_grouping = i % (grouping_size); + + let lut = if is_in_first_grouping { + //println!("drjredd"); + if index_in_grouping == grouping_size - 1 { + //println!("First Grouping PGN"); + &first_grouping_outer_propagation_lut + } else { + &first_grouping_inner_propagation_luts[index_in_grouping] + } + } else if index_in_grouping == grouping_size - 1 { + if shift_grouping_pgn { + //println!("Grouping PGN for sequential"); + &grouping_chunk_pgn_luts[(grouping_index - 1) % (grouping_size - 1)] + } else { + //println!("Grouping PGN for hillis"); + &grouping_chunk_pgn_luts[0] + } + } else { + &other_groupings_inner_propagation_luts[index_in_grouping] + }; + + self.key.apply_lookup_table_assign(cum_sum_block, lut); + + let may_have_its_padding_bit_set = + !is_in_first_grouping && index_in_grouping == grouping_size - 1; + if may_have_its_padding_bit_set { + if shift_grouping_pgn { + self.key.unchecked_scalar_add_assign( + cum_sum_block, + 1 << ((grouping_index - 1) % (grouping_size - 1)), + ); + } else { + self.key.unchecked_scalar_add_assign(cum_sum_block, 1); + } + cum_sum_block.degree = Degree::new(message_modulus as usize - 1); + } + //("cumsum out ", &[cum_sum_block.clone()]); + }); + + let num_groupings = num_blocks / grouping_size; + let mut groupings_pgns = Vec::with_capacity(num_groupings); + let mut propagation_simulators = Vec::with_capacity(num_blocks); + + // First block does not get borrowed from + propagation_simulators.push(self.key.create_trivial(0)); + for (i, block) in propagation_cum_sums + // .drain(..propagation_cum_sums.len().saturating_sub(1)) + .drain(..) + .enumerate() + { + if propagation_simulators.len() % grouping_size == 0 { + groupings_pgns.push(block); + if i != len - 1 { + // The first block in each grouping has its simulator set to 0 + // because it always receives any input borrow that may be generated from + // previous grouping + propagation_simulators.push(self.key.create_trivial(1)); + } + } else { + propagation_simulators.push(block); + } + } + + let mut prepared_blocks = shifted_blocks; + prepared_blocks + .iter_mut() + .zip(propagation_simulators.iter()) + .for_each(|(block, simulator)| { + crate::core_crypto::algorithms::lwe_ciphertext_add_assign( + &mut block.ct, + &simulator.ct, + ); + }); + + match requested_flag { + ComputationFlags::None => {} + ComputationFlags::Overflow => { + let block = output_flag.as_mut().unwrap(); + self.key + .unchecked_add_assign(block, &propagation_simulators[num_blocks - 1]); + } + ComputationFlags::Carry => { + let block = output_flag.as_mut().unwrap(); + self.key + .unchecked_add_assign(block, &propagation_simulators[num_blocks - 1]); + } + } + + (prepared_blocks, groupings_pgns) + }; + + // Third step: resolving carry propagation between the groups + let resolved_carries = if groupings_pgns.is_empty() { + vec![self.key.create_trivial(0)] + } else if shift_grouping_pgn { + let luts = (0..grouping_size - 1) + .map(|index| { + self.key.generate_lookup_table(|propa_cum_sum_block| { + let carry = propa_cum_sum_block & (1 << (index + 1)); + u64::from(carry != 0) + }) + }) + .collect::>(); + + groupings_pgns.rotate_left(1); + let mut resolved_carries = + vec![self.key.create_trivial(0), groupings_pgns.pop().unwrap()]; + for chunk in groupings_pgns.chunks(grouping_size - 1) { + //println!("chunk size: {}", chunk.len()); + let mut cum_sums = chunk.to_vec(); + self.key + .unchecked_add_assign(&mut cum_sums[0], resolved_carries.last().unwrap()); + + for i in [1, 2] { + if i == 1 && cum_sums.len() < 2 { + continue; + } + if i == 2 && cum_sums.len() < 3 { + continue; + } + // All this just to do add_assign(&mut cum_sum[i], &cum_sum[i-1]) + let (l, r) = cum_sums.split_at_mut(i); + let llen = l.len(); + self.key.unchecked_add_assign(&mut r[0], &l[llen - 1]); + } + + cum_sums + .par_iter_mut() + .zip(luts.par_iter()) + .for_each(|(cum_sum_block, lut)| { + self.key.apply_lookup_table_assign(cum_sum_block, lut); + }); + + // Cum sums now contains the output carries + resolved_carries.append(&mut cum_sums); + } + + resolved_carries + } else { + let lut_carry_propagation_sum = + self.key + .generate_lookup_table_bivariate(|msb: u64, lsb: u64| -> u64 { + if msb == 2 { + 1 // Remap Generate to 1 + } else if msb == 3 { + // MSB propagates + if lsb == 2 { + 1 + } else { + lsb + } // also remap here + } else { + msb + } + }); + let sum_function = |block_borrow: &mut Ciphertext, + previous_block_borrow: &Ciphertext| { + self.key.unchecked_apply_lookup_table_bivariate_assign( + block_borrow, + previous_block_borrow, + &lut_carry_propagation_sum, + ); + }; + let mut resolved_carries = + self.compute_prefix_sum_hillis_steele(groupings_pgns, sum_function); + resolved_carries.insert(0, self.key.create_trivial(0)); + resolved_carries + }; + + // Final step: adding resolved carries and cleaning result + let mut add_carries_and_cleanup = || { + let message_extract_lut = self + .key + .generate_lookup_table(|block| (block >> 1) % message_modulus); + + prepared_blocks + .par_iter_mut() + .enumerate() + .for_each(|(i, block)| { + let grouping_index = i / grouping_size; + let borrow = &resolved_carries[grouping_index]; + crate::core_crypto::algorithms::lwe_ciphertext_add_assign( + &mut block.ct, + &borrow.ct, + ); + + self.key + .apply_lookup_table_assign(block, &message_extract_lut) + }); + }; + + match requested_flag { + ComputationFlags::None => { + add_carries_and_cleanup(); + } + ComputationFlags::Overflow => { + let overflow_flag_lut = self.key.generate_lookup_table(|block| { + let input_carry = (block >> 1) & 1; + if input_carry == 1 { + (block >> 3) & 1 + } else { + (block >> 2) & 1 + } + }); + rayon::join( + || { + let block = output_flag.as_mut().unwrap(); + self.key.unchecked_add_assign( + block, + &resolved_carries[resolved_carries.len() - 1], + ); + self.key + .apply_lookup_table_assign(block, &overflow_flag_lut); + }, + add_carries_and_cleanup, + ); + } + ComputationFlags::Carry => { + let carry_flag_lut = self.key.generate_lookup_table(|block| (block >> 2) & 1); + + rayon::join( + || { + let block = output_flag.as_mut().unwrap(); + self.key.unchecked_add_assign( + block, + &resolved_carries[resolved_carries.len() - 1], + ); + self.key.apply_lookup_table_assign(block, &carry_flag_lut); + }, + add_carries_and_cleanup, + ); + } } - let lut_carry_propagation_sum = self - .key - .generate_lookup_table_bivariate(prefix_sum_carry_propagation); - // Type annotations are required, otherwise we get confusing errors - // "implementation of `FnOnce` is not general enough" - let sum_function = |block_carry: &mut Ciphertext, previous_block_carry: &Ciphertext| { - self.key.unchecked_apply_lookup_table_bivariate_assign( - block_carry, - previous_block_carry, - &lut_carry_propagation_sum, + blocks.clone_from_slice(&prepared_blocks); + + match requested_flag { + ComputationFlags::None => None, + ComputationFlags::Overflow | ComputationFlags::Carry => { + output_flag.map(BooleanBlock::new_unchecked) + } + } + } + + fn compute_shifted_blocks_and_block_states( + &self, + blocks: &[Ciphertext], + ) -> (Vec, Vec) { + let num_blocks = blocks.len(); + + let message_modulus = self.message_modulus().0 as u64; + + let block_modulus = self.message_modulus().0 * self.carry_modulus().0; + let num_bits_in_block = block_modulus.ilog2(); + + let grouping_size = num_bits_in_block as usize; + + let shift_block_fn = |block| (block % message_modulus) << 1; + let mut first_grouping_luts = vec![{ + let first_block_state_fn = |block| { + if block >= message_modulus { + 1 // Generates + } else { + 0 // Nothing + } + }; + self.key + .generate_many_lookup_table(&[&first_block_state_fn, &shift_block_fn]) + }]; + for i in 1..grouping_size { + let state_fn = |block| { + let r = if block >= message_modulus { + 2 // Generates Carry + } else if block == message_modulus - 1 { + 1 // Propagates a carry + } else { + 0 // Does not borrow + }; + + r << (i - 1) + }; + first_grouping_luts.push( + self.key + .generate_many_lookup_table(&[&state_fn, &shift_block_fn]), ); + } + + let other_block_state_luts = (0..grouping_size) + .map(|i| { + let state_fn = |block| { + let r = if block >= message_modulus { + 2 // Generates Carry + } else if block == message_modulus - 1 { + 1 // Propagates a carry + } else { + 0 // Does not borrow + }; + + r << i + }; + self.key + .generate_many_lookup_table(&[&state_fn, &shift_block_fn]) + }) + .collect::>(); + + let last_block_luts = { + if blocks.len() == 1 { + let first_block_state_fn = |block| { + if block >= message_modulus { + 2 << 1 // Generates + } else { + 0 // Nothing + } + }; + self.key + .generate_many_lookup_table(&[&first_block_state_fn, &shift_block_fn]) + } else if (blocks.len() - 1) <= grouping_size { + // The last block is in the first grouping + first_grouping_luts[2].clone() + } else { + first_grouping_luts[2].clone() + } }; - let num_blocks = generates_or_propagates.len(); - let mut carries_out = - self.compute_prefix_sum_hillis_steele(generates_or_propagates, sum_function); - let mut last_block_out_carry = self.key.create_trivial(0); - std::mem::swap(&mut carries_out[num_blocks - 1], &mut last_block_out_carry); - last_block_out_carry.degree = Degree::new(1); - // The output carry of block i-1 becomes the input - // carry of block i - carries_out.rotate_right(1); - (carries_out, last_block_out_carry) + let tmp = blocks + .par_iter() + .enumerate() + .map(|(index, block)| { + let grouping_index = index / grouping_size; + let is_in_first_grouping = grouping_index == 0; + let index_in_grouping = index % (grouping_size); + let is_last_index = index == blocks.len() - 1; + + let luts = if is_last_index { + &last_block_luts + } else if is_in_first_grouping { + &first_grouping_luts[index_in_grouping] + } else { + &other_block_state_luts[index_in_grouping] + }; + self.key.apply_many_lookup_table(block, luts) + }) + .collect::>(); + + let mut shifted_blocks = Vec::with_capacity(num_blocks); + let mut block_states = Vec::with_capacity(num_blocks); + for mut blocks in tmp { + assert_eq!(blocks.len(), 2); + shifted_blocks.push(blocks.pop().unwrap()); + block_states.push(blocks.pop().unwrap()); + } + + (shifted_blocks, block_states) } /// Computes a prefix sum/scan in parallel using Hillis & Steel algorithm @@ -581,8 +1256,8 @@ impl ServerKey { { debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 4)); - if blocks.is_empty() { - return vec![]; + if blocks.is_empty() || blocks.len() == 1 { + return blocks; } let num_blocks = blocks.len(); @@ -607,167 +1282,6 @@ impl ServerKey { blocks } - - /// This add_assign two numbers - /// - /// It is after the Blelloch algorithm to do - /// prefix sum / cumulative sum in parallel. - /// - /// It is not "work efficient" as in, it does not adds - /// that much work compared to other parallel algorithm, - /// thus requiring less threads. - /// - /// However it is slower. - /// - /// At most num_block / 2 threads are used - /// - /// # Requirements - /// - /// - The parameters have 4 bits in total - /// - Adding rhs to lhs must not consume more than one carry - /// - /// # Output - /// - /// - lhs will have its carries empty - pub(crate) fn unchecked_add_assign_parallelized_work_efficient(&self, lhs: &mut T, rhs: &T) - where - T: IntegerRadixCiphertext, - { - let degree_after_add_does_not_go_beyond_first_carry = lhs - .blocks() - .iter() - .zip(rhs.blocks().iter()) - .all(|(bl, br)| { - let degree_after_add = bl.degree.get() + br.degree.get(); - degree_after_add < (self.key.message_modulus.0 * 2) - }); - assert!(degree_after_add_does_not_go_beyond_first_carry); - debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 3)); - - self.unchecked_add_assign_parallelized(lhs, rhs); - let generates_or_propagates = self.generate_init_carry_array(lhs.blocks()); - let carry_out = - self.compute_carry_propagation_parallelized_work_efficient(generates_or_propagates); - - lhs.blocks_mut() - .par_iter_mut() - .zip(carry_out.par_iter()) - .for_each(|(block, carry_in)| { - self.key.unchecked_add_assign(block, carry_in); - self.key.message_extract_assign(block); - }); - } - - pub(crate) fn compute_carry_propagation_parallelized_work_efficient( - &self, - mut carry_out: Vec, - ) -> Vec { - debug_assert!(self.key.message_modulus.0 * self.key.carry_modulus.0 >= (1 << 3)); - - let num_blocks = carry_out.len(); - let num_steps = carry_out.len().ilog2() as usize; - - let lut_carry_propagation_sum = self - .key - .generate_lookup_table_bivariate(prefix_sum_carry_propagation); - - for i in 0..num_steps { - let two_pow_i_plus_1 = 2usize.checked_pow((i + 1) as u32).unwrap(); - let two_pow_i = 2usize.checked_pow(i as u32).unwrap(); - - carry_out - .par_chunks_exact_mut(two_pow_i_plus_1) - .for_each(|carry_out| { - let (last, head) = carry_out.split_last_mut().unwrap(); - let current_block = last; - let previous_block = &head[two_pow_i - 1]; - - self.key.unchecked_apply_lookup_table_bivariate_assign( - current_block, - previous_block, - &lut_carry_propagation_sum, - ); - }); - } - - // Down-Sweep phase - let mut buffer = Vec::with_capacity(num_blocks / 2); - self.key - .create_trivial_assign(&mut carry_out[num_blocks - 1], 0); - for i in (0..num_steps).rev() { - let two_pow_i_plus_1 = 2usize.checked_pow((i + 1) as u32).unwrap(); - let two_pow_i = 2usize.checked_pow(i as u32).unwrap(); - - (0..num_blocks) - .into_par_iter() - .step_by(two_pow_i_plus_1) - .map(|k| { - // Since our carry_propagation LUT ie sum function - // is not commutative we have to reverse operands - self.key.unchecked_apply_lookup_table_bivariate( - &carry_out[k + two_pow_i - 1], - &carry_out[k + two_pow_i_plus_1 - 1], - &lut_carry_propagation_sum, - ) - }) - .collect_into_vec(&mut buffer); - - let mut drainer = buffer.drain(..); - for k in (0..num_blocks).step_by(two_pow_i_plus_1) { - let b = drainer.next().unwrap(); - carry_out.swap(k + two_pow_i - 1, k + two_pow_i_plus_1 - 1); - carry_out[k + two_pow_i_plus_1 - 1] = b; - } - drop(drainer); - assert!(buffer.is_empty()); - } - - // The first step of the Down-Sweep phase sets the - // first block to 0, so no need to re-do it - carry_out - } - - pub(super) fn generate_init_carry_array(&self, sum_blocks: &[Ciphertext]) -> Vec { - let modulus = self.key.message_modulus.0 as u64; - - // This is used for the first pair of blocks - // as this pair can either generate or not, but never propagate - let lut_does_block_generate_carry = self.key.generate_lookup_table(|x| { - if x >= modulus { - OutputCarry::Generated as u64 - } else { - OutputCarry::None as u64 - } - }); - - let lut_does_block_generate_or_propagate = self.key.generate_lookup_table(|x| { - if x >= modulus { - OutputCarry::Generated as u64 - } else if x == (modulus - 1) { - OutputCarry::Propagated as u64 - } else { - OutputCarry::None as u64 - } - }); - - let mut generates_or_propagates = Vec::with_capacity(sum_blocks.len()); - sum_blocks - .par_iter() - .enumerate() - .map(|(i, block)| { - if i == 0 { - // The first block can only output a carry - self.key - .apply_lookup_table(block, &lut_does_block_generate_carry) - } else { - self.key - .apply_lookup_table(block, &lut_does_block_generate_or_propagate) - } - }) - .collect_into_vec(&mut generates_or_propagates); - - generates_or_propagates - } } #[cfg(test)] @@ -781,11 +1295,8 @@ mod tests { // Parameters and num blocks do not matter here let (_, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, 4); - let carry = sks.propagate_single_carry_parallelized_low_latency([].as_mut_slice()); - + sks.propagate_single_carry_parallelized(&mut []); // The most interesting part we test is that the code does not panic - assert!(carry.is_trivial()); - assert_eq!(carry.decrypt_trivial().unwrap(), 0u64); } #[test] diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index f1f354b32e..9e616f0819 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -33,6 +33,7 @@ mod vector_find; use super::ServerKey; use crate::integer::ciphertext::IntegerRadixCiphertext; +pub(crate) use add::ComputationFlags; use rayon::prelude::*; pub use scalar_div_mod::{MiniUnsignedInteger, Reciprocable}; pub use vector_find::MatchValues; @@ -129,27 +130,20 @@ impl ServerKey { .max_by(|block_a, block_b| block_a.degree.get().cmp(&block_b.degree.get())) .map(|block| block.degree.get()) .unwrap(); // We checked for emptiness earlier - if highest_degree <= (self.key.message_modulus.0 - 1) * 2 { - let _ = self.propagate_single_carry_parallelized_low_latency( - &mut ctxt.blocks_mut()[start_index..], - ); - } else { + + if highest_degree >= (self.key.message_modulus.0 - 1) * 2 { // At least one of the blocks has more than one carry, // we need to extract message and carries, then add + propagate let (mut message_blocks, carry_blocks) = extract_message_and_carry_blocks(&ctxt.blocks()[start_index..]); - ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks); - for (block, carry) in ctxt.blocks_mut()[start_index + 1..] - .iter_mut() - .zip(carry_blocks.iter()) - { - self.key.unchecked_add_assign(block, carry); - } - // We can start propagation one index later as we already did the first block - let _ = self.propagate_single_carry_parallelized_low_latency( - &mut ctxt.blocks_mut()[start_index + 1..], - ); + ctxt.blocks_mut()[start_index] = message_blocks.remove(0); + let mut lhs = T::from(message_blocks); + let rhs = T::from(carry_blocks); + self.add_assign_with_carry(&mut lhs, &rhs, None); + ctxt.blocks_mut()[start_index + 1..].clone_from_slice(lhs.blocks()); + } else { + self.propagate_single_carry_parallelized(&mut ctxt.blocks_mut()[start_index..]); } } else { let maybe_highest_degree = ctxt diff --git a/tfhe/src/integer/server_key/radix_parallel/neg.rs b/tfhe/src/integer/server_key/radix_parallel/neg.rs index f31e7f1948..d4ce9bbcca 100644 --- a/tfhe/src/integer/server_key/radix_parallel/neg.rs +++ b/tfhe/src/integer/server_key/radix_parallel/neg.rs @@ -88,14 +88,8 @@ impl ServerKey { &tmp_ctxt }; - if self.is_eligible_for_parallel_single_carry_propagation(ct) { - let mut ct = self.unchecked_neg(ct); - let _carry = self.propagate_single_carry_parallelized_low_latency(ct.blocks_mut()); - ct - } else { - let mut ct = self.unchecked_neg(ct); - self.full_propagate_parallelized(&mut ct); - ct - } + let mut ct = self.unchecked_neg(ct); + self.full_propagate_parallelized(&mut ct); + ct } } diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs index c900994b54..151e22e401 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs @@ -16,14 +16,28 @@ impl ServerKey { self.full_propagate_parallelized(lhs); } - self.unchecked_scalar_add_assign(lhs, scalar); - let overflowed = self.unsigned_overflowing_propagate_addition_carry(lhs); + let bits_in_message = self.key.message_modulus.0.ilog2(); + let mut scalar_blocks = BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message) + .iter_as::() + .map(|v| self.key.create_trivial(u64::from(v))) + .collect::>(); - let num_scalar_block = - BlockDecomposer::with_early_stop_at_zero(scalar, self.key.message_modulus.0.ilog2()) - .count(); + let trivially_overflowed = match scalar_blocks.len().cmp(&lhs.blocks.len()) { + std::cmp::Ordering::Less => { + scalar_blocks.resize_with(lhs.blocks.len(), || self.key.create_trivial(0)); + false + } + std::cmp::Ordering::Equal => false, + std::cmp::Ordering::Greater => { + scalar_blocks.truncate(lhs.blocks.len()); + true + } + }; + + let rhs = RadixCiphertext::from(scalar_blocks); + let overflowed = self.overflowing_add_assign_with_carry(lhs, &rhs, None); - if num_scalar_block > lhs.blocks.len() { + if trivially_overflowed { // Scalar has more blocks so addition counts as overflowing BooleanBlock::new_unchecked(self.key.create_trivial(1)) } else { @@ -257,7 +271,7 @@ impl ServerKey { if self.is_eligible_for_parallel_single_carry_propagation(ct) { self.unchecked_scalar_add_assign(ct, scalar); - let _carry = self.propagate_single_carry_parallelized_low_latency(ct.blocks_mut()); + self.propagate_single_carry_parallelized(ct.blocks_mut()) } else { self.unchecked_scalar_add_assign(ct, scalar); self.full_propagate_parallelized(ct); diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index a0deda93ca..f68cffaee0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -105,13 +105,18 @@ impl ServerKey { self.full_propagate_parallelized(ct); }; - self.unchecked_scalar_sub_assign(ct, scalar); + let Some(decomposer) = self.create_negated_block_decomposer(scalar) else { + // subtraction by zero + return; + }; - if self.is_eligible_for_parallel_single_carry_propagation(ct) { - let _carry = self.propagate_single_carry_parallelized_low_latency(ct.blocks_mut()); - } else { - self.full_propagate_parallelized(ct); - } + let blocks = decomposer + .take(ct.blocks().len()) + .map(|v| self.key.create_trivial(u64::from(v))) + .collect::>(); + let rhs = T::from_blocks(blocks); + + self.add_assign_with_carry(ct, &rhs, None); } pub fn unsigned_overflowing_scalar_sub_assign_parallelized( diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index 83809612d2..c807e1b2b6 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -1,8 +1,5 @@ -use super::add::OutputCarry; use crate::integer::ciphertext::IntegerRadixCiphertext; -use crate::integer::{ - BooleanBlock, IntegerCiphertext, RadixCiphertext, ServerKey, SignedRadixCiphertext, -}; +use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::ciphertext::Degree; use crate::shortint::Ciphertext; use rayon::prelude::*; @@ -20,13 +17,6 @@ enum BorrowGeneration { Propagated = 2, } -// see [ServerKey::generate_last_block_inner_propagation] -#[derive(Copy, Clone, PartialEq, Eq)] -pub(crate) enum SignedOperation { - Addition, - Subtraction, -} - impl ServerKey { /// Computes homomorphically the subtraction between ct_left and ct_right. /// @@ -235,56 +225,8 @@ impl ServerKey { } }; - if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - let neg = self.unchecked_neg(rhs); - let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, &neg); - } else { - self.unchecked_sub_assign(lhs, rhs); - self.full_propagate_parallelized(lhs); - } - } - - pub fn sub_parallelized_work_efficient(&self, ctxt_left: &T, ctxt_right: &T) -> T - where - T: IntegerRadixCiphertext, - { - let mut ct_res = ctxt_left.clone(); - self.sub_assign_parallelized_work_efficient(&mut ct_res, ctxt_right); - ct_res - } - - pub fn sub_assign_parallelized_work_efficient(&self, ctxt_left: &mut T, ctxt_right: &T) - where - T: IntegerRadixCiphertext, - { - let mut tmp_rhs; - - let (lhs, rhs) = match ( - ctxt_left.block_carries_are_empty(), - ctxt_right.block_carries_are_empty(), - ) { - (true, true) => (ctxt_left, ctxt_right), - (true, false) => { - tmp_rhs = ctxt_right.clone(); - self.full_propagate_parallelized(&mut tmp_rhs); - (ctxt_left, &tmp_rhs) - } - (false, true) => { - self.full_propagate_parallelized(ctxt_left); - (ctxt_left, ctxt_right) - } - (false, false) => { - tmp_rhs = ctxt_right.clone(); - rayon::join( - || self.full_propagate_parallelized(ctxt_left), - || self.full_propagate_parallelized(&mut tmp_rhs), - ); - (ctxt_left, &tmp_rhs) - } - }; - let neg = self.unchecked_neg(rhs); - self.unchecked_add_assign_parallelized_work_efficient(lhs, &neg); + self.add_assign_with_carry(lhs, &neg, None); } /// Computes the subtraction and returns an indicator of overflow @@ -351,7 +293,7 @@ impl ServerKey { } }; - self.unchecked_unsigned_overflowing_sub_parallelized(lhs, rhs) + self.unchecked_unsigned_overflowing_sub(lhs, rhs) } pub fn unchecked_unsigned_overflowing_sub_parallelized( @@ -369,7 +311,7 @@ impl ServerKey { ); // Here we have to use manual unchecked_sub on shortint blocks // rather than calling integer's unchecked_sub as we need each subtraction - // to be independent from other blocks. And we don't want to do subtraction by + // to be independent of other blocks. And we don't want to do subtraction by // adding negation let ct = lhs .blocks @@ -454,217 +396,6 @@ impl ServerKey { } } - // This is used in signed overflow detection - // see [unchecked_signed_overflowing_sub_parallelized] for more context - // - // This is to share the logic between the fully parallelized and - // semi parallelized algorithms. - // - // - last_lhs_block: last block of the lhs used in signed subtraction - // - last_rhs_block: last block the rhs used in signed subtraction - // - // Returns a block to be used as one of the inputs of [resolve_signed_overflow] - pub(crate) fn generate_last_block_inner_propagation( - &self, - last_lhs_block: &Ciphertext, - last_rhs_block: &Ciphertext, - op: SignedOperation, - ) -> Ciphertext { - let bits_of_message = self.key.message_modulus.0.ilog2(); - let message_bit_mask = (1 << bits_of_message) - 1; - - // This lut will generate a block that contains the information - // of how carry propagation happens in the last block, until the last bit. - let last_block_inner_propagation_lut = - self.key - .generate_lookup_table_bivariate(|lhs_block, rhs_block| { - let rhs_block = if op == SignedOperation::Subtraction { - // subtraction is done by doing addition of negation - // negation(x) = bit_flip(x) + 1 - // We only add the flipped value, the + 1 will be resolved by - // carry propagation computation - let flipped_rhs = !rhs_block; - - // We remove the last bit, its not interesting in this step - (flipped_rhs << 1) & message_bit_mask - } else { - (rhs_block << 1) & message_bit_mask - }; - - let lhs_block = (lhs_block << 1) & message_bit_mask; - - // whole_result contains the result of addition with - // the carry being in the first bit of carry space - // the message space contains the message, but with one 0 - // on the right (lsb) - let whole_result = lhs_block + rhs_block; - let carry = whole_result >> bits_of_message; - let result = (whole_result & message_bit_mask) >> 1; - let propagation_result = if carry == 1 { - // Addition of bits before last one generates a carry - OutputCarry::Generated - } else if result == ((self.key.message_modulus.0 as u64 - 1) >> 1) { - // Addition of bits before last one puts the bits - // in a state that makes it so that an input carry into last block - // gets propagated to last bit. - OutputCarry::Propagated - } else { - OutputCarry::None - }; - - // Shift the propagation result in carry part - // to have less noise growth later - (propagation_result as u64) << bits_of_message - }); - self.key.unchecked_apply_lookup_table_bivariate( - last_lhs_block, - last_rhs_block, - &last_block_inner_propagation_lut, - ) - } - - // - last_block_inner_propagation must be the result of generate_last_block_inner_propagation - // - last_block_input_carry: carry that the last pair of blocks (lhs, rhs) receives as input - // - last_block_output_carry: carry that the last pair of blocks (lhs, rhs) output - // - // Returns whether the subtraction overflowed - // - // See [unchecked_signed_overflowing_sub_parallelized] for more context - pub(crate) fn resolve_signed_overflow( - &self, - mut last_block_inner_propagation: Ciphertext, - last_block_input_carry: &BooleanBlock, - last_block_output_carry: &BooleanBlock, - ) -> BooleanBlock { - let bits_of_message = self.key.message_modulus.0.ilog2(); - - let resolve_overflow_lut = self.key.generate_lookup_table(|x| { - let carry_propagation = x >> bits_of_message; - let output_carry_of_block = (x >> 1) & 1; - let input_carry_of_block = x & 1; - - // Resolve the carry that the last bit actually receives as input - let input_carry_to_last_bit = if carry_propagation == OutputCarry::Propagated as u64 { - input_carry_of_block - } else if carry_propagation == OutputCarry::Generated as u64 { - 1 - } else { - 0 - }; - - u64::from(input_carry_to_last_bit != output_carry_of_block) - }); - - let x = self - .key - .unchecked_scalar_mul(last_block_output_carry.as_ref(), 2); - self.key - .unchecked_add_assign(&mut last_block_inner_propagation, &x); - self.key.unchecked_add_assign( - &mut last_block_inner_propagation, - last_block_input_carry.as_ref(), - ); - let result = self - .key - .apply_lookup_table(&last_block_inner_propagation, &resolve_overflow_lut); - BooleanBlock::new_unchecked(result) - } - - // This is the implementation of overflowing add/sub when we can use parallel carry - // propagation, as only a few things change between the two. - pub(crate) fn unchecked_signed_overflowing_add_or_sub_parallelized_impl( - &self, - lhs: &SignedRadixCiphertext, - rhs: &SignedRadixCiphertext, - signed_operation: SignedOperation, - ) -> (SignedRadixCiphertext, BooleanBlock) { - // This assert is here because this overflow computation requires these preconditions - // which is_eligible_for_parallel_single_carry_propagation, but it could change in the - // future - assert!(self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= 4); - - // In Two's complement arithmetic, overflow occurs when the output carry of the - // last bit is not the same as the input carry of the last bit. - // - // Here we have blocks, and we cannot just compare input and output carries of the last - // block as its not equivalent to checking what happens on the last bit. - // So we have to resolve that carry propagation that happens in the last block. - // - // So the carry propagation is done in 2 steps, first we compute the carry propagation - // in the last block to be able at the second step, to know the actual carry that - // the last bit receives. - // - // These are done in parallel to other stuff, and so no additional 'latency cost' - // should occur. - - let mut result = lhs.clone(); - - // Using parallel algorithms for unchecked_add/sub does not seem to bring - // measurable improvements - if signed_operation == SignedOperation::Subtraction { - self.unchecked_sub_assign(&mut result, rhs); - } else { - self.unchecked_add_assign(&mut result, rhs); - } - - let ((input_carries, output_carry), last_block_inner_propagation) = rayon::join( - || { - let generates_or_propagates = self.generate_init_carry_array(result.blocks()); - self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates) - }, - || { - self.generate_last_block_inner_propagation( - lhs.blocks.last().as_ref().unwrap(), - rhs.blocks.last().as_ref().unwrap(), - signed_operation, - ) - }, - ); - - let (_, overflowed) = rayon::join( - || { - result - .blocks - .par_iter_mut() - .zip(input_carries.par_iter()) - .for_each(|(block, input_carry)| { - self.key.unchecked_add_assign(block, input_carry); - self.key.message_extract_assign(block); - }); - }, - || { - let input_carry = input_carries - .last() - .cloned() - .map(BooleanBlock::new_unchecked) - .unwrap(); - let output_carry = BooleanBlock::new_unchecked(output_carry); - self.resolve_signed_overflow( - last_block_inner_propagation, - &input_carry, - &output_carry, - ) - }, - ); - - (result, overflowed) - } - - // It is in its own function so that it can be tested, as the main entry point - // unchecked_signed_overflowing_sub may select non parallel version if lhs - // does not have enough block. - pub(crate) fn unchecked_signed_overflowing_sub_parallelized_impl( - &self, - lhs: &SignedRadixCiphertext, - rhs: &SignedRadixCiphertext, - ) -> (SignedRadixCiphertext, BooleanBlock) { - self.unchecked_signed_overflowing_add_or_sub_parallelized_impl( - lhs, - rhs, - SignedOperation::Subtraction, - ) - } - pub fn unchecked_signed_overflowing_sub_parallelized( &self, lhs: &SignedRadixCiphertext, @@ -679,11 +410,12 @@ impl ServerKey { rhs.blocks.len() ); - if self.is_eligible_for_parallel_single_carry_propagation(lhs) { - self.unchecked_signed_overflowing_sub_parallelized_impl(lhs, rhs) - } else { - self.unchecked_signed_overflowing_sub(lhs, rhs) - } + let flipped_rhs = self.bitnot(rhs); + let input_carry = self.create_trivial_boolean_block(true); + let mut result = lhs.clone(); + let overflowed = + self.overflowing_add_assign_with_carry(&mut result, &flipped_rhs, Some(&input_carry)); + (result, overflowed) } pub(super) fn generate_init_borrow_array(&self, sum_ct: &RadixCiphertext) -> Vec { diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 3d31198a09..9847cb44aa 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -1881,6 +1881,9 @@ where clear = (clear_0 + clear_1) % modulus; + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!(clear, dec_res); + // Add multiple times to raise the degree for _ in 0..nb_tests_smaller { let tmp = executor.execute((&ct_res, clear_1)); @@ -2346,7 +2349,11 @@ where let ct_res = executor.execute((&ct, scalar)); let dec_res: u128 = cks.decrypt(&ct_res); - assert_eq!(clear.wrapping_mul(scalar as u128), dec_res); + assert_eq!( + clear.wrapping_mul(scalar as u128), + dec_res, + "Invalid result {clear} * {scalar}" + ); } pub(crate) fn default_scalar_bitand_test(param: P, mut executor: T) diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 341189592c..074c4971b6 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -813,9 +813,12 @@ fn integer_signed_default_scalar_div_rem(param: impl Into) { // Make the degree non-fresh let offset = random_non_zero_value(&mut rng, modulus); + println!("offset: {offset}"); sks.unchecked_scalar_add_assign(&mut ctxt_0, offset); clear_lhs = signed_add_under_modulus(clear_lhs, offset, modulus); assert!(!ctxt_0.block_carries_are_empty()); + let sanity_decryption: i64 = cks.decrypt_signed_radix(&ctxt_0); + assert_eq!(sanity_decryption, clear_lhs); let (q_res, r_res) = sks.signed_scalar_div_rem_parallelized(&ctxt_0, clear_rhs); let q: i64 = cks.decrypt_signed_radix(&q_res); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs index 7de7ac574b..d44e9b1433 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs @@ -20,25 +20,6 @@ use std::sync::Arc; create_parametrized_test!(integer_signed_unchecked_sub); create_parametrized_test!(integer_signed_unchecked_overflowing_sub); -create_parametrized_test!( - integer_signed_unchecked_overflowing_sub_parallelized { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // Requires 4 bits, so 1_1 parameters are not supported - // until they get their own version of the algorithm - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, - } - } -); create_parametrized_test!(integer_signed_default_sub); create_parametrized_test!(integer_signed_default_overflowing_sub); @@ -58,17 +39,6 @@ where signed_unchecked_overflowing_sub_test(param, executor); } -fn integer_signed_unchecked_overflowing_sub_parallelized

(param: P) -where - P: Into, -{ - // Call _impl so we are sure the parallel version is tested - // However this only supports param X_X where X >= 4 - let executor = - CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_sub_parallelized_impl); - signed_unchecked_overflowing_sub_test(param, executor); -} - fn integer_signed_default_sub

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs index 3f048b7844..f70269e3e6 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs @@ -19,24 +19,6 @@ create_parametrized_test!(integer_default_add); create_parametrized_test!(integer_default_overflowing_add); create_parametrized_test!(integer_unchecked_add); create_parametrized_test!(integer_unchecked_add_assign); -create_parametrized_test!( - integer_default_add_work_efficient { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // This algorithm requires 3 bits - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, - } - } -); fn integer_unchecked_add

(param: P) where @@ -70,14 +52,6 @@ where default_add_test(param, executor); } -fn integer_default_add_work_efficient

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized_work_efficient); - default_add_test(param, executor); -} - fn integer_default_overflowing_add

(param: P) where P: Into, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs index 1a0f568a7e..0034d68a6a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_sub.rs @@ -18,24 +18,6 @@ use std::sync::Arc; create_parametrized_test!(integer_unchecked_sub); create_parametrized_test!(integer_smart_sub); create_parametrized_test!(integer_default_sub); -create_parametrized_test!( - integer_default_sub_work_efficient { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - }, - no_coverage => { - // This algorithm requires 3 bits - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, - } - } -); create_parametrized_test!(integer_default_overflowing_sub); fn integer_unchecked_sub

(param: P) @@ -62,14 +44,6 @@ where default_sub_test(param, executor); } -fn integer_default_sub_work_efficient

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized_work_efficient); - default_sub_test(param, executor); -} - fn integer_default_overflowing_sub

(param: P) where P: Into,