Skip to content

Commit

Permalink
Access D and E via associated type instead of parsing generics
Browse files Browse the repository at this point in the history
  • Loading branch information
kangalio committed Feb 23, 2021
1 parent 6a21f94 commit 724b7c1
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 118 deletions.
15 changes: 5 additions & 10 deletions command_attr/src/impl_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ pub fn impl_check(attr: TokenStream, input: TokenStream) -> Result<TokenStream>
parse2::<syn::LitStr>(attr)?.value()
};

let (_, _, data, error) = utils::parse_generics(&fun.sig)?;
let (ctx_param, _msg_param) = utils::get_first_two_parameters(&fun.sig)?;
let options = Options::parse(&mut fun.attrs)?;

let builder_fn = builder_fn(&data, &error, &mut fun, &name, &options);
let ctx_ty = *utils::get_pat_type(&ctx_param)?.ty.clone();
let builder_fn = builder_fn(&ctx_ty, &mut fun, &name, &options);

let hook_macro = paths::hook_macro();

Expand All @@ -37,13 +38,7 @@ pub fn impl_check(attr: TokenStream, input: TokenStream) -> Result<TokenStream>
Ok(result)
}

fn builder_fn(
data: &Type,
error: &Type,
function: &mut ItemFn,
name: &str,
options: &Options,
) -> TokenStream {
fn builder_fn(ctx_ty: &Type, function: &mut ItemFn, name: &str, options: &Options) -> TokenStream {
// Derive the name of the builder from the check function.
// Prepend the check function's name with an underscore to avoid name
// collisions.
Expand All @@ -52,7 +47,7 @@ fn builder_fn(
function.sig.ident = function_name.clone();

let check_builder = paths::check_builder_type();
let check = paths::check_type(data, error);
let check = paths::check_type(ctx_ty);

let vis = &function.vis;
let external = &function.attrs;
Expand Down
79 changes: 46 additions & 33 deletions command_attr/src/impl_command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,20 @@ pub fn impl_command(attr: TokenStream, input: TokenStream) -> Result<TokenStream
parse2::<AttributeArgs>(attr)?.0
};

let (ctx_name, msg_name, data, error) = utils::parse_generics(&fun.sig)?;
let (ctx_param, msg_param) = utils::get_first_two_parameters(&fun.sig)?;
let ctx_param = utils::get_pat_type(&ctx_param)?.clone();
let msg_param = utils::get_pat_type(&msg_param)?.clone();

let options = Options::parse(&mut fun.attrs)?;

parse_arguments(ctx_name, msg_name, &mut fun, &options)?;
inject_argument_parsing_code(
utils::get_ident(&ctx_param.pat)?,
utils::get_ident(&msg_param.pat)?,
&mut fun,
&options,
)?;

let builder_fn = builder_fn(&data, &error, &mut fun, names, &options);
let builder_fn = builder_fn(&ctx_param.ty, &mut fun, names, &options);

let hook_macro = paths::hook_macro();

Expand All @@ -39,9 +47,10 @@ pub fn impl_command(attr: TokenStream, input: TokenStream) -> Result<TokenStream
Ok(result)
}

/// Replace the passed in function with a "builder function" that points to the renamed original
/// function
fn builder_fn(
data: &Type,
error: &Type,
ctx_ty: &Type,
function: &mut ItemFn,
mut names: Vec<String>,
options: &Options,
Expand All @@ -57,7 +66,7 @@ fn builder_fn(
function.sig.ident = function_name.clone();

let command_builder = paths::command_builder_type();
let command = paths::command_type(data, error);
let command = paths::command_type(ctx_ty);

let vis = &function.vis;
let external = &function.attrs;
Expand All @@ -74,52 +83,56 @@ fn builder_fn(
}
}

fn parse_arguments(
// Injects a code block at the top of the user-written command function that parses the command
// parameters from the FrameworkContext and Message function arguments
fn inject_argument_parsing_code(
ctx_name: Ident,
msg_name: Ident,
function: &mut ItemFn,
options: &Options,
) -> Result<()> {
let mut arguments = Vec::new();

while function.sig.inputs.len() > 2 {
let argument = function.sig.inputs.pop().unwrap().into_value();

arguments.push(Argument::new(argument)?);
}

if !arguments.is_empty() {
arguments.reverse();
// If this command has no parameters, don't bother injecting any parsing code
if arguments.is_empty() {
return Ok(());
}

check_arguments(&arguments)?;
arguments.reverse();

let delimiter = options.delimiter.as_ref().map_or(" ", String::as_str);
let asegsty = paths::argument_segments_type();
validate_arguments_order(&arguments)?;

let b = &function.block;
let delimiter = options.delimiter.as_ref().map_or(" ", String::as_str);
let asegsty = paths::argument_segments_type();

let argument_names = arguments.iter().map(|arg| &arg.name).collect::<Vec<_>>();
let argument_tys = arguments.iter().map(|arg| &arg.ty).collect::<Vec<_>>();
let argument_kinds = arguments.iter().map(|arg| &arg.kind).collect::<Vec<_>>();
let b = &function.block;

function.block = parse2(quote! {{
let (#(#argument_names),*) = {
// Place the segments into its scope to allow mutation of `Context::args`
// afterwards, as `ArgumentSegments` holds a reference to the source string.
let mut __args = #asegsty::new(&#ctx_name.args, #delimiter);
let argument_names = arguments.iter().map(|arg| &arg.name).collect::<Vec<_>>();
let argument_tys = arguments.iter().map(|arg| &arg.ty).collect::<Vec<_>>();
let argument_kinds = arguments.iter().map(|arg| &arg.kind).collect::<Vec<_>>();

#(let #argument_names: #argument_tys = #argument_kinds(
&#ctx_name.serenity_ctx,
&#msg_name,
&mut __args
).await?;)*
function.block = parse2(quote! {{
let (#(#argument_names),*) = {
// Place the segments into its scope to allow mutation of `Context::args`
// afterwards, as `ArgumentSegments` holds a reference to the source string.
let mut __args = #asegsty::new(&#ctx_name.args, #delimiter);

(#(#argument_names),*)
};
#(let #argument_names: #argument_tys = #argument_kinds(
&#ctx_name.serenity_ctx,
&#msg_name,
&mut __args
).await?;)*

#b
}})?;
}
(#(#argument_names),*)
};

#b
}})?;

Ok(())
}
Expand All @@ -134,7 +147,7 @@ fn parse_arguments(
/// - a list of arguments that only has one rest argument parameter, if present.
/// - a list of arguments that only has one variadic argument parameter or one rest
/// argument parameter.
fn check_arguments(args: &[Argument]) -> Result<()> {
fn validate_arguments_order(args: &[Argument]) -> Result<()> {
let mut last_arg: Option<&Argument> = None;

for arg in args {
Expand Down
30 changes: 10 additions & 20 deletions command_attr/src/paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,12 @@ fn to_path(tokens: TokenStream) -> Path {
parse2(tokens).unwrap()
}

fn to_type(tokens: TokenStream) -> Box<Type> {
parse2(tokens).unwrap()
}

pub fn default_data_type() -> Box<Type> {
to_type(quote! {
serenity_framework::DefaultData
})
}

pub fn default_error_type() -> Box<Type> {
to_type(quote! {
serenity_framework::DefaultError
})
}

pub fn command_type(data: &Type, error: &Type) -> Path {
pub fn command_type(ctx: &Type) -> Path {
to_path(quote! {
serenity_framework::command::Command<#data, #error>
serenity_framework::command::Command<
<#ctx as serenity_framework::_DataErrorHack>::D,
<#ctx as serenity_framework::_DataErrorHack>::E,
>
})
}

Expand Down Expand Up @@ -70,9 +57,12 @@ pub fn rest_argument_func() -> Path {
})
}

pub fn check_type(data: &Type, error: &Type) -> Path {
pub fn check_type(ctx: &Type) -> Path {
to_path(quote! {
serenity_framework::check::Check<#data, #error>
serenity_framework::check::Check<
<#ctx as serenity_framework::_DataErrorHack>::D,
<#ctx as serenity_framework::_DataErrorHack>::E,
>
})
}

Expand Down
58 changes: 3 additions & 55 deletions command_attr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ use proc_macro2::{Ident, TokenStream};
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::{Attribute, Error, FnArg, GenericArgument, Lit, LitStr, Meta};
use syn::{NestedMeta, Pat, PatType, Path, PathArguments, Result, Signature, Token, Type};

use crate::paths::{default_data_type, default_error_type};
use syn::{Attribute, Error, FnArg, Lit, LitStr, Meta};
use syn::{NestedMeta, Pat, PatType, Path, Result, Signature, Token, Type};

pub struct AttributeArgs(pub Vec<String>);

Expand Down Expand Up @@ -166,37 +164,7 @@ pub fn parse_bool(attr: &Attr) -> Result<bool> {
})
}

pub fn parse_generics(sig: &Signature) -> Result<(Ident, Ident, Box<Type>, Box<Type>)> {
let (ctx, msg) = get_first_two_parameters(sig)?;

let msg_indent = get_ident(&get_pat_type(msg)?.pat)?;

let ctx_binding = get_pat_type(ctx)?;
let ctx_ident = get_ident(&ctx_binding.pat)?;
let path = get_path(&ctx_binding.ty)?;
let mut arguments = get_generic_arguments(path)?;

let default_data = default_data_type();
let default_error = default_error_type();

let data = match arguments.next() {
Some(GenericArgument::Lifetime(_)) => match arguments.next() {
Some(arg) => get_generic_type(arg)?,
None => default_data,
},
Some(arg) => get_generic_type(arg)?,
None => default_data,
};

let error = match arguments.next() {
Some(arg) => get_generic_type(arg)?,
None => default_error,
};

Ok((ctx_ident, msg_indent, data, error))
}

fn get_first_two_parameters(sig: &Signature) -> Result<(&FnArg, &FnArg)> {
pub fn get_first_two_parameters(sig: &Signature) -> Result<(&FnArg, &FnArg)> {
let mut parameters = sig.inputs.iter();
match (parameters.next(), parameters.next()) {
(Some(first), Some(second)) => Ok((first, second)),
Expand Down Expand Up @@ -228,23 +196,3 @@ pub fn get_path(t: &Type) -> Result<&Path> {
_ => Err(Error::new(t.span(), "parameter must be a path to a context type")),
}
}

fn get_generic_arguments(path: &Path) -> Result<impl Iterator<Item = &GenericArgument> + '_> {
match &path.segments.last().unwrap().arguments {
PathArguments::None => Ok(Vec::new().into_iter()),
PathArguments::AngleBracketed(arguments) => {
Ok(arguments.args.iter().collect::<Vec<_>>().into_iter())
},
_ => Err(Error::new(
path.span(),
"context type cannot have generic parameters in parenthesis",
)),
}
}

fn get_generic_type(arg: &GenericArgument) -> Result<Box<Type>> {
match arg {
GenericArgument::Type(t) => Ok(Box::new(t.clone())),
_ => Err(Error::new(arg.span(), "generic parameter must be a type")),
}
}
16 changes: 16 additions & 0 deletions framework/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,19 @@ impl<D, E> Framework<D, E> {
Ok((ctx, func))
}
}

// Required as a reliable way for command_attr to know the data and error types
// (see https://github.com/serenity-rs/framework/issues/27)
#[doc(hidden)]
pub trait _DataErrorHack {
type D;
type E;
}
impl<D, E> _DataErrorHack for Framework<D, E> {
type D = D;
type E = E;
}
impl<D, E> _DataErrorHack for Context<D, E> {
type D = D;
type E = E;
}

0 comments on commit 724b7c1

Please sign in to comment.