From d39c35c62e797be3330635f1d21c78922599ac23 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Tue, 20 Aug 2024 22:39:43 +0200 Subject: [PATCH] chore(integer): addition test based on trivial inputs This adds `overflowing_add` and `add` tests that are on trivial inputs. As these are faster to run they can be more extensive than on true encryptions This also binds the advanced_add_assign functions tests to include overflow computation On a standard laptop with 1 test thread it takes ~7 minutes to run these trivial tests --- .../integer/server_key/radix_parallel/add.rs | 10 +- .../radix_parallel/tests_signed/mod.rs | 19 +- .../radix_parallel/tests_signed/test_add.rs | 352 ++++++++++++++---- .../radix_parallel/tests_unsigned/mod.rs | 19 +- .../radix_parallel/tests_unsigned/test_add.rs | 246 +++++++++++- 5 files changed, 537 insertions(+), 109 deletions(-) diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index 22db4a240f..5c4e4cdd14 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -12,7 +12,7 @@ pub(crate) enum OutputFlag { /// Request no flag at all None, /// The overflow flag is the flag that tells whether the input carry bit onto the last bit - /// is different than the output bit. + /// is different from the output bit. /// /// This is useful to know if a signed addition overflowed (in 2's complement) Overflow, @@ -666,18 +666,18 @@ impl ServerKey { if num_blocks == 1 && input_carry.is_some() { self.key .unchecked_add_assign(block, input_carry.map(|b| &b.0).unwrap()); - } else { + } else if num_blocks > 1 { self.key.unchecked_add_assign(block, &carry); } } - // To be able to use carry_extract_assign in it - carry.clone_from(&lhs[num_blocks - 1]); - // Note that here when num_blocks == 1 && requested_flag != Overflow nothing // will actually be spawned. rayon::scope(|s| { if num_blocks >= 2 { + // To be able to use carry_extract_assign in it + carry.clone_from(&lhs[num_blocks - 1]); + // These would already have been done when the first block was processed s.spawn(|_| { self.key.message_extract_assign(&mut lhs[num_blocks - 1]); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 64270f3462..16e356ba5b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -18,6 +18,7 @@ pub(crate) mod test_shift; pub(crate) mod test_sub; pub(crate) mod test_vector_comparisons; +use crate::core_crypto::prelude::SignedInteger; use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_unsigned::{ @@ -807,7 +808,7 @@ fn integer_signed_default_scalar_div_rem(param: impl Into) { // Helper functions //================================================================================ -pub(crate) fn signed_add_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 { +pub(crate) fn signed_add_under_modulus(lhs: T, rhs: T, modulus: T) -> T { signed_overflowing_add_under_modulus(lhs, rhs, modulus).0 } @@ -816,12 +817,12 @@ pub(crate) fn signed_add_under_modulus(lhs: i64, rhs: i64, modulus: i64) -> i64 // This is to 'simulate' i8, i16, ixy using i64 integers // // lhs and rhs must be in [-modulus..modulus[ -pub(crate) fn signed_overflowing_add_under_modulus( - lhs: i64, - rhs: i64, - modulus: i64, -) -> (i64, bool) { - assert!(modulus > 0); +pub(crate) fn signed_overflowing_add_under_modulus( + lhs: T, + rhs: T, + modulus: T, +) -> (T, bool) { + assert!(modulus > T::ZERO); assert!((-modulus..modulus).contains(&lhs)); // The code below requires rhs and lhs to be in range -modulus..modulus @@ -831,14 +832,14 @@ pub(crate) fn signed_overflowing_add_under_modulus( (lhs + rhs, false) } else { // 2*modulus to get all the bits - (lhs + (rhs % (2 * modulus)), true) + (lhs + (rhs % (T::TWO * modulus)), true) }; if res < -modulus { // rem_euclid(modulus) would also work res = modulus + (res - -modulus); overflowed = true; - } else if res > modulus - 1 { + } else if res > modulus - T::ONE { res = -modulus + (res - modulus); overflowed = true; } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs index faace022dd..73d3488a91 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_add.rs @@ -6,8 +6,9 @@ use crate::integer::server_key::radix_parallel::tests_signed::{ }; use crate::integer::server_key::radix_parallel::tests_unsigned::{ nb_tests_for_params, nb_tests_smaller_for_params, nb_unchecked_tests_for_params, - CpuFunctionExecutor, + CpuFunctionExecutor, MAX_NB_CTXT, }; +use crate::integer::server_key::radix_parallel::OutputFlag; use crate::integer::tests::create_parametrized_test; use crate::integer::{ BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext, @@ -41,7 +42,31 @@ create_parametrized_test!( ); create_parametrized_test!(integer_signed_smart_add); create_parametrized_test!(integer_signed_default_add); +create_parametrized_test!(integer_extensive_trivial_signed_default_add); create_parametrized_test!(integer_signed_default_overflowing_add); +create_parametrized_test!(integer_extensive_trivial_signed_overflowing_add); +create_parametrized_test!( + integer_extensive_trivial_signed_advanced_overflowing_add_assign_with_carry_sequential +); +create_parametrized_test!( + integer_extensive_trivial_signed_overflowing_advanced_add_assign_with_carry_at_least_4_bits { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + }, + no_coverage => { + // Requires 4 bits, so 1_1 parameters are not supported + // until they get their own version of the algorithm + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS, + } + } +); fn integer_signed_unchecked_add

(param: P) where @@ -76,6 +101,14 @@ where signed_default_overflowing_add_test(param, executor); } +fn integer_extensive_trivial_signed_overflowing_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized); + extensive_trivial_signed_default_overflowing_add_test(param, executor); +} + fn integer_signed_default_add

(param: P) where P: Into, @@ -84,6 +117,58 @@ where signed_default_add_test(param, executor); } +fn integer_extensive_trivial_signed_default_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized); + extensive_trivial_signed_default_add_test(param, executor); +} + +fn integer_extensive_trivial_signed_advanced_overflowing_add_assign_with_carry_sequential

( + param: P, +) where + P: Into, +{ + let func = |sks: &ServerKey, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext| { + let mut result = lhs.clone(); + let overflowed = sks + .advanced_add_assign_with_carry_sequential_parallelized( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::Overflow, + ) + .unwrap(); + (result, overflowed) + }; + let executor = CpuFunctionExecutor::new(&func); + extensive_trivial_signed_default_overflowing_add_test(param, executor); +} + +fn integer_extensive_trivial_signed_overflowing_advanced_add_assign_with_carry_at_least_4_bits

( + param: P, +) where + P: Into, +{ + // We explicitly call the 4 bit function to make sure it's being tested, + // no matter the number of blocks / threads available + let func = |sks: &ServerKey, lhs: &SignedRadixCiphertext, rhs: &SignedRadixCiphertext| { + let mut result = lhs.clone(); + let overflowed = sks + .advanced_add_assign_with_carry_at_least_4_bits( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::Overflow, + ) + .unwrap(); + (result, overflowed) + }; + let executor = CpuFunctionExecutor::new(&func); + extensive_trivial_signed_default_overflowing_add_test(param, executor); +} + fn integer_signed_smart_add

(param: P) where P: Into, @@ -226,83 +311,88 @@ where let mut rng = rand::thread_rng(); - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - executor.setup(&cks, sks.clone()); - for _ in 0..nb_tests_smaller { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); - - let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); - let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct, "Failed determinism check"); - assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); - - let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); - - let decrypted_result: i64 = cks.decrypt_signed(&ct_res); - let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" - ); - assert_eq!( - decrypted_overflowed, - expected_overflowed, - "Invalid overflow flag result for overflowing_suv for ({clear_0} + {clear_1}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" - ); - assert_eq!(result_overflowed.0.degree.get(), 1); - assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + for num_blocks in 1..MAX_NB_CTXT { + let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64; + if modulus == 1 { + // Basically have one bit the sign bit can't really test + continue; + } for _ in 0..nb_tests_smaller { - // Add non zero scalar to have non clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear_3 = random_non_zero_value(&mut rng, modulus); - - let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); - let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3); + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; - let clear_lhs = signed_add_under_modulus(clear_0, clear_2, modulus); - let clear_rhs = signed_add_under_modulus(clear_1, clear_3, modulus); - - let d0: i64 = cks.decrypt_signed(&ctxt_0); - assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); - let d1: i64 = cks.decrypt_signed(&ctxt_1); - assert_eq!(d1, clear_rhs, "Failed sanity decryption check"); + let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks); + let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks); let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1)); assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp_ct, "Failed determinism check"); + assert_eq!(tmp_o, result_overflowed, "Failed determinism check"); let (expected_result, expected_overflowed) = - signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); let decrypted_result: i64 = cks.decrypt_signed(&ct_res); let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); assert_eq!( decrypted_result, expected_result, - "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected {expected_result}, got {decrypted_result}" + "Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" ); assert_eq!( decrypted_overflowed, expected_overflowed, - "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ - expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + "Invalid overflow flag result for overflowing_suv for ({clear_0} + {clear_1}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" ); assert_eq!(result_overflowed.0.degree.get(), 1); assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + + for _ in 0..nb_tests_smaller { + // Add non zero scalar to have non clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear_3 = random_non_zero_value(&mut rng, modulus); + + let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2); + let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3); + + let clear_lhs = signed_add_under_modulus(clear_0, clear_2, modulus); + let clear_rhs = signed_add_under_modulus(clear_1, clear_3, modulus); + + let d0: i64 = cks.decrypt_signed(&ctxt_0); + assert_eq!(d0, clear_lhs, "Failed sanity decryption check"); + let d1: i64 = cks.decrypt_signed(&ctxt_1); + assert_eq!(d1, clear_rhs, "Failed sanity decryption check"); + + let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + + let (expected_result, expected_overflowed) = + signed_overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus); + + let decrypted_result: i64 = cks.decrypt_signed(&ct_res); + let decrypted_overflowed = cks.decrypt_bool(&result_overflowed); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected {expected_result}, got {decrypted_result}" + ); + assert_eq!( + decrypted_overflowed, + expected_overflowed, + "Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \ + expected overflow flag {expected_overflowed}, got {decrypted_overflowed}" + ); + assert_eq!(result_overflowed.0.degree.get(), 1); + assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL); + } } } - + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; // Test with trivial inputs, as it was bugged at some point for _ in 0..4 { // Reduce maximum value of random number such that at least the last block is a trivial 0 @@ -337,6 +427,64 @@ where } } +/// Although this uses the executor pattern and could be plugged in other backends, +/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts +/// or extremely extremely fast in general, or if its plugged just as a one time thing. +pub(crate) fn extensive_trivial_signed_default_overflowing_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + (SignedRadixCiphertext, BooleanBlock), + >, +{ + let param = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let message_modulus = cks.parameters().message_modulus(); + let block_num_bits = message_modulus.0.ilog2(); + // Contrary to regular add, we do bit_size every block num_bits, + // otherwise the bit_size actually encrypted is not exactly the same + // leading to false test overflow results. + for bit_size in (2..=64u32).step_by(block_num_bits as usize) { + let num_blocks = bit_size.div_ceil(block_num_bits); + let modulus = (cks.parameters().message_modulus().0 as i128).pow(num_blocks) / 2; + + for _ in 0..50 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize); + let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize); + + let (ct_res, ct_overflow) = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i128 = cks.decrypt_signed(&ct_res); + let dec_overflow = cks.decrypt_bool(&ct_overflow); + + let (expected_clear, expected_overflow) = + signed_overflowing_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!( + expected_clear, dec_res, + "Invalid result for {clear_0} + {clear_1}, expected: {expected_clear}, got: {dec_res}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); + assert_eq!( + expected_overflow, dec_overflow, + "Invalid overflow result for {clear_0} + {clear_1}, expected: {expected_overflow}, got: {dec_overflow}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); + } + } +} + pub(crate) fn signed_unchecked_add_test(param: P, mut executor: T) where P: Into, @@ -404,37 +552,91 @@ where let mut rng = rand::thread_rng(); - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - executor.setup(&cks, sks); let mut clear; - for _ in 0..nb_tests_smaller { - let clear_0 = rng.gen::() % modulus; - let clear_1 = rng.gen::() % modulus; - - let ctxt_0 = cks.encrypt_signed(clear_0); - let ctxt_1 = cks.encrypt_signed(clear_1); + for num_blocks in 1..MAX_NB_CTXT { + let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64; + if modulus == 1 { + // Basically have one bit the sign bit can't really test + continue; + } - let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); - let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp_ct); + for _ in 0..nb_tests_smaller { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; - clear = signed_add_under_modulus(clear_0, clear_1, modulus); + let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks); + let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks); - // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); - // add multiple times to raise the degree - for _ in 0..nb_tests_smaller { - ct_res = executor.execute((&ct_res, &ctxt_0)); + let mut ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let tmp_ct = executor.execute((&ctxt_0, &ctxt_1)); assert!(ct_res.block_carries_are_empty()); - clear = signed_add_under_modulus(clear, clear_0, modulus); + assert_eq!(ct_res, tmp_ct); - let dec_res: i64 = cks.decrypt_signed(&ct_res); + clear = signed_add_under_modulus(clear_0, clear_1, modulus); - // println!("clear = {}, dec_res = {}", clear, dec_res); - assert_eq!(clear, dec_res); + // println!("clear_0 = {}, clear_1 = {}", clear_0, clear_1); + // add multiple times to raise the degree + for _ in 0..nb_tests_smaller { + ct_res = executor.execute((&ct_res, &ctxt_0)); + assert!(ct_res.block_carries_are_empty()); + clear = signed_add_under_modulus(clear, clear_0, modulus); + + let dec_res: i64 = cks.decrypt_signed(&ct_res); + + // println!("clear = {}, dec_res = {}", clear, dec_res); + assert_eq!(clear, dec_res); + } + } + } +} + +/// Although this uses the executor pattern and could be plugged in other backends, +/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts +/// or extremely extremely fast in general, or if its plugged just as a one time thing. +pub(crate) fn extensive_trivial_signed_default_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), + SignedRadixCiphertext, + >, +{ + let param = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let message_modulus = cks.parameters().message_modulus(); + let block_num_bits = message_modulus.0.ilog2(); + for bit_size in 2..=64u32 { + let num_blocks = bit_size.div_ceil(block_num_bits); + let modulus = (cks.parameters().message_modulus().0 as i128).pow(num_blocks) / 2; + + for _ in 0..50 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize); + let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: i128 = cks.decrypt_signed(&ct_res); + + let expected_clear = signed_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!( + expected_clear, dec_res, + "Invalid result for {clear_0} + {clear_1}, expected: {expected_clear}, got: {dec_res}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); } } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 8dff2679c7..00886fd9b2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -24,6 +24,7 @@ pub(crate) mod test_vector_comparisons; pub(crate) mod test_vector_find; use super::tests_cases_unsigned::*; +use crate::core_crypto::prelude::UnsignedInteger; use crate::integer::keycache::KEY_CACHE; use crate::integer::tests::create_parametrized_test; use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; @@ -157,7 +158,11 @@ pub(crate) fn overflowing_sub_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (result % modulus, overflowed) } -pub(crate) fn overflowing_add_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) { +pub(crate) fn overflowing_add_under_modulus( + lhs: T, + rhs: T, + modulus: T, +) -> (T, bool) { let (result, overflowed) = lhs.overflowing_add(rhs); (result % modulus, overflowed || result >= modulus) } @@ -186,6 +191,18 @@ pub(crate) fn unsigned_modulus(block_modulus: MessageModulus, num_blocks: u32) - .expect("Modulus exceed u64::MAX") } +/// This is just a copy-paste as it creates less breakage than modify the u64 one to return +/// an u128. +/// +/// Also, it would mean users would do `unsigned_modulus(...) as u64` which when reading +/// could create the suspicion of whether the as cast is value and try_into should be used, +/// but then it becomes more verbose. +pub(crate) fn unsigned_modulus_u128(block_modulus: MessageModulus, num_blocks: u32) -> u128 { + (block_modulus.0 as u128) + .checked_pow(num_blocks) + .expect("Modulus exceed u128::MAX") +} + /// Given a radix ciphertext, checks that all the block's decrypted message and carry /// do not exceed the block's degree. #[track_caller] diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs index 798ac86460..9c66f2bc4f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_add.rs @@ -2,7 +2,8 @@ use super::{ nb_tests_for_params, nb_tests_smaller_for_params, overflowing_add_under_modulus, panic_if_any_block_info_exceeds_max_degree_or_noise, panic_if_any_block_is_not_clean, panic_if_any_block_values_exceeds_its_degree, random_non_zero_value, unsigned_modulus, - CpuFunctionExecutor, ExpectedDegrees, ExpectedNoiseLevels, MAX_NB_CTXT, NB_CTXT, + unsigned_modulus_u128, CpuFunctionExecutor, ExpectedDegrees, ExpectedNoiseLevels, MAX_NB_CTXT, + NB_CTXT, }; use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; @@ -19,8 +20,10 @@ create_parametrized_test!(integer_unchecked_add); create_parametrized_test!(integer_unchecked_add_assign); create_parametrized_test!(integer_smart_add); create_parametrized_test!(integer_default_add); +create_parametrized_test!(integer_extensive_trivial_default_add); create_parametrized_test!(integer_default_overflowing_add); -create_parametrized_test!(integer_advanced_add_assign_with_carry_at_least_4_bits { +create_parametrized_test!(integer_extensive_trivial_default_overflowing_add); +create_parametrized_test!(integer_advanced_overflowing_add_assign_with_carry_at_least_4_bits { coverage => { COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS @@ -36,6 +39,24 @@ create_parametrized_test!(integer_advanced_add_assign_with_carry_at_least_4_bits } }); create_parametrized_test!(integer_advanced_add_assign_with_carry_sequential); +create_parametrized_test!(integer_extensive_trivial_overflowing_advanced_add_assign_with_carry_at_least_4_bits { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS + } +}); +create_parametrized_test!( + integer_extensive_trivial_advanced_overflowing_add_assign_with_carry_sequential +); fn integer_unchecked_add

(param: P) where @@ -69,7 +90,15 @@ where default_add_test(param, executor); } -fn integer_advanced_add_assign_with_carry_at_least_4_bits

(param: P) +fn integer_extensive_trivial_default_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized); + extensive_trivial_default_add_test(param, executor); +} + +fn integer_advanced_overflowing_add_assign_with_carry_at_least_4_bits

(param: P) where P: Into, { @@ -77,16 +106,51 @@ where // no matter the number of blocks / threads available let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { let mut result = lhs.clone(); - sks.advanced_add_assign_with_carry_at_least_4_bits( - &mut result.blocks, - &rhs.blocks, - None, - OutputFlag::None, - ); - result + if !result.block_carries_are_empty() { + sks.full_propagate_parallelized(&mut result); + } + let mut tmp_rhs; + let rhs = if rhs.block_carries_are_empty() { + rhs + } else { + tmp_rhs = rhs.clone(); + sks.full_propagate_parallelized(&mut tmp_rhs); + &tmp_rhs + }; + let overflowed = sks + .advanced_add_assign_with_carry_at_least_4_bits( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::Carry, + ) + .unwrap(); + (result, overflowed) }; let executor = CpuFunctionExecutor::new(&func); - default_add_test(param, executor); + default_overflowing_add_test(param, executor); +} + +fn integer_extensive_trivial_overflowing_advanced_add_assign_with_carry_at_least_4_bits

(param: P) +where + P: Into, +{ + // We explicitly call the 4 bit function to make sure it's being tested, + // no matter the number of blocks / threads available + let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { + let mut result = lhs.clone(); + let overflowed = sks + .advanced_add_assign_with_carry_at_least_4_bits( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::Carry, + ) + .unwrap(); + (result, overflowed) + }; + let executor = CpuFunctionExecutor::new(&func); + extensive_trivial_default_overflowing_add_test(param, executor); } fn integer_advanced_add_assign_with_carry_sequential

(param: P) @@ -95,16 +159,49 @@ where { let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { let mut result = lhs.clone(); - sks.advanced_add_assign_with_carry_sequential_parallelized( - &mut result.blocks, - &rhs.blocks, - None, - OutputFlag::None, - ); - result + if !result.block_carries_are_empty() { + sks.full_propagate_parallelized(&mut result); + } + let mut tmp_rhs; + let rhs = if rhs.block_carries_are_empty() { + rhs + } else { + tmp_rhs = rhs.clone(); + sks.full_propagate_parallelized(&mut tmp_rhs); + &tmp_rhs + }; + let overflowed = sks + .advanced_add_assign_with_carry_sequential_parallelized( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::Carry, + ) + .unwrap(); + (result, overflowed) }; let executor = CpuFunctionExecutor::new(&func); - default_add_test(param, executor); + default_overflowing_add_test(param, executor); +} + +fn integer_extensive_trivial_advanced_overflowing_add_assign_with_carry_sequential

(param: P) +where + P: Into, +{ + let func = |sks: &ServerKey, lhs: &RadixCiphertext, rhs: &RadixCiphertext| { + let mut result = lhs.clone(); + let overflowed = sks + .advanced_add_assign_with_carry_sequential_parallelized( + &mut result.blocks, + &rhs.blocks, + None, + OutputFlag::Carry, + ) + .unwrap(); + (result, overflowed) + }; + let executor = CpuFunctionExecutor::new(&func); + extensive_trivial_default_overflowing_add_test(param, executor); } fn integer_default_overflowing_add

(param: P) @@ -115,6 +212,14 @@ where default_overflowing_add_test(param, executor); } +fn integer_extensive_trivial_default_overflowing_add

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized); + extensive_trivial_default_overflowing_add_test(param, executor); +} + impl ExpectedNoiseLevels { fn after_unchecked_add(&mut self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> &Self { self.set_with( @@ -374,6 +479,54 @@ where } } +/// Although this uses the executor pattern and could be plugged in other backends, +/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts +/// or extremely extremely fast in general, or if its plugged just as a one time thing. +pub(crate) fn extensive_trivial_default_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, +{ + let param = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let message_modulus = cks.parameters().message_modulus(); + let block_num_bits = message_modulus.0.ilog2(); + // Contrary to regular add, we do bit_size every block num_bits, + // otherwise the bit_size actually encrypted is not exactly the same + // leading to false test overflow results. + for bit_size in (1..=64u32).step_by(block_num_bits as usize) { + let num_blocks = bit_size.div_ceil(block_num_bits); + let modulus = unsigned_modulus_u128(cks.parameters().message_modulus(), num_blocks); + + for _ in 0..50 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize); + let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize); + + let ct_res = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: u128 = cks.decrypt(&ct_res); + + let expected_clear = clear_0.wrapping_add(clear_1) % modulus; + assert_eq!( + expected_clear, dec_res, + "Invalid result for {clear_0} + {clear_1}, expected: {expected_clear}, got: {dec_res}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); + } + } +} + pub(crate) fn default_overflowing_add_test(param: P, mut executor: T) where P: Into, @@ -498,3 +651,58 @@ where assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); } } + +/// Although this uses the executor pattern and could be plugged in other backends, +/// It is not recommended to do so unless the backend is extremely fast on trivial ciphertexts +/// or extremely extremely fast in general, or if its plugged just as a one time thing. +pub(crate) fn extensive_trivial_default_overflowing_add_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, BooleanBlock), + >, +{ + let param = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let message_modulus = cks.parameters().message_modulus(); + let block_num_bits = message_modulus.0.ilog2(); + for bit_size in 1..=64u32 { + let num_blocks = bit_size.div_ceil(block_num_bits); + let modulus = unsigned_modulus_u128(cks.parameters().message_modulus(), num_blocks); + + for _ in 0..50 { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_0 = sks.create_trivial_radix(clear_0, num_blocks as usize); + let ctxt_1 = sks.create_trivial_radix(clear_1, num_blocks as usize); + + let (ct_res, o_res) = executor.execute((&ctxt_0, &ctxt_1)); + let dec_res: u128 = cks.decrypt(&ct_res); + let dec_overflow = cks.decrypt_bool(&o_res); + + let (expected_clear, expected_overflow) = + overflowing_add_under_modulus(clear_0, clear_1, modulus); + assert_eq!( + expected_clear, dec_res, + "Invalid result for {clear_0} + {clear_1}, expected: {expected_clear}, got: {dec_res}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); + assert_eq!( + expected_overflow, dec_overflow, + "Invalid overflow result for {clear_0} + {clear_1}, expected: {expected_overflow}, got: {dec_overflow}\n\ + num_blocks={num_blocks}, modulus={modulus}" + ); + } + } +}