Skip to content

Commit

Permalink
chore(gpu): rework async logic for ilog2
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Sep 18, 2024
1 parent 50019e0 commit b963a67
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 76 deletions.
61 changes: 42 additions & 19 deletions tfhe/src/integer/gpu/server_key/radix/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,25 @@ impl CudaServerKey {
ciphertexts: &[T],
streams: &CudaStreams,
) -> T {
let mut result = unsafe {
self.unchecked_partial_sum_ciphertexts_async(ciphertexts, streams)
.unwrap()
};

unsafe {
self.propagate_single_carry_assign_async(&mut result, streams);
}
let result = unsafe { self.unchecked_sum_ciphertexts_async(ciphertexts, streams) };
streams.synchronize();
result
}

/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: &[T],
streams: &CudaStreams,
) -> T {
let mut result = self
.unchecked_partial_sum_ciphertexts_async(ciphertexts, streams)
.unwrap();

self.propagate_single_carry_assign_async(&mut result, streams);
assert!(result.block_carries_are_empty());
result
}
Expand Down Expand Up @@ -380,6 +390,20 @@ impl CudaServerKey {
}

pub fn sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: Vec<T>,
streams: &CudaStreams,
) -> Option<T> {
let res = unsafe { self.sum_ciphertexts_async(ciphertexts, streams) };
streams.synchronize();
res
}

/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self,
mut ciphertexts: Vec<T>,
streams: &CudaStreams,
Expand All @@ -388,16 +412,14 @@ impl CudaServerKey {
return None;
}

unsafe {
ciphertexts
.iter_mut()
.filter(|ct| !ct.block_carries_are_empty())
.for_each(|ct| {
self.full_propagate_assign_async(&mut *ct, streams);
});
}
ciphertexts
.iter_mut()
.filter(|ct| !ct.block_carries_are_empty())
.for_each(|ct| {
self.full_propagate_assign_async(&mut *ct, streams);
});

Some(self.unchecked_sum_ciphertexts(&ciphertexts, streams))
Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams))
}

/// ```rust
Expand Down Expand Up @@ -655,7 +677,8 @@ impl CudaServerKey {
unsafe {
result = lhs.duplicate_async(streams);
}
let carry_out: CudaSignedRadixCiphertext = self.create_trivial_zero_radix(1, streams);
let carry_out: CudaSignedRadixCiphertext =
unsafe { self.create_trivial_zero_radix_async(1, streams) };
let mut overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext);

unsafe {
Expand All @@ -666,8 +689,8 @@ impl CudaServerKey {
signed_operation,
streams,
);
streams.synchronize();
}
streams.synchronize();

(result, overflowed)
}
Expand Down
Loading

0 comments on commit b963a67

Please sign in to comment.