From a7c9357a02109e9279f03747fec29e5eb6866a88 Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 13 Feb 2025 11:57:50 +0100 Subject: [PATCH] fix(gpu): fix memory error in shift and rotate --- .../cuda/include/integer/integer_utilities.h | 3 --- .../tfhe-cuda-backend/cuda/src/integer/integer.cuh | 13 ++++++++++++- .../cuda/src/integer/shift_and_rotate.cuh | 4 ++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index 7b9a4f1e5a..3254ba4858 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -2781,9 +2781,6 @@ template struct int_logical_scalar_shift_buffer { tmp_rotated = pre_allocated_buffer; reuse_memory = true; - uint32_t max_amount_of_pbs = num_radix_blocks; - uint32_t big_lwe_size = params.big_lwe_dimension + 1; - uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus); set_zero_radix_ciphertext_slice_async(streams[0], gpu_indexes[0], tmp_rotated, 0, tmp_rotated->num_radix_blocks); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 2a29f06650..068b8e8f8c 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -2161,10 +2161,21 @@ extract_n_bits(cudaStream_t const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array_out, const CudaRadixCiphertextFFI *lwe_array_in, void *const *bsks, Torus *const *ksks, uint32_t effective_num_radix_blocks, + uint32_t num_radix_blocks, int_bit_extract_luts_buffer *bit_extract) { + copy_radix_ciphertext_slice_async(streams[0], gpu_indexes[0], + lwe_array_out, 0, num_radix_blocks, + lwe_array_in, 0, num_radix_blocks); + if (effective_num_radix_blocks / num_radix_blocks > 0) { + for (uint i = 1; i < effective_num_radix_blocks / num_radix_blocks; i++) { + copy_radix_ciphertext_slice_async( + streams[0], gpu_indexes[0], lwe_array_out, i * num_radix_blocks, + (i + 1) * num_radix_blocks, lwe_array_in, 0, num_radix_blocks); + } + } integer_radix_apply_univariate_lookup_table_kb( - streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_in, bsks, ksks, + streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_out, bsks, ksks, bit_extract->lut, effective_num_radix_blocks); } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh index 8feb24d303..52f7ef44eb 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/shift_and_rotate.cuh @@ -57,7 +57,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace( auto bits = mem->tmp_bits; extract_n_bits(streams, gpu_indexes, gpu_count, bits, lwe_array, bsks, ksks, num_radix_blocks * bits_per_block, - mem->bit_extract_luts); + num_radix_blocks, mem->bit_extract_luts); // Extract shift bits auto shift_bits = mem->tmp_shift_bits; @@ -78,7 +78,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace( // and we reduce noise growth extract_n_bits(streams, gpu_indexes, gpu_count, shift_bits, lwe_shift, bsks, ksks, max_num_bits_that_tell_shift, - mem->bit_extract_luts_with_offset_2); + num_radix_blocks, mem->bit_extract_luts_with_offset_2); // If signed, do an "arithmetic shift" by padding with the sign bit CudaRadixCiphertextFFI last_bit;