Skip to content

Commit

Permalink
feat(core): allow switching moduli during an LWE Keyswitch
Browse files Browse the repository at this point in the history
  • Loading branch information
jborfila authored and IceTDrinker committed Feb 20, 2024
1 parent e62808b commit b708abb
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 31 deletions.
102 changes: 71 additions & 31 deletions tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::core_crypto::algorithms::misc::divide_ceil;
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::ThreadCount;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
};
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
Expand Down Expand Up @@ -119,24 +121,26 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.ciphertext_modulus() == input_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",

let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();

assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus()
);
assert!(
lwe_keyswitch_key.ciphertext_modulus() == output_lwe_ciphertext.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus()
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);

let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();

assert!(
lwe_keyswitch_key
.ciphertext_modulus()
.is_compatible_with_native_modulus(),
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);

Expand All @@ -146,6 +150,20 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;

// If the moduli are not the same, we need to round the body in the output ciphertext
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);

*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}

// We instantiate a decomposer
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
Expand Down Expand Up @@ -385,24 +403,26 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.ciphertext_modulus() == input_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",

let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();

assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus()
);
assert!(
lwe_keyswitch_key.ciphertext_modulus() == output_lwe_ciphertext.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus()
output_ciphertext_modulus
);
assert!(
lwe_keyswitch_key
.ciphertext_modulus()
.is_compatible_with_native_modulus(),
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);

let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();

assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);

Expand All @@ -411,6 +431,28 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont
"Got thread_count == 0, this is not supported"
);

// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);

let output_lwe_size = output_lwe_ciphertext.lwe_size();

// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;

// If the moduli are not the same, we need to round the body in the output ciphertext
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);

*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}

// We instantiate a decomposer
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
Expand All @@ -421,9 +463,6 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont
let thread_count = thread_count.0.min(rayon::current_num_threads());
let mut intermediate_accumulators = Vec::with_capacity(thread_count);

let output_lwe_size = output_lwe_ciphertext.lwe_size();
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();

// Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size
let chunk_size = divide_ceil(input_lwe_ciphertext.lwe_size().0, thread_count);

Expand Down Expand Up @@ -475,8 +514,9 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont
.get_mut_mask()
.as_mut()
.copy_from_slice(reduced.get_mask().as_ref());
let input_lwe_body = *input_lwe_ciphertext.get_body().data;
let reduced_ksed_body = *reduced.get_body().data;
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = input_lwe_body.wrapping_add(reduced_ksed_body);

// Add the reduced body of the keyswitch to the output body to complete the keyswitch
*output_lwe_ciphertext.get_mut_body().data =
(*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
}
109 changes: 109 additions & 0 deletions tfhe/src/core_crypto/algorithms/test/lwe_keyswitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,112 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus + Send + Sync>(
}

create_parametrized_test!(lwe_encrypt_ks_decrypt_custom_mod);

#[test]
fn test_lwe_encrypt_ks_switch_mod_decrypt_custom_mod() {
let params = super::TEST_PARAMS_4_BITS_NATIVE_U64;

let lwe_dimension = params.lwe_dimension;
let lwe_modular_std_dev = params.lwe_modular_std_dev;
let input_ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let input_encoding_with_padding = get_encoding_with_padding(input_ciphertext_modulus);
let glwe_dimension = params.glwe_dimension;
let polynomial_size = params.polynomial_size;
let ks_decomp_base_log = params.ks_base_log;
let ks_decomp_level_count = params.ks_level;

let output_ciphertext_modulus = CiphertextModulus::<u64>::try_new_power_of_2(32).unwrap();
let output_encoding_with_padding = get_encoding_with_padding(output_ciphertext_modulus);

// Try to have a 32 bits modulus for the output
assert!(ks_decomp_base_log.0 * ks_decomp_level_count.0 <= 32);

let mut rsc = TestResources::new();

let msg_modulus = 1u64 << message_modulus_log.0;
let mut msg = msg_modulus;
let input_delta = input_encoding_with_padding / msg_modulus;
let output_delta = output_encoding_with_padding / msg_modulus;

while msg != 0 {
msg -= 1;
for _ in 0..NB_TESTS {
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut rsc.secret_random_generator,
);

let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut rsc.secret_random_generator,
);

let big_lwe_sk = glwe_sk.into_lwe_secret_key();

let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key(
&big_lwe_sk,
&lwe_sk,
ks_decomp_base_log,
ks_decomp_level_count,
lwe_modular_std_dev,
output_ciphertext_modulus,
&mut rsc.encryption_random_generator,
);

assert!(check_encrypted_content_respects_mod(
&ksk_big_to_small,
output_ciphertext_modulus
));

let plaintext = Plaintext(msg * input_delta);

let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk,
plaintext,
lwe_modular_std_dev,
input_ciphertext_modulus,
&mut rsc.encryption_random_generator,
);

assert!(check_encrypted_content_respects_mod(
&ct,
input_ciphertext_modulus
));

let mut output_ct = LweCiphertext::new(
0u64,
lwe_sk.lwe_dimension().to_lwe_size(),
output_ciphertext_modulus,
);

let mut output_ct_parallel = LweCiphertext::new(
0u64,
lwe_sk.lwe_dimension().to_lwe_size(),
output_ciphertext_modulus,
);

keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct);

assert!(check_encrypted_content_respects_mod(
&output_ct,
output_ciphertext_modulus
));

par_keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct_parallel);
assert_eq!(output_ct.as_ref(), output_ct_parallel.as_ref());

let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);

let decoded = round_decode(decrypted.0, output_delta) % msg_modulus;

assert_eq!(msg, decoded);
}

// In coverage, we break after one while loop iteration, changing message values does not
// yield higher coverage
#[cfg(feature = "__coverage")]
break;
}
}

0 comments on commit b708abb

Please sign in to comment.