From 31f325cf1118b7da762cb615ce62de408ced7e90 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Mon, 27 Jan 2025 16:53:14 +0100 Subject: [PATCH] fix(gpu): fix scalar mul with 1 block --- .../cuda/include/integer/integer_utilities.h | 13 ++++++++++--- .../cuda/src/integer/scalar_mul.cuh | 3 --- tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs | 9 ++++++--- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index b7220e64c4..528e33f7fa 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -4454,9 +4454,16 @@ template struct int_scalar_mul_buffer { num_ciphertext_bits * num_radix_blocks * lwe_size_bytes, streams[0], gpu_indexes[0]); - logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer( - streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks, - allocate_gpu_memory, all_shifted_buffer); + if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2) + logical_scalar_shift_buffer = + new int_logical_scalar_shift_buffer( + streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, + num_radix_blocks, allocate_gpu_memory, all_shifted_buffer); + else + logical_scalar_shift_buffer = + new int_logical_scalar_shift_buffer( + streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, + num_radix_blocks, allocate_gpu_memory); sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory( streams, gpu_indexes, gpu_count, params, num_radix_blocks, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh index ef58f738bc..8665c38dce 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh @@ -47,9 +47,6 @@ __host__ void host_integer_scalar_mul_radix( void *const *bsks, T *const *ksks, uint32_t input_lwe_dimension, uint32_t message_modulus, uint32_t num_radix_blocks, uint32_t num_scalars) { - if (num_radix_blocks == 0 | num_scalars == 0) - return; - // lwe_size includes the presence of the body // whereas lwe_dimension is the number of elements in the mask uint32_t lwe_size = input_lwe_dimension + 1; diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs index 0a87134ee0..65ec3e11b7 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs @@ -79,7 +79,9 @@ impl CudaServerKey { return; } - if scalar == Scalar::ONE { + let ciphertext = ct.as_mut(); + let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0; + if scalar == Scalar::ONE || num_blocks == 0 { return; } @@ -89,8 +91,6 @@ impl CudaServerKey { self.unchecked_scalar_left_shift_assign_async(ct, scalar.ilog2() as u64, streams); return; } - let ciphertext = ct.as_mut(); - let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0; let msg_bits = self.message_modulus.0.ilog2() as usize; let decomposer = BlockDecomposer::with_early_stop_at_zero(scalar, 1).iter_as::(); @@ -106,6 +106,9 @@ impl CudaServerKey { let decomposed_scalar = BlockDecomposer::with_early_stop_at_zero(scalar, 1) .iter_as::() .collect::>(); + if decomposed_scalar.is_empty() { + return; + } match &self.bootstrapping_key { CudaBootstrappingKey::Classic(d_bsk) => {