diff --git a/tfhe/src/integer/server_key/radix/mod.rs b/tfhe/src/integer/server_key/radix/mod.rs index 1a25bd7659..6e5348a585 100644 --- a/tfhe/src/integer/server_key/radix/mod.rs +++ b/tfhe/src/integer/server_key/radix/mod.rs @@ -7,6 +7,7 @@ mod scalar_add; pub(super) mod scalar_mul; pub(super) mod scalar_sub; mod shift; +pub(super) mod slice; mod sub; use super::ServerKey; diff --git a/tfhe/src/integer/server_key/radix/slice.rs b/tfhe/src/integer/server_key/radix/slice.rs new file mode 100644 index 0000000000..0890a4f01f --- /dev/null +++ b/tfhe/src/integer/server_key/radix/slice.rs @@ -0,0 +1,546 @@ +use std::ops::{Bound, RangeBounds}; + +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::prelude::CastFrom; +use crate::shortint; + +/// Error returned when the provided range for a slice is invalid +#[derive(Debug)] +pub enum InvalidRangeError { + /// The upper bound of the range is greater than the size of the integer + SliceTooBig, + /// The upper gound is smaller than the lower bound + WrongOrder, +} + +/// Normalize a rust bound object, and check that it is valid for the source integer +pub(in crate::integer) fn parse_bounds( + bounds: &B, + nb_bits: usize, +) -> Result<(usize, usize), InvalidRangeError> +where + B: RangeBounds, + T: CastFrom + Copy, + usize: CastFrom, +{ + let start = match bounds.start_bound() { + Bound::Included(inc) => usize::cast_from(*inc), + Bound::Excluded(excl) => usize::cast_from(*excl) - 1, + Bound::Unbounded => 0, + }; + + let end = match bounds.end_bound() { + Bound::Included(inc) => usize::cast_from(*inc) + 1, + Bound::Excluded(excl) => usize::cast_from(*excl), + Bound::Unbounded => nb_bits, + }; + + // TODO: return an error if bounds are wrong + assert!(end <= nb_bits); + assert!(start <= end); + + Ok((start, end)) +} + +/// This is the operation to extract a non-aligned block, on the clear. +/// For example, with a 2x4bits integer: |abcd|efgh|, extracting the block +/// at offset 2 will return |cdef|. This function should be used inside a LUT. +pub(in crate::integer) fn slice_oneblock_clear_unaligned( + cur_block: u64, + next_block: u64, + offset: usize, + block_size: usize, +) -> u64 { + cur_block >> (offset) | ((next_block << (block_size - offset)) & (1 << block_size)) +} + +impl ServerKey { + /// Extract a slice of blocks from a ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg = 225; + /// let start_block = 1; + /// let end_block = 2; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.blockslice(&ct, start_block, end_block); + /// + /// let blocksize = cks.parameters().message_modulus().0.ilog2() as u64; + /// let start_bit = (start_block as u64) * blocksize; + /// let end_bit = (end_block as u64) * blocksize; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn blockslice( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + end_block: usize, + ) -> RadixCiphertext { + let limit = end_block - start_block; + + let mut result: RadixCiphertext = self.create_trivial_zero_radix(limit); + + for (res_i, c_i) in result.blocks[..limit] + .iter_mut() + .zip(ctxt.blocks[start_block..].iter()) + { + res_i.clone_from(c_i); + } + + result + } + + /// Extract a slice of blocks from a ciphertext. + /// + /// The result is assigned in the input ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg = 225; + /// let start_block = 1; + /// let end_block = 2; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.blockslice_assign(&mut ct, start_block, end_block); + /// + /// let blocksize = cks.parameters().message_modulus().0.ilog2() as u64; + /// let start_bit = (start_block as u64) * blocksize; + /// let end_bit = (end_block as u64) * blocksize; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn blockslice_assign( + &self, + ctxt: &mut RadixCiphertext, + start_block: usize, + end_block: usize, + ) { + *ctxt = self.blockslice(ctxt, start_block, end_block); + } + + /// Return the unaligned remainder of a slice after all the unaligned full blocks have been + /// extracted. This is similar to what [`slice_interblock`] does on each block except that the + /// remainder is not a full block, so it will be truncated to `count` bits. + pub(in crate::integer) fn bitslice_remainder_unaligned( + &self, + ctxt: &RadixCiphertext, + block_idx: usize, + offset: usize, + count: usize, + ) -> shortint::Ciphertext { + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) % (1 << count) + }); + + self.key.apply_lookup_table_bivariate( + &ctxt.blocks[block_idx], + &ctxt + .blocks + .get(block_idx + 1) + .cloned() + .unwrap_or_else(|| self.key.create_trivial(0)), + &lut, + ) + } + + /// Returnsthe remainder of a slice after all the full blocks have been extracted. This will + /// simply truncate the block value to `count` bits. + pub(in crate::integer) fn bitslice_remainder( + &self, + ctxt: &RadixCiphertext, + block_idx: usize, + count: usize, + ) -> shortint::Ciphertext { + let lut = self.key.generate_lookup_table(|block| block % (1 << count)); + + self.key.apply_lookup_table(&ctxt.blocks[block_idx], &lut) + } + + /// Extract a slice from a ciphertext. The size of the slice is a multiple of the block + /// size but it is not aligned on block boundaries, so we need to mix block n and (n+1) toG + /// create a new block, using the lut function `slice_oneblock_clear_unaligned`. + pub fn blockslice_unaligned( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + block_count: usize, + offset: usize, + ) -> RadixCiphertext { + let mut blocks = Vec::new(); + + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) + }); + + for idx in 0..block_count { + let block = self.key.apply_lookup_table_bivariate( + &ctxt.blocks[idx + start_block], + &ctxt.blocks[idx + start_block + 1], + &lut, + ); + + blocks.push(block); + } + + RadixCiphertext::from(blocks) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .unchecked_scalar_bitslice(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn unchecked_scalar_bitslice( + &self, + ctxt: &RadixCiphertext, + bounds: B, + ) -> Result + where + B: RangeBounds + std::fmt::Debug, + T: CastFrom + Copy, + usize: CastFrom, + { + let block_width = self.message_modulus().0.ilog2() as usize; + let (start, end) = parse_bounds(&bounds, self.message_modulus().0 * ctxt.blocks.len())?; + + let slice_width = end - start; + + // If the starting bit is block aligned, we can do most of the slicing with block copies. + // If it's not we must extract the bits with PBS. In either cases, we must extract the last + // bits with a PBS if the slice size is not a multiple of the block size. + let mut sliced = if start % block_width != 0 { + let mut sliced = self.blockslice_unaligned( + ctxt, + start / block_width, + slice_width / block_width, + start % block_width, + ); + + if slice_width % block_width != 0 { + let last_block = self.bitslice_remainder_unaligned( + &ctxt, + start / block_width + slice_width / block_width, + start % block_width, + slice_width % block_width, + ); + sliced.blocks.push(last_block); + } + + sliced + } else { + let mut sliced = self.blockslice(ctxt, start / block_width, end / block_width); + if slice_width % block_width != 0 { + let last_block = + self.bitslice_remainder(&ctxt, end / block_width, slice_width % block_width); + sliced.blocks.push(last_block); + } + + sliced + }; + + // Extend with trivial zeroes to return an integer of the same size as the input one. + self.extend_radix_with_trivial_zero_blocks_msb_assign(&mut sliced, ctxt.blocks.len()); + Ok(sliced) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// sks.unchecked_scalar_bitslice(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn unchecked_scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + bounds: B, + ) -> Result<(), InvalidRangeError> + where + B: RangeBounds + std::fmt::Debug, + T: CastFrom + Copy, + usize: CastFrom, + { + *ctxt = self.unchecked_scalar_bitslice(ctxt, bounds)?; + Ok(()) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn scalar_bitslice( + &self, + ctxt: &RadixCiphertext, + bounds: B, + ) -> Result + where + B: RangeBounds + std::fmt::Debug, + T: CastFrom + Copy, + usize: CastFrom, + { + if !ctxt.block_carries_are_empty() { + let mut ctxt = ctxt.clone(); + self.full_propagate(&mut ctxt); + self.unchecked_scalar_bitslice(&ctxt, bounds) + } else { + self.unchecked_scalar_bitslice(ctxt, bounds) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + bounds: B, + ) -> Result<(), InvalidRangeError> + where + B: RangeBounds + std::fmt::Debug, + T: CastFrom + Copy, + usize: CastFrom, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + self.unchecked_scalar_bitslice_assign(ctxt, bounds) + } else { + self.unchecked_scalar_bitslice_assign(ctxt, bounds) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn smart_scalar_bitslice( + &self, + ctxt: &mut RadixCiphertext, + bounds: B, + ) -> Result + where + B: RangeBounds + std::fmt::Debug, + T: CastFrom + Copy, + usize: CastFrom, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + self.unchecked_scalar_bitslice(&ctxt, bounds) + } else { + self.unchecked_scalar_bitslice(ctxt, bounds) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg >> start_bit) & (end_bit - start_bit), clear); + /// ``` + pub fn smart_scalar_bitslice_assign( + &mut self, + ctxt: &mut RadixCiphertext, + bounds: B, + ) -> Result<(), InvalidRangeError> + where + B: RangeBounds + std::fmt::Debug, + T: CastFrom + Copy, + usize: CastFrom, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + self.unchecked_scalar_bitslice_assign(ctxt, bounds) + } else { + self.unchecked_scalar_bitslice_assign(ctxt, bounds) + } + } +} diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index 9192f3c6e1..41f98583d9 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -1,3 +1,5 @@ +use std::ops::RangeBounds; + use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::*; use crate::integer::server_key::radix_parallel::tests_unsigned::test_add::smart_add_test; @@ -8,11 +10,14 @@ use crate::integer::server_key::radix_parallel::tests_unsigned::test_sub::{ use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::{create_parametrized_test, create_parametrized_test_classical_params}; use crate::integer::{IntegerKeyKind, RadixCiphertext, ServerKey, SignedRadixCiphertext, U256}; +use crate::prelude::CastFrom; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::*; use rand::Rng; +use super::slice::parse_bounds; + /// Number of loop iteration within randomized tests #[cfg(not(tarpaulin))] pub(crate) const NB_TESTS: usize = 30; @@ -103,6 +108,7 @@ create_parametrized_test!( create_parametrized_test_classical_params!(integer_create_trivial_min_max); create_parametrized_test_classical_params!(integer_signed_decryption_correctly_sign_extend); +create_parametrized_test_classical_params!(integer_unchecked_scalar_slice); fn integer_encrypt_decrypt(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -1068,3 +1074,63 @@ fn integer_signed_decryption_correctly_sign_extend(param: impl Into().unwrap(), value as i128); } + +// Reference implementation of the slice using a conversion into a string of 0/1 to do the slicing. +fn slice_reference_impl(value: u64, bounds: B, modulus: u64) -> u64 +where + B: RangeBounds, + T: CastFrom + Copy, + u64: CastFrom, +{ + let (start, end) = parse_bounds(&bounds, modulus)?; + + let bin: String = format!("{value:064b}").chars().rev().collect(); + + let out_bin: String = bin[start..end].chars().rev().collect(); + u64::from_str_radix(&out_bin, 2).unwrap_or_default() +} + +// Todo: use fc executor ? +fn integer_unchecked_scalar_slice(param: ClassicPBSParameters) { + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + + for _ in 0..NB_TESTS { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % (modulus.ilog2() as u32); + let range_b = rng.gen::() % (modulus.ilog2() as u32); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt_radix(clear, NB_CTXT); + + // check exclusive bound + { + let ct_res = sks.unchecked_scalar_bitslice(&mut ct, range_start..range_end); + let dec_res: u64 = cks.decrypt_radix(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } + + // check inclusive bound + { + let ct_res = sks.unchecked_scalar_bitslice(&mut ct, range_start..=range_end); + let dec_res: u64 = cks.decrypt_radix(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..=range_end, modulus), + dec_res, + ); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 621a21fbba..2540b1868a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -23,6 +23,7 @@ mod sum; mod ilog2; mod reverse_bits; +mod slice; #[cfg(test)] pub(crate) mod tests_cases_unsigned; #[cfg(test)]