From 799829eab44a715ee153b2265ae2e4f380a4ab9f Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Mon, 5 Feb 2024 18:00:41 +0100 Subject: [PATCH] feat(gpu): cast between unsigned cuda radix ciphertexts --- tfhe/src/core_crypto/gpu/mod.rs | 61 ++- tfhe/src/core_crypto/gpu/vec.rs | 20 + tfhe/src/integer/gpu/ciphertext/mod.rs | 67 +++ tfhe/src/integer/gpu/server_key/mod.rs | 195 +------ tfhe/src/integer/gpu/server_key/radix/mod.rs | 547 +++++++++++++++++++ 5 files changed, 693 insertions(+), 197 deletions(-) diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 5502c323f3..584a0a8fc0 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -82,7 +82,6 @@ impl CudaStream { /// /// # Safety /// - /// - `dest` __must__ be a valid pointer to the GPU global memory /// - [CudaStream::synchronize] __must__ be called after the copy /// as soon as synchronization is required pub unsafe fn memset_async(&self, dest: &mut CudaVec, value: T) @@ -105,7 +104,6 @@ impl CudaStream { /// /// # Safety /// - /// - `dest` __must__ be a valid pointer to the GPU global memory /// - [CudaStream::synchronize] __must__ be called after the copy /// as soon as synchronization is required pub unsafe fn copy_to_gpu_async(&self, dest: &mut CudaVec, src: &[T]) @@ -131,8 +129,6 @@ impl CudaStream { /// /// # Safety /// - /// - `src` __must__ be a valid pointer to the GPU global memory - /// - `dest` __must__ be a valid pointer to the GPU global memory /// - [CudaStream::synchronize] __must__ be called after the copy /// as soon as synchronization is required pub unsafe fn copy_gpu_to_gpu_async(&self, dest: &mut CudaVec, src: &CudaVec) @@ -152,11 +148,66 @@ impl CudaStream { } } + /// Copies data between two CudaVec, selecting a range of `src` as target + /// + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after the copy + /// as soon as synchronization is required + pub unsafe fn copy_src_range_gpu_to_gpu_async( + &self, + range: R, + dest: &mut CudaVec, + src: &CudaVec, + ) where + R: std::ops::RangeBounds, + T: Numeric, + { + let (start, end) = src.range_bounds_to_start_end(range).into_inner(); + // size is > 0 thanks to this check + if end < start { + return; + } + assert!(end < src.len()); + assert!(end - start < dest.len()); + + let src_ptr = src.as_c_ptr().add(start * std::mem::size_of::()); + let size = (end - start + 1) * std::mem::size_of::(); + cuda_memcpy_async_gpu_to_gpu(dest.as_mut_c_ptr(), src_ptr, size as u64, self.as_c_ptr()); + } + + /// Copies data between two CudaVec, selecting a range of `dest` as target + /// + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after the copy + /// as soon as synchronization is required + pub unsafe fn copy_dest_range_gpu_to_gpu_async( + &self, + range: R, + dest: &mut CudaVec, + src: &CudaVec, + ) where + R: std::ops::RangeBounds, + T: Numeric, + { + let (start, end) = dest.range_bounds_to_start_end(range).into_inner(); + // size is > 0 thanks to this check + if end < start { + return; + } + assert!(end < dest.len()); + assert!(end - start < src.len()); + + let dest_ptr = dest.as_mut_c_ptr().add(start * std::mem::size_of::()); + let size = (end - start + 1) * std::mem::size_of::(); + cuda_memcpy_async_gpu_to_gpu(dest_ptr, src.as_c_ptr(), size as u64, self.as_c_ptr()); + } + /// Copies data from GPU pointer into slice /// /// # Safety /// - /// - `src` __must__ be a valid pointer to the GPU global memory /// - [CudaStream::synchronize] __must__ be called as soon as synchronization is /// required pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], src: &CudaVec) diff --git a/tfhe/src/core_crypto/gpu/vec.rs b/tfhe/src/core_crypto/gpu/vec.rs index 50ad3a0c55..d57fe545e7 100644 --- a/tfhe/src/core_crypto/gpu/vec.rs +++ b/tfhe/src/core_crypto/gpu/vec.rs @@ -1,5 +1,6 @@ use crate::core_crypto::gpu::{CudaDevice, CudaPtr, CudaStream}; use crate::core_crypto::prelude::Numeric; +use std::collections::Bound::{Excluded, Included, Unbounded}; use std::ffi::c_void; use std::marker::PhantomData; @@ -66,6 +67,25 @@ impl CudaVec { pub fn is_empty(&self) -> bool { self.len == 0 } + + pub(crate) fn range_bounds_to_start_end(&self, range: R) -> std::ops::RangeInclusive + where + R: std::ops::RangeBounds, + { + let start = match range.start_bound() { + Unbounded => 0usize, + Included(start) => *start, + Excluded(start) => *start + 1, + }; + + let end = match range.end_bound() { + Unbounded => self.len().saturating_sub(1), + Included(end) => *end, + Excluded(end) => end.saturating_sub(1), + }; + + start..=end + } } // SAFETY diff --git a/tfhe/src/integer/gpu/ciphertext/mod.rs b/tfhe/src/integer/gpu/ciphertext/mod.rs index 070ce830bd..9ace83a1f3 100644 --- a/tfhe/src/integer/gpu/ciphertext/mod.rs +++ b/tfhe/src/integer/gpu/ciphertext/mod.rs @@ -389,6 +389,70 @@ impl CudaRadixCiphertextInfo { .collect(), } } + + pub(crate) fn after_extend_radix_with_trivial_zero_blocks_lsb( + &self, + num_blocks: usize, + ) -> Self { + let mut new_block_info = Self { + blocks: Vec::with_capacity(self.blocks.len() + num_blocks), + }; + for _ in 0..num_blocks { + new_block_info.blocks.push(CudaBlockInfo { + degree: Degree::new(0), + message_modulus: self.blocks.first().unwrap().message_modulus, + carry_modulus: self.blocks.first().unwrap().carry_modulus, + pbs_order: self.blocks.first().unwrap().pbs_order, + noise_level: NoiseLevel::ZERO, + }); + } + for &b in self.blocks.iter() { + new_block_info.blocks.push(b); + } + new_block_info + } + + pub(crate) fn after_extend_radix_with_trivial_zero_blocks_msb( + &self, + num_blocks: usize, + ) -> Self { + let mut new_block_info = Self { + blocks: Vec::with_capacity(self.blocks.len() + num_blocks), + }; + for &b in self.blocks.iter() { + new_block_info.blocks.push(b); + } + for _ in 0..num_blocks { + new_block_info.blocks.push(CudaBlockInfo { + degree: Degree::new(0), + message_modulus: self.blocks.first().unwrap().message_modulus, + carry_modulus: self.blocks.first().unwrap().carry_modulus, + pbs_order: self.blocks.first().unwrap().pbs_order, + noise_level: NoiseLevel::ZERO, + }); + } + new_block_info + } + + pub(crate) fn after_trim_radix_blocks_lsb(&self, num_blocks: usize) -> Self { + let mut new_block_info = Self { + blocks: Vec::with_capacity(self.blocks.len().saturating_sub(num_blocks)), + }; + new_block_info + .blocks + .extend(self.blocks[num_blocks..].iter().copied()); + new_block_info + } + + pub(crate) fn after_trim_radix_blocks_msb(&self, num_blocks: usize) -> Self { + let mut new_block_info = Self { + blocks: Vec::with_capacity(self.blocks.len().saturating_sub(num_blocks)), + }; + new_block_info + .blocks + .extend(self.blocks[..num_blocks].iter().copied()); + new_block_info + } } // #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -399,6 +463,9 @@ pub struct CudaRadixCiphertext { } impl CudaRadixCiphertext { + pub fn new(d_blocks: CudaLweCiphertextList, info: CudaRadixCiphertextInfo) -> Self { + Self { d_blocks, info } + } /// Copies a RadixCiphertext to the GPU memory /// /// # Example diff --git a/tfhe/src/integer/gpu/server_key/mod.rs b/tfhe/src/integer/gpu/server_key/mod.rs index 407741c975..36b0e589e4 100644 --- a/tfhe/src/integer/gpu/server_key/mod.rs +++ b/tfhe/src/integer/gpu/server_key/mod.rs @@ -1,20 +1,14 @@ -use crate::core_crypto::commons::traits::contiguous_entity_container::ContiguousEntityContainerMut; use crate::core_crypto::gpu::lwe_bootstrap_key::CudaLweBootstrapKey; -use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey; use crate::core_crypto::gpu::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey; use crate::core_crypto::gpu::CudaStream; use crate::core_crypto::prelude::{ allocate_and_generate_new_lwe_keyswitch_key, par_allocate_and_generate_new_lwe_bootstrap_key, par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key, LweBootstrapKeyOwned, - LweCiphertextCount, LweCiphertextList, LweMultiBitBootstrapKeyOwned, -}; -use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; -use crate::integer::gpu::ciphertext::{ - CudaBlockInfo, CudaRadixCiphertext, CudaRadixCiphertextInfo, + LweMultiBitBootstrapKeyOwned, }; use crate::integer::ClientKey; -use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel}; +use crate::shortint::ciphertext::MaxDegree; use crate::shortint::engine::ShortintEngine; use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, PBSOrder}; @@ -45,7 +39,7 @@ pub struct CudaServerKey { } impl CudaServerKey { - /// Generates a server key that stores keys in the device memopry. + /// Generates a server key that stores keys in the device memory. /// /// # Example /// @@ -177,187 +171,4 @@ impl CudaServerKey { // } // }; // } - - /// # Safety - /// - /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must - /// not be dropped until stream is synchronised - pub(crate) unsafe fn propagate_single_carry_assign_async( - &self, - ct: &mut CudaRadixCiphertext, - stream: &CudaStream, - ) { - let num_blocks = ct.d_blocks.lwe_ciphertext_count().0 as u32; - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - stream.propagate_single_carry_classic_assign_async( - &mut ct.d_blocks.0.d_vec, - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - d_bsk.input_lwe_dimension(), - d_bsk.glwe_dimension(), - d_bsk.polynomial_size(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count(), - d_bsk.decomp_base_log(), - num_blocks, - ct.info.blocks.first().unwrap().message_modulus, - ct.info.blocks.first().unwrap().carry_modulus, - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - stream.propagate_single_carry_multibit_assign_async( - &mut ct.d_blocks.0.d_vec, - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - d_multibit_bsk.input_lwe_dimension(), - d_multibit_bsk.glwe_dimension(), - d_multibit_bsk.polynomial_size(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count(), - d_multibit_bsk.decomp_base_log(), - d_multibit_bsk.grouping_factor, - num_blocks, - ct.info.blocks.first().unwrap().message_modulus, - ct.info.blocks.first().unwrap().carry_modulus, - ); - } - }; - ct.info - .blocks - .iter_mut() - .for_each(|b| b.degree = Degree::new(b.message_modulus.0 - 1)); - } - - pub(crate) unsafe fn full_propagate_assign_async( - &self, - ct: &mut CudaRadixCiphertext, - stream: &CudaStream, - ) { - let num_blocks = ct.d_blocks.lwe_ciphertext_count().0 as u32; - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - stream.full_propagate_classic_assign_async( - &mut ct.d_blocks.0.d_vec, - &d_bsk.d_vec, - &self.key_switching_key.d_vec, - d_bsk.input_lwe_dimension(), - d_bsk.glwe_dimension(), - d_bsk.polynomial_size(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_bsk.decomp_level_count(), - d_bsk.decomp_base_log(), - num_blocks, - ct.info.blocks.first().unwrap().message_modulus, - ct.info.blocks.first().unwrap().carry_modulus, - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - stream.full_propagate_multibit_assign_async( - &mut ct.d_blocks.0.d_vec, - &d_multibit_bsk.d_vec, - &self.key_switching_key.d_vec, - d_multibit_bsk.input_lwe_dimension(), - d_multibit_bsk.glwe_dimension(), - d_multibit_bsk.polynomial_size(), - self.key_switching_key.decomposition_level_count(), - self.key_switching_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count(), - d_multibit_bsk.decomp_base_log(), - d_multibit_bsk.grouping_factor, - num_blocks, - ct.info.blocks.first().unwrap().message_modulus, - ct.info.blocks.first().unwrap().carry_modulus, - ); - } - }; - ct.info - .blocks - .iter_mut() - .for_each(|b| b.degree = Degree::new(b.message_modulus.0 - 1)); - } - - /// Create a ciphertext filled with zeros - /// - /// # Example - /// - /// ```rust - /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; - /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; - /// use tfhe::integer::gpu::gen_keys_radix_gpu; - /// use tfhe::integer::{gen_keys_radix, RadixCiphertext}; - /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - /// - /// let gpu_index = 0; - /// let device = CudaDevice::new(gpu_index); - /// let mut stream = CudaStream::new_unchecked(device); - /// - /// let num_blocks = 4; - /// - /// // Generate the client key and the server key: - /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); - /// - /// let d_ctxt: CudaRadixCiphertext = sks.create_trivial_zero_radix(num_blocks, &mut stream); - /// let ctxt = d_ctxt.to_radix_ciphertext(&mut stream); - /// - /// // Decrypt: - /// let dec: u64 = cks.decrypt(&ctxt); - /// assert_eq!(0, dec); - /// ``` - pub fn create_trivial_zero_radix( - &self, - num_blocks: usize, - stream: &CudaStream, - ) -> CudaRadixCiphertext { - self.create_trivial_radix(0, num_blocks, stream) - } - - pub fn create_trivial_radix( - &self, - scalar: T, - num_blocks: usize, - stream: &CudaStream, - ) -> CudaRadixCiphertext - where - T: DecomposableInto, - { - let lwe_size = match self.pbs_order { - PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_size(), - PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(), - }; - - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0) as u64; - - let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) - .iter_as::() - .chain(std::iter::repeat(0)) - .take(num_blocks); - let mut cpu_lwe_list = LweCiphertextList::new( - 0, - lwe_size, - LweCiphertextCount(num_blocks), - self.ciphertext_modulus, - ); - let mut info = Vec::with_capacity(num_blocks); - for (block_value, mut lwe) in decomposer.zip(cpu_lwe_list.iter_mut()) { - *lwe.get_mut_body().data = block_value * delta; - info.push(CudaBlockInfo { - degree: Degree::new(block_value as usize), - message_modulus: self.message_modulus, - carry_modulus: self.carry_modulus, - pbs_order: self.pbs_order, - noise_level: NoiseLevel::ZERO, - }); - } - - let d_blocks = CudaLweCiphertextList::from_lwe_ciphertext_list(&cpu_lwe_list, stream); - - CudaRadixCiphertext { - d_blocks, - info: CudaRadixCiphertextInfo { blocks: info }, - } - } } diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 4358e261f8..40a9f68a14 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -1,3 +1,16 @@ +use crate::core_crypto::entities::LweCiphertextList; +use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; +use crate::core_crypto::gpu::CudaStream; +use crate::core_crypto::prelude::{ContiguousEntityContainerMut, LweCiphertextCount}; +use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; +use crate::integer::gpu::ciphertext::{ + CudaBlockInfo, CudaRadixCiphertext, CudaRadixCiphertextInfo, +}; +use crate::integer::gpu::server_key::CudaBootstrappingKey; +use crate::integer::gpu::CudaServerKey; +use crate::shortint::ciphertext::{Degree, NoiseLevel}; +use crate::shortint::PBSOrder; + mod add; mod bitwise_op; mod cmux; @@ -15,3 +28,537 @@ mod sub; mod scalar_rotate; #[cfg(test)] mod tests; +impl CudaServerKey { + /// Create a trivial ciphertext filled with zeros + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::{gen_keys_radix, RadixCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let d_ctxt: CudaRadixCiphertext = sks.create_trivial_zero_radix(num_blocks, &mut stream); + /// let ctxt = d_ctxt.to_radix_ciphertext(&mut stream); + /// + /// // Decrypt: + /// let dec: u64 = cks.decrypt(&ctxt); + /// assert_eq!(0, dec); + /// ``` + pub fn create_trivial_zero_radix( + &self, + num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext { + self.create_trivial_radix(0, num_blocks, stream) + } + + /// Create a trivial ciphertext + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::{gen_keys_radix, RadixCiphertext}; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// let num_blocks = 4; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let d_ctxt: CudaRadixCiphertext = sks.create_trivial_radix(212u64, num_blocks, &mut stream); + /// let ctxt = d_ctxt.to_radix_ciphertext(&mut stream); + /// + /// // Decrypt: + /// let dec: u64 = cks.decrypt(&ctxt); + /// assert_eq!(212, dec); + /// ``` + pub fn create_trivial_radix( + &self, + scalar: T, + num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext + where + T: DecomposableInto, + { + let lwe_size = match self.pbs_order { + PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_size(), + PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(), + }; + + let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0) as u64; + + let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) + .iter_as::() + .chain(std::iter::repeat(0)) + .take(num_blocks); + let mut cpu_lwe_list = LweCiphertextList::new( + 0, + lwe_size, + LweCiphertextCount(num_blocks), + self.ciphertext_modulus, + ); + let mut info = Vec::with_capacity(num_blocks); + for (block_value, mut lwe) in decomposer.zip(cpu_lwe_list.iter_mut()) { + *lwe.get_mut_body().data = block_value * delta; + info.push(CudaBlockInfo { + degree: Degree::new(block_value as usize), + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + pbs_order: self.pbs_order, + noise_level: NoiseLevel::ZERO, + }); + } + + let d_blocks = CudaLweCiphertextList::from_lwe_ciphertext_list(&cpu_lwe_list, stream); + + CudaRadixCiphertext { + d_blocks, + info: CudaRadixCiphertextInfo { blocks: info }, + } + } + + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronized + pub(crate) unsafe fn propagate_single_carry_assign_async( + &self, + ct: &mut CudaRadixCiphertext, + stream: &CudaStream, + ) { + let num_blocks = ct.d_blocks.lwe_ciphertext_count().0 as u32; + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + stream.propagate_single_carry_classic_assign_async( + &mut ct.d_blocks.0.d_vec, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + d_bsk.input_lwe_dimension(), + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + num_blocks, + ct.info.blocks.first().unwrap().message_modulus, + ct.info.blocks.first().unwrap().carry_modulus, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + stream.propagate_single_carry_multibit_assign_async( + &mut ct.d_blocks.0.d_vec, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + d_multibit_bsk.input_lwe_dimension(), + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + d_multibit_bsk.grouping_factor, + num_blocks, + ct.info.blocks.first().unwrap().message_modulus, + ct.info.blocks.first().unwrap().carry_modulus, + ); + } + }; + ct.info + .blocks + .iter_mut() + .for_each(|b| b.degree = Degree::new(b.message_modulus.0 - 1)); + } + + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronized + pub(crate) unsafe fn full_propagate_assign_async( + &self, + ct: &mut CudaRadixCiphertext, + stream: &CudaStream, + ) { + let num_blocks = ct.d_blocks.lwe_ciphertext_count().0 as u32; + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + stream.full_propagate_classic_assign_async( + &mut ct.d_blocks.0.d_vec, + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + d_bsk.input_lwe_dimension(), + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + num_blocks, + ct.info.blocks.first().unwrap().message_modulus, + ct.info.blocks.first().unwrap().carry_modulus, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + stream.full_propagate_multibit_assign_async( + &mut ct.d_blocks.0.d_vec, + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + d_multibit_bsk.input_lwe_dimension(), + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + d_multibit_bsk.grouping_factor, + num_blocks, + ct.info.blocks.first().unwrap().message_modulus, + ct.info.blocks.first().unwrap().carry_modulus, + ); + } + }; + ct.info + .blocks + .iter_mut() + .for_each(|b| b.degree = Degree::new(b.message_modulus.0 - 1)); + } + + /// Prepend trivial zero LSB blocks to an existing [`CudaRadixCiphertext`] and returns the + /// result as a new [`CudaRadixCiphertext`]. This can be useful for casting operations. + /// + /// # Example + /// + ///```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::IntegerCiphertext; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 4; + /// + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let mut d_ct1: CudaRadixCiphertext = sks.create_trivial_radix(7u64, num_blocks, &mut stream); + /// let ct1 = d_ct1.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct1.blocks().len(), 4); + /// + /// let added_blocks = 2; + /// let d_ct_res = + /// sks.extend_radix_with_trivial_zero_blocks_lsb(&mut d_ct1, added_blocks, &mut stream); + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct_res.blocks().len(), 6); + /// + /// // Decrypt + /// let res: u64 = cks.decrypt(&ct_res); + /// assert_eq!( + /// 7 * (PARAM_MESSAGE_2_CARRY_2_KS_PBS.message_modulus.0 as u64).pow(added_blocks as u32), + /// res + /// ); + /// ``` + pub fn extend_radix_with_trivial_zero_blocks_lsb( + &self, + ct: &CudaRadixCiphertext, + num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext { + let new_num_blocks = ct.d_blocks.lwe_ciphertext_count().0 + num_blocks; + let ciphertext_modulus = ct.d_blocks.ciphertext_modulus(); + let lwe_size = ct.d_blocks.lwe_dimension().to_lwe_size(); + let shift = num_blocks * lwe_size.0; + + let mut extended_ct_vec = + unsafe { stream.malloc_async((new_num_blocks * lwe_size.0) as u32) }; + unsafe { + stream.memset_async(&mut extended_ct_vec, 0u64); + stream.copy_dest_range_gpu_to_gpu_async( + shift.., + &mut extended_ct_vec, + &ct.d_blocks.0.d_vec, + ); + } + stream.synchronize(); + let extended_ct_list = CudaLweCiphertextList::from_cuda_vec( + extended_ct_vec, + LweCiphertextCount(new_num_blocks), + ciphertext_modulus, + ); + + let extended_ct_info = ct + .info + .after_extend_radix_with_trivial_zero_blocks_lsb(num_blocks); + CudaRadixCiphertext::new(extended_ct_list, extended_ct_info) + } + + /// Append trivial zero MSB blocks to an existing [`CudaRadixCiphertext`] and returns the result + /// as a new [`CudaRadixCiphertext`]. This can be useful for casting operations. + /// + /// # Example + /// + ///```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::IntegerCiphertext; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 4; + /// + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let mut d_ct1: CudaRadixCiphertext = sks.create_trivial_radix(7u64, num_blocks, &mut stream); + /// let ct1 = d_ct1.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct1.blocks().len(), 4); + /// + /// let d_ct_res = sks.extend_radix_with_trivial_zero_blocks_msb(&d_ct1, 2, &mut stream); + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct_res.blocks().len(), 6); + /// + /// // Decrypt + /// let res: u64 = cks.decrypt(&ct_res); + /// assert_eq!(7, res); + /// ``` + pub fn extend_radix_with_trivial_zero_blocks_msb( + &self, + ct: &CudaRadixCiphertext, + num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext { + let new_num_blocks = ct.d_blocks.lwe_ciphertext_count().0 + num_blocks; + let ciphertext_modulus = ct.d_blocks.ciphertext_modulus(); + let lwe_size = ct.d_blocks.lwe_dimension().to_lwe_size(); + + let mut extended_ct_vec = + unsafe { stream.malloc_async((new_num_blocks * lwe_size.0) as u32) }; + unsafe { + stream.memset_async(&mut extended_ct_vec, 0u64); + stream.copy_gpu_to_gpu_async(&mut extended_ct_vec, &ct.d_blocks.0.d_vec); + } + stream.synchronize(); + let extended_ct_list = CudaLweCiphertextList::from_cuda_vec( + extended_ct_vec, + LweCiphertextCount(new_num_blocks), + ciphertext_modulus, + ); + + let extended_ct_info = ct + .info + .after_extend_radix_with_trivial_zero_blocks_msb(num_blocks); + CudaRadixCiphertext::new(extended_ct_list, extended_ct_info) + } + + /// Remove LSB blocks from an existing [`CudaRadixCiphertext`] and returns the result as a new + /// [`CudaRadixCiphertext`]. This can be useful for casting operations. + /// + /// # Example + /// + ///```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::IntegerCiphertext; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 4; + /// + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let mut d_ct1: CudaRadixCiphertext = sks.create_trivial_radix(119u64, num_blocks, &mut stream); + /// let ct1 = d_ct1.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct1.blocks().len(), 4); + /// + /// let d_ct_res = sks.trim_radix_blocks_lsb(&d_ct1, 2, &mut stream); + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct_res.blocks().len(), 2); + /// + /// // Decrypt + /// let res: u64 = cks.decrypt(&ct_res); + /// assert_eq!(7, res); + /// ``` + pub fn trim_radix_blocks_lsb( + &self, + ct: &CudaRadixCiphertext, + num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext { + let new_num_blocks = ct.d_blocks.lwe_ciphertext_count().0 - num_blocks; + let ciphertext_modulus = ct.d_blocks.ciphertext_modulus(); + let lwe_size = ct.d_blocks.lwe_dimension().to_lwe_size(); + let shift = num_blocks * lwe_size.0; + + let mut trimmed_ct_vec = + unsafe { stream.malloc_async((new_num_blocks * lwe_size.0) as u32) }; + unsafe { + stream.copy_src_range_gpu_to_gpu_async( + shift.., + &mut trimmed_ct_vec, + &ct.d_blocks.0.d_vec, + ); + } + stream.synchronize(); + let trimmed_ct_list = CudaLweCiphertextList::from_cuda_vec( + trimmed_ct_vec, + LweCiphertextCount(new_num_blocks), + ciphertext_modulus, + ); + + let trimmed_ct_info = ct.info.after_trim_radix_blocks_lsb(num_blocks); + CudaRadixCiphertext::new(trimmed_ct_list, trimmed_ct_info) + } + + /// Remove MSB blocks from an existing [`CudaRadixCiphertext`] and returns the result as a new + /// [`CudaRadixCiphertext`]. This can be useful for casting operations. + /// + /// # Example + /// + ///```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::IntegerCiphertext; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 4; + /// + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let mut d_ct1: CudaRadixCiphertext = sks.create_trivial_radix(119u64, num_blocks, &mut stream); + /// let ct1 = d_ct1.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct1.blocks().len(), 4); + /// + /// let d_ct_res = sks.trim_radix_blocks_msb(&d_ct1, 2, &mut stream); + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct_res.blocks().len(), 2); + /// + /// // Decrypt + /// let res: u64 = cks.decrypt(&ct_res); + /// assert_eq!(7, res); + /// ``` + pub fn trim_radix_blocks_msb( + &self, + ct: &CudaRadixCiphertext, + num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext { + let new_num_blocks = ct.d_blocks.lwe_ciphertext_count().0 - num_blocks; + let ciphertext_modulus = ct.d_blocks.ciphertext_modulus(); + let lwe_size = ct.d_blocks.lwe_dimension().to_lwe_size(); + let shift = new_num_blocks * lwe_size.0; + + let mut trimmed_ct_vec = + unsafe { stream.malloc_async((new_num_blocks * lwe_size.0) as u32) }; + unsafe { + stream.copy_src_range_gpu_to_gpu_async( + 0..shift, + &mut trimmed_ct_vec, + &ct.d_blocks.0.d_vec, + ); + } + stream.synchronize(); + let trimmed_ct_list = CudaLweCiphertextList::from_cuda_vec( + trimmed_ct_vec, + LweCiphertextCount(new_num_blocks), + ciphertext_modulus, + ); + + let trimmed_ct_info = ct.info.after_trim_radix_blocks_msb(num_blocks); + CudaRadixCiphertext::new(trimmed_ct_list, trimmed_ct_info) + } + + /// Cast a CudaRadixCiphertext to a CudaRadixCiphertext + /// with a possibly different number of blocks + /// + /// # Example + /// + ///```rust + /// use tfhe::core_crypto::gpu::{CudaDevice, CudaStream}; + /// use tfhe::integer::gpu::ciphertext::CudaRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::integer::IntegerCiphertext; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let num_blocks = 4; + /// let gpu_index = 0; + /// let device = CudaDevice::new(gpu_index); + /// let mut stream = CudaStream::new_unchecked(device); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &mut stream); + /// + /// let msg = 2u8; + /// + /// let mut d_ct1: CudaRadixCiphertext = sks.create_trivial_radix(msg, num_blocks, &mut stream); + /// let ct1 = d_ct1.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct1.blocks().len(), 4); + /// + /// let d_ct_res = sks.cast_to_unsigned(d_ct1, 8, &mut stream); + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// assert_eq!(ct_res.blocks().len(), 8); + /// + /// // Decrypt + /// let res: u16 = cks.decrypt(&ct_res); + /// assert_eq!(msg as u16, res); + /// ``` + pub fn cast_to_unsigned( + &self, + mut source: CudaRadixCiphertext, + target_num_blocks: usize, + stream: &CudaStream, + ) -> CudaRadixCiphertext { + if !source.block_carries_are_empty() { + unsafe { + self.full_propagate_assign_async(&mut source, stream); + } + stream.synchronize(); + } + let current_num_blocks = source.info.blocks.len(); + // Casting from unsigned to unsigned, this is just about trimming/extending with zeros + if target_num_blocks > current_num_blocks { + let num_blocks_to_add = target_num_blocks - current_num_blocks; + self.extend_radix_with_trivial_zero_blocks_msb(&source, num_blocks_to_add, stream) + } else { + let num_blocks_to_remove = current_num_blocks - target_num_blocks; + self.trim_radix_blocks_msb(&source, num_blocks_to_remove, stream) + } + } +}