Skip to content

Commit

Permalink
feat(core): add lwe ct modulus switch compression
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Mar 13, 2024
1 parent 3b35cc8 commit 865b667
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 0 deletions.
1 change: 1 addition & 0 deletions tfhe/src/core_crypto/algorithms/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod lwe_packing_keyswitch;
mod lwe_packing_keyswitch_key_generation;
mod lwe_private_functional_packing_keyswitch;
pub(crate) mod lwe_programmable_bootstrapping;
mod modulus_switch_compression;
mod noise_distribution;

pub struct TestResources {
Expand Down
65 changes: 65 additions & 0 deletions tfhe/src/core_crypto/algorithms/test/modulus_switch_compression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use super::*;
use crate::core_crypto::prelude::modulus_switched_lwe_ciphertext::PackedModulusSwitchedLweCiphertext;

#[cfg(not(tarpaulin))]
const NB_TESTS: usize = 10;
#[cfg(tarpaulin)]
const NB_TESTS: usize = 1;

fn encryption_ms_decryption<Scalar: UnsignedTorus + Sync + Send>(
params: ClassicTestParams<Scalar>,
) {
let ClassicTestParams {
lwe_noise_distribution,
message_modulus_log,
ciphertext_modulus,
..
} = params;

let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);

let mut rsc: TestResources = TestResources::new();

let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
let mut msg = msg_modulus;
let delta: Scalar = encoding_with_padding / msg_modulus;

while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);
for _ in 0..NB_TESTS {
// Create the LweSecretKey
let lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key::<Scalar, _>(
params.lwe_dimension,
&mut rsc.secret_random_generator,
);

let lwe = allocate_and_encrypt_new_lwe_ciphertext(
&lwe_secret_key,
Plaintext(msg * delta),
lwe_noise_distribution,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);

// Can be stored using much less space than the standard lwe ciphertexts
let compressed = PackedModulusSwitchedLweCiphertext::compress(
&lwe,
CiphertextModulusLog(params.polynomial_size.log2().0 + 1),
);

let lwe_ms_ed = compressed.extract();

let decrypted = decrypt_lwe_ciphertext(&lwe_secret_key, &lwe_ms_ed);

let decoded = round_decode(decrypted.0, delta) % msg_modulus;
assert_eq!(decoded, msg);
}

// In coverage, we break after one while loop iteration, changing message values does
// not yield higher coverage
#[cfg(tarpaulin)]
break;
}
}

create_parametrized_test!(encryption_ms_decryption);
1 change: 1 addition & 0 deletions tfhe/src/core_crypto/entities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub mod lwe_private_functional_packing_keyswitch_key;
pub mod lwe_private_functional_packing_keyswitch_key_list;
pub mod lwe_public_key;
pub mod lwe_secret_key;
pub mod modulus_switched_lwe_ciphertext;
pub mod plaintext;
pub mod plaintext_list;
pub mod polynomial;
Expand Down
308 changes: 308 additions & 0 deletions tfhe/src/core_crypto/entities/modulus_switched_lwe_ciphertext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
use crate::core_crypto::fft_impl::common::modulus_switch;
use crate::core_crypto::prelude::*;

/// An object to store a ciphertext in little memory
/// The modulus of the ciphertext is decreased by rounding and the result is stored in a compact way
/// The uncompacted result can be used as the input of a blind rotation to recover a low noise lwe
/// ciphertext
///
/// ```
/// use concrete_csprng::seeders::Seed;
/// use tfhe::core_crypto::prelude::*;
/// use tfhe::core_crypto::fft_impl::common::modulus_switch;
/// use tfhe::core_crypto::prelude::modulus_switched_lwe_ciphertext::PackedModulusSwitchedLweCiphertext;
///
/// let log_modulus = 12;
///
/// let mut secret_generator = SecretRandomGenerator::<ActivatedRandomGenerator>::new(Seed(0));
///
/// // Create the LweSecretKey
/// let lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key::<u64, _>(LweDimension(2048), &mut secret_generator);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
///
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
///
///
/// // Unsecure parameters, do not use them
/// let lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &lwe_secret_key,
/// Plaintext(0),
/// Gaussian::from_standard_dev(StandardDev(0.), 0.),
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Can be stored using much less space than the standard lwe ciphertexts
/// let compressed = PackedModulusSwitchedLweCiphertext::compress(
/// &lwe,
/// CiphertextModulusLog(log_modulus as usize),
/// );
///
/// let lwe_ms_ed = compressed.extract();
///
/// assert_eq!(
/// modulus_switch(
/// decrypt_lwe_ciphertext(&lwe_secret_key, &lwe_ms_ed).0,
/// CiphertextModulusLog(5)
/// ),
/// 0
/// );
/// ```
pub struct PackedModulusSwitchedLweCiphertext<Scalar: UnsignedTorus> {
packed_coeffs: Vec<Scalar>,
lwe_dimension: LweDimension,
log_modulus: CiphertextModulusLog,
uncompressed_ciphertext_modulus: CiphertextModulus<Scalar>,
}

impl<Scalar: UnsignedTorus> PackedModulusSwitchedLweCiphertext<Scalar> {
/// Compresses a ciphertext by reducing its modulus
/// This operation adds a lot of noise
pub fn compress<Cont: Container<Element = Scalar>>(
ct: &LweCiphertext<Cont>,
log_modulus: CiphertextModulusLog,
) -> Self {
let switch_modulus = |x| modulus_switch(x, log_modulus);

let log_modulus = log_modulus.0;

let uncompressed_ciphertext_modulus = ct.ciphertext_modulus();

assert!(
ct.ciphertext_modulus().is_power_of_two(),
"Modulus switch compression doe not support non power of 2 input moduli",
);

let uncompressed_ciphertext_modulus_log =
if uncompressed_ciphertext_modulus.is_native_modulus() {
Scalar::BITS
} else {
uncompressed_ciphertext_modulus.get_custom_modulus().ilog2() as usize
};

assert!(
log_modulus <= uncompressed_ciphertext_modulus_log,
"The log_modulus (={log_modulus}) for modulus switch compression must be smaller than the uncompressed ciphertext_modulus_log (={uncompressed_ciphertext_modulus_log})",
);

let lwe_size = ct.lwe_size().0;

let number_bits_to_pack = lwe_size * log_modulus;

let len = number_bits_to_pack.div_ceil(Scalar::BITS);

let slice = ct.as_ref();
// Lowest bits are on the right
//
// Target mapping:
// log_modulus
// |-------|
//
// slice : | k+2 | k+1 | k |
// packed_coeffs: i+1 | i | i-1
//
// |---------------|
// Scalar::BITS
//
// |---|
// start_shift
//
// |---|
// shift1
// (1st loop iteration)
//
// |-----------|
// shift2
// (2nd loop iteration)
//
// packed_coeffs[i] =
// slice[k] >> start_shift
// | slice[k+1] << shift1
// | slice[k+2] << shift2
//
// In the lowest bits of packed_coeffs[i], we want the highest bits of slice[k],
// hence the right shift
// The next bits should be the bits of slice[k+1] which we must left shifted to avoid
// overlapping
// This goes on
let packed_coeffs = (0..len)
.map(|i| {
let k = Scalar::BITS * i / log_modulus;
let mut j = k;

let start_shift = i * Scalar::BITS - j * log_modulus;

let mut value = switch_modulus(slice[j]) >> start_shift;
j += 1;

while j * log_modulus < ((i + 1) * Scalar::BITS) && j < lwe_size {
let shift = j * log_modulus - i * Scalar::BITS;

value |= switch_modulus(slice[j]) << shift;

j += 1;
}
value
})
.collect();

let log_modulus = CiphertextModulusLog(log_modulus);

Self {
packed_coeffs,
lwe_dimension: ct.lwe_size().to_lwe_dimension(),
log_modulus,
uncompressed_ciphertext_modulus,
}
}

/// Converts back a compressed ciphertext to its initial modulus
/// The noise added during the compression says int hte output
/// The output must got through a PBS to reduce the noise
pub fn extract(&self) -> LweCiphertextOwned<Scalar> {
let log_modulus = self.log_modulus.0;

let container = (0..(self.lwe_dimension.to_lwe_size().0))
.map(|i| {
let start = i * log_modulus;
let end = (i + 1) * log_modulus;

let start_block = start / Scalar::BITS;
let start_remainder = start % Scalar::BITS;

let end_block_inclusive = (end - 1) / Scalar::BITS;

if start_block == end_block_inclusive {
// Lowest bits are on the right
//
// Target mapping:
// Scalar::BITS
// |---------------|
//
// packed_coeffs: | start_block+1 | start_block |
// container : | i+1 | i | i-1 |
//
// |-------|
// log_modulus
//
// |---|
// start_remainder
//
// In container[i] we want the bits of packed_coeffs[start_block] starting from
// index start_remainder
//
// container[i] = lowest_bits of single_part
//
// The highest bits of single_part will be discarded during scaling
//
// single_part =
self.packed_coeffs[start_block] >> start_remainder
} else {
// Lowest bits are on the right
//
// Target mapping:
// Scalar::BITS
// |---------------|
//
// packed_coeffs: | start_block+1 | start_block |
// container : | i+1 | i | i-1 |
//
// |-------|
// log_modulus
//
// |-----------|
// start_remainder
//
// |---|
// Scalar::BITS - start_remainder
//
// In the lowest bits of container[i] we want the highest bits of
// packed_coeffs[start_block] starting from index start_remainder
//
// In the next bits, we want the lowest bits of packed_coeffs[start_block + 1]
// left shifted to avoid overlapping
//
// container[i] = lowest_bits of (first_part|second_part)
//
// The highest bits of (first_part|second_part) will be discarded during scaling
assert_eq!(end_block_inclusive, start_block + 1);

let first_part = self.packed_coeffs[start_block] >> start_remainder;

let second_part =
self.packed_coeffs[start_block + 1] << (Scalar::BITS - start_remainder);

first_part | second_part
}
})
// Scaling
.map(|a| a << (Scalar::BITS - log_modulus))
.collect();

LweCiphertextOwned::from_container(container, self.uncompressed_ciphertext_modulus)
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::core_crypto::prelude::test::TestResources;

#[test]
fn ms_compression_() {
ms_compression::<u32>(1, 100);
ms_compression::<u32>(10, 64);
ms_compression::<u32>(11, 700);
ms_compression::<u32>(12, 751);

ms_compression::<u64>(1, 100);
ms_compression::<u64>(10, 64);
ms_compression::<u64>(11, 700);
ms_compression::<u64>(12, 751);
ms_compression::<u64>(33, 10);
ms_compression::<u64>(53, 37);
ms_compression::<u64>(63, 63);

ms_compression::<u128>(127, 127);
}

fn ms_compression<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>>(
log_modulus: usize,
len: usize,
) {
let mut rsc: TestResources = TestResources::new();

let ciphertext_modulus = CiphertextModulus::new_native();

let mut lwe = vec![Scalar::ZERO; len];

rsc.encryption_random_generator
.fill_slice_with_random_uniform_mask(&mut lwe);

let lwe = LweCiphertextOwned::from_container(lwe, ciphertext_modulus);

let compressed =
PackedModulusSwitchedLweCiphertext::compress(&lwe, CiphertextModulusLog(log_modulus));

let lwe_ms_ed: Vec<Scalar> = compressed.extract().into_container();

let lwe = lwe.into_container();

for (i, output) in lwe_ms_ed.into_iter().enumerate() {
assert_eq!(
output,
(output >> (Scalar::BITS - log_modulus)) << (Scalar::BITS - log_modulus),
);

assert_eq!(
output >> (Scalar::BITS - log_modulus),
modulus_switch(lwe[i], CiphertextModulusLog(log_modulus))
)
}
}
}

0 comments on commit 865b667

Please sign in to comment.