Skip to content

Commit

Permalink
fix(integer): correct degree in small comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Feb 20, 2024
1 parent b708abb commit 42b7c2f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 36 deletions.
53 changes: 35 additions & 18 deletions tfhe/src/integer/server_key/comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,20 +261,33 @@ impl<'a> Comparator<'a> {
std::mem::swap(&mut sign_blocks_2, &mut sign_blocks);
}

let final_lut = self.server_key.key.generate_lookup_table(|x| {
let final_sign = reduce_two_orderings_function(x);
sign_result_handler_fn(final_sign)
});

// We don't use pack_block_assign as the offset '4' does not depend on params
let mut result = self.server_key.key.unchecked_scalar_mul(&sign_blocks[1], 4);
self.server_key
.key
.unchecked_add_assign(&mut result, &sign_blocks[0]);
self.server_key
.key
.apply_lookup_table_assign(&mut result, &final_lut);
result
if sign_blocks.len() == 2 {
let final_lut = self.server_key.key.generate_lookup_table(|x| {
let final_sign = reduce_two_orderings_function(x);
sign_result_handler_fn(final_sign)
});
// We don't use pack_block_assign as the offset '4' does not depend on params
let mut result = self.server_key.key.unchecked_scalar_mul(&sign_blocks[1], 4);
self.server_key
.key
.unchecked_add_assign(&mut result, &sign_blocks[0]);
self.server_key
.key
.apply_lookup_table_assign(&mut result, &final_lut);
result
} else {
let final_lut = self.server_key.key.generate_lookup_table(|x| {
// sign blocks have values in the set {0, 1, 2}
// here we force apply that modulus explicitely
// so that generate_lookup_table has the correct
// degree estimation
let final_sign = x % 3;
sign_result_handler_fn(final_sign)
});
self.server_key
.key
.apply_lookup_table(&sign_blocks[0], &final_lut)
}
}

/// Reduces a vec containing shortint blocks that encrypts a sign
Expand Down Expand Up @@ -322,10 +335,14 @@ impl<'a> Comparator<'a> {
.apply_lookup_table_assign(&mut result, &final_lut);
result
} else {
let final_lut = self
.server_key
.key
.generate_lookup_table(sign_result_handler_fn);
let final_lut = self.server_key.key.generate_lookup_table(|x| {
// sign blocks have values in the set {0, 1, 2}
// here we force apply that modulus explicitely
// so that generate_lookup_table has the correct
// degree estimation
let final_sign = x % 3;
sign_result_handler_fn(final_sign)
});
self.server_key
.key
.apply_lookup_table(&sign_blocks[0], &final_lut)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,36 @@ fn test_unchecked_function<UncheckedFn, ClearF>(
{
let mut rng = rand::thread_rng();

let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
let num_bits_per_block = param.message_modulus.0.ilog2();
let num_block = divide_ceil(256usize, num_bits_per_block as usize);

let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);

for _ in 0..num_test {
let clear_a = rng.gen::<U256>();
let clear_b = rng.gen::<U256>();
// Test with low number of blocks, as they take a different branches
// (regression tests)
for num_block in [num_block, 1, 2] {
let max = U256::MAX >> (U256::BITS - (num_block as u32 * num_bits_per_block));
for _ in 0..num_test {
let clear_a = rng.gen::<U256>() & max;
let clear_b = rng.gen::<U256>() & max;

let a = cks.encrypt_radix(clear_a, num_block);
let b = cks.encrypt_radix(clear_b, num_block);
let a = cks.encrypt_radix(clear_a, num_block);
let b = cks.encrypt_radix(clear_b, num_block);

{
let result = unchecked_comparison_method(&sks, &a, &b);
let decrypted: U256 = cks.decrypt_radix(&result);
let expected_result = clear_fn(clear_a, clear_b);
assert_eq!(decrypted, expected_result);
}
{
let result = unchecked_comparison_method(&sks, &a, &b);
let decrypted: U256 = cks.decrypt_radix(&result);
let expected_result = clear_fn(clear_a, clear_b);
assert_eq!(decrypted, expected_result);
}

{
// Force case where lhs == rhs
let result = unchecked_comparison_method(&sks, &a, &a);
let decrypted: U256 = cks.decrypt_radix(&result);
let expected_result = clear_fn(clear_a, clear_a);
assert_eq!(decrypted, expected_result);
{
// Force case where lhs == rhs
let result = unchecked_comparison_method(&sks, &a, &a);
let decrypted: U256 = cks.decrypt_radix(&result);
let expected_result = clear_fn(clear_a, clear_a);
assert_eq!(decrypted, expected_result);
}
}
}
}
Expand Down

0 comments on commit 42b7c2f

Please sign in to comment.