Skip to content

Commit

Permalink
Revert "chore(gpu): optimize mult"
Browse files Browse the repository at this point in the history
This reverts commit 36e72f8.
  • Loading branch information
agnesLeroy committed Jul 16, 2024
1 parent b71751e commit 1a58168
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
16 changes: 15 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
streams, gpu_indexes, active_gpu_count, new_blocks_vec, new_blocks,
luts_message_carry->h_lwe_indexes_in,
luts_message_carry->using_trivial_lwe_indexes, message_count,
big_lwe_size, false, total_count);
big_lwe_size, false);

/// Apply KS to go from a big LWE dimension to a small LWE dimension
/// After this keyswitch execution, we need to synchronize the streams
Expand All @@ -393,6 +393,20 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
mem_ptr->params.ks_base_log,
mem_ptr->params.ks_level, message_count, false);

/// Copy data back to GPU 0, rebuild the lwe array, and scatter again on a
/// different configuration
multi_gpu_gather_lwe<Torus>(streams, gpu_indexes, active_gpu_count,
small_lwe_vector, small_lwe_vector_vec,
luts_message_carry->h_lwe_indexes_in,
luts_message_carry->using_trivial_lwe_indexes,
message_count, small_lwe_size);

multi_gpu_scatter_lwe<Torus>(
streams, gpu_indexes, active_gpu_count, small_lwe_vector_vec,
small_lwe_vector, luts_message_carry->h_lwe_indexes_in,
luts_message_carry->using_trivial_lwe_indexes, total_count,
small_lwe_size, false);

/// Apply PBS to apply a LUT, reduce the noise and go from a small LWE
/// dimension to a big LWE dimension
execute_pbs<Torus>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,9 @@ void multi_gpu_scatter_lwe(cudaStream_t *streams, uint32_t *gpu_indexes,
Torus *src, Torus *h_src_indexes,
bool is_trivial_index, uint32_t num_inputs,
uint32_t elements_per_input,
bool sync_threads = true,
uint32_t num_inputs_pbs_mul = 0) {
bool sync_threads = true) {

auto active_gpu_count = get_active_gpu_count(num_inputs, gpu_count);
if (num_inputs_pbs_mul == 0)
num_inputs_pbs_mul = num_inputs;

if (sync_threads)
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
Expand All @@ -89,8 +86,7 @@ void multi_gpu_scatter_lwe(cudaStream_t *streams, uint32_t *gpu_indexes,
auto inputs_on_gpu = get_num_inputs_on_gpu(num_inputs, i, active_gpu_count);
auto gpu_offset = 0;
for (uint j = 0; j < i; j++) {
gpu_offset +=
get_num_inputs_on_gpu(num_inputs_pbs_mul, j, active_gpu_count);
gpu_offset += get_num_inputs_on_gpu(num_inputs, j, active_gpu_count);
}

if (is_trivial_index) {
Expand Down

0 comments on commit 1a58168

Please sign in to comment.