Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(versionable): Handle generics in NotVersioned #1995

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions utils/tfhe-versionable-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeE

pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
Expand Down Expand Up @@ -316,14 +317,26 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);

let mut generics = input.generics.clone();
// Versionize needs T to impl Serialize
let mut versionize_generics = input.generics.clone();
syn_unwrap!(add_trait_where_clause(
&mut generics,
&mut versionize_generics,
&[parse_quote! { Self }],
&["Clone"]
&[SERIALIZE_TRAIT_NAME]
));

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
// VersionizeOwned needs T to impl Serialize and DeserializeOwned
let mut versionize_owned_generics = input.generics.clone();
syn_unwrap!(add_trait_where_clause(
&mut versionize_owned_generics,
&[parse_quote! { Self }],
&[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME]
));

let (impl_generics, ty_generics, versionize_where_clause) =
versionize_generics.split_for_impl();
let (_, _, versionize_owned_where_clause) = versionize_owned_generics.split_for_impl();

let input_ident = &input.ident;

let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
Expand All @@ -334,16 +347,16 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {

quote! {
#[automatically_derived]
impl #impl_generics #versionize_trait for #input_ident #ty_generics #where_clause {
type Versioned<#lifetime> = &#lifetime Self;
impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause {
type Versioned<#lifetime> = &#lifetime Self where Self: 'vers;

fn versionize(&self) -> Self::Versioned<'_> {
self
}
}

#[automatically_derived]
impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #where_clause {
impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #versionize_owned_where_clause {
type VersionedOwned = Self;

fn versionize_owned(self) -> Self::VersionedOwned {
Expand All @@ -352,14 +365,14 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
}

#[automatically_derived]
impl #impl_generics #unversionize_trait for #input_ident #ty_generics #where_clause {
impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
Ok(versioned)
}
}

#[automatically_derived]
impl NotVersioned for #input_ident #ty_generics #where_clause {}
impl #impl_generics NotVersioned for #input_ident #ty_generics #versionize_owned_where_clause {}

}
.into()
Expand Down
6 changes: 3 additions & 3 deletions utils/tfhe-versionable/examples/not_versioned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use serde::{Deserialize, Serialize};
use tfhe_versionable::{NotVersioned, Versionize, VersionsDispatch};

#[derive(Clone, Serialize, Deserialize, NotVersioned)]
struct MyStructNotVersioned {
val: u32,
struct MyStructNotVersioned<Inner> {
val: Inner,
}

#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct {
inner: MyStructNotVersioned,
inner: MyStructNotVersioned<u32>,
}

#[derive(VersionsDispatch)]
Expand Down
Loading