diff --git a/tfhe/src/integer/server_key/radix/mod.rs b/tfhe/src/integer/server_key/radix/mod.rs index 900a546817..c775e8034d 100644 --- a/tfhe/src/integer/server_key/radix/mod.rs +++ b/tfhe/src/integer/server_key/radix/mod.rs @@ -12,7 +12,7 @@ mod sub; use super::ServerKey; use crate::integer::block_decomposition::DecomposableInto; -use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext}; +use crate::integer::ciphertext::{IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext}; use crate::integer::encryption::encrypt_words_radix_impl; use crate::integer::{BooleanBlock, SignedRadixCiphertext}; @@ -474,6 +474,157 @@ impl ServerKey { result } + /// Cast a RadixCiphertext or SignedRadixCiphertext to a RadixCiphertext + /// with a possibly different number of blocks + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::{gen_keys_radix, IntegerCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg = -2i8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// assert_eq!(ct1.blocks().len(), 4); + /// + /// let ct_res = sks.cast_to_unsigned(ct1, 8); + /// assert_eq!(ct_res.blocks().len(), 8); + /// + /// // Decrypt + /// let res: u16 = cks.decrypt(&ct_res); + /// assert_eq!(msg as u16, res); + /// ``` + pub fn cast_to_unsigned( + &self, + mut source: T, + target_num_blocks: usize, + ) -> RadixCiphertext { + if !source.block_carries_are_empty() { + self.full_propagate_parallelized(&mut source); + } + + let blocks = source.into_blocks(); + let current_num_blocks = blocks.len(); + + let blocks = if T::IS_SIGNED { + // Casting from signed to unsigned + // We have to trim or sign extend first + if target_num_blocks > current_num_blocks { + let mut ct_as_signed_radix = SignedRadixCiphertext::from_blocks(blocks); + let num_blocks_to_add = target_num_blocks - current_num_blocks; + self.extend_radix_with_sign_msb_assign(&mut ct_as_signed_radix, num_blocks_to_add); + ct_as_signed_radix.blocks + } else { + let mut ct_as_unsigned_radix = crate::integer::RadixCiphertext::from_blocks(blocks); + let num_blocks_to_remove = current_num_blocks - target_num_blocks; + self.trim_radix_blocks_msb_assign(&mut ct_as_unsigned_radix, num_blocks_to_remove); + ct_as_unsigned_radix.blocks + } + } else { + // Casting from unsigned to unsigned, this is just about trimming/extending with zeros + let mut ct_as_unsigned_radix = crate::integer::RadixCiphertext::from_blocks(blocks); + if target_num_blocks > current_num_blocks { + let num_blocks_to_add = target_num_blocks - current_num_blocks; + self.extend_radix_with_trivial_zero_blocks_msb_assign( + &mut ct_as_unsigned_radix, + num_blocks_to_add, + ); + } else { + let num_blocks_to_remove = current_num_blocks - target_num_blocks; + self.trim_radix_blocks_msb_assign(&mut ct_as_unsigned_radix, num_blocks_to_remove); + }; + ct_as_unsigned_radix.blocks + }; + + assert_eq!( + blocks.len(), + target_num_blocks, + "internal error, wrong number of blocks after casting" + ); + crate::integer::RadixCiphertext::from(blocks) + } + + /// Cast a RadixCiphertext or SignedRadixCiphertext to a SignedRadixCiphertext + /// with a possibly different number of blocks + /// + /// # Example + /// + ///```rust + /// use tfhe::integer::{gen_keys_radix, IntegerCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 8; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg = u16::MAX; + /// + /// let ct1 = cks.encrypt(msg); + /// assert_eq!(ct1.blocks().len(), num_blocks); + /// + /// let ct_res = sks.cast_to_signed(ct1, 4); + /// assert_eq!(ct_res.blocks().len(), 4); + /// + /// // Decrypt + /// let res: i8 = cks.decrypt_signed(&ct_res); + /// assert_eq!(msg as i8, res); + /// ``` + pub fn cast_to_signed( + &self, + mut source: T, + target_num_blocks: usize, + ) -> SignedRadixCiphertext { + if !source.block_carries_are_empty() { + self.full_propagate_parallelized(&mut source); + } + + let current_num_blocks = source.blocks().len(); + + let blocks = if T::IS_SIGNED { + // Casting from signed to signed + if target_num_blocks > current_num_blocks { + let mut ct_as_signed_radix = + SignedRadixCiphertext::from_blocks(source.into_blocks()); + let num_blocks_to_add = target_num_blocks - current_num_blocks; + self.extend_radix_with_sign_msb_assign(&mut ct_as_signed_radix, num_blocks_to_add); + ct_as_signed_radix.blocks + } else { + let mut ct_as_unsigned_radix = RadixCiphertext::from_blocks(source.into_blocks()); + let num_blocks_to_remove = current_num_blocks - target_num_blocks; + self.trim_radix_blocks_msb_assign(&mut ct_as_unsigned_radix, num_blocks_to_remove); + ct_as_unsigned_radix.blocks + } + } else { + // casting from unsigned to signed + let mut ct_as_unsigned_radix = RadixCiphertext::from_blocks(source.into_blocks()); + if target_num_blocks > current_num_blocks { + let num_blocks_to_add = target_num_blocks - current_num_blocks; + self.extend_radix_with_trivial_zero_blocks_msb_assign( + &mut ct_as_unsigned_radix, + num_blocks_to_add, + ); + } else { + let num_blocks_to_remove = current_num_blocks - target_num_blocks; + self.trim_radix_blocks_msb_assign(&mut ct_as_unsigned_radix, num_blocks_to_remove); + }; + ct_as_unsigned_radix.blocks + }; + + assert_eq!( + blocks.len(), + target_num_blocks, + "internal error, wrong number of blocks after casting" + ); + SignedRadixCiphertext::from_blocks(blocks) + } + /// Propagate the carry of the 'index' block to the next one. /// /// # Example