Skip to content

Ns/feat/atomic patterns #2024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
3b54ba0
refactor(shortint): wrap PbsOrder into AtomicPattern in ciphertext
nsarlin-zama Apr 25, 2025
298c556
refactor(shortint): factorize generate_lookup_table
nsarlin-zama Jan 21, 2025
d11305c
refactor(shortint): use a single lwe buffer inside shortint engine
nsarlin-zama Jan 23, 2025
73ee949
refactor(shortint): function to directly set noise level to nominal
nsarlin-zama Jan 27, 2025
f9b3a1d
refactor(shortint): use a dedicated type for lut size
nsarlin-zama Jan 27, 2025
87954ae
refactor(shortint): remove degree in generate_lookup_table_no_encode
nsarlin-zama Mar 6, 2025
93e2aff
feat(shortint): create atomic pattern trait and enum
nsarlin-zama Apr 25, 2025
336cb22
feat(shortint): insert the AP inside the ServerKey
nsarlin-zama May 2, 2025
90f14f0
refactor(shortint): engine can create any atomic pattern sk
nsarlin-zama Mar 17, 2025
1a23dd6
feat(shortint): add the dynamic ap
nsarlin-zama Feb 10, 2025
353a56f
refactor(shortint): support any ciphertext modulus in the engine
nsarlin-zama Feb 6, 2025
f2f2622
refactor(core): allow different input/output scalars in multibit br
nsarlin-zama Mar 17, 2025
d8042e7
refactor(shortint): support any scalar in modswitch noise reduction
nsarlin-zama Mar 12, 2025
fee8641
refactor(shortint): make oprf generic over the Scalar type
nsarlin-zama Mar 17, 2025
43cb1c5
refactor(shortint): make modswitch compression generic over scalar
nsarlin-zama Mar 12, 2025
9245ed0
refactor(core): make ksk generation generic over the scalar type
nsarlin-zama Mar 19, 2025
20a792a
feat(shortint): introduce the KS32 atomic pattern
nsarlin-zama Mar 17, 2025
15b80ad
chore(tests): add support for AP in tests and benches
nsarlin-zama Mar 27, 2025
bc6937c
chore(shortint): add tests for the KS32 AP
nsarlin-zama Apr 4, 2025
8579504
feat(shortint): allow the KS32 parameters to have non native KSK modulus
IceTDrinker Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def filter_shortint_tests(input_args):
msg_carry_pairs.append((4, 4))

filter_expression = [
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks_pbs.*/)"
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks(32)?_pbs.*/)"
for msg, carry in msg_carry_pairs
]
filter_expression.append("test(/^shortint::.*_ci_run_filter/)")
Expand Down
6 changes: 3 additions & 3 deletions tests/backward_compatibility/high_level_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::shortint::load_params;
use crate::{load_and_unversionize, TestedModule};
use std::path::Path;
use tfhe::prelude::{CiphertextList, FheDecrypt, FheEncrypt};
use tfhe::shortint::PBSParameters;
use tfhe::shortint::{AtomicPatternParameters, PBSParameters};
#[cfg(feature = "zk-pok")]
use tfhe::zk::CompactPkeCrs;
use tfhe::{
Expand All @@ -25,10 +25,10 @@ use tfhe_backward_compat_data::{
};
use tfhe_versionable::Unversionize;

fn load_hl_params(test_params: &TestParameterSet) -> PBSParameters {
fn load_hl_params(test_params: &TestParameterSet) -> AtomicPatternParameters {
let pbs_params = load_params(test_params);

PBSParameters::PBS(pbs_params)
PBSParameters::PBS(pbs_params).into()
}

/// Test HL ciphertext: loads the ciphertext and compare the decrypted value to the one in the
Expand Down
23 changes: 18 additions & 5 deletions tfhe/benches/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub mod shortint_utils {
use super::*;
use itertools::iproduct;
use std::vec::IntoIter;
use tfhe::shortint::atomic_pattern::AtomicPatternParameters;
use tfhe::shortint::parameters::compact_public_key_only::CompactPublicKeyEncryptionParameters;
#[cfg(not(feature = "gpu"))]
use tfhe::shortint::parameters::current_params::V1_1_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
Expand All @@ -58,8 +59,10 @@ pub mod shortint_utils {
/// of parameters and a num_block to achieve a certain bit_size ciphertext
/// in radix decomposition
pub struct ParamsAndNumBlocksIter {
params_and_bit_sizes:
itertools::Product<IntoIter<tfhe::shortint::PBSParameters>, IntoIter<usize>>,
params_and_bit_sizes: itertools::Product<
IntoIter<tfhe::shortint::atomic_pattern::AtomicPatternParameters>,
IntoIter<usize>,
>,
}

impl Default for ParamsAndNumBlocksIter {
Expand Down Expand Up @@ -92,7 +95,11 @@ pub mod shortint_utils {
}

impl Iterator for ParamsAndNumBlocksIter {
type Item = (tfhe::shortint::PBSParameters, usize, usize);
type Item = (
tfhe::shortint::atomic_pattern::AtomicPatternParameters,
usize,
usize,
);

fn next(&mut self) -> Option<Self::Item> {
let (param, bit_size) = self.params_and_bit_sizes.next()?;
Expand All @@ -105,6 +112,12 @@ pub mod shortint_utils {

impl From<PBSParameters> for CryptoParametersRecord<u64> {
fn from(params: PBSParameters) -> Self {
AtomicPatternParameters::from(params).into()
}
}

impl From<AtomicPatternParameters> for CryptoParametersRecord<u64> {
fn from(params: AtomicPatternParameters) -> Self {
CryptoParametersRecord {
lwe_dimension: Some(params.lwe_dimension()),
glwe_dimension: Some(params.glwe_dimension()),
Expand Down Expand Up @@ -757,7 +770,7 @@ mod cuda_utils {
impl<T: UnsignedInteger> CudaLocalKeys<T> {
pub fn from_cpu_keys(
cpu_keys: &CpuKeys<T>,
ms_noise_reduction_key: Option<&ModulusSwitchNoiseReductionKey>,
ms_noise_reduction_key: Option<&ModulusSwitchNoiseReductionKey<u64>>,
stream: &CudaStreams,
) -> Self {
Self {
Expand All @@ -782,7 +795,7 @@ mod cuda_utils {
#[allow(dead_code)]
pub fn cuda_local_keys_core<T: UnsignedInteger>(
cpu_keys: &CpuKeys<T>,
ms_noise_reduction_key: Option<&ModulusSwitchNoiseReductionKey>,
ms_noise_reduction_key: Option<&ModulusSwitchNoiseReductionKey<u64>>,
) -> Vec<CudaLocalKeys<T>> {
let gpu_count = get_number_of_gpus() as usize;
let mut gpu_keys_vec = Vec::with_capacity(gpu_count);
Expand Down
9 changes: 5 additions & 4 deletions tfhe/examples/utilities/shortint_key_sizes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::parameters::current_params::*;
use tfhe::shortint::parameters::*;
use tfhe::shortint::server_key::{StandardServerKey, StandardServerKeyView};
use tfhe::shortint::{
ClassicPBSParameters, ClientKey, CompactPrivateKey, CompressedCompactPublicKey,
CompressedKeySwitchingKey, CompressedServerKey, PBSParameters,
Expand Down Expand Up @@ -58,7 +59,7 @@ fn client_server_key_sizes(results_file: &Path) {
let keys = KEY_CACHE.get_from_param(params);

let cks = keys.client_key();
let sks = keys.server_key();
let sks = StandardServerKeyView::try_from(keys.server_key().as_view()).unwrap();
let ksk_size = sks.key_switching_key_size_bytes();
let test_name = format!("shortint_key_sizes_{}_ksk", params.name());

Expand Down Expand Up @@ -167,10 +168,10 @@ fn tuniform_key_set_sizes(results_file: &Path) {
let param_fhe_name = param_fhe.name();
let cks = ClientKey::new(param_fhe);
let compressed_sks = CompressedServerKey::new(&cks);
let sks = compressed_sks.decompress();
let sks = StandardServerKey::try_from(compressed_sks.decompress()).unwrap();

measure_serialized_size(
&sks.key_switching_key,
&sks.atomic_pattern.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk",
Expand All @@ -187,7 +188,7 @@ fn tuniform_key_set_sizes(results_file: &Path) {
);

measure_serialized_size(
&sks.bootstrapping_key,
&sks.atomic_pattern.bootstrapping_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk",
Expand Down
93 changes: 60 additions & 33 deletions tfhe/src/core_crypto/algorithms/lwe_keyswitch_key_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ use crate::core_crypto::entities::*;
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// let input_lwe_secret_key: LweSecretKeyOwned<u64> =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
/// let output_lwe_secret_key: LweSecretKeyOwned<u64> =
/// allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let mut ksk = LweKeyswitchKey::new(
/// 0u64,
Expand All @@ -64,7 +65,8 @@ use crate::core_crypto::entities::*;
/// assert!(!ksk.as_ref().iter().all(|&x| x == 0));
/// ```
pub fn generate_lwe_keyswitch_key<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -77,11 +79,12 @@ pub fn generate_lwe_keyswitch_key<
noise_distribution: NoiseDistribution,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
Expand All @@ -106,7 +109,8 @@ pub fn generate_lwe_keyswitch_key<
}

pub fn generate_lwe_keyswitch_key_native_mod_compatible<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -119,11 +123,12 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
noise_distribution: NoiseDistribution,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
assert!(
Expand All @@ -140,6 +145,13 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.decomposition_base_log().0
* lwe_keyswitch_key.decomposition_level_count().0
<= OutputScalar::BITS,
"This operation only supports a DecompositionBaseLog and DecompositionLevelCount product \
smaller than the OutputScalar bit count."
);

let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
Expand All @@ -148,7 +160,7 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<

// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
PlaintextListOwned::new(OutputScalar::ZERO, PlaintextCount(decomp_level_count.0));

// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
Expand All @@ -165,9 +177,13 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
// Here we take the decomposition term from the native torus, bring it to the torus we
// are working with by dividing by the scaling factor and the encryption will take care
// of mapping that back to the native torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
*message.0 = DecompositionTerm::new(
level,
decomp_base_log,
CastInto::<OutputScalar>::cast_into(*input_key_element),
)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}

encrypt_lwe_ciphertext_list(
Expand All @@ -181,7 +197,8 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
}

pub fn generate_lwe_keyswitch_key_other_mod<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -194,11 +211,12 @@ pub fn generate_lwe_keyswitch_key_other_mod<
noise_distribution: NoiseDistribution,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
assert!(
Expand All @@ -215,6 +233,13 @@ pub fn generate_lwe_keyswitch_key_other_mod<
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.decomposition_base_log().0
* lwe_keyswitch_key.decomposition_level_count().0
<= OutputScalar::BITS,
"This operation only supports a DecompositionBaseLog and DecompositionLevelCount product \
smaller than the OutputScalar bit count."
);

let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
Expand All @@ -223,7 +248,7 @@ pub fn generate_lwe_keyswitch_key_other_mod<

// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
PlaintextListOwned::new(OutputScalar::ZERO, PlaintextCount(decomp_level_count.0));

// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
Expand All @@ -243,7 +268,7 @@ pub fn generate_lwe_keyswitch_key_other_mod<
*message.0 = DecompositionTermNonNative::new(
level,
decomp_base_log,
*input_key_element,
CastInto::<OutputScalar>::cast_into(*input_key_element),
ciphertext_modulus,
)
.to_approximate_recomposition_summand();
Expand All @@ -264,7 +289,8 @@ pub fn generate_lwe_keyswitch_key_other_mod<
///
/// See [`keyswitch_lwe_ciphertext`] for usage.
pub fn allocate_and_generate_new_lwe_keyswitch_key<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -275,18 +301,19 @@ pub fn allocate_and_generate_new_lwe_keyswitch_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
ciphertext_modulus: CiphertextModulus<OutputScalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
) -> LweKeyswitchKeyOwned<Scalar>
) -> LweKeyswitchKeyOwned<OutputScalar>
where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
let mut new_lwe_keyswitch_key = LweKeyswitchKeyOwned::new(
Scalar::ZERO,
OutputScalar::ZERO,
decomp_base_log,
decomp_level_count,
input_lwe_sk.lwe_dimension(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,18 @@ pub fn prepare_multi_bit_ggsw_mem_optimized<GgswBufferCont, GgswGroupCont, Fouri
/// "Multiplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
/// );
/// ```
pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
pub fn multi_bit_blind_rotate_assign<InputScalar, InputCont, OutputScalar, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
accumulator: &mut GlweCiphertext<OutputCont>,
multi_bit_bsk: &FourierLweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
deterministic_execution: bool,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
InputScalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
OutputScalar: UnsignedTorus + Sync,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
KeyCont: Container<Element = c64> + Sync,
{
assert_eq!(
Expand All @@ -327,14 +328,6 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
multi_bit_bsk.input_lwe_dimension(),
);

assert_eq!(
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
);

let grouping_factor = multi_bit_bsk.grouping_factor();

let lut_poly_size = accumulator.polynomial_size();
Expand Down
Loading
Loading