Skip to content

Commit

Permalink
fix(gpu): compressed list gpu <-> cpu
Browse files Browse the repository at this point in the history
Some counts where to copied from the correct
source to correct destination.

And more importantly, the list on cuda side was stored
using a GlweCiphertextList but the data was compressed
(so the list was mostly empty). This use of a GlweList
instead of a specialized type lead to problems when converting
to Cpu
  • Loading branch information
tmontaigu committed Feb 12, 2025
1 parent d0b0fe8 commit 16d8af1
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 103 deletions.
6 changes: 6 additions & 0 deletions tfhe/src/core_crypto/entities/glwe_ciphertext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ pub fn glwe_ciphertext_size(glwe_size: GlweSize, polynomial_size: PolynomialSize
glwe_size.0 * polynomial_size.0
}

/// Return the number of elements in the **mask** of a [`GlweCiphertext`]
/// given a [`GlweDimension`] and [`PolynomialSize`].
pub fn glwe_mask_size(glwe_dim: GlweDimension, polynomial_size: PolynomialSize) -> usize {
glwe_dim.0 * polynomial_size.0
}

/// Return the number of elements in a [`GlweMask`] given a [`GlweDimension`] and
/// [`PolynomialSize`].
pub fn glwe_ciphertext_mask_size(
Expand Down
9 changes: 1 addition & 8 deletions tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,15 +703,8 @@ pub struct CudaGlweList<T: UnsignedInteger> {

impl<T: UnsignedInteger> CudaGlweList<T> {
pub fn duplicate(&self, streams: &CudaStreams) -> Self {
let d_vec = unsafe {
let mut d_vec = CudaVec::new_async(self.d_vec.len(), streams, 0);
d_vec.copy_from_gpu_async(&self.d_vec, streams, 0);
d_vec
};
streams.synchronize();

Self {
d_vec,
d_vec: self.d_vec.duplicate(streams),
glwe_ciphertext_count: self.glwe_ciphertext_count,
glwe_dimension: self.glwe_dimension,
polynomial_size: self.polynomial_size,
Expand Down
10 changes: 10 additions & 0 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ impl<T: Numeric> CudaVec<T> {
pub fn is_empty(&self) -> bool {
self.len == 0
}

pub fn duplicate(&self, streams: &CudaStreams) -> Self {
let d_vec = unsafe {
let mut d_vec = Self::new_async(self.len(), streams, 0);
d_vec.copy_from_gpu_async(self, streams, 0);
d_vec
};
streams.synchronize();
d_vec
}
}

// SAFETY
Expand Down
267 changes: 207 additions & 60 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
use crate::core_crypto::entities::packed_integers::PackedIntegers;
use crate::core_crypto::entities::GlweCiphertextList;
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::vec::GpuIndex;
use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext;
use crate::core_crypto::prelude::{
glwe_ciphertext_size, CiphertextCount, ContiguousEntityContainer, LweCiphertextCount,
};
use crate::core_crypto::prelude::{CiphertextCount, LweCiphertextCount};
use crate::integer::ciphertext::{CompressedCiphertextList, DataKind};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
CudaUnsignedRadixCiphertext,
};
use crate::integer::gpu::list_compression::server_keys::{
CudaCompressionKey, CudaDecompressionKey, CudaPackedGlweCiphertext,
CudaCompressionKey, CudaDecompressionKey, CudaPackedGlweCiphertextList,
};
use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList;
use crate::shortint::PBSOrder;
Expand Down Expand Up @@ -62,13 +58,13 @@ impl CudaExpandable for CudaBooleanBlock {
}
}
pub struct CudaCompressedCiphertextList {
pub(crate) packed_list: CudaPackedGlweCiphertext,
pub(crate) packed_list: CudaPackedGlweCiphertextList,
pub(crate) info: Vec<DataKind>,
}

impl CudaCompressedCiphertextList {
pub fn gpu_indexes(&self) -> &[GpuIndex] {
&self.packed_list.glwe_ciphertext_list.0.d_vec.gpu_indexes
&self.packed_list.data.gpu_indexes
}
pub fn len(&self) -> usize {
self.info.len()
Expand Down Expand Up @@ -182,39 +178,45 @@ impl CudaCompressedCiphertextList {
/// let converted_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams);
/// ```
pub fn to_compressed_ciphertext_list(&self, streams: &CudaStreams) -> CompressedCiphertextList {
let glwe_list = self
.packed_list
.glwe_ciphertext_list
.to_glwe_ciphertext_list(streams);
let ciphertext_modulus = self.packed_list.glwe_ciphertext_list.ciphertext_modulus();

let ciphertext_modulus = self.packed_list.ciphertext_modulus;
let message_modulus = self.packed_list.message_modulus;
let carry_modulus = self.packed_list.carry_modulus;
let lwe_per_glwe = self.packed_list.lwe_per_glwe;
let storage_log_modulus = self.packed_list.storage_log_modulus;
let glwe_dimension = self.packed_list.glwe_dimension;
let polynomial_size = self.packed_list.polynomial_size;
let mut modulus_switched_glwe_ciphertext_list =
Vec::with_capacity(self.packed_list.glwe_ciphertext_count().0);

let flat_cpu_data = unsafe {
let mut v = vec![0u64; self.packed_list.data.len()];
self.packed_list.data.copy_to_cpu_async(&mut v, streams, 0);
streams.synchronize();
v
};

let initial_len = self.packed_list.initial_len;
let number_bits_to_pack = initial_len * storage_log_modulus.0;
let len = number_bits_to_pack.div_ceil(u64::BITS as usize);

let modulus_switched_glwe_ciphertext_list = glwe_list
.iter()
.map(|x| {
let glwe_dimension = x.glwe_size().to_glwe_dimension();
let polynomial_size = x.polynomial_size();
CompressedModulusSwitchedGlweCiphertext {
packed_integers: PackedIntegers {
packed_coeffs: x.into_container()[0..len].to_vec(),
log_modulus: storage_log_modulus,
initial_len,
},
glwe_dimension,
polynomial_size,
bodies_count: LweCiphertextCount(self.packed_list.bodies_count),
uncompressed_ciphertext_modulus: ciphertext_modulus,
}
})
.collect_vec();
let mut num_bodies_left = self.packed_list.bodies_count;
let mut chunk_start = 0;
while num_bodies_left != 0 {
let bodies_count = LweCiphertextCount(num_bodies_left.min(lwe_per_glwe.0));
let initial_len = (glwe_dimension.0 * polynomial_size.0) + bodies_count.0;
let number_bits_to_pack = initial_len * storage_log_modulus.0;
let len = number_bits_to_pack.div_ceil(u64::BITS as usize);
let chunk_end = chunk_start + len;
modulus_switched_glwe_ciphertext_list.push(CompressedModulusSwitchedGlweCiphertext {
packed_integers: PackedIntegers {
packed_coeffs: flat_cpu_data[chunk_start..chunk_end].to_vec(),
log_modulus: storage_log_modulus,
initial_len,
},
glwe_dimension,
polynomial_size,
bodies_count,
uncompressed_ciphertext_modulus: ciphertext_modulus,
});
num_bodies_left = num_bodies_left.saturating_sub(lwe_per_glwe.0);
chunk_start = chunk_end;
}

let count = CiphertextCount(self.packed_list.bodies_count);
let pbs_order = PBSOrder::KeyswitchBootstrap;
Expand Down Expand Up @@ -323,39 +325,34 @@ impl CompressedCiphertextList {

let first_ct = modulus_switched_glwe_ciphertext_list.first().unwrap();
let storage_log_modulus = first_ct.packed_integers.log_modulus;
let initial_len = first_ct.packed_integers.initial_len;
let bodies_count = first_ct.bodies_count.0;
let initial_len = modulus_switched_glwe_ciphertext_list
.iter()
.map(|glwe| glwe.packed_integers.initial_len)
.sum();

let message_modulus = self.packed_list.message_modulus;
let carry_modulus = self.packed_list.carry_modulus;

let mut data = modulus_switched_glwe_ciphertext_list
let flat_cpu_data = modulus_switched_glwe_ciphertext_list
.iter()
.flat_map(|ct| ct.packed_integers.packed_coeffs.clone())
.collect_vec();
let glwe_ciphertext_size = glwe_ciphertext_size(
first_ct.glwe_dimension.to_glwe_size(),
first_ct.polynomial_size,
);
data.resize(
self.packed_list.modulus_switched_glwe_ciphertext_list.len() * glwe_ciphertext_size,
0,
);
let glwe_ciphertext_list = GlweCiphertextList::from_container(
data.as_slice(),
first_ct.glwe_dimension.to_glwe_size(),
first_ct.polynomial_size,
self.packed_list.ciphertext_modulus,
);

let flat_gpu_data = unsafe {
let v = CudaVec::from_cpu_async(flat_cpu_data.as_slice(), streams, 0);
streams.synchronize();
v
};

CudaCompressedCiphertextList {
packed_list: CudaPackedGlweCiphertext {
glwe_ciphertext_list: CudaGlweCiphertextList::from_glwe_ciphertext_list(
&glwe_ciphertext_list,
streams,
),
packed_list: CudaPackedGlweCiphertextList {
data: flat_gpu_data,
glwe_dimension: first_ct.glwe_dimension(),
polynomial_size: first_ct.polynomial_size(),
message_modulus,
carry_modulus,
bodies_count,
ciphertext_modulus: self.packed_list.ciphertext_modulus,
bodies_count: self.packed_list.count.0,
storage_log_modulus,
lwe_per_glwe,
initial_len,
Expand Down Expand Up @@ -495,7 +492,9 @@ impl<'de> serde::Deserialize<'de> for CudaCompressedCiphertextList {
#[cfg(test)]
mod tests {
use super::*;
use crate::integer::ciphertext::CompressedCiphertextListBuilder;
use crate::integer::gpu::gen_keys_radix_gpu;
use crate::integer::{ClientKey, RadixCiphertext, RadixClientKey};
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::{
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
Expand All @@ -507,6 +506,154 @@ mod tests {
const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;

#[test]
fn test_cpu_to_gpu_compressed_ciphertext_list() {
const NUM_BLOCKS: usize = 32;
let streams = CudaStreams::new_multi_gpu();

let params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
let comp_params = COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;

let cks = ClientKey::new(params);

let private_compression_key = cks.new_compression_private_key(comp_params);
let (compressed_compression_key, compressed_decompression_key) =
cks.new_compressed_compression_decompression_keys(&private_compression_key);
let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams);
let cuda_decompression_key = compressed_decompression_key.decompress_to_cuda(
cks.parameters().glwe_dimension(),
cks.parameters().polynomial_size(),
cks.parameters().message_modulus(),
cks.parameters().carry_modulus(),
cks.parameters().ciphertext_modulus(),
&streams,
);
let cpu_compression_key = compressed_compression_key.decompress();
let cpu_decompression_key = compressed_decompression_key.decompress();

let radix_cks = RadixClientKey::from((cks, NUM_BLOCKS));

// How many uints of NUM_BLOCKS we have to push in the list to ensure it
// internally has more than one packed GLWE
const MAX_NB_MESSAGES: usize = 1 + 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
.lwe_per_glwe
.0
/ NUM_BLOCKS;

let mut rng = rand::thread_rng();
let message_modulus: u128 = radix_cks.parameters().message_modulus().0 as u128;
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let messages = (0..MAX_NB_MESSAGES)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();

let cpu_cts = messages
.iter()
.map(|message| radix_cks.encrypt(*message))
.collect_vec();

let cuda_cts = cpu_cts
.iter()
.map(|ct| CudaUnsignedRadixCiphertext::from_radix_ciphertext(ct, &streams))
.collect_vec();

let cpu_compressed_list = {
let mut builder = CompressedCiphertextListBuilder::new();
for d_ct in cpu_cts {
builder.push(d_ct);
}
builder.build(&cpu_compression_key)
};

let cuda_compressed_list = {
let mut builder = CudaCompressedCiphertextListBuilder::new();
for d_ct in cuda_cts {
builder.push(d_ct, &streams);
}
builder.build(&cuda_compression_key, &streams)
};

// Test Decompression on Gpu
{
// Roundtrip Gpu->Cpu->Gpu
let cuda_compressed_list = cuda_compressed_list
.to_compressed_ciphertext_list(&streams)
.to_cuda_compressed_ciphertext_list(&streams);

let cuda_compressed_list_2 =
cpu_compressed_list.to_cuda_compressed_ciphertext_list(&streams);

for (i, message) in messages.iter().enumerate() {
let d_decompressed: CudaUnsignedRadixCiphertext = cuda_compressed_list
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(
decrypted, *message,
"Invalid decompression for cuda list that roundtripped Cuda->Cpu->Cuda"
);

let d_decompressed: CudaUnsignedRadixCiphertext = cuda_compressed_list_2
.get(i, &cuda_decompression_key, &streams)
.unwrap()
.unwrap();
let decompressed = d_decompressed.to_radix_ciphertext(&streams);
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(
decrypted, *message,
"Invalid decompression for cuda list that originated from Cpu"
);
}
}

// Test Decompression on CPU (to test conversions)
{
let expected_flat_len = cpu_compressed_list.packed_list.flat_len();

// Roundtrip Cpu->Gpu->Cpu
let cpu_compressed_list = cpu_compressed_list
.to_cuda_compressed_ciphertext_list(&streams)
.to_compressed_ciphertext_list(&streams);
assert_eq!(
cpu_compressed_list.packed_list.flat_len(),
expected_flat_len,
"Invalid flat len after Cpu->Gpu->Cpu"
);

let cpu_compressed_list_2 =
cuda_compressed_list.to_compressed_ciphertext_list(&streams);
assert_eq!(
cpu_compressed_list_2.packed_list.flat_len(),
expected_flat_len,
"Invalid flat len after Gpu->Cpu"
);

for (i, message) in messages.iter().enumerate() {
let decompressed: RadixCiphertext = cpu_compressed_list
.get(i, &cpu_decompression_key)
.unwrap()
.unwrap();
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(
decrypted, *message,
"Invalid decompression for cpu list that roundtripped Cpu->Gpu->Cpu"
);

let decompressed: RadixCiphertext = cpu_compressed_list_2
.get(i, &cpu_decompression_key)
.unwrap()
.unwrap();
let decrypted: u128 = radix_cks.decrypt(&decompressed);
assert_eq!(
decrypted, *message,
"Invalid decompression for cpu list that originated from Gpu"
);
}
}
}

#[test]
fn test_gpu_ciphertext_compression() {
const NUM_BLOCKS: usize = 32;
Expand Down
Loading

0 comments on commit 16d8af1

Please sign in to comment.