From 42b7c2f403abc3f92cea7f694fe95aadb7c3f224 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Mon, 19 Feb 2024 12:40:31 +0100 Subject: [PATCH] fix(integer): correct degree in small comparisons --- tfhe/src/integer/server_key/comparator.rs | 53 ++++++++++++------- .../radix_parallel/tests_cases_comparisons.rs | 42 ++++++++------- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/tfhe/src/integer/server_key/comparator.rs b/tfhe/src/integer/server_key/comparator.rs index bae8794272..609c9ea44e 100644 --- a/tfhe/src/integer/server_key/comparator.rs +++ b/tfhe/src/integer/server_key/comparator.rs @@ -261,20 +261,33 @@ impl<'a> Comparator<'a> { std::mem::swap(&mut sign_blocks_2, &mut sign_blocks); } - let final_lut = self.server_key.key.generate_lookup_table(|x| { - let final_sign = reduce_two_orderings_function(x); - sign_result_handler_fn(final_sign) - }); - - // We don't use pack_block_assign as the offset '4' does not depend on params - let mut result = self.server_key.key.unchecked_scalar_mul(&sign_blocks[1], 4); - self.server_key - .key - .unchecked_add_assign(&mut result, &sign_blocks[0]); - self.server_key - .key - .apply_lookup_table_assign(&mut result, &final_lut); - result + if sign_blocks.len() == 2 { + let final_lut = self.server_key.key.generate_lookup_table(|x| { + let final_sign = reduce_two_orderings_function(x); + sign_result_handler_fn(final_sign) + }); + // We don't use pack_block_assign as the offset '4' does not depend on params + let mut result = self.server_key.key.unchecked_scalar_mul(&sign_blocks[1], 4); + self.server_key + .key + .unchecked_add_assign(&mut result, &sign_blocks[0]); + self.server_key + .key + .apply_lookup_table_assign(&mut result, &final_lut); + result + } else { + let final_lut = self.server_key.key.generate_lookup_table(|x| { + // sign blocks have values in the set {0, 1, 2} + // here we force apply that modulus explicitely + // so that generate_lookup_table has the correct + // degree estimation + let final_sign = x % 3; + sign_result_handler_fn(final_sign) + }); + self.server_key + .key + .apply_lookup_table(&sign_blocks[0], &final_lut) + } } /// Reduces a vec containing shortint blocks that encrypts a sign @@ -322,10 +335,14 @@ impl<'a> Comparator<'a> { .apply_lookup_table_assign(&mut result, &final_lut); result } else { - let final_lut = self - .server_key - .key - .generate_lookup_table(sign_result_handler_fn); + let final_lut = self.server_key.key.generate_lookup_table(|x| { + // sign blocks have values in the set {0, 1, 2} + // here we force apply that modulus explicitely + // so that generate_lookup_table has the correct + // degree estimation + let final_sign = x % 3; + sign_result_handler_fn(final_sign) + }); self.server_key .key .apply_lookup_table(&sign_blocks[0], &final_lut) diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_comparisons.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_comparisons.rs index d1cc4f982e..3a4794b5c5 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_comparisons.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_comparisons.rs @@ -22,30 +22,36 @@ fn test_unchecked_function( { let mut rng = rand::thread_rng(); - let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize; + let num_bits_per_block = param.message_modulus.0.ilog2(); + let num_block = divide_ceil(256usize, num_bits_per_block as usize); let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - for _ in 0..num_test { - let clear_a = rng.gen::(); - let clear_b = rng.gen::(); + // Test with low number of blocks, as they take a different branches + // (regression tests) + for num_block in [num_block, 1, 2] { + let max = U256::MAX >> (U256::BITS - (num_block as u32 * num_bits_per_block)); + for _ in 0..num_test { + let clear_a = rng.gen::() & max; + let clear_b = rng.gen::() & max; - let a = cks.encrypt_radix(clear_a, num_block); - let b = cks.encrypt_radix(clear_b, num_block); + let a = cks.encrypt_radix(clear_a, num_block); + let b = cks.encrypt_radix(clear_b, num_block); - { - let result = unchecked_comparison_method(&sks, &a, &b); - let decrypted: U256 = cks.decrypt_radix(&result); - let expected_result = clear_fn(clear_a, clear_b); - assert_eq!(decrypted, expected_result); - } + { + let result = unchecked_comparison_method(&sks, &a, &b); + let decrypted: U256 = cks.decrypt_radix(&result); + let expected_result = clear_fn(clear_a, clear_b); + assert_eq!(decrypted, expected_result); + } - { - // Force case where lhs == rhs - let result = unchecked_comparison_method(&sks, &a, &a); - let decrypted: U256 = cks.decrypt_radix(&result); - let expected_result = clear_fn(clear_a, clear_a); - assert_eq!(decrypted, expected_result); + { + // Force case where lhs == rhs + let result = unchecked_comparison_method(&sks, &a, &a); + let decrypted: U256 = cks.decrypt_radix(&result); + let expected_result = clear_fn(clear_a, clear_a); + assert_eq!(decrypted, expected_result); + } } } }