diff --git a/crates/burn-derive/src/module/codegen_enum.rs b/crates/burn-derive/src/module/codegen_enum.rs index 00a81edd15..81e836a39c 100644 --- a/crates/burn-derive/src/module/codegen_enum.rs +++ b/crates/burn-derive/src/module/codegen_enum.rs @@ -2,9 +2,11 @@ use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen}; use crate::shared::enum_variant::{parse_variants, EnumVariant}; use proc_macro2::{Ident, TokenStream}; use quote::quote; +use syn::Visibility; pub(crate) struct EnumModuleCodegen { pub variants: Vec, + pub vis: Visibility, } impl ModuleCodegen for EnumModuleCodegen { @@ -157,7 +159,7 @@ impl ModuleCodegen for EnumModuleCodegen { } fn record_codegen(self) -> Self::RecordCodegen { - EnumModuleRecordCodegen::new(self.variants) + EnumModuleRecordCodegen::new(self.variants, self.vis) } } @@ -165,6 +167,7 @@ impl EnumModuleCodegen { pub fn from_ast(ast: &syn::DeriveInput) -> Self { Self { variants: parse_variants(ast), + vis: ast.vis.clone(), } } diff --git a/crates/burn-derive/src/module/codegen_struct.rs b/crates/burn-derive/src/module/codegen_struct.rs index ae3e5f25e6..98ae6035a6 100644 --- a/crates/burn-derive/src/module/codegen_struct.rs +++ b/crates/burn-derive/src/module/codegen_struct.rs @@ -2,9 +2,11 @@ use super::{codegen::ModuleCodegen, record_struct::StructModuleRecordCodegen}; use crate::shared::field::{parse_fields, FieldTypeAnalyzer}; use proc_macro2::{Ident, TokenStream}; use quote::quote; +use syn::Visibility; pub(crate) struct StructModuleCodegen { pub fields: Vec, + pub vis: Visibility, } impl ModuleCodegen for StructModuleCodegen { @@ -182,7 +184,7 @@ impl ModuleCodegen for StructModuleCodegen { } fn record_codegen(self) -> Self::RecordCodegen { - StructModuleRecordCodegen::new(self.fields) + StructModuleRecordCodegen::new(self.fields, self.vis) } } @@ -193,6 +195,7 @@ impl StructModuleCodegen { .into_iter() .map(FieldTypeAnalyzer::new) .collect(), + vis: ast.vis.clone(), } } diff --git a/crates/burn-derive/src/module/record_enum.rs b/crates/burn-derive/src/module/record_enum.rs index d4c0035320..dda1911735 100644 --- a/crates/burn-derive/src/module/record_enum.rs +++ b/crates/burn-derive/src/module/record_enum.rs @@ -1,18 +1,20 @@ use crate::shared::enum_variant::EnumVariant; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::Generics; +use syn::{Generics, Visibility}; use super::record::ModuleRecordCodegen; #[derive(new)] pub(crate) struct EnumModuleRecordCodegen { variants: Vec, + vis: Visibility, } impl ModuleRecordCodegen for EnumModuleRecordCodegen { fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { let mut variants = quote! {}; + let vis = &self.vis; // Capture the Record enum variant types for variant in self.variants.iter() { @@ -31,7 +33,7 @@ impl ModuleRecordCodegen for EnumModuleRecordCodegen { /// The record type for the module. #[derive(burn::record::Record)] - pub enum #record_name #generics #generics_where { + #vis enum #record_name #generics #generics_where { #variants } } diff --git a/crates/burn-derive/src/module/record_struct.rs b/crates/burn-derive/src/module/record_struct.rs index b54dfdaee7..0f4af32059 100644 --- a/crates/burn-derive/src/module/record_struct.rs +++ b/crates/burn-derive/src/module/record_struct.rs @@ -1,18 +1,20 @@ use crate::shared::field::FieldTypeAnalyzer; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::Generics; +use syn::{Generics, Visibility}; use super::record::ModuleRecordCodegen; #[derive(new)] pub(crate) struct StructModuleRecordCodegen { fields: Vec, + vis: Visibility, } impl ModuleRecordCodegen for StructModuleRecordCodegen { fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> TokenStream { let mut fields = quote! {}; + let vis = &self.vis; for field in self.fields.iter() { let ty = &field.field.ty; @@ -20,7 +22,7 @@ impl ModuleRecordCodegen for StructModuleRecordCodegen { fields.extend(quote! { /// The module record associative type. - pub #name: <#ty as burn::module::Module>::Record, + #vis #name: <#ty as burn::module::Module>::Record, }); } @@ -30,7 +32,7 @@ impl ModuleRecordCodegen for StructModuleRecordCodegen { /// The record type for the module. #[derive(burn::record::Record)] - pub struct #record_name #generics #generics_where { + #vis struct #record_name #generics #generics_where { #fields } } diff --git a/crates/burn-derive/src/record/item/codegen_enum.rs b/crates/burn-derive/src/record/item/codegen_enum.rs index 112888a34d..58e4d28924 100644 --- a/crates/burn-derive/src/record/item/codegen_enum.rs +++ b/crates/burn-derive/src/record/item/codegen_enum.rs @@ -1,19 +1,21 @@ use crate::shared::enum_variant::{parse_variants, EnumVariant}; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::{parse_quote, Generics}; +use syn::{parse_quote, Generics, Visibility}; use super::codegen::RecordItemCodegen; pub(crate) struct EnumRecordItemCodegen { /// Enum variants. variants: Vec, + vis: Visibility, } impl RecordItemCodegen for EnumRecordItemCodegen { fn from_ast(ast: &syn::DeriveInput) -> Self { Self { variants: parse_variants(ast), + vis: ast.vis.clone(), } } @@ -25,6 +27,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen { ) -> TokenStream { let mut variants = quote! {}; let mut bounds = quote! {}; + let vis = &self.vis; // Capture the Record enum variant types and names to transpose them in RecordItem for variant in self.variants.iter() { @@ -62,7 +65,7 @@ impl RecordItemCodegen for EnumRecordItemCodegen { #[derive(burn::serde::Serialize, burn::serde::Deserialize)] #[serde(crate = "burn::serde")] #[serde(bound = #bound)] - pub enum #item_name #generics #generics_where { + #vis enum #item_name #generics #generics_where { #variants } } diff --git a/crates/burn-derive/src/record/item/codegen_struct.rs b/crates/burn-derive/src/record/item/codegen_struct.rs index de1ffd23a3..9f790561a3 100644 --- a/crates/burn-derive/src/record/item/codegen_struct.rs +++ b/crates/burn-derive/src/record/item/codegen_struct.rs @@ -1,12 +1,13 @@ use crate::shared::field::{parse_fields, FieldTypeAnalyzer}; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::{parse_quote, Generics}; +use syn::{parse_quote, Generics, Visibility}; use super::codegen::RecordItemCodegen; pub(crate) struct StructRecordItemCodegen { fields: Vec, + vis: Visibility, } impl RecordItemCodegen for StructRecordItemCodegen { @@ -16,6 +17,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { .into_iter() .map(FieldTypeAnalyzer::new) .collect(), + vis: ast.vis.clone(), } } @@ -27,6 +29,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { ) -> TokenStream { let mut fields = quote! {}; let mut bounds = quote! {}; + let vis = &self.vis; for field in self.fields.iter() { let ty = &field.field.ty; @@ -60,7 +63,7 @@ impl RecordItemCodegen for StructRecordItemCodegen { #[derive(burn::serde::Serialize, burn::serde::Deserialize)] #[serde(crate = "burn::serde")] #[serde(bound = #bound)] - pub struct #item_name #generics #generics_where { + #vis struct #item_name #generics #generics_where { #fields } }