Skip to content

Commit

Permalink
fix(gpu): fix scalar shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jul 25, 2024
1 parent d3f2ecd commit a6dff32
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1755,12 +1755,12 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);

tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 2) *
tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 3) *
big_lwe_size_bytes,
streams[0], gpu_indexes[0]);

cuda_memset_async(tmp_rotated, 0,
(num_radix_blocks + 2) * big_lwe_size_bytes, streams[0],
(num_radix_blocks + 3) * big_lwe_size_bytes, streams[0],
gpu_indexes[0]);

uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);
Expand Down
32 changes: 15 additions & 17 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
Torus *full_rotated_buffer = mem->tmp_rotated;
Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size];

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

// rotate right all the blocks in radix ciphertext
// copy result in new buffer
// 1024 threads are used in every block
Expand All @@ -76,6 +74,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(
return;
}

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];
auto partial_current_blocks = &lwe_array[rotations * big_lwe_size];
auto partial_previous_blocks =
&full_rotated_buffer[rotations * big_lwe_size];
Expand Down Expand Up @@ -109,6 +108,7 @@ __host__ void host_integer_radix_logical_scalar_shift_kb_inplace(

auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

size_t partial_block_count = num_blocks - rotations;

Expand Down Expand Up @@ -139,8 +139,6 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
int_arithmetic_scalar_shift_buffer<Torus> *mem, void **bsks, Torus **ksks,
uint32_t num_blocks) {

cudaSetDevice(gpu_indexes[0]);

auto params = mem->params;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
Expand All @@ -160,15 +158,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
size_t shift_within_block = shift % num_bits_in_block;

Torus *rotated_buffer = mem->tmp_rotated;
Torus *padding_block = &rotated_buffer[num_blocks * big_lwe_size];
Torus *padding_block = &rotated_buffer[(num_blocks + 1) * big_lwe_size];
Torus *last_block_copy = &padding_block[big_lwe_size];

auto lut_univariate_shift_last_block =
mem->lut_buffers_univariate[shift_within_block - 1];
auto lut_univariate_padding_block =
mem->lut_buffers_univariate[num_bits_in_block - 1];
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

if (mem->shift_type == RIGHT_SHIFT) {
host_radix_blocks_rotate_left(streams, gpu_indexes, gpu_count,
rotated_buffer, lwe_array, rotations,
Expand Down Expand Up @@ -205,10 +197,12 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
last_block_copy,
rotated_buffer + (num_blocks - rotations - 1) * big_lwe_size,
big_lwe_size_bytes, streams[0], gpu_indexes[0]);
auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];
size_t partial_block_count = num_blocks - rotations;
if (shift_within_block != 0 && rotations != num_blocks) {
auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];
size_t partial_block_count = num_blocks - rotations;
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, partial_current_blocks,
partial_current_blocks, partial_next_blocks, bsks, ksks,
Expand All @@ -225,10 +219,13 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
// All sections may be executed in parallel
#pragma omp section
{
auto lut_univariate_padding_block =
mem->lut_buffers_univariate[num_bits_in_block - 1];
integer_radix_apply_univariate_lookup_table_kb(
mem->local_streams_1, gpu_indexes, gpu_count, padding_block,
last_block_copy, bsks, ksks, 1, lut_univariate_padding_block);
// Replace blocks 'pulled' from the left with the correct padding block
// Replace blocks 'pulled' from the left with the correct padding
// block
for (uint i = 0; i < rotations; i++) {
cuda_memcpy_async_gpu_to_gpu(
lwe_array + (num_blocks - rotations + i) * big_lwe_size,
Expand All @@ -238,7 +235,9 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
}
#pragma omp section
{
if (shift_within_block != 0 && rotations != num_blocks) {
if (shift_within_block != 0) {
auto lut_univariate_shift_last_block =
mem->lut_buffers_univariate[shift_within_block - 1];
integer_radix_apply_univariate_lookup_table_kb(
mem->local_streams_2, gpu_indexes, gpu_count, last_block,
last_block_copy, bsks, ksks, 1, lut_univariate_shift_last_block);
Expand All @@ -249,7 +248,6 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]);
cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]);
}

} else {
PANIC("Cuda error (scalar shift): left scalar shift is never of the "
"arithmetic type")
Expand Down

0 comments on commit a6dff32

Please sign in to comment.