Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(gpu): add back async entry points #2004

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
@@ -525,7 +525,7 @@ pub unsafe fn decompress_integer_radix_async<T: UnsignedInteger, B: Numeric>(
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_add_integer_radix_assign(
pub unsafe fn unchecked_add_integer_radix_assign_async(
streams: &CudaStreams,
radix_lwe_left: &mut CudaRadixCiphertext,
radix_lwe_right: &CudaRadixCiphertext,
@@ -589,7 +589,6 @@ pub unsafe fn unchecked_add_integer_radix_assign(
&radix_lwe_right_data,
);
update_noise_degree(radix_lwe_left, &radix_lwe_left_data);
streams.synchronize();
}

#[allow(clippy::too_many_arguments)]
@@ -2163,7 +2162,7 @@ pub unsafe fn unchecked_rotate_left_integer_radix_kb_assign_async<
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_cmux_integer_radix_kb<T: UnsignedInteger, B: Numeric>(
pub unsafe fn unchecked_cmux_integer_radix_kb_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
radix_lwe_out: &mut CudaRadixCiphertext,
radix_lwe_condition: &CudaBooleanBlock,
@@ -2347,7 +2346,6 @@ pub unsafe fn unchecked_cmux_integer_radix_kb<T: UnsignedInteger, B: Numeric>(
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
streams.synchronize()
}

#[allow(clippy::too_many_arguments)]
@@ -3302,7 +3300,7 @@ pub(crate) unsafe fn unchecked_unsigned_overflowing_sub_integer_radix_kb_assign_
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn unchecked_signed_abs_radix_kb_assign<T: UnsignedInteger, B: Numeric>(
pub unsafe fn unchecked_signed_abs_radix_kb_assign_async<T: UnsignedInteger, B: Numeric>(
streams: &CudaStreams,
ct: &mut CudaRadixCiphertext,
bootstrapping_key: &CudaVec<B>,
@@ -3382,7 +3380,6 @@ pub unsafe fn unchecked_signed_abs_radix_kb_assign<T: UnsignedInteger, B: Numeri
streams.len() as u32,
std::ptr::addr_of_mut!(mem_ptr),
);
streams.synchronize()
}

#[allow(clippy::too_many_arguments)]
20 changes: 11 additions & 9 deletions tfhe/src/integer/gpu/server_key/radix/abs.rs
Original file line number Diff line number Diff line change
@@ -2,14 +2,14 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{unchecked_signed_abs_radix_kb_assign, PBSType};
use crate::integer::gpu::{unchecked_signed_abs_radix_kb_assign_async, PBSType};

impl CudaServerKey {
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub fn unchecked_abs_assign<T>(&self, ct: &mut T, streams: &CudaStreams)
pub unsafe fn unchecked_abs_assign_async<T>(&self, ct: &mut T, streams: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
{
@@ -18,7 +18,7 @@ impl CudaServerKey {
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
unchecked_signed_abs_radix_kb_assign(
unchecked_signed_abs_radix_kb_assign_async(
streams,
ct.as_mut(),
&d_bsk.d_vec,
@@ -43,7 +43,7 @@ impl CudaServerKey {
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
unchecked_signed_abs_radix_kb_assign(
unchecked_signed_abs_radix_kb_assign_async(
streams,
ct.as_mut(),
&d_multibit_bsk.d_vec,
@@ -74,10 +74,11 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut res = ct.duplicate(streams);
let mut res = unsafe { ct.duplicate_async(streams) };
if T::IS_SIGNED {
self.unchecked_abs_assign(&mut res, streams);
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
}
streams.synchronize();
res
}

@@ -131,13 +132,14 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut res = ct.duplicate(streams);
let mut res = unsafe { ct.duplicate_async(streams) };
if !ct.block_carries_are_empty() {
self.full_propagate_assign(&mut res, streams);
unsafe { self.full_propagate_assign_async(&mut res, streams) };
};
if T::IS_SIGNED {
self.unchecked_abs_assign(&mut res, streams);
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
}
streams.synchronize();
res
}
}
45 changes: 28 additions & 17 deletions tfhe/src/integer/gpu/server_key/radix/add.rs
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ use crate::integer::gpu::ciphertext::{
};
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{
unchecked_add_integer_radix_assign,
unchecked_add_integer_radix_assign_async,
unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async, PBSType,
};
use crate::integer::server_key::radix_parallel::OutputFlag;
@@ -70,7 +70,7 @@ impl CudaServerKey {
ct_right: &T,
streams: &CudaStreams,
) -> T {
let mut result = ct_left.duplicate(streams);
let mut result = unsafe { ct_left.duplicate_async(streams) };
self.add_assign(&mut result, ct_right, streams);
result
}
@@ -94,18 +94,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
@@ -179,7 +179,7 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
pub unsafe fn unchecked_add_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
@@ -204,10 +204,21 @@ impl CudaServerKey {
);

unsafe {
unchecked_add_integer_radix_assign(streams, ciphertext_left, ciphertext_right);
unchecked_add_integer_radix_assign_async(streams, ciphertext_left, ciphertext_right);
}
}

pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
@@ -396,7 +407,7 @@ impl CudaServerKey {
.iter_mut()
.filter(|ct| !ct.block_carries_are_empty())
.for_each(|ct| {
self.full_propagate_assign(&mut *ct, streams);
self.full_propagate_assign_async(&mut *ct, streams);
});

Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams))
@@ -456,14 +467,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
@@ -472,8 +483,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
@@ -643,14 +654,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
@@ -659,8 +670,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
26 changes: 13 additions & 13 deletions tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs
Original file line number Diff line number Diff line change
@@ -465,18 +465,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
}
@@ -570,18 +570,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
@@ -675,18 +675,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
@@ -764,7 +764,7 @@ impl CudaServerKey {
streams: &CudaStreams,
) {
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}

self.unchecked_bitnot_assign_async(ct, streams);
53 changes: 37 additions & 16 deletions tfhe/src/integer/gpu/server_key/radix/cmux.rs
Original file line number Diff line number Diff line change
@@ -3,10 +3,14 @@ use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{unchecked_cmux_integer_radix_kb, CudaServerKey, PBSType};
use crate::integer::gpu::{unchecked_cmux_integer_radix_kb_async, CudaServerKey, PBSType};

impl CudaServerKey {
pub fn unchecked_if_then_else<T: CudaIntegerRadixCiphertext>(
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_if_then_else_async<T: CudaIntegerRadixCiphertext>(
&self,
condition: &CudaBooleanBlock,
true_ct: &T,
@@ -20,7 +24,7 @@ impl CudaServerKey {
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
unchecked_cmux_integer_radix_kb(
unchecked_cmux_integer_radix_kb_async(
stream,
result.as_mut(),
condition,
@@ -48,7 +52,7 @@ impl CudaServerKey {
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
unchecked_cmux_integer_radix_kb(
unchecked_cmux_integer_radix_kb_async(
stream,
result.as_mut(),
condition,
@@ -82,6 +86,19 @@ impl CudaServerKey {
result
}

pub fn unchecked_if_then_else<T: CudaIntegerRadixCiphertext>(
&self,
condition: &CudaBooleanBlock,
true_ct: &T,
false_ct: &T,
stream: &CudaStreams,
) -> T {
let result =
unsafe { self.unchecked_if_then_else_async(condition, true_ct, false_ct, stream) };
stream.synchronize();
result
}

pub fn if_then_else<T: CudaIntegerRadixCiphertext>(
&self,
condition: &CudaBooleanBlock,
@@ -92,20 +109,24 @@ impl CudaServerKey {
let mut tmp_true_ct;
let mut tmp_false_ct;

let true_ct = if true_ct.block_carries_are_empty() {
true_ct
} else {
tmp_true_ct = true_ct.duplicate(stream);
self.full_propagate_assign(&mut tmp_true_ct, stream);
&tmp_true_ct
let true_ct = unsafe {
if true_ct.block_carries_are_empty() {
true_ct
} else {
tmp_true_ct = true_ct.duplicate_async(stream);
self.full_propagate_assign_async(&mut tmp_true_ct, stream);
&tmp_true_ct
}
};

let false_ct = if false_ct.block_carries_are_empty() {
false_ct
} else {
tmp_false_ct = false_ct.duplicate(stream);
self.full_propagate_assign(&mut tmp_false_ct, stream);
&tmp_false_ct
let false_ct = unsafe {
if false_ct.block_carries_are_empty() {
false_ct
} else {
tmp_false_ct = false_ct.duplicate_async(stream);
self.full_propagate_assign_async(&mut tmp_false_ct, stream);
&tmp_false_ct
}
};

self.unchecked_if_then_else(condition, true_ct, false_ct, stream)
64 changes: 32 additions & 32 deletions tfhe/src/integer/gpu/server_key/radix/comparison.rs
Original file line number Diff line number Diff line change
@@ -285,20 +285,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -379,20 +379,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -558,20 +558,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -666,20 +666,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -790,20 +790,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -914,20 +914,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -1161,20 +1161,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -1208,20 +1208,20 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct_left.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, ct_right)
}
(false, false) => {
tmp_lhs = ct_left.duplicate_async(streams);
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
16 changes: 8 additions & 8 deletions tfhe/src/integer/gpu/server_key/radix/div_mod.rs
Original file line number Diff line number Diff line change
@@ -136,19 +136,19 @@ impl CudaServerKey {
(true, true) => (numerator, divisor),
(true, false) => {
tmp_divisor = divisor.duplicate(streams);
self.full_propagate_assign(&mut tmp_divisor, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_divisor, streams) };
(numerator, &tmp_divisor)
}
(false, true) => {
tmp_numerator = numerator.duplicate(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_numerator, streams) };
(&tmp_numerator, divisor)
}
(false, false) => {
tmp_divisor = divisor.duplicate(streams);
tmp_numerator = numerator.duplicate(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign(&mut tmp_divisor, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_numerator, streams) };
unsafe { self.full_propagate_assign_async(&mut tmp_divisor, streams) };
(&tmp_numerator, &tmp_divisor)
}
};
@@ -176,19 +176,19 @@ impl CudaServerKey {
(true, true) => (numerator, divisor),
(true, false) => {
tmp_divisor = divisor.duplicate(streams);
self.full_propagate_assign(&mut tmp_divisor, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_divisor, streams) };
(numerator, &tmp_divisor)
}
(false, true) => {
tmp_numerator = numerator.duplicate(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_numerator, streams) };
(&tmp_numerator, divisor)
}
(false, false) => {
tmp_divisor = divisor.duplicate(streams);
tmp_numerator = numerator.duplicate(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign(&mut tmp_divisor, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_numerator, streams) };
unsafe { self.full_propagate_assign_async(&mut tmp_divisor, streams) };
(&tmp_numerator, &tmp_divisor)
}
};
12 changes: 6 additions & 6 deletions tfhe/src/integer/gpu/server_key/radix/ilog2.rs
Original file line number Diff line number Diff line change
@@ -847,7 +847,7 @@ impl CudaServerKey {
ct
} else {
tmp = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp, streams);
self.full_propagate_assign_async(&mut tmp, streams);
&tmp
};
self.unchecked_trailing_zeros_async(ct, streams)
@@ -920,7 +920,7 @@ impl CudaServerKey {
ct
} else {
tmp = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp, streams);
self.full_propagate_assign_async(&mut tmp, streams);
&tmp
};
self.unchecked_trailing_ones_async(ct, streams)
@@ -993,7 +993,7 @@ impl CudaServerKey {
ct
} else {
tmp = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp, streams);
self.full_propagate_assign_async(&mut tmp, streams);
&tmp
};
self.unchecked_leading_zeros_async(ct, streams)
@@ -1066,7 +1066,7 @@ impl CudaServerKey {
ct
} else {
tmp = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp, streams);
self.full_propagate_assign_async(&mut tmp, streams);
&tmp
};
self.unchecked_leading_ones_async(ct, streams)
@@ -1132,7 +1132,7 @@ impl CudaServerKey {
ct
} else {
tmp = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp, streams);
self.full_propagate_assign_async(&mut tmp, streams);
&tmp
};

@@ -1207,7 +1207,7 @@ impl CudaServerKey {
ct
} else {
tmp = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp, streams);
self.full_propagate_assign_async(&mut tmp, streams);
&tmp
};

7 changes: 3 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
@@ -386,7 +386,7 @@ impl CudaServerKey {
carry_out
}

pub(crate) fn full_propagate_assign<T: CudaIntegerRadixCiphertext>(
pub(crate) unsafe fn full_propagate_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct: &mut T,
streams: &CudaStreams,
@@ -445,7 +445,6 @@ impl CudaServerKey {
NoiseLevel::NOMINAL
};
});
streams.synchronize();
}

/// Prepend trivial zero LSB blocks to an existing [`CudaUnsignedRadixCiphertext`] or
@@ -1353,7 +1352,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !source.block_carries_are_empty() {
self.full_propagate_assign(&mut source, streams);
self.full_propagate_assign_async(&mut source, streams);
}
let current_num_blocks = source.as_ref().info.blocks.len();
if T::IS_SIGNED {
@@ -1459,7 +1458,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !source.block_carries_are_empty() {
self.full_propagate_assign(&mut source, streams);
self.full_propagate_assign_async(&mut source, streams);
}

let current_num_blocks = source.as_ref().info.blocks.len();
8 changes: 4 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/mul.rs
Original file line number Diff line number Diff line change
@@ -214,18 +214,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
2 changes: 1 addition & 1 deletion tfhe/src/integer/gpu/server_key/radix/neg.rs
Original file line number Diff line number Diff line change
@@ -142,7 +142,7 @@ impl CudaServerKey {
ctxt
} else {
tmp_ctxt = ctxt.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_ctxt, streams);
self.full_propagate_assign_async(&mut tmp_ctxt, streams);
&mut tmp_ctxt
};

32 changes: 16 additions & 16 deletions tfhe/src/integer/gpu/server_key/radix/rotate.rs
Original file line number Diff line number Diff line change
@@ -271,20 +271,20 @@ impl CudaServerKey {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = rotate.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -316,20 +316,20 @@ impl CudaServerKey {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&mut tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = rotate.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
@@ -425,20 +425,20 @@ impl CudaServerKey {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = rotate.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -470,20 +470,20 @@ impl CudaServerKey {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&mut tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = rotate.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
6 changes: 3 additions & 3 deletions tfhe/src/integer/gpu/server_key/radix/scalar_add.rs
Original file line number Diff line number Diff line change
@@ -185,7 +185,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
};

self.unchecked_scalar_add_assign_async(ct, scalar, streams);
@@ -228,7 +228,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u8> + CastInto<u64>,
{
if !ct_left.block_carries_are_empty() {
self.full_propagate_assign(ct_left, stream);
unsafe { self.full_propagate_assign_async(ct_left, stream) };
}
self.unchecked_unsigned_overflowing_scalar_add_assign(ct_left, scalar, stream)
}
@@ -330,7 +330,7 @@ impl CudaServerKey {
let mut tmp_lhs;
tmp_lhs = ct_left.duplicate(streams);
if !tmp_lhs.block_carries_are_empty() {
self.full_propagate_assign(&mut tmp_lhs, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_lhs, streams) };
}

let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix(
6 changes: 3 additions & 3 deletions tfhe/src/integer/gpu/server_key/radix/scalar_bitwise_op.rs
Original file line number Diff line number Diff line change
@@ -193,7 +193,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarAnd, streams);
ct.as_mut().info = ct.as_ref().info.after_scalar_bitand(rhs);
@@ -234,7 +234,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarOr, streams);
ct.as_mut().info = ct.as_ref().info.after_scalar_bitor(rhs);
@@ -275,7 +275,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarXor, streams);
ct.as_mut().info = ct.as_ref().info.after_scalar_bitxor(rhs);
16 changes: 8 additions & 8 deletions tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs
Original file line number Diff line number Diff line change
@@ -606,7 +606,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -688,7 +688,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -929,7 +929,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -970,7 +970,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -1011,7 +1011,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -1051,7 +1051,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -1156,7 +1156,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

@@ -1192,7 +1192,7 @@ impl CudaServerKey {
ct
} else {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
&tmp_lhs
};

12 changes: 6 additions & 6 deletions tfhe/src/integer/gpu/server_key/radix/scalar_div_mod.rs
Original file line number Diff line number Diff line change
@@ -296,7 +296,7 @@ impl CudaServerKey {
numerator
} else {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
&tmp_numerator
};

@@ -425,7 +425,7 @@ impl CudaServerKey {
numerator
} else {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
&tmp_numerator
};

@@ -536,7 +536,7 @@ impl CudaServerKey {
numerator
} else {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
&tmp_numerator
};

@@ -776,7 +776,7 @@ impl CudaServerKey {
numerator
} else {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
&tmp_numerator
};

@@ -894,7 +894,7 @@ impl CudaServerKey {
numerator
} else {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
&tmp_numerator
};

@@ -1006,7 +1006,7 @@ impl CudaServerKey {
numerator
} else {
tmp_numerator = numerator.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_numerator, streams);
self.full_propagate_assign_async(&mut tmp_numerator, streams);
&tmp_numerator
};

2 changes: 1 addition & 1 deletion tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs
Original file line number Diff line number Diff line change
@@ -242,7 +242,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
};

self.unchecked_scalar_mul_assign_async(ct, scalar, streams);
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs
Original file line number Diff line number Diff line change
@@ -231,7 +231,7 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, stream);
unsafe { self.full_propagate_assign_async(ct, stream) };
}

unsafe { self.unchecked_scalar_rotate_left_assign_async(ct, n, stream) };
@@ -245,7 +245,7 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, stream);
unsafe { self.full_propagate_assign_async(ct, stream) };
}

unsafe { self.unchecked_scalar_rotate_right_assign_async(ct, n, stream) };
8 changes: 4 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs
Original file line number Diff line number Diff line change
@@ -371,7 +371,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}

self.unchecked_scalar_right_shift_assign_async(ct, shift, streams);
@@ -459,7 +459,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
}

self.unchecked_scalar_left_shift_assign_async(ct, shift, streams);
@@ -543,7 +543,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
unsafe { self.full_propagate_assign_async(ct, streams) };
}

unsafe {
@@ -563,7 +563,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
unsafe { self.full_propagate_assign_async(ct, streams) };
}

unsafe {
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs
Original file line number Diff line number Diff line change
@@ -155,7 +155,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
self.full_propagate_assign_async(ct, streams);
};

self.unchecked_scalar_sub_assign_async(ct, scalar, streams);
@@ -221,7 +221,7 @@ impl CudaServerKey {
unsafe {
tmp_lhs = ct_left.duplicate_async(streams);
if !tmp_lhs.block_carries_are_empty() {
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
}
}

32 changes: 16 additions & 16 deletions tfhe/src/integer/gpu/server_key/radix/shift.rs
Original file line number Diff line number Diff line change
@@ -267,20 +267,20 @@ impl CudaServerKey {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = shift.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -312,20 +312,20 @@ impl CudaServerKey {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&mut tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = shift.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
@@ -420,20 +420,20 @@ impl CudaServerKey {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = shift.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
@@ -465,20 +465,20 @@ impl CudaServerKey {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
(&mut tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate_async(streams);
tmp_rhs = shift.duplicate_async(streams);

self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_lhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
26 changes: 13 additions & 13 deletions tfhe/src/integer/gpu/server_key/radix/sub.rs
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ impl CudaServerKey {
streams: &CudaStreams,
) {
let neg = self.unchecked_neg_async(ct_right, streams);
self.unchecked_add_assign(ct_left, &neg, streams);
self.unchecked_add_assign_async(ct_left, &neg, streams);
}

/// Computes homomorphically a subtraction between two ciphertexts encrypting integer values.
@@ -257,18 +257,18 @@ impl CudaServerKey {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.duplicate_async(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign_async(ct_left, streams);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.duplicate_async(streams);

self.full_propagate_assign(ct_left, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
self.full_propagate_assign_async(ct_left, streams);
self.full_propagate_assign_async(&mut tmp_rhs, streams);
(ct_left, &tmp_rhs)
}
};
@@ -299,14 +299,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
@@ -315,8 +315,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
@@ -521,14 +521,14 @@ impl CudaServerKey {
(true, false) => {
unsafe {
tmp_rhs = ct_right.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}
(ct_left, &tmp_rhs)
}
(false, true) => {
unsafe {
tmp_lhs = ct_left.duplicate_async(stream);
self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
}
(&tmp_lhs, ct_right)
}
@@ -537,8 +537,8 @@ impl CudaServerKey {
tmp_lhs = ct_left.duplicate_async(stream);
tmp_rhs = ct_right.duplicate_async(stream);

self.full_propagate_assign(&mut tmp_lhs, stream);
self.full_propagate_assign(&mut tmp_rhs, stream);
self.full_propagate_assign_async(&mut tmp_lhs, stream);
self.full_propagate_assign_async(&mut tmp_rhs, stream);
}

(&tmp_lhs, &tmp_rhs)
8 changes: 4 additions & 4 deletions tfhe/src/integer/gpu/server_key/radix/vector_comparisons.rs
Original file line number Diff line number Diff line change
@@ -260,7 +260,7 @@ impl CudaServerKey {
for ct in lhs.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_lhs.push(temp_ct);
}
@@ -275,7 +275,7 @@ impl CudaServerKey {
for ct in rhs.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_rhs.push(temp_ct);
}
@@ -411,7 +411,7 @@ impl CudaServerKey {
for ct in lhs.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_lhs.push(temp_ct);
}
@@ -426,7 +426,7 @@ impl CudaServerKey {
for ct in rhs.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_rhs.push(temp_ct);
}
52 changes: 28 additions & 24 deletions tfhe/src/integer/gpu/server_key/radix/vector_find.rs
Original file line number Diff line number Diff line change
@@ -250,7 +250,7 @@ impl CudaServerKey {
self.unchecked_match_value(ct, matches, streams)
} else {
let mut clone = ct.duplicate(streams);
self.full_propagate_assign(&mut clone, streams);
unsafe { self.full_propagate_assign_async(&mut clone, streams) };
self.unchecked_match_value(&clone, matches, streams)
}
}
@@ -364,7 +364,7 @@ impl CudaServerKey {
self.unchecked_match_value_or(ct, matches, or_value, streams)
} else {
let mut clone = ct.duplicate(streams);
self.full_propagate_assign(&mut clone, streams);
unsafe { self.full_propagate_assign_async(&mut clone, streams) };
self.unchecked_match_value_or(&clone, matches, or_value, streams)
}
}
@@ -444,7 +444,7 @@ impl CudaServerKey {
for ct in cts.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_cts.push(temp_ct);
}
@@ -458,7 +458,7 @@ impl CudaServerKey {
value
} else {
tmp_value = value.duplicate(streams);
self.full_propagate_assign(&mut tmp_value, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_value, streams) };
&tmp_value
};

@@ -544,7 +544,7 @@ impl CudaServerKey {
for ct in cts.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_cts.push(temp_ct);
}
@@ -632,7 +632,7 @@ impl CudaServerKey {
ct
} else {
tmp_ct = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_ct, streams) };
&tmp_ct
};
self.unchecked_is_in_clears(ct, clears, streams)
@@ -732,7 +732,7 @@ impl CudaServerKey {
ct
} else {
tmp_ct = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_ct, streams) };
streams.synchronize();
&tmp_ct
};
@@ -859,7 +859,7 @@ impl CudaServerKey {
ct
} else {
tmp_ct = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_ct, streams) };
streams.synchronize();
&tmp_ct
};
@@ -964,7 +964,7 @@ impl CudaServerKey {
for ct in cts.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_cts.push(temp_ct);
}
@@ -978,7 +978,7 @@ impl CudaServerKey {
value
} else {
tmp_value = value.duplicate(streams);
self.full_propagate_assign(&mut tmp_value, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_value, streams) };
&tmp_value
};
self.unchecked_index_of(cts, value, streams)
@@ -1082,7 +1082,7 @@ impl CudaServerKey {
for ct in cts.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_cts.push(temp_ct);
}
@@ -1211,7 +1211,7 @@ impl CudaServerKey {
for ct in cts.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_cts.push(temp_ct);
}
@@ -1343,7 +1343,7 @@ impl CudaServerKey {
for ct in cts.iter() {
let mut temp_ct = ct.duplicate(streams);
if !temp_ct.block_carries_are_empty() {
self.full_propagate_assign(&mut temp_ct, streams);
unsafe { self.full_propagate_assign_async(&mut temp_ct, streams) };
}
tmp_cts.push(temp_ct);
}
@@ -1357,7 +1357,7 @@ impl CudaServerKey {
value
} else {
tmp_value = value.duplicate(streams);
self.full_propagate_assign(&mut tmp_value, streams);
unsafe { self.full_propagate_assign_async(&mut tmp_value, streams) };
&tmp_value
};
self.unchecked_first_index_of(cts, value, streams)
@@ -1606,11 +1606,13 @@ impl CudaServerKey {
for chunk_idx in 0..(num_chunks - 1) {
for ct_idx in 0..chunk_size {
let one_hot_idx = chunk_idx * chunk_size + ct_idx;
self.unchecked_add_assign(
&mut aggregated_vector,
&one_hot_vector[one_hot_idx],
streams,
);
unsafe {
self.unchecked_add_assign_async(
&mut aggregated_vector,
&one_hot_vector[one_hot_idx],
streams,
)
};
}
let mut temp = aggregated_vector.duplicate(streams);
let mut aggregated_mut_slice = aggregated_vector
@@ -1684,11 +1686,13 @@ impl CudaServerKey {
let last_chunk_size = one_hot_vector.len() - (num_chunks - 1) * chunk_size;
for ct_idx in 0..last_chunk_size {
let one_hot_idx = (num_chunks - 1) * chunk_size + ct_idx;
self.unchecked_add_assign(
&mut aggregated_vector,
&one_hot_vector[one_hot_idx],
streams,
);
unsafe {
self.unchecked_add_assign_async(
&mut aggregated_vector,
&one_hot_vector[one_hot_idx],
streams,
)
};
}

let message_extract_lut =