Skip to content

Commit

Permalink
feat: Attributes forwarding to message enums and fields (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
kulikthebird authored Jun 28, 2024
1 parent 90af7d0 commit 55e9667
Show file tree
Hide file tree
Showing 15 changed files with 452 additions and 119 deletions.
50 changes: 47 additions & 3 deletions sylvia-derive/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::associated_types::{AssociatedTypes, ItemType, EXEC_TYPE, QUERY_TYPE};
use crate::check_generics::{CheckGenerics, GetPath};
use crate::crate_module;
use crate::interfaces::Interfaces;
use crate::parser::attributes::{MsgAttrForwarding, VariantAttrForwarding};
use crate::parser::{
parse_associated_custom_type, ContractErrorAttr, Custom, EntryPointArgs,
FilteredOverrideEntryPoints, MsgAttr, MsgType, OverrideEntryPoint, ParsedSylviaAttributes,
Expand Down Expand Up @@ -38,6 +39,7 @@ pub struct StructMessage<'a> {
result: &'a ReturnType,
msg_attr: MsgAttr,
custom: &'a Custom,
msg_attrs_to_forward: Vec<MsgAttrForwarding>,
}

impl<'a> StructMessage<'a> {
Expand All @@ -60,6 +62,12 @@ impl<'a> StructMessage<'a> {
let (used_generics, unused_generics) = generics_checker.used_unused();
let wheres = filter_wheres(&source.generics.where_clause, generics, &used_generics);

let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
.msg_attrs_forward
.into_iter()
.filter(|attr| attr.msg_type == ty)
.collect();

Some(Self {
contract_type,
fields,
Expand All @@ -71,6 +79,7 @@ impl<'a> StructMessage<'a> {
result: &method.sig.output,
msg_attr,
custom,
msg_attrs_to_forward,
})
}

Expand Down Expand Up @@ -111,9 +120,12 @@ impl<'a> StructMessage<'a> {
pub fn emit(&self) -> TokenStream {
use MsgAttr::*;

let instantiate_msg = Ident::new("InstantiateMsg", self.function_name.span());
let migrate_msg = Ident::new("MigrateMsg", self.function_name.span());

match &self.msg_attr {
Instantiate { name } => self.emit_struct(name),
Migrate { name } => self.emit_struct(name),
Instantiate { .. } => self.emit_struct(&instantiate_msg),
Migrate { .. } => self.emit_struct(&migrate_msg),
_ => {
emit_error!(Span::mixed_site(), "Invalid message type");
quote! {}
Expand All @@ -135,6 +147,7 @@ impl<'a> StructMessage<'a> {
result,
msg_attr,
custom,
msg_attrs_to_forward,
} = self;

let ctx_type = msg_attr
Expand All @@ -151,10 +164,12 @@ impl<'a> StructMessage<'a> {
let where_clause = as_where_clause(wheres);
let generics = emit_bracketed_generics(generics);
let unused_generics = emit_bracketed_generics(unused_generics);
let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);

quote! {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema)]
#( #[ #msg_attrs_to_forward ] )*
#[serde(rename_all="snake_case")]
pub struct #name #generics {
#(pub #fields,)*
Expand Down Expand Up @@ -184,6 +199,7 @@ pub struct EnumMessage<'a> {
msg_ty: MsgType,
resp_type: Type,
query_type: Type,
msg_attrs_to_forward: Vec<MsgAttrForwarding>,
}

impl<'a> EnumMessage<'a> {
Expand All @@ -209,13 +225,20 @@ impl<'a> EnumMessage<'a> {
.or(associated_query)
.unwrap_or_else(Custom::default_type);

let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
.msg_attrs_forward
.into_iter()
.filter(|attr| attr.msg_type == msg_ty)
.collect();

Self {
source,
variants,
associated_types,
msg_ty,
resp_type,
query_type,
msg_attrs_to_forward,
}
}

Expand All @@ -227,6 +250,7 @@ impl<'a> EnumMessage<'a> {
msg_ty,
resp_type,
query_type,
msg_attrs_to_forward,
} = self;

let trait_name = &source.ident;
Expand Down Expand Up @@ -259,10 +283,12 @@ impl<'a> EnumMessage<'a> {
let ep_name = msg_ty.emit_ep_name();
let messages_fn_name = Ident::new(&format!("{}_messages", ep_name), enum_name.span());
let derive_call = msg_ty.emit_derive_call();
let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);

quote! {
#[allow(clippy::derive_partial_eq_without_eq)]
#derive_call
#( #[ #msg_attrs_to_forward ] )*
#[serde(rename_all="snake_case")]
pub enum #unique_enum_name #bracketed_used_generics {
#(#msg_variants,)*
Expand Down Expand Up @@ -302,6 +328,7 @@ pub struct ContractEnumMessage<'a> {
error: &'a ContractErrorAttr,
custom: &'a Custom,
where_clause: &'a Option<WhereClause>,
msg_attrs_to_forward: Vec<MsgAttrForwarding>,
}

impl<'a> ContractEnumMessage<'a> {
Expand All @@ -314,6 +341,11 @@ impl<'a> ContractEnumMessage<'a> {
) -> Self {
let where_clause = &source.generics.where_clause;
let variants = MsgVariants::new(source.as_variants(), msg_ty, generics, where_clause);
let msg_attrs_to_forward = ParsedSylviaAttributes::new(source.attrs.iter())
.msg_attrs_forward
.into_iter()
.filter(|attr| attr.msg_type == msg_ty)
.collect();

Self {
variants,
Expand All @@ -322,6 +354,7 @@ impl<'a> ContractEnumMessage<'a> {
error,
custom,
where_clause,
msg_attrs_to_forward,
}
}

Expand All @@ -335,6 +368,7 @@ impl<'a> ContractEnumMessage<'a> {
error,
custom,
where_clause,
msg_attrs_to_forward,
..
} = self;

Expand Down Expand Up @@ -369,10 +403,12 @@ impl<'a> ContractEnumMessage<'a> {
},
false => quote! {},
};
let msg_attrs_to_forward = msg_attrs_to_forward.iter().map(|attr| &attr.attrs);

quote! {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(#sylvia ::serde::Serialize, #sylvia ::serde::Deserialize, Clone, Debug, PartialEq, #sylvia ::schemars::JsonSchema, #derive_query )]
#( #[ #msg_attrs_to_forward ] )*
#[serde(rename_all="snake_case")]
pub enum #enum_name #bracketed_used_generics {
#(#variants,)*
Expand Down Expand Up @@ -409,6 +445,7 @@ pub struct MsgVariant<'a> {
/// `returns` attribute.
return_type: Option<Type>,
msg_type: MsgType,
attrs_to_forward: Vec<VariantAttrForwarding>,
}

impl<'a> MsgVariant<'a> {
Expand All @@ -417,6 +454,7 @@ impl<'a> MsgVariant<'a> {
sig: &'a Signature,
generics_checker: &mut CheckGenerics<Generic>,
msg_attr: MsgAttr,
attrs_to_forward: Vec<VariantAttrForwarding>,
) -> MsgVariant<'a>
where
Generic: GetPath + PartialEq,
Expand All @@ -427,7 +465,7 @@ impl<'a> MsgVariant<'a> {
let fields = process_fields(sig, generics_checker);
let msg_type = msg_attr.msg_type();

let return_type = if let MsgAttr::Query { resp_type } = msg_attr {
let return_type = if let MsgAttr::Query { resp_type, .. } = msg_attr {
match resp_type {
Some(resp_type) => {
let resp_type = parse_quote! { #resp_type };
Expand All @@ -451,6 +489,7 @@ impl<'a> MsgVariant<'a> {
fields,
return_type,
msg_type,
attrs_to_forward,
}
}

Expand All @@ -461,13 +500,16 @@ impl<'a> MsgVariant<'a> {
fields,
msg_type,
return_type,
attrs_to_forward,
..
} = self;
let fields = fields.iter().map(MsgField::emit);
let returns_attribute = msg_type.emit_returns_attribute(return_type);
let attrs_to_forward = attrs_to_forward.iter().map(|attr| &attr.attrs);

quote! {
#returns_attribute
#( #[ #attrs_to_forward ] )*
#name {
#(#fields,)*
}
Expand Down Expand Up @@ -570,6 +612,7 @@ where
let variants: Vec<_> = source
.filter_map(|variant_desc| {
let msg_attr: MsgAttr = variant_desc.attr_msg()?;
let attrs_to_forward = variant_desc.attrs_to_forward();

if msg_attr.msg_type() != msg_ty {
return None;
Expand All @@ -579,6 +622,7 @@ where
variant_desc.into_sig(),
&mut generics_checker,
msg_attr,
attrs_to_forward,
))
})
.collect();
Expand Down
71 changes: 71 additions & 0 deletions sylvia-derive/src/parser/attributes/attr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use proc_macro2::{Span, TokenStream};
use proc_macro_error::emit_error;
use syn::parse::{Error, Parse, ParseStream, Parser};
use syn::spanned::Spanned;
use syn::{Ident, MetaList, Result, Token};

use super::MsgType;

#[derive(Clone, Debug)]
pub struct VariantAttrForwarding {
pub attrs: TokenStream,
pub span: Span,
}

impl VariantAttrForwarding {
pub fn new(attr: &MetaList) -> Self {
VariantAttrForwarding {
attrs: attr.tokens.clone(),
span: attr.span(),
}
}
}

#[derive(Clone, Debug)]
pub struct MsgAttrForwarding {
pub msg_type: MsgType,
pub attrs: TokenStream,
}

impl MsgAttrForwarding {
pub fn new(attr: &MetaList) -> Result<Self> {
MsgAttrForwarding::parse
.parse2(attr.tokens.clone())
.map_err(|err| {
emit_error!(attr.tokens.span(), err);
err
})
}
}

impl Parse for MsgAttrForwarding {
fn parse(input: ParseStream) -> Result<Self> {
let error_msg =
"Expected attribute of the form: `#[sv::msg_attr(msg_type, attribute_to_forward)]`";
let msg_type: Ident = input
.parse()
.map_err(|err| Error::new(err.span(), error_msg))?;
let _: Token![,] = input
.parse()
.map_err(|err| Error::new(err.span(), error_msg))?;
let attrs: TokenStream = input
.parse()
.map_err(|err| Error::new(err.span(), error_msg))?;
if attrs.is_empty() {
return Err(Error::new(attrs.span(), error_msg));
}
let msg_type = match msg_type.to_string().as_str() {
"exec" => MsgType::Exec,
"query" => MsgType::Query,
"instantiate" => MsgType::Instantiate,
"migrate" => MsgType::Migrate,
"reply" => MsgType::Reply,
"sudo" => MsgType::Sudo,
_ => return Err(Error::new(
msg_type.span(),
"Invalid message type, expected one of: `exec`, `query`, `instantiate`, `migrate`, `reply` or `sudo`.",
))
};
Ok(Self { msg_type, attrs })
}
}
36 changes: 17 additions & 19 deletions sylvia-derive/src/parser/attributes/custom.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use proc_macro_error::emit_error;
use syn::parse::{Parse, ParseStream, Parser};
use syn::spanned::Spanned;
use syn::{parse_quote, Attribute, Ident, Result, Token, Type};
use syn::{parse_quote, Error, Ident, MetaList, Result, Token, Type};

use crate::crate_module;

Expand All @@ -12,14 +11,11 @@ pub struct Custom {
}

impl Custom {
pub fn new(attr: &Attribute) -> Result<Self> {
attr.meta
.require_list()
.and_then(|meta| Custom::parse.parse2(meta.tokens.clone()))
.map_err(|err| {
emit_error!(attr.span(), err);
err
})
pub fn new(attr: &MetaList) -> Result<Self> {
Custom::parse.parse2(attr.tokens.clone()).map_err(|err| {
emit_error!(err.span(), err);
err
})
}

pub fn msg_or_default(&self) -> Type {
Expand All @@ -43,15 +39,17 @@ impl Parse for Custom {
while !input.is_empty() {
let ty: Ident = input.parse()?;
let _: Token![=] = input.parse()?;
if ty == "msg" {
custom.msg = Some(input.parse()?)
} else if ty == "query" {
custom.query = Some(input.parse()?)
} else {
emit_error!(ty.span(), "Invalid custom type.";
note = ty.span() => "Expected `#[sv::custom(msg=SomeMsg, query=SomeQuery)]`"
);
};
match ty.to_string().as_str() {
"msg" => custom.msg = Some(input.parse()?),
"query" => custom.query = Some(input.parse()?),
_ => {
return Err(Error::new(
ty.span(),
"Invalid custom type.\n
= note: Expected `#[sv::custom(msg=SomeMsg, query=SomeQuery)]`.\n",
))
}
}
if !input.peek(Token![,]) {
break;
}
Expand Down
12 changes: 5 additions & 7 deletions sylvia-derive/src/parser/attributes/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use proc_macro_error::emit_error;
use syn::parse::{Parse, ParseStream, Parser};
use syn::spanned::Spanned;
use syn::{parse_quote, Attribute, Result, Type};
use syn::{parse_quote, MetaList, Result, Type};

use crate::crate_module;

Expand All @@ -20,12 +19,11 @@ impl Default for ContractErrorAttr {
}

impl ContractErrorAttr {
pub fn new(attr: &Attribute) -> Result<Self> {
attr.meta
.require_list()
.and_then(|meta| ContractErrorAttr::parse.parse2(meta.tokens.clone()))
pub fn new(attr: &MetaList) -> Result<Self> {
ContractErrorAttr::parse
.parse2(attr.tokens.clone())
.map_err(|err| {
emit_error!(attr.span(), err);
emit_error!(err.span(), err);
err
})
}
Expand Down
Loading

0 comments on commit 55e9667

Please sign in to comment.