diff --git a/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs b/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs index 1075ff18f6..402b404e9c 100644 --- a/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs +++ b/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs @@ -6,7 +6,7 @@ use crate::core_crypto::commons::traits::{Container, ContainerMut, UnsignedInteg use crate::core_crypto::entities::{LweCiphertext, LweCiphertextList}; use crate::core_crypto::fft_impl::common::modulus_switch; use crate::core_crypto::prelude::{ - CiphertextModulusLog, ContiguousEntityContainer, DispersionParameter, + CiphertextModulus, CiphertextModulusLog, ContiguousEntityContainer, DispersionParameter, }; use itertools::Itertools; @@ -128,6 +128,12 @@ where 0, "Expected at least one encryption of zero" ); + assert_eq!( + lwe.ciphertext_modulus(), + CiphertextModulus::new_native(), + "Non native modulus are not supported, got {}", + lwe.ciphertext_modulus(), + ); let modulus = lwe.ciphertext_modulus().raw_modulus_float(); diff --git a/tfhe/src/core_crypto/algorithms/test/modulus_switch_noise_reduction.rs b/tfhe/src/core_crypto/algorithms/test/modulus_switch_noise_reduction.rs index ecc9ae89eb..c71767800d 100644 --- a/tfhe/src/core_crypto/algorithms/test/modulus_switch_noise_reduction.rs +++ b/tfhe/src/core_crypto/algorithms/test/modulus_switch_noise_reduction.rs @@ -466,12 +466,14 @@ fn check_noise_improve_modulus_switch_noise( variance_improved / base_variance, ); + let modulus = ciphertext_modulus.raw_modulus_float(); + let expected_base_variance = { let lwe_dim = lwe_dimension.0 as f64; let poly_size = 2_f64.powi((log_modulus.0 - 1) as i32); - (lwe_dim + 2.) * 2_f64.powi(128) / (96. * poly_size * poly_size) + (lwe_dim - 4.) / 48. + (lwe_dim + 2.) * modulus * modulus / (96. * poly_size * poly_size) + (lwe_dim - 4.) / 48. }; assert!( @@ -480,7 +482,7 @@ fn check_noise_improve_modulus_switch_noise( ); let expected_variance_improved = Variance(expected_variance_improved.0 - input_variance.0) - .get_modular_variance(2_f64.powi(64)) + .get_modular_variance(modulus) .value; assert!(