Skip to content

Commit

Permalink
refactor(all): refactor oprf integer and hl APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Aug 2, 2024
1 parent 840841c commit 49f19ed
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 79 deletions.
6 changes: 3 additions & 3 deletions tfhe/c_api_tests/test_high_level_integers.c
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ void test_oprf(const ClientKey *client_key) {

fhe_uint8_destroy(ct);

status = generate_oblivious_pseudo_random_bits_fhe_uint8(&ct, 0, 0, 2);
status = generate_oblivious_pseudo_random_bounded_fhe_uint8(&ct, 0, 0, 2);
assert(status == 0);

status = fhe_uint8_decrypt(ct, client_key, &decrypted);
Expand All @@ -613,7 +613,7 @@ void test_oprf(const ClientKey *client_key) {
{
FheInt8 *ct = NULL;

int status = generate_oblivious_pseudo_random_full_signed_range_fhe_int8(&ct, 0, 0);
int status = generate_oblivious_pseudo_random_fhe_int8(&ct, 0, 0);
assert(status == 0);

int8_t decrypted;
Expand All @@ -623,7 +623,7 @@ void test_oprf(const ClientKey *client_key) {

fhe_int8_destroy(ct);

status = generate_oblivious_pseudo_random_unsigned_fhe_int8(&ct, 0, 0, 2);
status = generate_oblivious_pseudo_random_bounded_fhe_int8(&ct, 0, 0, 2);
assert(status == 0);

status = fhe_int8_decrypt(ct, client_key, &decrypted);
Expand Down
32 changes: 13 additions & 19 deletions tfhe/src/c_api/high_level_api/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,13 @@ macro_rules! impl_oprf_for_uint {
seed_low_bytes: u64,
seed_high_bytes: u64,
) -> c_int {
use crate::high_level_api::IntegerId;
$crate::c_api::utils::catch_panic(|| {
let seed_low_bytes: u128 = seed_low_bytes.into();
let seed_high_bytes: u128 = seed_high_bytes.into();
let seed = crate::Seed((seed_high_bytes << 64) | seed_low_bytes);

let result = crate::FheUint::generate_oblivious_pseudo_random(
seed,
<crate::[<$name Id>] as IntegerId>::num_bits() as u64
);
*out_result = Box::into_raw(Box::new($name(result)));
})
Expand All @@ -508,7 +506,7 @@ macro_rules! impl_oprf_for_uint {

::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<generate_oblivious_pseudo_random_bits_ $name:snake>](
pub unsafe extern "C" fn [<generate_oblivious_pseudo_random_bounded_ $name:snake>](
out_result: *mut *mut $name,
seed_low_bytes: u64,
seed_high_bytes: u64,
Expand All @@ -520,7 +518,7 @@ macro_rules! impl_oprf_for_uint {
let seed_high_bytes: u128 = seed_high_bytes.into();
let seed = crate::Seed((seed_high_bytes << 64) | seed_low_bytes);

let result = crate::FheUint::generate_oblivious_pseudo_random(seed, random_bits_count);
let result = crate::FheUint::generate_oblivious_pseudo_random_bounded(seed, random_bits_count);
*out_result = Box::into_raw(Box::new($name(result)));
})
}
Expand All @@ -532,48 +530,44 @@ macro_rules! impl_oprf_for_int {
(
name: $name:ident
) => {

::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<generate_oblivious_pseudo_random_unsigned_ $name:snake>](
pub unsafe extern "C" fn [<generate_oblivious_pseudo_random_ $name:snake>](
out_result: *mut *mut $name,
seed_low_bytes: u64,
seed_high_bytes: u64,
random_bits_count: u64,
) -> c_int {
$crate::c_api::utils::catch_panic(|| {
let seed_low_bytes: u128 = seed_low_bytes.into();
let seed_high_bytes: u128 = seed_high_bytes.into();
let seed = crate::Seed((seed_high_bytes << 64) | seed_low_bytes);

let result =
crate::FheInt::generate_oblivious_pseudo_random(
seed,
crate::high_level_api::SignedRandomizationSpec::Unsigned {
random_bits_count
},
);
let result = crate::FheInt::generate_oblivious_pseudo_random(
seed,
);
*out_result = Box::into_raw(Box::new($name(result)));
})
}
}

::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<generate_oblivious_pseudo_random_full_signed_range_ $name:snake>](
pub unsafe extern "C" fn [<generate_oblivious_pseudo_random_bounded_ $name:snake>](
out_result: *mut *mut $name,
seed_low_bytes: u64,
seed_high_bytes: u64,
random_bits_count: u64,
) -> c_int {
$crate::c_api::utils::catch_panic(|| {
let seed_low_bytes: u128 = seed_low_bytes.into();
let seed_high_bytes: u128 = seed_high_bytes.into();
let seed = crate::Seed((seed_high_bytes << 64) | seed_low_bytes);

let result = crate::FheInt::generate_oblivious_pseudo_random(
seed,
crate::high_level_api::SignedRandomizationSpec::FullSigned,
);
let result =
crate::FheInt::generate_oblivious_pseudo_random_bounded(
seed,
random_bits_count,
);
*out_result = Box::into_raw(Box::new($name(result)));
})
}
Expand Down
42 changes: 35 additions & 7 deletions tfhe/src/high_level_api/integers/oprf.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::{FheIntId, FheUintId};
use crate::high_level_api::global_state;
use crate::high_level_api::keys::InternalServerKey;
use crate::integer::oprf::SignedRandomizationSpec;
use crate::{FheInt, FheUint, Seed};

impl<Id: FheUintId> FheUint<Id> {
Expand All @@ -26,11 +25,27 @@ impl<Id: FheUintId> FheUint<Id> {
/// let dec_result: u16 = ct_res.decrypt(&client_key);
/// assert!(dec_result < (1 << random_bits_count));
/// ```
pub fn generate_oblivious_pseudo_random(seed: Seed, random_bits_count: u64) -> Self {
pub fn generate_oblivious_pseudo_random(seed: Seed) -> Self {
let ct = global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(key) => key
.key
.par_generate_oblivious_pseudo_random_unsigned_integer(
seed,
Id::num_blocks(key.message_modulus()) as u64,
),
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
todo!("Cuda devices do not yet support oblivious pseudo random generation")
}
});

Self::new(ct)
}
pub fn generate_oblivious_pseudo_random_bounded(seed: Seed, random_bits_count: u64) -> Self {
let ct = global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(key) => key
.key
.par_generate_oblivious_pseudo_random_unsigned_integer_bounded(
seed,
random_bits_count,
Id::num_blocks(key.message_modulus()) as u64,
Expand Down Expand Up @@ -69,15 +84,11 @@ impl<Id: FheIntId> FheInt<Id> {
/// assert!(dec_result < 1 << 7);
/// assert!(dec_result >= -(1 << 7));
/// ```
pub fn generate_oblivious_pseudo_random(
seed: Seed,
randomizer: SignedRandomizationSpec,
) -> Self {
pub fn generate_oblivious_pseudo_random(seed: Seed) -> Self {
let ct = global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(key) => {
key.key.par_generate_oblivious_pseudo_random_signed_integer(
seed,
randomizer,
Id::num_blocks(key.message_modulus()) as u64,
)
}
Expand All @@ -87,6 +98,23 @@ impl<Id: FheIntId> FheInt<Id> {
}
});

Self::new(ct)
}
pub fn generate_oblivious_pseudo_random_bounded(seed: Seed, random_bits_count: u64) -> Self {
let ct = global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(key) => key
.key
.par_generate_oblivious_pseudo_random_signed_integer_bounded(
seed,
random_bits_count,
Id::num_blocks(key.message_modulus()) as u64,
),
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
todo!("Cuda devices do not yet support oblivious pseudo random generation")
}
});

Self::new(ct)
}
}
1 change: 0 additions & 1 deletion tfhe/src/high_level_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ macro_rules! expand_pub_use_fhe_type(
);

pub use crate::core_crypto::commons::math::random::Seed;
pub use crate::integer::oprf::SignedRandomizationSpec;
pub use crate::integer::server_key::MatchValues;
pub use config::{Config, ConfigBuilder};
pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context};
Expand Down
Loading

0 comments on commit 49f19ed

Please sign in to comment.