From 8f94ae5f858d8f83be78192e39d7baafd1525bcf Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Thu, 30 Jan 2025 12:42:28 +0000 Subject: [PATCH] fix(hlapi): ciphertext list decompress after safe_deser After a safe_serialize/safe_deserialize, the CompressedCiphertextList was on Cpu. As the `get` method looked at the device of the data and not the device of the server_key to know where computation needs to happen, it meant that in this case decompressing using Gpu was impossible, only Cpu was usable (as data was always onlu on Cpu) The fix is twofold: * First, when deserializing, the data will use the current serverkey (if any) as a hint on where data should be placed * the `get` method now uses the server_key to know where computations needs to be done, which may incur a temporary copy/transfer on every call to `get` if the device is not correct. The API to move data has also been added Note that this was not the case when using regular serialize/deserialize as this would store the device, so that deserialize was able to restore into the same device (hence why the test which use serialie/deserialize did not fail). In hindsight, the ser/de impl should not save which device the data originated from --- .../gpu/entities/glwe_ciphertext_list.rs | 4 + tfhe/src/core_crypto/gpu/mod.rs | 18 ++ .../compressed_ciphertext_list.rs | 295 ++++++++++++++---- .../ciphertext/compressed_ciphertext_list.rs | 11 + .../gpu/list_compression/server_keys.rs | 14 + 5 files changed, 285 insertions(+), 57 deletions(-) diff --git a/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs index bb47d57fa5..57d8b10873 100644 --- a/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs @@ -137,4 +137,8 @@ impl CudaGlweCiphertextList { pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus { self.0.ciphertext_modulus } + + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self(self.0.duplicate(streams)) + } } diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 8bed6d2661..b8324810fe 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -729,6 +729,24 @@ pub struct CudaGlweList { pub ciphertext_modulus: CiphertextModulus, } +impl CudaGlweList { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + let d_vec = unsafe { + let mut d_vec = CudaVec::new_async(self.d_vec.len(), streams, 0); + d_vec.copy_from_gpu_async(&self.d_vec, streams, 0); + d_vec + }; + streams.synchronize(); + + Self { + d_vec, + glwe_ciphertext_count: self.glwe_ciphertext_count, + glwe_dimension: self.glwe_dimension, + polynomial_size: self.polynomial_size, + ciphertext_modulus: self.ciphertext_modulus, + } + } +} /// Get the number of GPUs on the machine pub fn get_number_of_gpus() -> i32 { unsafe { cuda_get_number_of_gpus() } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 033918acbe..3ce480eaac 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -1,10 +1,18 @@ use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; +use super::details::MaybeCloned; +#[cfg(feature = "gpu")] +use super::global_state::with_thread_local_cuda_streams_for_gpu_indexes; use super::keys::InternalServerKey; +#[cfg(feature = "gpu")] +use super::GpuIndex; use crate::backward_compatibility::compressed_ciphertext_list::CompressedCiphertextListVersions; use crate::core_crypto::commons::math::random::{Deserialize, Serialize}; +#[cfg(feature = "gpu")] +use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::booleans::InnerBoolean; use crate::high_level_api::errors::UninitializedServerKey; +use crate::high_level_api::global_state::device_of_internal_keys; #[cfg(feature = "gpu")] use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::{FheIntId, FheUintId}; @@ -206,13 +214,119 @@ impl CompressedCiphertextListBuilder { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize)] pub(crate) enum InnerCompressedCiphertextList { Cpu(crate::integer::ciphertext::CompressedCiphertextList), #[cfg(feature = "gpu")] Cuda(crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList), } +impl<'de> serde::Deserialize<'de> for InnerCompressedCiphertextList { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + enum Fake { + Cpu(crate::integer::ciphertext::CompressedCiphertextList), + #[cfg(feature = "gpu")] + Cuda(crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList), + } + let mut new = match Fake::deserialize(deserializer)? { + Fake::Cpu(v) => Self::Cpu(v), + #[cfg(feature = "gpu")] + Fake::Cuda(v) => Self::Cuda(v), + }; + + if let Some(device) = device_of_internal_keys() { + new.move_to_device(device); + } + + Ok(new) + } +} + +impl InnerCompressedCiphertextList { + fn current_device(&self) -> crate::Device { + match self { + Self::Cpu(_) => crate::Device::Cpu, + #[cfg(feature = "gpu")] + Self::Cuda(_) => crate::Device::CudaGpu, + } + } + + fn move_to_device(&mut self, device: crate::Device) { + let new_value = match (&self, device) { + (Self::Cpu(_), crate::Device::Cpu) => None, + #[cfg(feature = "gpu")] + (Self::Cuda(cuda_ct), crate::Device::CudaGpu) => { + with_thread_local_cuda_streams(|streams| { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + None + } else { + Some(Self::Cuda(cuda_ct.duplicate(streams))) + } + }) + } + #[cfg(feature = "gpu")] + (Self::Cuda(cuda_ct), crate::Device::Cpu) => { + let cpu_ct = with_thread_local_cuda_streams_for_gpu_indexes( + cuda_ct.gpu_indexes(), + |streams| cuda_ct.to_compressed_ciphertext_list(streams), + ); + Some(Self::Cpu(cpu_ct)) + } + #[cfg(feature = "gpu")] + (Self::Cpu(cpu_ct), crate::Device::CudaGpu) => { + let cuda_ct = with_thread_local_cuda_streams(|streams| { + cpu_ct.to_cuda_compressed_ciphertext_list(streams) + }); + Some(Self::Cuda(cuda_ct)) + } + }; + + if let Some(v) = new_value { + *self = v; + } + } + + fn on_cpu(&self) -> MaybeCloned { + match self { + Self::Cpu(cpu_ct) => MaybeCloned::Borrowed(cpu_ct), + #[cfg(feature = "gpu")] + Self::Cuda(cuda_ct) => { + let cpu_ct = with_thread_local_cuda_streams_for_gpu_indexes( + cuda_ct.gpu_indexes(), + |streams| cuda_ct.to_compressed_ciphertext_list(streams), + ); + MaybeCloned::Cloned(cpu_ct) + } + } + } + + #[cfg(feature = "gpu")] + fn on_gpu( + &self, + streams: &CudaStreams, + ) -> MaybeCloned< + crate::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextList, + > { + match self { + Self::Cpu(cpu_ct) => { + let cuda_ct = cpu_ct.to_cuda_compressed_ciphertext_list(streams); + MaybeCloned::Cloned(cuda_ct) + } + Self::Cuda(cuda_ct) => { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + MaybeCloned::Borrowed(cuda_ct) + } else { + MaybeCloned::Cloned(cuda_ct.duplicate(streams)) + } + } + } + } +} + impl Versionize for InnerCompressedCiphertextList { type Versioned<'vers> = ::VersionedOwned; @@ -315,53 +429,47 @@ impl CiphertextList for CompressedCiphertextList { where T: HlExpandable + Tagged, { - match &self.inner { - InnerCompressedCiphertextList::Cpu(inner) => { - crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { - Some(InternalServerKey::Cpu(cpu_key)) => cpu_key - .key - .decompression_key - .as_ref() - .ok_or_else(|| { - crate::Error::new("Compression key not set in server key".to_owned()) - }) - .and_then(|decompression_key| { - let mut ct = inner.get::(index, decompression_key); - if let Ok(Some(ct_ref)) = &mut ct { - ct_ref.tag_mut().set_data(cpu_key.tag.data()) - } - ct - }), - _ => Err(crate::Error::new( - "A Cpu server key is needed to be set".to_string(), - )), + // We use the server key to know where computation should happen, + // if the data is not on the correct device, a temporary copy (and transfer) will happen + // + // This should be mitigated by the fact that the deserialization uses the current sks as a + // hint on where to move data. + crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { + Some(InternalServerKey::Cpu(cpu_key)) => cpu_key + .key + .decompression_key + .as_ref() + .ok_or_else(|| { + crate::Error::new("Compression key not set in server key".to_owned()) }) - } + .and_then(|decompression_key| { + let mut ct = self.inner.on_cpu().get::(index, decompression_key); + if let Ok(Some(ct_ref)) = &mut ct { + ct_ref.tag_mut().set_data(cpu_key.tag.data()) + } + ct + }), #[cfg(feature = "gpu")] - InnerCompressedCiphertextList::Cuda(inner) => { - crate::high_level_api::global_state::try_with_internal_keys(|keys| match keys { - Some(InternalServerKey::Cuda(cuda_key)) => cuda_key - .key - .decompression_key - .as_ref() - .ok_or_else(|| { - crate::Error::new("Compression key not set in server key".to_owned()) - }) - .and_then(|decompression_key| { - let mut ct = with_thread_local_cuda_streams(|streams| { - inner.get::(index, decompression_key, streams) - }); - if let Ok(Some(ct_ref)) = &mut ct { - ct_ref.tag_mut().set_data(cuda_key.tag.data()) - } - ct - }), - _ => Err(crate::Error::new( - "A Cuda server key is needed to be set".to_string(), - )), + Some(InternalServerKey::Cuda(cuda_key)) => cuda_key + .key + .decompression_key + .as_ref() + .ok_or_else(|| { + crate::Error::new("Compression key not set in server key".to_owned()) }) - } - } + .and_then(|decompression_key| { + let mut ct = with_thread_local_cuda_streams(|streams| { + self.inner + .on_gpu(streams) + .get::(index, decompression_key, streams) + }); + if let Ok(Some(ct_ref)) = &mut ct { + ct_ref.tag_mut().set_data(cuda_key.tag.data()) + } + ct + }), + None => Err(UninitializedServerKey.into()), + }) } } @@ -389,6 +497,23 @@ impl CompressedCiphertextList { tag, } } + + pub fn current_device(&self) -> crate::Device { + self.inner.current_device() + } + + pub fn move_to_current_device(&mut self) { + if let Some(device) = device_of_internal_keys() { + self.inner.move_to_device(device); + } + } + #[cfg(feature = "gpu")] + pub fn gpu_indexes(&self) -> &[GpuIndex] { + match &self.inner { + InnerCompressedCiphertextList::Cpu(_) => &[], + InnerCompressedCiphertextList::Cuda(cuda_ct) => cuda_ct.gpu_indexes(), + } + } } #[cfg(feature = "gpu")] @@ -563,13 +688,14 @@ pub mod gpu { #[cfg(test)] mod tests { use crate::prelude::*; + use crate::safe_serialization::{safe_deserialize, safe_serialize}; use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::multi_bit::tuniform::p_fail_2_minus_64::ks_pbs::V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::PBSParameters; use crate::{ - set_server_key, ClientKey, CompressedCiphertextList, CompressedCiphertextListBuilder, - FheBool, FheInt64, FheUint16, FheUint2, FheUint32, + set_server_key, unset_server_key, ClientKey, CompressedCiphertextList, + CompressedCiphertextListBuilder, FheBool, FheInt64, FheUint16, FheUint2, FheUint32, }; #[test] @@ -600,12 +726,43 @@ mod tests { .push(ct4); set_server_key(sk.decompress()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + let compressed_list = compressed_list_builder.build().unwrap(); + + // Add a serialize-deserialize round trip as it will generally be + // how compressed list are use as its meant for data exchange + let mut serialized = vec![]; + safe_serialize(&compressed_list, &mut serialized, 1024 * 1024 * 16).unwrap(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + + check_is_correct(&compressed_list, &ck); #[cfg(feature = "gpu")] { set_server_key(sk.decompress_to_gpu()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + check_is_correct(&compressed_list, &ck); + } + + // Now redo the tests, but with the server_key not being set when deserializing + // meaning, the deserialization process could not use that as a hint on where to put + // the data + { + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress()); + check_is_correct(&compressed_list, &ck); + + #[cfg(feature = "gpu")] + { + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress_to_gpu()); + check_is_correct(&compressed_list, &ck); + } } } @@ -623,24 +780,48 @@ mod tests { ct4.move_to_device(crate::Device::Cpu); let mut compressed_list_builder = CompressedCiphertextListBuilder::new(); - compressed_list_builder + let compressed_list = compressed_list_builder .push(ct1) .push(ct2) .push(ct3) - .push(ct4); + .push(ct4) + .build() + .unwrap(); + + // Add a serialize-deserialize round trip as it will generally be + // how compressed list are use as its meant for data exchange + let mut serialized = vec![]; + safe_serialize(&compressed_list, &mut serialized, 1024 * 1024 * 16).unwrap(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); set_server_key(sk.decompress()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + check_is_correct(&compressed_list, &ck); set_server_key(sk.decompress_to_gpu()); - check_is_correct(&compressed_list_builder.build().unwrap(), &ck); + check_is_correct(&compressed_list, &ck); + + // Now redo the tests, but with the server_key not being set when deserializing + // meaning, the deserialization process could not use that as a hint on where to put + // the data + { + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress()); + check_is_correct(&compressed_list, &ck); + + unset_server_key(); + let compressed_list: CompressedCiphertextList = + safe_deserialize(serialized.as_slice(), 1024 * 1024 * 16).unwrap(); + assert_eq!(compressed_list.current_device(), crate::Device::Cpu); + set_server_key(sk.decompress_to_gpu()); + check_is_correct(&compressed_list, &ck); + } } fn check_is_correct(compressed_list: &CompressedCiphertextList, ck: &ClientKey) { - let serialized = bincode::serialize(&compressed_list).unwrap(); - - let compressed_list: CompressedCiphertextList = - bincode::deserialize(&serialized).unwrap(); { let a: FheUint32 = compressed_list.get(0).unwrap().unwrap(); let b: FheInt64 = compressed_list.get(1).unwrap().unwrap(); diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index b48c3a189f..756f0ab803 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -1,6 +1,7 @@ use crate::core_crypto::entities::packed_integers::PackedIntegers; use crate::core_crypto::entities::GlweCiphertextList; use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList; +use crate::core_crypto::gpu::vec::GpuIndex; use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext; use crate::core_crypto::prelude::{ @@ -66,6 +67,9 @@ pub struct CudaCompressedCiphertextList { } impl CudaCompressedCiphertextList { + pub fn gpu_indexes(&self) -> &[GpuIndex] { + &self.packed_list.glwe_ciphertext_list.0.d_vec.gpu_indexes + } pub fn len(&self) -> usize { self.info.len() } @@ -229,6 +233,13 @@ impl CudaCompressedCiphertextList { info: self.info.clone(), } } + + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + packed_list: self.packed_list.duplicate(streams), + info: self.info.clone(), + } + } } impl CompressedCiphertextList { diff --git a/tfhe/src/integer/gpu/list_compression/server_keys.rs b/tfhe/src/integer/gpu/list_compression/server_keys.rs index cafef0da98..da714593a3 100644 --- a/tfhe/src/integer/gpu/list_compression/server_keys.rs +++ b/tfhe/src/integer/gpu/list_compression/server_keys.rs @@ -47,6 +47,20 @@ pub struct CudaPackedGlweCiphertext { pub initial_len: usize, } +impl CudaPackedGlweCiphertext { + pub fn duplicate(&self, streams: &CudaStreams) -> Self { + Self { + glwe_ciphertext_list: self.glwe_ciphertext_list.duplicate(streams), + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + bodies_count: self.bodies_count, + storage_log_modulus: self.storage_log_modulus, + lwe_per_glwe: self.lwe_per_glwe, + initial_len: self.initial_len, + } + } +} + impl Clone for CudaPackedGlweCiphertext { fn clone(&self) -> Self { Self {