From 0f44ffdf30e736751490e52c6e7a7719c2aac706 Mon Sep 17 00:00:00 2001 From: Guillermo Oyarzun Date: Fri, 14 Feb 2025 16:55:48 +0100 Subject: [PATCH] fix(gpu): enable larger number of samples in the keyswitch --- .../cuda/src/crypto/keyswitch.cuh | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh index 05af732153..9ffd597eed 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh @@ -45,19 +45,19 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes, const Torus *__restrict__ lwe_input_indexes, const Torus *__restrict__ ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) { - const int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int tid = threadIdx.x + blockIdx.y * blockDim.x; const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x; extern __shared__ int8_t sharedmem[]; Torus *lwe_acc_out = (Torus *)sharedmem; auto block_lwe_array_out = get_chunk( - lwe_array_out, lwe_output_indexes[blockIdx.y], lwe_dimension_out + 1); + lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1); if (tid <= lwe_dimension_out) { Torus local_lwe_out = 0; auto block_lwe_array_in = get_chunk( - lwe_array_in, lwe_input_indexes[blockIdx.y], lwe_dimension_in + 1); + lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1); if (tid == lwe_dimension_out && threadIdx.y == 0) { local_lwe_out = block_lwe_array_in[lwe_dimension_in]; @@ -108,13 +108,19 @@ __host__ void host_keyswitch_lwe_ciphertext_vector( cuda_set_device(gpu_index); constexpr int num_threads_y = 32; - int num_blocks, num_threads_x; + int num_blocks_per_sample, num_threads_x; getNumBlocksAndThreads2D(lwe_dimension_out + 1, 512, num_threads_y, - num_blocks, num_threads_x); + num_blocks_per_sample, num_threads_x); int shared_mem = sizeof(Torus) * num_threads_y * num_threads_x; - dim3 grid(num_blocks, num_samples, 1); + if (num_blocks_per_sample > 65536) + PANIC("Cuda error (Keyswith): number of blocks per sample is too large"); + + // In multiplication of large integers (512, 1024, 2048), the number of + // samples can be larger than 65536, so we need to set it in the first + // dimension of the grid + dim3 grid(num_samples, num_blocks_per_sample, 1); dim3 threads(num_threads_x, num_threads_y, 1); keyswitch<<>>(