Skip to content

Commit

Permalink
chore(core): refactor DispersionParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Feb 13, 2025
1 parent eeb6c8a commit 4305f8d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 48 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use crate::core_crypto::commons::dispersion::StandardDev;
use crate::core_crypto::commons::dispersion::{StandardDev, Variance};
use tfhe_versionable::VersionsDispatch;

#[derive(VersionsDispatch)]
pub enum StandardDevVersions {
V0(StandardDev),
}

#[derive(VersionsDispatch)]
pub enum VarianceVersions {
V0(Variance),
}
7 changes: 7 additions & 0 deletions tfhe/src/core_crypto/commons/ciphertext_modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
}
}
}

pub fn raw_modulus_float(&self) -> f64 {
match self.inner {
CiphertextModulusInner::Native => 2_f64.powi(Scalar::BITS as i32),
CiphertextModulusInner::Custom(non_zero) => non_zero.get() as f64,
}
}
}

impl<Scalar: UnsignedInteger> std::fmt::Display for CiphertextModulus<Scalar> {
Expand Down
106 changes: 59 additions & 47 deletions tfhe/src/core_crypto/commons/dispersion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

use crate::core_crypto::backward_compatibility::commons::dispersion::StandardDevVersions;
use crate::core_crypto::backward_compatibility::commons::dispersion::{
StandardDevVersions, VarianceVersions,
};

/// A trait for types representing distribution parameters, for a given unsigned integer type.
// Warning:
Expand All @@ -32,17 +34,13 @@ pub trait DispersionParameter: Copy {
/// $\log\_2(\sigma)=p$
fn get_log_standard_dev(&self) -> LogStandardDev;
/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$.
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev;
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev;

/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$.
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance;
fn get_modular_variance(&self, modulus: f64) -> ModularVariance;

/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$.
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev;
}

fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 {
2.0f64.powi(log2_modulus as i32)
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev;
}

/// A distribution parameter that uses the base-2 logarithm of the standard deviation as
Expand All @@ -57,12 +55,15 @@ fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 {
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// params.get_modular_standard_dev(2_f64.powi(32)).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32).value,
/// params.get_modular_log_standard_dev(2_f64.powi(32)).value,
/// 32. - 25.
/// );
/// assert_eq!(
/// params.get_modular_variance(2_f64.powi(32)).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
///
Expand Down Expand Up @@ -98,22 +99,24 @@ impl DispersionParameter for LogStandardDev {
fn get_log_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
ModularStandardDev {
value: f64::powf(2., log2_modulus as f64 + self.0),
modulus: log2_modulus_to_modulus(log2_modulus),
value: 2_f64.powf(self.0) * modulus,
modulus,
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
let std_dev = 2_f64.powf(self.0) * modulus;

ModularVariance {
value: f64::powf(2., (log2_modulus as f64 + self.0) * 2.),
modulus: log2_modulus_to_modulus(log2_modulus),
value: std_dev * std_dev,
modulus,
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0,
modulus: log2_modulus_to_modulus(log2_modulus),
value: modulus.log2() + self.0,
modulus,
}
}
}
Expand All @@ -129,12 +132,15 @@ impl DispersionParameter for LogStandardDev {
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// params.get_modular_standard_dev(2_f64.powi(32)).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32).value,
/// params.get_modular_log_standard_dev(2_f64.powi(32)).value,
/// 32. - 25.
/// );
/// assert_eq!(
/// params.get_modular_variance(2_f64.powi(32)).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
/// ```
Expand Down Expand Up @@ -168,22 +174,24 @@ impl DispersionParameter for StandardDev {
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.log2())
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(log2_modulus as f64 + self.0.log2()),
modulus: log2_modulus_to_modulus(log2_modulus),
value: self.0 * modulus,
modulus,
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
let std_dev = self.0 * modulus;

ModularVariance {
value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())),
modulus: log2_modulus_to_modulus(log2_modulus),
value: std_dev * std_dev,
modulus,
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0.log2(),
modulus: log2_modulus_to_modulus(log2_modulus),
value: modulus.log2() + self.0.log2(),
modulus,
}
}
}
Expand All @@ -199,16 +207,20 @@ impl DispersionParameter for StandardDev {
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// params.get_modular_standard_dev(2_f64.powi(32)).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32).value,
/// params.get_modular_log_standard_dev(2_f64.powi(32)).value,
/// 32. - 25.
/// );
/// assert_eq!(
/// params.get_modular_variance(2_f64.powi(32)).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
#[versionize(VarianceVersions)]
pub struct Variance(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
Expand All @@ -222,8 +234,8 @@ impl Variance {
Self(var)
}

pub fn from_modular_variance(var: f64, log2_modulus: u32) -> Self {
Self(var / 2_f64.powf(log2_modulus as f64 * 2.))
pub fn from_modular_variance(var: f64, modulus: f64) -> Self {
Self(var / (modulus * modulus))
}
}

Expand All @@ -237,22 +249,22 @@ impl DispersionParameter for Variance {
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.sqrt().log2())
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
fn get_modular_standard_dev(&self, modulus: f64) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()),
modulus: log2_modulus_to_modulus(log2_modulus),
value: self.0.sqrt() * modulus,
modulus,
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
fn get_modular_variance(&self, modulus: f64) -> ModularVariance {
ModularVariance {
value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())),
modulus: log2_modulus_to_modulus(log2_modulus),
value: self.0 * modulus * modulus,
modulus,
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
fn get_modular_log_standard_dev(&self, modulus: f64) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0.sqrt().log2(),
modulus: log2_modulus_to_modulus(log2_modulus),
value: modulus.log2() + self.0.sqrt().log2(),
modulus,
}
}
}

0 comments on commit 4305f8d

Please sign in to comment.