From de7b9d135473980b093a9ea5db73a49e34eb4ddd Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Thu, 25 Jul 2024 15:34:55 +0200 Subject: [PATCH] chore(pr): comments 3 --- .../integer/server_key/radix_parallel/add.rs | 366 ++++++++++-------- .../radix_parallel/tests_unsigned/test_add.rs | 6 +- 2 files changed, 203 insertions(+), 169 deletions(-) diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index 5b9c81dd95..fcfdceede2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -541,10 +541,8 @@ impl ServerKey { }; } - let mut carry = input_carry.map_or_else( - || self.key.create_trivial(0), - |boolean_block| boolean_block.0.clone(), - ); + let carry = + input_carry.map_or_else(|| self.create_trivial_boolean_block(false), Clone::clone); // 2_2, 3_3, 4_4 // If we have at least 2 bits and at least as much carries @@ -560,194 +558,230 @@ impl ServerKey { // the overflow computation adds additional layer of PBS. if self.key.message_modulus.0 >= 4 && self.key.carry_modulus.0 >= self.key.message_modulus.0 { - let mut overflow_flag = if requested_flag == OutputFlag::Overflow { - let mut block = self.key.unchecked_scalar_mul( - lhs.last().as_ref().unwrap(), - self.message_modulus().0 as u8, - ); - self.key - .unchecked_add_assign(&mut block, rhs.last().as_ref().unwrap()); - Some(block) - } else { - None - }; + self.advanced_add_assign_sequential_at_least_4_bits( + requested_flag, + lhs, + rhs, + carry, + input_carry, + ) + } else if self.key.message_modulus.0 == 2 { + self.advanced_add_assign_sequential_at_least_2_bits(lhs, rhs, carry, requested_flag) + } else { + panic!( + "Invalid combo of message modulus ({}) and carry modulus ({}) \n\ + This function requires the message modulus >= 2 and carry modulus >= message_modulus \n\ + I.e. PARAM_MESSAGE_X_CARRY_Y where X >= 1 and Y >= X.", + self.key.message_modulus.0, self.key.carry_modulus.0 + ); + } + } - // Handle the first block - self.key.unchecked_add_assign(&mut lhs[0], &rhs[0]); - self.key.unchecked_add_assign(&mut lhs[0], &carry); + /// Computes lhs += (rhs + carry) using the sequential propagation of carries + /// + /// parameters of blocks must have 4 bits, parameters in the form X_Y where X >= 2 && Y >= X + fn advanced_add_assign_sequential_at_least_4_bits( + &self, + requested_flag: OutputFlag, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + carry: BooleanBlock, + input_carry: Option<&BooleanBlock>, + ) -> Option { + let mut carry = carry.0; - // To be able to use carry_extract_assign in it - carry.clone_from(&lhs[0]); - rayon::scope(|s| { - s.spawn(|_| { - self.key.message_extract_assign(&mut lhs[0]); - }); + let mut overflow_flag = if requested_flag == OutputFlag::Overflow { + let mut block = self + .key + .unchecked_scalar_mul(lhs.last().as_ref().unwrap(), self.message_modulus().0 as u8); + self.key + .unchecked_add_assign(&mut block, rhs.last().as_ref().unwrap()); + Some(block) + } else { + None + }; - s.spawn(|_| { - self.key.carry_extract_assign(&mut carry); - }); + // Handle the first block + self.key.unchecked_add_assign(&mut lhs[0], &rhs[0]); + self.key.unchecked_add_assign(&mut lhs[0], &carry); - if requested_flag == OutputFlag::Overflow { - s.spawn(|_| { - // Computing the overflow flag requires an extra step for the first block + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[0]); + rayon::scope(|s| { + s.spawn(|_| { + self.key.message_extract_assign(&mut lhs[0]); + }); - let overflow_flag = overflow_flag.as_mut().unwrap(); - let num_bits_in_message = self.message_modulus().0.ilog2() as u64; - let lut = self.key.generate_lookup_table(|lhs_rhs| { - let lhs = lhs_rhs / self.message_modulus().0 as u64; - let rhs = lhs_rhs % self.message_modulus().0 as u64; - overflow_flag_preparation_lut(lhs, rhs, num_bits_in_message) - }); - self.key.apply_lookup_table_assign(overflow_flag, &lut); - }); - } + s.spawn(|_| { + self.key.carry_extract_assign(&mut carry); }); - let num_blocks = lhs.len(); - - // We did the first block before, the last block is done after this if, - // so we need 3 blocks at least to enter this - if num_blocks >= 3 { - for (lhs_b, rhs_b) in lhs[1..num_blocks - 1] - .iter_mut() - .zip(rhs[1..num_blocks - 1].iter()) - { - self.key.unchecked_add_assign(lhs_b, rhs_b); - self.key.unchecked_add_assign(lhs_b, &carry); - - carry.clone_from(lhs_b); - rayon::join( - || self.key.message_extract_assign(lhs_b), - || self.key.carry_extract_assign(&mut carry), - ); - } + if requested_flag == OutputFlag::Overflow { + s.spawn(|_| { + // Computing the overflow flag requires an extra step for the first block + + let overflow_flag = overflow_flag.as_mut().unwrap(); + let num_bits_in_message = self.message_modulus().0.ilog2() as u64; + let lut = self.key.generate_lookup_table(|lhs_rhs| { + let lhs = lhs_rhs / self.message_modulus().0 as u64; + let rhs = lhs_rhs % self.message_modulus().0 as u64; + overflow_flag_preparation_lut(lhs, rhs, num_bits_in_message) + }); + self.key.apply_lookup_table_assign(overflow_flag, &lut); + }); } + }); - if num_blocks >= 2 { - // Handle the last block - self.key - .unchecked_add_assign(&mut lhs[num_blocks - 1], &rhs[num_blocks - 1]); - self.key - .unchecked_add_assign(&mut lhs[num_blocks - 1], &carry); - } + let num_blocks = lhs.len(); - if let Some(block) = overflow_flag.as_mut() { - if num_blocks == 1 && input_carry.is_some() { - self.key - .unchecked_add_assign(block, input_carry.map(|b| &b.0).unwrap()); - } else { - self.key.unchecked_add_assign(block, &carry); - } + // We did the first block before, the last block is done after this if, + // so we need 3 blocks at least to enter this + if num_blocks >= 3 { + for (lhs_b, rhs_b) in lhs[1..num_blocks - 1] + .iter_mut() + .zip(rhs[1..num_blocks - 1].iter()) + { + self.key.unchecked_add_assign(lhs_b, rhs_b); + self.key.unchecked_add_assign(lhs_b, &carry); + + carry.clone_from(lhs_b); + rayon::join( + || self.key.message_extract_assign(lhs_b), + || self.key.carry_extract_assign(&mut carry), + ); } + } - // To be able to use carry_extract_assign in it - carry.clone_from(&lhs[num_blocks - 1]); + if num_blocks >= 2 { + // Handle the last block + self.key + .unchecked_add_assign(&mut lhs[num_blocks - 1], &rhs[num_blocks - 1]); + self.key + .unchecked_add_assign(&mut lhs[num_blocks - 1], &carry); + } - // Note that here when num_blocks == 1 && requested_flag != Overflow nothing - // will actually be spawned. - rayon::scope(|s| { - if num_blocks >= 2 { - // These would already have been done when the first block was processed - s.spawn(|_| { - self.key.message_extract_assign(&mut lhs[num_blocks - 1]); - }); + if let Some(block) = overflow_flag.as_mut() { + if num_blocks == 1 && input_carry.is_some() { + self.key + .unchecked_add_assign(block, input_carry.map(|b| &b.0).unwrap()); + } else { + self.key.unchecked_add_assign(block, &carry); + } + } - s.spawn(|_| { - self.key.carry_extract_assign(&mut carry); - }); - } + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[num_blocks - 1]); - if requested_flag == OutputFlag::Overflow { - s.spawn(|_| { - let overflow_flag_block = overflow_flag.as_mut().unwrap(); - // Computing the overflow flag requires and extra step for the first block - let overflow_flag_lut = self.key.generate_lookup_table(|block| { - let input_carry = block & 1; - let does_overflow_if_carry_is_1 = (block >> 3) & 1; - let does_overflow_if_carry_is_0 = (block >> 2) & 1; - if input_carry == 1 { - does_overflow_if_carry_is_1 - } else { - does_overflow_if_carry_is_0 - } - }); + // Note that here when num_blocks == 1 && requested_flag != Overflow nothing + // will actually be spawned. + rayon::scope(|s| { + if num_blocks >= 2 { + // These would already have been done when the first block was processed + s.spawn(|_| { + self.key.message_extract_assign(&mut lhs[num_blocks - 1]); + }); - self.key - .apply_lookup_table_assign(overflow_flag_block, &overflow_flag_lut); + s.spawn(|_| { + self.key.carry_extract_assign(&mut carry); + }); + } + + if requested_flag == OutputFlag::Overflow { + s.spawn(|_| { + let overflow_flag_block = overflow_flag.as_mut().unwrap(); + // Computing the overflow flag requires and extra step for the first block + let overflow_flag_lut = self.key.generate_lookup_table(|block| { + let input_carry = block & 1; + let does_overflow_if_carry_is_1 = (block >> 3) & 1; + let does_overflow_if_carry_is_0 = (block >> 2) & 1; + if input_carry == 1 { + does_overflow_if_carry_is_1 + } else { + does_overflow_if_carry_is_0 + } }); - } - }); - match requested_flag { - OutputFlag::None => None, - OutputFlag::Overflow => { - assert!( - overflow_flag.is_some(), - "internal error, overflow_flag should exist" - ); - overflow_flag.map(BooleanBlock::new_unchecked) - } - OutputFlag::Carry => { - carry.degree = Degree::new(1); - Some(BooleanBlock::new_unchecked(carry)) - } + self.key + .apply_lookup_table_assign(overflow_flag_block, &overflow_flag_lut); + }); } - } else if self.key.message_modulus.0 == 2 { - // 1_X parameters - // - // Same idea as other algorithms, however since we have 1 bit per block - // we do not have to resolve any inner propagation but it adds one more - // sequential PBS when we are interested in the OverflowFlag - fn block_add_assign_returning_carry( - sks: &ServerKey, - lhs: &mut Ciphertext, - rhs: &Ciphertext, - carry: &Ciphertext, - ) -> Ciphertext { - sks.key.unchecked_add_assign(lhs, rhs); - sks.key.unchecked_add_assign(lhs, carry); - let (carry, message) = rayon::join( - || sks.key.carry_extract(lhs), - || sks.key.message_extract(lhs), - ); + }); - *lhs = message; - - carry + match requested_flag { + OutputFlag::None => None, + OutputFlag::Overflow => { + assert!( + overflow_flag.is_some(), + "internal error, overflow_flag should exist" + ); + overflow_flag.map(BooleanBlock::new_unchecked) } - let num_blocks = lhs.len(); - for (lhs_b, rhs_b) in lhs[..num_blocks - 1] - .iter_mut() - .zip(rhs[..num_blocks - 1].iter()) - { - carry = block_add_assign_returning_carry(self, lhs_b, rhs_b, &carry); + OutputFlag::Carry => { + carry.degree = Degree::new(1); + Some(BooleanBlock::new_unchecked(carry)) } + } + } - let mut output_carry = block_add_assign_returning_carry( - self, - &mut lhs[num_blocks - 1], - &rhs[num_blocks - 1], - &carry, + /// Computes lhs += (rhs + carry) using the sequential propagation of carries + /// + /// parameters of blocks must have 2 bits, parameters in the form X_Y where X >= 1 && Y >= X + // so 1_X parameters + // + // Same idea as other algorithms, however since we have 1 bit per block + // we do not have to resolve any inner propagation but it adds one more + // sequential PBS when we are interested in the OverflowFlag + fn advanced_add_assign_sequential_at_least_2_bits( + &self, + lhs: &mut [Ciphertext], + rhs: &[Ciphertext], + carry: BooleanBlock, + requested_flag: OutputFlag, + ) -> Option { + let mut carry = carry.0; + + fn block_add_assign_returning_carry( + sks: &ServerKey, + lhs: &mut Ciphertext, + rhs: &Ciphertext, + carry: &Ciphertext, + ) -> Ciphertext { + sks.key.unchecked_add_assign(lhs, rhs); + sks.key.unchecked_add_assign(lhs, carry); + let (carry, message) = rayon::join( + || sks.key.carry_extract(lhs), + || sks.key.message_extract(lhs), ); - match requested_flag { - OutputFlag::None => None, - OutputFlag::Overflow => { - let overflowed = self.key.not_equal(&output_carry, &carry); - Some(BooleanBlock::new_unchecked(overflowed)) - } - OutputFlag::Carry => { - output_carry.degree = Degree::new(1); - Some(BooleanBlock::new_unchecked(output_carry)) - } + *lhs = message; + + carry + } + let num_blocks = lhs.len(); + for (lhs_b, rhs_b) in lhs[..num_blocks - 1] + .iter_mut() + .zip(rhs[..num_blocks - 1].iter()) + { + carry = block_add_assign_returning_carry(self, lhs_b, rhs_b, &carry); + } + + let mut output_carry = block_add_assign_returning_carry( + self, + &mut lhs[num_blocks - 1], + &rhs[num_blocks - 1], + &carry, + ); + + match requested_flag { + OutputFlag::None => None, + OutputFlag::Overflow => { + let overflowed = self.key.not_equal(&output_carry, &carry); + Some(BooleanBlock::new_unchecked(overflowed)) + } + OutputFlag::Carry => { + output_carry.degree = Degree::new(1); + Some(BooleanBlock::new_unchecked(output_carry)) } - } else { - panic!( - "Invalid combo of message modulus ({}) and carry modulus ({}) \n\ - This function requires the message modulus >= 2 and carry modulus >= message_modulus \n\ - I.e. PARAM_MESSAGE_X_CARRY_Y where X >= 1 and Y >= X.", - self.key.message_modulus.0, self.key.carry_modulus.0 - ); } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs index f75a48dd41..f4b39e2494 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs @@ -339,8 +339,8 @@ where for num_blocks in 1..MAX_NB_CTXT { let modulus = unsigned_modulus(cks.parameters().message_modulus(), num_blocks as u32); - let clear_1 = rng.gen::() % modulus; let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; let ctxt_0 = cks.as_ref().encrypt_radix(clear_0, num_blocks); let ctxt_1 = cks.as_ref().encrypt_radix(clear_1, num_blocks); @@ -348,7 +348,7 @@ where let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); - // panic_if_any_block_is_not_clean(&ct_res, &cks); + panic_if_any_block_is_not_clean(&ct_res, &cks); assert_eq!(ct_res, tmp_ct); clear = clear_0.wrapping_add(clear_1) % modulus; @@ -361,7 +361,7 @@ where for _ in 0..nb_tests_smaller { ct_res = executor.execute((&ct_res, &ctxt_0)); - // panic_if_any_block_is_not_clean(&ct_res, &cks); + panic_if_any_block_is_not_clean(&ct_res, &cks); let result = (clear + clear_0) % modulus;