-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(gpu): start using a struct to pass data across rust/c++
- Loading branch information
1 parent
b46affa
commit c3fc1cb
Showing
46 changed files
with
971 additions
and
807 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
backends/tfhe-cuda-backend/cuda/include/integer/radix_ciphertext.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
#ifndef CUDA_RADIX_CIPHERTEXT_H | ||
#define CUDA_RADIX_CIPHERTEXT_H | ||
|
||
#include "device.h" | ||
#include "integer.h" | ||
|
||
template <typename Torus> | ||
void create_trivial_radix_ciphertext_async( | ||
cudaStream_t const stream, uint32_t const gpu_index, | ||
CudaRadixCiphertextData *output_radix, uint32_t num_radix_blocks, | ||
uint32_t lwe_dimension) { | ||
uint32_t lwe_size_bytes = (lwe_dimension + 1) * sizeof(Torus); | ||
output_radix->ptr = (void *)cuda_malloc_async( | ||
num_radix_blocks * lwe_size_bytes, stream, gpu_index); | ||
for (uint i = 0; i < output_radix->num_radix_blocks; i++) { | ||
output_radix->degrees[i] = 0; | ||
output_radix->noise_levels[i] = 0; | ||
} | ||
output_radix->lwe_dimension = lwe_dimension; | ||
output_radix->num_radix_blocks = num_radix_blocks; | ||
} | ||
|
||
template <typename Torus> | ||
void as_radix_ciphertext_slice(CudaRadixCiphertextData *output_radix, | ||
const CudaRadixCiphertextData *input_radix, | ||
uint32_t start_lwe_index, | ||
uint32_t end_lwe_index) { | ||
if (input_radix->num_radix_blocks < start_lwe_index - end_lwe_index + 1) | ||
PANIC("Cuda error: input radix should have more blocks than the specified " | ||
"range") | ||
if (start_lwe_index <= end_lwe_index) | ||
PANIC("Cuda error: slice range should be strictly positive") | ||
|
||
auto lwe_size = input_radix->lwe_dimension + 1; | ||
output_radix->num_radix_blocks = end_lwe_index - start_lwe_index + 1; | ||
output_radix->lwe_dimension = input_radix->lwe_dimension; | ||
Torus *in_ptr = (Torus *)input_radix->ptr; | ||
output_radix->ptr = (void *)(&in_ptr[start_lwe_index * lwe_size]); | ||
for (uint i = 0; i < output_radix->num_radix_blocks; i++) { | ||
output_radix->degrees[i] = | ||
input_radix->degrees[i + start_lwe_index * lwe_size]; | ||
output_radix->noise_levels[i] = | ||
input_radix->noise_levels[i + start_lwe_index * lwe_size]; | ||
} | ||
} | ||
|
||
template <typename Torus> | ||
void copy_radix_ciphertext_to_larger_output_slice_async( | ||
cudaStream_t const stream, uint32_t const gpu_index, | ||
CudaRadixCiphertextData *output_radix, | ||
const CudaRadixCiphertextData *input_radix, | ||
uint32_t output_start_lwe_index) { | ||
if (output_radix->lwe_dimension != input_radix->lwe_dimension) | ||
PANIC("Cuda error: input lwe dimension should be equal to output lwe " | ||
"dimension") | ||
if (output_radix->num_radix_blocks - output_start_lwe_index != | ||
input_radix->num_radix_blocks) | ||
PANIC("Cuda error: input radix should have the same number of blocks as " | ||
"the output range") | ||
if (output_start_lwe_index >= output_radix->num_radix_blocks) | ||
PANIC("Cuda error: output index should be strictly smaller than the number " | ||
"of blocks") | ||
|
||
auto lwe_size = input_radix->lwe_dimension + 1; | ||
Torus *out_ptr = (Torus *)output_radix->ptr; | ||
out_ptr = &out_ptr[output_start_lwe_index * lwe_size]; | ||
|
||
cuda_memcpy_async_gpu_to_gpu(out_ptr, input_radix->ptr, | ||
input_radix->num_radix_blocks * | ||
(input_radix->lwe_dimension + 1) * | ||
sizeof(Torus), | ||
stream, gpu_index); | ||
for (uint i = 0; i < input_radix->num_radix_blocks; i++) { | ||
output_radix->degrees[i + output_start_lwe_index] = input_radix->degrees[i]; | ||
output_radix->noise_levels[i + output_start_lwe_index] = | ||
input_radix->noise_levels[i]; | ||
} | ||
} | ||
|
||
#endif // CUDA_RADIX_CIPHERTEXT_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.