Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ns/fix/versionable bad bounds #1551

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/aws_tfhe_fast_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
outputs:
csprng_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.csprng_any_changed }}
zk_pok_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.zk_pok_any_changed }}
versionable_test: ${{ env.IS_PULL_REQUEST == 'false' || steps.changed-files.outputs.versionable_any_changed }}
core_crypto_test: ${{ env.IS_PULL_REQUEST == 'false' ||
steps.changed-files.outputs.core_crypto_any_changed ||
steps.changed-files.outputs.dependencies_any_changed }}
Expand Down Expand Up @@ -64,10 +65,15 @@ jobs:
- tfhe/Cargo.toml
- concrete-csprng/**
- tfhe-zk-pok/**
- utils/tfhe-versionable/**
- utils/tfhe-versionable-derive/**
csprng:
- concrete-csprng/**
zk_pok:
- tfhe-zk-pok/**
versionable:
- utils/tfhe-versionable/**
- utils/tfhe-versionable-derive/**
core_crypto:
- tfhe/src/core_crypto/**
boolean:
Expand Down Expand Up @@ -103,6 +109,7 @@ jobs:
if: ( steps.changed-files.outputs.dependencies_any_changed == 'true' ||
steps.changed-files.outputs.csprng_any_changed == 'true' ||
steps.changed-files.outputs.zk_pok_any_changed == 'true' ||
steps.changed-files.outputs.versionable_any_changed == 'true' ||
steps.changed-files.outputs.core_crypto_any_changed == 'true' ||
steps.changed-files.outputs.boolean_any_changed == 'true' ||
steps.changed-files.outputs.shortint_any_changed == 'true' ||
Expand Down Expand Up @@ -167,6 +174,11 @@ jobs:
run: |
make test_zk_pok

- name: Run tfhe-versionable tests
if: needs.should-run.outputs.versionable_test == 'true'
run: |
make test_versionable

- name: Run core tests
if: needs.should-run.outputs.core_crypto_test == 'true'
run: |
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ test_zk_pok: install_rs_build_toolchain
.PHONY: test_versionable # Run tests for tfhe-versionable subcrate
test_versionable: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
-p tfhe-versionable
--all-targets -p tfhe-versionable

# The backward compat data repo holds historical binary data but also rust code to generate and load them.
# Here we use the "patch" functionality of Cargo to make sure the repo used for the data is the same as the one used for the code.
Expand Down
2 changes: 1 addition & 1 deletion tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ sha3 = { version = "0.10", optional = true }
itertools = "0.11.0"
rand_core = { version = "0.6.4", features = ["std"] }
tfhe-zk-pok = { version = "0.3.0-alpha.1", path = "../tfhe-zk-pok", optional = true }
tfhe-versionable = { version = "0.2.1", path = "../utils/tfhe-versionable" }
tfhe-versionable = { version = "0.3.0", path = "../utils/tfhe-versionable" }

# wasm deps
wasm-bindgen = { version = "0.2.86", features = [
Expand Down
2 changes: 1 addition & 1 deletion utils/tfhe-versionable-derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tfhe-versionable-derive"
version = "0.2.1"
version = "0.3.0"
edition = "2021"
keywords = ["versioning", "serialization", "encoding", "proc-macro", "derive"]
homepage = "https://zama.ai/"
Expand Down
95 changes: 54 additions & 41 deletions utils/tfhe-versionable-derive/src/associated.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::{
parse_quote, DeriveInput, ImplGenerics, Item, ItemImpl, Lifetime, Path, Type, WhereClause,
parse_quote, DeriveInput, Generics, ImplGenerics, Item, ItemImpl, Lifetime, Path, Type,
WhereClause,
};

use crate::{
add_lifetime_bound, add_trait_bound, add_trait_where_clause, add_where_lifetime_bound,
add_lifetime_bound, add_trait_where_clause, add_where_lifetime_bound, extend_where_clause,
parse_const_str, DESERIALIZE_TRAIT_NAME, LIFETIME_NAME, SERIALIZE_TRAIT_NAME,
};

Expand Down Expand Up @@ -91,6 +92,11 @@ pub(crate) enum AssociatedTypeKind {
/// [`DispatchType`]: crate::dispatch_type::DispatchType
/// [`VersionType`]: crate::dispatch_type::VersionType
pub(crate) trait AssociatedType: Sized {
/// Bounds that will be added on the fields of the ref type definition
const REF_BOUNDS: &'static [&'static str];
/// Bounds that will be added on the fields of the owned type definition
const OWNED_BOUNDS: &'static [&'static str];

/// This will create the alternative of the type that holds a reference to the underlying data
fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self>;
/// This will create the alternative of the type that owns the underlying data
Expand All @@ -99,6 +105,30 @@ pub(crate) trait AssociatedType: Sized {
/// Generates the type declaration for this type
fn generate_type_declaration(&self) -> syn::Result<Item>;

/// Returns the kind of associated type, a ref or an owned type
fn kind(&self) -> &AssociatedTypeKind;

/// Returns the generics found in the original type definition
fn orig_type_generics(&self) -> &Generics;

/// Returns the generics and bounds that should be added to the type
fn type_generics(&self) -> syn::Result<Generics> {
let mut generics = self.orig_type_generics().clone();
if let AssociatedTypeKind::Ref(opt_lifetime) = &self.kind() {
if let Some(lifetime) = opt_lifetime {
add_lifetime_bound(&mut generics, lifetime);
}
add_trait_where_clause(&mut generics, self.inner_types()?, Self::REF_BOUNDS)?;
} else {
add_trait_where_clause(&mut generics, self.inner_types()?, Self::OWNED_BOUNDS)?;
}

Ok(generics)
}

/// Generics needed for conversions between the original type and associated types
fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result<Generics>;

/// Generates conversion methods between the origin type and the associated type. If the version
/// type is a ref, the conversion is `From<&'vers OrigType> for Associated<'vers>` because this
/// conversion is used for versioning. If the version type is owned, the conversion is
Expand All @@ -109,10 +139,6 @@ pub(crate) trait AssociatedType: Sized {
/// [`Version`]: crate::dispatch_type::VersionType
fn generate_conversion(&self) -> syn::Result<Vec<ItemImpl>>;

/// The lifetime added for this type, if it is a "ref" type. It also returns None if the type is
/// a unit type (no data)
//fn lifetime(&self) -> Option<&Lifetime>;

/// The identifier used to name this type
fn ident(&self) -> Ident;

Expand Down Expand Up @@ -144,40 +170,19 @@ pub(crate) struct AssociatingTrait<T> {
owned_type: T,
orig_type: DeriveInput,
trait_path: Path,
/// Bounds that should be added to the generics for the impl
generics_bounds: Vec<String>,
/// Bounds that should be added on the struct attributes
attributes_bounds: Vec<String>,
}

impl<T: AssociatedType> AssociatingTrait<T> {
pub(crate) fn new(
orig_type: &DeriveInput,
name: &str,
generics_bounds: &[&str],
attributes_bounds: &[&str],
) -> syn::Result<Self> {
pub(crate) fn new(orig_type: &DeriveInput, name: &str) -> syn::Result<Self> {
let ref_type = T::new_ref(orig_type)?;
let owned_type = T::new_owned(orig_type)?;
let trait_path = syn::parse_str(name)?;

let generics_bounds = generics_bounds
.iter()
.map(|bound| bound.to_string())
.collect();

let attributes_bounds = attributes_bounds
.iter()
.map(|bound| bound.to_string())
.collect();

Ok(Self {
ref_type,
owned_type,
orig_type: orig_type.clone(),
trait_path,
generics_bounds,
attributes_bounds,
})
}

Expand All @@ -189,22 +194,24 @@ impl<T: AssociatedType> AssociatingTrait<T> {
let ref_ident = self.ref_type.ident();
let owned_ident = self.owned_type.ident();

let mut generics = self.orig_type.generics.clone();

for bound in &self.generics_bounds {
add_trait_bound(&mut generics, bound)?;
}

let trait_param = self.ref_type.as_trait_param().transpose()?;

let mut ref_type_generics = generics.clone();

add_trait_where_clause(
&mut generics,
self.ref_type.inner_types()?,
&self.attributes_bounds,
)?;
// AssociatedToOrig conversion always has a stricter bound than the other side so we use it
let mut generics = self
.owned_type
.conversion_generics(ConversionDirection::AssociatedToOrig)?;

// Merge the where clause for the reference type with the one from the owned type
let owned_where_clause = generics.make_where_clause();
if let Some(ref_where_clause) = self
.ref_type
.conversion_generics(ConversionDirection::AssociatedToOrig)?
.where_clause
{
extend_where_clause(owned_where_clause, &ref_where_clause);
}

let mut ref_type_generics = self.ref_type.orig_type_generics().clone();
// If the original type has some generics, we need to add a lifetime bound on them
if let Some(lifetime) = self.ref_type.lifetime() {
add_lifetime_bound(&mut ref_type_generics, lifetime);
Expand Down Expand Up @@ -246,8 +253,13 @@ impl<T: AssociatedType> AssociatingTrait<T> {
)
]};

// Creates the type declaration. These types are the output of the versioning process, so
// they should be serializable. Serde might try to add automatic bounds on the type generics
// even if we don't need them, so we use `#[serde(bound = "")]` to disable this. The bounds
// on the generated types should be sufficient.
let owned_tokens = quote! {
#[derive(#serialize_trait, #deserialize_trait)]
#[serde(bound = "")]
#ignored_lints
#owned_decla

Expand All @@ -260,6 +272,7 @@ impl<T: AssociatedType> AssociatingTrait<T> {

let ref_tokens = quote! {
#[derive(#serialize_trait)]
#[serde(bound = "")]
#ignored_lints
#ref_decla

Expand Down
71 changes: 48 additions & 23 deletions utils/tfhe-versionable-derive/src/dispatch_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use syn::{

use crate::associated::{
generate_from_trait_impl, generate_try_from_trait_impl, AssociatedType, AssociatedTypeKind,
ConversionDirection,
};
use crate::{
add_lifetime_bound, add_trait_bound, add_trait_where_clause, parse_const_str, LIFETIME_NAME,
UNVERSIONIZE_ERROR_NAME, VERSIONIZE_TRAIT_NAME, VERSION_TRAIT_NAME,
parse_const_str, LIFETIME_NAME, UNVERSIONIZE_ERROR_NAME, UPGRADE_TRAIT_NAME, VERSION_TRAIT_NAME,
};

/// This is the enum that holds all the versions of a specific type. Each variant of the enum is
Expand Down Expand Up @@ -47,6 +47,10 @@ fn derive_input_to_enum(input: &DeriveInput) -> syn::Result<ItemEnum> {
}

impl AssociatedType for DispatchType {
const REF_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME];

const OWNED_BOUNDS: &'static [&'static str] = &[VERSION_TRAIT_NAME];

fn new_ref(orig_type: &DeriveInput) -> syn::Result<Self> {
for lt in orig_type.generics.lifetimes() {
// check for collision with other lifetimes in `orig_type`
Expand Down Expand Up @@ -93,21 +97,49 @@ impl AssociatedType for DispatchType {

Ok(ItemEnum {
ident: self.ident(),
generics: self.generics()?,
generics: self.type_generics()?,
attrs: vec![parse_quote! { #[automatically_derived] }],
variants: variants?,
..self.orig_type.clone()
}
.into())
}

fn generate_conversion(&self) -> syn::Result<Vec<ItemImpl>> {
let generics = self.generics()?;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
fn kind(&self) -> &AssociatedTypeKind {
&self.kind
}

fn orig_type_generics(&self) -> &Generics {
&self.orig_type.generics
}

fn conversion_generics(&self, direction: ConversionDirection) -> syn::Result<Generics> {
let mut generics = self.type_generics()?;
let preds = &mut generics.make_where_clause().predicates;

let upgrade_trait: Path = parse_const_str(UPGRADE_TRAIT_NAME);

if let ConversionDirection::AssociatedToOrig = direction {
if let AssociatedTypeKind::Owned = &self.kind {
// Add a bound for each version to be upgradable into the next one
for src_idx in 0..(self.versions_count() - 1) {
let src_ty = self.version_type_at(src_idx)?;
let next_ty = self.version_type_at(src_idx + 1)?;
preds.push(parse_quote! { #src_ty: #upgrade_trait<#next_ty> })
}
}
}

Ok(generics)
}

fn generate_conversion(&self) -> syn::Result<Vec<ItemImpl>> {
match &self.kind {
AssociatedTypeKind::Ref(lifetime) => {
// Wraps the highest version into the dispatch enum
let generics = self.conversion_generics(ConversionDirection::OrigToAssociated)?;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let src_type = self.latest_version_type()?;
let src = parse_quote! { &#lifetime #src_type };
let dest_ident = self.ident();
Expand All @@ -126,6 +158,9 @@ impl AssociatedType for DispatchType {
}
AssociatedTypeKind::Owned => {
// Upgrade to the highest version the convert to the main type
let generics = self.conversion_generics(ConversionDirection::AssociatedToOrig)?;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let src_ident = self.ident();
let src = parse_quote! { #src_ident #ty_generics };
let dest_type = self.latest_version_type()?;
Expand All @@ -144,6 +179,9 @@ impl AssociatedType for DispatchType {
)?;

// Wraps the highest version into the dispatch enum
let generics = self.conversion_generics(ConversionDirection::OrigToAssociated)?;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let src_type = self.latest_version_type()?;
let src = parse_quote! { #src_type };
let dest_ident = self.ident();
Expand Down Expand Up @@ -182,13 +220,13 @@ impl AssociatedType for DispatchType {
}
}

fn as_trait_param(&self) -> Option<syn::Result<&Type>> {
Some(self.latest_version_type())
}

fn inner_types(&self) -> syn::Result<Vec<&Type>> {
self.version_types()
}

fn as_trait_param(&self) -> Option<syn::Result<&Type>> {
Some(self.latest_version_type())
}
}

impl DispatchType {
Expand All @@ -200,19 +238,6 @@ impl DispatchType {
)
}

fn generics(&self) -> syn::Result<Generics> {
let mut generics = self.orig_type.generics.clone();
if let AssociatedTypeKind::Ref(Some(lifetime)) = &self.kind {
add_lifetime_bound(&mut generics, lifetime);
}

add_trait_where_clause(&mut generics, self.inner_types()?, &[VERSION_TRAIT_NAME])?;

add_trait_bound(&mut generics, VERSIONIZE_TRAIT_NAME)?;

Ok(generics)
}

/// Returns the number of versions in this dispatch enum
fn versions_count(&self) -> usize {
self.orig_type.variants.len()
Expand Down
Loading
Loading