diff --git a/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs index bb47d57fa5..57d8b10873 100644 --- a/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs @@ -137,4 +137,8 @@ impl CudaGlweCiphertextList { pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus { self.0.ciphertext_modulus } + + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self(self.0.duplicate(streams)) + } } diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 552c3d730a..b833c9741e 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -729,6 +729,24 @@ pub struct CudaGlweList { pub ciphertext_modulus: CiphertextModulus, } +impl CudaGlweList { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + let d_vec = unsafe { + let mut d_vec = CudaVec::new_async(self.d_vec.len(), streams, 0); + d_vec.copy_from_gpu_async(&self.d_vec, streams, 0); + d_vec + }; + streams.synchronize(); + + Self { + d_vec, + glwe_ciphertext_count: self.glwe_ciphertext_count, + glwe_dimension: self.glwe_dimension, + polynomial_size: self.polynomial_size, + ciphertext_modulus: self.ciphertext_modulus, + } + } +} /// Get the number of GPUs on the machine pub fn get_number_of_gpus() -> i32 { unsafe { cuda_get_number_of_gpus() } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 033918acbe..43e9a29b48 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -1,10 +1,18 @@ use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; +use super::details::MaybeCloned; +#[cfg(feature = "gpu")] +use super::global_state::with_thread_local_cuda_streams_for_gpu_indexes; use super::keys::InternalServerKey; +#[cfg(feature = "gpu")] +use super::GpuIndex; use crate::backward_compatibility::compressed_ciphertext_list::CompressedCiphertextListVersions; use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; +#[cfg(feature = "gpu")] +use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::booleans::InnerBoolean; use crate::high_level_api::errors::UninitializedServerKey; +use crate::high_level_api::global_state::device_of_internal_keys; #[cfg(feature = "gpu")] use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::{FheIntId, FheUintId}; @@ -206,13 +214,119 @@ impl CompressedCiphertextListBuilder { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize)] pub(crate) enum InnerCompressedCiphertextList { Cpu(crate::integer::ciphertext::CompressedCiphertextList), #[cfg(feature = "gpu")] Cuda(crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList), } +impl<'de> serde::Deserialize<'de> for InnerCompressedCiphertextList { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + enum Fake { + Cpu(crate::integer::ciphertext::CompressedCiphertextList), + #[cfg(feature = "gpu")] + Cuda(crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList), + } + let mut new = match Fake::deserialize(deserializer)? { + Fake::Cpu(v) => Self::Cpu(v), + #[cfg(feature = "gpu")] + Fake::Cuda(v) => Self::Cuda(v), + }; + + if let Some(device) = device_of_internal_keys() { + new.move_to_device(device); + } + + Ok(new) + } +} + +impl InnerCompressedCiphertextList { + fn current_device(&self) -> crate::Device { + match self { + Self::Cpu(_) => crate::Device::Cpu, + #[cfg(feature = "gpu")] + Self::Cuda(_) => crate::Device::CudaGpu, + } + } + + fn move_to_device(&mut self, device: crate::Device) { + let new_value = match (&self, device) { + (Self::Cpu(_), crate::Device::Cpu) => None, + #[cfg(feature = "gpu")] + (Self::Cuda(cuda_ct), crate::Device::CudaGpu) => { + with_thread_local_cuda_streams(|streams| { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + None + } else { + Some(Self::Cuda(cuda_ct.duplicate(streams))) + } + }) + } + #[cfg(feature = "gpu")] + (Self::Cuda(cuda_ct), crate::Device::Cpu) => { + let cpu_ct = with_thread_local_cuda_streams_for_gpu_indexes( + cuda_ct.gpu_indexes(), + |streams| cuda_ct.to_compressed_ciphertext_list(streams), + ); + Some(Self::Cpu(cpu_ct)) + } + #[cfg(feature = "gpu")] + (Self::Cpu(cpu_ct), crate::Device::CudaGpu) => { + let cuda_ct = with_thread_local_cuda_streams(|streams| { + cpu_ct.to_cuda_compressed_ciphertext_list(streams) + }); + Some(Self::Cuda(cuda_ct)) + } + }; + + if let Some(v) = new_value { + *self = v; + } + } + + fn on_cpu(&self) -> MaybeCloned { + match self { + Self::Cpu(cpu_ct) => MaybeCloned::Borrowed(cpu_ct), + #[cfg(feature = "gpu")] + Self::Cuda(cuda_ct) => { + let cpu_ct = with_thread_local_cuda_streams_for_gpu_indexes( + cuda_ct.gpu_indexes(), + |streams| cuda_ct.to_compressed_ciphertext_list(streams), + ); + MaybeCloned::Cloned(cpu_ct) + } + } + } + + #[cfg(feature = "gpu")] + fn on_gpu( + &self, + streams: &CudaStreams, + ) -> MaybeCloned< + crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList, + > { + match self { + Self::Cpu(cpu_ct) => { + let cuda_ct = cpu_ct.to_cuda_compressed_ciphertext_list(streams); + MaybeCloned::Cloned(cuda_ct) + } + Self::Cuda(cuda_ct) => { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + MaybeCloned::Borrowed(cuda_ct) + } else { + MaybeCloned::Cloned(cuda_ct.duplicate(streams)) + } + } + } + } +} + impl Versionize for InnerCompressedCiphertextList { type Versioned<'vers> = ::VersionedOwned; @@ -315,53 +429,47 @@ impl CiphertextList for CompressedCiphertextList { where T: HlExpandable + Tagged, { - match &self.inner { - InnerCompressedCiphertextList::Cpu(inner) => { - crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { - Some(InternalServerKey::Cpu(cpu_key)) => cpu_key - .key - .decompression_key - .as_ref() - .ok_or_else(|| { - crate::Error::new("Compression key not set in server key".to_owned()) - }) - .and_then(|decompression_key| { - let mut ct = inner.get::(index, decompression_key); - if let Ok(Some(ct_ref)) = &mut ct { - ct_ref.tag_mut().set_data(cpu_key.tag.data()) - } - ct - }), - _ => Err(crate::Error::new( - "A Cpu server key is needed to be set".to_string(), - )), + // We use the server key to know where computation should happen, + // if the data is not on the correct device, a temporary copy (and transfer) will happen + // + // This should be mitigated by the fact that the deserialization uses the current sks as a + // hint on where to move data. + crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { + Some(InternalServerKey::Cpu(cpu_key)) => cpu_key + .key + .decompression_key + .as_ref() + .ok_or_else(|| { + crate::Error::new("Compression key not set in server key".to_owned()) }) - } + .and_then(|decompression_key| { + let mut ct = self.inner.on_cpu().get::(index, decompression_key); + if let Ok(Some(ct_ref)) = &mut ct { + ct_ref.tag_mut().set_data(cpu_key.tag.data()) + } + ct + }), #[cfg(feature = "gpu")] - InnerCompressedCiphertextList::Cuda(inner) => { - crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { - Some(InternalServerKey::Cuda(cuda_key)) => cuda_key - .key - .decompression_key - .as_ref() - .ok_or_else(|| { - crate::Error::new("Compression key not set in server key".to_owned()) - }) - .and_then(|decompression_key| { - let mut ct = with_thread_local_cuda_streams(|streams| { - inner.get::(index, decompression_key, streams) - }); - if let Ok(Some(ct_ref)) = &mut ct { - ct_ref.tag_mut().set_data(cuda_key.tag.data()) - } - ct - }), - _ => Err(crate::Error::new( - "A Cuda server key is needed to be set".to_string(), - )), + Some(InternalServerKey::Cuda(cuda_key)) => cuda_key + .key + .decompression_key + .as_ref() + .ok_or_else(|| { + crate::Error::new("Compression key not set in server key".to_owned()) }) - } - } + .and_then(|decompression_key| { + let mut ct = with_thread_local_cuda_streams(|streams| { + self.inner + .on_gpu(streams) + .get::(index, decompression_key, streams) + }); + if let Ok(Some(ct_ref)) = &mut ct { + ct_ref.tag_mut().set_data(cuda_key.tag.data()) + } + ct + }), + None => Err(UninitializedServerKey.into()), + }) } } @@ -389,6 +497,23 @@ impl CompressedCiphertextList { tag, } } + + pub fn current_device(&self) -> crate::Device { + self.inner.current_device() + } + + pub fn move_to_current_device(&mut self) { + if let Some(device) = device_of_internal_keys() { + self.inner.move_to_device(device); + } + } + #[cfg(feature = "gpu")] + pub fn gpu_indexes(&self) -> &[GpuIndex] { + match &self.inner { + InnerCompressedCiphertextList::Cpu(_) => &[], + InnerCompressedCiphertextList::Cuda(cuda_ct) => cuda_ct.gpu_indexes(), + } + } } #[cfg(feature = "gpu")] @@ -563,13 +688,14 @@ pub mod gpu { #[cfg(test)] mod tests { use crate::prelude::*; + use crate::safe_serialization::{safe_deserialize, safe_serialize}; use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::multi_bit::tuniform::p_fail_2_minus_64::ks_pbs::V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::{ - set_server_key, ClientKey, CompressedCiphertextList, CompressedCiphertextListBuilder, - FheBool, FheInt64, FheUint16, FheUint2, FheUint32, + set_server_key, unset_server_key, ClientKey, CompressedCiphertextList, + CompressedCiphertextListBuilder, FheBool, FheInt64, FheUint16, FheUint2, FheUint32, }; #[test] @@ -600,12 +726,44 @@ mod tests { .push(ct4); set_server_key(sk.decompress()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + let compressed_list = compressed_list_builder.build().unwrap(); + + // Add a serialize-deserialize round trip as it will generally be + // how compressed list are use as its meant for data exchange + let mut serialized = vec![]; + safe_serialize(&compressed_list, &mut serialized, 1024 * 1024 * 16) + .expect("safe serialize succeeds"); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + + check_is_correct(&compressed_list, &ck); #[cfg(feature = "gpu")] { set_server_key(sk.decompress_to_gpu()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + check_is_correct(&compressed_list, &ck); + } + + // Now redo the tests, but with the server_key not being set when deserializing + // meaning, the deserialization process could not use that as a hint on where to put + // the data + { + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress()); + check_is_correct(&compressed_list, &ck); + + #[cfg(feature = "gpu")] + { + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress_to_gpu()); + check_is_correct(&compressed_list, &ck); + } } } @@ -623,24 +781,49 @@ mod tests { ct4.move_to_device(crate::Device::Cpu); let mut compressed_list_builder = CompressedCiphertextListBuilder::new(); - compressed_list_builder + let compressed_list = compressed_list_builder .push(ct1) .push(ct2) .push(ct3) - .push(ct4); + .push(ct4) + .build() + .unwrap(); + + // Add a serialize-deserialize round trip as it will generally be + // how compressed list are use as its meant for data exchange + let mut serialized = vec![]; + safe_serialize(&compressed_list, &mut serialized, 1024 * 1024 * 16) + .expect("safe serialize succeeds"); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); set_server_key(sk.decompress()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + check_is_correct(&compressed_list, &ck); set_server_key(sk.decompress_to_gpu()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + check_is_correct(&compressed_list, &ck); + + // Now redo the tests, but with the server_key not being set when deserializing + // meaning, the deserialization process could not use that as a hint on where to put + // the data + { + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress()); + check_is_correct(&compressed_list, &ck); + + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress_to_gpu()); + check_is_correct(&compressed_list, &ck); + } } fn check_is_correct(compressed_list: &CompressedCiphertextList, ck: &ClientKey) { - let serialized = bincode::serialize(&compressed_list).unwrap(); - - let compressed_list: CompressedCiphertextList = - bincode::deserialize(&serialized).unwrap(); { let a: FheUint32 = compressed_list.get(0).unwrap().unwrap(); let b: FheInt64 = compressed_list.get(1).unwrap().unwrap(); diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index b48c3a189f..756f0ab803 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -1,6 +1,7 @@ use crate::core_crypto::entities::packed_integers::PackedIntegers; use crate::core_crypto::entities::GlweCiphertextList; use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList; +use crate::core_crypto::gpu::vec::GpuIndex; use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext; use crate::core_crypto::prelude::{ @@ -66,6 +67,9 @@ pub struct CudaCompressedCiphertextList { } impl CudaCompressedCiphertextList { + pub fn gpu_indexes(&self) -> &[GpuIndex] { + &self.packed_list.glwe_ciphertext_list.0.d_vec.gpu_indexes + } pub fn len(&self) -> usize { self.info.len() } @@ -229,6 +233,13 @@ impl CudaCompressedCiphertextList { info: self.info.clone(), } } + + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + packed_list: self.packed_list.duplicate(streams), + info: self.info.clone(), + } + } } impl CompressedCiphertextList { diff --git a/tfhe/src/integer/gpu/list_compression/server_keys.rs b/tfhe/src/integer/gpu/list_compression/server_keys.rs index cafef0da98..da714593a3 100644 --- a/tfhe/src/integer/gpu/list_compression/server_keys.rs +++ b/tfhe/src/integer/gpu/list_compression/server_keys.rs @@ -47,6 +47,20 @@ pub struct CudaPackedGlweCiphertext { pub initial_len: usize, } +impl CudaPackedGlweCiphertext { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + glwe_ciphertext_list: self.glwe_ciphertext_list.duplicate(streams), + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + bodies_count: self.bodies_count, + storage_log_modulus: self.storage_log_modulus, + lwe_per_glwe: self.lwe_per_glwe, + initial_len: self.initial_len, + } + } +} + impl Clone for CudaPackedGlweCiphertext { fn clone(&self) -> Self { Self {