Skip to content

Commit

Permalink
Module derive types should inherit visibility (tracel-ai#2610)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Dec 13, 2024
1 parent ebd7649 commit 9d355ef
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 11 deletions.
5 changes: 4 additions & 1 deletion crates/burn-derive/src/module/codegen_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};
use crate::shared::enum_variant::{parse_variants, EnumVariant};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Visibility;

pub(crate) struct EnumModuleCodegen {
pub variants: Vec<EnumVariant>,
pub vis: Visibility,
}

impl ModuleCodegen for EnumModuleCodegen {
Expand Down Expand Up @@ -157,14 +159,15 @@ impl ModuleCodegen for EnumModuleCodegen {
}

fn record_codegen(self) -> Self::RecordCodegen {
EnumModuleRecordCodegen::new(self.variants)
EnumModuleRecordCodegen::new(self.variants, self.vis)
}
}

impl EnumModuleCodegen {
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
variants: parse_variants(ast),
vis: ast.vis.clone(),
}
}

Expand Down
5 changes: 4 additions & 1 deletion crates/burn-derive/src/module/codegen_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ use super::{codegen::ModuleCodegen, record_struct::StructModuleRecordCodegen};
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Visibility;

pub(crate) struct StructModuleCodegen {
pub fields: Vec<FieldTypeAnalyzer>,
pub vis: Visibility,
}

impl ModuleCodegen for StructModuleCodegen {
Expand Down Expand Up @@ -182,7 +184,7 @@ impl ModuleCodegen for StructModuleCodegen {
}

fn record_codegen(self) -> Self::RecordCodegen {
StructModuleRecordCodegen::new(self.fields)
StructModuleRecordCodegen::new(self.fields, self.vis)
}
}

Expand All @@ -193,6 +195,7 @@ impl StructModuleCodegen {
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
vis: ast.vis.clone(),
}
}

Expand Down
6 changes: 4 additions & 2 deletions crates/burn-derive/src/module/record_enum.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
use crate::shared::enum_variant::EnumVariant;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Generics;
use syn::{Generics, Visibility};

use super::record::ModuleRecordCodegen;

#[derive(new)]
pub(crate) struct EnumModuleRecordCodegen {
variants: Vec<EnumVariant>,
vis: Visibility,
}

impl ModuleRecordCodegen for EnumModuleRecordCodegen {
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
let mut variants = quote! {};
let vis = &self.vis;

// Capture the Record enum variant types
for variant in self.variants.iter() {
Expand All @@ -31,7 +33,7 @@ impl ModuleRecordCodegen for EnumModuleRecordCodegen {

/// The record type for the module.
#[derive(burn::record::Record)]
pub enum #record_name #generics #generics_where {
#vis enum #record_name #generics #generics_where {
#variants
}
}
Expand Down
8 changes: 5 additions & 3 deletions crates/burn-derive/src/module/record_struct.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
use crate::shared::field::FieldTypeAnalyzer;
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::Generics;
use syn::{Generics, Visibility};

use super::record::ModuleRecordCodegen;

#[derive(new)]
pub(crate) struct StructModuleRecordCodegen {
fields: Vec<FieldTypeAnalyzer>,
vis: Visibility,
}

impl ModuleRecordCodegen for StructModuleRecordCodegen {
fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream {
let mut fields = quote! {};
let vis = &self.vis;

for field in self.fields.iter() {
let ty = &field.field.ty;
let name = &field.field.ident;

fields.extend(quote! {
/// The module record associative type.
pub #name: <#ty as burn::module::Module<B>>::Record,
#vis #name: <#ty as burn::module::Module<B>>::Record,
});
}

Expand All @@ -30,7 +32,7 @@ impl ModuleRecordCodegen for StructModuleRecordCodegen {

/// The record type for the module.
#[derive(burn::record::Record)]
pub struct #record_name #generics #generics_where {
#vis struct #record_name #generics #generics_where {
#fields
}
}
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-derive/src/record/item/codegen_enum.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use crate::shared::enum_variant::{parse_variants, EnumVariant};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use syn::{parse_quote, Generics, Visibility};

use super::codegen::RecordItemCodegen;

pub(crate) struct EnumRecordItemCodegen {
/// Enum variants.
variants: Vec<EnumVariant>,
vis: Visibility,
}

impl RecordItemCodegen for EnumRecordItemCodegen {
fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
variants: parse_variants(ast),
vis: ast.vis.clone(),
}
}

Expand All @@ -25,6 +27,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen {
) -> TokenStream {
let mut variants = quote! {};
let mut bounds = quote! {};
let vis = &self.vis;

// Capture the Record enum variant types and names to transpose them in RecordItem
for variant in self.variants.iter() {
Expand Down Expand Up @@ -62,7 +65,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen {
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#[serde(bound = #bound)]
pub enum #item_name #generics #generics_where {
#vis enum #item_name #generics #generics_where {
#variants
}
}
Expand Down
7 changes: 5 additions & 2 deletions crates/burn-derive/src/record/item/codegen_struct.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{parse_quote, Generics};
use syn::{parse_quote, Generics, Visibility};

use super::codegen::RecordItemCodegen;

pub(crate) struct StructRecordItemCodegen {
fields: Vec<FieldTypeAnalyzer>,
vis: Visibility,
}

impl RecordItemCodegen for StructRecordItemCodegen {
Expand All @@ -16,6 +17,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
vis: ast.vis.clone(),
}
}

Expand All @@ -27,6 +29,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
) -> TokenStream {
let mut fields = quote! {};
let mut bounds = quote! {};
let vis = &self.vis;

for field in self.fields.iter() {
let ty = &field.field.ty;
Expand Down Expand Up @@ -60,7 +63,7 @@ impl RecordItemCodegen for StructRecordItemCodegen {
#[derive(burn::serde::Serialize, burn::serde::Deserialize)]
#[serde(crate = "burn::serde")]
#[serde(bound = #bound)]
pub struct #item_name #generics #generics_where {
#vis struct #item_name #generics #generics_where {
#fields
}
}
Expand Down

0 comments on commit 9d355ef

Please sign in to comment.