diff --git a/tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs b/tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs index 8f6e44a815..726c69de6d 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs @@ -5,16 +5,16 @@ use crate::core_crypto::gpu::{extract_lwe_samples_from_glwe_ciphertext_list_asyn use crate::core_crypto::prelude::{MonomialDegree, UnsignedTorus}; use itertools::Itertools; -/// For each [`GLWE Ciphertext`] (`CudaGlweCiphertextList`) given as input, extract the nth -/// coefficient from its body as an [`LWE ciphertext`](`CudaLweCiphertextList`). This variant is -/// GPU-accelerated. -pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list( +/// # Safety +/// +/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must not +/// be dropped until stream is synchronised +pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async( input_glwe_list: &CudaGlweCiphertextList, output_lwe_list: &mut CudaLweCiphertextList, vec_nth: &[MonomialDegree], streams: &CudaStreams, ) where - // CastInto required for PBS modulus switch which returns a usize Scalar: UnsignedTorus, { let in_lwe_dim = input_glwe_list @@ -58,3 +58,26 @@ pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list( ); } } + +/// For each [`GLWE Ciphertext`] (`CudaGlweCiphertextList`) given as input, extract the nth +/// coefficient from its body as an [`LWE ciphertext`](`CudaLweCiphertextList`). This variant is +/// GPU-accelerated. +pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list( + input_glwe_list: &CudaGlweCiphertextList, + output_lwe_list: &mut CudaLweCiphertextList, + vec_nth: &[MonomialDegree], + streams: &CudaStreams, +) where + // CastInto required for PBS modulus switch which returns a usize + Scalar: UnsignedTorus, +{ + unsafe { + cuda_extract_lwe_samples_from_glwe_ciphertext_list_async( + input_glwe_list, + output_lwe_list, + vec_nth, + streams, + ); + } + streams.synchronize(); +} diff --git a/tfhe/src/core_crypto/gpu/vec.rs b/tfhe/src/core_crypto/gpu/vec.rs index b18b791b3d..fc0e5d9d76 100644 --- a/tfhe/src/core_crypto/gpu/vec.rs +++ b/tfhe/src/core_crypto/gpu/vec.rs @@ -1,5 +1,5 @@ use crate::core_crypto::gpu::slice::{CudaSlice, CudaSliceMut}; -use crate::core_crypto::gpu::{synchronize_device, CudaStreams}; +use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::Numeric; use std::collections::Bound::{Excluded, Included, Unbounded}; use std::ffi::c_void; @@ -447,8 +447,6 @@ impl Drop for CudaVec { /// Free memory for pointer `ptr` synchronously fn drop(&mut self) { for &gpu_index in self.gpu_indexes.iter() { - // Synchronizes the device to be sure no stream is still using this pointer - synchronize_device(gpu_index); unsafe { cuda_drop(self.get_mut_c_ptr(gpu_index), gpu_index) }; } } diff --git a/tfhe/src/integer/gpu/server_key/radix/add.rs b/tfhe/src/integer/gpu/server_key/radix/add.rs index 06e6e832b2..3766edee91 100644 --- a/tfhe/src/integer/gpu/server_key/radix/add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/add.rs @@ -370,6 +370,7 @@ impl CudaServerKey { let mut result = unsafe { ciphertexts[0].duplicate_async(streams) }; if ciphertexts.len() == 1 { + streams.synchronize(); return Some(result); }