Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(gpu): avoid synchronizations in the keybundle #1505

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#include <vector>

template <typename Torus, class params>
__device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t grouping_factor) {
__device__ uint32_t calculates_monomial_degree(const Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t grouping_factor) {
Torus x = 0;
for (int i = 0; i < grouping_factor; i++) {
uint32_t mask_position = grouping_factor - (i + 1);
Expand All @@ -31,6 +31,13 @@ __device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
return modulus_switch(x, params::log2_degree + 1);
}

__device__ __forceinline__ int
get_start_ith_ggsw_offset(uint32_t polynomial_size, int glwe_dimension,
uint32_t level_count) {
return polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) *
level_count;
}

template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *__restrict__ lwe_array_in,
Expand Down Expand Up @@ -60,8 +67,6 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
uint32_t input_idx = blockIdx.x / lwe_chunk_size;

if (lwe_iteration < (lwe_dimension / grouping_factor)) {
//
Torus *accumulator = (Torus *)selected_memory;

const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
Expand All @@ -81,57 +86,52 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, 0, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly = bsk_slice + poly_id * params::degree;
const Torus *bsk_poly_ini = bsk_slice + poly_id * params::degree;

copy_polynomial<Torus, params::opt, params::degree / params::opt>(
bsk_poly, accumulator);
Torus reg_acc[params::opt];

// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {
copy_polynomial_in_regs<Torus, params::opt, params::degree / params::opt>(
bsk_poly_ini, reg_acc);

const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, g, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly = bsk_slice + poly_id * params::degree;
int offset =
get_start_ith_ggsw_offset(polynomial_size, glwe_dimension, level_count);

// Calculates the monomial degree
// Precalculate the monomial degrees and store them in shared memory
uint32_t *monomial_degrees = (uint32_t *)selected_memory;
if (threadIdx.x < (1 << grouping_factor)) {
const Torus *lwe_array_group =
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
uint32_t monomial_degree = calculates_monomial_degree<Torus, params>(
lwe_array_group, g, grouping_factor);

synchronize_threads_in_block();
// Multiply by the bsk element
polynomial_accumulate_monic_monomial_mul<Torus>(
accumulator, bsk_poly, monomial_degree, threadIdx.x, params::degree,
params::opt, false);
monomial_degrees[threadIdx.x] = calculates_monomial_degree<Torus, params>(
lwe_array_group, threadIdx.x, grouping_factor);
}

synchronize_threads_in_block();

// Move accumulator to local memory
double2 temp[params::opt / 2];
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
temp[i].x = __ll2double_rn((int64_t)accumulator[tid]);
temp[i].y =
__ll2double_rn((int64_t)accumulator[tid + params::degree / 2]);
temp[i].x /= (double)std::numeric_limits<Torus>::max();
temp[i].y /= (double)std::numeric_limits<Torus>::max();
tid += params::degree / params::opt;
// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {

uint32_t monomial_degree = monomial_degrees[g];

const Torus *bsk_poly = bsk_poly_ini + g * offset;
// Multiply by the bsk element
polynomial_product_accumulate_by_monomial_nosync<Torus, params>(
reg_acc, bsk_poly, monomial_degree);
}
synchronize_threads_in_block(); // needed because we are going to reuse the
// shared memory for the fft

synchronize_threads_in_block();
// Move from local memory back to shared memory but as complex
tid = threadIdx.x;
int tid = threadIdx.x;
double2 *fft = (double2 *)selected_memory;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] = temp[i];
fft[tid] =
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
(double)std::numeric_limits<Torus>::max());
tid += params::degree / params::opt;
}
synchronize_threads_in_block();

NSMFFT_direct<HalfDegree<params>>(fft);

// lwe iteration
Expand Down
7 changes: 7 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ __device__ void copy_polynomial(const T *__restrict__ source, T *dst) {
tid = tid + block_size;
}
}
template <typename T, int elems_per_thread, int block_size>
__device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
dst[i] = source[threadIdx.x + i * block_size];
}
}

/*
* Receives num_poly concatenated polynomials of type T. For each:
Expand Down
25 changes: 25 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,29 @@ __device__ void polynomial_accumulate_monic_monomial_mul(
}
}

template <typename T, class params>
__device__ void polynomial_product_accumulate_by_monomial_nosync(
T *result, const T *__restrict__ poly, uint32_t monomial_degree) {
// monomial_degree \in [0, 2 * params::degree)
int full_cycles_count = monomial_degree / params::degree;
int remainder_degrees = monomial_degree % params::degree;

// Every thread has a fixed position to track instead of "chasing" the
// position
#pragma unroll
for (int i = 0; i < params::opt; i++) {
int pos =
(threadIdx.x + i * (params::degree / params::opt) - monomial_degree) &
(params::degree - 1);

T element = poly[pos];
T x = SEL(element, -element, full_cycles_count % 2);
x = SEL(-x, x,
threadIdx.x + i * (params::degree / params::opt) >=
remainder_degrees);

result[i] += x;
}
}

#endif // CNCRT_POLYNOMIAL_MATH_H
Loading