diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index df4eb39df0..33d862cd29 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -1755,12 +1755,12 @@ template struct int_arithmetic_scalar_shift_buffer { uint32_t big_lwe_size = params.big_lwe_dimension + 1; uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus); - tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 2) * + tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 3) * big_lwe_size_bytes, streams[0], gpu_indexes[0]); cuda_memset_async(tmp_rotated, 0, - (num_radix_blocks + 2) * big_lwe_size_bytes, streams[0], + (num_radix_blocks + 3) * big_lwe_size_bytes, streams[0], gpu_indexes[0]); uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh index e612c9ab2f..192ac2cb48 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh @@ -52,8 +52,6 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace( Torus *full_rotated_buffer = mem->tmp_rotated; Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size]; - auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; - // rotate right all the blocks in radix ciphertext // copy result in new buffer // 1024 threads are used in every block @@ -76,6 +74,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace( return; } + auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; auto partial_current_blocks = &lwe_array[rotations * big_lwe_size]; auto partial_previous_blocks = &full_rotated_buffer[rotations * big_lwe_size]; @@ -109,6 +108,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace( auto partial_current_blocks = lwe_array; auto partial_next_blocks = &rotated_buffer[big_lwe_size]; + auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; size_t partial_block_count = num_blocks - rotations; @@ -139,8 +139,6 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( int_arithmetic_scalar_shift_buffer *mem, void **bsks, Torus **ksks, uint32_t num_blocks) { - cudaSetDevice(gpu_indexes[0]); - auto params = mem->params; auto glwe_dimension = params.glwe_dimension; auto polynomial_size = params.polynomial_size; @@ -160,15 +158,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( size_t shift_within_block = shift % num_bits_in_block; Torus *rotated_buffer = mem->tmp_rotated; - Torus *padding_block = &rotated_buffer[num_blocks * big_lwe_size]; + Torus *padding_block = &rotated_buffer[(num_blocks + 1) * big_lwe_size]; Torus *last_block_copy = &padding_block[big_lwe_size]; - auto lut_univariate_shift_last_block = - mem->lut_buffers_univariate[shift_within_block - 1]; - auto lut_univariate_padding_block = - mem->lut_buffers_univariate[num_bits_in_block - 1]; - auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; - if (mem->shift_type == RIGHT_SHIFT) { host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count, rotated_buffer, lwe_array, rotations, @@ -205,10 +197,12 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( last_block_copy, rotated_buffer + (num_blocks - rotations - 1) * big_lwe_size, big_lwe_size_bytes, streams[0], gpu_indexes[0]); - auto partial_current_blocks = lwe_array; - auto partial_next_blocks = &rotated_buffer[big_lwe_size]; - size_t partial_block_count = num_blocks - rotations; if (shift_within_block != 0 && rotations != num_blocks) { + auto partial_current_blocks = lwe_array; + auto partial_next_blocks = &rotated_buffer[big_lwe_size]; + size_t partial_block_count = num_blocks - rotations; + auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1]; + integer_radix_apply_bivariate_lookup_table_kb( streams, gpu_indexes, gpu_count, partial_current_blocks, partial_current_blocks, partial_next_blocks, bsks, ksks, @@ -225,10 +219,13 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( // All sections may be executed in parallel #pragma omp section { + auto lut_univariate_padding_block = + mem->lut_buffers_univariate[num_bits_in_block - 1]; integer_radix_apply_univariate_lookup_table_kb( mem->local_streams_1, gpu_indexes, gpu_count, padding_block, last_block_copy, bsks, ksks, 1, lut_univariate_padding_block); - // Replace blocks 'pulled' from the left with the correct padding block + // Replace blocks 'pulled' from the left with the correct padding + // block for (uint i = 0; i < rotations; i++) { cuda_memcpy_async_gpu_to_gpu( lwe_array + (num_blocks - rotations + i) * big_lwe_size, @@ -238,7 +235,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( } #pragma omp section { - if (shift_within_block != 0 && rotations != num_blocks) { + if (shift_within_block != 0) { + auto lut_univariate_shift_last_block = + mem->lut_buffers_univariate[shift_within_block - 1]; integer_radix_apply_univariate_lookup_table_kb( mem->local_streams_2, gpu_indexes, gpu_count, last_block, last_block_copy, bsks, ksks, 1, lut_univariate_shift_last_block); @@ -249,7 +248,6 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace( cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]); cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]); } - } else { PANIC("Cuda error (scalar shift): left scalar shift is never of the " "arithmetic type")