diff --git a/tfhe/src/named.rs b/tfhe/src/named.rs index 2c03d54abd..5fa9e0a31d 100644 --- a/tfhe/src/named.rs +++ b/tfhe/src/named.rs @@ -1,3 +1,7 @@ pub trait Named { + /// Default name for the type const NAME: &'static str; + /// Aliases that should also be accepted for backward compatibility when checking the name of + /// values of this type + const BACKWARD_COMPATIBILITY_ALIASES: &'static [&'static str] = &[]; } diff --git a/tfhe/src/safe_serialization.rs b/tfhe/src/safe_serialization.rs index e29ce2f696..886ef24ac9 100644 --- a/tfhe/src/safe_serialization.rs +++ b/tfhe/src/safe_serialization.rs @@ -122,7 +122,11 @@ Please use the versioned serialization mode for backward compatibility.", } } - if self.name != T::NAME { + if self.name != T::NAME + && T::BACKWARD_COMPATIBILITY_ALIASES + .iter() + .all(|alias| self.name != *alias) + { return Err(format!( "On deserialization, expected type {}, got type {}", T::NAME, @@ -494,13 +498,17 @@ pub fn safe_deserialize_conformant< #[cfg(all(test, feature = "shortint"))] mod test_shortint { - use crate::safe_serialization::{DeserializationConfig, SerializationConfig}; + use tfhe_versionable::Versionize; + + use crate::named::Named; use crate::shortint::parameters::{ PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, V0_11_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, }; use crate::shortint::{gen_keys, Ciphertext}; + use super::*; + #[test] fn safe_deserialization_ct_unversioned() { let (ck, _sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); @@ -626,6 +634,46 @@ mod test_shortint { let dec = ck.decrypt(&ct2); assert_eq!(msg, dec); } + + #[test] + fn safe_deserialization_named() { + #[derive(Serialize, Deserialize, Versionize)] + #[repr(transparent)] + struct Foo(u64); + + impl Named for Foo { + const NAME: &'static str = "Foo"; + } + + #[derive(Deserialize, Versionize)] + #[repr(transparent)] + struct Bar(u64); + + impl Named for Bar { + const NAME: &'static str = "Bar"; + + const BACKWARD_COMPATIBILITY_ALIASES: &'static [&'static str] = &["Foo"]; + } + + #[derive(Deserialize, Versionize)] + #[repr(transparent)] + struct Baz(u64); + + impl Named for Baz { + const NAME: &'static str = "Baz"; + } + + let foo = Foo(3); + let mut foo_ser = Vec::new(); + safe_serialize(&foo, &mut foo_ser, 0x1000).unwrap(); + + let foo_deser: Foo = safe_deserialize(foo_ser.as_slice(), 0x1000).unwrap(); + let bar_deser: Bar = safe_deserialize(foo_ser.as_slice(), 0x1000).unwrap(); + + assert_eq!(foo_deser.0, bar_deser.0); + + assert!(safe_deserialize::(foo_ser.as_slice(), 0x1000).is_err()); + } } #[cfg(all(test, feature = "integer"))] diff --git a/tfhe/src/zk/mod.rs b/tfhe/src/zk/mod.rs index 760489d6e7..c62daf5bc8 100644 --- a/tfhe/src/zk/mod.rs +++ b/tfhe/src/zk/mod.rs @@ -196,6 +196,8 @@ pub enum CompactPkeCrs { impl Named for CompactPkeCrs { const NAME: &'static str = "zk::CompactPkeCrs"; + + const BACKWARD_COMPATIBILITY_ALIASES: &'static [&'static str] = &["zk::CompactPkePublicParams"]; } impl From for CompactPkeCrs {