Skip to content

Commit

Permalink
fix(gpu): make all async functions unsafe, fix cuda_drop binding, add…
Browse files Browse the repository at this point in the history
… missing sync
  • Loading branch information
agnesLeroy committed Jan 24, 2024
1 parent ae8d481 commit 11db96d
Show file tree
Hide file tree
Showing 15 changed files with 1,614 additions and 1,466 deletions.
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/src/cuda_bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ extern "C" {
pub fn cuda_drop_async(ptr: *mut c_void, v_stream: *const c_void) -> i32;

/// Free memory for pointer `ptr` on GPU `gpu_index` synchronously
pub fn cuda_drop(ptr: *mut c_void) -> i32;
pub fn cuda_drop(ptr: *mut c_void, gpu_index: u32) -> i32;

/// Get the maximum amount of shared memory on GPU `gpu_index`
pub fn cuda_get_max_shared_memory(gpu_index: u32) -> i32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,5 @@ pub fn cuda_multi_bit_programmable_bootstrap_lwe_ciphertext<Scalar>(
stream,
);
}
stream.synchronize();
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,5 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext<Scalar>(
stream,
);
}
stream.synchronize();
}
8 changes: 4 additions & 4 deletions tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + CastFrom<usize>>(
.iter()
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
.collect_vec();
let mut d_input_indexes = stream.malloc_async::<Scalar>(num_blocks as u32);
let mut d_output_indexes = stream.malloc_async::<Scalar>(num_blocks as u32);
stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes);
stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes);
let mut d_input_indexes = unsafe { stream.malloc_async::<Scalar>(num_blocks as u32) };
let mut d_output_indexes = unsafe { stream.malloc_async::<Scalar>(num_blocks as u32) };
unsafe { stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes) };
unsafe { stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes) };

cuda_keyswitch_lwe_ciphertext(
&d_ksk_big_to_small,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,21 @@ fn lwe_encrypt_multi_bit_pbs_decrypt_custom_mod<
}

let mut d_test_vector_indexes =
stream.malloc_async::<Scalar>(number_of_messages as u32);
stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes);
unsafe { stream.malloc_async::<Scalar>(number_of_messages as u32) };
unsafe { stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes) };

let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0;
let lwe_indexes_usize: Vec<usize> = (0..num_blocks).collect_vec();
let lwe_indexes = lwe_indexes_usize
.iter()
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
.collect_vec();
let mut d_output_indexes = stream.malloc_async::<Scalar>(num_blocks as u32);
let mut d_input_indexes = stream.malloc_async::<Scalar>(num_blocks as u32);
stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes);
stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes);
let mut d_output_indexes = unsafe { stream.malloc_async::<Scalar>(num_blocks as u32) };
let mut d_input_indexes = unsafe { stream.malloc_async::<Scalar>(num_blocks as u32) };
unsafe {
stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes);
stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes);
}

cuda_multi_bit_programmable_bootstrap_lwe_ciphertext(
&d_lwe_ciphertext_in,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,21 @@ fn lwe_encrypt_pbs_decrypt<
}

let mut d_test_vector_indexes =
stream.malloc_async::<Scalar>(number_of_messages as u32);
stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes);
unsafe { stream.malloc_async::<Scalar>(number_of_messages as u32) };
unsafe { stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes) };

let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0;
let lwe_indexes_usize: Vec<usize> = (0..num_blocks).collect_vec();
let lwe_indexes = lwe_indexes_usize
.iter()
.map(|&x| <usize as CastInto<Scalar>>::cast_into(x))
.collect_vec();
let mut d_output_indexes = stream.malloc_async::<Scalar>(num_blocks as u32);
let mut d_input_indexes = stream.malloc_async::<Scalar>(num_blocks as u32);
stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes);
stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes);
let mut d_output_indexes = unsafe { stream.malloc_async::<Scalar>(num_blocks as u32) };
let mut d_input_indexes = unsafe { stream.malloc_async::<Scalar>(num_blocks as u32) };
unsafe {
stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes);
stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes);
}

cuda_programmable_bootstrap_lwe_ciphertext(
&d_lwe_ciphertext_in,
Expand Down
50 changes: 32 additions & 18 deletions tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ impl<T: UnsignedInteger> CudaGlweCiphertextList<T> {
stream: &CudaStream,
) -> Self {
// Allocate memory in the device
let d_vec = stream.malloc_async(
(glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size)
* glwe_ciphertext_count.0) as u32,
);

let d_vec = unsafe {
stream.malloc_async(
(glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size)
* glwe_ciphertext_count.0) as u32,
)
};
stream.synchronize();
let cuda_glwe_list = CudaGlweList {
d_vec,
glwe_ciphertext_count,
Expand All @@ -43,13 +45,17 @@ impl<T: UnsignedInteger> CudaGlweCiphertextList<T> {
let polynomial_size = h_ct.polynomial_size();
let ciphertext_modulus = h_ct.ciphertext_modulus();

let mut d_vec = stream.malloc_async(
(glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size)
* glwe_ciphertext_count.0) as u32,
);

let mut d_vec = unsafe {
stream.malloc_async(
(glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size)
* glwe_ciphertext_count.0) as u32,
)
};
// Copy to the GPU
stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref());
unsafe {
stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref());
stream.synchronize();
}

let cuda_glwe_list = CudaGlweList {
d_vec,
Expand All @@ -70,8 +76,10 @@ impl<T: UnsignedInteger> CudaGlweCiphertextList<T> {
* glwe_ciphertext_size(self.0.glwe_dimension.to_glwe_size(), self.0.polynomial_size);
let mut container: Vec<T> = vec![T::ZERO; glwe_ct_size];

stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec);
stream.synchronize();
unsafe {
stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec);
stream.synchronize();
}

GlweCiphertextList::from_container(
container,
Expand All @@ -90,14 +98,20 @@ impl<T: UnsignedInteger> CudaGlweCiphertextList<T> {
let polynomial_size = h_ct.polynomial_size();
let ciphertext_modulus = h_ct.ciphertext_modulus();

let mut d_vec = stream.malloc_async(
(glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size)
* glwe_ciphertext_count.0) as u32,
);
let mut d_vec = unsafe {
stream.malloc_async(
(glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size)
* glwe_ciphertext_count.0) as u32,
)
};

// Copy to the GPU
let h_input = h_ct.as_view().into_container();
stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref());
stream.synchronize();
unsafe {
stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref());
}
stream.synchronize();

let cuda_glwe_list = CudaGlweList {
d_vec,
Expand Down
32 changes: 18 additions & 14 deletions tfhe/src/core_crypto/gpu/entities/lwe_bootstrap_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,25 @@ impl CudaLweBootstrapKey {
let glwe_dimension = bsk.glwe_size().to_glwe_dimension();

// Allocate memory
let mut d_vec = stream.malloc_async::<f64>(lwe_bootstrap_key_size(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_level_count,
) as u32);
let mut d_vec = unsafe {
stream.malloc_async::<f64>(lwe_bootstrap_key_size(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_level_count,
) as u32)
};
// Copy to the GPU
stream.convert_lwe_bootstrap_key_async(
&mut d_vec,
bsk.as_ref(),
input_lwe_dimension,
glwe_dimension,
decomp_level_count,
polynomial_size,
);
unsafe {
stream.convert_lwe_bootstrap_key_async(
&mut d_vec,
bsk.as_ref(),
input_lwe_dimension,
glwe_dimension,
decomp_level_count,
polynomial_size,
);
}
stream.synchronize();
Self {
d_vec,
Expand Down
41 changes: 28 additions & 13 deletions tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
stream: &CudaStream,
) -> Self {
// Allocate memory in the device
let d_vec =
stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32);
let d_vec = unsafe {
stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32)
};
stream.synchronize();

let cuda_lwe_list = CudaLweList {
d_vec,
Expand All @@ -41,10 +43,13 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {

// Copy to the GPU
let h_input = h_ct.as_view().into_container();
let mut d_vec =
stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32);
stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref());
stream.synchronize();
let mut d_vec = unsafe {
stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32)
};
unsafe {
stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref());
stream.synchronize();
}
let cuda_lwe_list = CudaLweList {
d_vec,
lwe_ciphertext_count,
Expand Down Expand Up @@ -73,8 +78,10 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
let lwe_ct_size = self.0.lwe_ciphertext_count.0 * self.0.lwe_dimension.to_lwe_size().0;
let mut container: Vec<T> = vec![T::ZERO; lwe_ct_size];

stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec);
stream.synchronize();
unsafe {
stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec);
stream.synchronize();
}

LweCiphertextList::from_container(
container,
Expand All @@ -92,8 +99,11 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
let ciphertext_modulus = h_ct.ciphertext_modulus();

// Copy to the GPU
let mut d_vec = stream.malloc_async((lwe_dimension.to_lwe_size().0) as u32);
stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref());
let mut d_vec = unsafe { stream.malloc_async((lwe_dimension.to_lwe_size().0) as u32) };
unsafe {
stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref());
}
stream.synchronize();

let cuda_lwe_list = CudaLweList {
d_vec,
Expand All @@ -108,7 +118,9 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
let lwe_ct_size = self.0.lwe_dimension.to_lwe_size().0;
let mut container: Vec<T> = vec![T::ZERO; lwe_ct_size];

stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec);
unsafe {
stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec);
}
stream.synchronize();

LweCiphertext::from_container(container, self.ciphertext_modulus())
Expand Down Expand Up @@ -148,8 +160,11 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
let ciphertext_modulus = self.ciphertext_modulus();

// Copy to the GPU
let mut d_vec = stream.malloc_async(self.0.d_vec.len() as u32);
stream.copy_gpu_to_gpu_async(&mut d_vec, &self.0.d_vec);
let mut d_vec = unsafe { stream.malloc_async(self.0.d_vec.len() as u32) };
unsafe {
stream.copy_gpu_to_gpu_async(&mut d_vec, &self.0.d_vec);
}
stream.synchronize();

let cuda_lwe_list = CudaLweList {
d_vec,
Expand Down
20 changes: 12 additions & 8 deletions tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@ impl<T: UnsignedInteger> CudaLweKeyswitchKey<T> {
let ciphertext_modulus = h_ksk.ciphertext_modulus();

// Allocate memory
let mut d_vec = stream.malloc_async::<T>(
(input_lwe_size.to_lwe_dimension().0
* lwe_keyswitch_key_input_key_element_encrypted_size(
decomp_level_count,
output_lwe_size,
)) as u32,
);
let mut d_vec = unsafe {
stream.malloc_async::<T>(
(input_lwe_size.to_lwe_dimension().0
* lwe_keyswitch_key_input_key_element_encrypted_size(
decomp_level_count,
output_lwe_size,
)) as u32,
)
};

stream.convert_lwe_keyswitch_key_async(&mut d_vec, h_ksk.as_ref());
unsafe {
stream.convert_lwe_keyswitch_key_async(&mut d_vec, h_ksk.as_ref());
}

stream.synchronize();

Expand Down
38 changes: 21 additions & 17 deletions tfhe/src/core_crypto/gpu/entities/lwe_multi_bit_bootstrap_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,30 @@ impl CudaLweMultiBitBootstrapKey {
let grouping_factor = bsk.grouping_factor();

// Allocate memory
let mut d_vec = stream.malloc_async::<u64>(
lwe_multi_bit_bootstrap_key_size(
let mut d_vec = unsafe {
stream.malloc_async::<u64>(
lwe_multi_bit_bootstrap_key_size(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_level_count,
grouping_factor,
)
.unwrap() as u32,
)
};
// Copy to the GPU
unsafe {
stream.convert_lwe_multi_bit_bootstrap_key_async(
&mut d_vec,
bsk.as_ref(),
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
glwe_dimension,
decomp_level_count,
polynomial_size,
grouping_factor,
)
.unwrap() as u32,
);
// Copy to the GPU
stream.convert_lwe_multi_bit_bootstrap_key_async(
&mut d_vec,
bsk.as_ref(),
input_lwe_dimension,
glwe_dimension,
decomp_level_count,
polynomial_size,
grouping_factor,
);
);
}
stream.synchronize();
Self {
d_vec,
Expand Down
Loading

0 comments on commit 11db96d

Please sign in to comment.