Skip to content

Commit

Permalink
chore(gpu): remove device synchronization in drop for CudaVec
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 2, 2024
1 parent c258d53 commit a7fdde2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
32 changes: 28 additions & 4 deletions tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use crate::core_crypto::gpu::{extract_lwe_samples_from_glwe_ciphertext_list_asyn
use crate::core_crypto::prelude::{MonomialDegree, UnsignedTorus};
use itertools::Itertools;

/// For each [`GLWE Ciphertext`] (`CudaGlweCiphertextList`) given as input, extract the nth
/// coefficient from its body as an [`LWE ciphertext`](`CudaLweCiphertextList`). This variant is
/// GPU-accelerated.
pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not
/// be dropped until stream is synchronised
pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
vec_nth: &[MonomialDegree],
Expand Down Expand Up @@ -58,3 +59,26 @@ pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
);
}
}

/// For each [`GLWE Ciphertext`] (`CudaGlweCiphertextList`) given as input, extract the nth
/// coefficient from its body as an [`LWE ciphertext`](`CudaLweCiphertextList`). This variant is
/// GPU-accelerated.
pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
vec_nth: &[MonomialDegree],
streams: &CudaStreams,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus,
{
unsafe {
cuda_extract_lwe_samples_from_glwe_ciphertext_list_async(
input_glwe_list,
output_lwe_list,
vec_nth,
streams,
);
}
streams.synchronize();
}
4 changes: 1 addition & 3 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut};
use crate::core_crypto::gpu::{synchronize_device, CudaStreams};
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::Numeric;
use std::collections::Bound::{Excluded, Included, Unbounded};
use std::ffi::c_void;
Expand Down Expand Up @@ -447,8 +447,6 @@ impl<T: Numeric> Drop for CudaVec<T> {
/// Free memory for pointer `ptr` synchronously
fn drop(&mut self) {
for &gpu_index in self.gpu_indexes.iter() {
// Synchronizes the device to be sure no stream is still using this pointer
synchronize_device(gpu_index);
unsafe { cuda_drop(self.get_mut_c_ptr(gpu_index), gpu_index) };
}
}
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/integer/gpu/server_key/radix/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ impl CudaServerKey {
let mut result = unsafe { ciphertexts[0].duplicate_async(streams) };

if ciphertexts.len() == 1 {
streams.synchronize();
return Some(result);
}

Expand Down

0 comments on commit a7fdde2

Please sign in to comment.