Skip to content

Commit

Permalink
chore(gpu): refactor small scalar mul to keep track of degree and noi…
Browse files Browse the repository at this point in the history
…se changes
  • Loading branch information
pdroalves committed Feb 3, 2025
1 parent e3aa179 commit 1a47ff8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
40 changes: 39 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ __host__ void host_integer_scalar_mul_radix(

// Small scalar_mul is used in shift/rotate
template <typename T>
__host__ void host_integer_small_scalar_mul_radix(
__host__ void host_legacy_integer_small_scalar_mul_radix(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, T *output_lwe_array, T *input_lwe_array, T scalar,
uint32_t input_lwe_dimension, uint32_t input_lwe_ciphertext_count) {
Expand All @@ -143,4 +143,42 @@ __host__ void host_integer_small_scalar_mul_radix(
input_lwe_ciphertext_count);
check_cuda_error(cudaGetLastError());
}

// Small scalar_mul is used in shift/rotate
template <typename T>
__host__ void host_integer_small_scalar_mul_radix(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, CudaRadixCiphertextFFI *output_lwe_array,
CudaRadixCiphertextFFI *input_lwe_array, T scalar) {

if (output_lwe_array->num_radix_blocks != input_lwe_array->num_radix_blocks)
PANIC("Cuda error: input and output num radix blocks must be the same")
if (output_lwe_array->lwe_dimension != input_lwe_array->lwe_dimension)
PANIC("Cuda error: input and output lwe_dimension must be the same")

cuda_set_device(gpu_indexes[0]);
auto lwe_dimension = input_lwe_array->lwe_dimension;
auto num_radix_blocks = input_lwe_array->num_radix_blocks;

// lwe_size includes the presence of the body
// whereas lwe_dimension is the number of elements in the mask
int lwe_size = lwe_dimension + 1;
// Create a 1-dimensional grid of threads
int num_blocks = 0, num_threads = 0;
int num_entries = num_radix_blocks * lwe_size;
getNumBlocksAndThreads(num_entries, 512, num_blocks, num_threads);
dim3 grid(num_blocks, 1, 1);
dim3 thds(num_threads, 1, 1);

device_small_scalar_radix_multiplication<<<grid, thds, 0, streams[0]>>>(
(T *)output_lwe_array->ptr, (T *)input_lwe_array->ptr, scalar,
lwe_dimension, num_radix_blocks);
check_cuda_error(cudaGetLastError());

for (int i = 0; i < num_radix_blocks; i++) {
output_lwe_array->noise_levels[i] =
input_lwe_array->noise_levels[i] * scalar;
output_lwe_array->degrees[i] = input_lwe_array->degrees[i] * scalar;
}
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ __host__ void host_integer_radix_shift_and_rotate_kb_inplace(
lwe_last_out = lwe_array;
for (int i = bits_per_block - 2; i >= 0; i--) {

host_integer_small_scalar_mul_radix<Torus>(
host_legacy_integer_small_scalar_mul_radix<Torus>(
streams, gpu_indexes, gpu_count, lwe_last_out, lwe_last_out, 2,
big_lwe_dimension, num_radix_blocks);

Expand Down

0 comments on commit 1a47ff8

Please sign in to comment.