From 6ae633c0fa4dfdc00711877ac2154a83bdc08611 Mon Sep 17 00:00:00 2001 From: Lucas Ransan Date: Thu, 26 Aug 2021 21:07:43 +0200 Subject: [PATCH] Implement server side named parameters --- derive/src/options.rs | 5 +- derive/src/rpc_trait.rs | 14 -- derive/src/to_delegate.rs | 125 +++++++++++------- derive/tests/macros.rs | 101 +++++++++++++- .../tests/ui/attr-named-params-on-server.rs | 10 -- .../ui/attr-named-params-on-server.stderr | 9 -- .../ui/trait-attr-named-params-on-server.rs | 7 - .../trait-attr-named-params-on-server.stderr | 7 - 8 files changed, 172 insertions(+), 106 deletions(-) delete mode 100644 derive/tests/ui/attr-named-params-on-server.rs delete mode 100644 derive/tests/ui/attr-named-params-on-server.stderr delete mode 100644 derive/tests/ui/trait-attr-named-params-on-server.rs delete mode 100644 derive/tests/ui/trait-attr-named-params-on-server.stderr diff --git a/derive/src/options.rs b/derive/src/options.rs index fe52c07dc..34780e474 100644 --- a/derive/src/options.rs +++ b/derive/src/options.rs @@ -62,10 +62,7 @@ impl DeriveOptions { options.enable_client = true; options.enable_server = true; } - if options.enable_server && options.params_style == ParamStyle::Named { - // This is not allowed at this time - panic!("Server code generation only supports `params = \"positional\"` (default) or `params = \"raw\" at this time.") - } + Ok(options) } } diff --git a/derive/src/rpc_trait.rs b/derive/src/rpc_trait.rs index 25bc004b6..8137f089e 100644 --- a/derive/src/rpc_trait.rs +++ b/derive/src/rpc_trait.rs @@ -1,5 +1,4 @@ use crate::options::DeriveOptions; -use crate::params_style::ParamStyle; use crate::rpc_attr::{AttributeKind, PubSubMethodKind, RpcMethodAttribute}; use crate::to_client::generate_client_module; use crate::to_delegate::{generate_trait_item_method, MethodRegistration, RpcMethod}; @@ -22,10 +21,6 @@ const MISSING_UNSUBSCRIBE_METHOD_ERR: &str = "Can't find unsubscribe method, expected a method annotated with `unsubscribe` \ e.g. `#[pubsub(subscription = \"hello\", unsubscribe, name = \"hello_unsubscribe\")]`"; -pub const USING_NAMED_PARAMS_WITH_SERVER_ERR: &str = - "`params = \"named\"` can only be used to generate a client (on a trait annotated with #[rpc(client)]). \ - At this time the server does not support named parameters."; - const RPC_MOD_NAME_PREFIX: &str = "rpc_impl_"; struct RpcTrait { @@ -222,12 +217,6 @@ fn rpc_wrapper_mod_name(rpc_trait: &syn::ItemTrait) -> syn::Ident { syn::Ident::new(&mod_name, proc_macro2::Span::call_site()) } -fn has_named_params(methods: &[RpcMethod]) -> bool { - methods - .iter() - .any(|method| method.attr.params_style == Some(ParamStyle::Named)) -} - pub fn crate_name(name: &str) -> Result { proc_macro_crate::crate_name(name) .map(|name| Ident::new(&name, Span::call_site())) @@ -264,9 +253,6 @@ pub fn rpc_impl(input: syn::Item, options: &DeriveOptions) -> Result Result { - let mut param_types: Vec<_> = self + let args = self .trait_item .sig .inputs .iter() .cloned() .filter_map(|arg| match arg { - syn::FnArg::Typed(ty) => Some(*ty.ty), + syn::FnArg::Typed(pat_type) => Some(pat_type), _ => None, }) - .collect(); + .enumerate(); + + let (special_args, fn_args) = { + // special args are those which are not passed directly via rpc params: metadata, subscriber + let mut special_args = vec![]; + let mut fn_args = vec![]; + + for (i, arg) in args { + if let Some(sarg) = Self::special_arg(i, arg.clone()) { + special_args.push(sarg); + } else { + fn_args.push(arg); + } + } + + (special_args, fn_args) + }; + + let param_types: Vec<_> = fn_args.iter().map(|arg| *arg.ty.clone()).collect(); + let arg_names: Vec<_> = fn_args.iter().map(|arg| *arg.pat.clone()).collect(); - // special args are those which are not passed directly via rpc params: metadata, subscriber - let special_args = Self::special_args(¶m_types); - param_types.retain(|ty| !special_args.iter().any(|(_, sty)| sty == ty)); if param_types.len() > TUPLE_FIELD_NAMES.len() { return Err(syn::Error::new_spanned( &self.trait_item, @@ -232,28 +248,38 @@ impl RpcMethod { .take(param_types.len()) .map(|name| ident(name)) .collect()); - let param_types = ¶m_types; - let parse_params = { - // last arguments that are `Option`-s are optional 'trailing' arguments - let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count(); - - if trailing_args_num != 0 { - self.params_with_trailing(trailing_args_num, param_types, tuple_fields) - } else if param_types.is_empty() { - quote! { let params = params.expect_no_params(); } - } else if self.attr.params_style == Some(ParamStyle::Raw) { - quote! { let params: _jsonrpc_core::Result<_> = Ok((params,)); } - } else if self.attr.params_style == Some(ParamStyle::Positional) { - quote! { let params = params.parse::<(#(#param_types, )*)>(); } - } else { - unimplemented!("Server side named parameters are not implemented"); + let parse_params = if param_types.is_empty() { + quote! { let params = params.expect_no_params(); } + } else { + match self.attr.params_style.as_ref().unwrap() { + ParamStyle::Raw => quote! { let params: _jsonrpc_core::Result<_> = Ok((params,)); }, + ParamStyle::Positional => { + // last arguments that are `Option`-s are optional 'trailing' arguments + let trailing_args_num = param_types.iter().rev().take_while(|t| is_option_type(t)).count(); + if trailing_args_num != 0 { + self.params_with_trailing(trailing_args_num, ¶m_types, tuple_fields) + } else { + quote! { let params = params.parse::<(#(#param_types, )*)>(); } + } + } + ParamStyle::Named => quote! { + #[derive(serde::Deserialize)] + #[allow(non_camel_case_types)] + struct __Params { + #( + #fn_args, + )* + } + let params = params.parse::<__Params>() + .map(|__Params { #(#arg_names, )* }| (#(#arg_names, )*)); + }, } }; let method_ident = self.ident(); let result = &self.trait_item.sig.output; - let extra_closure_args: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.0).collect(); - let extra_method_types: &Vec<_> = &special_args.iter().cloned().map(|arg| arg.1).collect(); + let extra_closure_args: Vec<_> = special_args.iter().map(|arg| *arg.pat.clone()).collect(); + let extra_method_types: Vec<_> = special_args.iter().map(|arg| *arg.ty.clone()).collect(); let closure_args = quote! { base, params, #(#extra_closure_args), * }; let method_sig = quote! { fn(&Self, #(#extra_method_types, ) * #(#param_types), *) #result }; @@ -301,34 +327,35 @@ impl RpcMethod { }) } - fn special_args(param_types: &[syn::Type]) -> Vec<(syn::Ident, syn::Type)> { - let meta_arg = param_types.first().and_then(|ty| { - if *ty == parse_quote!(Self::Metadata) { - Some(ty.clone()) - } else { - None - } - }); - let subscriber_arg = param_types.get(1).and_then(|ty| { - if let syn::Type::Path(path) = ty { - if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) { - Some(ty.clone()) - } else { - None + fn special_arg(index: usize, arg: syn::PatType) -> Option { + match index { + 0 if arg.ty == parse_quote!(Self::Metadata) => Some(syn::PatType { + pat: Box::new(syn::Pat::Ident(syn::PatIdent { + attrs: vec![], + by_ref: None, + mutability: None, + ident: ident(METADATA_CLOSURE_ARG), + subpat: None, + })), + ..arg + }), + 1 => match *arg.ty { + syn::Type::Path(ref path) if path.path.segments.iter().any(|s| s.ident == SUBSCRIBER_TYPE_IDENT) => { + Some(syn::PatType { + pat: Box::new(syn::Pat::Ident(syn::PatIdent { + attrs: vec![], + by_ref: None, + mutability: None, + ident: ident(SUBSCRIBER_CLOSURE_ARG), + subpat: None, + })), + ..arg + }) } - } else { - None - } - }); - - let mut special_args = Vec::new(); - if let Some(meta) = meta_arg { - special_args.push((ident(METADATA_CLOSURE_ARG), meta)); - } - if let Some(subscriber) = subscriber_arg { - special_args.push((ident(SUBSCRIBER_CLOSURE_ARG), subscriber)); + _ => None, + }, + _ => None, } - special_args } fn params_with_trailing( diff --git a/derive/tests/macros.rs b/derive/tests/macros.rs index 1f4483672..eca4cf92c 100644 --- a/derive/tests/macros.rs +++ b/derive/tests/macros.rs @@ -1,6 +1,7 @@ use jsonrpc_core::types::params::Params; use jsonrpc_core::{IoHandler, Response}; use jsonrpc_derive::rpc; +use serde; use serde_json; pub enum MyError {} @@ -14,6 +15,8 @@ type Result = ::std::result::Result; #[rpc] pub trait Rpc { + type Metadata; + /// Returns a protocol version. #[rpc(name = "protocolVersion")] fn protocol_version(&self) -> Result; @@ -30,6 +33,18 @@ pub trait Rpc { #[rpc(name = "raw", params = "raw")] fn raw(&self, params: Params) -> Result; + /// Adds two numbers and returns a result. + #[rpc(name = "named_add", params = "named")] + fn named_add(&self, a: u64, b: u64) -> Result; + + /// Adds one or two numbers and returns a result. + #[rpc(name = "option_named_add", params = "named")] + fn option_named_add(&self, a: u64, b: Option) -> Result; + + /// Adds two numbers and returns a result. + #[rpc(meta, name = "meta_named_add", params = "named")] + fn meta_named_add(&self, meta: Self::Metadata, a: u64, b: u64) -> Result; + /// Handles a notification. #[rpc(name = "notify")] fn notify(&self, a: u64); @@ -39,6 +54,8 @@ pub trait Rpc { struct RpcImpl; impl Rpc for RpcImpl { + type Metadata = Metadata; + fn protocol_version(&self) -> Result { Ok("version1".into()) } @@ -55,14 +72,30 @@ impl Rpc for RpcImpl { Ok("OK".into()) } + fn named_add(&self, a: u64, b: u64) -> Result { + Ok(a + b) + } + + fn option_named_add(&self, a: u64, b: Option) -> Result { + Ok(a + b.unwrap_or_default()) + } + + fn meta_named_add(&self, _meta: Self::Metadata, a: u64, b: u64) -> Result { + Ok(a + b) + } + fn notify(&self, a: u64) { println!("Received `notify` with value: {}", a); } } +#[derive(Clone, Default)] +struct Metadata; +impl jsonrpc_core::Metadata for Metadata {} + #[test] fn should_accept_empty_array_as_no_params() { - let mut io = IoHandler::new(); + let mut io = IoHandler::default(); let rpc = RpcImpl::default(); io.extend_with(rpc.to_delegate()); @@ -94,7 +127,7 @@ fn should_accept_empty_array_as_no_params() { #[test] fn should_accept_single_param() { - let mut io = IoHandler::new(); + let mut io = IoHandler::default(); let rpc = RpcImpl::default(); io.extend_with(rpc.to_delegate()); @@ -120,7 +153,7 @@ fn should_accept_single_param() { #[test] fn should_accept_multiple_params() { - let mut io = IoHandler::new(); + let mut io = IoHandler::default(); let rpc = RpcImpl::default(); io.extend_with(rpc.to_delegate()); @@ -146,7 +179,7 @@ fn should_accept_multiple_params() { #[test] fn should_use_method_name_aliases() { - let mut io = IoHandler::new(); + let mut io = IoHandler::default(); let rpc = RpcImpl::default(); io.extend_with(rpc.to_delegate()); @@ -187,7 +220,7 @@ fn should_use_method_name_aliases() { #[test] fn should_accept_any_raw_params() { - let mut io = IoHandler::new(); + let mut io = IoHandler::default(); let rpc = RpcImpl::default(); io.extend_with(rpc.to_delegate()); @@ -222,9 +255,65 @@ fn should_accept_any_raw_params() { assert_eq!(expected, result4); } +#[test] +fn should_accept_named_params() { + let mut io = IoHandler::default(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"named_add","params":{"a":1,"b":2}}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"named_add","params":{"b":2,"a":1}}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + + let expected = r#"{ + "jsonrpc": "2.0", + "result": 3, + "id": 1 + }"#; + let expected: Response = serde_json::from_str(expected).unwrap(); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!(expected, result1); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!(expected, result2); +} + +#[test] +fn should_accept_option_named_params() { + let mut io = IoHandler::default(); + let rpc = RpcImpl::default(); + io.extend_with(rpc.to_delegate()); + + // when + let req1 = r#"{"jsonrpc":"2.0","id":1,"method":"option_named_add","params":{"a":1,"b":2}}"#; + let req2 = r#"{"jsonrpc":"2.0","id":1,"method":"option_named_add","params":{"a":3}}"#; + + let res1 = io.handle_request_sync(req1); + let res2 = io.handle_request_sync(req2); + + let expected = r#"{ + "jsonrpc": "2.0", + "result": 3, + "id": 1 + }"#; + let expected: Response = serde_json::from_str(expected).unwrap(); + + // then + let result1: Response = serde_json::from_str(&res1.unwrap()).unwrap(); + assert_eq!(expected, result1); + + let result2: Response = serde_json::from_str(&res2.unwrap()).unwrap(); + assert_eq!(expected, result2); +} + #[test] fn should_accept_only_notifications() { - let mut io = IoHandler::new(); + let mut io = IoHandler::default(); let rpc = RpcImpl::default(); io.extend_with(rpc.to_delegate()); diff --git a/derive/tests/ui/attr-named-params-on-server.rs b/derive/tests/ui/attr-named-params-on-server.rs deleted file mode 100644 index 074995642..000000000 --- a/derive/tests/ui/attr-named-params-on-server.rs +++ /dev/null @@ -1,10 +0,0 @@ -use jsonrpc_derive::rpc; - -#[rpc] -pub trait Rpc { - /// Returns a protocol version - #[rpc(name = "add", params = "named")] - fn add(&self, a: u32, b: u32) -> Result; -} - -fn main() {} diff --git a/derive/tests/ui/attr-named-params-on-server.stderr b/derive/tests/ui/attr-named-params-on-server.stderr deleted file mode 100644 index 41ccc852a..000000000 --- a/derive/tests/ui/attr-named-params-on-server.stderr +++ /dev/null @@ -1,9 +0,0 @@ -error: `params = "named"` can only be used to generate a client (on a trait annotated with #[rpc(client)]). At this time the server does not support named parameters. - --> $DIR/attr-named-params-on-server.rs:4:1 - | -4 | / pub trait Rpc { -5 | | /// Returns a protocol version -6 | | #[rpc(name = "add", params = "named")] -7 | | fn add(&self, a: u32, b: u32) -> Result; -8 | | } - | |_^ diff --git a/derive/tests/ui/trait-attr-named-params-on-server.rs b/derive/tests/ui/trait-attr-named-params-on-server.rs deleted file mode 100644 index 302768fcf..000000000 --- a/derive/tests/ui/trait-attr-named-params-on-server.rs +++ /dev/null @@ -1,7 +0,0 @@ -use jsonrpc_derive::rpc; - -#[rpc(server, params = "named")] -pub trait Rpc { -} - -fn main() {} diff --git a/derive/tests/ui/trait-attr-named-params-on-server.stderr b/derive/tests/ui/trait-attr-named-params-on-server.stderr deleted file mode 100644 index c44d44465..000000000 --- a/derive/tests/ui/trait-attr-named-params-on-server.stderr +++ /dev/null @@ -1,7 +0,0 @@ -error: custom attribute panicked - --> $DIR/trait-attr-named-params-on-server.rs:3:1 - | -3 | #[rpc(server, params = "named")] - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - | - = help: message: Server code generation only supports `params = "positional"` (default) or `params = "raw" at this time.