Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Al/refactor scalar shift & rotate #2033

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/tfhe-cuda-backend/cuda/include/integer/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ void scratch_cuda_integer_radix_logical_scalar_shift_kb_64(

void cuda_integer_radix_logical_scalar_shift_kb_64_inplace(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *lwe_array, uint32_t shift, int8_t *mem_ptr, void *const *bsks,
void *const *ksks, uint32_t num_blocks);
CudaRadixCiphertextFFI *lwe_array, uint32_t shift, int8_t *mem_ptr,
void *const *bsks, void *const *ksks);

void scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
Expand Down Expand Up @@ -291,8 +291,8 @@ void scratch_cuda_integer_radix_scalar_rotate_kb_64(

void cuda_integer_radix_scalar_rotate_kb_64_inplace(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *lwe_array, uint32_t n, int8_t *mem_ptr, void *const *bsks,
void *const *ksks, uint32_t num_blocks);
CudaRadixCiphertextFFI *lwe_array, uint32_t n, int8_t *mem_ptr,
void *const *bsks, void *const *ksks);

void cleanup_cuda_integer_radix_scalar_rotate(void *const *streams,
uint32_t const *gpu_indexes,
Expand Down
45 changes: 20 additions & 25 deletions backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -2682,7 +2682,7 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {

SHIFT_OR_ROTATE_TYPE shift_type;

Torus *tmp_rotated;
CudaRadixCiphertextFFI *tmp_rotated;

bool reuse_memory = false;

Expand All @@ -2698,16 +2698,11 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {

if (allocate_gpu_memory) {
uint32_t max_amount_of_pbs = num_radix_blocks;
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);

tmp_rotated = (Torus *)cuda_malloc_async((max_amount_of_pbs + 2) *
big_lwe_size_bytes,
streams[0], gpu_indexes[0]);

cuda_memset_async(tmp_rotated, 0,
(max_amount_of_pbs + 2) * big_lwe_size_bytes,
streams[0], gpu_indexes[0]);
tmp_rotated = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], tmp_rotated, max_amount_of_pbs + 2,
params.big_lwe_dimension);

uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);

Expand Down Expand Up @@ -2783,7 +2778,7 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, SHIFT_OR_ROTATE_TYPE shift_type,
int_radix_params params, uint32_t num_radix_blocks,
bool allocate_gpu_memory, Torus *pre_allocated_buffer) {
bool allocate_gpu_memory, CudaRadixCiphertextFFI *pre_allocated_buffer) {
this->shift_type = shift_type;
this->params = params;
tmp_rotated = pre_allocated_buffer;
Expand All @@ -2792,9 +2787,9 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
uint32_t max_amount_of_pbs = num_radix_blocks;
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);
cuda_memset_async(tmp_rotated, 0,
(max_amount_of_pbs + 2) * big_lwe_size_bytes, streams[0],
gpu_indexes[0]);
set_zero_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
tmp_rotated, 0,
tmp_rotated->num_radix_blocks);
agnesLeroy marked this conversation as resolved.
Show resolved Hide resolved
if (allocate_gpu_memory) {

uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);
Expand Down Expand Up @@ -2874,8 +2869,10 @@ template <typename Torus> struct int_logical_scalar_shift_buffer {
}
lut_buffers_bivariate.clear();

if (!reuse_memory)
cuda_drop_async(tmp_rotated, streams[0], gpu_indexes[0]);
if (!reuse_memory) {
release_radix_ciphertext(streams[0], gpu_indexes[0], tmp_rotated);
delete tmp_rotated;
}
}
};

Expand Down Expand Up @@ -4423,7 +4420,7 @@ template <typename Torus> struct int_scalar_mul_buffer {
int_logical_scalar_shift_buffer<Torus> *logical_scalar_shift_buffer;
int_sum_ciphertexts_vec_memory<Torus> *sum_ciphertexts_vec_mem;
Torus *preshifted_buffer;
Torus *all_shifted_buffer;
CudaRadixCiphertextFFI *all_shifted_buffer;
int_sc_prop_memory<Torus> *sc_prop_mem;
bool anticipated_buffers_drop;

Expand All @@ -4447,18 +4444,15 @@ template <typename Torus> struct int_scalar_mul_buffer {
preshifted_buffer = (Torus *)cuda_malloc_async(
num_ciphertext_bits * lwe_size_bytes, streams[0], gpu_indexes[0]);

all_shifted_buffer = (Torus *)cuda_malloc_async(
num_ciphertext_bits * num_radix_blocks * lwe_size_bytes, streams[0],
gpu_indexes[0]);
all_shifted_buffer = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], all_shifted_buffer,
num_ciphertext_bits * num_radix_blocks, params.big_lwe_dimension);

cuda_memset_async(preshifted_buffer, 0,
num_ciphertext_bits * lwe_size_bytes, streams[0],
gpu_indexes[0]);

cuda_memset_async(all_shifted_buffer, 0,
num_ciphertext_bits * num_radix_blocks * lwe_size_bytes,
streams[0], gpu_indexes[0]);

if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2)
logical_scalar_shift_buffer =
new int_logical_scalar_shift_buffer<Torus>(
Expand Down Expand Up @@ -4487,7 +4481,8 @@ template <typename Torus> struct int_scalar_mul_buffer {
sc_prop_mem->release(streams, gpu_indexes, gpu_count);
delete sum_ciphertexts_vec_mem;
delete sc_prop_mem;
cuda_drop_async(all_shifted_buffer, streams[0], gpu_indexes[0]);
release_radix_ciphertext(streams[0], gpu_indexes[0], all_shifted_buffer);
delete all_shifted_buffer;
if (!anticipated_buffers_drop) {
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
Expand Down
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ __host__ void host_unsigned_integer_div_rem_kb(
interesting_remainder1.insert(0, numerator_block_1.first_block(),
streams[0], gpu_indexes[0]);

host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
legacy_host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
streams, gpu_indexes, gpu_count, interesting_remainder1.data, 1,
mem_ptr->shift_mem_1, bsks, ksks, interesting_remainder1.len);

Expand Down Expand Up @@ -369,7 +369,7 @@ __host__ void host_unsigned_integer_div_rem_kb(
auto left_shift_interesting_remainder2 = [&](cudaStream_t const *streams,
uint32_t const *gpu_indexes,
uint32_t gpu_count) {
host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
legacy_host_integer_radix_logical_scalar_shift_kb_inplace<Torus>(
streams, gpu_indexes, gpu_count, interesting_remainder2.data, 1,
mem_ptr->shift_mem_2, bsks, ksks, interesting_remainder2.len);
}; // left_shift_interesting_remainder2
Expand Down
51 changes: 25 additions & 26 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,44 +94,42 @@ __host__ void array_rotate_left(Torus *array_out, Torus *array_in,
// calculation is not inplace, so `dst` and `src` must not be the same
// one block is responsible to process single lwe ciphertext
template <typename Torus>
__host__ void host_radix_blocks_rotate_right(cudaStream_t const *streams,
uint32_t const *gpu_indexes,
uint32_t gpu_count,
CudaRadixCiphertextFFI *dst,
CudaRadixCiphertextFFI *src,
uint32_t rotations) {
__host__ void host_radix_blocks_rotate_right(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, CudaRadixCiphertextFFI *dst,
CudaRadixCiphertextFFI *src, uint32_t rotations, uint32_t num_blocks) {
if (src == dst) {
PANIC("Cuda error (blocks_rotate_right): the source and destination "
"pointers should be different");
}
if (dst->lwe_dimension != src->lwe_dimension)
PANIC("Cuda error: input and output should have the same "
"lwe dimension")
if (dst->num_radix_blocks < num_blocks || src->num_radix_blocks < num_blocks)
PANIC("Cuda error: input and output should have more blocks than asked for "
"in the "
"function call")

auto lwe_size = src->lwe_dimension + 1;

cuda_set_device(gpu_indexes[0]);
radix_blocks_rotate_right<Torus>
<<<src->num_radix_blocks, 1024, 0, streams[0]>>>(
(Torus *)dst->ptr, (Torus *)src->ptr, rotations,
dst->num_radix_blocks, lwe_size);
radix_blocks_rotate_right<Torus><<<num_blocks, 1024, 0, streams[0]>>>(
(Torus *)dst->ptr, (Torus *)src->ptr, rotations, num_blocks, lwe_size);
check_cuda_error(cudaGetLastError());

// Rotate degrees and noise to follow blocks
array_rotate_right(dst->degrees, src->degrees, rotations,
dst->num_radix_blocks);
array_rotate_right(dst->degrees, src->degrees, rotations, num_blocks);
array_rotate_right(dst->noise_levels, src->noise_levels, rotations,
dst->num_radix_blocks);
num_blocks);
}

// rotate radix ciphertext left with specific value
// calculation is not inplace, so `dst` and `src` must not be the same
template <typename Torus>
__host__ void
host_radix_blocks_rotate_left(cudaStream_t const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count,
CudaRadixCiphertextFFI *dst,
CudaRadixCiphertextFFI *src, uint32_t value) {
__host__ void host_radix_blocks_rotate_left(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, CudaRadixCiphertextFFI *dst,
CudaRadixCiphertextFFI *src, uint32_t value, uint32_t num_blocks) {
if (src == dst) {
PANIC("Cuda error (blocks_rotate_left): the source and destination "
"pointers should be different");
Expand All @@ -140,20 +138,21 @@ host_radix_blocks_rotate_left(cudaStream_t const *streams,
if (dst->lwe_dimension != src->lwe_dimension)
PANIC("Cuda error: input and output should have the same "
"lwe dimension")
if (dst->num_radix_blocks < num_blocks || src->num_radix_blocks < num_blocks)
PANIC("Cuda error: input and output should have more blocks than asked for "
"in the "
"function call")

auto lwe_size = src->lwe_dimension + 1;

cuda_set_device(gpu_indexes[0]);
radix_blocks_rotate_left<Torus>
<<<src->num_radix_blocks, 1024, 0, streams[0]>>>(
(Torus *)dst->ptr, (Torus *)src->ptr, value, dst->num_radix_blocks,
lwe_size);
radix_blocks_rotate_left<Torus><<<num_blocks, 1024, 0, streams[0]>>>(
(Torus *)dst->ptr, (Torus *)src->ptr, value, num_blocks, lwe_size);
check_cuda_error(cudaGetLastError());

// Rotate degrees and noise to follow blocks
array_rotate_left(dst->degrees, src->degrees, value, dst->num_radix_blocks);
array_rotate_left(dst->noise_levels, src->noise_levels, value,
dst->num_radix_blocks);
array_rotate_left(dst->degrees, src->degrees, value, num_blocks);
array_rotate_left(dst->noise_levels, src->noise_levels, value, num_blocks);
}

// rotate radix ciphertext right with specific value
Expand Down Expand Up @@ -1836,7 +1835,7 @@ void host_propagate_single_sub_borrow(cudaStream_t const *streams,

host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count,
step_output, generates_or_propagates, 1,
num_blocks, big_lwe_size);
num_blocks);
cuda_memset_async(step_output, 0, big_lwe_size_bytes, streams[0],
gpu_indexes[0]);

Expand Down
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ __host__ void host_integer_scalar_mul_radix(
uint32_t num_ciphertext_bits = msg_bits * num_radix_blocks;

T *preshifted_buffer = mem->preshifted_buffer;
T *all_shifted_buffer = mem->all_shifted_buffer;
T *all_shifted_buffer = (T *)mem->all_shifted_buffer->ptr;

for (size_t shift_amount = 0; shift_amount < msg_bits; shift_amount++) {
T *ptr = preshifted_buffer + shift_amount * lwe_size * num_radix_blocks;
if (has_at_least_one_set[shift_amount] == 1) {
cuda_memcpy_async_gpu_to_gpu(ptr, lwe_array,
lwe_size_bytes * num_radix_blocks,
streams[0], gpu_indexes[0]);
host_integer_radix_logical_scalar_shift_kb_inplace<T>(
legacy_host_integer_radix_logical_scalar_shift_kb_inplace<T>(
streams, gpu_indexes, gpu_count, ptr, shift_amount,
mem->logical_scalar_shift_buffer, bsks, ksks, num_radix_blocks);
} else {
Expand Down
9 changes: 4 additions & 5 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_rotate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ void scratch_cuda_integer_radix_scalar_rotate_kb_64(

void cuda_integer_radix_scalar_rotate_kb_64_inplace(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *lwe_array, uint32_t n, int8_t *mem_ptr, void *const *bsks,
void *const *ksks, uint32_t num_blocks) {
CudaRadixCiphertextFFI *lwe_array, uint32_t n, int8_t *mem_ptr,
void *const *bsks, void *const *ksks) {

host_integer_radix_scalar_rotate_kb_inplace<uint64_t>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(lwe_array), n,
(cudaStream_t *)(streams), gpu_indexes, gpu_count, lwe_array, n,
(int_logical_scalar_shift_buffer<uint64_t> *)mem_ptr, bsks,
(uint64_t **)(ksks), num_blocks);
(uint64_t **)(ksks));
}

void cleanup_cuda_integer_radix_scalar_rotate(void *const *streams,
Expand Down
56 changes: 26 additions & 30 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_rotate.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,14 @@ __host__ void scratch_cuda_integer_radix_scalar_rotate_kb(
template <typename Torus>
__host__ void host_integer_radix_scalar_rotate_kb_inplace(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, Torus *lwe_array, uint32_t n,
uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array, uint32_t n,
int_logical_scalar_shift_buffer<Torus> *mem, void *const *bsks,
Torus *const *ksks, uint32_t num_blocks) {
Torus *const *ksks) {

auto num_blocks = lwe_array->num_radix_blocks;
auto params = mem->params;
auto glwe_dimension = params.glwe_dimension;
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;

size_t big_lwe_size = glwe_dimension * polynomial_size + 1;
size_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);

size_t num_bits_in_message = (size_t)log2_int(message_modulus);
size_t total_num_bits = num_bits_in_message * num_blocks;
n = n % total_num_bits;
Expand All @@ -48,7 +44,7 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(
size_t rotations = n / num_bits_in_message;
size_t shift_within_block = n % num_bits_in_message;

Torus *rotated_buffer = mem->tmp_rotated;
auto rotated_buffer = mem->tmp_rotated;

// rotate right all the blocks in radix ciphertext
// copy result in new buffer
Expand All @@ -57,56 +53,56 @@ __host__ void host_integer_radix_scalar_rotate_kb_inplace(
// one block is responsible to process single lwe ciphertext
if (mem->shift_type == LEFT_SHIFT) {
// rotate right as the blocks are from LSB to MSB
legacy_host_radix_blocks_rotate_right<Torus>(
streams, gpu_indexes, gpu_count, rotated_buffer, lwe_array, rotations,
num_blocks, big_lwe_size);
host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count,
rotated_buffer, lwe_array, rotations,
num_blocks);

cuda_memcpy_async_gpu_to_gpu(lwe_array, rotated_buffer,
num_blocks * big_lwe_size_bytes, streams[0],
gpu_indexes[0]);
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
lwe_array, 0, num_blocks,
rotated_buffer, 0, num_blocks);

if (shift_within_block == 0) {
return;
}

auto receiver_blocks = lwe_array;
auto giver_blocks = rotated_buffer;
legacy_host_radix_blocks_rotate_right<Torus>(
streams, gpu_indexes, gpu_count, giver_blocks, lwe_array, 1, num_blocks,
big_lwe_size);
host_radix_blocks_rotate_right<Torus>(streams, gpu_indexes, gpu_count,
giver_blocks, lwe_array, 1,
num_blocks);

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

legacy_integer_radix_apply_bivariate_lookup_table_kb<Torus>(
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
giver_blocks, bsks, ksks, num_blocks, lut_bivariate,
giver_blocks, bsks, ksks, lut_bivariate, num_blocks,
lut_bivariate->params.message_modulus);

} else {
// rotate left as the blocks are from LSB to MSB
legacy_host_radix_blocks_rotate_left<Torus>(
streams, gpu_indexes, gpu_count, rotated_buffer, lwe_array, rotations,
num_blocks, big_lwe_size);
host_radix_blocks_rotate_left<Torus>(streams, gpu_indexes, gpu_count,
rotated_buffer, lwe_array, rotations,
num_blocks);

cuda_memcpy_async_gpu_to_gpu(lwe_array, rotated_buffer,
num_blocks * big_lwe_size_bytes, streams[0],
gpu_indexes[0]);
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
lwe_array, 0, num_blocks,
rotated_buffer, 0, num_blocks);

if (shift_within_block == 0) {
return;
}

auto receiver_blocks = lwe_array;
auto giver_blocks = rotated_buffer;
legacy_host_radix_blocks_rotate_left<Torus>(streams, gpu_indexes, gpu_count,
giver_blocks, lwe_array, 1,
num_blocks, big_lwe_size);
host_radix_blocks_rotate_left<Torus>(streams, gpu_indexes, gpu_count,
giver_blocks, lwe_array, 1,
num_blocks);

auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];

legacy_integer_radix_apply_bivariate_lookup_table_kb<Torus>(
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array, receiver_blocks,
giver_blocks, bsks, ksks, num_blocks, lut_bivariate,
giver_blocks, bsks, ksks, lut_bivariate, num_blocks,
lut_bivariate->params.message_modulus);
}
}
Expand Down
Loading
Loading