Skip to content

Commit

Permalink
refactor(gpu): avoid synchronizations in the keybundle
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Aug 16, 2024
1 parent 3397aa8 commit 9c10e05
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#include <vector>

template <typename Torus, class params>
__device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t grouping_factor) {
__device__ uint32_t calculates_monomial_degree(const Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t grouping_factor) {
Torus x = 0;
for (int i = 0; i < grouping_factor; i++) {
uint32_t mask_position = grouping_factor - (i + 1);
Expand All @@ -31,6 +31,13 @@ __device__ Torus calculates_monomial_degree(const Torus *lwe_array_group,
return modulus_switch(x, params::log2_degree + 1);
}

__device__ __forceinline__ int
get_start_ith_ggsw_offset(int i, uint32_t polynomial_size, int glwe_dimension,
uint32_t level_count) {
return i * polynomial_size / 2 * (glwe_dimension + 1) * (glwe_dimension + 1) *
level_count;
}

template <typename Torus, class params, sharedMemDegree SMD>
__global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *__restrict__ lwe_array_in,
Expand Down Expand Up @@ -61,7 +68,7 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(

if (lwe_iteration < (lwe_dimension / grouping_factor)) {
//
Torus *accumulator = (Torus *)selected_memory;
// Torus *accumulator = (Torus *)selected_memory;

const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[input_idx] * (lwe_dimension + 1)];
Expand All @@ -81,56 +88,49 @@ __global__ void device_multi_bit_programmable_bootstrap_keybundle(
const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, 0, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly = bsk_slice + poly_id * params::degree;
const Torus *bsk_poly_ini = bsk_slice + poly_id * params::degree;
Torus reg_acc[params::opt];
copy_polynomial_in_regs<Torus, params::opt, params::degree / params::opt>(
bsk_poly_ini, reg_acc);

copy_polynomial<Torus, params::opt, params::degree / params::opt>(
bsk_poly, accumulator);
int offset = get_start_ith_ggsw_offset(1, 2 * polynomial_size,
glwe_dimension, level_count);

// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {

const Torus *bsk_slice = get_multi_bit_ith_lwe_gth_group_kth_block(
bootstrapping_key, g, rev_lwe_iteration, glwe_id, level_id,
grouping_factor, 2 * polynomial_size, glwe_dimension, level_count);
const Torus *bsk_poly = bsk_slice + poly_id * params::degree;

// Calculates the monomial degree
__shared__ uint32_t mono_degs[8];
if (threadIdx.x < (1 << grouping_factor)) {
const Torus *lwe_array_group =
block_lwe_array_in + rev_lwe_iteration * grouping_factor;
uint32_t monomial_degree = calculates_monomial_degree<Torus, params>(
lwe_array_group, g, grouping_factor);

synchronize_threads_in_block();
// Multiply by the bsk element
polynomial_product_accumulate_by_monomial<Torus, params>(
accumulator, bsk_poly, monomial_degree, false);
mono_degs[threadIdx.x] = calculates_monomial_degree<Torus, params>(
lwe_array_group, threadIdx.x, grouping_factor);
}
__syncthreads();

synchronize_threads_in_block();
// Accumulate the other terms
for (int g = 1; g < (1 << grouping_factor); g++) {

// Move accumulator to local memory
double2 temp[params::opt / 2];
int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
temp[i].x = __ll2double_rn((int64_t)accumulator[tid]);
temp[i].y =
__ll2double_rn((int64_t)accumulator[tid + params::degree / 2]);
temp[i].x /= (double)std::numeric_limits<Torus>::max();
temp[i].y /= (double)std::numeric_limits<Torus>::max();
tid += params::degree / params::opt;
uint32_t monomial_degree = mono_degs[g];

const Torus *bsk_poly = bsk_poly_ini + g * offset;
// Multiply by the bsk element
polynomial_product_accumulate_by_monomial_nosync<Torus, params>(
reg_acc, bsk_poly, monomial_degree);
}

synchronize_threads_in_block();
// Move from local memory back to shared memory but as complex
tid = threadIdx.x;
int tid = threadIdx.x;
double2 *fft = (double2 *)selected_memory;
#pragma unroll
for (int i = 0; i < params::opt / 2; i++) {
fft[tid] = temp[i];
fft[tid] =
make_double2(__ll2double_rn((int64_t)reg_acc[i]) /
(double)std::numeric_limits<Torus>::max(),
__ll2double_rn((int64_t)reg_acc[i + params::opt / 2]) /
(double)std::numeric_limits<Torus>::max());
tid += params::degree / params::opt;
}
synchronize_threads_in_block();
// synchronize_threads_in_block(); // seems that we can get rid of this
// sync, since in the first iteration of fft we access only to values that
// are in regs
NSMFFT_direct<HalfDegree<params>>(fft);

// lwe iteration
Expand Down
9 changes: 9 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ __device__ void copy_polynomial(const T *__restrict__ source, T *dst) {
tid = tid + block_size;
}
}
template <typename T, int elems_per_thread, int block_size>
__device__ void copy_polynomial_in_regs(const T *__restrict__ source, T *dst) {
// int tid = threadIdx.x;
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
dst[i] = source[threadIdx.x + i * block_size];
// tid = tid + block_size;
}
}

/*
* Receives num_poly concatenated polynomials of type T. For each:
Expand Down
31 changes: 31 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/polynomial/polynomial_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,35 @@ polynomial_product_accumulate_by_monomial(T *result, const T *__restrict__ poly,
}
}

template <typename T, class params>
__device__ void polynomial_product_accumulate_by_monomial_nosync(
T *result, const T *__restrict__ poly, uint32_t monomial_degree) {
// monomial_degree \in [0, 2 * params::degree)
// int full_cycles_count = monomial_degree / params::degree;
int full_cycles_count = monomial_degree >> log2(params::degree);

int remainder_degrees = monomial_degree & (params::degree - 1);
// int remainder_degrees = monomial_degree % params::degree;

// Every thread has a fixed position to track instead of "chasing" the
// position
// int new_pos = threadIdx.x;
#pragma unroll
for (int i = 0; i < params::opt; i++) {
// int pos = (new_pos - monomial_degree) % params::degree;
int pos =
(threadIdx.x + i * (params::degree / params::opt) - monomial_degree) &
(params::degree - 1);

T element = poly[pos];
T x = SEL(element, -element, full_cycles_count % 2); // monomial coefficient
x = SEL(-x, x,
threadIdx.x + i * (params::degree / params::opt) >=
remainder_degrees);

result[i] += x;
// new_pos += params::degree / params::opt;
}
}

#endif // CNCRT_POLYNOMIAL_MATH_H

0 comments on commit 9c10e05

Please sign in to comment.