From 94f7408bfeddf778aa99a0b3622899e8f1818691 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Fri, 9 Aug 2024 18:29:47 +0200 Subject: [PATCH] chore(all): use a builder pattern for safe serialization API --- tfhe/docs/fundamentals/serialization.md | 34 +- tfhe/src/high_level_api/booleans/base.rs | 1 + tfhe/src/high_level_api/booleans/tests.rs | 17 +- .../high_level_api/integers/signed/base.rs | 1 + .../high_level_api/integers/signed/tests.rs | 13 +- .../high_level_api/integers/unsigned/base.rs | 1 + .../integers/unsigned/tests/cpu.rs | 22 +- tfhe/src/high_level_api/mod.rs | 11 +- tfhe/src/integer/parameters/mod.rs | 2 + tfhe/src/safe_deserialization.rs | 561 ++++++++++-------- 10 files changed, 370 insertions(+), 293 deletions(-) diff --git a/tfhe/docs/fundamentals/serialization.md b/tfhe/docs/fundamentals/serialization.md index 52ee964133..c3d1fbab87 100644 --- a/tfhe/docs/fundamentals/serialization.md +++ b/tfhe/docs/fundamentals/serialization.md @@ -98,7 +98,7 @@ Here is an example: use tfhe::conformance::ParameterSetConformant; use tfhe::integer::parameters::RadixCiphertextConformanceParams; use tfhe::prelude::*; -use tfhe::safe_deserialization::{safe_deserialize_conformant, safe_serialize}; +use tfhe::safe_deserialization::{SerializationConfig, DeserializationConfig}; use tfhe::shortint::parameters::{PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_2_CARRY_2_PBS_KS}; use tfhe::conformance::ListSizeConstraint; use tfhe::{ @@ -130,19 +130,15 @@ fn main() { let mut buffer = vec![]; - safe_serialize(&ct, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1<<40).serialize_into(&ct, &mut buffer).unwrap(); - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &conformance_params_2 - ).is_err()); - - let ct2 = safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &conformance_params_1 - ).unwrap(); + assert!(DeserializationConfig::new(1 << 20, &conformance_params_2) + .deserialize_from::(buffer.as_slice()) + .is_err()); + + let ct2 = DeserializationConfig::new(1 << 20, &conformance_params_1) + .deserialize_from::(buffer.as_slice()) + .unwrap(); let dec: u8 = ct2.decrypt(&client_key); assert_eq!(msg, dec); @@ -155,18 +151,16 @@ fn main() { let compact_list = builder.build(); let mut buffer = vec![]; - safe_serialize(&compact_list, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1<<40).serialize_into(&compact_list, &mut buffer).unwrap(); let conformance_params = CompactCiphertextListConformanceParams { shortint_params: params_1.to_shortint_conformance_param(), num_elements_constraint: ListSizeConstraint::exact_size(2), }; - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - &conformance_params - ).is_ok()); + assert!(DeserializationConfig::new(1 << 20, &conformance_params) + .deserialize_from::(buffer.as_slice()) + .is_ok()); } ``` -You can combine this serialization/deserialization feature with the [data versioning](../guides/data\_versioning.md) feature by using the `safe_serialize_versioned` and `safe_deserialize_conformant_versioned` functions. +By default, this feature uses the [data versioning](../guides/data\_versioning.md). diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index 69da74aaa1..24a4bb66d9 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -56,6 +56,7 @@ impl Named for FheBool { const NAME: &'static str = "high_level_api::FheBool"; } +#[derive(Copy, Clone)] pub struct FheBoolConformanceParams(pub(crate) CiphertextConformanceParams); impl

From

for FheBoolConformanceParams diff --git a/tfhe/src/high_level_api/booleans/tests.rs b/tfhe/src/high_level_api/booleans/tests.rs index ddb05ddb3c..bec5190e10 100644 --- a/tfhe/src/high_level_api/booleans/tests.rs +++ b/tfhe/src/high_level_api/booleans/tests.rs @@ -318,7 +318,7 @@ fn compressed_bool_test_case(setup_fn: impl FnOnce() -> (ClientKey, Device)) { mod cpu { use super::*; - use crate::safe_deserialization::safe_deserialize_conformant; + use crate::safe_deserialization::DeserializationConfig; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; use crate::FheBoolConformanceParams; use rand::random; @@ -685,9 +685,9 @@ mod cpu { assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); let params = FheBoolConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); let decrypted: bool = deserialized_a.decrypt(&client_key); assert_eq!(decrypted, clear_a); @@ -706,12 +706,9 @@ mod cpu { assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); let params = FheBoolConformanceParams::from(&server_key); - let deserialized_a = safe_deserialize_conformant::( - serialized.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); assert!(deserialized_a.is_conformant(&FheBoolConformanceParams::from(block_params))); diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index abe998b121..9bb60c9cfd 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -41,6 +41,7 @@ pub struct FheInt { pub(in crate::high_level_api::integers) id: Id, } +#[derive(Copy, Clone)] pub struct FheIntConformanceParams { pub(crate) params: RadixCiphertextConformanceParams, pub(crate) id: PhantomData, diff --git a/tfhe/src/high_level_api/integers/signed/tests.rs b/tfhe/src/high_level_api/integers/signed/tests.rs index 3cc0d539ea..b10bbee939 100644 --- a/tfhe/src/high_level_api/integers/signed/tests.rs +++ b/tfhe/src/high_level_api/integers/signed/tests.rs @@ -1,6 +1,6 @@ use crate::integer::I256; use crate::prelude::*; -use crate::safe_deserialization::safe_deserialize_conformant; +use crate::safe_deserialization::DeserializationConfig; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; use crate::{ generate_keys, set_server_key, ClientKey, CompactCiphertextList, CompactPublicKey, @@ -651,8 +651,9 @@ fn test_safe_deserialize_conformant_fhe_int32() { assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); let params = FheInt32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms).unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); let decrypted: i32 = deserialized_a.decrypt(&client_key); assert_eq!(decrypted, clear_a); @@ -673,9 +674,9 @@ fn test_safe_deserialize_conformant_compressed_fhe_int32() { assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); let params = FheInt32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); let params = FheInt32ConformanceParams::from(block_params); assert!(deserialized_a.is_conformant(¶ms)); diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 2f58b43fba..9443fd0f45 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -80,6 +80,7 @@ pub struct FheUint { pub(in crate::high_level_api::integers) id: Id, } +#[derive(Copy, Clone)] pub struct FheUintConformanceParams { pub(crate) params: RadixCiphertextConformanceParams, pub(crate) id: PhantomData, diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index 4f9e21bf7f..d437d23e34 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -2,7 +2,7 @@ use crate::conformance::ListSizeConstraint; use crate::high_level_api::prelude::*; use crate::high_level_api::{generate_keys, set_server_key, ConfigBuilder, FheUint8}; use crate::integer::U256; -use crate::safe_deserialization::safe_deserialize_conformant; +use crate::safe_deserialization::DeserializationConfig; use crate::shortint::parameters::classic::compact_pk::*; use crate::shortint::parameters::compact_public_key_only::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::key_switching::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; @@ -410,8 +410,9 @@ fn test_safe_deserialize_conformant_fhe_uint32() { assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); let params = FheUint32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms).unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); let decrypted: u32 = deserialized_a.decrypt(&client_key); assert_eq!(decrypted, clear_a); @@ -431,9 +432,9 @@ fn test_safe_deserialize_conformant_compressed_fhe_uint32() { assert!(crate::safe_serialize(&a, &mut serialized, 1 << 20).is_ok()); let params = FheUint32ConformanceParams::from(&server_key); - let deserialized_a = - safe_deserialize_conformant::(serialized.as_slice(), 1 << 20, ¶ms) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); assert!(deserialized_a.is_conformant(&FheUint32ConformanceParams::from(block_params))); @@ -460,12 +461,9 @@ fn test_safe_deserialize_conformant_compact_fhe_uint32() { shortint_params: block_params.to_shortint_conformance_param(), num_elements_constraint: ListSizeConstraint::exact_size(clears.len()), }; - let deserialized_a = safe_deserialize_conformant::( - serialized.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); + let deserialized_a = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(serialized.as_slice()) + .unwrap(); let expander = deserialized_a.expand().unwrap(); for (i, clear) in clears.into_iter().enumerate() { diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index 9fd70a9bef..3d9110641c 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -130,9 +130,11 @@ pub mod safe_serialize { serialized_size_limit: u64, ) -> Result<(), String> where - T: Named + Serialize, + T: Named + Versionize + Serialize, { - crate::safe_deserialization::safe_serialize(a, writer, serialized_size_limit) + crate::safe_deserialization::SerializationConfig::new(serialized_size_limit) + .disable_versioning() + .serialize_into(a, writer) .map_err(|err| err.to_string()) } @@ -142,9 +144,10 @@ pub mod safe_serialize { serialized_size_limit: u64, ) -> Result<(), String> where - T: Named + Versionize, + T: Named + Versionize + Serialize, { - crate::safe_deserialization::safe_serialize_versioned(a, writer, serialized_size_limit) + crate::safe_deserialization::SerializationConfig::new(serialized_size_limit) + .serialize_into(a, writer) .map_err(|err| err.to_string()) } } diff --git a/tfhe/src/integer/parameters/mod.rs b/tfhe/src/integer/parameters/mod.rs index 0fd52e7944..787523f1ac 100644 --- a/tfhe/src/integer/parameters/mod.rs +++ b/tfhe/src/integer/parameters/mod.rs @@ -176,6 +176,7 @@ pub const PARAM_MESSAGE_1_CARRY_1_KS_PBS_32_BITS: WopbsParameters = WopbsParamet encryption_key_choice: EncryptionKeyChoice::Big, }; +#[derive(Copy, Clone)] pub struct RadixCiphertextConformanceParams { pub shortint_params: CiphertextConformanceParams, pub num_blocks_per_integer: usize, @@ -210,6 +211,7 @@ impl RadixCiphertextConformanceParams { /// Structure to store the expected properties of a ciphertext list /// Can be used on a server to check if client inputs are well formed /// before running a computation on them +#[derive(Copy, Clone)] pub struct CompactCiphertextListConformanceParams { pub shortint_params: CiphertextConformanceParams, pub num_elements_constraint: ListSizeConstraint, diff --git a/tfhe/src/safe_deserialization.rs b/tfhe/src/safe_deserialization.rs index ad7ceb21e4..dc58e4e45c 100644 --- a/tfhe/src/safe_deserialization.rs +++ b/tfhe/src/safe_deserialization.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::fmt::Display; use crate::conformance::ParameterSetConformant; use crate::named::Named; @@ -7,35 +8,58 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tfhe_versionable::{Unversionize, Versionize}; -// The `SERIALIZATION_VERSION` is serialized along objects serialized with `safe_serialize`. -// This `SERIALIZATION_VERSION` should be changed on each release where any object serialization -// details changes (this can happen when adding a field or reorderging fields of a struct). -// When a object is deserialized using `safe_deserialize`, the deserialized version is checked -// to be equal to SERIALIZATION_VERSION. -// This helps prevent users from inadvertently deserializaing an object serialized in another -// release. -// When this happens, it also gives a clear version mismatch error rather than a generic -// deserialization error or worse, a garbage object. -const SERIALIZATION_VERSION: &str = "0.4"; +/// This is the global version of the serialization scheme that is used. This should be updated when +/// the SerializationHeader is updated. +const SERIALIZATION_VERSION: &str = "0.5"; + +/// This is the version of the versioning scheme used to add backward compatibibility on tfhe-rs +/// types. Similar to SERIALIZATION_VERSION, this number should be increased when the versioning +/// scheme is upgraded. +const VERSIONING_VERSION: &str = "0.1"; + +/// This is the current version of this crate. This is used to be able to reject unversioned data +/// if they come from a previous version. +const CRATE_VERSION: &str = concat!( + env!("CARGO_PKG_VERSION_MAJOR"), + ".", + env!("CARGO_PKG_VERSION_MINOR") +); /// Tells if this serialized object is versioned or not -#[derive(Serialize, Deserialize, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Copy, Clone, PartialEq, Eq)] // This type should not be versioned because it is part of a wrapper of versioned messages. #[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] -enum SerializationMode { +enum SerializationVersioningMode { /// Serialize with type versioning for backward compatibility Versioned, - /// Directly serialize the type as it is provided - Direct, + /// Serialize the type without versioning information + Unversioned, } -/// Header with global metadata about the serialized object. +impl Display for SerializationVersioningMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Versioned => write!(f, "versioned"), + Self::Unversioned => write!(f, "unversioned"), + } + } +} + +/// `HEADER_LENGTH_LIMIT` is the maximum `SerializationHeader` size which +/// `DeserializationConfig::deserialize_from` is going to try to read (it returns an error if +/// it's too big). +/// It helps prevent an attacker passing a very long header to exhaust memory. +const HEADER_LENGTH_LIMIT: u64 = 1000; + +/// Header with global metadata about the serialized object. This help checking that we are not +/// deserializing data that we can't handle. #[derive(Serialize, Deserialize)] // This type should not be versioned because it is part of a wrapper of versioned messages. #[cfg_attr(tfhe_lints, allow(tfhe_lints::serialize_without_versionize))] struct SerializationHeader { - mode: SerializationMode, - version: Cow<'static, str>, + header_version: Cow<'static, str>, + versioning_mode: SerializationVersioningMode, + versioning_version: Cow<'static, str>, name: Cow<'static, str>, } @@ -43,28 +67,41 @@ impl SerializationHeader { /// Creates a new header for a versioned message fn new_versioned() -> Self { Self { - mode: SerializationMode::Versioned, - version: Cow::Borrowed(VERSIONING_VERSION), + header_version: Cow::Borrowed(SERIALIZATION_VERSION), + versioning_mode: SerializationVersioningMode::Versioned, + versioning_version: Cow::Borrowed(VERSIONING_VERSION), name: Cow::Borrowed(T::NAME), } } - /// Checks the validity of a versioned message - fn check_versioned(&self) -> Result<(), String> { - if self.mode != SerializationMode::Versioned { - return Err( - "On deserialization, expected versioned type but got unversioned one".to_string(), - ); + /// Creates a new header for an unversioned message + fn new_unversioned() -> Self { + Self { + header_version: Cow::Borrowed(SERIALIZATION_VERSION), + versioning_mode: SerializationVersioningMode::Unversioned, + versioning_version: Cow::Borrowed(CRATE_VERSION), + name: Cow::Borrowed(T::NAME), } + } - // Since there is only one "VERSIONING_VERSION", a message with a different value than the - // expected one is clearly invalid, so we return an error. In the future, we want to - // be able to upgrade it to the new versioning scheme. - if self.version != VERSIONING_VERSION { + /// Checks the validity of the header + fn validate(&self) -> Result<(), String> { + if self.versioning_mode == SerializationVersioningMode::Versioned { + // For the moment there is only one versioning scheme, so another value is + // a hard error. But maybe if we upgrade it we will be able to automatically convert + // it. + if self.versioning_version != VERSIONING_VERSION { + return Err(format!( + "On deserialization, expected versioning scheme version {VERSIONING_VERSION}, \ +got version {}", + self.versioning_version + )); + } + } else if self.versioning_version != CRATE_VERSION { return Err(format!( - "On deserialization, expected versioning scheme version {VERSIONING_VERSION}, \ - got version {}", - self.version + "This {} has been saved from TFHE-rs v{}, without versioning informations. \ +Please use the versioned serialization mode for backward compatibility.", + self.name, self.versioning_version )); } @@ -80,186 +117,243 @@ impl SerializationHeader { } } -// This is the version of the versioning scheme used to add backward compatibibility on tfhe-rs -// types. Similar to SERIALIZATION_VERSION, this number should be increased when the versioning -// scheme is upgraded. -const VERSIONING_VERSION: &str = "0.1"; +/// A configuration used to Serialize *TFHE-rs* objects. This configuration decides +/// if the object will be versioned and holds the max byte size of the written data. +#[derive(Copy, Clone)] +pub struct SerializationConfig { + versioned: SerializationVersioningMode, + serialized_size_limit: u64, +} -// `VERSION_LENGTH_LIMIT` is the maximum `SERIALIZATION_VERSION` size which `safe_deserialization` -// is going to try to read (it returns an error if it's too big). -// It helps prevent an attacker passing a very long `SERIALIZATION_VERSION` to exhaust memory. -const VERSION_LENGTH_LIMIT: u64 = 100; +impl SerializationConfig { + /// Creates a new serialization config. The default configuration will serialize the object + /// with versioning information for backward compatibility. + /// `serialized_size_limit` is the size limit (in number of byte) of the serialized object + /// (excluding the header). + pub fn new(serialized_size_limit: u64) -> Self { + Self { + versioned: SerializationVersioningMode::Versioned, + serialized_size_limit, + } + } -const TYPE_NAME_LENGTH_LIMIT: u64 = 1000; + /// Creates a new serialization config without any size check. + pub fn new_with_unlimited_size() -> Self { + Self { + versioned: SerializationVersioningMode::Versioned, + serialized_size_limit: 0, + } + } -const HEADER_LENGTH_LIMIT: u64 = 1000; + /// Disables the size limit for serialized objects + pub fn disable_size_limit(self) -> Self { + Self { + serialized_size_limit: 0, + ..self + } + } -/// Serializes an object into a [writer](std::io::Write). -/// The result contains a version of the serialization and the name of the -/// serialized type to provide checks on deserialization with [safe_deserialize]. -/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object -/// (excluding version and name serialization). -pub fn safe_serialize( - object: &T, - mut writer: impl std::io::Write, - serialized_size_limit: u64, -) -> bincode::Result<()> { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); + /// Disable the versioning of serializd objects + pub fn disable_versioning(self) -> Self { + Self { + versioned: SerializationVersioningMode::Unversioned, + ..self + } + } + + /// Create a serialization header based on the current config + fn create_header(&self) -> SerializationHeader { + match self.versioned { + SerializationVersioningMode::Versioned => SerializationHeader::new_versioned::(), + SerializationVersioningMode::Unversioned => SerializationHeader::new_unversioned::(), + } + } - options - .with_limit(VERSION_LENGTH_LIMIT) - .serialize_into::<_, String>(&mut writer, &SERIALIZATION_VERSION.to_owned())?; + /// Returns the max length of the serialized header + fn header_length_limit(&self) -> u64 { + if self.serialized_size_limit == 0 { + 0 + } else { + HEADER_LENGTH_LIMIT + } + } - options - .with_limit(TYPE_NAME_LENGTH_LIMIT) - .serialize_into::<_, String>(&mut writer, &T::NAME.to_owned())?; + /// Serializes an object into a [writer](std::io::Write), based on the current config. + /// The written bytes can be deserialized using [`DeserializationConfig::deserialize_from`]. + pub fn serialize_into( + self, + object: &T, + mut writer: impl std::io::Write, + ) -> bincode::Result<()> { + let options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + let header = self.create_header::(); + options + .with_limit(self.header_length_limit()) + .serialize_into(&mut writer, &header)?; + + match self.versioned { + SerializationVersioningMode::Versioned => options + .with_limit(self.serialized_size_limit) + .serialize_into(&mut writer, &object.versionize())?, + SerializationVersioningMode::Unversioned => options + .with_limit(self.serialized_size_limit) + .serialize_into(&mut writer, &object)?, + }; - options - .with_limit(serialized_size_limit) - .serialize_into(&mut writer, object)?; + Ok(()) + } +} - Ok(()) +/// Tells if the deserialization should also check conformance. +#[derive(Copy, Clone)] +enum ConformanceMode { + Checked(Params), + Unchecked, } -/// Serializes an object into a [writer](std::io::Write) like [`safe_serialize`] does, -/// but adds versioning information before. -pub fn safe_serialize_versioned( - object: &T, - mut writer: impl std::io::Write, +/// A configuration used to Serialize *TFHE-rs* objects. This configuration decides +/// the various sanity checks that will be performed during deserialization. +#[derive(Copy, Clone)] +pub struct DeserializationConfig { serialized_size_limit: u64, -) -> bincode::Result<()> { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); + validate_header: bool, + conformance: ConformanceMode, +} - let header = SerializationHeader::new_versioned::(); - options - .with_limit(HEADER_LENGTH_LIMIT) - .serialize_into(&mut writer, &header)?; +impl DeserializationConfig { + /// Creates a new deserialization config. + /// By default, it will check that the serialization version and the name of the + /// deserialized type are correct. + /// `serialized_size_limit` is the size limit (in number of byte) of the serialized object + /// (excluding version and name serialization). + /// It will also check that the object is conformant with the parameter set given in + /// `conformance_params`. Finally, it will check the compatibility of the loaded data with + /// the current *TFHE-rs* version. + pub fn new(serialized_size_limit: u64, conformance_params: &Params) -> Self { + Self { + serialized_size_limit, + validate_header: true, + conformance: ConformanceMode::Checked(*conformance_params), + } + } - options - .with_limit(serialized_size_limit) - .serialize_into(&mut writer, &object.versionize())?; + /// Creates a new config without any size limit for the deserialized objects. + pub fn new_with_unlimited_size(conformance_params: &Params) -> Self { + Self { + serialized_size_limit: 0, + validate_header: true, + conformance: ConformanceMode::Checked(*conformance_params), + } + } - Ok(()) -} + /// Disables the size limit for the serialized objects. + pub fn disable_size_limit(self) -> Self { + Self { + serialized_size_limit: 0, + ..self + } + } -/// Deserializes an object serialized by `safe_serialize` from a [reader](std::io::Read). -/// Checks that the serialization version and the name of the -/// deserialized type are correct. -/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object -/// (excluding version and name serialization). -pub fn safe_deserialize( - mut reader: impl std::io::Read, - serialized_size_limit: u64, -) -> Result { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); - - let deserialized_version: String = options - .with_limit(VERSION_LENGTH_LIMIT) - .deserialize_from::<_, String>(&mut reader) - .map_err(|err| err.to_string())?; - - if deserialized_version != SERIALIZATION_VERSION { - return Err(format!( - "On deserialization, expected serialization version {SERIALIZATION_VERSION}, got version {deserialized_version}" - )); + /// Disables the header validation on the object. This header validations + /// checks that the serialized object is the one that is supposed to be loaded + /// and is compatible with this version of *TFHE-rs*. + pub fn disable_header_validation(self) -> Self { + Self { + validate_header: false, + ..self + } } - let deserialized_type: String = options - .with_limit(TYPE_NAME_LENGTH_LIMIT) - .deserialize_from::<_, String>(&mut reader) - .map_err(|err| err.to_string())?; - - if deserialized_type != T::NAME { - return Err(format!( - "On deserialization, expected type {}, got type {}", - T::NAME, - deserialized_type - )); + /// Creates a config with conformance checks disabled. The conformance is used + /// to validate that the loaded object is compatible with the given parameters. + pub fn new_without_conformance(serialized_size_limit: u64) -> Self { + Self { + serialized_size_limit, + validate_header: true, + conformance: ConformanceMode::Unchecked, + } } - options - .with_limit(serialized_size_limit) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string()) -} + /// Disables the conformance check on an existing config. + pub fn disable_conformance_check(self) -> Self { + Self { + conformance: ConformanceMode::Unchecked, + ..self + } + } -/// Deserializes an object with [safe_deserialize] and checks than it is conformant with the given -/// parameter set -pub fn safe_deserialize_conformant( - reader: impl std::io::Read, - serialized_size_limit: u64, - parameter_set: &T::ParameterSet, -) -> Result { - let deser: T = safe_deserialize(reader, serialized_size_limit)?; - - if !deser.is_conformant(parameter_set) { - return Err(format!( - "Deserialized object of type {} not conformant with given parameter set", - T::NAME - )); + /// Creates a new config without any sanity check. + pub fn new_unsafe() -> Self { + Self { + serialized_size_limit: 0, + validate_header: false, + conformance: ConformanceMode::Unchecked, + } } - Ok(deser) -} + fn header_length_limit(&self) -> u64 { + if self.serialized_size_limit == 0 { + 0 + } else { + HEADER_LENGTH_LIMIT + } + } -/// Deserializes an object serialized by `safe_serialize_versioned` from a [reader](std::io::Read). -/// Checks that the serialization version and the name of the -/// deserialized type are correct. -/// `serialized_size_limit` is the size limit (in number of byte) of the serialized object -/// (excluding version and name serialization). -pub fn safe_deserialize_versioned( - mut reader: impl std::io::Read, - serialized_size_limit: u64, -) -> Result { - let options = bincode::DefaultOptions::new() - .with_fixint_encoding() - .with_limit(0); - - let deserialized_header: SerializationHeader = options - .with_limit(HEADER_LENGTH_LIMIT) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string())?; - - deserialized_header.check_versioned::()?; - - options - .with_limit(serialized_size_limit) - .deserialize_from(&mut reader) - .map_err(|err| err.to_string()) - .and_then(|val| T::unversionize(val).map_err(|err| err.to_string())) -} + /// Deserializes an object serialized by [`SerializationConfig::serialize_into`] from a + /// [reader](std::io::Read). Performs various sanity checks based on the deserialization config. + pub fn deserialize_from< + T: DeserializeOwned + Unversionize + Named + ParameterSetConformant, + >( + self, + mut reader: impl std::io::Read, + ) -> Result { + let options = bincode::DefaultOptions::new() + .with_fixint_encoding() + .with_limit(0); + + let deserialized_header: SerializationHeader = options + .with_limit(self.header_length_limit()) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())?; + + if self.validate_header { + deserialized_header.validate::()?; + } -/// Deserializes an object with [safe_deserialize] and checks than it is conformant with the given -/// parameter set -pub fn safe_deserialize_conformant_versioned( - reader: impl std::io::Read, - serialized_size_limit: u64, - parameter_set: &T::ParameterSet, -) -> Result { - let deser: T = safe_deserialize_versioned(reader, serialized_size_limit)?; - - if !deser.is_conformant(parameter_set) { - return Err(format!( - "Deserialized object of type {} not conformant with given parameter set", - T::NAME - )); - } + let deser = if deserialized_header.versioning_mode == SerializationVersioningMode::Versioned + { + let deser_versioned = options + .with_limit(self.serialized_size_limit - self.header_length_limit()) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())?; + + T::unversionize(deser_versioned).map_err(|e| e.to_string())? + } else { + options + .with_limit(self.serialized_size_limit - self.header_length_limit()) + .deserialize_from(&mut reader) + .map_err(|err| err.to_string())? + }; + + if let ConformanceMode::Checked(parameter_set) = self.conformance { + if !deser.is_conformant(¶meter_set) { + return Err(format!( + "Deserialized object of type {} not conformant with given parameter set", + T::NAME + )); + } + } - Ok(deser) + Ok(deser) + } } #[cfg(all(test, feature = "shortint"))] mod test_shortint { - use crate::safe_deserialization::{ - safe_deserialize_conformant, safe_deserialize_conformant_versioned, safe_serialize, - safe_serialize_versioned, - }; + use crate::safe_deserialization::{DeserializationConfig, SerializationConfig}; use crate::shortint::parameters::{ PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, }; @@ -275,20 +369,23 @@ mod test_shortint { let mut buffer = vec![]; - safe_serialize(&ct, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1 << 40) + .disable_versioning() + .serialize_into(&ct, &mut buffer) + .unwrap(); - assert!(safe_deserialize_conformant::( - buffer.as_slice(), + assert!(DeserializationConfig::new( 1 << 20, - &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), + &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param() ) + .deserialize_from::(buffer.as_slice()) .is_err()); - let ct2 = safe_deserialize_conformant( - buffer.as_slice(), + let ct2 = DeserializationConfig::new( 1 << 20, &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), ) + .deserialize_from::(buffer.as_slice()) .unwrap(); let dec = ck.decrypt(&ct2); @@ -305,20 +402,22 @@ mod test_shortint { let mut buffer = vec![]; - safe_serialize_versioned(&ct, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1 << 40) + .serialize_into(&ct, &mut buffer) + .unwrap(); - assert!(safe_deserialize_conformant_versioned::( - buffer.as_slice(), + assert!(DeserializationConfig::new( 1 << 20, - &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param(), + &PARAM_MESSAGE_3_CARRY_3_KS_PBS.to_shortint_conformance_param() ) + .deserialize_from::(buffer.as_slice()) .is_err()); - let ct2 = safe_deserialize_conformant_versioned( - buffer.as_slice(), + let ct2 = DeserializationConfig::new( 1 << 20, &PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), ) + .deserialize_from::(buffer.as_slice()) .unwrap(); let dec = ck.decrypt(&ct2); @@ -331,10 +430,7 @@ mod test_integer { use crate::conformance::ListSizeConstraint; use crate::high_level_api::{generate_keys, ConfigBuilder}; use crate::prelude::*; - use crate::safe_deserialization::{ - safe_deserialize_conformant, safe_deserialize_conformant_versioned, safe_serialize, - safe_serialize_versioned, - }; + use crate::safe_deserialization::{DeserializationConfig, SerializationConfig}; use crate::shortint::parameters::{ PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, }; @@ -360,7 +456,10 @@ mod test_integer { let mut buffer = vec![]; - safe_serialize(&ct_list, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1 << 40) + .disable_versioning() + .serialize_into(&ct_list, &mut buffer) + .unwrap(); let to_param_set = |list_size_constraint| CompactCiphertextListConformanceParams { shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), @@ -377,12 +476,9 @@ mod test_integer { to_param_set(ListSizeConstraint::try_size_in_range(1, 2).unwrap()), to_param_set(ListSizeConstraint::try_size_in_range(4, 5).unwrap()), ] { - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - ¶m_set - ) - .is_err()); + assert!(DeserializationConfig::new(1 << 20, ¶m_set) + .deserialize_from::(buffer.as_slice()) + .is_err()); } for len_constraint in [ @@ -395,24 +491,18 @@ mod test_integer { shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), num_elements_constraint: len_constraint, }; - assert!(safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .is_ok()); + assert!(DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(buffer.as_slice()) + .is_ok()); } let params = CompactCiphertextListConformanceParams { shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), num_elements_constraint: ListSizeConstraint::exact_size(3), }; - let ct2 = safe_deserialize_conformant::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); + let ct2 = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(buffer.as_slice()) + .unwrap(); let mut cts = Vec::with_capacity(3); let expander = ct2.expand().unwrap(); @@ -442,7 +532,9 @@ mod test_integer { let mut buffer = vec![]; - safe_serialize_versioned(&ct_list, &mut buffer, 1 << 40).unwrap(); + SerializationConfig::new(1 << 40) + .serialize_into(&ct_list, &mut buffer) + .unwrap(); let to_param_set = |list_size_constraint| CompactCiphertextListConformanceParams { shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), @@ -459,14 +551,9 @@ mod test_integer { to_param_set(ListSizeConstraint::try_size_in_range(1, 2).unwrap()), to_param_set(ListSizeConstraint::try_size_in_range(4, 5).unwrap()), ] { - assert!( - safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - ¶m_set - ) - .is_err() - ); + assert!(DeserializationConfig::new(1 << 20, ¶m_set) + .deserialize_from::(buffer.as_slice()) + .is_err()); } for len_constraint in [ @@ -479,26 +566,18 @@ mod test_integer { shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), num_elements_constraint: len_constraint, }; - assert!( - safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .is_ok() - ); + assert!(DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(buffer.as_slice()) + .is_ok()); } let params = CompactCiphertextListConformanceParams { shortint_params: PARAM_MESSAGE_2_CARRY_2_KS_PBS.to_shortint_conformance_param(), num_elements_constraint: ListSizeConstraint::exact_size(3), }; - let ct2 = safe_deserialize_conformant_versioned::( - buffer.as_slice(), - 1 << 20, - ¶ms, - ) - .unwrap(); + let ct2 = DeserializationConfig::new(1 << 20, ¶ms) + .deserialize_from::(buffer.as_slice()) + .unwrap(); let mut cts = Vec::with_capacity(3); let expander = ct2.expand().unwrap();