Skip to content

chore(gpu): add back async entry points #2004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ pub unsafe fn decompress_integer_radix_async<T: UnsignedInteger, B: Numeric>(
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_add_integer_radix_assign(
pub unsafe fn unchecked_add_integer_radix_assign_async(
streams: &CudaStreams,
radix_lwe_left: &mut CudaRadixCiphertext,
radix_lwe_right: &CudaRadixCiphertext,
Expand Down Expand Up @@ -589,7 +589,6 @@ pub unsafe fn unchecked_add_integer_radix_assign(
&radix_lwe_right_data,
);
update_noise_degree(radix_lwe_left, &radix_lwe_left_data);
streams.synchronize();
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -2163,7 +2162,7 @@ pub unsafe fn unchecked_rotate_left_integer_radix_kb_assign_async<
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_cmux_integer_radix_kb<T: UnsignedInteger, B: Numeric>(
pub unsafe fn unchecked_cmux_integer_radix_kb_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_out: &mut CudaRadixCiphertext,
radix_lwe_condition: &CudaBooleanBlock,
Expand Down Expand Up @@ -2347,7 +2346,6 @@ pub unsafe fn unchecked_cmux_integer_radix_kb<T: UnsignedInteger, B: Numeric>(
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
streams.synchronize()
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -3302,7 +3300,7 @@ pub(crate) unsafe fn unchecked_unsigned_overflowing_sub_integer_radix_kb_assign_
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_signed_abs_radix_kb_assign<T: UnsignedInteger, B: Numeric>(
pub unsafe fn unchecked_signed_abs_radix_kb_assign_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
ct: &mut CudaRadixCiphertext,
bootstrapping_key: &CudaVec<B>,
Expand Down Expand Up @@ -3382,7 +3380,6 @@ pub unsafe fn unchecked_signed_abs_radix_kb_assign<T: UnsignedInteger, B: Numeri
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
streams.synchronize()
}

#[allow(clippy::too_many_arguments)]
Expand Down
20 changes: 11 additions & 9 deletions tfhe/src/integer/gpu/server_key/radix/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{unchecked_signed_abs_radix_kb_assign, PBSType};
use crate::integer::gpu::{unchecked_signed_abs_radix_kb_assign_async, PBSType};

impl CudaServerKey {
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub fn unchecked_abs_assign<T>(&self, ct: &mut T, streams: &CudaStreams)
pub unsafe fn unchecked_abs_assign_async<T>(&self, ct: &mut T, streams: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
{
Expand All @@ -18,7 +18,7 @@ impl CudaServerKey {
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
unchecked_signed_abs_radix_kb_assign(
unchecked_signed_abs_radix_kb_assign_async(
streams,
ct.as_mut(),
&d_bsk.d_vec,
Expand All @@ -43,7 +43,7 @@ impl CudaServerKey {
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
unchecked_signed_abs_radix_kb_assign(
unchecked_signed_abs_radix_kb_assign_async(
streams,
ct.as_mut(),
&d_multibit_bsk.d_vec,
Expand Down Expand Up @@ -74,10 +74,11 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut res = ct.duplicate(streams);
let mut res = unsafe { ct.duplicate_async(streams) };
if T::IS_SIGNED {
self.unchecked_abs_assign(&mut res, streams);
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
}
streams.synchronize();
res
}

Expand Down Expand Up @@ -131,13 +132,14 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut res = ct.duplicate(streams);
let mut res = unsafe { ct.duplicate_async(streams) };
if !ct.block_carries_are_empty() {
self.full_propagate_assign(&mut res, streams);
unsafe { self.full_propagate_assign_async(&mut res, streams) };
};
if T::IS_SIGNED {
self.unchecked_abs_assign(&mut res, streams);
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
}
streams.synchronize();
res
}
}
45 changes: 28 additions & 17 deletions tfhe/src/integer/gpu/server_key/radix/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::integer::gpu::ciphertext::{
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{
unchecked_add_integer_radix_assign,
unchecked_add_integer_radix_assign_async,
unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async, PBSType,
};
use crate::integer::server_key::radix_parallel::OutputFlag;
Expand Down Expand Up @@ -70,7 +70,7 @@ impl CudaServerKey {
ct_right: &T,
streams: &CudaStreams,
) -> T {
let mut result = ct_left.duplicate(streams);
let mut result = unsafe { ct_left.duplicate_async(streams) };
self.add_assign(&mut result, ct_right, streams);
result
}
Expand All @@ -94,18 +94,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
Expand Down Expand Up @@ -179,7 +179,7 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
pub unsafe fn unchecked_add_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
Expand All @@ -204,10 +204,21 @@ impl CudaServerKey {
);

unsafe {
unchecked_add_integer_radix_assign(streams, ciphertext_left, ciphertext_right);
unchecked_add_integer_radix_assign_async(streams, ciphertext_left, ciphertext_right);
}
}

pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
Expand Down Expand Up @@ -396,7 +407,7 @@ impl CudaServerKey {
.iter_mut()
.filter(|ct| !ct.block_carries_are_empty())
.for_each(|ct| {
self.full_propagate_assign(&mut *ct, streams);
self.full_propagate_assign_async(&mut *ct, streams);
});

Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams))
Expand Down Expand Up @@ -456,14 +467,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
Expand All @@ -472,8 +483,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
Expand Down Expand Up @@ -643,14 +654,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
Expand All @@ -659,8 +670,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
Expand Down
26 changes: 13 additions & 13 deletions tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,18 +465,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
}
Expand Down Expand Up @@ -570,18 +570,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
Expand Down Expand Up @@ -675,18 +675,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
Expand Down Expand Up @@ -764,7 +764,7 @@ impl CudaServerKey {
streams: &CudaStreams,
) {
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}

self.unchecked_bitnot_assign_async(ct, streams);
Expand Down
Loading
Loading