Skip to content

Commit

Permalink
refactor(core): simplify fast_pbs_modulus_switch
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Mar 13, 2024
1 parent 8e19bd1 commit 3b35cc8
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::common::fast_pbs_modulus_switch;
use crate::core_crypto::fft_impl::common::pbs_modulus_switch;
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
add_external_product_assign, add_external_product_assign_scratch, update_with_fmadd_factor,
};
Expand Down Expand Up @@ -61,12 +61,7 @@ pub fn prepare_multi_bit_ggsw_mem_optimized<
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}

let switched_degree = fast_pbs_modulus_switch(
monomial_degree,
polynomial_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let switched_degree = pbs_modulus_switch(monomial_degree, polynomial_size);

let factor = fft.incomplete_monomial_forward_as_integer(
fourier_a_monomial.as_mut_view(),
Expand Down Expand Up @@ -372,12 +367,7 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
let work_queue = Mutex::new(work_queue);

let lut_poly_size = accumulator.polynomial_size();
let monomial_degree = fast_pbs_modulus_switch(
*lwe_body.data,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let monomial_degree = pbs_modulus_switch(*lwe_body.data, lut_poly_size);

// Modulus switching
accumulator
Expand Down Expand Up @@ -621,12 +611,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, InputCont, OutputCont
let work_queue = &work_queue;

let lut_poly_size = accumulator.polynomial_size();
let monomial_degree = fast_pbs_modulus_switch(
*lwe_body.data,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let monomial_degree = pbs_modulus_switch(*lwe_body.data, lut_poly_size);

// Modulus switching
accumulator
Expand Down Expand Up @@ -704,12 +689,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, InputCont, OutputCont
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}

let switched_degree = fast_pbs_modulus_switch(
monomial_degree,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let switched_degree = pbs_modulus_switch(monomial_degree, lut_poly_size);

let factor = fft.incomplete_monomial_forward_as_integer(
fourier_a_monomial.as_mut_view(),
Expand Down Expand Up @@ -1265,12 +1245,7 @@ pub fn std_prepare_multi_bit_ggsw<Scalar, GgswBufferCont, TmpGgswBufferCont, Ggs
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}

let switched_degree = fast_pbs_modulus_switch(
monomial_degree,
polynomial_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let switched_degree = pbs_modulus_switch(monomial_degree, polynomial_size);

tmp_ggsw_buffer
.as_mut_polynomial_list()
Expand Down Expand Up @@ -1371,12 +1346,7 @@ pub fn std_multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>
let work_queue = Mutex::new(work_queue);

let lut_poly_size = accumulator.polynomial_size();
let monomial_degree = fast_pbs_modulus_switch(
*lwe_body.data,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let monomial_degree = pbs_modulus_switch(*lwe_body.data, lut_poly_size);

// Modulus switching
accumulator
Expand Down Expand Up @@ -1655,12 +1625,7 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, InputCont, Output
let work_queue = &work_queue;

let lut_poly_size = accumulator.polynomial_size();
let monomial_degree = fast_pbs_modulus_switch(
*lwe_body.data,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let monomial_degree = pbs_modulus_switch(*lwe_body.data, lut_poly_size);

// Modulus switching
accumulator
Expand Down
49 changes: 16 additions & 33 deletions tfhe/src/core_crypto/fft_impl/common.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,28 @@
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
ModulusSwitchOffset, PolynomialSize,
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize,
};
use crate::core_crypto::commons::traits::Container;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::ContainerMut;
use crate::core_crypto::prelude::{CastInto, CiphertextModulusLog, ContainerMut};
use dyn_stack::{PodStack, SizeOverflow, StackReq};

/// This function switches modulus for a single coefficient of a ciphertext,
/// only in the context of a PBS
///
/// - offset: the number of msb discarded
/// - lut_count_log: the right padding
///
/// # Note
///
/// If you are switching to a modulus of $2N$ then this function may return the value $2N$ while a
/// "true" modulus switch would return $0$ in that case. It turns out that this is not affecting
/// other parts of the code relying on the modulus switch (as a rotation by $2N$ is effectively the
/// same as rotation by $0$ for polynomials of size $N$ in the ring $X^N+1$) but it could be
/// problematic for code requiring an output in the expected $[0; 2N[$ range. Also this saves a few
/// instructions which can add up when this is being called hundreds or thousands of times per PBS.
pub fn fast_pbs_modulus_switch<Scalar: UnsignedTorus + CastInto<usize>>(
pub fn pbs_modulus_switch<Scalar: UnsignedTorus + CastInto<usize>>(
input: Scalar,
poly_size: PolynomialSize,
offset: ModulusSwitchOffset,
lut_count_log: LutCountLog,
polynomial_size: PolynomialSize,
) -> usize {
// First, do the left shift (we discard the offset msb)
let mut output = input << offset.0;
// Start doing the right shift
output >>= Scalar::BITS - poly_size.log2().0 - 2 + lut_count_log.0;
// Do the rounding
output += Scalar::ONE;
// Finish the right shift
output >>= 1;
// Apply the lsb padding
output <<= lut_count_log.0;
<Scalar as CastInto<usize>>::cast_into(output)
modulus_switch(input, CiphertextModulusLog(polynomial_size.log2().0 + 1)).cast_into()
}

pub fn modulus_switch<Scalar: UnsignedTorus>(
input: Scalar,
log_modulus: CiphertextModulusLog,
) -> Scalar {
// Flooring output_to_floor is equivalent to rounding the input
let output_to_floor = input.wrapping_add(Scalar::ONE << (Scalar::BITS - log_modulus.0 - 1));

output_to_floor >> (Scalar::BITS - log_modulus.0)
}

pub trait FourierBootstrapKey<Scalar: UnsignedInteger> {
Expand Down
20 changes: 5 additions & 15 deletions tfhe/src/core_crypto/fft_impl/fft128/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
ModulusSwitchOffset, MonomialDegree, PolynomialSize,
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, MonomialDegree,
PolynomialSize,
};
use crate::core_crypto::commons::traits::{
Container, ContiguousEntityContainer, ContiguousEntityContainerMut, Split,
};
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::common::{fast_pbs_modulus_switch, FourierBootstrapKey};
use crate::core_crypto::fft_impl::common::{pbs_modulus_switch, FourierBootstrapKey};
use crate::core_crypto::prelude::ContainerMut;
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
use core::any::TypeId;
Expand Down Expand Up @@ -266,12 +266,7 @@ where
let lut_poly_size = lut.polynomial_size();
let ciphertext_modulus = lut.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let monomial_degree = fast_pbs_modulus_switch(
*lwe_body,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let monomial_degree = pbs_modulus_switch(*lwe_body, lut_poly_size);

lut.as_mut_polynomial_list()
.iter_mut()
Expand Down Expand Up @@ -303,12 +298,7 @@ where
for mut poly in ct1.as_mut_polynomial_list().iter_mut() {
polynomial_wrapping_monic_monomial_mul_assign(
&mut poly,
MonomialDegree(fast_pbs_modulus_switch(
*lwe_mask_element,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
)),
MonomialDegree(pbs_modulus_switch(*lwe_mask_element, lut_poly_size)),
);
}

Expand Down
19 changes: 4 additions & 15 deletions tfhe/src/core_crypto/fft_impl/fft128_u128/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ use super::ggsw::cmux_split;
use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext;
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, LutCountLog,
ModulusSwitchOffset, MonomialDegree,
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, MonomialDegree,
};
use crate::core_crypto::commons::traits::ContiguousEntityContainerMut;
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::common::fast_pbs_modulus_switch;
use crate::core_crypto::fft_impl::common::pbs_modulus_switch;
use crate::core_crypto::prelude::{Container, ContainerMut};
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{PodStack, ReborrowMut};
Expand Down Expand Up @@ -83,12 +82,7 @@ where
let (lwe_body, lwe_mask) = lwe.split_last().unwrap();

let lut_poly_size = lut_lo.polynomial_size();
let monomial_degree = fast_pbs_modulus_switch(
*lwe_body,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
let monomial_degree = pbs_modulus_switch(*lwe_body, lut_poly_size);

for (poly_lo, poly_hi) in izip!(
lut_lo.as_mut_polynomial_list().iter_mut(),
Expand Down Expand Up @@ -134,12 +128,7 @@ where
polynomial_wrapping_monic_monomial_mul_assign_split(
poly_lo,
poly_hi,
MonomialDegree(fast_pbs_modulus_switch(
*lwe_mask_element,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
)),
MonomialDegree(pbs_modulus_switch(*lwe_mask_element, lut_poly_size)),
);
}

Expand Down
21 changes: 6 additions & 15 deletions tfhe/src/core_crypto/fft_impl/fft64/crypto/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension,
ModulusSwitchOffset, MonomialDegree, PolynomialSize,
DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, MonomialDegree,
PolynomialSize,
};
use crate::core_crypto::commons::traits::{
Container, ContiguousEntityContainer, ContiguousEntityContainerMut, IntoContainerOwned, Split,
};
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::common::{fast_pbs_modulus_switch, FourierBootstrapKey};
use crate::core_crypto::fft_impl::common::{pbs_modulus_switch, FourierBootstrapKey};
use crate::core_crypto::fft_impl::fft64::math::fft::par_convert_polynomials_list_to_fourier;
use crate::core_crypto::prelude::ContainerMut;
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
Expand Down Expand Up @@ -251,12 +251,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
let lut_poly_size = lut.polynomial_size();
let ciphertext_modulus = lut.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let monomial_degree = MonomialDegree(fast_pbs_modulus_switch(
*lwe_body,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
));
let monomial_degree = MonomialDegree(pbs_modulus_switch(*lwe_body, lut_poly_size));

lut.as_mut_polynomial_list()
.iter_mut()
Expand All @@ -279,12 +274,8 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
for (lwe_mask_element, bootstrap_key_ggsw) in izip!(lwe_mask.iter(), self.into_ggsw_iter())
{
if *lwe_mask_element != Scalar::ZERO {
let monomial_degree = MonomialDegree(fast_pbs_modulus_switch(
*lwe_mask_element,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
));
let monomial_degree =
MonomialDegree(pbs_modulus_switch(*lwe_mask_element, lut_poly_size));

// we effectively inline the body of cmux here, merging the initial subtraction
// operation with the monic polynomial multiplication, then performing the external
Expand Down

0 comments on commit 3b35cc8

Please sign in to comment.