diff --git a/tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs b/tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs index 5ee3fbf510..39c3df618b 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_keyswitch.rs @@ -5,7 +5,9 @@ use crate::core_crypto::algorithms::misc::divide_ceil; use crate::core_crypto::algorithms::slice_algorithms::*; use crate::core_crypto::commons::math::decomposition::SignedDecomposer; use crate::core_crypto::commons::numeric::UnsignedInteger; -use crate::core_crypto::commons::parameters::ThreadCount; +use crate::core_crypto::commons::parameters::{ + DecompositionBaseLog, DecompositionLevelCount, ThreadCount, +}; use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use rayon::prelude::*; @@ -119,24 +121,26 @@ pub fn keyswitch_lwe_ciphertext( lwe_keyswitch_key.output_key_lwe_dimension(), output_lwe_ciphertext.lwe_size().to_lwe_dimension(), ); - assert!( - lwe_keyswitch_key.ciphertext_modulus() == input_lwe_ciphertext.ciphertext_modulus(), - "Mismatched CiphertextModulus. \ - LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.", + + let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus(); + + assert_eq!( lwe_keyswitch_key.ciphertext_modulus(), - input_lwe_ciphertext.ciphertext_modulus() - ); - assert!( - lwe_keyswitch_key.ciphertext_modulus() == output_lwe_ciphertext.ciphertext_modulus(), + output_ciphertext_modulus, "Mismatched CiphertextModulus. \ LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.", lwe_keyswitch_key.ciphertext_modulus(), - output_lwe_ciphertext.ciphertext_modulus() + output_ciphertext_modulus + ); + assert!( + output_ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports power of 2 moduli" ); + + let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus(); + assert!( - lwe_keyswitch_key - .ciphertext_modulus() - .is_compatible_with_native_modulus(), + input_ciphertext_modulus.is_compatible_with_native_modulus(), "This operation currently only supports power of 2 moduli" ); @@ -146,6 +150,20 @@ pub fn keyswitch_lwe_ciphertext( // Copy the input body to the output ciphertext *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data; + // If the moduli are not the same, we need to round the body in the output ciphertext + if output_ciphertext_modulus != input_ciphertext_modulus + && !output_ciphertext_modulus.is_native_modulus() + { + let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize; + let output_decomposer = SignedDecomposer::new( + DecompositionBaseLog(modulus_bits), + DecompositionLevelCount(1), + ); + + *output_lwe_ciphertext.get_mut_body().data = + output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data); + } + // We instantiate a decomposer let decomposer = SignedDecomposer::new( lwe_keyswitch_key.decomposition_base_log(), @@ -385,24 +403,26 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count= input_lwe_size let chunk_size = divide_ceil(input_lwe_ciphertext.lwe_size().0, thread_count); @@ -475,8 +514,9 @@ pub fn par_keyswitch_lwe_ciphertext_with_thread_count( } create_parametrized_test!(lwe_encrypt_ks_decrypt_custom_mod); + +#[test] +fn test_lwe_encrypt_ks_switch_mod_decrypt_custom_mod() { + let params = super::TEST_PARAMS_4_BITS_NATIVE_U64; + + let lwe_dimension = params.lwe_dimension; + let lwe_modular_std_dev = params.lwe_modular_std_dev; + let input_ciphertext_modulus = params.ciphertext_modulus; + let message_modulus_log = params.message_modulus_log; + let input_encoding_with_padding = get_encoding_with_padding(input_ciphertext_modulus); + let glwe_dimension = params.glwe_dimension; + let polynomial_size = params.polynomial_size; + let ks_decomp_base_log = params.ks_base_log; + let ks_decomp_level_count = params.ks_level; + + let output_ciphertext_modulus = CiphertextModulus::::try_new_power_of_2(32).unwrap(); + let output_encoding_with_padding = get_encoding_with_padding(output_ciphertext_modulus); + + // Try to have a 32 bits modulus for the output + assert!(ks_decomp_base_log.0 * ks_decomp_level_count.0 <= 32); + + let mut rsc = TestResources::new(); + + let msg_modulus = 1u64 << message_modulus_log.0; + let mut msg = msg_modulus; + let input_delta = input_encoding_with_padding / msg_modulus; + let output_delta = output_encoding_with_padding / msg_modulus; + + while msg != 0 { + msg -= 1; + for _ in 0..NB_TESTS { + let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key( + lwe_dimension, + &mut rsc.secret_random_generator, + ); + + let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key( + glwe_dimension, + polynomial_size, + &mut rsc.secret_random_generator, + ); + + let big_lwe_sk = glwe_sk.into_lwe_secret_key(); + + let ksk_big_to_small = allocate_and_generate_new_lwe_keyswitch_key( + &big_lwe_sk, + &lwe_sk, + ks_decomp_base_log, + ks_decomp_level_count, + lwe_modular_std_dev, + output_ciphertext_modulus, + &mut rsc.encryption_random_generator, + ); + + assert!(check_encrypted_content_respects_mod( + &ksk_big_to_small, + output_ciphertext_modulus + )); + + let plaintext = Plaintext(msg * input_delta); + + let ct = allocate_and_encrypt_new_lwe_ciphertext( + &big_lwe_sk, + plaintext, + lwe_modular_std_dev, + input_ciphertext_modulus, + &mut rsc.encryption_random_generator, + ); + + assert!(check_encrypted_content_respects_mod( + &ct, + input_ciphertext_modulus + )); + + let mut output_ct = LweCiphertext::new( + 0u64, + lwe_sk.lwe_dimension().to_lwe_size(), + output_ciphertext_modulus, + ); + + let mut output_ct_parallel = LweCiphertext::new( + 0u64, + lwe_sk.lwe_dimension().to_lwe_size(), + output_ciphertext_modulus, + ); + + keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct); + + assert!(check_encrypted_content_respects_mod( + &output_ct, + output_ciphertext_modulus + )); + + par_keyswitch_lwe_ciphertext(&ksk_big_to_small, &ct, &mut output_ct_parallel); + assert_eq!(output_ct.as_ref(), output_ct_parallel.as_ref()); + + let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &output_ct); + + let decoded = round_decode(decrypted.0, output_delta) % msg_modulus; + + assert_eq!(msg, decoded); + } + + // In coverage, we break after one while loop iteration, changing message values does not + // yield higher coverage + #[cfg(feature = "__coverage")] + break; + } +}