Skip to content

Commit

Permalink
feat(gpu): signed scalar add
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Mar 11, 2024
1 parent cc905a0 commit f84c34c
Show file tree
Hide file tree
Showing 12 changed files with 582 additions and 344 deletions.
118 changes: 116 additions & 2 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,100 @@ mod cuda {
}
);

fn bench_cuda_server_key_binary_scalar_signed_function_clean_inputs<F, G>(
c: &mut Criterion,
bench_name: &str,
display_name: &str,
binary_op: F,
rng_func: G,
) where
F: Fn(&CudaServerKey, &mut CudaSignedRadixCiphertext, ScalarType, &CudaStream),
G: Fn(&mut ThreadRng, usize) -> ScalarType,
{
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(60));
let mut rng = rand::thread_rng();

let gpu_index = 0;
let device = CudaDevice::new(gpu_index);
let stream = CudaStream::new_unchecked(device);

for (param, num_block, bit_size) in ParamsAndNumBlocksIter::default() {
if bit_size > ScalarType::BITS as usize {
break;
}
let param_name = param.name();

let max_value_for_bit_size = ScalarType::MAX >> (ScalarType::BITS as usize - bit_size);

let bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits_scalar_{bit_size}");
bench_group.bench_function(&bench_id, |b| {
let (cks, _cpu_sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let gpu_sks = CudaServerKey::new(&cks, &stream);

let encrypt_one_value = || {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
let clear_0 = tfhe::integer::I256::from((clearlow, clearhigh));
let ct_0 = cks.encrypt_signed_radix(clear_0, num_block);
let d_ct_0 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct_0, &stream);

let clear_1 = rng_func(&mut rng, bit_size) & max_value_for_bit_size;

(d_ct_0, clear_1)
};

b.iter_batched(
encrypt_one_value,
|(mut ct_0, clear_1)| {
binary_op(&gpu_sks, &mut ct_0, clear_1, &stream);
},
criterion::BatchSize::SmallInput,
)
});

write_to_json::<u64, _>(
&bench_id,
param,
param.name(),
display_name,
&OperatorType::Atomic,
bit_size as u32,
vec![param.message_modulus().0.ilog2(); num_block],
);
}

bench_group.finish()
}

macro_rules! define_cuda_server_key_bench_clean_input_scalar_signed_fn (
(method_name: $server_key_method:ident, display_name:$name:ident, rng_func:$($rng_fn:tt)*) => {
::paste::paste!{
fn [<cuda_ $server_key_method>](c: &mut Criterion) {
bench_cuda_server_key_binary_scalar_signed_function_clean_inputs(
c,
concat!("integer::cuda::signed::", stringify!($server_key_method)),
stringify!($name),
|server_key, lhs, rhs, stream| {
server_key.$server_key_method(lhs, rhs, stream);
},
$($rng_fn)*
)
}
}
}
);

// Functions used to apply different way of selecting a scalar based on the context.
fn default_signed_scalar(rng: &mut ThreadRng, _clear_bit_size: usize) -> ScalarType {
let clearlow = rng.gen::<u128>();
let clearhigh = rng.gen::<u128>();
tfhe::integer::I256::from((clearlow, clearhigh))
}

define_cuda_server_key_bench_clean_input_signed_fn!(
method_name: unchecked_add,
display_name: add
Expand All @@ -1402,6 +1496,12 @@ mod cuda {
display_name: mul
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: unchecked_scalar_add,
display_name: add,
rng_func: default_signed_scalar
);

//===========================================
// Default
//===========================================
Expand All @@ -1426,28 +1526,42 @@ mod cuda {
display_name: mul
);

define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
method_name: scalar_add,
display_name: add,
rng_func: default_signed_scalar
);

criterion_group!(
unchecked_cuda_ops,
cuda_unchecked_add,
cuda_unchecked_sub,
cuda_unchecked_neg,
cuda_unchecked_mul
cuda_unchecked_mul,
);

criterion_group!(unchecked_scalar_cuda_ops, cuda_unchecked_scalar_add,);

criterion_group!(default_cuda_ops, cuda_add, cuda_sub, cuda_neg, cuda_mul);

criterion_group!(default_scalar_cuda_ops, cuda_scalar_add);
}

#[cfg(feature = "gpu")]
use cuda::{default_cuda_ops, unchecked_cuda_ops};
use cuda::{
default_cuda_ops, default_scalar_cuda_ops, unchecked_cuda_ops, unchecked_scalar_cuda_ops,
};

#[cfg(feature = "gpu")]
fn go_through_gpu_bench_groups(val: &str) {
match val.to_lowercase().as_str() {
"default" => {
default_cuda_ops();
default_scalar_cuda_ops();
}
"unchecked" => {
unchecked_cuda_ops();
unchecked_scalar_cuda_ops();
}
_ => panic!("unknown benchmark operations flavor"),
};
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ generic_integer_impl_scalar_operation!(
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_stream(|stream| {
cuda_key.key.scalar_add(
&lhs.ciphertext.on_gpu(), rhs, stream
&*lhs.ciphertext.on_gpu(), rhs, stream
)
});
RadixCiphertext::Cuda(inner_result)
Expand Down
66 changes: 29 additions & 37 deletions tfhe/src/integer/gpu/server_key/radix/scalar_add.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::CudaStream;
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaServerKey;
use itertools::Itertools;

Expand Down Expand Up @@ -43,14 +43,10 @@ impl CudaServerKey {
/// let dec: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg + scalar, dec);
/// ```
pub fn unchecked_scalar_add<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn unchecked_scalar_add<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u8>,
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.unchecked_scalar_add_assign(&mut result, scalar, stream);
Expand All @@ -61,15 +57,16 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_add_assign_async<T>(
pub unsafe fn unchecked_scalar_add_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8>,
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
if scalar > T::ZERO {
if scalar != Scalar::ZERO {
let bits_in_message = self.message_modulus.0.ilog2();
let decomposer =
BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message).iter_as::<u8>();
Expand All @@ -95,18 +92,19 @@ impl CudaServerKey {
self.message_modulus.0 as u32,
self.carry_modulus.0 as u32,
);
}

ct.as_mut().info = ct.as_ref().info.after_scalar_add(scalar);
ct.as_mut().info = ct.as_ref().info.after_scalar_add(scalar);
}
}

pub fn unchecked_scalar_add_assign<T>(
pub fn unchecked_scalar_add_assign<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8>,
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_add_assign_async(ct, scalar, stream);
Expand Down Expand Up @@ -151,14 +149,10 @@ impl CudaServerKey {
/// let dec: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg + scalar, dec);
/// ```
pub fn scalar_add<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_add<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u8>,
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.scalar_add_assign(&mut result, scalar, stream);
Expand All @@ -169,13 +163,14 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_add_assign_async<T>(
pub unsafe fn scalar_add_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8>,
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign_async(ct, stream);
Expand All @@ -185,13 +180,10 @@ impl CudaServerKey {
self.full_propagate_assign_async(ct, stream);
}

pub fn scalar_add_assign<T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) where
T: DecomposableInto<u8>,
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream)
where
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_add_assign_async(ct, scalar, stream);
Expand Down
37 changes: 34 additions & 3 deletions tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_sub;

use crate::core_crypto::gpu::CudaStream;
Expand Down Expand Up @@ -100,21 +101,21 @@ where
}

/// For unchecked/default binary functions with one scalar input
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>
for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
u64,
i64,
&CudaStream,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}

fn execute(&mut self, input: (&'a SignedRadixCiphertext, u64)) -> SignedRadixCiphertext {
fn execute(&mut self, input: (&'a SignedRadixCiphertext, i64)) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
Expand All @@ -128,3 +129,33 @@ where
gpu_result.to_signed_radix_ciphertext(&context.stream)
}
}

/// For unchecked/default binary functions with one scalar input
impl<F> FunctionExecutor<(SignedRadixCiphertext, i64), SignedRadixCiphertext>
for GpuFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStream,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}

fn execute(&mut self, input: (SignedRadixCiphertext, i64)) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");

let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&input.0, &context.stream);

let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.stream);

gpu_result.to_signed_radix_ciphertext(&context.stream)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_cases_signed::{
signed_default_scalar_add_test, signed_unchecked_scalar_add_test,
};
use crate::shortint::parameters::*;

create_gpu_parametrized_test!(integer_signed_unchecked_scalar_add);
create_gpu_parametrized_test!(integer_signed_scalar_add);

fn integer_signed_unchecked_scalar_add<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_add);
signed_unchecked_scalar_add_test(param, executor);
}

fn integer_signed_scalar_add<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add);
signed_default_scalar_add_test(param, executor);
}
Loading

0 comments on commit f84c34c

Please sign in to comment.