diff --git a/command_attr/src/impl_check/mod.rs b/command_attr/src/impl_check/mod.rs index d5e5656..c39d7e8 100644 --- a/command_attr/src/impl_check/mod.rs +++ b/command_attr/src/impl_check/mod.rs @@ -19,10 +19,11 @@ pub fn impl_check(attr: TokenStream, input: TokenStream) -> Result parse2::(attr)?.value() }; - let (_, _, data, error) = utils::parse_generics(&fun.sig)?; + let (ctx_param, _msg_param) = utils::get_first_two_parameters(&fun.sig)?; let options = Options::parse(&mut fun.attrs)?; - let builder_fn = builder_fn(&data, &error, &mut fun, &name, &options); + let ctx_ty = *utils::get_pat_type(&ctx_param)?.ty.clone(); + let builder_fn = builder_fn(&ctx_ty, &mut fun, &name, &options); let hook_macro = paths::hook_macro(); @@ -37,13 +38,7 @@ pub fn impl_check(attr: TokenStream, input: TokenStream) -> Result Ok(result) } -fn builder_fn( - data: &Type, - error: &Type, - function: &mut ItemFn, - name: &str, - options: &Options, -) -> TokenStream { +fn builder_fn(ctx_ty: &Type, function: &mut ItemFn, name: &str, options: &Options) -> TokenStream { // Derive the name of the builder from the check function. // Prepend the check function's name with an underscore to avoid name // collisions. @@ -52,7 +47,7 @@ fn builder_fn( function.sig.ident = function_name.clone(); let check_builder = paths::check_builder_type(); - let check = paths::check_type(data, error); + let check = paths::check_type(ctx_ty); let vis = &function.vis; let external = &function.attrs; diff --git a/command_attr/src/impl_command/mod.rs b/command_attr/src/impl_command/mod.rs index 3d5044c..cb77b58 100644 --- a/command_attr/src/impl_command/mod.rs +++ b/command_attr/src/impl_command/mod.rs @@ -19,12 +19,20 @@ pub fn impl_command(attr: TokenStream, input: TokenStream) -> Result(attr)?.0 }; - let (ctx_name, msg_name, data, error) = utils::parse_generics(&fun.sig)?; + let (ctx_param, msg_param) = utils::get_first_two_parameters(&fun.sig)?; + let ctx_param = utils::get_pat_type(&ctx_param)?.clone(); + let msg_param = utils::get_pat_type(&msg_param)?.clone(); + let options = Options::parse(&mut fun.attrs)?; - parse_arguments(ctx_name, msg_name, &mut fun, &options)?; + inject_argument_parsing_code( + utils::get_ident(&ctx_param.pat)?, + utils::get_ident(&msg_param.pat)?, + &mut fun, + &options, + )?; - let builder_fn = builder_fn(&data, &error, &mut fun, names, &options); + let builder_fn = builder_fn(&ctx_param.ty, &mut fun, names, &options); let hook_macro = paths::hook_macro(); @@ -39,9 +47,10 @@ pub fn impl_command(attr: TokenStream, input: TokenStream) -> Result, options: &Options, @@ -57,7 +66,7 @@ fn builder_fn( function.sig.ident = function_name.clone(); let command_builder = paths::command_builder_type(); - let command = paths::command_type(data, error); + let command = paths::command_type(ctx_ty); let vis = &function.vis; let external = &function.attrs; @@ -74,52 +83,56 @@ fn builder_fn( } } -fn parse_arguments( +// Injects a code block at the top of the user-written command function that parses the command +// parameters from the FrameworkContext and Message function arguments +fn inject_argument_parsing_code( ctx_name: Ident, msg_name: Ident, function: &mut ItemFn, options: &Options, ) -> Result<()> { let mut arguments = Vec::new(); - while function.sig.inputs.len() > 2 { let argument = function.sig.inputs.pop().unwrap().into_value(); arguments.push(Argument::new(argument)?); } - if !arguments.is_empty() { - arguments.reverse(); + // If this command has no parameters, don't bother injecting any parsing code + if arguments.is_empty() { + return Ok(()); + } - check_arguments(&arguments)?; + arguments.reverse(); - let delimiter = options.delimiter.as_ref().map_or(" ", String::as_str); - let asegsty = paths::argument_segments_type(); + validate_arguments_order(&arguments)?; - let b = &function.block; + let delimiter = options.delimiter.as_ref().map_or(" ", String::as_str); + let asegsty = paths::argument_segments_type(); - let argument_names = arguments.iter().map(|arg| &arg.name).collect::>(); - let argument_tys = arguments.iter().map(|arg| &arg.ty).collect::>(); - let argument_kinds = arguments.iter().map(|arg| &arg.kind).collect::>(); + let b = &function.block; - function.block = parse2(quote! {{ - let (#(#argument_names),*) = { - // Place the segments into its scope to allow mutation of `Context::args` - // afterwards, as `ArgumentSegments` holds a reference to the source string. - let mut __args = #asegsty::new(&#ctx_name.args, #delimiter); + let argument_names = arguments.iter().map(|arg| &arg.name).collect::>(); + let argument_tys = arguments.iter().map(|arg| &arg.ty).collect::>(); + let argument_kinds = arguments.iter().map(|arg| &arg.kind).collect::>(); - #(let #argument_names: #argument_tys = #argument_kinds( - &#ctx_name.serenity_ctx, - &#msg_name, - &mut __args - ).await?;)* + function.block = parse2(quote! {{ + let (#(#argument_names),*) = { + // Place the segments into its scope to allow mutation of `Context::args` + // afterwards, as `ArgumentSegments` holds a reference to the source string. + let mut __args = #asegsty::new(&#ctx_name.args, #delimiter); - (#(#argument_names),*) - }; + #(let #argument_names: #argument_tys = #argument_kinds( + &#ctx_name.serenity_ctx, + &#msg_name, + &mut __args + ).await?;)* - #b - }})?; - } + (#(#argument_names),*) + }; + + #b + }})?; Ok(()) } @@ -134,7 +147,7 @@ fn parse_arguments( /// - a list of arguments that only has one rest argument parameter, if present. /// - a list of arguments that only has one variadic argument parameter or one rest /// argument parameter. -fn check_arguments(args: &[Argument]) -> Result<()> { +fn validate_arguments_order(args: &[Argument]) -> Result<()> { let mut last_arg: Option<&Argument> = None; for arg in args { diff --git a/command_attr/src/paths.rs b/command_attr/src/paths.rs index d189ed0..905bd64 100644 --- a/command_attr/src/paths.rs +++ b/command_attr/src/paths.rs @@ -6,25 +6,12 @@ fn to_path(tokens: TokenStream) -> Path { parse2(tokens).unwrap() } -fn to_type(tokens: TokenStream) -> Box { - parse2(tokens).unwrap() -} - -pub fn default_data_type() -> Box { - to_type(quote! { - serenity_framework::DefaultData - }) -} - -pub fn default_error_type() -> Box { - to_type(quote! { - serenity_framework::DefaultError - }) -} - -pub fn command_type(data: &Type, error: &Type) -> Path { +pub fn command_type(ctx: &Type) -> Path { to_path(quote! { - serenity_framework::command::Command<#data, #error> + serenity_framework::command::Command< + <#ctx as serenity_framework::_DataErrorHack>::D, + <#ctx as serenity_framework::_DataErrorHack>::E, + > }) } @@ -70,9 +57,12 @@ pub fn rest_argument_func() -> Path { }) } -pub fn check_type(data: &Type, error: &Type) -> Path { +pub fn check_type(ctx: &Type) -> Path { to_path(quote! { - serenity_framework::check::Check<#data, #error> + serenity_framework::check::Check< + <#ctx as serenity_framework::_DataErrorHack>::D, + <#ctx as serenity_framework::_DataErrorHack>::E, + > }) } diff --git a/command_attr/src/utils.rs b/command_attr/src/utils.rs index fa047fc..529b237 100644 --- a/command_attr/src/utils.rs +++ b/command_attr/src/utils.rs @@ -4,10 +4,8 @@ use proc_macro2::{Ident, TokenStream}; use quote::{quote, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned; -use syn::{Attribute, Error, FnArg, GenericArgument, Lit, LitStr, Meta}; -use syn::{NestedMeta, Pat, PatType, Path, PathArguments, Result, Signature, Token, Type}; - -use crate::paths::{default_data_type, default_error_type}; +use syn::{Attribute, Error, FnArg, Lit, LitStr, Meta}; +use syn::{NestedMeta, Pat, PatType, Path, Result, Signature, Token, Type}; pub struct AttributeArgs(pub Vec); @@ -166,37 +164,7 @@ pub fn parse_bool(attr: &Attr) -> Result { }) } -pub fn parse_generics(sig: &Signature) -> Result<(Ident, Ident, Box, Box)> { - let (ctx, msg) = get_first_two_parameters(sig)?; - - let msg_indent = get_ident(&get_pat_type(msg)?.pat)?; - - let ctx_binding = get_pat_type(ctx)?; - let ctx_ident = get_ident(&ctx_binding.pat)?; - let path = get_path(&ctx_binding.ty)?; - let mut arguments = get_generic_arguments(path)?; - - let default_data = default_data_type(); - let default_error = default_error_type(); - - let data = match arguments.next() { - Some(GenericArgument::Lifetime(_)) => match arguments.next() { - Some(arg) => get_generic_type(arg)?, - None => default_data, - }, - Some(arg) => get_generic_type(arg)?, - None => default_data, - }; - - let error = match arguments.next() { - Some(arg) => get_generic_type(arg)?, - None => default_error, - }; - - Ok((ctx_ident, msg_indent, data, error)) -} - -fn get_first_two_parameters(sig: &Signature) -> Result<(&FnArg, &FnArg)> { +pub fn get_first_two_parameters(sig: &Signature) -> Result<(&FnArg, &FnArg)> { let mut parameters = sig.inputs.iter(); match (parameters.next(), parameters.next()) { (Some(first), Some(second)) => Ok((first, second)), @@ -228,23 +196,3 @@ pub fn get_path(t: &Type) -> Result<&Path> { _ => Err(Error::new(t.span(), "parameter must be a path to a context type")), } } - -fn get_generic_arguments(path: &Path) -> Result + '_> { - match &path.segments.last().unwrap().arguments { - PathArguments::None => Ok(Vec::new().into_iter()), - PathArguments::AngleBracketed(arguments) => { - Ok(arguments.args.iter().collect::>().into_iter()) - }, - _ => Err(Error::new( - path.span(), - "context type cannot have generic parameters in parenthesis", - )), - } -} - -fn get_generic_type(arg: &GenericArgument) -> Result> { - match arg { - GenericArgument::Type(t) => Ok(Box::new(t.clone())), - _ => Err(Error::new(arg.span(), "generic parameter must be a type")), - } -} diff --git a/framework/src/lib.rs b/framework/src/lib.rs index 0ab2819..6e3e912 100644 --- a/framework/src/lib.rs +++ b/framework/src/lib.rs @@ -185,3 +185,19 @@ impl Framework { Ok((ctx, func)) } } + +// Required as a reliable way for command_attr to know the data and error types +// (see https://github.com/serenity-rs/framework/issues/27) +#[doc(hidden)] +pub trait _DataErrorHack { + type D; + type E; +} +impl _DataErrorHack for Framework { + type D = D; + type E = E; +} +impl _DataErrorHack for Context { + type D = D; + type E = E; +}