Skip to content

Commit

Permalink
fix(hlapi): ciphertext list decompress after safe_deser
Browse files Browse the repository at this point in the history
After a safe_serialize/safe_deserialize, the CompressedCiphertextList
was on Cpu. As the `get` method looked at the device of the data
and not the device of the server_key to know where computation
needs to happen, it meant that in this case decompressing using Gpu
was impossible, only Cpu was usable (as data was always onlu on Cpu)

The fix is twofold:
* First, when deserializing, the data will use the current serverkey
  (if any) as a hint on where data should be placed
* the `get` method now uses the server_key to know where computations
  needs to be done, which may incur a temporary copy/transfer on every
  call to `get` if the device is not correct.

The API to move data has also been added

Note that this was not the case when using regular serialize/deserialize
as this would store the device, so that deserialize was able to restore
into the same device (hence why the test which use serialie/deserialize
did not fail). In hindsight, the ser/de impl should not save which
device the data originated from
  • Loading branch information
tmontaigu committed Jan 30, 2025
1 parent 298640a commit 6b3a320
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 57 deletions.
4 changes: 4 additions & 0 deletions tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,8 @@ impl<T: UnsignedInteger> CudaGlweCiphertextList<T> {
pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus<T> {
self.0.ciphertext_modulus
}

pub fn duplicate(&self, streams: &CudaStreams) -> Self {
Self(self.0.duplicate(streams))
}
}
18 changes: 18 additions & 0 deletions tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,24 @@ pub struct CudaGlweList<T: UnsignedInteger> {
pub ciphertext_modulus: CiphertextModulus<T>,
}

impl<T: UnsignedInteger> CudaGlweList<T> {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
let d_vec = unsafe {
let mut d_vec = CudaVec::new_async(self.d_vec.len(), streams, 0);
d_vec.copy_from_gpu_async(&self.d_vec, streams, 0);
d_vec
};
streams.synchronize();

Self {
d_vec,
glwe_ciphertext_count: self.glwe_ciphertext_count,
glwe_dimension: self.glwe_dimension,
polynomial_size: self.polynomial_size,
ciphertext_modulus: self.ciphertext_modulus,
}
}
}
/// Get the number of GPUs on the machine
pub fn get_number_of_gpus() -> i32 {
unsafe { cuda_get_number_of_gpus() }
Expand Down
Loading

0 comments on commit 6b3a320

Please sign in to comment.