Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ns/feat/atomic patterns #2024

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1e79c39
refactor(shortint): wrap PbsOrder into AtomicPattern in ciphertext
nsarlin-zama Jan 22, 2025
3e77803
refactor(shortint): factorize generate_lookup_table
nsarlin-zama Jan 21, 2025
025d6de
refactor(shortint): use a single lwe buffer inside shortint engine
nsarlin-zama Jan 23, 2025
fd3fc7e
refactor(shortint): function to directly set noise level to nominal
nsarlin-zama Jan 27, 2025
fc7b5a8
refactor(shortint): use a dedicated type for lut size
nsarlin-zama Jan 27, 2025
8febf54
refactor(shortint): remove degree in generate_lookup_table_no_encode
nsarlin-zama Mar 6, 2025
211e603
feat(shortint): create atomic pattern trait and enum
nsarlin-zama Mar 13, 2025
c5c1a6b
feat(shortint): insert the AP inside the ServerKey
nsarlin-zama Feb 10, 2025
337ae3b
refactor(shortint): engine can create any atomic pattern sk
nsarlin-zama Mar 17, 2025
4ed64fd
feat(shortint): add the dynamic ap
nsarlin-zama Feb 10, 2025
5019f98
refactor(shortint): support any ciphertext modulus in the engine
nsarlin-zama Feb 6, 2025
b7a6baa
refactor(core): allow different input/output scalars in multibit br
nsarlin-zama Mar 17, 2025
afd3a6e
refactor(shortint): support any scalar in modswitch noise reduction
nsarlin-zama Mar 12, 2025
1505647
refactor(shortint): make oprf generic over the Scalar type
nsarlin-zama Mar 17, 2025
49dd8a4
refactor(shortint): make modswitch compression generic over scalar
nsarlin-zama Mar 12, 2025
0286201
refactor(core): make ksk generation generic over the scalar type
nsarlin-zama Mar 19, 2025
2545f82
feat(shortint): introduce the KS32 atomic pattern
nsarlin-zama Mar 17, 2025
20843dd
chore(tests): add support for AP in tests and benches
nsarlin-zama Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/backward_compatibility/high_level_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::shortint::load_params;
use crate::{load_and_unversionize, TestedModule};
use std::path::Path;
use tfhe::prelude::{CiphertextList, FheDecrypt, FheEncrypt};
use tfhe::shortint::PBSParameters;
use tfhe::shortint::{AtomicPatternParameters, PBSParameters};
#[cfg(feature = "zk-pok")]
use tfhe::zk::CompactPkeCrs;
use tfhe::{
Expand All @@ -22,10 +22,10 @@ use tfhe_backward_compat_data::{
};
use tfhe_versionable::Unversionize;

fn load_hl_params(test_params: &TestParameterSet) -> PBSParameters {
fn load_hl_params(test_params: &TestParameterSet) -> AtomicPatternParameters {
let pbs_params = load_params(test_params);

PBSParameters::PBS(pbs_params)
PBSParameters::PBS(pbs_params).into()
}

/// Test HL ciphertext: loads the ciphertext and compare the decrypted value to the one in the
Expand Down
38 changes: 35 additions & 3 deletions tfhe/benches/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub mod shortint_utils {
use super::*;
use itertools::iproduct;
use std::vec::IntoIter;
use tfhe::shortint::atomic_pattern::AtomicPatternParameters;
use tfhe::shortint::parameters::compact_public_key_only::CompactPublicKeyEncryptionParameters;
#[cfg(not(feature = "gpu"))]
use tfhe::shortint::parameters::current_params::V1_1_PARAM_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
Expand All @@ -56,8 +57,10 @@ pub mod shortint_utils {
/// of parameters and a num_block to achieve a certain bit_size ciphertext
/// in radix decomposition
pub struct ParamsAndNumBlocksIter {
params_and_bit_sizes:
itertools::Product<IntoIter<tfhe::shortint::PBSParameters>, IntoIter<usize>>,
params_and_bit_sizes: itertools::Product<
IntoIter<tfhe::shortint::atomic_pattern::AtomicPatternParameters>,
IntoIter<usize>,
>,
}

impl Default for ParamsAndNumBlocksIter {
Expand Down Expand Up @@ -90,7 +93,11 @@ pub mod shortint_utils {
}

impl Iterator for ParamsAndNumBlocksIter {
type Item = (tfhe::shortint::PBSParameters, usize, usize);
type Item = (
tfhe::shortint::atomic_pattern::AtomicPatternParameters,
usize,
usize,
);

fn next(&mut self) -> Option<Self::Item> {
let (param, bit_size) = self.params_and_bit_sizes.next()?;
Expand Down Expand Up @@ -126,6 +133,31 @@ pub mod shortint_utils {
}
}

impl From<AtomicPatternParameters> for CryptoParametersRecord<u64> {
fn from(params: AtomicPatternParameters) -> Self {
CryptoParametersRecord {
lwe_dimension: Some(params.lwe_dimension()),
glwe_dimension: Some(params.glwe_dimension()),
polynomial_size: Some(params.polynomial_size()),
lwe_noise_distribution: Some(params.lwe_noise_distribution()),
glwe_noise_distribution: Some(params.glwe_noise_distribution()),
pbs_base_log: Some(params.pbs_base_log()),
pbs_level: Some(params.pbs_level()),
ks_base_log: Some(params.ks_base_log()),
ks_level: Some(params.ks_level()),
message_modulus: Some(params.message_modulus().0),
carry_modulus: Some(params.carry_modulus().0),
ciphertext_modulus: Some(
params
.ciphertext_modulus()
.try_to()
.expect("failed to convert ciphertext modulus"),
),
..Default::default()
}
}
}

impl From<ShortintKeySwitchingParameters> for CryptoParametersRecord<u64> {
fn from(params: ShortintKeySwitchingParameters) -> Self {
CryptoParametersRecord {
Expand Down
9 changes: 5 additions & 4 deletions tfhe/examples/utilities/shortint_key_sizes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::parameters::current_params::*;
use tfhe::shortint::parameters::*;
use tfhe::shortint::server_key::{ClassicalServerKey, ClassicalServerKeyView};
use tfhe::shortint::{
ClassicPBSParameters, ClientKey, CompactPrivateKey, CompressedCompactPublicKey,
CompressedKeySwitchingKey, CompressedServerKey, PBSParameters,
Expand Down Expand Up @@ -58,7 +59,7 @@ fn client_server_key_sizes(results_file: &Path) {
let keys = KEY_CACHE.get_from_param(params);

let cks = keys.client_key();
let sks = keys.server_key();
let sks = ClassicalServerKeyView::try_from(keys.server_key().as_view()).unwrap();
let ksk_size = sks.key_switching_key_size_bytes();
let test_name = format!("shortint_key_sizes_{}_ksk", params.name());

Expand Down Expand Up @@ -170,10 +171,10 @@ fn tuniform_key_set_sizes(results_file: &Path) {
let param_fhe_name = param_fhe.name();
let cks = ClientKey::new(param_fhe);
let compressed_sks = CompressedServerKey::new(&cks);
let sks = compressed_sks.decompress();
let sks = ClassicalServerKey::try_from(compressed_sks.decompress()).unwrap();

measure_serialized_size(
&sks.key_switching_key,
&sks.atomic_pattern.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk",
Expand All @@ -190,7 +191,7 @@ fn tuniform_key_set_sizes(results_file: &Path) {
);

measure_serialized_size(
&sks.bootstrapping_key,
&sks.atomic_pattern.bootstrapping_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk",
Expand Down
79 changes: 46 additions & 33 deletions tfhe/src/core_crypto/algorithms/lwe_keyswitch_key_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ use crate::core_crypto::entities::*;
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// let input_lwe_secret_key: LweSecretKeyOwned<u64> =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
/// let output_lwe_secret_key: LweSecretKeyOwned<u64> =
/// allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let mut ksk = LweKeyswitchKey::new(
/// 0u64,
Expand All @@ -64,7 +65,8 @@ use crate::core_crypto::entities::*;
/// assert!(!ksk.as_ref().iter().all(|&x| x == 0));
/// ```
pub fn generate_lwe_keyswitch_key<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -77,11 +79,12 @@ pub fn generate_lwe_keyswitch_key<
noise_distribution: NoiseDistribution,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
Expand All @@ -106,7 +109,8 @@ pub fn generate_lwe_keyswitch_key<
}

pub fn generate_lwe_keyswitch_key_native_mod_compatible<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -119,11 +123,12 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
noise_distribution: NoiseDistribution,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
assert!(
Expand All @@ -148,7 +153,7 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<

// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
PlaintextListOwned::new(OutputScalar::ZERO, PlaintextCount(decomp_level_count.0));

// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
Expand All @@ -165,9 +170,13 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
// Here we take the decomposition term from the native torus, bring it to the torus we
// are working with by dividing by the scaling factor and the encryption will take care
// of mapping that back to the native torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
*message.0 = DecompositionTerm::new(
level,
decomp_base_log,
CastInto::<OutputScalar>::cast_into(*input_key_element),
)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}

encrypt_lwe_ciphertext_list(
Expand All @@ -181,7 +190,8 @@ pub fn generate_lwe_keyswitch_key_native_mod_compatible<
}

pub fn generate_lwe_keyswitch_key_other_mod<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -194,11 +204,12 @@ pub fn generate_lwe_keyswitch_key_other_mod<
noise_distribution: NoiseDistribution,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
assert!(
Expand All @@ -223,7 +234,7 @@ pub fn generate_lwe_keyswitch_key_other_mod<

// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
PlaintextListOwned::new(OutputScalar::ZERO, PlaintextCount(decomp_level_count.0));

// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
Expand All @@ -243,7 +254,7 @@ pub fn generate_lwe_keyswitch_key_other_mod<
*message.0 = DecompositionTermNonNative::new(
level,
decomp_base_log,
*input_key_element,
CastInto::<OutputScalar>::cast_into(*input_key_element),
ciphertext_modulus,
)
.to_approximate_recomposition_summand();
Expand All @@ -264,7 +275,8 @@ pub fn generate_lwe_keyswitch_key_other_mod<
///
/// See [`keyswitch_lwe_ciphertext`] for usage.
pub fn allocate_and_generate_new_lwe_keyswitch_key<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
Expand All @@ -275,18 +287,19 @@ pub fn allocate_and_generate_new_lwe_keyswitch_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
ciphertext_modulus: CiphertextModulus<OutputScalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
) -> LweKeyswitchKeyOwned<Scalar>
) -> LweKeyswitchKeyOwned<OutputScalar>
where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
Gen: ByteRandomGenerator,
{
let mut new_lwe_keyswitch_key = LweKeyswitchKeyOwned::new(
Scalar::ZERO,
OutputScalar::ZERO,
decomp_base_log,
decomp_level_count,
input_lwe_sk.lwe_dimension(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,18 @@ pub fn prepare_multi_bit_ggsw_mem_optimized<GgswBufferCont, GgswGroupCont, Fouri
/// "Multiplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
/// );
/// ```
pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
pub fn multi_bit_blind_rotate_assign<InputScalar, InputCont, OutputScalar, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
accumulator: &mut GlweCiphertext<OutputCont>,
multi_bit_bsk: &FourierLweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
deterministic_execution: bool,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
InputScalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
OutputScalar: UnsignedTorus + Sync,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
KeyCont: Container<Element = c64> + Sync,
{
assert_eq!(
Expand All @@ -327,14 +328,6 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
multi_bit_bsk.input_lwe_dimension(),
);

assert_eq!(
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
"Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
input.ciphertext_modulus(),
accumulator.ciphertext_modulus(),
);

let grouping_factor = multi_bit_bsk.grouping_factor();

let lut_poly_size = accumulator.polynomial_size();
Expand Down
Loading
Loading