diff --git a/.github/workflows/aws_tfhe_fast_tests.yml b/.github/workflows/aws_tfhe_fast_tests.yml index 0edd1be98d..0c90077539 100644 --- a/.github/workflows/aws_tfhe_fast_tests.yml +++ b/.github/workflows/aws_tfhe_fast_tests.yml @@ -26,6 +26,7 @@ jobs: outputs: csprng_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.csprng_any_changed }} zk_pok_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.zk_pok_any_changed }} + versionable_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.versionable_any_changed }} core_crypto_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.core_crypto_any_changed || steps.changed-files.outputs.dependencies_any_changed }} @@ -64,10 +65,15 @@ jobs: - tfhe/Cargo.toml - concrete-csprng/** - tfhe-zk-pok/** + - utils/tfhe-versionable/** + - utils/tfhe-versionable-derive/** csprng: - concrete-csprng/** zk_pok: - tfhe-zk-pok/** + versionable: + - utils/tfhe-versionable/** + - utils/tfhe-versionable-derive/** core_crypto: - tfhe/src/core_crypto/** boolean: @@ -103,6 +109,7 @@ jobs: if: ( steps.changed-files.outputs.dependencies_any_changed == 'true' || steps.changed-files.outputs.csprng_any_changed == 'true' || steps.changed-files.outputs.zk_pok_any_changed == 'true' || + steps.changed-files.outputs.versionable_any_changed == 'true' || steps.changed-files.outputs.core_crypto_any_changed == 'true' || steps.changed-files.outputs.boolean_any_changed == 'true' || steps.changed-files.outputs.shortint_any_changed == 'true' || @@ -167,6 +174,11 @@ jobs: run: | make test_zk_pok + - name: Run tfhe-versionable tests + if: needs.should-run.outputs.versionable_test == 'true' + run: | + make test_versionable + - name: Run core tests if: needs.should-run.outputs.core_crypto_test == 'true' run: | diff --git a/Makefile b/Makefile index 3ed3e38058..75619a1c09 100644 --- a/Makefile +++ b/Makefile @@ -744,7 +744,7 @@ test_zk_pok: install_rs_build_toolchain .PHONY: test_versionable # Run tests for tfhe-versionable subcrate test_versionable: install_rs_build_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \ - -p tfhe-versionable + --all-targets -p tfhe-versionable # The backward compat data repo holds historical binary data but also rust code to generate and load them. # Here we use the "patch" functionality of Cargo to make sure the repo used for the data is the same as the one used for the code. diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index 32c2db7b16..5a56e6cd5e 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -76,7 +76,7 @@ sha3 = { version = "0.10", optional = true } itertools = "0.11.0" rand_core = { version = "0.6.4", features = ["std"] } tfhe-zk-pok = { version = "0.3.0-alpha.1", path = "../tfhe-zk-pok", optional = true } -tfhe-versionable = { version = "0.2.1", path = "../utils/tfhe-versionable" } +tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" } # wasm deps wasm-bindgen = { version = "0.2.86", features = [ diff --git a/utils/tfhe-versionable-derive/Cargo.toml b/utils/tfhe-versionable-derive/Cargo.toml index dedee480f8..fb0ce1e249 100644 --- a/utils/tfhe-versionable-derive/Cargo.toml +++ b/utils/tfhe-versionable-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tfhe-versionable-derive" -version = "0.2.1" +version = "0.3.0" edition = "2021" keywords = ["versioning", "serialization", "encoding", "proc-macro", "derive"] homepage = "https://zama.ai/" diff --git a/utils/tfhe-versionable-derive/src/associated.rs b/utils/tfhe-versionable-derive/src/associated.rs index 98971ae510..12a69556f2 100644 --- a/utils/tfhe-versionable-derive/src/associated.rs +++ b/utils/tfhe-versionable-derive/src/associated.rs @@ -1,11 +1,12 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::{ - parse_quote, DeriveInput, ImplGenerics, Item, ItemImpl, Lifetime, Path, Type, WhereClause, + parse_quote, DeriveInput, Generics, ImplGenerics, Item, ItemImpl, Lifetime, Path, Type, + WhereClause, }; use crate::{ - add_lifetime_bound, add_trait_bound, add_trait_where_clause, add_where_lifetime_bound, + add_lifetime_bound, add_trait_where_clause, add_where_lifetime_bound, extend_where_clause, parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME, SERIALIZE_TRAIT_NAME, }; @@ -91,6 +92,11 @@ pub(crate) enum AssociatedTypeKind { /// [`DispatchType`]: crate::dispatch_type::DispatchType /// [`VersionType`]: crate::dispatch_type::VersionType pub(crate) trait AssociatedType: Sized { + /// Bounds that will be added on the fields of the ref type definition + const REF_BOUNDS: &'static [&'static str]; + /// Bounds that will be added on the fields of the owned type definition + const OWNED_BOUNDS: &'static [&'static str]; + /// This will create the alternative of the type that holds a reference to the underlying data fn new_ref(orig_type: &DeriveInput) -> syn::Result; /// This will create the alternative of the type that owns the underlying data @@ -99,6 +105,30 @@ pub(crate) trait AssociatedType: Sized { /// Generates the type declaration for this type fn generate_type_declaration(&self) -> syn::Result; + /// Returns the kind of associated type, a ref or an owned type + fn kind(&self) -> &AssociatedTypeKind; + + /// Returns the generics found in the original type definition + fn orig_type_generics(&self) -> &Generics; + + /// Returns the generics and bounds that should be added to the type + fn type_generics(&self) -> syn::Result { + let mut generics = self.orig_type_generics().clone(); + if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind() { + if let Some(lifetime) = opt_lifetime { + add_lifetime_bound(&mut generics, lifetime); + } + add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?; + } else { + add_trait_where_clause(&mut generics, self.inner_types()?, Self::OWNED_BOUNDS)?; + } + + Ok(generics) + } + + /// Generics needed for conversions between the original type and associated types + fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result; + /// Generates conversion methods between the origin type and the associated type. If the version /// type is a ref, the conversion is `From<&'vers OrigType> for Associated<'vers>` because this /// conversion is used for versioning. If the version type is owned, the conversion is @@ -109,10 +139,6 @@ pub(crate) trait AssociatedType: Sized { /// [`Version`]: crate::dispatch_type::VersionType fn generate_conversion(&self) -> syn::Result>; - /// The lifetime added for this type, if it is a "ref" type. It also returns None if the type is - /// a unit type (no data) - //fn lifetime(&self) -> Option<&Lifetime>; - /// The identifier used to name this type fn ident(&self) -> Ident; @@ -144,40 +170,19 @@ pub(crate) struct AssociatingTrait { owned_type: T, orig_type: DeriveInput, trait_path: Path, - /// Bounds that should be added to the generics for the impl - generics_bounds: Vec, - /// Bounds that should be added on the struct attributes - attributes_bounds: Vec, } impl AssociatingTrait { - pub(crate) fn new( - orig_type: &DeriveInput, - name: &str, - generics_bounds: &[&str], - attributes_bounds: &[&str], - ) -> syn::Result { + pub(crate) fn new(orig_type: &DeriveInput, name: &str) -> syn::Result { let ref_type = T::new_ref(orig_type)?; let owned_type = T::new_owned(orig_type)?; let trait_path = syn::parse_str(name)?; - let generics_bounds = generics_bounds - .iter() - .map(|bound| bound.to_string()) - .collect(); - - let attributes_bounds = attributes_bounds - .iter() - .map(|bound| bound.to_string()) - .collect(); - Ok(Self { ref_type, owned_type, orig_type: orig_type.clone(), trait_path, - generics_bounds, - attributes_bounds, }) } @@ -189,22 +194,24 @@ impl AssociatingTrait { let ref_ident = self.ref_type.ident(); let owned_ident = self.owned_type.ident(); - let mut generics = self.orig_type.generics.clone(); - - for bound in &self.generics_bounds { - add_trait_bound(&mut generics, bound)?; - } - let trait_param = self.ref_type.as_trait_param().transpose()?; - let mut ref_type_generics = generics.clone(); - - add_trait_where_clause( - &mut generics, - self.ref_type.inner_types()?, - &self.attributes_bounds, - )?; + // AssociatedToOrig conversion always has a stricter bound than the other side so we use it + let mut generics = self + .owned_type + .conversion_generics(ConversionDirection::AssociatedToOrig)?; + + // Merge the where clause for the reference type with the one from the owned type + let owned_where_clause = generics.make_where_clause(); + if let Some(ref_where_clause) = self + .ref_type + .conversion_generics(ConversionDirection::AssociatedToOrig)? + .where_clause + { + extend_where_clause(owned_where_clause, &ref_where_clause); + } + let mut ref_type_generics = self.ref_type.orig_type_generics().clone(); // If the original type has some generics, we need to add a lifetime bound on them if let Some(lifetime) = self.ref_type.lifetime() { add_lifetime_bound(&mut ref_type_generics, lifetime); @@ -246,8 +253,13 @@ impl AssociatingTrait { ) ]}; + // Creates the type declaration. These types are the output of the versioning process, so + // they should be serializable. Serde might try to add automatic bounds on the type generics + // even if we don't need them, so we use `#[serde(bound = "")]` to disable this. The bounds + // on the generated types should be sufficient. let owned_tokens = quote! { #[derive(#serialize_trait, #deserialize_trait)] + #[serde(bound = "")] #ignored_lints #owned_decla @@ -260,6 +272,7 @@ impl AssociatingTrait { let ref_tokens = quote! { #[derive(#serialize_trait)] + #[serde(bound = "")] #ignored_lints #ref_decla diff --git a/utils/tfhe-versionable-derive/src/dispatch_type.rs b/utils/tfhe-versionable-derive/src/dispatch_type.rs index 3e6725bbb1..c39aad57ad 100644 --- a/utils/tfhe-versionable-derive/src/dispatch_type.rs +++ b/utils/tfhe-versionable-derive/src/dispatch_type.rs @@ -10,10 +10,10 @@ use syn::{ use crate::associated::{ generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind, + ConversionDirection, }; use crate::{ - add_lifetime_bound, add_trait_bound, add_trait_where_clause, parse_const_str, LIFETIME_NAME, - UNVERSIONIZE_ERROR_NAME, VERSIONIZE_TRAIT_NAME, VERSION_TRAIT_NAME, + parse_const_str, LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UPGRADE_TRAIT_NAME, VERSION_TRAIT_NAME, }; /// This is the enum that holds all the versions of a specific type. Each variant of the enum is @@ -47,6 +47,10 @@ fn derive_input_to_enum(input: &DeriveInput) -> syn::Result { } impl AssociatedType for DispatchType { + const REF_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME]; + + const OWNED_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME]; + fn new_ref(orig_type: &DeriveInput) -> syn::Result { for lt in orig_type.generics.lifetimes() { // check for collision with other lifetimes in `orig_type` @@ -93,7 +97,7 @@ impl AssociatedType for DispatchType { Ok(ItemEnum { ident: self.ident(), - generics: self.generics()?, + generics: self.type_generics()?, attrs: vec![parse_quote! { #[automatically_derived] }], variants: variants?, ..self.orig_type.clone() @@ -101,13 +105,41 @@ impl AssociatedType for DispatchType { .into()) } - fn generate_conversion(&self) -> syn::Result> { - let generics = self.generics()?; - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + fn kind(&self) -> &AssociatedTypeKind { + &self.kind + } + fn orig_type_generics(&self) -> &Generics { + &self.orig_type.generics + } + + fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result { + let mut generics = self.type_generics()?; + let preds = &mut generics.make_where_clause().predicates; + + let upgrade_trait: Path = parse_const_str(UPGRADE_TRAIT_NAME); + + if let ConversionDirection::AssociatedToOrig = direction { + if let AssociatedTypeKind::Owned = &self.kind { + // Add a bound for each version to be upgradable into the next one + for src_idx in 0..(self.versions_count() - 1) { + let src_ty = self.version_type_at(src_idx)?; + let next_ty = self.version_type_at(src_idx + 1)?; + preds.push(parse_quote! { #src_ty: #upgrade_trait<#next_ty> }) + } + } + } + + Ok(generics) + } + + fn generate_conversion(&self) -> syn::Result> { match &self.kind { AssociatedTypeKind::Ref(lifetime) => { // Wraps the highest version into the dispatch enum + let generics = self.conversion_generics(ConversionDirection::OrigToAssociated)?; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let src_type = self.latest_version_type()?; let src = parse_quote! { &#lifetime #src_type }; let dest_ident = self.ident(); @@ -126,6 +158,9 @@ impl AssociatedType for DispatchType { } AssociatedTypeKind::Owned => { // Upgrade to the highest version the convert to the main type + let generics = self.conversion_generics(ConversionDirection::AssociatedToOrig)?; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let src_ident = self.ident(); let src = parse_quote! { #src_ident #ty_generics }; let dest_type = self.latest_version_type()?; @@ -144,6 +179,9 @@ impl AssociatedType for DispatchType { )?; // Wraps the highest version into the dispatch enum + let generics = self.conversion_generics(ConversionDirection::OrigToAssociated)?; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let src_type = self.latest_version_type()?; let src = parse_quote! { #src_type }; let dest_ident = self.ident(); @@ -182,13 +220,13 @@ impl AssociatedType for DispatchType { } } - fn as_trait_param(&self) -> Option> { - Some(self.latest_version_type()) - } - fn inner_types(&self) -> syn::Result> { self.version_types() } + + fn as_trait_param(&self) -> Option> { + Some(self.latest_version_type()) + } } impl DispatchType { @@ -200,19 +238,6 @@ impl DispatchType { ) } - fn generics(&self) -> syn::Result { - let mut generics = self.orig_type.generics.clone(); - if let AssociatedTypeKind::Ref(Some(lifetime)) = &self.kind { - add_lifetime_bound(&mut generics, lifetime); - } - - add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSION_TRAIT_NAME])?; - - add_trait_bound(&mut generics, VERSIONIZE_TRAIT_NAME)?; - - Ok(generics) - } - /// Returns the number of versions in this dispatch enum fn versions_count(&self) -> usize { self.orig_type.variants.len() diff --git a/utils/tfhe-versionable-derive/src/lib.rs b/utils/tfhe-versionable-derive/src/lib.rs index c2410d037c..f40374daf0 100644 --- a/utils/tfhe-versionable-derive/src/lib.rs +++ b/utils/tfhe-versionable-derive/src/lib.rs @@ -16,12 +16,10 @@ use proc_macro2::Span; use quote::{quote, ToTokens}; use syn::parse::Parse; use syn::punctuated::Punctuated; -use syn::spanned::Spanned; use syn::{ parse_macro_input, parse_quote, DeriveInput, GenericParam, Generics, Ident, Lifetime, - LifetimeParam, Path, TraitBound, Type, TypeParam, TypeParamBound, + LifetimeParam, Path, TraitBound, Type, TypeParamBound, WhereClause, }; -use versionize_attribute::VersionizeAttribute; /// Adds the full path of the current crate name to avoid name clashes in generated code. macro_rules! crate_full_path { @@ -39,15 +37,16 @@ pub(crate) const VERSIONIZE_SLICE_TRAIT_NAME: &str = crate_full_path!("Versioniz pub(crate) const VERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("VersionizeVec"); pub(crate) const UNVERSIONIZE_TRAIT_NAME: &str = crate_full_path!("Unversionize"); pub(crate) const UNVERSIONIZE_VEC_TRAIT_NAME: &str = crate_full_path!("UnversionizeVec"); +pub(crate) const UPGRADE_TRAIT_NAME: &str = crate_full_path!("Upgrade"); pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeError"); pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize"; pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize"; -pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned"; use associated::AssociatingTrait; use crate::version_type::VersionType; +use crate::versionize_attribute::VersionizeAttribute; /// unwrap a `syn::Result` by extracting the Ok value or returning from the outer function with /// a compile error @@ -74,8 +73,6 @@ fn impl_version_trait(input: &DeriveInput) -> proc_macro2::TokenStream { let version_trait = syn_unwrap!(AssociatingTrait::::new( input, VERSION_TRAIT_NAME, - &[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME], - &[VERSIONIZE_TRAIT_NAME, UNVERSIONIZE_TRAIT_NAME] )); let version_types = syn_unwrap!(version_trait.generate_types_declarations()); @@ -102,16 +99,6 @@ pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream { let dispatch_trait = syn_unwrap!(AssociatingTrait::::new( &input, DISPATCH_TRAIT_NAME, - &[ - VERSIONIZE_TRAIT_NAME, - VERSIONIZE_VEC_TRAIT_NAME, - VERSIONIZE_SLICE_TRAIT_NAME, - UNVERSIONIZE_TRAIT_NAME, - UNVERSIONIZE_VEC_TRAIT_NAME, - SERIALIZE_TRAIT_NAME, - DESERIALIZE_OWNED_TRAIT_NAME - ], - &[] )); let dispatch_types = syn_unwrap!(dispatch_trait.generate_types_declarations()); @@ -137,9 +124,13 @@ pub fn derive_versions_dispatch(input: TokenStream) -> TokenStream { pub fn derive_versionize(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); - let attributes = syn_unwrap!(VersionizeAttribute::parse_from_attributes_list( - &input.attrs - )); + let attributes = syn_unwrap!( + VersionizeAttribute::parse_from_attributes_list(&input.attrs).and_then(|attr_opt| attr_opt + .ok_or_else(|| syn::Error::new( + Span::call_site(), + "Missing `versionize` attribute for `Versionize`", + ))) + ); // If we apply a type conversion before the call to versionize, the type that implements // the `Version` trait is the target type and not Self @@ -197,48 +188,9 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { let versionize_slice_trait: Path = parse_const_str(VERSIONIZE_SLICE_TRAIT_NAME); let unversionize_vec_trait: Path = parse_const_str(UNVERSIONIZE_VEC_TRAIT_NAME); - let mut versionize_generics = trait_generics.clone(); - for bound in attributes.versionize_bounds() { - syn_unwrap!(add_type_param_bound(&mut versionize_generics, bound)); - } - - // Add generic bounds specified by the user with the `bound` attribute - let mut unversionize_generics = trait_generics.clone(); - for bound in attributes.unversionize_bounds() { - syn_unwrap!(add_type_param_bound(&mut unversionize_generics, bound)); - } - - // Add Generics for the `VersionizeVec` and `UnversionizeVec` traits - let mut versionize_slice_generics = versionize_generics.clone(); - syn_unwrap!(add_trait_bound( - &mut versionize_slice_generics, - VERSIONIZE_TRAIT_NAME - )); - - let mut versionize_vec_generics = versionize_generics.clone(); - syn_unwrap!(add_trait_bound( - &mut versionize_vec_generics, - VERSIONIZE_OWNED_TRAIT_NAME - )); - let mut unversionize_vec_generics = unversionize_generics.clone(); - syn_unwrap!(add_trait_bound( - &mut unversionize_vec_generics, - UNVERSIONIZE_TRAIT_NAME - )); - // split generics so they can be used inside the generated code let (_, _, ref_where_clause) = ref_generics.split_for_impl(); - let (versionize_impl_generics, _, versionize_where_clause) = - versionize_generics.split_for_impl(); - let (unversionize_impl_generics, _, unversionize_where_clause) = - unversionize_generics.split_for_impl(); - - let (versionize_slice_impl_generics, _, versionize_slice_where_clause) = - versionize_slice_generics.split_for_impl(); - let (versionize_vec_impl_generics, _, versionize_vec_where_clause) = - versionize_vec_generics.split_for_impl(); - let (unversionize_vec_impl_generics, _, unversionize_vec_where_clause) = - unversionize_vec_generics.split_for_impl(); + let (trait_impl_generics, _, trait_where_clause) = trait_generics.split_for_impl(); // If we want to apply a conversion before the call to versionize we need to use the "owned" // alternative of the dispatch enum to be able to store the conversion result. @@ -257,8 +209,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { #version_trait_impl #[automatically_derived] - impl #versionize_impl_generics #versionize_trait for #input_ident #ty_generics - #versionize_where_clause + impl #trait_impl_generics #versionize_trait for #input_ident #ty_generics + #trait_where_clause { type Versioned<#lifetime> = <#dispatch_enum_path #dispatch_generics as @@ -270,8 +222,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } #[automatically_derived] - impl #versionize_impl_generics #versionize_owned_trait for #input_ident #ty_generics - #versionize_where_clause + impl #trait_impl_generics #versionize_owned_trait for #input_ident #ty_generics + #trait_where_clause { type VersionedOwned = <#dispatch_enum_path #dispatch_generics as @@ -283,8 +235,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } #[automatically_derived] - impl #unversionize_impl_generics #unversionize_trait for #input_ident #ty_generics - #unversionize_where_clause + impl #trait_impl_generics #unversionize_trait for #input_ident #ty_generics + #trait_where_clause { fn unversionize(#unversionize_arg_name: Self::VersionedOwned) -> Result { #unversionize_body @@ -292,8 +244,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } #[automatically_derived] - impl #versionize_slice_impl_generics #versionize_slice_trait for #input_ident #ty_generics - #versionize_slice_where_clause + impl #trait_impl_generics #versionize_slice_trait for #input_ident #ty_generics + #trait_where_clause { type VersionedSlice<#lifetime> = Vec<::Versioned<#lifetime>> #ref_where_clause; @@ -302,8 +254,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } } - impl #versionize_vec_impl_generics #versionize_vec_trait for #input_ident #ty_generics - #versionize_vec_where_clause + impl #trait_impl_generics #versionize_vec_trait for #input_ident #ty_generics + #trait_where_clause { type VersionedVec = Vec<::VersionedOwned> #owned_where_clause; @@ -314,8 +266,8 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream { } #[automatically_derived] - impl #unversionize_vec_impl_generics #unversionize_vec_trait for #input_ident #ty_generics - #unversionize_vec_where_clause { + impl #trait_impl_generics #unversionize_vec_trait for #input_ident #ty_generics + #trait_where_clause { fn unversionize_vec(versioned: Self::VersionedVec) -> Result, #unversionize_error> { versioned .into_iter() @@ -425,35 +377,6 @@ fn parse_trait_bound(trait_name: &str) -> syn::Result { Ok(parse_quote!(#trait_path)) } -/// Adds a trait bound for `trait_name` on all the generic types in `generics` -fn add_trait_bound(generics: &mut Generics, trait_name: &str) -> syn::Result<()> { - let trait_bound: TraitBound = parse_trait_bound(trait_name)?; - for param in generics.type_params_mut() { - param - .bounds - .push(TypeParamBound::Trait(trait_bound.clone())); - } - - Ok(()) -} - -fn add_type_param_bound(generics: &mut Generics, type_param_bound: &TypeParam) -> syn::Result<()> { - for param in generics.type_params_mut() { - if param.ident == type_param_bound.ident { - param.bounds.extend(type_param_bound.bounds.clone()); - return Ok(()); - } - } - - Err(syn::Error::new( - type_param_bound.span(), - format!( - "Bound type {} not found in target type generics", - type_param_bound.ident - ), - )) -} - /// Adds a "where clause" bound for all the input types with all the input traits fn add_trait_where_clause<'a, S: AsRef, I: IntoIterator>( generics: &mut Generics, @@ -475,6 +398,18 @@ fn add_trait_where_clause<'a, S: AsRef, I: IntoIterator>( Ok(()) } +/// Extends a where clause with predicates from another one, filtering duplicates +fn extend_where_clause(base_clause: &mut WhereClause, extension_clause: &WhereClause) { + for extend_predicate in &extension_clause.predicates { + if base_clause.predicates.iter().all(|base_predicate| { + base_predicate.to_token_stream().to_string() + != extend_predicate.to_token_stream().to_string() + }) { + base_clause.predicates.push(extend_predicate.clone()); + } + } +} + /// Creates a Result [`syn::punctuated::Punctuated`] from an iterator of Results fn punctuated_from_iter_result>>( iter: I, diff --git a/utils/tfhe-versionable-derive/src/version_type.rs b/utils/tfhe-versionable-derive/src/version_type.rs index 1c31877ff5..37d2288236 100644 --- a/utils/tfhe-versionable-derive/src/version_type.rs +++ b/utils/tfhe-versionable-derive/src/version_type.rs @@ -16,9 +16,9 @@ use crate::associated::{ ConversionDirection, }; use crate::{ - add_lifetime_bound, add_trait_where_clause, parse_const_str, parse_trait_bound, - punctuated_from_iter_result, LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, - VERSIONIZE_OWNED_TRAIT_NAME, VERSIONIZE_TRAIT_NAME, + add_trait_where_clause, parse_const_str, parse_trait_bound, punctuated_from_iter_result, + LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UNVERSIONIZE_TRAIT_NAME, VERSIONIZE_OWNED_TRAIT_NAME, + VERSIONIZE_TRAIT_NAME, }; /// The types generated for a specific version of a given exposed type. These types are identical to @@ -30,6 +30,9 @@ pub(crate) struct VersionType { } impl AssociatedType for VersionType { + const REF_BOUNDS: &'static [&'static str] = &[VERSIONIZE_TRAIT_NAME]; + const OWNED_BOUNDS: &'static [&'static str] = &[VERSIONIZE_OWNED_TRAIT_NAME]; + fn new_ref(orig_type: &DeriveInput) -> syn::Result { let lifetime = if is_unit(orig_type) { None @@ -173,40 +176,23 @@ impl AssociatedType for VersionType { } } - fn as_trait_param(&self) -> Option> { - None - } - fn inner_types(&self) -> syn::Result> { self.orig_type_fields() .iter() .map(|field| Ok(&field.ty)) .collect() } -} -impl VersionType { - /// Returns the fields of the original declaration. - fn orig_type_fields(&self) -> Punctuated<&Field, Comma> { - derive_type_fields(&self.orig_type) + fn as_trait_param(&self) -> Option> { + None } - fn type_generics(&self) -> syn::Result { - let mut generics = self.orig_type.generics.clone(); - if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind { - if let Some(lifetime) = opt_lifetime { - add_lifetime_bound(&mut generics, lifetime); - } - add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSIONIZE_TRAIT_NAME])?; - } else { - add_trait_where_clause( - &mut generics, - self.inner_types()?, - &[VERSIONIZE_OWNED_TRAIT_NAME], - )?; - } + fn kind(&self) -> &AssociatedTypeKind { + &self.kind + } - Ok(generics) + fn orig_type_generics(&self) -> &Generics { + &self.orig_type.generics } fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result { @@ -224,6 +210,13 @@ impl VersionType { Ok(generics) } +} + +impl VersionType { + /// Returns the fields of the original declaration. + fn orig_type_fields(&self) -> Punctuated<&Field, Comma> { + derive_type_fields(&self.orig_type) + } /// Generates the declaration for the Version equivalent of the input struct fn generate_struct(&self, stru: &DataStruct) -> syn::Result { diff --git a/utils/tfhe-versionable-derive/src/versionize_attribute.rs b/utils/tfhe-versionable-derive/src/versionize_attribute.rs index 9a4a529e08..36ee2e08c1 100644 --- a/utils/tfhe-versionable-derive/src/versionize_attribute.rs +++ b/utils/tfhe-versionable-derive/src/versionize_attribute.rs @@ -2,14 +2,11 @@ use proc_macro2::Span; use quote::{quote, ToTokens}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{ - Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token, TraitBound, Type, - TypeParam, -}; +use syn::{Attribute, Expr, Ident, Lit, Meta, Path, Token, TraitBound, Type}; use crate::{parse_const_str, UNVERSIONIZE_ERROR_NAME, VERSIONIZE_OWNED_TRAIT_NAME}; -/// Name of the attribute used to give arguments to our macros +/// Name of the attribute used to give arguments to the `Versionize` macro const VERSIONIZE_ATTR_NAME: &str = "versionize"; pub(crate) struct VersionizeAttribute { @@ -17,8 +14,6 @@ pub(crate) struct VersionizeAttribute { from: Option, try_from: Option, into: Option, - versionize_bounds: Vec, - unversionize_bounds: Vec, } #[derive(Default)] @@ -27,8 +22,6 @@ struct VersionizeAttributeBuilder { from: Option, try_from: Option, into: Option, - versionize_bounds: Vec, - unversionize_bounds: Vec, } impl VersionizeAttributeBuilder { @@ -42,8 +35,6 @@ impl VersionizeAttributeBuilder { from: self.from, try_from: self.try_from, into: self.into, - versionize_bounds: self.versionize_bounds, - unversionize_bounds: self.unversionize_bounds, }) } } @@ -53,18 +44,17 @@ impl VersionizeAttribute { /// `DispatchType` is the name of the type holding the dispatch enum. /// Returns an error if no `versionize` attribute has been found, if multiple attributes are /// present on the same struct or if the attribute is malformed. - pub(crate) fn parse_from_attributes_list(attributes: &[Attribute]) -> syn::Result { + pub(crate) fn parse_from_attributes_list( + attributes: &[Attribute], + ) -> syn::Result> { let version_attributes: Vec<&Attribute> = attributes .iter() .filter(|attr| attr.path().is_ident(VERSIONIZE_ATTR_NAME)) .collect(); match version_attributes.as_slice() { - [] => Err(syn::Error::new( - Span::call_site(), - "Missing `versionize` attribute for `Versionize`", - )), - [attr] => Self::parse_from_attribute(attr), + [] => Ok(None), + [attr] => Self::parse_from_attribute(attr).map(Some), [_, attr2, ..] => Err(syn::Error::new( attr2.span(), "Multiple `versionize` attributes found", @@ -91,31 +81,6 @@ impl VersionizeAttribute { attribute_builder.dispatch_enum = Some(dispatch_enum.clone()); } } - Meta::List(list) => { - // parse versionize(bound(unversionize = "Type: Bound")) - if list.path.is_ident("bound") { - let name_value: MetaNameValue = list.parse_args()?; - let bound_attr: TypeParam = match &name_value.value { - Expr::Lit(ExprLit { - attrs: _, - lit: Lit::Str(s), - }) => syn::parse_str(&s.value())?, - _ => { - return Err(Self::default_error(meta.span())); - } - }; - - if name_value.path.is_ident("versionize") { - attribute_builder.versionize_bounds.push(bound_attr); - } else if name_value.path.is_ident("unversionize") { - attribute_builder.unversionize_bounds.push(bound_attr); - } else { - return Err(Self::default_error(meta.span())); - } - } else { - return Err(Self::default_error(meta.span())); - } - } Meta::NameValue(name_value) => { // parse versionize(from = "TypeFrom") if name_value.path.is_ident("from") { @@ -142,22 +107,11 @@ impl VersionizeAttribute { Some(parse_path_ignore_quotes(&name_value.value)?); } // parse versionize(bound = "Type: Bound") - } else if name_value.path.is_ident("bound") { - let bound_attr: TypeParam = match &name_value.value { - Expr::Lit(ExprLit { - attrs: _, - lit: Lit::Str(s), - }) => syn::parse_str(&s.value())?, - _ => { - return Err(Self::default_error(meta.span())); - } - }; - attribute_builder.versionize_bounds.push(bound_attr.clone()); - attribute_builder.unversionize_bounds.push(bound_attr); } else { return Err(Self::default_error(meta.span())); } } + _ => return Err(Self::default_error(meta.span())), } } @@ -213,14 +167,6 @@ impl VersionizeAttribute { quote! { #arg_name.try_into() } } } - - pub(crate) fn versionize_bounds(&self) -> &[TypeParam] { - &self.versionize_bounds - } - - pub(crate) fn unversionize_bounds(&self) -> &[TypeParam] { - &self.unversionize_bounds - } } fn parse_path_ignore_quotes(value: &Expr) -> syn::Result { diff --git a/utils/tfhe-versionable/Cargo.toml b/utils/tfhe-versionable/Cargo.toml index 0bf1b732b4..ce7f7eaa2e 100644 --- a/utils/tfhe-versionable/Cargo.toml +++ b/utils/tfhe-versionable/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tfhe-versionable" -version = "0.2.1" +version = "0.3.0" edition = "2021" keywords = ["versioning", "serialization", "encoding"] homepage = "https://zama.ai/" @@ -26,6 +26,6 @@ toml = "0.8" [dependencies] serde = { version = "1.0", features = ["derive"] } -tfhe-versionable-derive = { version = "0.2.1", path = "../tfhe-versionable-derive" } +tfhe-versionable-derive = { version = "0.3.0", path = "../tfhe-versionable-derive" } num-complex = { version = "0.4", features = ["serde"] } aligned-vec = { version = "0.5", features = ["serde"] } diff --git a/utils/tfhe-versionable/README.md b/utils/tfhe-versionable/README.md index 3ff1cf0675..8d540a587a 100644 --- a/utils/tfhe-versionable/README.md +++ b/utils/tfhe-versionable/README.md @@ -1,9 +1,15 @@ # TFHE-versionable -This crate provides type level versioning for serialized data. It offers a way to add backward -compatibility on any data type. The versioning scheme works recursively and is independant of the -chosen serialization backend. +This crate provides type level versioning for serialized data. It offers a way +to add backward compatibility on any data type. The versioning scheme works +recursively and is independant of the chosen serialized file format. It uses the +`serde` framework. -To use it, simply define an enum that have a variant for each version of your target type. +The crate will convert any type into an equivalent packed with versions +information. This "versioned" type is then serializable using any format +compatible with `serde`. + +To use it, simply define an enum that have a variant for each version of your +target type. For example, if you have defined an internal type: ```rust @@ -19,7 +25,9 @@ enum MyStructVersions { } ``` -If at a subsequent point in time you want to add a field to this struct, the idea is to copy the previously defined version of the struct and create a new one with the added field. This mostly becomes: +If at a subsequent point in time you want to add a field to this struct, the +idea is to copy the previously defined version of the struct and create a new +one with the added field. This mostly becomes: ```rust struct MyStruct { val: u32, @@ -36,16 +44,22 @@ enum MyStructVersions { } ``` -You also have to implement the `Upgrade` trait, that tells how to go from a version to another. +You also have to implement the `Upgrade` trait, that tells how to go from a +version to another. -To make this work recursively, this crate defines 3 derive macro that should be used on these types: -- `Versionize` should be used on the current version of your type, the one that is used in your code +To make this work recursively, this crate defines 3 derive macro that should be +used on these types: +- `Versionize` should be used on the current version of your type, the one that + is used in your code - `Version` is used on every previous version of this type - `VersionsDispatch` is used on the enum holding all the versions -This will implement the `Versionize`/`Unversionize` traits with their `versionize` and `unversionize` methods that should be used before/after the calls to `serialize`/`deserialize`. +This will implement the `Versionize`/`Unversionize` traits with their +`versionize` and `unversionize` methods that should be used before/after the +calls to `serialize`/`deserialize`. -The enum variants should keep their order and names between versions. The only supported operation is to add a new variant. +The enum variants should keep their order and names between versions. The only +allowed operation on this enum is to add a new variant. # Complete example ```rust diff --git a/utils/tfhe-versionable/examples/associated_bounds.rs b/utils/tfhe-versionable/examples/associated_bounds.rs new file mode 100644 index 0000000000..7370f5b32d --- /dev/null +++ b/utils/tfhe-versionable/examples/associated_bounds.rs @@ -0,0 +1,39 @@ +/// In this example, we use a generic that is not versionable itself. Only its associated types +/// should be versioned. +use tfhe_versionable::{Versionize, VersionsDispatch}; + +trait WithAssociated { + type Assoc; + type OtherAssoc; +} + +struct Marker; + +impl WithAssociated for Marker { + type Assoc = u64; + + type OtherAssoc = u32; +} + +#[derive(VersionsDispatch)] +#[allow(unused)] +enum MyStructVersions { + V0(MyStruct), +} + +#[derive(Versionize)] +#[versionize(MyStructVersions)] +struct MyStruct { + val: T::Assoc, + other_val: T::OtherAssoc, +} + +#[test] +fn main() { + let ms = MyStruct:: { + val: 27, + other_val: 54, + }; + + ms.versionize(); +} diff --git a/utils/tfhe-versionable/examples/bounds.rs b/utils/tfhe-versionable/examples/bounds.rs index dddfdfa085..f115d9d7bf 100644 --- a/utils/tfhe-versionable/examples/bounds.rs +++ b/utils/tfhe-versionable/examples/bounds.rs @@ -1,52 +1,75 @@ -//! This example shows how to use the `bound` attribute to add a specific bound that is needed to be -//! able to derive `Versionize` +//! Example of a simple struct with an Upgrade impl that requires a specific bound. +//! In that case, the previous versions of the type used a string as a representation, but it has +//! been changed to a Generic. For the upgrade to work, we need to be able to create this generic +//! from a String. -use serde::de::DeserializeOwned; -use serde::Serialize; -use tfhe_versionable::{ - Unversionize, UnversionizeError, Versionize, VersionizeOwned, VersionsDispatch, -}; +use std::error::Error; +use std::io::Cursor; +use std::str::FromStr; -// Example of a simple struct with a manual Versionize impl that requires a specific bound -struct MyStruct { - val: T, -} +use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch}; -impl Versionize for MyStruct { - type Versioned<'vers> = &'vers T where T: 'vers; +/// The previous version of our application +mod v0 { + use tfhe_versionable::{Versionize, VersionsDispatch}; - fn versionize(&self) -> Self::Versioned<'_> { - &self.val + #[derive(Versionize)] + #[versionize(MyStructVersions)] + pub(super) struct MyStruct { + pub(super) val: String, } -} - -impl> VersionizeOwned for MyStruct { - type VersionedOwned = T; - fn versionize_owned(self) -> Self::VersionedOwned { - self.val.to_owned() + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructVersions { + V0(MyStruct), } } -impl> Unversionize for MyStruct { - fn unversionize(versioned: Self::VersionedOwned) -> Result { - Ok(MyStruct { val: versioned }) - } +#[derive(Version)] +struct MyStructV0 { + val: String, } -// The additional bound can be specified on the parent struct using this attribute. This is similar -// to what serde does. You can also use #[versionize(OuterVersions, bound(unversionize = "T: -// ToOwned"))] if the bound is only needed for the Unversionize impl. #[derive(Versionize)] -#[versionize(OuterVersions, bound = "T: ToOwned")] -struct Outer { - inner: MyStruct, +#[versionize(MyStructVersions)] +struct MyStruct { + val: T, +} + +impl Upgrade> for MyStructV0 +where + ::Err: Error + Send + Sync + 'static, +{ + type Error = ::Err; + + fn upgrade(self) -> Result, Self::Error> { + let val = T::from_str(&self.val)?; + + Ok(MyStruct { val }) + } } #[derive(VersionsDispatch)] #[allow(unused)] -enum OuterVersions> { - V0(Outer), +enum MyStructVersions { + V0(MyStructV0), + V1(MyStruct), } -fn main() {} +#[test] +fn main() { + let val = 64; + let stru_v0 = v0::MyStruct { + val: format!("{val}"), + }; + + let mut ser = Vec::new(); + ciborium::ser::into_writer(&stru_v0.versionize(), &mut ser).unwrap(); + + let unvers = + MyStruct::::unversionize(ciborium::de::from_reader(&mut Cursor::new(&ser)).unwrap()) + .unwrap(); + + assert_eq!(unvers.val, val); +} diff --git a/utils/tfhe-versionable/examples/convert.rs b/utils/tfhe-versionable/examples/convert.rs index 41d8381cf9..c5f7955143 100644 --- a/utils/tfhe-versionable/examples/convert.rs +++ b/utils/tfhe-versionable/examples/convert.rs @@ -39,6 +39,7 @@ impl From for MyStruct { } } +#[test] fn main() { let stru = MyStruct { val: 37 }; diff --git a/utils/tfhe-versionable/examples/failed_upgrade.rs b/utils/tfhe-versionable/examples/failed_upgrade.rs index 7ae4a19e24..792d24524e 100644 --- a/utils/tfhe-versionable/examples/failed_upgrade.rs +++ b/utils/tfhe-versionable/examples/failed_upgrade.rs @@ -77,6 +77,7 @@ mod v1 { } } +#[test] fn main() { let v0 = v0::MyStruct(Some(37)); let serialized = bincode::serialize(&v0.versionize()).unwrap(); diff --git a/utils/tfhe-versionable/examples/manual_impl.rs b/utils/tfhe-versionable/examples/manual_impl.rs index 26283083ad..0eebf5dd43 100644 --- a/utils/tfhe-versionable/examples/manual_impl.rs +++ b/utils/tfhe-versionable/examples/manual_impl.rs @@ -94,6 +94,7 @@ enum MyStructVersionsDispatchOwned { V1(MyStructVersionOwned), } +#[test] fn main() { let ms = MyStruct { attr: 37u64, diff --git a/utils/tfhe-versionable/examples/not_versioned.rs b/utils/tfhe-versionable/examples/not_versioned.rs index e4578e6d17..ff70bfffb2 100644 --- a/utils/tfhe-versionable/examples/not_versioned.rs +++ b/utils/tfhe-versionable/examples/not_versioned.rs @@ -1,5 +1,6 @@ //! This example shows how to create a type that should not be versioned, even if it is -//! included in other versioned types +//! included in other versioned types. Of course it means that if this type is modified in the +//! future, the parent struct should be updated. use serde::{Deserialize, Serialize}; use tfhe_versionable::{NotVersioned, Versionize, VersionsDispatch}; @@ -21,4 +22,11 @@ enum MyStructVersions { V0(MyStruct), } -fn main() {} +#[test] +fn main() { + let ms = MyStruct { + inner: MyStructNotVersioned { val: 1234 }, + }; + + let _versioned = ms.versionize(); +} diff --git a/utils/tfhe-versionable/examples/recursive.rs b/utils/tfhe-versionable/examples/recursive.rs index 76202a8379..38a4e5591d 100644 --- a/utils/tfhe-versionable/examples/recursive.rs +++ b/utils/tfhe-versionable/examples/recursive.rs @@ -2,8 +2,9 @@ use std::convert::Infallible; -use tfhe_versionable::{Upgrade, Version, Versionize, VersionsDispatch}; +use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch}; +// The inner struct is independently versioned #[derive(Versionize)] #[versionize(MyStructInnerVersions)] struct MyStructInner { @@ -13,7 +14,7 @@ struct MyStructInner { #[derive(Version)] struct MyStructInnerV0 { - attr: u32, + builtin: u32, } impl Upgrade> for MyStructInnerV0 { @@ -22,7 +23,7 @@ impl Upgrade> for MyStructInnerV0 { fn upgrade(self) -> Result, Self::Error> { Ok(MyStructInner { attr: T::default(), - builtin: 0, + builtin: self.builtin, }) } } @@ -34,6 +35,7 @@ enum MyStructInnerVersions { V1(MyStructInner), } +// An upgrade of the inner struct does not require an upgrade of the outer struct #[derive(Versionize)] #[versionize(MyStructVersions)] struct MyStruct { @@ -46,4 +48,45 @@ enum MyStructVersions { V0(MyStruct), } -fn main() {} +mod v0 { + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(MyStructInnerVersions)] + pub(super) struct MyStructInner { + pub(super) builtin: u32, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructInnerVersions { + V0(MyStructInner), + } + + #[derive(Versionize)] + #[versionize(MyStructVersions)] + pub(super) struct MyStruct { + pub(super) inner: MyStructInner, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructVersions { + V0(MyStruct), + } +} + +#[test] +fn main() { + let builtin = 654; + let inner = v0::MyStructInner { builtin: 654 }; + let ms = v0::MyStruct { inner }; + + let serialized = bincode::serialize(&ms.versionize()).unwrap(); + + // This can be called in future versions of your application, when more variants have been added + let unserialized = + MyStruct::::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + assert_eq!(unserialized.inner.builtin, builtin); +} diff --git a/utils/tfhe-versionable/examples/simple.rs b/utils/tfhe-versionable/examples/simple.rs index a9995ff5c7..53a339a10a 100644 --- a/utils/tfhe-versionable/examples/simple.rs +++ b/utils/tfhe-versionable/examples/simple.rs @@ -43,14 +43,35 @@ enum MyStructVersions { V1(MyStruct), } +mod v0 { + // This module simulates an older version of our app where we initiated the versioning process. + // In real life code this would likely be only present in your git history. + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(MyStructVersions)] + pub(super) struct MyStruct { + pub(super) builtin: u32, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructVersions { + V0(MyStruct), + } +} + +#[test] fn main() { - let ms = MyStruct { - attr: 37u64, - builtin: 1234, - }; + // In the past we saved a value + let value = 1234; + let ms = v0::MyStruct { builtin: value }; let serialized = bincode::serialize(&ms.versionize()).unwrap(); // This can be called in future versions of your application, when more variants have been added - let _unserialized = MyStruct::::unversionize(bincode::deserialize(&serialized).unwrap()); + let unserialized = + MyStruct::::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + assert_eq!(unserialized.builtin, value); } diff --git a/utils/tfhe-versionable/examples/upgrades.rs b/utils/tfhe-versionable/examples/upgrades.rs index b9019c6199..dbb9a7cd14 100644 --- a/utils/tfhe-versionable/examples/upgrades.rs +++ b/utils/tfhe-versionable/examples/upgrades.rs @@ -132,6 +132,7 @@ mod v2 { } } +#[test] fn main() { let v0 = v0::MyStruct(37); diff --git a/utils/tfhe-versionable/examples/vec.rs b/utils/tfhe-versionable/examples/vec.rs index 0ecbc8398e..5435552da0 100644 --- a/utils/tfhe-versionable/examples/vec.rs +++ b/utils/tfhe-versionable/examples/vec.rs @@ -1,6 +1,16 @@ -/// The `VersionizeVec` and `UnversionizeVec` traits are also automatically derived -/// So that Vec can be versioned as well -use tfhe_versionable::{Versionize, VersionsDispatch}; +//! The `VersionizeVec` and `UnversionizeVec` traits are also automatically derived +//! So that Vec can be versioned as well. Because of the recursivity, each element of the vec +//! has its own version tag. For built-in rust types and anything that derives `NotVersioned`, +//! the versioning of the whole vec is skipped. + +use std::convert::Infallible; + +use tfhe_versionable::{Unversionize, Upgrade, Version, Versionize, VersionsDispatch}; + +#[derive(Version)] +struct MyStructInnerV0 { + val: u64, +} #[derive(Versionize)] #[versionize(MyStructInnerVersions)] @@ -9,10 +19,22 @@ struct MyStructInner { gen: T, } +impl Upgrade> for MyStructInnerV0 { + type Error = Infallible; + + fn upgrade(self) -> Result, Self::Error> { + Ok(MyStructInner { + val: self.val, + gen: T::default(), + }) + } +} + #[derive(VersionsDispatch)] #[allow(unused)] enum MyStructInnerVersions { - V0(MyStructInner), + V0(MyStructInnerV0), + V1(MyStructInner), } #[derive(Versionize)] @@ -27,4 +49,49 @@ enum MyVecVersions { V0(MyVec), } -fn main() {} +mod v0 { + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(MyStructInnerVersions)] + pub(super) struct MyStructInner { + pub(super) val: u64, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyStructInnerVersions { + V0(MyStructInner), + } + + #[derive(Versionize)] + #[versionize(MyVecVersions)] + pub(super) struct MyVec { + pub(super) vec: Vec, + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub(super) enum MyVecVersions { + V0(MyVec), + } +} + +#[test] +fn main() { + let values: [u64; 6] = [12, 23, 34, 45, 56, 67]; + let vec = values + .iter() + .map(|val| v0::MyStructInner { val: *val }) + .collect(); + let mv = v0::MyVec { vec }; + + let serialized = bincode::serialize(&mv.versionize()).unwrap(); + + let unserialized = + MyVec::::unversionize(bincode::deserialize(&serialized).unwrap()).unwrap(); + + let unser_values: Vec = unserialized.vec.iter().map(|inner| inner.val).collect(); + + assert_eq!(unser_values, values); +} diff --git a/utils/tfhe-versionable/src/lib.rs b/utils/tfhe-versionable/src/lib.rs index 1f54902eef..8975c88936 100644 --- a/utils/tfhe-versionable/src/lib.rs +++ b/utils/tfhe-versionable/src/lib.rs @@ -632,6 +632,7 @@ impl Unversionize for () { impl NotVersioned for () {} +// TODO: use a macro for more tuple sizes impl Versionize for (T, U) { type Versioned<'vers> = (T::Versioned<'vers>, U::Versioned<'vers>) where T: 'vers, U: 'vers; @@ -654,7 +655,35 @@ impl Unversionize for (T, U) { } } -impl NotVersioned for (T, U) {} +impl VersionizeSlice for (T, U) { + type VersionedSlice<'vers> = Vec<(T::Versioned<'vers>, U::Versioned<'vers>)> where T: 'vers, U: 'vers; + + fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { + slice + .iter() + .map(|(t, u)| (t.versionize(), u.versionize())) + .collect() + } +} + +impl VersionizeVec for (T, U) { + type VersionedVec = Vec<(T::VersionedOwned, U::VersionedOwned)>; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter() + .map(|(t, u)| (t.versionize_owned(), u.versionize_owned())) + .collect() + } +} + +impl UnversionizeVec for (T, U) { + fn unversionize_vec(versioned: Self::VersionedVec) -> Result, UnversionizeError> { + versioned + .into_iter() + .map(|(t, u)| Ok((T::unversionize(t)?, U::unversionize(u)?))) + .collect() + } +} impl Versionize for (T, U, V) { type Versioned<'vers> = (T::Versioned<'vers>, U::Versioned<'vers>, V::Versioned<'vers>) where T: 'vers, U: 'vers, V: 'vers; @@ -690,7 +719,47 @@ impl Unversionize for (T, U, } } -impl NotVersioned for (T, U, V) {} +impl VersionizeSlice for (T, U, V) { + type VersionedSlice<'vers> = Vec<(T::Versioned<'vers>, U::Versioned<'vers>, V::Versioned<'vers>)> where T: 'vers, U: 'vers, V: 'vers; + + fn versionize_slice(slice: &[Self]) -> Self::VersionedSlice<'_> { + slice + .iter() + .map(|(t, u, v)| (t.versionize(), u.versionize(), v.versionize())) + .collect() + } +} + +impl VersionizeVec for (T, U, V) { + type VersionedVec = Vec<(T::VersionedOwned, U::VersionedOwned, V::VersionedOwned)>; + + fn versionize_vec(vec: Vec) -> Self::VersionedVec { + vec.into_iter() + .map(|(t, u, v)| { + ( + t.versionize_owned(), + u.versionize_owned(), + v.versionize_owned(), + ) + }) + .collect() + } +} + +impl UnversionizeVec for (T, U, V) { + fn unversionize_vec(versioned: Self::VersionedVec) -> Result, UnversionizeError> { + versioned + .into_iter() + .map(|(t, u, v)| { + Ok(( + T::unversionize(t)?, + U::unversionize(u)?, + V::unversionize(v)?, + )) + }) + .collect() + } +} // converts to `Vec` for the versioned type, so we don't have to derive // Eq/Hash on it. diff --git a/utils/tfhe-versionable/src/upgrade.rs b/utils/tfhe-versionable/src/upgrade.rs index d0417af9a2..10e4747684 100644 --- a/utils/tfhe-versionable/src/upgrade.rs +++ b/utils/tfhe-versionable/src/upgrade.rs @@ -3,6 +3,6 @@ /// This trait should be implemented for each version of the original type that is not the current /// one. The upgrade method is called in chains until we get to the last version of the type. pub trait Upgrade { - type Error: std::error::Error; + type Error: std::error::Error + Send + Sync + 'static; fn upgrade(self) -> Result; } diff --git a/utils/tfhe-versionable/tests/bounds_private_in_public.rs b/utils/tfhe-versionable/tests/bounds_private_in_public.rs new file mode 100644 index 0000000000..489d6ad0fc --- /dev/null +++ b/utils/tfhe-versionable/tests/bounds_private_in_public.rs @@ -0,0 +1,44 @@ +//! This test checks that the bounds added by the proc macro does not prevent the code to +//! compile by leaking a private type +use tfhe_versionable::Versionize; + +mod mymod { + use tfhe_versionable::{Versionize, VersionsDispatch}; + + #[derive(Versionize)] + #[versionize(PublicVersions)] + pub struct Public { + private: Private, + } + + impl Public { + pub fn new(val: u64) -> Self { + Self { + private: Private(val), + } + } + } + + #[derive(VersionsDispatch)] + #[allow(unused)] + pub enum PublicVersions { + V0(Public), + } + + #[derive(Versionize)] + #[versionize(PrivateVersions)] + struct Private(T); + + #[derive(VersionsDispatch)] + #[allow(dead_code)] + enum PrivateVersions { + V0(Private), + } +} + +#[test] +fn bounds_private_in_public() { + let public = mymod::Public::new(42); + + let _vers = public.versionize(); +}