Skip to content

Commit

Permalink
chore(all): use a builder pattern for safe serialization API
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Aug 9, 2024
1 parent 5340859 commit 94f7408
Show file tree
Hide file tree
Showing 10 changed files with 370 additions and 293 deletions.
34 changes: 14 additions & 20 deletions tfhe/docs/fundamentals/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Here is an example:
use tfhe::conformance::ParameterSetConformant;
use tfhe::integer::parameters::RadixCiphertextConformanceParams;
use tfhe::prelude::*;
use tfhe::safe_deserialization::{safe_deserialize_conformant, safe_serialize};
use tfhe::safe_deserialization::{SerializationConfig, DeserializationConfig};
use tfhe::shortint::parameters::{PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_2_CARRY_2_PBS_KS};
use tfhe::conformance::ListSizeConstraint;
use tfhe::{
Expand Down Expand Up @@ -130,19 +130,15 @@ fn main() {

let mut buffer = vec![];

safe_serialize(&ct, &mut buffer, 1 << 40).unwrap();
SerializationConfig::new(1<<40).serialize_into(&ct, &mut buffer).unwrap();

assert!(safe_deserialize_conformant::<FheUint8>(
buffer.as_slice(),
1 << 20,
&conformance_params_2
).is_err());

let ct2 = safe_deserialize_conformant::<FheUint8>(
buffer.as_slice(),
1 << 20,
&conformance_params_1
).unwrap();
assert!(DeserializationConfig::new(1 << 20, &conformance_params_2)
.deserialize_from::<FheUint8>(buffer.as_slice())
.is_err());

let ct2 = DeserializationConfig::new(1 << 20, &conformance_params_1)
.deserialize_from::<FheUint8>(buffer.as_slice())
.unwrap();

let dec: u8 = ct2.decrypt(&client_key);
assert_eq!(msg, dec);
Expand All @@ -155,18 +151,16 @@ fn main() {
let compact_list = builder.build();

let mut buffer = vec![];
safe_serialize(&compact_list, &mut buffer, 1 << 40).unwrap();
SerializationConfig::new(1<<40).serialize_into(&compact_list, &mut buffer).unwrap();

let conformance_params = CompactCiphertextListConformanceParams {
shortint_params: params_1.to_shortint_conformance_param(),
num_elements_constraint: ListSizeConstraint::exact_size(2),
};
assert!(safe_deserialize_conformant::<CompactCiphertextList>(
buffer.as_slice(),
1 << 20,
&conformance_params
).is_ok());
assert!(DeserializationConfig::new(1 << 20, &conformance_params)
.deserialize_from::<CompactCiphertextList>(buffer.as_slice())
.is_ok());
}
```

You can combine this serialization/deserialization feature with the [data versioning](../guides/data\_versioning.md) feature by using the `safe_serialize_versioned` and `safe_deserialize_conformant_versioned` functions.
By default, this feature uses the [data versioning](../guides/data\_versioning.md).
1 change: 1 addition & 0 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ impl Named for FheBool {
const NAME: &'static str = "high_level_api::FheBool";
}

#[derive(Copy, Clone)]
pub struct FheBoolConformanceParams(pub(crate) CiphertextConformanceParams);

impl<P> From<P> for FheBoolConformanceParams
Expand Down
17 changes: 7 additions & 10 deletions tfhe/src/high_level_api/booleans/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ fn compressed_bool_test_case(setup_fn: impl FnOnce() -> (ClientKey, Device)) {

mod cpu {
use super::*;
use crate::safe_deserialization::safe_deserialize_conformant;
use crate::safe_deserialization::DeserializationConfig;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use crate::FheBoolConformanceParams;
use rand::random;
Expand Down Expand Up @@ -685,9 +685,9 @@ mod cpu {
assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok());

let params = FheBoolConformanceParams::from(&server_key);
let deserialized_a =
safe_deserialize_conformant::<FheBool>(serialized.as_slice(), 1 << 20, &params)
.unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<FheBool>(serialized.as_slice())
.unwrap();
let decrypted: bool = deserialized_a.decrypt(&client_key);
assert_eq!(decrypted, clear_a);

Expand All @@ -706,12 +706,9 @@ mod cpu {
assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok());

let params = FheBoolConformanceParams::from(&server_key);
let deserialized_a = safe_deserialize_conformant::<CompressedFheBool>(
serialized.as_slice(),
1 << 20,
&params,
)
.unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<CompressedFheBool>(serialized.as_slice())
.unwrap();

assert!(deserialized_a.is_conformant(&FheBoolConformanceParams::from(block_params)));

Expand Down
1 change: 1 addition & 0 deletions tfhe/src/high_level_api/integers/signed/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub struct FheInt<Id: FheIntId> {
pub(in crate::high_level_api::integers) id: Id,
}

#[derive(Copy, Clone)]
pub struct FheIntConformanceParams<Id: FheIntId> {
pub(crate) params: RadixCiphertextConformanceParams,
pub(crate) id: PhantomData<Id>,
Expand Down
13 changes: 7 additions & 6 deletions tfhe/src/high_level_api/integers/signed/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::integer::I256;
use crate::prelude::*;
use crate::safe_deserialization::safe_deserialize_conformant;
use crate::safe_deserialization::DeserializationConfig;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
use crate::{
generate_keys, set_server_key, ClientKey, CompactCiphertextList, CompactPublicKey,
Expand Down Expand Up @@ -651,8 +651,9 @@ fn test_safe_deserialize_conformant_fhe_int32() {
assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok());

let params = FheInt32ConformanceParams::from(&server_key);
let deserialized_a =
safe_deserialize_conformant::<FheInt32>(serialized.as_slice(), 1 << 20, &params).unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<FheInt32>(serialized.as_slice())
.unwrap();
let decrypted: i32 = deserialized_a.decrypt(&client_key);
assert_eq!(decrypted, clear_a);

Expand All @@ -673,9 +674,9 @@ fn test_safe_deserialize_conformant_compressed_fhe_int32() {
assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok());

let params = FheInt32ConformanceParams::from(&server_key);
let deserialized_a =
safe_deserialize_conformant::<CompressedFheInt32>(serialized.as_slice(), 1 << 20, &params)
.unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<CompressedFheInt32>(serialized.as_slice())
.unwrap();

let params = FheInt32ConformanceParams::from(block_params);
assert!(deserialized_a.is_conformant(&params));
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/high_level_api/integers/unsigned/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub struct FheUint<Id: FheUintId> {
pub(in crate::high_level_api::integers) id: Id,
}

#[derive(Copy, Clone)]
pub struct FheUintConformanceParams<Id: FheUintId> {
pub(crate) params: RadixCiphertextConformanceParams,
pub(crate) id: PhantomData<Id>,
Expand Down
22 changes: 10 additions & 12 deletions tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::conformance::ListSizeConstraint;
use crate::high_level_api::prelude::*;
use crate::high_level_api::{generate_keys, set_server_key, ConfigBuilder, FheUint8};
use crate::integer::U256;
use crate::safe_deserialization::safe_deserialize_conformant;
use crate::safe_deserialization::DeserializationConfig;
use crate::shortint::parameters::classic::compact_pk::*;
use crate::shortint::parameters::compact_public_key_only::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::key_switching::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
Expand Down Expand Up @@ -410,8 +410,9 @@ fn test_safe_deserialize_conformant_fhe_uint32() {
assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok());

let params = FheUint32ConformanceParams::from(&server_key);
let deserialized_a =
safe_deserialize_conformant::<FheUint32>(serialized.as_slice(), 1 << 20, &params).unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<FheUint32>(serialized.as_slice())
.unwrap();
let decrypted: u32 = deserialized_a.decrypt(&client_key);
assert_eq!(decrypted, clear_a);

Expand All @@ -431,9 +432,9 @@ fn test_safe_deserialize_conformant_compressed_fhe_uint32() {
assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok());

let params = FheUint32ConformanceParams::from(&server_key);
let deserialized_a =
safe_deserialize_conformant::<CompressedFheUint32>(serialized.as_slice(), 1 << 20, &params)
.unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<CompressedFheUint32>(serialized.as_slice())
.unwrap();

assert!(deserialized_a.is_conformant(&FheUint32ConformanceParams::from(block_params)));

Expand All @@ -460,12 +461,9 @@ fn test_safe_deserialize_conformant_compact_fhe_uint32() {
shortint_params: block_params.to_shortint_conformance_param(),
num_elements_constraint: ListSizeConstraint::exact_size(clears.len()),
};
let deserialized_a = safe_deserialize_conformant::<CompactCiphertextList>(
serialized.as_slice(),
1 << 20,
&params,
)
.unwrap();
let deserialized_a = DeserializationConfig::new(1 << 20, &params)
.deserialize_from::<CompactCiphertextList>(serialized.as_slice())
.unwrap();

let expander = deserialized_a.expand().unwrap();
for (i, clear) in clears.into_iter().enumerate() {
Expand Down
11 changes: 7 additions & 4 deletions tfhe/src/high_level_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ pub mod safe_serialize {
serialized_size_limit: u64,
) -> Result<(), String>
where
T: Named + Serialize,
T: Named + Versionize + Serialize,
{
crate::safe_deserialization::safe_serialize(a, writer, serialized_size_limit)
crate::safe_deserialization::SerializationConfig::new(serialized_size_limit)
.disable_versioning()
.serialize_into(a, writer)
.map_err(|err| err.to_string())
}

Expand All @@ -142,9 +144,10 @@ pub mod safe_serialize {
serialized_size_limit: u64,
) -> Result<(), String>
where
T: Named + Versionize,
T: Named + Versionize + Serialize,
{
crate::safe_deserialization::safe_serialize_versioned(a, writer, serialized_size_limit)
crate::safe_deserialization::SerializationConfig::new(serialized_size_limit)
.serialize_into(a, writer)
.map_err(|err| err.to_string())
}
}
2 changes: 2 additions & 0 deletions tfhe/src/integer/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ pub const PARAM_MESSAGE_1_CARRY_1_KS_PBS_32_BITS: WopbsParameters = WopbsParamet
encryption_key_choice: EncryptionKeyChoice::Big,
};

#[derive(Copy, Clone)]
pub struct RadixCiphertextConformanceParams {
pub shortint_params: CiphertextConformanceParams,
pub num_blocks_per_integer: usize,
Expand Down Expand Up @@ -210,6 +211,7 @@ impl RadixCiphertextConformanceParams {
/// Structure to store the expected properties of a ciphertext list
/// Can be used on a server to check if client inputs are well formed
/// before running a computation on them
#[derive(Copy, Clone)]
pub struct CompactCiphertextListConformanceParams {
pub shortint_params: CiphertextConformanceParams,
pub num_elements_constraint: ListSizeConstraint,
Expand Down
Loading

0 comments on commit 94f7408

Please sign in to comment.