Skip to content

Commit

Permalink
chore(gpu): refactor template and clean arguments for the PBS
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jul 30, 2024
1 parent f937524 commit 463c22e
Show file tree
Hide file tree
Showing 21 changed files with 273 additions and 428 deletions.
27 changes: 10 additions & 17 deletions backends/tfhe-cuda-backend/cuda/include/programmable_bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,15 @@ void cuda_programmable_bootstrap_amortized_lwe_ciphertext_vector_32(
void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key,
int8_t *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t num_luts, uint32_t lwe_idx,
uint32_t max_shared_memory);
uint32_t num_samples, uint32_t max_shared_memory);

void cuda_programmable_bootstrap_amortized_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void *lwe_output_indexes, void *lut_vector, void *lut_vector_indexes,
void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key,
int8_t *pbs_buffer, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t num_luts, uint32_t lwe_idx,
uint32_t max_shared_memory);
uint32_t num_samples, uint32_t max_shared_memory);

void cleanup_cuda_programmable_bootstrap_amortized(void *stream,
uint32_t gpu_index,
Expand All @@ -75,17 +73,15 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32(
void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key,
int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t num_luts, uint32_t lwe_idx,
uint32_t max_shared_memory);
uint32_t num_samples, uint32_t max_shared_memory);

void cuda_programmable_bootstrap_lwe_ciphertext_vector_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void *lwe_output_indexes, void *lut_vector, void *lut_vector_indexes,
void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key,
int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t num_luts, uint32_t lwe_idx,
uint32_t max_shared_memory);
uint32_t num_samples, uint32_t max_shared_memory);

void cleanup_cuda_programmable_bootstrap(void *stream, uint32_t gpu_index,
int8_t **pbs_buffer);
Expand Down Expand Up @@ -353,8 +349,7 @@ void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector(
Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_luts,
uint32_t lwe_idx, uint32_t max_shared_memory);
uint32_t level_count, uint32_t num_samples, uint32_t max_shared_memory);

template <typename Torus>
void cuda_programmable_bootstrap_lwe_ciphertext_vector(
Expand All @@ -363,8 +358,7 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector(
Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_luts,
uint32_t lwe_idx, uint32_t max_shared_memory);
uint32_t level_count, uint32_t num_samples, uint32_t max_shared_memory);

#if (CUDA_ARCH >= 900)
template <typename Torus>
Expand All @@ -374,25 +368,24 @@ void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector(
Torus *lwe_array_in, Torus *lwe_input_indexes, double2 *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_luts,
uint32_t lwe_idx, uint32_t max_shared_memory);
uint32_t level_count, uint32_t num_samples, uint32_t max_shared_memory);

template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_programmable_bootstrap_tbc(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory,
bool allocate_gpu_memory);
#endif

template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_programmable_bootstrap_cg(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory,
bool allocate_gpu_memory);

template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64(
void *lwe_array_in, void *lwe_input_indexes, void *bootstrapping_key,
int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_luts,
uint32_t lwe_idx, uint32_t max_shared_memory, uint32_t lwe_chunk_size = 0);
uint32_t level_count, uint32_t num_samples, uint32_t max_shared_memory,
uint32_t lwe_chunk_size = 0);

void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream,
uint32_t gpu_index,
Expand All @@ -47,7 +47,7 @@ bool has_support_to_cuda_programmable_bootstrap_tbc_multi_bit(
uint32_t level_count, uint32_t max_shared_memory);

#if CUDA_ARCH >= 900
template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_tbc_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
Expand All @@ -63,19 +63,18 @@ void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
pbs_buffer<Torus, MULTI_BIT> *pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory,
uint32_t lwe_chunk_size);
uint32_t max_shared_memory, uint32_t lwe_chunk_size);
#endif

template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_cg_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t grouping_factor,
uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory,
bool allocate_gpu_memory, uint32_t lwe_chunk_size = 0);

template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_cg_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
Expand All @@ -90,10 +89,9 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
pbs_buffer<Torus, MULTI_BIT> *pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory,
uint32_t lwe_chunk_size = 0);
uint32_t max_shared_memory, uint32_t lwe_chunk_size = 0);

template <typename Torus, typename STorus>
template <typename Torus>
void scratch_cuda_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
Expand All @@ -109,8 +107,7 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
pbs_buffer<Torus, MULTI_BIT> *pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_luts, uint32_t lwe_idx, uint32_t max_shared_memory,
uint32_t lwe_chunk_size = 0);
uint32_t max_shared_memory, uint32_t lwe_chunk_size = 0);

template <typename Torus>
__host__ __device__ uint64_t
Expand Down
12 changes: 6 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, 1, 0,
grouping_factor, num_radix_blocks,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);
} else {
/// Make sure all data that should be on GPU 0 is indeed there
Expand All @@ -204,7 +204,7 @@ __host__ void integer_radix_apply_univariate_lookup_table_kb(
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
pbs_level, grouping_factor, num_radix_blocks, 1, 0,
pbs_level, grouping_factor, num_radix_blocks,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);

/// Copy data back to GPU 0 and release vecs
Expand Down Expand Up @@ -270,7 +270,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0],
lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension,
small_lwe_dimension, polynomial_size, pbs_base_log, pbs_level,
grouping_factor, num_radix_blocks, 1, 0,
grouping_factor, num_radix_blocks,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);
} else {
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
Expand All @@ -293,7 +293,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table_kb(
lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec,
lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size, pbs_base_log,
pbs_level, grouping_factor, num_radix_blocks, 1, 0,
pbs_level, grouping_factor, num_radix_blocks,
cuda_get_max_shared_memory(gpu_indexes[0]), pbs_type);

/// Copy data back to GPU 0 and release vecs
Expand Down Expand Up @@ -696,8 +696,8 @@ void host_full_propagate_inplace(cudaStream_t *streams, uint32_t *gpu_indexes,
mem_ptr->lut->lwe_trivial_indexes, bsks, mem_ptr->lut->buffer,
params.glwe_dimension, params.small_lwe_dimension,
params.polynomial_size, params.pbs_base_log, params.pbs_level,
params.grouping_factor, 2, 2, 0,
cuda_get_max_shared_memory(gpu_indexes[0]), params.pbs_type);
params.grouping_factor, 2, cuda_get_max_shared_memory(gpu_indexes[0]),
params.pbs_type);

cuda_memcpy_async_gpu_to_gpu(cur_input_block, mem_ptr->tmp_big_lwe_vector,
big_lwe_size * sizeof(Torus), streams[0],
Expand Down
14 changes: 7 additions & 7 deletions backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,55 +133,55 @@ void cuda_integer_mult_radix_ciphertext_kb_64(

switch (polynomial_size) {
case 256:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<256>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<256>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
static_cast<uint64_t *>(radix_lwe_right), bsks, (uint64_t **)(ksks),
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
break;
case 512:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<512>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<512>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
static_cast<uint64_t *>(radix_lwe_right), bsks, (uint64_t **)(ksks),
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
break;
case 1024:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<1024>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<1024>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
static_cast<uint64_t *>(radix_lwe_right), bsks, (uint64_t **)(ksks),
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
break;
case 2048:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<2048>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<2048>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
static_cast<uint64_t *>(radix_lwe_right), bsks, (uint64_t **)(ksks),
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
break;
case 4096:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<4096>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<4096>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
static_cast<uint64_t *>(radix_lwe_right), bsks, (uint64_t **)(ksks),
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
break;
case 8192:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<8192>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<8192>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
static_cast<uint64_t *>(radix_lwe_right), bsks, (uint64_t **)(ksks),
(int_mul_memory<uint64_t> *)mem_ptr, num_blocks);
break;
case 16384:
host_integer_mult_radix_kb<uint64_t, int64_t, AmortizedDegree<16384>>(
host_integer_mult_radix_kb<uint64_t, AmortizedDegree<16384>>(
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
static_cast<uint64_t *>(radix_lwe_out),
static_cast<uint64_t *>(radix_lwe_left),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
small_lwe_vector, lwe_indexes_in, bsks, luts_message_carry->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size,
mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level,
mem_ptr->params.grouping_factor, total_count, 2, 0, max_shared_memory,
mem_ptr->params.grouping_factor, total_count, max_shared_memory,
mem_ptr->params.pbs_type);
} else {
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
Expand Down Expand Up @@ -420,7 +420,7 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
lwe_trivial_indexes_vec, bsks, luts_message_carry->buffer,
glwe_dimension, small_lwe_dimension, polynomial_size,
mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level,
mem_ptr->params.grouping_factor, total_count, 2, 0, max_shared_memory,
mem_ptr->params.grouping_factor, total_count, max_shared_memory,
mem_ptr->params.pbs_type);

multi_gpu_gather_lwe_async<Torus>(
Expand Down Expand Up @@ -457,7 +457,7 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
mem_ptr->scp_mem, bsks, ksks, num_blocks);
}

template <typename Torus, typename STorus, class params>
template <typename Torus, class params>
__host__ void host_integer_mult_radix_kb(
cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count,
uint64_t *radix_lwe_out, uint64_t *radix_lwe_left,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ void execute_pbs_async(
std::vector<int8_t *> pbs_buffer, uint32_t glwe_dimension,
uint32_t lwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t grouping_factor,
uint32_t input_lwe_ciphertext_count, uint32_t num_luts, uint32_t lwe_idx,
uint32_t max_shared_memory, PBS_TYPE pbs_type) {
uint32_t input_lwe_ciphertext_count, uint32_t max_shared_memory,
PBS_TYPE pbs_type) {
switch (sizeof(Torus)) {
case sizeof(uint32_t):
// 32 bits
Expand Down Expand Up @@ -160,8 +160,8 @@ void execute_pbs_async(
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
current_lwe_array_in, current_lwe_input_indexes,
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
polynomial_size, base_log, level_count, num_inputs_on_gpu, num_luts,
lwe_idx, max_shared_memory);
polynomial_size, base_log, level_count, num_inputs_on_gpu,
max_shared_memory);
}
break;
default:
Expand Down Expand Up @@ -200,7 +200,7 @@ void execute_pbs_async(
current_lwe_array_in, current_lwe_input_indexes,
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
polynomial_size, grouping_factor, base_log, level_count,
num_inputs_on_gpu, num_luts, lwe_idx, max_shared_memory);
num_inputs_on_gpu, max_shared_memory);
}
break;
case CLASSICAL:
Expand Down Expand Up @@ -228,8 +228,8 @@ void execute_pbs_async(
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
current_lwe_array_in, current_lwe_input_indexes,
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
polynomial_size, base_log, level_count, num_inputs_on_gpu, num_luts,
lwe_idx, max_shared_memory);
polynomial_size, base_log, level_count, num_inputs_on_gpu,
max_shared_memory);
}
break;
default:
Expand Down
Loading

0 comments on commit 463c22e

Please sign in to comment.