Skip to content

Commit

Permalink
fix(integer): rotations/shifts < 2 blocks
Browse files Browse the repository at this point in the history
This commit fixes a few bugs

* The shift/rotate functions used when blocks encrypt a number of bits
  that is a power of 2 was causing a panic when working on one block.
  - Also, when the number of blocks was low (e.g 2 blocks with 2_2
    params) a noise cleaning step was wrongly skipped

* The function used when blocks encrypt non power of 2 number of bits
  also had a problem

The test have been updated to test with different block sizes and check
the noise level

Overall these bugs only affected low block counts (e.g FheUint2,
FheUint4) ciphertexts
  • Loading branch information
tmontaigu committed Feb 13, 2025
1 parent 53a1f35 commit 37934e4
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 184 deletions.
8 changes: 7 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/block_shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ impl ServerKey {
T: IntegerRadixCiphertext,
{
if d_range.is_empty() {
return ct.clone();
let mut result = ct.clone();
result
.blocks_mut()
.par_iter_mut()
.filter(|b| b.noise_level > NoiseLevel::NOMINAL)
.for_each(|block| self.key.message_extract_assign(block));
return result;
}

assert!(
Expand Down
128 changes: 89 additions & 39 deletions tfhe/src/integer/server_key/radix_parallel/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,21 +370,59 @@ impl ServerKey {
where
T: IntegerRadixCiphertext,
{
if amount.blocks.is_empty() || ct.blocks().is_empty() {
return ct.clone();
}

let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
assert!(carry_bits_per_block >= message_bits_per_block);
assert!(message_bits_per_block.is_power_of_two());

// Extracts bits and put them in the bit index 2 (=> bit number 3)
// so that it is already aligned to the correct position of the cmux input,
// and we reduce noise growth
let mut shift_bit_extractor = BitExtractor::with_final_offset(
&amount.blocks,
self,
message_bits_per_block as usize,
message_bits_per_block as usize,
);
if ct.blocks().len() == 1 {
let lut = self
.key
.generate_lookup_table_bivariate(|input, first_shift_block| {
let shift_within_block = first_shift_block % message_bits_per_block;

assert!(message_bits_per_block.is_power_of_two());
match operation {
BarrelShifterOperation::LeftShift => {
(input << shift_within_block) % self.message_modulus().0
}
BarrelShifterOperation::LeftRotate => {
let shifted = (input << shift_within_block) % self.message_modulus().0;
let wrapped = input >> (shift_within_block);
shifted | wrapped
}
BarrelShifterOperation::RightRotate => {
let shifted = input >> shift_within_block;
let wrapped = (input << shift_within_block) % self.message_modulus().0;
wrapped | shifted
}
BarrelShifterOperation::RightShift => {
if T::IS_SIGNED {
let sign_bit_pos = message_bits_per_block - 1;
let sign_bit = (input >> sign_bit_pos) & 1;
let padding_block = (self.message_modulus().0 - 1) * sign_bit;

// Pad with sign bits to 'simulate' an arithmetic shift
let input = (padding_block << message_bits_per_block) | input;
(input >> shift_within_block) % self.message_modulus().0
} else {
input >> shift_within_block
}
}
}
});

let block = self.key.unchecked_apply_lookup_table_bivariate(
&ct.blocks()[0],
&amount.blocks[0],
&lut,
);

return T::from_blocks(vec![block]);
}

let message_for_block =
self.key
Expand All @@ -408,6 +446,45 @@ impl ServerKey {
b
}
});

// When doing right shift of a signed ciphertext, we do an arithmetic shift
// Thus, we need some special luts to be used on the last block
// (which has the sign bit)
let message_for_block_right_shift_signed =
if T::IS_SIGNED && operation == BarrelShifterOperation::RightShift {
let lut = self
.key
.generate_lookup_table_bivariate(|input, first_shift_block| {
let shift_within_block = first_shift_block % message_bits_per_block;
let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;

let sign_bit_pos = message_bits_per_block - 1;
let sign_bit = (input >> sign_bit_pos) & 1;
let padding_block = (self.message_modulus().0 - 1) * sign_bit;

if shift_to_next_block == 1 {
padding_block
} else {
// Pad with sign bits to 'simulate' an arithmetic shift
let input = (padding_block << message_bits_per_block) | input;
(input >> shift_within_block) % self.message_modulus().0
}
});
Some(lut)
} else {
None
};

// Extracts bits and put them in the bit index 2 (=> bit number 3)
// so that it is already aligned to the correct position of the cmux input,
// and we reduce noise growth
let mut shift_bit_extractor = BitExtractor::with_final_offset(
&amount.blocks,
self,
message_bits_per_block as usize,
message_bits_per_block as usize,
);

let message_for_next_block =
self.key
.generate_lookup_table_bivariate(|previous, first_shift_block| {
Expand Down Expand Up @@ -467,34 +544,6 @@ impl ServerKey {
}
});

// When doing right shift of a signed ciphertext, we do an arithmetic shift
// Thus, we need some special luts to be used on the last block
// (which has the sign big)
let message_for_block_right_shift_signed =
if T::IS_SIGNED && operation == BarrelShifterOperation::RightShift {
let lut = self
.key
.generate_lookup_table_bivariate(|input, first_shift_block| {
let shift_within_block = first_shift_block % message_bits_per_block;
let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;

let sign_bit_pos = message_bits_per_block - 1;
let sign_bit = (input >> sign_bit_pos) & 1;
let padding_block = (self.message_modulus().0 - 1) * sign_bit;

if shift_to_next_block == 1 {
padding_block
} else {
// Pad with sign bits to 'simulate' an arithmetic shift
let input = (padding_block << message_bits_per_block) | input;
(input >> shift_within_block) % self.message_modulus().0
}
});
Some(lut)
} else {
None
};

let message_for_next_block_right_shift_signed = if T::IS_SIGNED
&& operation == BarrelShifterOperation::RightShift
{
Expand Down Expand Up @@ -693,7 +742,8 @@ impl ServerKey {
) where
T: IntegerRadixCiphertext,
{
let num_blocks = shift.blocks.len();
// What matters is the len of the ct to shift, not the `shift` len
let num_blocks = ct.blocks().len();
let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
let total_nb_bits = message_bits_per_block * num_blocks as u64;
Expand Down
Loading

0 comments on commit 37934e4

Please sign in to comment.