Skip to content

Commit

Permalink
feat(gpu): scalar shifts with one wave of pbs
Browse files Browse the repository at this point in the history
  • Loading branch information
bbarbakadze committed Feb 15, 2024
1 parent 52f3bab commit 56f9b22
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 61 deletions.
33 changes: 4 additions & 29 deletions backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,10 @@ template <typename Torus> struct int_shift_buffer {
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);

tmp_rotated = (Torus *)cuda_malloc_async(
max_amount_of_pbs * big_lwe_size_bytes, stream);
(max_amount_of_pbs + 2) * big_lwe_size_bytes, stream);

cuda_memset_async(tmp_rotated, 0,
(max_amount_of_pbs + 2) * big_lwe_size_bytes, stream);

uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);

Expand Down Expand Up @@ -709,34 +712,6 @@ template <typename Torus> struct int_shift_buffer {

lut_buffers_bivariate.push_back(cur_lut_bivariate);
}

// here we generate 'message_modulus' times lut
// one for each 'shift'
// lut_indexes will have indexes for single lut only and those indexes
// will be 0 it means for pbs corresponding lut should be selected and
// pass along lut_indexes filled with zeros

// calculate lut for each 'shift'
for (int shift = 0; shift < params.message_modulus; shift++) {
auto cur_lut =
new int_radix_lut<Torus>(stream, params, 1, 1, allocate_gpu_memory);

std::function<Torus(Torus)> shift_lut_f;
if (shift_type == LEFT_SHIFT)
shift_lut_f = [shift, params](Torus x) -> Torus {
return (x << shift) % params.message_modulus;
};
else
shift_lut_f = [shift, params](Torus x) -> Torus {
return (x >> shift) % params.message_modulus;
};

generate_device_accumulator<Torus>(
stream, cur_lut->lut, params.glwe_dimension, params.polynomial_size,
params.message_modulus, params.carry_modulus, shift_lut_f);

lut_buffers_univariate.push_back(cur_lut);
}
}
}

Expand Down
47 changes: 15 additions & 32 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ __host__ void host_integer_radix_scalar_shift_kb_inplace(
size_t rotations = std::min(shift / num_bits_in_block, (size_t)num_blocks);
size_t shift_within_block = shift % num_bits_in_block;

Torus *rotated_buffer = mem->tmp_rotated;
Torus *full_rotated_buffer = mem->tmp_rotated;
Torus *rotated_buffer = &full_rotated_buffer[big_lwe_size];

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];
auto lut_univariate = mem->lut_buffers_univariate[shift_within_block];

// rotate right all the blocks in radix ciphertext
// copy result in new buffer
Expand All @@ -68,23 +68,15 @@ __host__ void host_integer_radix_scalar_shift_kb_inplace(
return;
}

// check if we have enough blocks for partial processing
if (rotations < num_blocks - 1) {
auto partial_current_blocks = &lwe_array[(rotations + 1) * big_lwe_size];
auto partial_previous_blocks = &lwe_array[rotations * big_lwe_size];
auto partial_current_blocks = &lwe_array[rotations * big_lwe_size];
auto partial_previous_blocks =
&full_rotated_buffer[rotations * big_lwe_size];

size_t partial_block_count = num_blocks - rotations - 1;
size_t partial_block_count = num_blocks - rotations;

integer_radix_apply_bivariate_lookup_table_kb<Torus>(
stream, partial_current_blocks, partial_current_blocks,
partial_previous_blocks, bsk, ksk, partial_block_count,
lut_bivariate);
}

auto rest = &lwe_array[rotations * big_lwe_size];

integer_radix_apply_univariate_lookup_table_kb<Torus>(
stream, rest, rest, bsk, ksk, 1, lut_univariate);
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
stream, partial_current_blocks, partial_current_blocks,
partial_previous_blocks, bsk, ksk, partial_block_count, lut_bivariate);

} else {
// right shift
Expand All @@ -102,23 +94,14 @@ __host__ void host_integer_radix_scalar_shift_kb_inplace(
return;
}

// check if we have enough blocks for partial processing
if (rotations < num_blocks - 1) {
auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &lwe_array[big_lwe_size];
auto partial_current_blocks = lwe_array;
auto partial_next_blocks = &rotated_buffer[big_lwe_size];

size_t partial_block_count = num_blocks - rotations - 1;

integer_radix_apply_bivariate_lookup_table_kb<Torus>(
stream, partial_current_blocks, partial_current_blocks,
partial_next_blocks, bsk, ksk, partial_block_count, lut_bivariate);
}
size_t partial_block_count = num_blocks - rotations;

// The right-most block is done separately as it does not
// need to recuperate the shifted bits from its right neighbour.
auto last_block = &lwe_array[(num_blocks - rotations - 1) * big_lwe_size];
integer_radix_apply_univariate_lookup_table_kb<Torus>(
stream, last_block, last_block, bsk, ksk, 1, lut_univariate);
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
stream, partial_current_blocks, partial_current_blocks,
partial_next_blocks, bsk, ksk, partial_block_count, lut_bivariate);
}
}

Expand Down

0 comments on commit 56f9b22

Please sign in to comment.