From 56f9b221eb657cb8011710663b14833e15b11584 Mon Sep 17 00:00:00 2001 From: Beka Barbakadze Date: Wed, 7 Feb 2024 18:49:57 +0400 Subject: [PATCH] feat(gpu): scalar shifts with one wave of pbs --- .../tfhe-cuda-backend/cuda/include/integer.h | 33 ++----------- .../cuda/src/integer/scalar_shifts.cuh | 47 ++++++------------- 2 files changed, 19 insertions(+), 61 deletions(-) diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index b7944762b0..7a4b867c11 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -646,7 +646,10 @@ template struct int_shift_buffer { uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus); tmp_rotated = (Torus *)cuda_malloc_async( - max_amount_of_pbs * big_lwe_size_bytes, stream); + (max_amount_of_pbs + 2) * big_lwe_size_bytes, stream); + + cuda_memset_async(tmp_rotated, 0, + (max_amount_of_pbs + 2) * big_lwe_size_bytes, stream); uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus); @@ -709,34 +712,6 @@ template struct int_shift_buffer { lut_buffers_bivariate.push_back(cur_lut_bivariate); } - - // here we generate 'message_modulus' times lut - // one for each 'shift' - // lut_indexes will have indexes for single lut only and those indexes - // will be 0 it means for pbs corresponding lut should be selected and - // pass along lut_indexes filled with zeros - - // calculate lut for each 'shift' - for (int shift = 0; shift < params.message_modulus; shift++) { - auto cur_lut = - new int_radix_lut(stream, params, 1, 1, allocate_gpu_memory); - - std::function shift_lut_f; - if (shift_type == LEFT_SHIFT) - shift_lut_f = [shift, params](Torus x) -> Torus { - return (x << shift) % params.message_modulus; - }; - else - shift_lut_f = [shift, params](Torus x) -> Torus { - return (x >> shift) % params.message_modulus; - }; - - generate_device_accumulator( - stream, cur_lut->lut, params.glwe_dimension, params.polynomial_size, - params.message_modulus, params.carry_modulus, shift_lut_f); - - lut_buffers_univariate.push_back(cur_lut); - } } } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh index 098aa901a5..a2f3f9c652 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh @@ -44,10 +44,10 @@ __host__ void host_integer_radix_scalar_shift_kb_inplace( size_t rotations = std::min(shift / num_bits_in_block, (size_t)num_blocks); size_t shift_within_block = shift % num_bits_in_block; - Torus *rotated_buffer = mem->tmp_rotated; + Torus *full_rotated_buffer = mem->tmp_rotated; + Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size]; auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; - auto lut_univariate = mem->lut_buffers_univariate[shift_within_block]; // rotate right all the blocks in radix ciphertext // copy result in new buffer @@ -68,23 +68,15 @@ __host__ void host_integer_radix_scalar_shift_kb_inplace( return; } - // check if we have enough blocks for partial processing - if (rotations < num_blocks - 1) { - auto partial_current_blocks = &lwe_array[(rotations + 1) * big_lwe_size]; - auto partial_previous_blocks = &lwe_array[rotations * big_lwe_size]; + auto partial_current_blocks = &lwe_array[rotations * big_lwe_size]; + auto partial_previous_blocks = + &full_rotated_buffer[rotations * big_lwe_size]; - size_t partial_block_count = num_blocks - rotations - 1; + size_t partial_block_count = num_blocks - rotations; - integer_radix_apply_bivariate_lookup_table_kb( - stream, partial_current_blocks, partial_current_blocks, - partial_previous_blocks, bsk, ksk, partial_block_count, - lut_bivariate); - } - - auto rest = &lwe_array[rotations * big_lwe_size]; - - integer_radix_apply_univariate_lookup_table_kb( - stream, rest, rest, bsk, ksk, 1, lut_univariate); + integer_radix_apply_bivariate_lookup_table_kb( + stream, partial_current_blocks, partial_current_blocks, + partial_previous_blocks, bsk, ksk, partial_block_count, lut_bivariate); } else { // right shift @@ -102,23 +94,14 @@ __host__ void host_integer_radix_scalar_shift_kb_inplace( return; } - // check if we have enough blocks for partial processing - if (rotations < num_blocks - 1) { - auto partial_current_blocks = lwe_array; - auto partial_next_blocks = &lwe_array[big_lwe_size]; + auto partial_current_blocks = lwe_array; + auto partial_next_blocks = &rotated_buffer[big_lwe_size]; - size_t partial_block_count = num_blocks - rotations - 1; - - integer_radix_apply_bivariate_lookup_table_kb( - stream, partial_current_blocks, partial_current_blocks, - partial_next_blocks, bsk, ksk, partial_block_count, lut_bivariate); - } + size_t partial_block_count = num_blocks - rotations; - // The right-most block is done separately as it does not - // need to recuperate the shifted bits from its right neighbour. - auto last_block = &lwe_array[(num_blocks - rotations - 1) * big_lwe_size]; - integer_radix_apply_univariate_lookup_table_kb( - stream, last_block, last_block, bsk, ksk, 1, lut_univariate); + integer_radix_apply_bivariate_lookup_table_kb( + stream, partial_current_blocks, partial_current_blocks, + partial_next_blocks, bsk, ksk, partial_block_count, lut_bivariate); } }