From 865b667ffd9e4aae3c1b28496c11633cf6c089e5 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:44:02 +0100 Subject: [PATCH] feat(core): add lwe ct modulus switch compression --- tfhe/src/core_crypto/algorithms/test/mod.rs | 1 + .../test/modulus_switch_compression.rs | 65 ++++ tfhe/src/core_crypto/entities/mod.rs | 1 + .../modulus_switched_lwe_ciphertext.rs | 308 ++++++++++++++++++ 4 files changed, 375 insertions(+) create mode 100644 tfhe/src/core_crypto/algorithms/test/modulus_switch_compression.rs create mode 100644 tfhe/src/core_crypto/entities/modulus_switched_lwe_ciphertext.rs diff --git a/tfhe/src/core_crypto/algorithms/test/mod.rs b/tfhe/src/core_crypto/algorithms/test/mod.rs index c6a9e347db..db0cfb6222 100644 --- a/tfhe/src/core_crypto/algorithms/test/mod.rs +++ b/tfhe/src/core_crypto/algorithms/test/mod.rs @@ -24,6 +24,7 @@ mod lwe_packing_keyswitch; mod lwe_packing_keyswitch_key_generation; mod lwe_private_functional_packing_keyswitch; pub(crate) mod lwe_programmable_bootstrapping; +mod modulus_switch_compression; mod noise_distribution; pub struct TestResources { diff --git a/tfhe/src/core_crypto/algorithms/test/modulus_switch_compression.rs b/tfhe/src/core_crypto/algorithms/test/modulus_switch_compression.rs new file mode 100644 index 0000000000..eb8f7499f4 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/test/modulus_switch_compression.rs @@ -0,0 +1,65 @@ +use super::*; +use crate::core_crypto::prelude::modulus_switched_lwe_ciphertext::PackedModulusSwitchedLweCiphertext; + +#[cfg(not(tarpaulin))] +const NB_TESTS: usize = 10; +#[cfg(tarpaulin)] +const NB_TESTS: usize = 1; + +fn encryption_ms_decryption( + params: ClassicTestParams, +) { + let ClassicTestParams { + lwe_noise_distribution, + message_modulus_log, + ciphertext_modulus, + .. + } = params; + + let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus); + + let mut rsc: TestResources = TestResources::new(); + + let msg_modulus = Scalar::ONE.shl(message_modulus_log.0); + let mut msg = msg_modulus; + let delta: Scalar = encoding_with_padding / msg_modulus; + + while msg != Scalar::ZERO { + msg = msg.wrapping_sub(Scalar::ONE); + for _ in 0..NB_TESTS { + // Create the LweSecretKey + let lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key::( + params.lwe_dimension, + &mut rsc.secret_random_generator, + ); + + let lwe = allocate_and_encrypt_new_lwe_ciphertext( + &lwe_secret_key, + Plaintext(msg * delta), + lwe_noise_distribution, + ciphertext_modulus, + &mut rsc.encryption_random_generator, + ); + + // Can be stored using much less space than the standard lwe ciphertexts + let compressed = PackedModulusSwitchedLweCiphertext::compress( + &lwe, + CiphertextModulusLog(params.polynomial_size.log2().0 + 1), + ); + + let lwe_ms_ed = compressed.extract(); + + let decrypted = decrypt_lwe_ciphertext(&lwe_secret_key, &lwe_ms_ed); + + let decoded = round_decode(decrypted.0, delta) % msg_modulus; + assert_eq!(decoded, msg); + } + + // In coverage, we break after one while loop iteration, changing message values does + // not yield higher coverage + #[cfg(tarpaulin)] + break; + } +} + +create_parametrized_test!(encryption_ms_decryption); diff --git a/tfhe/src/core_crypto/entities/mod.rs b/tfhe/src/core_crypto/entities/mod.rs index 8bfb1082f1..486f900f1e 100644 --- a/tfhe/src/core_crypto/entities/mod.rs +++ b/tfhe/src/core_crypto/entities/mod.rs @@ -22,6 +22,7 @@ pub mod lwe_private_functional_packing_keyswitch_key; pub mod lwe_private_functional_packing_keyswitch_key_list; pub mod lwe_public_key; pub mod lwe_secret_key; +pub mod modulus_switched_lwe_ciphertext; pub mod plaintext; pub mod plaintext_list; pub mod polynomial; diff --git a/tfhe/src/core_crypto/entities/modulus_switched_lwe_ciphertext.rs b/tfhe/src/core_crypto/entities/modulus_switched_lwe_ciphertext.rs new file mode 100644 index 0000000000..7f94884979 --- /dev/null +++ b/tfhe/src/core_crypto/entities/modulus_switched_lwe_ciphertext.rs @@ -0,0 +1,308 @@ +use crate::core_crypto::fft_impl::common::modulus_switch; +use crate::core_crypto::prelude::*; + +/// An object to store a ciphertext in little memory +/// The modulus of the ciphertext is decreased by rounding and the result is stored in a compact way +/// The uncompacted result can be used as the input of a blind rotation to recover a low noise lwe +/// ciphertext +/// +/// ``` +/// use concrete_csprng::seeders::Seed; +/// use tfhe::core_crypto::prelude::*; +/// use tfhe::core_crypto::fft_impl::common::modulus_switch; +/// use tfhe::core_crypto::prelude::modulus_switched_lwe_ciphertext::PackedModulusSwitchedLweCiphertext; +/// +/// let log_modulus = 12; +/// +/// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); +/// +/// // Create the LweSecretKey +/// let lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key::(LweDimension(2048), &mut secret_generator); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// +/// +/// // Unsecure parameters, do not use them +/// let lwe = allocate_and_encrypt_new_lwe_ciphertext( +/// &lwe_secret_key, +/// Plaintext(0), +/// Gaussian::from_standard_dev(StandardDev(0.), 0.), +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Can be stored using much less space than the standard lwe ciphertexts +/// let compressed = PackedModulusSwitchedLweCiphertext::compress( +/// &lwe, +/// CiphertextModulusLog(log_modulus as usize), +/// ); +/// +/// let lwe_ms_ed = compressed.extract(); +/// +/// assert_eq!( +/// modulus_switch( +/// decrypt_lwe_ciphertext(&lwe_secret_key, &lwe_ms_ed).0, +/// CiphertextModulusLog(5) +/// ), +/// 0 +/// ); +/// ``` +pub struct PackedModulusSwitchedLweCiphertext { + packed_coeffs: Vec, + lwe_dimension: LweDimension, + log_modulus: CiphertextModulusLog, + uncompressed_ciphertext_modulus: CiphertextModulus, +} + +impl PackedModulusSwitchedLweCiphertext { + /// Compresses a ciphertext by reducing its modulus + /// This operation adds a lot of noise + pub fn compress>( + ct: &LweCiphertext, + log_modulus: CiphertextModulusLog, + ) -> Self { + let switch_modulus = |x| modulus_switch(x, log_modulus); + + let log_modulus = log_modulus.0; + + let uncompressed_ciphertext_modulus = ct.ciphertext_modulus(); + + assert!( + ct.ciphertext_modulus().is_power_of_two(), + "Modulus switch compression doe not support non power of 2 input moduli", + ); + + let uncompressed_ciphertext_modulus_log = + if uncompressed_ciphertext_modulus.is_native_modulus() { + Scalar::BITS + } else { + uncompressed_ciphertext_modulus.get_custom_modulus().ilog2() as usize + }; + + assert!( + log_modulus <= uncompressed_ciphertext_modulus_log, + "The log_modulus (={log_modulus}) for modulus switch compression must be smaller than the uncompressed ciphertext_modulus_log (={uncompressed_ciphertext_modulus_log})", + ); + + let lwe_size = ct.lwe_size().0; + + let number_bits_to_pack = lwe_size * log_modulus; + + let len = number_bits_to_pack.div_ceil(Scalar::BITS); + + let slice = ct.as_ref(); + // Lowest bits are on the right + // + // Target mapping: + // log_modulus + // |-------| + // + // slice : | k+2 | k+1 | k | + // packed_coeffs: i+1 | i | i-1 + // + // |---------------| + // Scalar::BITS + // + // |---| + // start_shift + // + // |---| + // shift1 + // (1st loop iteration) + // + // |-----------| + // shift2 + // (2nd loop iteration) + // + // packed_coeffs[i] = + // slice[k] >> start_shift + // | slice[k+1] << shift1 + // | slice[k+2] << shift2 + // + // In the lowest bits of packed_coeffs[i], we want the highest bits of slice[k], + // hence the right shift + // The next bits should be the bits of slice[k+1] which we must left shifted to avoid + // overlapping + // This goes on + let packed_coeffs = (0..len) + .map(|i| { + let k = Scalar::BITS * i / log_modulus; + let mut j = k; + + let start_shift = i * Scalar::BITS - j * log_modulus; + + let mut value = switch_modulus(slice[j]) >> start_shift; + j += 1; + + while j * log_modulus < ((i + 1) * Scalar::BITS) && j < lwe_size { + let shift = j * log_modulus - i * Scalar::BITS; + + value |= switch_modulus(slice[j]) << shift; + + j += 1; + } + value + }) + .collect(); + + let log_modulus = CiphertextModulusLog(log_modulus); + + Self { + packed_coeffs, + lwe_dimension: ct.lwe_size().to_lwe_dimension(), + log_modulus, + uncompressed_ciphertext_modulus, + } + } + + /// Converts back a compressed ciphertext to its initial modulus + /// The noise added during the compression says int hte output + /// The output must got through a PBS to reduce the noise + pub fn extract(&self) -> LweCiphertextOwned { + let log_modulus = self.log_modulus.0; + + let container = (0..(self.lwe_dimension.to_lwe_size().0)) + .map(|i| { + let start = i * log_modulus; + let end = (i + 1) * log_modulus; + + let start_block = start / Scalar::BITS; + let start_remainder = start % Scalar::BITS; + + let end_block_inclusive = (end - 1) / Scalar::BITS; + + if start_block == end_block_inclusive { + // Lowest bits are on the right + // + // Target mapping: + // Scalar::BITS + // |---------------| + // + // packed_coeffs: | start_block+1 | start_block | + // container : | i+1 | i | i-1 | + // + // |-------| + // log_modulus + // + // |---| + // start_remainder + // + // In container[i] we want the bits of packed_coeffs[start_block] starting from + // index start_remainder + // + // container[i] = lowest_bits of single_part + // + // The highest bits of single_part will be discarded during scaling + // + // single_part = + self.packed_coeffs[start_block] >> start_remainder + } else { + // Lowest bits are on the right + // + // Target mapping: + // Scalar::BITS + // |---------------| + // + // packed_coeffs: | start_block+1 | start_block | + // container : | i+1 | i | i-1 | + // + // |-------| + // log_modulus + // + // |-----------| + // start_remainder + // + // |---| + // Scalar::BITS - start_remainder + // + // In the lowest bits of container[i] we want the highest bits of + // packed_coeffs[start_block] starting from index start_remainder + // + // In the next bits, we want the lowest bits of packed_coeffs[start_block + 1] + // left shifted to avoid overlapping + // + // container[i] = lowest_bits of (first_part|second_part) + // + // The highest bits of (first_part|second_part) will be discarded during scaling + assert_eq!(end_block_inclusive, start_block + 1); + + let first_part = self.packed_coeffs[start_block] >> start_remainder; + + let second_part = + self.packed_coeffs[start_block + 1] << (Scalar::BITS - start_remainder); + + first_part | second_part + } + }) + // Scaling + .map(|a| a << (Scalar::BITS - log_modulus)) + .collect(); + + LweCiphertextOwned::from_container(container, self.uncompressed_ciphertext_modulus) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::core_crypto::prelude::test::TestResources; + + #[test] + fn ms_compression_() { + ms_compression::(1, 100); + ms_compression::(10, 64); + ms_compression::(11, 700); + ms_compression::(12, 751); + + ms_compression::(1, 100); + ms_compression::(10, 64); + ms_compression::(11, 700); + ms_compression::(12, 751); + ms_compression::(33, 10); + ms_compression::(53, 37); + ms_compression::(63, 63); + + ms_compression::(127, 127); + } + + fn ms_compression + CastFrom>( + log_modulus: usize, + len: usize, + ) { + let mut rsc: TestResources = TestResources::new(); + + let ciphertext_modulus = CiphertextModulus::new_native(); + + let mut lwe = vec![Scalar::ZERO; len]; + + rsc.encryption_random_generator + .fill_slice_with_random_uniform_mask(&mut lwe); + + let lwe = LweCiphertextOwned::from_container(lwe, ciphertext_modulus); + + let compressed = + PackedModulusSwitchedLweCiphertext::compress(&lwe, CiphertextModulusLog(log_modulus)); + + let lwe_ms_ed: Vec = compressed.extract().into_container(); + + let lwe = lwe.into_container(); + + for (i, output) in lwe_ms_ed.into_iter().enumerate() { + assert_eq!( + output, + (output >> (Scalar::BITS - log_modulus)) << (Scalar::BITS - log_modulus), + ); + + assert_eq!( + output >> (Scalar::BITS - log_modulus), + modulus_switch(lwe[i], CiphertextModulusLog(log_modulus)) + ) + } + } +}