diff --git a/tfhe/src/core_crypto/backward_compatibility/commons/dispersion.rs b/tfhe/src/core_crypto/backward_compatibility/commons/dispersion.rs index 6982f8dee1..b4fd37a9f0 100644 --- a/tfhe/src/core_crypto/backward_compatibility/commons/dispersion.rs +++ b/tfhe/src/core_crypto/backward_compatibility/commons/dispersion.rs @@ -1,7 +1,12 @@ -use crate::core_crypto::commons::dispersion::StandardDev; +use crate::core_crypto::commons::dispersion::{StandardDev, Variance}; use tfhe_versionable::VersionsDispatch; #[derive(VersionsDispatch)] pub enum StandardDevVersions { V0(StandardDev), } + +#[derive(VersionsDispatch)] +pub enum VarianceVersions { + V0(Variance), +} diff --git a/tfhe/src/core_crypto/commons/ciphertext_modulus.rs b/tfhe/src/core_crypto/commons/ciphertext_modulus.rs index beeaeb3406..e398a1a38f 100644 --- a/tfhe/src/core_crypto/commons/ciphertext_modulus.rs +++ b/tfhe/src/core_crypto/commons/ciphertext_modulus.rs @@ -345,6 +345,13 @@ impl CiphertextModulus { } } } + + pub fn raw_modulus_float(&self) -> f64 { + match self.inner { + CiphertextModulusInner::Native => 2_f64.powi(Scalar::BITS as i32), + CiphertextModulusInner::Custom(non_zero) => non_zero.get() as f64, + } + } } impl std::fmt::Display for CiphertextModulus { diff --git a/tfhe/src/core_crypto/commons/dispersion.rs b/tfhe/src/core_crypto/commons/dispersion.rs index 0a101e07d9..d2c9157d14 100644 --- a/tfhe/src/core_crypto/commons/dispersion.rs +++ b/tfhe/src/core_crypto/commons/dispersion.rs @@ -16,7 +16,9 @@ use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; -use crate::core_crypto::backward_compatibility::commons::dispersion::StandardDevVersions; +use crate::core_crypto::backward_compatibility::commons::dispersion::{ + StandardDevVersions, VarianceVersions, +}; /// A trait for types representing distribution parameters, for a given unsigned integer type. // Warning: @@ -32,17 +34,13 @@ pub trait DispersionParameter: Copy { /// $\log\_2(\sigma)=p$ fn get_log_standard_dev(&self) -> LogStandardDev; /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$. - fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev; + fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev; /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$. - fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance; + fn get_modular_variance(&self, modulus: f64) -> ModularVariance; /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$. - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev; -} - -fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 { - 2.0f64.powi(log2_modulus as i32) + fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev; } /// A distribution parameter that uses the base-2 logarithm of the standard deviation as @@ -57,12 +55,15 @@ fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 { /// assert_eq!(params.get_log_standard_dev().0, -25.); /// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2)); /// assert_eq!( -/// params.get_modular_standard_dev(32).value, +/// params.get_modular_standard_dev(2_f64.powi(32)).value, /// 2_f64.powf(32. - 25.) /// ); -/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.); /// assert_eq!( -/// params.get_modular_variance(32).value, +/// params.get_modular_log_standard_dev(2_f64.powi(32)).value, +/// 32. - 25. +/// ); +/// assert_eq!( +/// params.get_modular_variance(2_f64.powi(32)).value, /// 2_f64.powf(32. - 25.).powi(2) /// ); /// @@ -98,22 +99,24 @@ impl DispersionParameter for LogStandardDev { fn get_log_standard_dev(&self) -> Self { Self(self.0) } - fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev { + fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev { ModularStandardDev { - value: f64::powf(2., log2_modulus as f64 + self.0), - modulus: log2_modulus_to_modulus(log2_modulus), + value: 2_f64.powf(self.0) * modulus, + modulus, } } - fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance { + fn get_modular_variance(&self, modulus: f64) -> ModularVariance { + let std_dev = 2_f64.powf(self.0) * modulus; + ModularVariance { - value: f64::powf(2., (log2_modulus as f64 + self.0) * 2.), - modulus: log2_modulus_to_modulus(log2_modulus), + value: std_dev * std_dev, + modulus, } } - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev { + fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev { ModularLogStandardDev { - value: log2_modulus as f64 + self.0, - modulus: log2_modulus_to_modulus(log2_modulus), + value: modulus.log2() + self.0, + modulus, } } } @@ -129,12 +132,15 @@ impl DispersionParameter for LogStandardDev { /// assert_eq!(params.get_log_standard_dev().0, -25.); /// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2)); /// assert_eq!( -/// params.get_modular_standard_dev(32).value, +/// params.get_modular_standard_dev(2_f64.powi(32)).value, /// 2_f64.powf(32. - 25.) /// ); -/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.); /// assert_eq!( -/// params.get_modular_variance(32).value, +/// params.get_modular_log_standard_dev(2_f64.powi(32)).value, +/// 32. - 25. +/// ); +/// assert_eq!( +/// params.get_modular_variance(2_f64.powi(32)).value, /// 2_f64.powf(32. - 25.).powi(2) /// ); /// ``` @@ -168,22 +174,24 @@ impl DispersionParameter for StandardDev { fn get_log_standard_dev(&self) -> LogStandardDev { LogStandardDev(self.0.log2()) } - fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev { + fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev { ModularStandardDev { - value: 2_f64.powf(log2_modulus as f64 + self.0.log2()), - modulus: log2_modulus_to_modulus(log2_modulus), + value: self.0 * modulus, + modulus, } } - fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance { + fn get_modular_variance(&self, modulus: f64) -> ModularVariance { + let std_dev = self.0 * modulus; + ModularVariance { - value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())), - modulus: log2_modulus_to_modulus(log2_modulus), + value: std_dev * std_dev, + modulus, } } - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev { + fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev { ModularLogStandardDev { - value: log2_modulus as f64 + self.0.log2(), - modulus: log2_modulus_to_modulus(log2_modulus), + value: modulus.log2() + self.0.log2(), + modulus, } } } @@ -199,16 +207,20 @@ impl DispersionParameter for StandardDev { /// assert_eq!(params.get_log_standard_dev().0, -25.); /// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2)); /// assert_eq!( -/// params.get_modular_standard_dev(32).value, +/// params.get_modular_standard_dev(2_f64.powi(32)).value, /// 2_f64.powf(32. - 25.) /// ); -/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.); /// assert_eq!( -/// params.get_modular_variance(32).value, +/// params.get_modular_log_standard_dev(2_f64.powi(32)).value, +/// 32. - 25. +/// ); +/// assert_eq!( +/// params.get_modular_variance(2_f64.powi(32)).value, /// 2_f64.powf(32. - 25.).powi(2) /// ); /// ``` -#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)] +#[versionize(VarianceVersions)] pub struct Variance(pub f64); #[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] @@ -222,8 +234,8 @@ impl Variance { Self(var) } - pub fn from_modular_variance(var: f64, log2_modulus: u32) -> Self { - Self(var / 2_f64.powf(log2_modulus as f64 * 2.)) + pub fn from_modular_variance(var: f64, modulus: f64) -> Self { + Self(var / (modulus * modulus)) } } @@ -237,22 +249,22 @@ impl DispersionParameter for Variance { fn get_log_standard_dev(&self) -> LogStandardDev { LogStandardDev(self.0.sqrt().log2()) } - fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev { + fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev { ModularStandardDev { - value: 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()), - modulus: log2_modulus_to_modulus(log2_modulus), + value: self.0.sqrt() * modulus, + modulus, } } - fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance { + fn get_modular_variance(&self, modulus: f64) -> ModularVariance { ModularVariance { - value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())), - modulus: log2_modulus_to_modulus(log2_modulus), + value: self.0 * modulus * modulus, + modulus, } } - fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev { + fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev { ModularLogStandardDev { - value: log2_modulus as f64 + self.0.sqrt().log2(), - modulus: log2_modulus_to_modulus(log2_modulus), + value: modulus.log2() + self.0.sqrt().log2(), + modulus, } } }