diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh index 26bd6befed..a71c2c5beb 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh @@ -37,8 +37,8 @@ __global__ void pack(Torus *array_out, Torus *array_in, uint32_t log_modulus, template __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index, - Torus *array_out, Torus *array_in, uint32_t num_inputs, - uint32_t body_count, int_compression *mem_ptr) { + Torus *array_out, Torus *array_in, uint32_t body_count, + int_compression *mem_ptr) { cudaSetDevice(gpu_index); auto params = mem_ptr->compression_params; @@ -105,7 +105,7 @@ __host__ void host_integer_compress(cudaStream_t *streams, check_cuda_error(cudaGetLastError()); host_pack(streams[0], gpu_indexes[0], glwe_array_out, tmp_glwe_array_out, - num_glwes, body_count, mem_ptr); + body_count, mem_ptr); } template diff --git a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs index d41784dae7..963af08045 100644 --- a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs @@ -80,11 +80,11 @@ use crate::core_crypto::prelude::*; #[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedModulusSwitchedGlweCiphertextVersions)] pub struct CompressedModulusSwitchedGlweCiphertext { - packed_integers: PackedIntegers, - glwe_dimension: GlweDimension, - polynomial_size: PolynomialSize, - bodies_count: LweCiphertextCount, - uncompressed_ciphertext_modulus: CiphertextModulus, + pub(crate) packed_integers: PackedIntegers, + pub(crate) glwe_dimension: GlweDimension, + pub(crate) polynomial_size: PolynomialSize, + pub(crate) bodies_count: LweCiphertextCount, + pub(crate) uncompressed_ciphertext_modulus: CiphertextModulus, } impl CompressedModulusSwitchedGlweCiphertext { diff --git a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs index 72b81e7926..68231dceb1 100644 --- a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs @@ -97,7 +97,7 @@ impl CompressedCiphertextListBuilder { #[versionize(CompressedCiphertextListVersions)] pub struct CompressedCiphertextList { pub(crate) packed_list: ShortintCompressedCiphertextList, - info: Vec, + pub(crate) info: Vec, } impl CompressedCiphertextList { diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index 8cc1fab697..6e95216f88 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -1,5 +1,8 @@ +use crate::core_crypto::entities::packed_integers::PackedIntegers; use crate::core_crypto::gpu::CudaStreams; -use crate::integer::ciphertext::DataKind; +use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext; +use crate::core_crypto::prelude::{CiphertextCount, ContiguousEntityContainer, LweCiphertextCount}; +use crate::integer::ciphertext::{CompressedCiphertextList, DataKind}; use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; use crate::integer::gpu::ciphertext::{ CudaRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext, @@ -7,6 +10,8 @@ use crate::integer::gpu::ciphertext::{ use crate::integer::gpu::list_compression::server_keys::{ CudaCompressionKey, CudaDecompressionKey, CudaPackedGlweCiphertext, }; +use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList; +use itertools::Itertools; pub struct CudaCompressedCiphertextList { pub(crate) packed_list: CudaPackedGlweCiphertext, @@ -45,6 +50,60 @@ impl CudaCompressedCiphertextList { streams, ) } + + pub fn to_compressed_ciphertext_list(&self, streams: &CudaStreams) -> CompressedCiphertextList { + let glwe_list = self + .packed_list + .glwe_ciphertext_list + .to_glwe_ciphertext_list(streams); + let ciphertext_modulus = self.packed_list.glwe_ciphertext_list.ciphertext_modulus(); + + let first_element = self.packed_list.block_info.first().unwrap(); + let message_modulus = first_element.message_modulus; + let carry_modulus = first_element.carry_modulus; + let pbs_order = first_element.pbs_order; + let lwe_per_glwe = self.packed_list.lwe_per_glwe; + let log_modulus = self.packed_list.storage_log_modulus; + + let initial_len = self.packed_list.initial_len; + let number_bits_to_pack = initial_len * log_modulus.0; + let len = number_bits_to_pack.div_ceil(u64::BITS as usize); + + let modulus_switched_glwe_ciphertext_list = glwe_list + .iter() + .map(|x| { + let glwe_dimension = x.glwe_size().to_glwe_dimension(); + let polynomial_size = x.polynomial_size(); + CompressedModulusSwitchedGlweCiphertext { + packed_integers: PackedIntegers { + packed_coeffs: x.into_container()[0..len].to_vec(), + log_modulus: self.packed_list.storage_log_modulus, + initial_len, + }, + glwe_dimension, + polynomial_size, + bodies_count: LweCiphertextCount(self.packed_list.bodies_count), + uncompressed_ciphertext_modulus: ciphertext_modulus, + } + }) + .collect_vec(); + + let count = CiphertextCount(self.packed_list.bodies_count); + let packed_list = ShortintCompressedCiphertextList { + modulus_switched_glwe_ciphertext_list, + ciphertext_modulus, + message_modulus, + carry_modulus, + pbs_order, + lwe_per_glwe, + count, + }; + + CompressedCiphertextList { + packed_list: packed_list, + info: self.info.clone(), + } + } } pub trait CudaCompressible { @@ -136,8 +195,9 @@ impl CudaCompressedCiphertextListBuilder { #[cfg(test)] mod tests { use super::*; + use crate::integer::ciphertext::CompressedCiphertextListBuilder; use crate::integer::gpu::gen_keys_radix_gpu; - use crate::integer::ClientKey; + use crate::integer::{BooleanBlock, ClientKey, RadixCiphertext, SignedRadixCiphertext}; use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64; @@ -198,4 +258,82 @@ mod tests { assert!(decrypted); } } + #[test] + fn test_gpu_compressed_ciphertext_conversion_to_cpu() { + let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64); + + let private_compression_key = + cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64); + + let streams = CudaStreams::new_multi_gpu(); + + let num_blocks = 32; + let (radix_cks, _) = gen_keys_radix_gpu( + PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + num_blocks, + &streams, + ); + let (compressed_compression_key, compressed_decompression_key) = + radix_cks.new_compressed_compression_decompression_keys(&private_compression_key); + + let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams); + + let compression_key = compressed_compression_key.decompress(); + let decompression_key = compressed_decompression_key.decompress(); + + for _ in 0..NB_TESTS { + let ct1 = radix_cks.encrypt(3_u32); + let ct2 = radix_cks.encrypt_signed(-2); + let ct3 = radix_cks.encrypt_bool(true); + + // Copy to GPU + let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams); + let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams); + let d_ct3 = CudaBooleanBlock::from_boolean_block(&ct3, &streams); + + let cuda_compressed = CudaCompressedCiphertextListBuilder::new() + .push(d_ct1, &streams) + .push(d_ct2, &streams) + .push(d_ct3, &streams) + .build(&cuda_compression_key, &streams); + + let reference_compressed = CompressedCiphertextListBuilder::new() + .push(ct1) + .push(ct2) + .push(ct3) + .build(&compression_key); + + let converted_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams); + + let decompressed1: RadixCiphertext = converted_compressed + .get(0, &decompression_key) + .unwrap() + .unwrap(); + let reference_decompressed1 = reference_compressed + .get(0, &decompression_key) + .unwrap() + .unwrap(); + assert_eq!(decompressed1, reference_decompressed1); + + let decompressed2: SignedRadixCiphertext = converted_compressed + .get(1, &decompression_key) + .unwrap() + .unwrap(); + let reference_decompressed2 = reference_compressed + .get(1, &decompression_key) + .unwrap() + .unwrap(); + assert_eq!(decompressed2, reference_decompressed2); + + let decompressed3: BooleanBlock = converted_compressed + .get(2, &decompression_key) + .unwrap() + .unwrap(); + let reference_decompressed3 = reference_compressed + .get(2, &decompression_key) + .unwrap() + .unwrap(); + assert_eq!(decompressed3, reference_decompressed3); + } + } } diff --git a/tfhe/src/integer/gpu/list_compression/server_keys.rs b/tfhe/src/integer/gpu/list_compression/server_keys.rs index bd7d2e12d5..a2ee5e3b4f 100644 --- a/tfhe/src/integer/gpu/list_compression/server_keys.rs +++ b/tfhe/src/integer/gpu/list_compression/server_keys.rs @@ -32,6 +32,8 @@ pub struct CudaPackedGlweCiphertext { pub block_info: Vec, pub bodies_count: usize, pub storage_log_modulus: CiphertextModulusLog, + pub lwe_per_glwe: LweCiphertextCount, + pub initial_len: usize, } impl CudaCompressionKey { @@ -126,10 +128,12 @@ impl CudaCompressionKey { .map(|x| x.d_blocks.lwe_ciphertext_count().0) .sum(); + let num_glwes = num_lwes.div_ceil(self.lwe_per_glwe.0); + let mut output_glwe = CudaGlweCiphertextList::new( compress_glwe_size.to_glwe_dimension(), compress_polynomial_size, - GlweCiphertextCount(ciphertexts.len()), + GlweCiphertextCount(num_glwes), ciphertext_modulus, streams, ); @@ -159,11 +163,16 @@ impl CudaCompressionKey { info }; + let initial_len = + compress_glwe_size.to_glwe_dimension().0 * compress_polynomial_size.0 + num_lwes; + CudaPackedGlweCiphertext { glwe_ciphertext_list: output_glwe, block_info: info, bodies_count: num_lwes, storage_log_modulus: self.storage_log_modulus, + lwe_per_glwe: LweCiphertextCount(compress_polynomial_size.0), + initial_len, } } }