diff --git a/.rustfmt.toml b/.rustfmt.toml index 4bab087d205..72a3cb9cd08 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,6 +1,4 @@ unstable_features = true -max_width = 120 -struct_lit_width = 120 -tab_spaces = 2 +struct_lit_width = 60 imports_granularity = "Module" group_imports = "StdExternalCrate" diff --git a/autogen/src/main.rs b/autogen/src/main.rs index 132889d9186..5da1437c63c 100644 --- a/autogen/src/main.rs +++ b/autogen/src/main.rs @@ -12,92 +12,96 @@ static JSON_SCHEMA_FILE: &'static str = "../examples/.tailcallrc.schema.json"; #[tokio::main] async fn main() { - logger_init(); - let args: Vec = env::args().collect(); - let arg = args.get(1); + logger_init(); + let args: Vec = env::args().collect(); + let arg = args.get(1); - if arg.is_none() { - log::error!("An argument required, you can pass either `fix` or `check` argument"); - return; - } - match arg.unwrap().as_str() { - "fix" => { - let result = mode_fix().await; - if let Err(e) = result { - log::error!("{}", e); - exit(1); - } + if arg.is_none() { + log::error!("An argument required, you can pass either `fix` or `check` argument"); + return; } - "check" => { - let result = mode_check().await; - if let Err(e) = result { - log::error!("{}", e); - exit(1); - } + match arg.unwrap().as_str() { + "fix" => { + let result = mode_fix().await; + if let Err(e) = result { + log::error!("{}", e); + exit(1); + } + } + "check" => { + let result = mode_check().await; + if let Err(e) = result { + log::error!("{}", e); + exit(1); + } + } + &_ => { + log::error!("Unknown argument, you can pass either `fix` or `check` argument"); + return; + } } - &_ => { - log::error!("Unknown argument, you can pass either `fix` or `check` argument"); - return; - } - } } async fn mode_check() -> Result<()> { - let json_schema = get_file_path(); - let file_io = init_file(); - let content = file_io - .read(json_schema.to_str().ok_or(anyhow!("Unable to determine path"))?) - .await?; - let content = serde_json::from_str::(&content)?; - let schema = get_updated_json().await?; - match content.eq(&schema) { - true => Ok(()), - false => Err(anyhow!("Schema mismatch")), - } + let json_schema = get_file_path(); + let file_io = init_file(); + let content = file_io + .read( + json_schema + .to_str() + .ok_or(anyhow!("Unable to determine path"))?, + ) + .await?; + let content = serde_json::from_str::(&content)?; + let schema = get_updated_json().await?; + match content.eq(&schema) { + true => Ok(()), + false => Err(anyhow!("Schema mismatch")), + } } async fn mode_fix() -> Result<()> { - update_json().await?; - // update_gql().await?; - Ok(()) + update_json().await?; + // update_gql().await?; + Ok(()) } async fn update_json() -> Result<()> { - let path = get_file_path(); - let schema = serde_json::to_string_pretty(&get_updated_json().await?)?; - let file_io = init_file(); - log::info!("Updating JSON Schema: {}", path.to_str().unwrap()); - file_io - .write( - path.to_str().ok_or(anyhow!("Unable to determine path"))?, - schema.as_bytes(), - ) - .await?; - Ok(()) + let path = get_file_path(); + let schema = serde_json::to_string_pretty(&get_updated_json().await?)?; + let file_io = init_file(); + log::info!("Updating JSON Schema: {}", path.to_str().unwrap()); + file_io + .write( + path.to_str().ok_or(anyhow!("Unable to determine path"))?, + schema.as_bytes(), + ) + .await?; + Ok(()) } fn get_file_path() -> PathBuf { - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(JSON_SCHEMA_FILE) + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(JSON_SCHEMA_FILE) } async fn get_updated_json() -> Result { - let schema = schemars::schema_for!(Config); - let schema = json!(schema); - Ok(schema) + let schema = schemars::schema_for!(Config); + let schema = json!(schema); + Ok(schema) } fn logger_init() { - // set the log level - const LONG_ENV_FILTER_VAR_NAME: &str = "TAILCALL_SCHEMA_LOG_LEVEL"; - const SHORT_ENV_FILTER_VAR_NAME: &str = "TC_SCHEMA_LOG_LEVEL"; + // set the log level + const LONG_ENV_FILTER_VAR_NAME: &str = "TAILCALL_SCHEMA_LOG_LEVEL"; + const SHORT_ENV_FILTER_VAR_NAME: &str = "TC_SCHEMA_LOG_LEVEL"; - // Select which env variable to use for the log level filter. This is because filter_or doesn't allow picking between multiple env_var for the filter value - let filter_env_name = env::var(LONG_ENV_FILTER_VAR_NAME) - .map(|_| LONG_ENV_FILTER_VAR_NAME) - .unwrap_or_else(|_| SHORT_ENV_FILTER_VAR_NAME); + // Select which env variable to use for the log level filter. This is because filter_or doesn't allow picking between multiple env_var for the filter value + let filter_env_name = env::var(LONG_ENV_FILTER_VAR_NAME) + .map(|_| LONG_ENV_FILTER_VAR_NAME) + .unwrap_or_else(|_| SHORT_ENV_FILTER_VAR_NAME); - // use the log level from the env if there is one, otherwise use the default. - let env = env_logger::Env::new().filter_or(filter_env_name, "info"); + // use the log level from the env if there is one, otherwise use the default. + let env = env_logger::Env::new().filter_or(filter_env_name, "info"); - env_logger::Builder::from_env(env).init(); + env_logger::Builder::from_env(env).init(); } diff --git a/benches/data_loader_bench.rs b/benches/data_loader_bench.rs index 4cd66638c0c..c667fd06818 100644 --- a/benches/data_loader_bench.rs +++ b/benches/data_loader_bench.rs @@ -12,43 +12,51 @@ use tailcall::HttpIO; #[derive(Clone)] struct MockHttpClient { - // To keep track of number of times execute is called - request_count: Arc, + // To keep track of number of times execute is called + request_count: Arc, } #[async_trait::async_trait] impl HttpIO for MockHttpClient { - async fn execute(&self, _req: Request) -> anyhow::Result> { - Ok(Response::empty()) - } + async fn execute(&self, _req: Request) -> anyhow::Result> { + Ok(Response::empty()) + } } fn benchmark_data_loader(c: &mut Criterion) { - c.bench_function("test_data_loader", |b| { - b.iter(|| { - tokio::runtime::Runtime::new().unwrap().spawn(async { - let client = Arc::new(MockHttpClient { request_count: Arc::new(AtomicUsize::new(0)) }); - let loader = HttpDataLoader::new(client.clone(), None, false); - let loader = loader.to_data_loader(Batch::default().delay(1)); - - let request1 = reqwest::Request::new(reqwest::Method::GET, "http://example.com/1".parse().unwrap()); - let request2 = reqwest::Request::new(reqwest::Method::GET, "http://example.com/2".parse().unwrap()); - - let headers_to_consider = BTreeSet::from(["Header1".to_string(), "Header2".to_string()]); - let key1 = DataLoaderRequest::new(request1, headers_to_consider.clone()); - let key2 = DataLoaderRequest::new(request2, headers_to_consider); - - let futures1 = (0..100).map(|_| loader.load_one(key1.clone())); - let futures2 = (0..100).map(|_| loader.load_one(key2.clone())); - let _ = join_all(futures1.chain(futures2)).await; - assert_eq!( - client.request_count.load(Ordering::SeqCst), - 2, - "Only one request should be made for the same key" - ); - }) - }) - }); + c.bench_function("test_data_loader", |b| { + b.iter(|| { + tokio::runtime::Runtime::new().unwrap().spawn(async { + let client = + Arc::new(MockHttpClient { request_count: Arc::new(AtomicUsize::new(0)) }); + let loader = HttpDataLoader::new(client.clone(), None, false); + let loader = loader.to_data_loader(Batch::default().delay(1)); + + let request1 = reqwest::Request::new( + reqwest::Method::GET, + "http://example.com/1".parse().unwrap(), + ); + let request2 = reqwest::Request::new( + reqwest::Method::GET, + "http://example.com/2".parse().unwrap(), + ); + + let headers_to_consider = + BTreeSet::from(["Header1".to_string(), "Header2".to_string()]); + let key1 = DataLoaderRequest::new(request1, headers_to_consider.clone()); + let key2 = DataLoaderRequest::new(request2, headers_to_consider); + + let futures1 = (0..100).map(|_| loader.load_one(key1.clone())); + let futures2 = (0..100).map(|_| loader.load_one(key2.clone())); + let _ = join_all(futures1.chain(futures2)).await; + assert_eq!( + client.request_count.load(Ordering::SeqCst), + 2, + "Only one request should be made for the same key" + ); + }) + }) + }); } criterion_group! { diff --git a/benches/impl_path_string_for_evaluation_context.rs b/benches/impl_path_string_for_evaluation_context.rs index aec1f0c1b61..79e4185092e 100644 --- a/benches/impl_path_string_for_evaluation_context.rs +++ b/benches/impl_path_string_for_evaluation_context.rs @@ -17,21 +17,21 @@ use tailcall::lambda::{EvaluationContext, ResolverContextLike}; use tailcall::path::PathString; const INPUT_VALUE: &[&[&str]] = &[ - // existing values - &["value", "root"], - &["value", "nested", "existing"], - // missing values - &["value", "missing"], - &["value", "nested", "missing"], + // existing values + &["value", "root"], + &["value", "nested", "existing"], + // missing values + &["value", "missing"], + &["value", "nested", "missing"], ]; const ARGS_VALUE: &[&[&str]] = &[ - // existing values - &["args", "root"], - &["args", "nested", "existing"], - // missing values - &["args", "missing"], - &["args", "nested", "missing"], + // existing values + &["args", "root"], + &["args", "nested", "existing"], + // missing values + &["args", "missing"], + &["args", "nested", "missing"], ]; const HEADERS_VALUE: &[&[&str]] = &[&["headers", "existing"], &["headers", "missing"]]; @@ -39,147 +39,156 @@ const HEADERS_VALUE: &[&[&str]] = &[&["headers", "existing"], &["headers", "miss const VARS_VALUE: &[&[&str]] = &[&["vars", "existing"], &["vars", "missing"]]; static TEST_VALUES: Lazy = Lazy::new(|| { - let mut root = IndexMap::new(); - let mut nested = IndexMap::new(); + let mut root = IndexMap::new(); + let mut nested = IndexMap::new(); - nested.insert(Name::new("existing"), Value::String("nested-test".to_owned())); + nested.insert( + Name::new("existing"), + Value::String("nested-test".to_owned()), + ); - root.insert(Name::new("root"), Value::String("root-test".to_owned())); - root.insert(Name::new("nested"), Value::Object(nested)); + root.insert(Name::new("root"), Value::String("root-test".to_owned())); + root.insert(Name::new("nested"), Value::Object(nested)); - Value::Object(root) + Value::Object(root) }); static TEST_ARGS: Lazy> = Lazy::new(|| { - let mut root = IndexMap::new(); - let mut nested = IndexMap::new(); + let mut root = IndexMap::new(); + let mut nested = IndexMap::new(); - nested.insert(Name::new("existing"), Value::String("nested-test".to_owned())); + nested.insert( + Name::new("existing"), + Value::String("nested-test".to_owned()), + ); - root.insert(Name::new("root"), Value::String("root-test".to_owned())); - root.insert(Name::new("nested"), Value::Object(nested)); + root.insert(Name::new("root"), Value::String("root-test".to_owned())); + root.insert(Name::new("nested"), Value::Object(nested)); - root + root }); static TEST_HEADERS: Lazy = Lazy::new(|| { - let mut map = HeaderMap::new(); + let mut map = HeaderMap::new(); - map.insert("x-existing", HeaderValue::from_static("header")); + map.insert("x-existing", HeaderValue::from_static("header")); - map + map }); static TEST_VARS: Lazy> = Lazy::new(|| { - let mut map = BTreeMap::new(); + let mut map = BTreeMap::new(); - map.insert("existing".to_owned(), "var".to_owned()); + map.insert("existing".to_owned(), "var".to_owned()); - map + map }); fn to_bench_id(input: &[&str]) -> BenchmarkId { - BenchmarkId::new("input", input.join(".")) + BenchmarkId::new("input", input.join(".")) } struct MockGraphqlContext; impl<'a> ResolverContextLike<'a> for MockGraphqlContext { - fn value(&'a self) -> Option<&'a Value> { - Some(&TEST_VALUES) - } + fn value(&'a self) -> Option<&'a Value> { + Some(&TEST_VALUES) + } - fn args(&'a self) -> Option<&'a IndexMap> { - Some(&TEST_ARGS) - } + fn args(&'a self) -> Option<&'a IndexMap> { + Some(&TEST_ARGS) + } - fn field(&'a self) -> Option { - None - } + fn field(&'a self) -> Option { + None + } - fn add_error(&'a self, _: async_graphql::ServerError) {} + fn add_error(&'a self, _: async_graphql::ServerError) {} } // assert that everything was set up correctly for the benchmark fn assert_test(eval_ctx: &EvaluationContext<'_, MockGraphqlContext>) { - // value - assert_eq!( - eval_ctx.path_string(&["value", "root"]), - Some(Cow::Borrowed("root-test")) - ); - assert_eq!( - eval_ctx.path_string(&["value", "nested", "existing"]), - Some(Cow::Borrowed("nested-test")) - ); - assert_eq!(eval_ctx.path_string(&["value", "missing"]), None); - assert_eq!(eval_ctx.path_string(&["value", "nested", "missing"]), None); - - // args - assert_eq!( - eval_ctx.path_string(&["args", "root"]), - Some(Cow::Borrowed("root-test")) - ); - assert_eq!( - eval_ctx.path_string(&["args", "nested", "existing"]), - Some(Cow::Borrowed("nested-test")) - ); - assert_eq!(eval_ctx.path_string(&["args", "missing"]), None); - assert_eq!(eval_ctx.path_string(&["args", "nested", "missing"]), None); - - // headers - assert_eq!( - eval_ctx.path_string(&["headers", "x-existing"]), - Some(Cow::Borrowed("header")) - ); - assert_eq!(eval_ctx.path_string(&["headers", "x-missing"]), None); - - // vars - assert_eq!(eval_ctx.path_string(&["vars", "existing"]), Some(Cow::Borrowed("var"))); - assert_eq!(eval_ctx.path_string(&["vars", "missing"]), None); + // value + assert_eq!( + eval_ctx.path_string(&["value", "root"]), + Some(Cow::Borrowed("root-test")) + ); + assert_eq!( + eval_ctx.path_string(&["value", "nested", "existing"]), + Some(Cow::Borrowed("nested-test")) + ); + assert_eq!(eval_ctx.path_string(&["value", "missing"]), None); + assert_eq!(eval_ctx.path_string(&["value", "nested", "missing"]), None); + + // args + assert_eq!( + eval_ctx.path_string(&["args", "root"]), + Some(Cow::Borrowed("root-test")) + ); + assert_eq!( + eval_ctx.path_string(&["args", "nested", "existing"]), + Some(Cow::Borrowed("nested-test")) + ); + assert_eq!(eval_ctx.path_string(&["args", "missing"]), None); + assert_eq!(eval_ctx.path_string(&["args", "nested", "missing"]), None); + + // headers + assert_eq!( + eval_ctx.path_string(&["headers", "x-existing"]), + Some(Cow::Borrowed("header")) + ); + assert_eq!(eval_ctx.path_string(&["headers", "x-missing"]), None); + + // vars + assert_eq!( + eval_ctx.path_string(&["vars", "existing"]), + Some(Cow::Borrowed("var")) + ); + assert_eq!(eval_ctx.path_string(&["vars", "missing"]), None); } fn request_context() -> RequestContext { - let tailcall::config::Config { server, upstream, .. } = tailcall::config::Config::default(); - //TODO: default is used only in tests. Drop default and move it to test. - let server = Server::try_from(server).unwrap(); - - let h_client = Arc::new(init_http(&upstream)); - let h2_client = Arc::new(init_http2_only(&upstream)); - RequestContext { - req_headers: HeaderMap::new(), - h_client, - h2_client, - server, - upstream, - http_data_loaders: Arc::new(vec![]), - gql_data_loaders: Arc::new(vec![]), - cache: Arc::new(NativeChronoCache::new()), - grpc_data_loaders: Arc::new(vec![]), - min_max_age: Arc::new(Mutex::new(None)), - cache_public: Arc::new(Mutex::new(None)), - env_vars: Arc::new(init_env()), - } + let tailcall::config::Config { server, upstream, .. } = tailcall::config::Config::default(); + //TODO: default is used only in tests. Drop default and move it to test. + let server = Server::try_from(server).unwrap(); + + let h_client = Arc::new(init_http(&upstream)); + let h2_client = Arc::new(init_http2_only(&upstream)); + RequestContext { + req_headers: HeaderMap::new(), + h_client, + h2_client, + server, + upstream, + http_data_loaders: Arc::new(vec![]), + gql_data_loaders: Arc::new(vec![]), + cache: Arc::new(NativeChronoCache::new()), + grpc_data_loaders: Arc::new(vec![]), + min_max_age: Arc::new(Mutex::new(None)), + cache_public: Arc::new(Mutex::new(None)), + env_vars: Arc::new(init_env()), + } } fn bench_main(c: &mut Criterion) { - let mut req_ctx = request_context().req_headers(TEST_HEADERS.clone()); + let mut req_ctx = request_context().req_headers(TEST_HEADERS.clone()); - req_ctx.server.vars = TEST_VARS.clone(); - let eval_ctx = EvaluationContext::new(&req_ctx, &MockGraphqlContext); + req_ctx.server.vars = TEST_VARS.clone(); + let eval_ctx = EvaluationContext::new(&req_ctx, &MockGraphqlContext); - assert_test(&eval_ctx); + assert_test(&eval_ctx); - let all_inputs = INPUT_VALUE - .iter() - .chain(ARGS_VALUE) - .chain(HEADERS_VALUE) - .chain(VARS_VALUE); + let all_inputs = INPUT_VALUE + .iter() + .chain(ARGS_VALUE) + .chain(HEADERS_VALUE) + .chain(VARS_VALUE); - for input in all_inputs { - c.bench_with_input(to_bench_id(input), input, |b, input| { - b.iter(|| eval_ctx.path_string(input)); - }); - } + for input in all_inputs { + c.bench_with_input(to_bench_id(input), input, |b, input| { + b.iter(|| eval_ctx.path_string(input)); + }); + } } criterion_group!(benches, bench_main); diff --git a/benches/json_like_bench.rs b/benches/json_like_bench.rs index 8e8a4ce9655..849641140ed 100644 --- a/benches/json_like_bench.rs +++ b/benches/json_like_bench.rs @@ -2,31 +2,31 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use serde_json::json; fn benchmark_batched_body(c: &mut Criterion) { - c.bench_function("test_batched_body", |b| { - b.iter(|| { - let input = json!({ - "data": [ - {"user": {"id": "1"}}, - {"user": {"id": "2"}}, - {"user": {"id": "3"}}, - {"user": [ - {"id": "4"}, - {"id": "5"} - ] - }, - ] - }); + c.bench_function("test_batched_body", |b| { + b.iter(|| { + let input = json!({ + "data": [ + {"user": {"id": "1"}}, + {"user": {"id": "2"}}, + {"user": {"id": "3"}}, + {"user": [ + {"id": "4"}, + {"id": "5"} + ] + }, + ] + }); - black_box( - serde_json::to_value(tailcall::json::gather_path_matches( - &input, - &["data".into(), "user".into(), "id".into()], - vec![], - )) - .unwrap(), - ); - }) - }); + black_box( + serde_json::to_value(tailcall::json::gather_path_matches( + &input, + &["data".into(), "user".into(), "id".into()], + vec![], + )) + .unwrap(), + ); + }) + }); } criterion_group!(benches, benchmark_batched_body); diff --git a/benches/request_template_bench.rs b/benches/request_template_bench.rs index 086fbf0ec94..bc67c1486fb 100644 --- a/benches/request_template_bench.rs +++ b/benches/request_template_bench.rs @@ -11,51 +11,53 @@ use tailcall::path::PathString; #[derive(Setters)] struct Context { - pub value: serde_json::Value, - pub headers: HeaderMap, + pub value: serde_json::Value, + pub headers: HeaderMap, } impl Default for Context { - fn default() -> Self { - Self { value: serde_json::Value::Null, headers: HeaderMap::new() } - } + fn default() -> Self { + Self { value: serde_json::Value::Null, headers: HeaderMap::new() } + } } impl PathString for Context { - fn path_string>(&self, parts: &[T]) -> Option> { - self.value.path_string(parts) - } + fn path_string>(&self, parts: &[T]) -> Option> { + self.value.path_string(parts) + } } impl HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers - } + fn headers(&self) -> &HeaderMap { + &self.headers + } } fn benchmark_to_request(c: &mut Criterion) { - let tmpl_mustache = RequestTemplate::try_from(Endpoint::new( - "http://localhost:3000/{{args.b}}?a={{args.a}}&b={{args.b}}&c={{args.c}}".to_string(), - )) - .unwrap(); + let tmpl_mustache = RequestTemplate::try_from(Endpoint::new( + "http://localhost:3000/{{args.b}}?a={{args.a}}&b={{args.b}}&c={{args.c}}".to_string(), + )) + .unwrap(); - let tmpl_literal = - RequestTemplate::try_from(Endpoint::new("http://localhost:3000/foo?a=bar&b=foo&c=baz".to_string())).unwrap(); + let tmpl_literal = RequestTemplate::try_from(Endpoint::new( + "http://localhost:3000/foo?a=bar&b=foo&c=baz".to_string(), + )) + .unwrap(); - let ctx = Context::default().value(json!({ - "args": { - "b": "foo" - } - })); + let ctx = Context::default().value(json!({ + "args": { + "b": "foo" + } + })); - c.bench_function("with_mustache_literal", |b| { - b.iter(|| { - black_box(tmpl_literal.to_request(&ctx).unwrap()); - }) - }); + c.bench_function("with_mustache_literal", |b| { + b.iter(|| { + black_box(tmpl_literal.to_request(&ctx).unwrap()); + }) + }); - c.bench_function("with_mustache_expressions", |b| { - b.iter(|| { - black_box(tmpl_mustache.to_request(&ctx).unwrap()); - }) - }); + c.bench_function("with_mustache_expressions", |b| { + b.iter(|| { + black_box(tmpl_mustache.to_request(&ctx).unwrap()); + }) + }); } criterion_group! { diff --git a/cloudflare/src/cache.rs b/cloudflare/src/cache.rs index 1ca2bf36c67..ee524223f62 100644 --- a/cloudflare/src/cache.rs +++ b/cloudflare/src/cache.rs @@ -10,7 +10,7 @@ use worker::kv::KvStore; use crate::to_anyhow; pub struct CloudflareChronoCache { - env: Rc, + env: Rc, } unsafe impl Send for CloudflareChronoCache {} @@ -18,42 +18,46 @@ unsafe impl Send for CloudflareChronoCache {} unsafe impl Sync for CloudflareChronoCache {} impl CloudflareChronoCache { - pub fn init(env: Rc) -> Self { - Self { env } - } - fn get_kv(&self) -> Result { - self.env.kv("TMP_KV").map_err(to_anyhow) - } + pub fn init(env: Rc) -> Self { + Self { env } + } + fn get_kv(&self) -> Result { + self.env.kv("TMP_KV").map_err(to_anyhow) + } } // TODO: Needs fix #[async_trait::async_trait] impl Cache for CloudflareChronoCache { - type Key = u64; - type Value = ConstValue; - async fn set<'a>(&'a self, key: u64, value: ConstValue, ttl: NonZeroU64) -> Result { - let kv_store = self.get_kv()?; - let ttl = ttl.get(); - async_std::task::spawn_local(async move { - kv_store - .put(&key.to_string(), value.to_string()) - .map_err(to_anyhow)? - .expiration_ttl(ttl) - .execute() + type Key = u64; + type Value = ConstValue; + async fn set<'a>(&'a self, key: u64, value: ConstValue, ttl: NonZeroU64) -> Result { + let kv_store = self.get_kv()?; + let ttl = ttl.get(); + async_std::task::spawn_local(async move { + kv_store + .put(&key.to_string(), value.to_string()) + .map_err(to_anyhow)? + .expiration_ttl(ttl) + .execute() + .await + .map_err(to_anyhow)?; + anyhow::Ok(value) + }) .await - .map_err(to_anyhow)?; - anyhow::Ok(value) - }) - .await - } + } - async fn get<'a>(&'a self, key: &'a u64) -> Result { - let kv_store = self.get_kv()?; - let key = key.to_string(); - async_std::task::spawn_local(async move { - let val = kv_store.get(&key).json::().await.map_err(to_anyhow)?; - let val = val.ok_or(anyhow!("key not found"))?; - Ok(ConstValue::from_json(val)?) - }) - .await - } + async fn get<'a>(&'a self, key: &'a u64) -> Result { + let kv_store = self.get_kv()?; + let key = key.to_string(); + async_std::task::spawn_local(async move { + let val = kv_store + .get(&key) + .json::() + .await + .map_err(to_anyhow)?; + let val = val.ok_or(anyhow!("key not found"))?; + Ok(ConstValue::from_json(val)?) + }) + .await + } } diff --git a/cloudflare/src/env.rs b/cloudflare/src/env.rs index 80275a755e9..d3694e99238 100644 --- a/cloudflare/src/env.rs +++ b/cloudflare/src/env.rs @@ -4,20 +4,20 @@ use tailcall::EnvIO; use worker::Env; pub struct CloudflareEnv { - env: Rc, + env: Rc, } unsafe impl Send for CloudflareEnv {} unsafe impl Sync for CloudflareEnv {} impl EnvIO for CloudflareEnv { - fn get(&self, key: &str) -> Option { - self.env.var(key).ok().map(|s| s.to_string()) - } + fn get(&self, key: &str) -> Option { + self.env.var(key).ok().map(|s| s.to_string()) + } } impl CloudflareEnv { - pub fn init(env: Rc) -> Self { - Self { env } - } + pub fn init(env: Rc) -> Self { + Self { env } + } } diff --git a/cloudflare/src/file.rs b/cloudflare/src/file.rs index 295079745b3..e291af14796 100644 --- a/cloudflare/src/file.rs +++ b/cloudflare/src/file.rs @@ -8,48 +8,53 @@ use crate::to_anyhow; #[derive(Clone)] pub struct CloudflareFileIO { - bucket: Rc, + bucket: Rc, } impl CloudflareFileIO { - pub fn init(env: Rc, bucket_id: String) -> anyhow::Result { - let bucket = env.bucket(bucket_id.as_str()).map_err(|e| anyhow!(e.to_string()))?; - Ok(CloudflareFileIO { bucket: Rc::new(bucket) }) - } + pub fn init(env: Rc, bucket_id: String) -> anyhow::Result { + let bucket = env + .bucket(bucket_id.as_str()) + .map_err(|e| anyhow!(e.to_string()))?; + Ok(CloudflareFileIO { bucket: Rc::new(bucket) }) + } } impl CloudflareFileIO { - async fn get(&self, path: String) -> anyhow::Result { - let maybe_object = self.bucket.get(&path).execute().await.map_err(to_anyhow)?; - let object = maybe_object.ok_or(anyhow!("File '{}' was not found in bucket", path))?; - - let body = match object.body() { - Some(body) => body.text().await.map_err(to_anyhow), - None => Ok("".to_string()), - }; - body - } - - async fn put(&self, path: String, value: Vec) -> anyhow::Result<()> { - self.bucket.put(&path, value).execute().await.map_err(to_anyhow)?; - Ok(()) - } + async fn get(&self, path: String) -> anyhow::Result { + let maybe_object = self.bucket.get(&path).execute().await.map_err(to_anyhow)?; + let object = maybe_object.ok_or(anyhow!("File '{}' was not found in bucket", path))?; + + let body = match object.body() { + Some(body) => body.text().await.map_err(to_anyhow), + None => Ok("".to_string()), + }; + body + } + + async fn put(&self, path: String, value: Vec) -> anyhow::Result<()> { + self.bucket + .put(&path, value) + .execute() + .await + .map_err(to_anyhow)?; + Ok(()) + } } impl FileIO for CloudflareFileIO { - async fn write<'a>(&'a self, file_path: &'a str, content: &'a [u8]) -> anyhow::Result<()> { - self - .put(file_path.to_string(), content.to_vec()) - .await - .map_err(to_anyhow)?; - - log::info!("File write: {} ... ok", file_path); - Ok(()) - } - - async fn read<'a>(&'a self, file_path: &'a str) -> anyhow::Result { - let content = self.get(file_path.to_string()).await.map_err(to_anyhow)?; - log::info!("File read: {} ... ok", file_path); - Ok(content) - } + async fn write<'a>(&'a self, file_path: &'a str, content: &'a [u8]) -> anyhow::Result<()> { + self.put(file_path.to_string(), content.to_vec()) + .await + .map_err(to_anyhow)?; + + log::info!("File write: {} ... ok", file_path); + Ok(()) + } + + async fn read<'a>(&'a self, file_path: &'a str) -> anyhow::Result { + let content = self.get(file_path.to_string()).await.map_err(to_anyhow)?; + log::info!("File read: {} ... ok", file_path); + Ok(content) + } } diff --git a/cloudflare/src/handle.rs b/cloudflare/src/handle.rs index 8aed2bfaa8c..7f511c02c70 100644 --- a/cloudflare/src/handle.rs +++ b/cloudflare/src/handle.rs @@ -16,21 +16,30 @@ use crate::{init_cache, init_file, init_http}; type CloudFlareAppContext = AppContext; lazy_static! { - static ref APP_CTX: RwLock)>> = RwLock::new(None); + static ref APP_CTX: RwLock)>> = RwLock::new(None); } /// /// The handler which handles requests on cloudflare /// -pub async fn fetch(req: worker::Request, env: worker::Env, _: worker::Context) -> anyhow::Result { - log::info!("{} {:?}", req.method().to_string(), req.url().map(|u| u.to_string())); - let req = to_request(req).await?; - let env = Rc::new(env); - let app_ctx = match get_app_ctx(env, &req).await? { - Ok(app_ctx) => app_ctx, - Err(e) => return Ok(to_response(e).await?), - }; - let resp = handle_request::(req, app_ctx).await?; - Ok(to_response(resp).await?) +pub async fn fetch( + req: worker::Request, + env: worker::Env, + _: worker::Context, +) -> anyhow::Result { + log::info!( + "{} {:?}", + req.method().to_string(), + req.url().map(|u| u.to_string()) + ); + let req = to_request(req).await?; + let env = Rc::new(env); + let app_ctx = match get_app_ctx(env, &req).await? { + Ok(app_ctx) => app_ctx, + Err(e) => return Ok(to_response(e).await?), + }; + let resp = + handle_request::(req, app_ctx).await?; + Ok(to_response(resp).await?) } /// @@ -38,47 +47,54 @@ pub async fn fetch(req: worker::Request, env: worker::Env, _: worker::Context) - /// for future requests. /// async fn get_app_ctx( - env: Rc, - req: &Request, + env: Rc, + req: &Request, ) -> anyhow::Result, Response>> { - // Read context from cache - let file_path = req - .uri() - .query() - .and_then(|x| serde_qs::from_str::>(x).ok()) - .and_then(|x| x.get("config").cloned()); + // Read context from cache + let file_path = req + .uri() + .query() + .and_then(|x| serde_qs::from_str::>(x).ok()) + .and_then(|x| x.get("config").cloned()); - if let Some(file_path) = &file_path { - if let Some(app_ctx) = read_app_ctx() { - if app_ctx.0 == file_path.borrow() { - log::info!("Using cached application context"); - return Ok(Ok(app_ctx.clone().1)); - } + if let Some(file_path) = &file_path { + if let Some(app_ctx) = read_app_ctx() { + if app_ctx.0 == file_path.borrow() { + log::info!("Using cached application context"); + return Ok(Ok(app_ctx.clone().1)); + } + } } - } - // Create new context - let env_io = CloudflareEnv::init(env.clone()); - let bucket_id = env_io.get("BUCKET").ok_or(anyhow!("CONFIG var is not set"))?; - log::debug!("R2 Bucket ID: {}", bucket_id); + // Create new context + let env_io = CloudflareEnv::init(env.clone()); + let bucket_id = env_io + .get("BUCKET") + .ok_or(anyhow!("CONFIG var is not set"))?; + log::debug!("R2 Bucket ID: {}", bucket_id); - let file = init_file(env.clone(), bucket_id)?; - let http = init_http(); - let cache = init_cache(env); + let file = init_file(env.clone(), bucket_id)?; + let http = init_http(); + let cache = init_cache(env); - match showcase_get_app_ctx::(req, (http, env_io, Some(file), Arc::new(cache))).await? { - Ok(app_ctx) => { - let app_ctx = Arc::new(app_ctx); - if let Some(file_path) = file_path { - *APP_CTX.write().unwrap() = Some((file_path, app_ctx.clone())); - } - log::info!("Initialized new application context"); - Ok(Ok(app_ctx)) + match showcase_get_app_ctx::( + req, + (http, env_io, Some(file), Arc::new(cache)), + ) + .await? + { + Ok(app_ctx) => { + let app_ctx = Arc::new(app_ctx); + if let Some(file_path) = file_path { + *APP_CTX.write().unwrap() = Some((file_path, app_ctx.clone())); + } + log::info!("Initialized new application context"); + Ok(Ok(app_ctx)) + } + Err(e) => Ok(Err(e)), } - Err(e) => Ok(Err(e)), - } } fn read_app_ctx() -> Option<(String, Arc)> { - APP_CTX.read().unwrap().clone() + APP_CTX.read().unwrap().clone() } diff --git a/cloudflare/src/http.rs b/cloudflare/src/http.rs index 3b3a42ee275..830df242521 100644 --- a/cloudflare/src/http.rs +++ b/cloudflare/src/http.rs @@ -9,81 +9,87 @@ use crate::to_anyhow; #[derive(Clone)] pub struct CloudflareHttp { - client: Client, + client: Client, } impl Default for CloudflareHttp { - fn default() -> Self { - Self { client: Client::new() } - } + fn default() -> Self { + Self { client: Client::new() } + } } impl CloudflareHttp { - pub fn init() -> Self { - let client = Client::new(); - Self { client } - } + pub fn init() -> Self { + let client = Client::new(); + Self { client } + } } #[async_trait::async_trait] impl HttpIO for CloudflareHttp { - // HttpClientOptions are ignored in Cloudflare - // This is because there is little control over the underlying HTTP client - async fn execute(&self, request: reqwest::Request) -> Result> { - let client = self.client.clone(); - let method = request.method().clone(); - let url = request.url().clone(); - // TODO: remove spawn local - let res = spawn_local(async move { - let response = client.execute(request).await?.error_for_status()?; - Response::from_reqwest(response).await - }) - .await?; - log::info!("{} {} {}", method, url, res.status.as_u16()); - Ok(res) - } + // HttpClientOptions are ignored in Cloudflare + // This is because there is little control over the underlying HTTP client + async fn execute(&self, request: reqwest::Request) -> Result> { + let client = self.client.clone(); + let method = request.method().clone(); + let url = request.url().clone(); + // TODO: remove spawn local + let res = spawn_local(async move { + let response = client.execute(request).await?.error_for_status()?; + Response::from_reqwest(response).await + }) + .await?; + log::info!("{} {} {}", method, url, res.status.as_u16()); + Ok(res) + } } -pub async fn to_response(response: hyper::Response) -> anyhow::Result { - let status = response.status().as_u16(); - let headers = response.headers().clone(); - let bytes = hyper::body::to_bytes(response).await?; - let body = worker::ResponseBody::Body(bytes.to_vec()); - let mut w_response = worker::Response::from_body(body).map_err(to_anyhow)?; - w_response = w_response.with_status(status); - let mut_headers = w_response.headers_mut(); - for (name, value) in headers.iter() { - let value = String::from_utf8(value.as_bytes().to_vec())?; - mut_headers.append(name.as_str(), &value).map_err(to_anyhow)?; - } +pub async fn to_response( + response: hyper::Response, +) -> anyhow::Result { + let status = response.status().as_u16(); + let headers = response.headers().clone(); + let bytes = hyper::body::to_bytes(response).await?; + let body = worker::ResponseBody::Body(bytes.to_vec()); + let mut w_response = worker::Response::from_body(body).map_err(to_anyhow)?; + w_response = w_response.with_status(status); + let mut_headers = w_response.headers_mut(); + for (name, value) in headers.iter() { + let value = String::from_utf8(value.as_bytes().to_vec())?; + mut_headers + .append(name.as_str(), &value) + .map_err(to_anyhow)?; + } - Ok(w_response) + Ok(w_response) } pub fn to_method(method: worker::Method) -> anyhow::Result { - let method = &*method.to_string().to_uppercase(); - match method { - "GET" => Ok(hyper::Method::GET), - "POST" => Ok(hyper::Method::POST), - "PUT" => Ok(hyper::Method::PUT), - "DELETE" => Ok(hyper::Method::DELETE), - "HEAD" => Ok(hyper::Method::HEAD), - "OPTIONS" => Ok(hyper::Method::OPTIONS), - "PATCH" => Ok(hyper::Method::PATCH), - "CONNECT" => Ok(hyper::Method::CONNECT), - "TRACE" => Ok(hyper::Method::TRACE), - method => Err(anyhow!("Unsupported HTTP method: {}", method)), - } + let method = &*method.to_string().to_uppercase(); + match method { + "GET" => Ok(hyper::Method::GET), + "POST" => Ok(hyper::Method::POST), + "PUT" => Ok(hyper::Method::PUT), + "DELETE" => Ok(hyper::Method::DELETE), + "HEAD" => Ok(hyper::Method::HEAD), + "OPTIONS" => Ok(hyper::Method::OPTIONS), + "PATCH" => Ok(hyper::Method::PATCH), + "CONNECT" => Ok(hyper::Method::CONNECT), + "TRACE" => Ok(hyper::Method::TRACE), + method => Err(anyhow!("Unsupported HTTP method: {}", method)), + } } pub async fn to_request(mut req: worker::Request) -> anyhow::Result> { - let body = req.text().await.map_err(to_anyhow)?; - let method = req.method(); - let uri = req.url().map_err(to_anyhow)?.as_str().to_string(); - let headers = req.headers(); - let mut builder = hyper::Request::builder().method(to_method(method)?).uri(uri); - for (k, v) in headers { - builder = builder.header(k, v); - } - Ok(builder.body(hyper::body::Body::from(body))?) + let body = req.text().await.map_err(to_anyhow)?; + let method = req.method(); + let uri = req.url().map_err(to_anyhow)?.as_str().to_string(); + let headers = req.headers(); + let mut builder = hyper::Request::builder() + .method(to_method(method)?) + .uri(uri); + for (k, v) in headers { + builder = builder.header(k, v); + } + Ok(builder.body(hyper::body::Body::from(body))?) } diff --git a/cloudflare/src/lib.rs b/cloudflare/src/lib.rs index 8e659c55b3e..5f77dbb3be2 100644 --- a/cloudflare/src/lib.rs +++ b/cloudflare/src/lib.rs @@ -10,41 +10,48 @@ mod handle; mod http; pub fn init_env(env: Rc) -> env::CloudflareEnv { - env::CloudflareEnv::init(env) + env::CloudflareEnv::init(env) } -pub fn init_file(env: Rc, bucket_id: String) -> anyhow::Result { - file::CloudflareFileIO::init(env, bucket_id) +pub fn init_file( + env: Rc, + bucket_id: String, +) -> anyhow::Result { + file::CloudflareFileIO::init(env, bucket_id) } pub fn init_http() -> http::CloudflareHttp { - http::CloudflareHttp::init() + http::CloudflareHttp::init() } pub fn init_cache(env: Rc) -> cache::CloudflareChronoCache { - cache::CloudflareChronoCache::init(env) + cache::CloudflareChronoCache::init(env) } #[worker::event(fetch)] -async fn fetch(req: worker::Request, env: worker::Env, context: worker::Context) -> anyhow::Result { - let result = handle::fetch(req, env, context).await; - - match result { - Ok(response) => Ok(response), - Err(message) => { - log::error!("ServerError: {}", message.to_string()); - worker::Response::error(message.to_string(), 500).map_err(to_anyhow) +async fn fetch( + req: worker::Request, + env: worker::Env, + context: worker::Context, +) -> anyhow::Result { + let result = handle::fetch(req, env, context).await; + + match result { + Ok(response) => Ok(response), + Err(message) => { + log::error!("ServerError: {}", message.to_string()); + worker::Response::error(message.to_string(), 500).map_err(to_anyhow) + } } - } } #[worker::event(start)] fn start() { - // Initialize Logger - wasm_logger::init(wasm_logger::Config::new(log::Level::Info)); - panic::set_hook(Box::new(console_error_panic_hook::hook)); + // Initialize Logger + wasm_logger::init(wasm_logger::Config::new(log::Level::Info)); + panic::set_hook(Box::new(console_error_panic_hook::hook)); } fn to_anyhow(e: T) -> anyhow::Error { - anyhow!("{}", e) + anyhow!("{}", e) } diff --git a/src/app_context.rs b/src/app_context.rs index 913f4726112..00c439feddf 100644 --- a/src/app_context.rs +++ b/src/app_context.rs @@ -13,105 +13,112 @@ use crate::lambda::{DataLoaderId, Expression, IO}; use crate::{grpc, EntityCache, EnvIO, HttpIO}; pub struct AppContext { - pub schema: dynamic::Schema, - pub universal_http_client: Arc, - pub http2_only_client: Arc, - pub blueprint: Blueprint, - pub http_data_loaders: Arc>>, - pub gql_data_loaders: Arc>>, - pub grpc_data_loaders: Arc>>, - pub cache: Arc, - pub env_vars: Arc, + pub schema: dynamic::Schema, + pub universal_http_client: Arc, + pub http2_only_client: Arc, + pub blueprint: Blueprint, + pub http_data_loaders: Arc>>, + pub gql_data_loaders: Arc>>, + pub grpc_data_loaders: Arc>>, + pub cache: Arc, + pub env_vars: Arc, } impl AppContext { - #[allow(clippy::too_many_arguments)] - pub fn new( - mut blueprint: Blueprint, - h_client: Arc, - h2_client: Arc, - env: Arc, - cache: Arc, - ) -> Self { - let mut http_data_loaders = vec![]; - let mut gql_data_loaders = vec![]; - let mut grpc_data_loaders = vec![]; - - for def in blueprint.definitions.iter_mut() { - if let Definition::ObjectTypeDefinition(def) = def { - for field in &mut def.fields { - if let Some(Expression::IO(expr)) = &mut field.resolver { - match expr { - IO::Http { req_template, group_by, .. } => { - let data_loader = HttpDataLoader::new( - h_client.clone(), - group_by.clone(), - matches!(&field.of_type, ListType { .. }), - ) - .to_data_loader(blueprint.upstream.batch.clone().unwrap_or_default()); - - field.resolver = Some(Expression::IO(IO::Http { - req_template: req_template.clone(), - group_by: group_by.clone(), - dl_id: Some(DataLoaderId(http_data_loaders.len())), - })); - - http_data_loaders.push(data_loader); - } - - IO::GraphQLEndpoint { req_template, field_name, batch, .. } => { - let graphql_data_loader = GraphqlDataLoader::new(h_client.clone(), *batch) - .to_data_loader(blueprint.upstream.batch.clone().unwrap_or_default()); - - field.resolver = Some(Expression::IO(IO::GraphQLEndpoint { - req_template: req_template.clone(), - field_name: field_name.clone(), - batch: *batch, - dl_id: Some(DataLoaderId(gql_data_loaders.len())), - })); - - gql_data_loaders.push(graphql_data_loader); - } - - IO::Grpc { req_template, group_by, .. } => { - let data_loader = GrpcDataLoader { - client: h2_client.clone(), - operation: req_template.operation.clone(), - group_by: group_by.clone(), - }; - let data_loader = data_loader.to_data_loader(blueprint.upstream.batch.clone().unwrap_or_default()); - - field.resolver = Some(Expression::IO(IO::Grpc { - req_template: req_template.clone(), - group_by: group_by.clone(), - dl_id: Some(DataLoaderId(grpc_data_loaders.len())), - })); - - grpc_data_loaders.push(data_loader); - } - _ => {} + #[allow(clippy::too_many_arguments)] + pub fn new( + mut blueprint: Blueprint, + h_client: Arc, + h2_client: Arc, + env: Arc, + cache: Arc, + ) -> Self { + let mut http_data_loaders = vec![]; + let mut gql_data_loaders = vec![]; + let mut grpc_data_loaders = vec![]; + + for def in blueprint.definitions.iter_mut() { + if let Definition::ObjectTypeDefinition(def) = def { + for field in &mut def.fields { + if let Some(Expression::IO(expr)) = &mut field.resolver { + match expr { + IO::Http { req_template, group_by, .. } => { + let data_loader = HttpDataLoader::new( + h_client.clone(), + group_by.clone(), + matches!(&field.of_type, ListType { .. }), + ) + .to_data_loader( + blueprint.upstream.batch.clone().unwrap_or_default(), + ); + + field.resolver = Some(Expression::IO(IO::Http { + req_template: req_template.clone(), + group_by: group_by.clone(), + dl_id: Some(DataLoaderId(http_data_loaders.len())), + })); + + http_data_loaders.push(data_loader); + } + + IO::GraphQLEndpoint { req_template, field_name, batch, .. } => { + let graphql_data_loader = + GraphqlDataLoader::new(h_client.clone(), *batch) + .to_data_loader( + blueprint.upstream.batch.clone().unwrap_or_default(), + ); + + field.resolver = Some(Expression::IO(IO::GraphQLEndpoint { + req_template: req_template.clone(), + field_name: field_name.clone(), + batch: *batch, + dl_id: Some(DataLoaderId(gql_data_loaders.len())), + })); + + gql_data_loaders.push(graphql_data_loader); + } + + IO::Grpc { req_template, group_by, .. } => { + let data_loader = GrpcDataLoader { + client: h2_client.clone(), + operation: req_template.operation.clone(), + group_by: group_by.clone(), + }; + let data_loader = data_loader.to_data_loader( + blueprint.upstream.batch.clone().unwrap_or_default(), + ); + + field.resolver = Some(Expression::IO(IO::Grpc { + req_template: req_template.clone(), + group_by: group_by.clone(), + dl_id: Some(DataLoaderId(grpc_data_loaders.len())), + })); + + grpc_data_loaders.push(data_loader); + } + _ => {} + } + } + } } - } } - } - } - let schema = blueprint.to_schema(); - - AppContext { - schema, - universal_http_client: h_client, - http2_only_client: h2_client, - blueprint, - http_data_loaders: Arc::new(http_data_loaders), - gql_data_loaders: Arc::new(gql_data_loaders), - cache, - grpc_data_loaders: Arc::new(grpc_data_loaders), - env_vars: env, + let schema = blueprint.to_schema(); + + AppContext { + schema, + universal_http_client: h_client, + http2_only_client: h2_client, + blueprint, + http_data_loaders: Arc::new(http_data_loaders), + gql_data_loaders: Arc::new(gql_data_loaders), + cache, + grpc_data_loaders: Arc::new(grpc_data_loaders), + env_vars: env, + } } - } - pub async fn execute(&self, request: impl Into) -> Response { - self.schema.execute(request).await - } + pub async fn execute(&self, request: impl Into) -> Response { + self.schema.execute(request).await + } } diff --git a/src/async_graphql_hyper.rs b/src/async_graphql_hyper.rs index 03e1bd477f0..5c5e7f448c9 100644 --- a/src/async_graphql_hyper.rs +++ b/src/async_graphql_hyper.rs @@ -9,10 +9,10 @@ use serde::{Deserialize, Serialize}; #[async_trait::async_trait] pub trait GraphQLRequestLike { - fn data(self, data: D) -> Self; - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor; + fn data(self, data: D) -> Self; + async fn execute(self, executor: &E) -> GraphQLResponse + where + E: Executor; } #[derive(Debug, Deserialize)] @@ -21,19 +21,19 @@ impl GraphQLBatchRequest {} #[async_trait::async_trait] impl GraphQLRequestLike for GraphQLBatchRequest { - fn data(mut self, data: D) -> Self { - for request in self.0.iter_mut() { - request.data.insert(data.clone()); + fn data(mut self, data: D) -> Self { + for request in self.0.iter_mut() { + request.data.insert(data.clone()); + } + self + } + /// Shortcut method to execute the request on the executor. + async fn execute(self, executor: &E) -> GraphQLResponse + where + E: Executor, + { + GraphQLResponse(executor.execute_batch(self.0).await) } - self - } - /// Shortcut method to execute the request on the executor. - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - GraphQLResponse(executor.execute_batch(self.0).await) - } } #[derive(Debug, Deserialize)] @@ -43,116 +43,118 @@ impl GraphQLRequest {} #[async_trait::async_trait] impl GraphQLRequestLike for GraphQLRequest { - #[must_use] - fn data(mut self, data: D) -> Self { - self.0.data.insert(data); - self - } - /// Shortcut method to execute the request on the schema. - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - GraphQLResponse(executor.execute(self.0).await.into()) - } + #[must_use] + fn data(mut self, data: D) -> Self { + self.0.data.insert(data); + self + } + /// Shortcut method to execute the request on the schema. + async fn execute(self, executor: &E) -> GraphQLResponse + where + E: Executor, + { + GraphQLResponse(executor.execute(self.0).await.into()) + } } #[derive(Debug, Serialize)] pub struct GraphQLResponse(pub async_graphql::BatchResponse); impl From for GraphQLResponse { - fn from(batch: async_graphql::BatchResponse) -> Self { - Self(batch) - } + fn from(batch: async_graphql::BatchResponse) -> Self { + Self(batch) + } } impl From for GraphQLResponse { - fn from(res: async_graphql::Response) -> Self { - Self(res.into()) - } + fn from(res: async_graphql::Response) -> Self { + Self(res.into()) + } } impl From for GraphQLRequest { - fn from(query: GraphQLQuery) -> Self { - let mut request = async_graphql::Request::new(query.query); + fn from(query: GraphQLQuery) -> Self { + let mut request = async_graphql::Request::new(query.query); - if let Some(operation_name) = query.operation_name { - request = request.operation_name(operation_name); - } + if let Some(operation_name) = query.operation_name { + request = request.operation_name(operation_name); + } - if let Some(variables) = query.variables { - let value = serde_json::from_str(&variables).unwrap_or_default(); - let variables = async_graphql::Variables::from_json(value); - request = request.variables(variables); - } + if let Some(variables) = query.variables { + let value = serde_json::from_str(&variables).unwrap_or_default(); + let variables = async_graphql::Variables::from_json(value); + request = request.variables(variables); + } - GraphQLRequest(request) - } + GraphQLRequest(request) + } } #[derive(Debug)] pub struct GraphQLQuery { - query: String, - operation_name: Option, - variables: Option, + query: String, + operation_name: Option, + variables: Option, } impl GraphQLQuery { - /// Shortcut method to execute the request on the schema. - pub async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - let request: GraphQLRequest = self.into(); - request.execute(executor).await - } + /// Shortcut method to execute the request on the schema. + pub async fn execute(self, executor: &E) -> GraphQLResponse + where + E: Executor, + { + let request: GraphQLRequest = self.into(); + request.execute(executor).await + } } -static APPLICATION_JSON: Lazy = Lazy::new(|| HeaderValue::from_static("application/json")); +static APPLICATION_JSON: Lazy = + Lazy::new(|| HeaderValue::from_static("application/json")); impl GraphQLResponse { - pub fn to_response(self) -> Result> { - let mut response = Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, APPLICATION_JSON.as_ref()) - .body(Body::from(serde_json::to_string(&self.0)?))?; - - if self.0.is_ok() { - if let Some(cache_control) = self.0.cache_control().value() { - response - .headers_mut() - .insert(CACHE_CONTROL, HeaderValue::from_str(cache_control.as_str())?); - } + pub fn to_response(self) -> Result> { + let mut response = Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, APPLICATION_JSON.as_ref()) + .body(Body::from(serde_json::to_string(&self.0)?))?; + + if self.0.is_ok() { + if let Some(cache_control) = self.0.cache_control().value() { + response.headers_mut().insert( + CACHE_CONTROL, + HeaderValue::from_str(cache_control.as_str())?, + ); + } + } + + Ok(response) } - Ok(response) - } - - /// Sets the `cache_control` for a given `GraphQLResponse`. - /// - /// The function modifies the `GraphQLResponse` to set the `cache_control` `max_age` - /// to the specified `min_cache` value and `public` flag to `cache_public` - /// - /// # Arguments - /// - /// * `res` - The GraphQL response whose `cache_control` is to be set. - /// * `min_cache` - The `max_age` value to be set for `cache_control`. - /// * `cache_public` - The negation of `public` flag to be set for `cache_control`. - /// - /// # Returns - /// - /// * A modified `GraphQLResponse` with updated `cache_control` `max_age` and `public` flag. - pub fn set_cache_control(mut self, min_cache: i32, cache_public: bool) -> GraphQLResponse { - match self.0 { - BatchResponse::Single(ref mut res) => { - res.cache_control.max_age = min_cache; - res.cache_control.public = cache_public; - } - BatchResponse::Batch(ref mut list) => { - for res in list { - res.cache_control.max_age = min_cache; - res.cache_control.public = cache_public; - } - } - }; - self - } + /// Sets the `cache_control` for a given `GraphQLResponse`. + /// + /// The function modifies the `GraphQLResponse` to set the `cache_control` `max_age` + /// to the specified `min_cache` value and `public` flag to `cache_public` + /// + /// # Arguments + /// + /// * `res` - The GraphQL response whose `cache_control` is to be set. + /// * `min_cache` - The `max_age` value to be set for `cache_control`. + /// * `cache_public` - The negation of `public` flag to be set for `cache_control`. + /// + /// # Returns + /// + /// * A modified `GraphQLResponse` with updated `cache_control` `max_age` and `public` flag. + pub fn set_cache_control(mut self, min_cache: i32, cache_public: bool) -> GraphQLResponse { + match self.0 { + BatchResponse::Single(ref mut res) => { + res.cache_control.max_age = min_cache; + res.cache_control.public = cache_public; + } + BatchResponse::Batch(ref mut list) => { + for res in list { + res.cache_control.max_age = min_cache; + res.cache_control.public = cache_public; + } + } + }; + self + } } diff --git a/src/blueprint/blueprint.rs b/src/blueprint/blueprint.rs index e7612c39b92..e7534086fbb 100644 --- a/src/blueprint/blueprint.rs +++ b/src/blueprint/blueprint.rs @@ -19,169 +19,169 @@ use crate::lambda::{Expression, Lambda}; /// It's not optimized for REST APIs (yet). #[derive(Clone, Debug, Default, Setters)] pub struct Blueprint { - pub definitions: Vec, - pub schema: SchemaDefinition, - pub server: Server, - pub upstream: Upstream, + pub definitions: Vec, + pub schema: SchemaDefinition, + pub server: Server, + pub upstream: Upstream, } #[derive(Clone, Debug)] pub enum Type { - NamedType { name: String, non_null: bool }, - ListType { of_type: Box, non_null: bool }, + NamedType { name: String, non_null: bool }, + ListType { of_type: Box, non_null: bool }, } impl Default for Type { - fn default() -> Self { - Type::NamedType { name: "JSON".to_string(), non_null: false } - } + fn default() -> Self { + Type::NamedType { name: "JSON".to_string(), non_null: false } + } } impl Type { - pub fn name(&self) -> &str { - match self { - Type::NamedType { name, .. } => name, - Type::ListType { of_type, .. } => of_type.name(), + pub fn name(&self) -> &str { + match self { + Type::NamedType { name, .. } => name, + Type::ListType { of_type, .. } => of_type.name(), + } } - } - /// checks if the type is nullable - pub fn is_nullable(&self) -> bool { - !match self { - Type::NamedType { non_null, .. } => *non_null, - Type::ListType { non_null, .. } => *non_null, + /// checks if the type is nullable + pub fn is_nullable(&self) -> bool { + !match self { + Type::NamedType { non_null, .. } => *non_null, + Type::ListType { non_null, .. } => *non_null, + } } - } } #[derive(Clone, Debug)] pub enum Definition { - InterfaceTypeDefinition(InterfaceTypeDefinition), - ObjectTypeDefinition(ObjectTypeDefinition), - InputObjectTypeDefinition(InputObjectTypeDefinition), - ScalarTypeDefinition(ScalarTypeDefinition), - EnumTypeDefinition(EnumTypeDefinition), - UnionTypeDefinition(UnionTypeDefinition), + InterfaceTypeDefinition(InterfaceTypeDefinition), + ObjectTypeDefinition(ObjectTypeDefinition), + InputObjectTypeDefinition(InputObjectTypeDefinition), + ScalarTypeDefinition(ScalarTypeDefinition), + EnumTypeDefinition(EnumTypeDefinition), + UnionTypeDefinition(UnionTypeDefinition), } impl Definition { - pub fn name(&self) -> &str { - match self { - Definition::InterfaceTypeDefinition(def) => &def.name, - Definition::ObjectTypeDefinition(def) => &def.name, - Definition::InputObjectTypeDefinition(def) => &def.name, - Definition::ScalarTypeDefinition(def) => &def.name, - Definition::EnumTypeDefinition(def) => &def.name, - Definition::UnionTypeDefinition(def) => &def.name, + pub fn name(&self) -> &str { + match self { + Definition::InterfaceTypeDefinition(def) => &def.name, + Definition::ObjectTypeDefinition(def) => &def.name, + Definition::InputObjectTypeDefinition(def) => &def.name, + Definition::ScalarTypeDefinition(def) => &def.name, + Definition::EnumTypeDefinition(def) => &def.name, + Definition::UnionTypeDefinition(def) => &def.name, + } } - } } #[derive(Clone, Debug)] pub struct InterfaceTypeDefinition { - pub name: String, - pub fields: Vec, - pub description: Option, + pub name: String, + pub fields: Vec, + pub description: Option, } #[derive(Clone, Debug)] pub struct ObjectTypeDefinition { - pub name: String, - pub fields: Vec, - pub description: Option, - pub implements: BTreeSet, + pub name: String, + pub fields: Vec, + pub description: Option, + pub implements: BTreeSet, } #[derive(Clone, Debug)] pub struct InputObjectTypeDefinition { - pub name: String, - pub fields: Vec, - pub description: Option, + pub name: String, + pub fields: Vec, + pub description: Option, } #[derive(Clone, Debug)] pub struct EnumTypeDefinition { - pub name: String, - pub directives: Vec, - pub description: Option, - pub enum_values: Vec, + pub name: String, + pub directives: Vec, + pub description: Option, + pub enum_values: Vec, } #[derive(Clone, Debug)] pub struct EnumValueDefinition { - pub description: Option, - pub name: String, - pub directives: Vec, + pub description: Option, + pub name: String, + pub directives: Vec, } #[derive(Clone, Debug, Default)] pub struct SchemaDefinition { - pub query: String, - pub mutation: Option, - pub directives: Vec, + pub query: String, + pub mutation: Option, + pub directives: Vec, } #[derive(Clone, Debug)] pub struct InputFieldDefinition { - pub name: String, - pub of_type: Type, - pub default_value: Option, - pub description: Option, + pub name: String, + pub of_type: Type, + pub default_value: Option, + pub description: Option, } #[derive(Clone, Debug)] pub struct Cache { - pub max_age: NonZeroU64, - pub hasher: DefaultHasher, + pub max_age: NonZeroU64, + pub hasher: DefaultHasher, } #[derive(Clone, Debug, Setters, Default)] pub struct FieldDefinition { - pub name: String, - pub args: Vec, - pub of_type: Type, - pub resolver: Option, - pub directives: Vec, - pub description: Option, - pub cache: Option, + pub name: String, + pub args: Vec, + pub of_type: Type, + pub resolver: Option, + pub directives: Vec, + pub description: Option, + pub cache: Option, } impl FieldDefinition { - pub fn to_lambda(self) -> Option> { - self.resolver.map(Lambda::new) - } - - pub fn resolver_or_default( - mut self, - default_res: Lambda, - other: impl Fn(Lambda) -> Lambda, - ) -> Self { - self.resolver = match self.resolver { - None => Some(default_res.expression), - Some(expr) => Some(other(Lambda::new(expr)).expression), - }; - self - } + pub fn to_lambda(self) -> Option> { + self.resolver.map(Lambda::new) + } + + pub fn resolver_or_default( + mut self, + default_res: Lambda, + other: impl Fn(Lambda) -> Lambda, + ) -> Self { + self.resolver = match self.resolver { + None => Some(default_res.expression), + Some(expr) => Some(other(Lambda::new(expr)).expression), + }; + self + } } #[derive(Clone, Debug)] pub struct Directive { - pub name: String, - pub arguments: HashMap, - pub index: usize, + pub name: String, + pub arguments: HashMap, + pub index: usize, } #[derive(Clone, Debug)] pub struct ScalarTypeDefinition { - pub name: String, - pub directive: Vec, - pub description: Option, + pub name: String, + pub directive: Vec, + pub description: Option, } #[derive(Clone, Debug)] pub struct UnionTypeDefinition { - pub name: String, - pub directives: Vec, - pub description: Option, - pub types: BTreeSet, + pub name: String, + pub directives: Vec, + pub description: Option, + pub types: BTreeSet, } /// @@ -189,80 +189,80 @@ pub struct UnionTypeDefinition { /// #[derive(Copy, Clone, Debug, Default)] pub struct SchemaModifiers { - /// If true, the generated schema will not have any resolvers. - pub no_resolver: bool, + /// If true, the generated schema will not have any resolvers. + pub no_resolver: bool, } impl SchemaModifiers { - pub fn no_resolver() -> Self { - Self { no_resolver: true } - } + pub fn no_resolver() -> Self { + Self { no_resolver: true } + } } impl Blueprint { - pub fn query(&self) -> String { - self.schema.query.clone() - } - - pub fn mutation(&self) -> Option { - self.schema.mutation.clone() - } - - fn drop_resolvers(mut self) -> Self { - for def in self.definitions.iter_mut() { - if let Definition::ObjectTypeDefinition(def) = def { - for field in def.fields.iter_mut() { - field.resolver = None; - } - } + pub fn query(&self) -> String { + self.schema.query.clone() } - self - } - - /// - /// This function is used to generate a schema from a blueprint. - /// - pub fn to_schema(&self) -> Schema { - self.to_schema_with(SchemaModifiers::default()) - } - - /// - /// This function is used to generate a schema from a blueprint. - /// The generated schema can be modified using the SchemaModifiers. - /// - pub fn to_schema_with(&self, schema_modifiers: SchemaModifiers) -> Schema { - let blueprint = if schema_modifiers.no_resolver { - self.clone().drop_resolvers() - } else { - self.clone() - }; - - let server = &blueprint.server; - let mut schema = SchemaBuilder::from(&blueprint); - - if server.enable_apollo_tracing { - schema = schema.extension(ApolloTracing); + pub fn mutation(&self) -> Option { + self.schema.mutation.clone() } - if server.global_response_timeout > 0 { - schema = schema - .data(async_graphql::Value::from(server.global_response_timeout)) - .extension(GlobalTimeout); - } + fn drop_resolvers(mut self) -> Self { + for def in self.definitions.iter_mut() { + if let Definition::ObjectTypeDefinition(def) = def { + for field in def.fields.iter_mut() { + field.resolver = None; + } + } + } - if server.get_enable_query_validation() || schema_modifiers.no_resolver { - schema = schema.validation_mode(ValidationMode::Strict); - } else { - schema = schema.validation_mode(ValidationMode::Fast); + self } - if !server.get_enable_introspection() || schema_modifiers.no_resolver { - schema = schema.disable_introspection(); + /// + /// This function is used to generate a schema from a blueprint. + /// + pub fn to_schema(&self) -> Schema { + self.to_schema_with(SchemaModifiers::default()) } - // We should safely assume the blueprint is correct and, - // generation of schema cannot fail. - schema.finish().unwrap() - } + /// + /// This function is used to generate a schema from a blueprint. + /// The generated schema can be modified using the SchemaModifiers. + /// + pub fn to_schema_with(&self, schema_modifiers: SchemaModifiers) -> Schema { + let blueprint = if schema_modifiers.no_resolver { + self.clone().drop_resolvers() + } else { + self.clone() + }; + + let server = &blueprint.server; + let mut schema = SchemaBuilder::from(&blueprint); + + if server.enable_apollo_tracing { + schema = schema.extension(ApolloTracing); + } + + if server.global_response_timeout > 0 { + schema = schema + .data(async_graphql::Value::from(server.global_response_timeout)) + .extension(GlobalTimeout); + } + + if server.get_enable_query_validation() || schema_modifiers.no_resolver { + schema = schema.validation_mode(ValidationMode::Strict); + } else { + schema = schema.validation_mode(ValidationMode::Fast); + } + + if !server.get_enable_introspection() || schema_modifiers.no_resolver { + schema = schema.disable_introspection(); + } + + // We should safely assume the blueprint is correct and, + // generation of schema cannot fail. + schema.finish().unwrap() + } } diff --git a/src/blueprint/compress.rs b/src/blueprint/compress.rs index 8e87569eb98..b8cba9046d7 100644 --- a/src/blueprint/compress.rs +++ b/src/blueprint/compress.rs @@ -4,100 +4,100 @@ use super::{Blueprint, Definition}; // compress() takes a Blueprint and returns a compressed Blueprint. So that unused types are removed. pub fn compress(mut blueprint: Blueprint) -> Blueprint { - let graph = build_dependency_graph(&blueprint); + let graph = build_dependency_graph(&blueprint); - // Pre-defined root-types for graphql - let mut root_type = vec!["Query", "Mutation", "Subscription"]; + // Pre-defined root-types for graphql + let mut root_type = vec!["Query", "Mutation", "Subscription"]; - // User-might create custom root-types other than default i.e non-default types for root-definitions. - let defined_query_type = blueprint.query().clone(); - let mutation = blueprint.mutation().unwrap_or("Mutation".to_string()); + // User-might create custom root-types other than default i.e non-default types for root-definitions. + let defined_query_type = blueprint.query().clone(); + let mutation = blueprint.mutation().unwrap_or("Mutation".to_string()); - // Push to root-types - root_type.push(defined_query_type.as_str()); - root_type.push(mutation.as_str()); + // Push to root-types + root_type.push(defined_query_type.as_str()); + root_type.push(mutation.as_str()); - let mut referenced_types = identify_referenced_types(&graph, root_type); - referenced_types.insert("Query".to_string()); - referenced_types.insert("Mutation".to_string()); - referenced_types.insert("Subscription".to_string()); - referenced_types.insert("__Schema".to_string()); - referenced_types.insert("__Type".to_string()); - referenced_types.insert("__Field".to_string()); - referenced_types.insert("__InputValue".to_string()); - referenced_types.insert("__EnumValue".to_string()); - referenced_types.insert("__Directive".to_string()); - referenced_types.insert("__DirectiveLocation".to_string()); + let mut referenced_types = identify_referenced_types(&graph, root_type); + referenced_types.insert("Query".to_string()); + referenced_types.insert("Mutation".to_string()); + referenced_types.insert("Subscription".to_string()); + referenced_types.insert("__Schema".to_string()); + referenced_types.insert("__Type".to_string()); + referenced_types.insert("__Field".to_string()); + referenced_types.insert("__InputValue".to_string()); + referenced_types.insert("__EnumValue".to_string()); + referenced_types.insert("__Directive".to_string()); + referenced_types.insert("__DirectiveLocation".to_string()); - let mut definitions = Vec::new(); - for def in blueprint.definitions.iter() { - if referenced_types.contains(def.name()) { - definitions.push(def.clone()); + let mut definitions = Vec::new(); + for def in blueprint.definitions.iter() { + if referenced_types.contains(def.name()) { + definitions.push(def.clone()); + } } - } - blueprint.definitions = definitions; - blueprint + blueprint.definitions = definitions; + blueprint } fn build_dependency_graph(blueprint: &Blueprint) -> HashMap<&str, Vec<&str>> { - let mut graph: HashMap<&str, Vec<&str>> = HashMap::new(); + let mut graph: HashMap<&str, Vec<&str>> = HashMap::new(); - for def in &blueprint.definitions { - let type_name = def.name(); - let mut dependencies: Vec<&str> = Vec::new(); + for def in &blueprint.definitions { + let type_name = def.name(); + let mut dependencies: Vec<&str> = Vec::new(); - match def { - Definition::ObjectTypeDefinition(def) => { - dependencies.extend(def.fields.iter().map(|field| field.of_type.name())); - for field in &def.fields { - dependencies.extend(field.args.iter().map(|arg| arg.of_type.name())); - } - dependencies.extend(def.implements.iter().map(|s| s.as_str())); - } - Definition::InterfaceTypeDefinition(def) => { - dependencies.extend(def.fields.iter().map(|field| field.of_type.name())); - for def_inner in &blueprint.definitions { - if let Definition::ObjectTypeDefinition(def_inner) = def_inner { - if def_inner.implements.contains(&def.name) { - dependencies.push(&def_inner.name); + match def { + Definition::ObjectTypeDefinition(def) => { + dependencies.extend(def.fields.iter().map(|field| field.of_type.name())); + for field in &def.fields { + dependencies.extend(field.args.iter().map(|arg| arg.of_type.name())); + } + dependencies.extend(def.implements.iter().map(|s| s.as_str())); + } + Definition::InterfaceTypeDefinition(def) => { + dependencies.extend(def.fields.iter().map(|field| field.of_type.name())); + for def_inner in &blueprint.definitions { + if let Definition::ObjectTypeDefinition(def_inner) = def_inner { + if def_inner.implements.contains(&def.name) { + dependencies.push(&def_inner.name); + } + } + } + } + Definition::InputObjectTypeDefinition(def) => { + dependencies.extend(def.fields.iter().map(|field| field.of_type.name())); + } + Definition::EnumTypeDefinition(def) => { + dependencies.extend(def.enum_values.iter().map(|value| value.name.as_str())); + } + Definition::UnionTypeDefinition(def) => { + dependencies.extend(def.types.iter().map(|s| s.as_str())); + } + Definition::ScalarTypeDefinition(sc) => { + dependencies.push(sc.name.as_str()); } - } } - } - Definition::InputObjectTypeDefinition(def) => { - dependencies.extend(def.fields.iter().map(|field| field.of_type.name())); - } - Definition::EnumTypeDefinition(def) => { - dependencies.extend(def.enum_values.iter().map(|value| value.name.as_str())); - } - Definition::UnionTypeDefinition(def) => { - dependencies.extend(def.types.iter().map(|s| s.as_str())); - } - Definition::ScalarTypeDefinition(sc) => { - dependencies.push(sc.name.as_str()); - } - } - graph.insert(type_name, dependencies); - } - graph + graph.insert(type_name, dependencies); + } + graph } // Function to perform DFS and identify all reachable types fn identify_referenced_types(graph: &HashMap<&str, Vec<&str>>, root: Vec<&str>) -> HashSet { - let mut stack = root; - let mut referenced_types = HashSet::new(); + let mut stack = root; + let mut referenced_types = HashSet::new(); - while let Some(type_name) = stack.pop() { - if referenced_types.insert(type_name.to_string()) { - if let Some(dependencies) = graph.get(type_name) { - for dependency in dependencies { - stack.push(dependency); + while let Some(type_name) = stack.pop() { + if referenced_types.insert(type_name.to_string()) { + if let Some(dependencies) = graph.get(type_name) { + for dependency in dependencies { + stack.push(dependency); + } + } } - } } - } - referenced_types + referenced_types } diff --git a/src/blueprint/definitions.rs b/src/blueprint/definitions.rs index 01843739289..4f8997aa831 100644 --- a/src/blueprint/definitions.rs +++ b/src/blueprint/definitions.rs @@ -14,289 +14,317 @@ use crate::try_fold::TryFold; use crate::valid::Valid; pub fn to_scalar_type_definition(name: &str) -> Valid { - Valid::succeed(Definition::ScalarTypeDefinition(ScalarTypeDefinition { - name: name.to_string(), - directive: Vec::new(), - description: None, - })) + Valid::succeed(Definition::ScalarTypeDefinition(ScalarTypeDefinition { + name: name.to_string(), + directive: Vec::new(), + description: None, + })) } pub fn to_union_type_definition((name, u): (&String, &Union)) -> Definition { - Definition::UnionTypeDefinition(UnionTypeDefinition { - name: name.to_owned(), - description: u.doc.clone(), - directives: Vec::new(), - types: u.types.clone(), - }) + Definition::UnionTypeDefinition(UnionTypeDefinition { + name: name.to_owned(), + description: u.doc.clone(), + directives: Vec::new(), + types: u.types.clone(), + }) } -pub fn to_input_object_type_definition(definition: ObjectTypeDefinition) -> Valid { - Valid::succeed(Definition::InputObjectTypeDefinition(InputObjectTypeDefinition { - name: definition.name, - fields: definition - .fields - .iter() - .map(|field| InputFieldDefinition { - name: field.name.clone(), - description: field.description.clone(), - default_value: None, - of_type: field.of_type.clone(), - }) - .collect(), - description: definition.description, - })) +pub fn to_input_object_type_definition( + definition: ObjectTypeDefinition, +) -> Valid { + Valid::succeed(Definition::InputObjectTypeDefinition( + InputObjectTypeDefinition { + name: definition.name, + fields: definition + .fields + .iter() + .map(|field| InputFieldDefinition { + name: field.name.clone(), + description: field.description.clone(), + default_value: None, + of_type: field.of_type.clone(), + }) + .collect(), + description: definition.description, + }, + )) } pub fn to_interface_type_definition(definition: ObjectTypeDefinition) -> Valid { - Valid::succeed(Definition::InterfaceTypeDefinition(InterfaceTypeDefinition { - name: definition.name, - fields: definition.fields, - description: definition.description, - })) + Valid::succeed(Definition::InterfaceTypeDefinition( + InterfaceTypeDefinition { + name: definition.name, + fields: definition.fields, + description: definition.description, + }, + )) } type InvalidPathHandler = dyn Fn(&str, &[String], &[String]) -> Valid; type PathResolverErrorHandler = dyn Fn(&str, &str, &str, &[String]) -> Valid; struct ProcessFieldWithinTypeContext<'a> { - field: &'a config::Field, - field_name: &'a str, - remaining_path: &'a [String], - type_info: &'a config::Type, - is_required: bool, - config: &'a Config, - invalid_path_handler: &'a InvalidPathHandler, - path_resolver_error_handler: &'a PathResolverErrorHandler, - original_path: &'a [String], + field: &'a config::Field, + field_name: &'a str, + remaining_path: &'a [String], + type_info: &'a config::Type, + is_required: bool, + config: &'a Config, + invalid_path_handler: &'a InvalidPathHandler, + path_resolver_error_handler: &'a PathResolverErrorHandler, + original_path: &'a [String], } #[derive(Clone)] struct ProcessPathContext<'a> { - path: &'a [String], - field: &'a config::Field, - type_info: &'a config::Type, - is_required: bool, - config: &'a Config, - invalid_path_handler: &'a InvalidPathHandler, - path_resolver_error_handler: &'a PathResolverErrorHandler, - original_path: &'a [String], + path: &'a [String], + field: &'a config::Field, + type_info: &'a config::Type, + is_required: bool, + config: &'a Config, + invalid_path_handler: &'a InvalidPathHandler, + path_resolver_error_handler: &'a PathResolverErrorHandler, + original_path: &'a [String], } fn process_field_within_type(context: ProcessFieldWithinTypeContext) -> Valid { - let field = context.field; - let field_name = context.field_name; - let remaining_path = context.remaining_path; - let type_info = context.type_info; - let is_required = context.is_required; - let config = context.config; - let invalid_path_handler = context.invalid_path_handler; - let path_resolver_error_handler = context.path_resolver_error_handler; - - if let Some(next_field) = type_info.fields.get(field_name) { - if next_field.has_resolver() { - let next_dir_http = next_field.http.as_ref().map(|_| config::Http::directive_name()); - let next_dir_const = next_field.const_field.as_ref().map(|_| config::Const::directive_name()); - return path_resolver_error_handler( - next_dir_http - .or(next_dir_const) - .unwrap_or(config::JS::directive_name()) - .as_str(), - &field.type_of, - field_name, - context.original_path, - ) - .and(process_path(ProcessPathContext { - type_info, - is_required, - config, - invalid_path_handler, - path_resolver_error_handler, - path: remaining_path, - field: next_field, - original_path: context.original_path, - })); - } + let field = context.field; + let field_name = context.field_name; + let remaining_path = context.remaining_path; + let type_info = context.type_info; + let is_required = context.is_required; + let config = context.config; + let invalid_path_handler = context.invalid_path_handler; + let path_resolver_error_handler = context.path_resolver_error_handler; + + if let Some(next_field) = type_info.fields.get(field_name) { + if next_field.has_resolver() { + let next_dir_http = next_field + .http + .as_ref() + .map(|_| config::Http::directive_name()); + let next_dir_const = next_field + .const_field + .as_ref() + .map(|_| config::Const::directive_name()); + return path_resolver_error_handler( + next_dir_http + .or(next_dir_const) + .unwrap_or(config::JS::directive_name()) + .as_str(), + &field.type_of, + field_name, + context.original_path, + ) + .and(process_path(ProcessPathContext { + type_info, + is_required, + config, + invalid_path_handler, + path_resolver_error_handler, + path: remaining_path, + field: next_field, + original_path: context.original_path, + })); + } - let next_is_required = is_required && next_field.required; - if is_scalar(&next_field.type_of) { - return process_path(ProcessPathContext { - type_info, - config, - invalid_path_handler, - path_resolver_error_handler, - path: remaining_path, - field: next_field, - is_required: next_is_required, - original_path: context.original_path, - }); - } + let next_is_required = is_required && next_field.required; + if is_scalar(&next_field.type_of) { + return process_path(ProcessPathContext { + type_info, + config, + invalid_path_handler, + path_resolver_error_handler, + path: remaining_path, + field: next_field, + is_required: next_is_required, + original_path: context.original_path, + }); + } - if let Some(next_type_info) = config.find_type(&next_field.type_of) { - return process_path(ProcessPathContext { - config, - invalid_path_handler, - path_resolver_error_handler, - path: remaining_path, - field: next_field, - type_info: next_type_info, - is_required: next_is_required, - original_path: context.original_path, - }) - .and_then(|of_type| { - if next_field.list { - Valid::succeed(ListType { of_type: Box::new(of_type), non_null: is_required }) - } else { - Valid::succeed(of_type) + if let Some(next_type_info) = config.find_type(&next_field.type_of) { + return process_path(ProcessPathContext { + config, + invalid_path_handler, + path_resolver_error_handler, + path: remaining_path, + field: next_field, + type_info: next_type_info, + is_required: next_is_required, + original_path: context.original_path, + }) + .and_then(|of_type| { + if next_field.list { + Valid::succeed(ListType { of_type: Box::new(of_type), non_null: is_required }) + } else { + Valid::succeed(of_type) + } + }); + } + } else if let Some((head, tail)) = remaining_path.split_first() { + if let Some(field) = type_info.fields.get(head) { + return process_path(ProcessPathContext { + path: tail, + field, + type_info, + is_required, + config, + invalid_path_handler, + path_resolver_error_handler, + original_path: context.original_path, + }); } - }); - } - } else if let Some((head, tail)) = remaining_path.split_first() { - if let Some(field) = type_info.fields.get(head) { - return process_path(ProcessPathContext { - path: tail, - field, - type_info, - is_required, - config, - invalid_path_handler, - path_resolver_error_handler, - original_path: context.original_path, - }); } - } - invalid_path_handler(field_name, remaining_path, context.original_path) + invalid_path_handler(field_name, remaining_path, context.original_path) } // Helper function to recursively process the path and return the corresponding type fn process_path(context: ProcessPathContext) -> Valid { - let path = context.path; - let field = context.field; - let type_info = context.type_info; - let is_required = context.is_required; - let config = context.config; - let invalid_path_handler = context.invalid_path_handler; - let path_resolver_error_handler = context.path_resolver_error_handler; - if let Some((field_name, remaining_path)) = path.split_first() { - if field_name.parse::().is_ok() { - let mut modified_field = field.clone(); - modified_field.list = false; - return process_path(ProcessPathContext { - config, - type_info, - invalid_path_handler, - path_resolver_error_handler, - path: remaining_path, - field: &modified_field, - is_required: false, - original_path: context.original_path, - }); - } - let target_type_info = type_info - .fields - .get(field_name) - .map(|_| type_info) - .or_else(|| config.find_type(&field.type_of)); - - if let Some(type_info) = target_type_info { - return process_field_within_type(ProcessFieldWithinTypeContext { - field, - field_name, - remaining_path, - type_info, - is_required, - config, - invalid_path_handler, - path_resolver_error_handler, - original_path: context.original_path, - }); + let path = context.path; + let field = context.field; + let type_info = context.type_info; + let is_required = context.is_required; + let config = context.config; + let invalid_path_handler = context.invalid_path_handler; + let path_resolver_error_handler = context.path_resolver_error_handler; + if let Some((field_name, remaining_path)) = path.split_first() { + if field_name.parse::().is_ok() { + let mut modified_field = field.clone(); + modified_field.list = false; + return process_path(ProcessPathContext { + config, + type_info, + invalid_path_handler, + path_resolver_error_handler, + path: remaining_path, + field: &modified_field, + is_required: false, + original_path: context.original_path, + }); + } + let target_type_info = type_info + .fields + .get(field_name) + .map(|_| type_info) + .or_else(|| config.find_type(&field.type_of)); + + if let Some(type_info) = target_type_info { + return process_field_within_type(ProcessFieldWithinTypeContext { + field, + field_name, + remaining_path, + type_info, + is_required, + config, + invalid_path_handler, + path_resolver_error_handler, + original_path: context.original_path, + }); + } + return invalid_path_handler(field_name, path, context.original_path); } - return invalid_path_handler(field_name, path, context.original_path); - } - Valid::succeed(to_type(field, Some(is_required))) + Valid::succeed(to_type(field, Some(is_required))) } -fn to_enum_type_definition(name: &str, type_: &config::Type, variants: &BTreeSet) -> Valid { - let enum_type_definition = Definition::EnumTypeDefinition(EnumTypeDefinition { - name: name.to_string(), - directives: Vec::new(), - description: type_.doc.clone(), - enum_values: variants - .iter() - .map(|variant| EnumValueDefinition { description: None, name: variant.clone(), directives: Vec::new() }) - .collect(), - }); - Valid::succeed(enum_type_definition) +fn to_enum_type_definition( + name: &str, + type_: &config::Type, + variants: &BTreeSet, +) -> Valid { + let enum_type_definition = Definition::EnumTypeDefinition(EnumTypeDefinition { + name: name.to_string(), + directives: Vec::new(), + description: type_.doc.clone(), + enum_values: variants + .iter() + .map(|variant| EnumValueDefinition { + description: None, + name: variant.clone(), + directives: Vec::new(), + }) + .collect(), + }); + Valid::succeed(enum_type_definition) } -fn to_object_type_definition(name: &str, type_of: &config::Type, config: &Config) -> Valid { - to_fields(name, type_of, config).map(|fields| { - Definition::ObjectTypeDefinition(ObjectTypeDefinition { - name: name.to_string(), - description: type_of.doc.clone(), - fields, - implements: type_of.implements.clone(), +fn to_object_type_definition( + name: &str, + type_of: &config::Type, + config: &Config, +) -> Valid { + to_fields(name, type_of, config).map(|fields| { + Definition::ObjectTypeDefinition(ObjectTypeDefinition { + name: name.to_string(), + description: type_of.doc.clone(), + fields, + implements: type_of.implements.clone(), + }) }) - }) } fn update_args<'a>( - hasher: DefaultHasher, + hasher: DefaultHasher, ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new(move |(_, field, _, name), _| { - let mut hasher = hasher.clone(); - name.hash(&mut hasher); - let cache = field - .cache - .as_ref() - .map(|config::Cache { max_age }| Cache { max_age: *max_age, hasher }); - - // TODO! assert type name - Valid::from_iter(field.args.iter(), |(name, arg)| { - Valid::succeed(InputFieldDefinition { - name: name.clone(), - description: arg.doc.clone(), - of_type: to_type(arg, None), - default_value: arg.default_value.clone(), - }) - }) - .map(|args| FieldDefinition { - name: name.to_string(), - description: field.doc.clone(), - args, - of_type: to_type(*field, None), - directives: Vec::new(), - resolver: None, - cache, - }) - }) + TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( + move |(_, field, _, name), _| { + let mut hasher = hasher.clone(); + name.hash(&mut hasher); + let cache = field + .cache + .as_ref() + .map(|config::Cache { max_age }| Cache { max_age: *max_age, hasher }); + + // TODO! assert type name + Valid::from_iter(field.args.iter(), |(name, arg)| { + Valid::succeed(InputFieldDefinition { + name: name.clone(), + description: arg.doc.clone(), + of_type: to_type(arg, None), + default_value: arg.default_value.clone(), + }) + }) + .map(|args| FieldDefinition { + name: name.to_string(), + description: field.doc.clone(), + args, + of_type: to_type(*field, None), + directives: Vec::new(), + resolver: None, + cache, + }) + }, + ) } fn item_is_numberic(list: &[String]) -> bool { - list.iter().any(|s| { - let re = Regex::new(r"^\d+$").unwrap(); - re.is_match(s) - }) + list.iter().any(|s| { + let re = Regex::new(r"^\d+$").unwrap(); + re.is_match(s) + }) } fn update_resolver_from_path( - context: &ProcessPathContext, - base_field: blueprint::FieldDefinition, + context: &ProcessPathContext, + base_field: blueprint::FieldDefinition, ) -> Valid { - let has_index = item_is_numberic(context.path); - - process_path(context.clone()).and_then(|of_type| { - let mut updated_base_field = base_field; - let resolver = Lambda::context_path(context.path.to_owned()); - if has_index { - updated_base_field.of_type = Type::NamedType { name: of_type.name().to_string(), non_null: false } - } else { - updated_base_field.of_type = of_type; - } + let has_index = item_is_numberic(context.path); + + process_path(context.clone()).and_then(|of_type| { + let mut updated_base_field = base_field; + let resolver = Lambda::context_path(context.path.to_owned()); + if has_index { + updated_base_field.of_type = + Type::NamedType { name: of_type.name().to_string(), non_null: false } + } else { + updated_base_field.of_type = of_type; + } - updated_base_field = updated_base_field.resolver_or_default(resolver, |r| r.to_input_path(context.path.to_owned())); - Valid::succeed(updated_base_field) - }) + updated_base_field = updated_base_field + .resolver_or_default(resolver, |r| r.to_input_path(context.path.to_owned())); + Valid::succeed(updated_base_field) + }) } /// Sets empty resolver to fields that has @@ -305,170 +333,190 @@ fn update_resolver_from_path( /// and nested resolvers won't be called pub fn update_nested_resolvers<'a>( ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( - move |(config, field, _, name), mut b_field| { - if !field.has_resolver() && validate_field_has_resolver(name, field, &config.types).is_succeed() { - b_field = b_field.resolver(Some(Expression::Literal(serde_json::Value::Object(Default::default())))); - } - - Valid::succeed(b_field) - }, - ) + TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( + move |(config, field, _, name), mut b_field| { + if !field.has_resolver() + && validate_field_has_resolver(name, field, &config.types).is_succeed() + { + b_field = b_field.resolver(Some(Expression::Literal(serde_json::Value::Object( + Default::default(), + )))); + } + + Valid::succeed(b_field) + }, + ) } fn validate_field_type_exist(config: &Config, field: &Field) -> Valid<(), String> { - let field_type = &field.type_of; - if !is_scalar(field_type) && !config.contains(field_type) { - Valid::fail(format!("Undeclared type '{field_type}' was found")) - } else { - Valid::succeed(()) - } + let field_type = &field.type_of; + if !is_scalar(field_type) && !config.contains(field_type) { + Valid::fail(format!("Undeclared type '{field_type}' was found")) + } else { + Valid::succeed(()) + } } -fn to_fields(object_name: &str, type_of: &config::Type, config: &Config) -> Valid, String> { - let operation_type = if config.schema.mutation.as_deref().eq(&Some(object_name)) { - GraphQLOperationType::Mutation - } else { - GraphQLOperationType::Query - }; +fn to_fields( + object_name: &str, + type_of: &config::Type, + config: &Config, +) -> Valid, String> { + let operation_type = if config.schema.mutation.as_deref().eq(&Some(object_name)) { + GraphQLOperationType::Mutation + } else { + GraphQLOperationType::Query + }; - let to_field = move |name: &String, field: &Field| { - let directives = field.resolvable_directives(); + let to_field = move |name: &String, field: &Field| { + let directives = field.resolvable_directives(); - if directives.len() > 1 { - return Valid::fail(format!("Multiple resolvers detected [{}]", directives.join(", "))); - } + if directives.len() > 1 { + return Valid::fail(format!( + "Multiple resolvers detected [{}]", + directives.join(", ") + )); + } - let mut hasher = DefaultHasher::new(); - object_name.hash(&mut hasher); - - update_args(hasher) - .and(update_http().trace(config::Http::trace_name().as_str())) - .and(update_grpc(&operation_type).trace(config::Grpc::trace_name().as_str())) - .and(update_js().trace(config::JS::trace_name().as_str())) - .and(update_const_field().trace(config::Const::trace_name().as_str())) - .and(update_graphql(&operation_type).trace(config::GraphQL::trace_name().as_str())) - .and(update_expr(&operation_type).trace(config::Expr::trace_name().as_str())) - .and(update_modify().trace(config::Modify::trace_name().as_str())) - .and(update_nested_resolvers()) - .try_fold(&(config, field, type_of, name), FieldDefinition::default()) - }; - - // Process fields that are not marked as `omit` - let fields = Valid::from_iter( - type_of.fields.iter().filter(|(_, field)| !field.is_omitted()), - |(name, field)| { - validate_field_type_exist(config, field) - .and(to_field(name, field)) - .trace(name) - }, - ); - - let to_added_field = - |add_field: &config::AddField, type_of: &config::Type| -> Valid { - let source_field = type_of - .fields - .iter() - .find(|&(field_name, _)| *field_name == add_field.path[0]); - match source_field { - Some((_, source_field)) => to_field(&add_field.name, source_field) - .and_then(|field_definition| { - let added_field_path = match source_field.http { - Some(_) => add_field.path[1..].iter().map(|s| s.to_owned()).collect::>(), - None => add_field.path.clone(), - }; - let invalid_path_handler = - |field_name: &str, _added_field_path: &[String], original_path: &[String]| -> Valid { - Valid::fail_with( - "Cannot add field".to_string(), - format!("Path [{}] does not exist", original_path.join(", ")), - ) - .trace(field_name) - }; - let path_resolver_error_handler = |resolver_name: &str, - field_type: &str, - field_name: &str, - original_path: &[String]| - -> Valid { - Valid::::fail_with( - "Cannot add field".to_string(), - format!( - "Path: [{}] contains resolver {} at [{}.{}]", - original_path.join(", "), - resolver_name, - field_type, - field_name - ), - ) - }; - update_resolver_from_path( - &ProcessPathContext { - path: &added_field_path, - field: source_field, - type_info: type_of, - is_required: false, - config, - invalid_path_handler: &invalid_path_handler, - path_resolver_error_handler: &path_resolver_error_handler, - original_path: &add_field.path, - }, - field_definition, - ) - }) - .trace(config::AddField::trace_name().as_str()), - None => Valid::fail(format!( - "Could not find field {} in path {}", - add_field.path[0], - add_field.path.join(",") - )), - } + let mut hasher = DefaultHasher::new(); + object_name.hash(&mut hasher); + + update_args(hasher) + .and(update_http().trace(config::Http::trace_name().as_str())) + .and(update_grpc(&operation_type).trace(config::Grpc::trace_name().as_str())) + .and(update_js().trace(config::JS::trace_name().as_str())) + .and(update_const_field().trace(config::Const::trace_name().as_str())) + .and(update_graphql(&operation_type).trace(config::GraphQL::trace_name().as_str())) + .and(update_expr(&operation_type).trace(config::Expr::trace_name().as_str())) + .and(update_modify().trace(config::Modify::trace_name().as_str())) + .and(update_nested_resolvers()) + .try_fold(&(config, field, type_of, name), FieldDefinition::default()) }; - let added_fields = Valid::from_iter(type_of.added_fields.iter(), |added_field| { - to_added_field(added_field, type_of) - }); - fields.zip(added_fields).map(|(mut fields, added_fields)| { - fields.extend(added_fields); - fields - }) + // Process fields that are not marked as `omit` + let fields = Valid::from_iter( + type_of + .fields + .iter() + .filter(|(_, field)| !field.is_omitted()), + |(name, field)| { + validate_field_type_exist(config, field) + .and(to_field(name, field)) + .trace(name) + }, + ); + + let to_added_field = |add_field: &config::AddField, + type_of: &config::Type| + -> Valid { + let source_field = type_of + .fields + .iter() + .find(|&(field_name, _)| *field_name == add_field.path[0]); + match source_field { + Some((_, source_field)) => to_field(&add_field.name, source_field) + .and_then(|field_definition| { + let added_field_path = match source_field.http { + Some(_) => add_field.path[1..] + .iter() + .map(|s| s.to_owned()) + .collect::>(), + None => add_field.path.clone(), + }; + let invalid_path_handler = |field_name: &str, + _added_field_path: &[String], + original_path: &[String]| + -> Valid { + Valid::fail_with( + "Cannot add field".to_string(), + format!("Path [{}] does not exist", original_path.join(", ")), + ) + .trace(field_name) + }; + let path_resolver_error_handler = |resolver_name: &str, + field_type: &str, + field_name: &str, + original_path: &[String]| + -> Valid { + Valid::::fail_with( + "Cannot add field".to_string(), + format!( + "Path: [{}] contains resolver {} at [{}.{}]", + original_path.join(", "), + resolver_name, + field_type, + field_name + ), + ) + }; + update_resolver_from_path( + &ProcessPathContext { + path: &added_field_path, + field: source_field, + type_info: type_of, + is_required: false, + config, + invalid_path_handler: &invalid_path_handler, + path_resolver_error_handler: &path_resolver_error_handler, + original_path: &add_field.path, + }, + field_definition, + ) + }) + .trace(config::AddField::trace_name().as_str()), + None => Valid::fail(format!( + "Could not find field {} in path {}", + add_field.path[0], + add_field.path.join(",") + )), + } + }; + + let added_fields = Valid::from_iter(type_of.added_fields.iter(), |added_field| { + to_added_field(added_field, type_of) + }); + fields.zip(added_fields).map(|(mut fields, added_fields)| { + fields.extend(added_fields); + fields + }) } pub fn to_definitions<'a>() -> TryFold<'a, Config, Vec, String> { - TryFold::, String>::new(|config, _| { - let output_types = config.output_types(); - let input_types = config.input_types(); - Valid::from_iter(config.types.iter(), |(name, type_)| { - let dbl_usage = input_types.contains(name) && output_types.contains(name); - if let Some(variants) = &type_.variants { - if !variants.is_empty() { - to_enum_type_definition(name, type_, variants).trace(name) - } else { - Valid::fail("No variants found for enum".to_string()) - } - } else if type_.scalar { - to_scalar_type_definition(name).trace(name) - } else if dbl_usage { - Valid::fail("type is used in input and output".to_string()).trace(name) - } else { - to_object_type_definition(name, type_, config) - .trace(name) - .and_then(|definition| match definition.clone() { - Definition::ObjectTypeDefinition(object_type_definition) => { - if config.input_types().contains(name) { - to_input_object_type_definition(object_type_definition).trace(name) - } else if type_.interface { - to_interface_type_definition(object_type_definition).trace(name) - } else { - Valid::succeed(definition) - } + TryFold::, String>::new(|config, _| { + let output_types = config.output_types(); + let input_types = config.input_types(); + Valid::from_iter(config.types.iter(), |(name, type_)| { + let dbl_usage = input_types.contains(name) && output_types.contains(name); + if let Some(variants) = &type_.variants { + if !variants.is_empty() { + to_enum_type_definition(name, type_, variants).trace(name) + } else { + Valid::fail("No variants found for enum".to_string()) + } + } else if type_.scalar { + to_scalar_type_definition(name).trace(name) + } else if dbl_usage { + Valid::fail("type is used in input and output".to_string()).trace(name) + } else { + to_object_type_definition(name, type_, config) + .trace(name) + .and_then(|definition| match definition.clone() { + Definition::ObjectTypeDefinition(object_type_definition) => { + if config.input_types().contains(name) { + to_input_object_type_definition(object_type_definition).trace(name) + } else if type_.interface { + to_interface_type_definition(object_type_definition).trace(name) + } else { + Valid::succeed(definition) + } + } + _ => Valid::succeed(definition), + }) } - _ => Valid::succeed(definition), - }) - } - }) - .map(|mut types| { - types.extend(config.unions.iter().map(to_union_type_definition)); - types + }) + .map(|mut types| { + types.extend(config.unions.iter().map(to_union_type_definition)); + types + }) }) - }) } diff --git a/src/blueprint/from_config.rs b/src/blueprint/from_config.rs index ff989ac0e15..2f7d80ff6b9 100644 --- a/src/blueprint/from_config.rs +++ b/src/blueprint/from_config.rs @@ -10,103 +10,107 @@ use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; pub fn config_blueprint<'a>() -> TryFold<'a, Config, Blueprint, String> { - let server = TryFoldConfig::::new(|config, blueprint| { - Valid::from(Server::try_from(config.server.clone())).map(|server| blueprint.server(server)) - }); + let server = TryFoldConfig::::new(|config, blueprint| { + Valid::from(Server::try_from(config.server.clone())).map(|server| blueprint.server(server)) + }); - let schema = to_schema().transform::( - |schema, blueprint| blueprint.schema(schema), - |blueprint| blueprint.schema, - ); + let schema = to_schema().transform::( + |schema, blueprint| blueprint.schema(schema), + |blueprint| blueprint.schema, + ); - let definitions = to_definitions().transform::( - |definitions, blueprint| blueprint.definitions(definitions), - |blueprint| blueprint.definitions, - ); + let definitions = to_definitions().transform::( + |definitions, blueprint| blueprint.definitions(definitions), + |blueprint| blueprint.definitions, + ); - let upstream = to_upstream().transform::( - |upstream, blueprint| blueprint.upstream(upstream), - |blueprint| blueprint.upstream, - ); + let upstream = to_upstream().transform::( + |upstream, blueprint| blueprint.upstream(upstream), + |blueprint| blueprint.upstream, + ); - server - .and(schema) - .and(definitions) - .and(upstream) - .update(apply_batching) - .update(compress) + server + .and(schema) + .and(definitions) + .and(upstream) + .update(apply_batching) + .update(compress) } // Apply batching if any of the fields have a @http directive with groupBy field pub fn apply_batching(mut blueprint: Blueprint) -> Blueprint { - for def in blueprint.definitions.iter() { - if let Definition::ObjectTypeDefinition(object_type_definition) = def { - for field in object_type_definition.fields.iter() { - if let Some(Expression::IO(IO::Http { group_by: Some(_), .. })) = field.resolver.clone() { - blueprint.upstream.batch = blueprint.upstream.batch.or(Some(Batch::default())); - return blueprint; + for def in blueprint.definitions.iter() { + if let Definition::ObjectTypeDefinition(object_type_definition) = def { + for field in object_type_definition.fields.iter() { + if let Some(Expression::IO(IO::Http { group_by: Some(_), .. })) = + field.resolver.clone() + { + blueprint.upstream.batch = blueprint.upstream.batch.or(Some(Batch::default())); + return blueprint; + } + } } - } } - } - blueprint + blueprint } pub fn to_json_schema_for_field(field: &Field, config: &Config) -> JsonSchema { - to_json_schema(field, config) + to_json_schema(field, config) } pub fn to_json_schema_for_args(args: &BTreeMap, config: &Config) -> JsonSchema { - let mut schema_fields = HashMap::new(); - for (name, arg) in args.iter() { - schema_fields.insert(name.clone(), to_json_schema(arg, config)); - } - JsonSchema::Obj(schema_fields) + let mut schema_fields = HashMap::new(); + for (name, arg) in args.iter() { + schema_fields.insert(name.clone(), to_json_schema(arg, config)); + } + JsonSchema::Obj(schema_fields) } fn to_json_schema(field: &T, config: &Config) -> JsonSchema where - T: TypeLike, + T: TypeLike, { - let type_of = field.name(); - let list = field.list(); - let required = field.non_null(); - let type_ = config.find_type(type_of); - let schema = match type_ { - Some(type_) => { - let mut schema_fields = HashMap::new(); - for (name, field) in type_.fields.iter() { - if field.script.is_none() && field.http.is_none() { - schema_fields.insert(name.clone(), to_json_schema_for_field(field, config)); + let type_of = field.name(); + let list = field.list(); + let required = field.non_null(); + let type_ = config.find_type(type_of); + let schema = match type_ { + Some(type_) => { + let mut schema_fields = HashMap::new(); + for (name, field) in type_.fields.iter() { + if field.script.is_none() && field.http.is_none() { + schema_fields.insert(name.clone(), to_json_schema_for_field(field, config)); + } + } + JsonSchema::Obj(schema_fields) } - } - JsonSchema::Obj(schema_fields) - } - None => match type_of { - "String" => JsonSchema::Str {}, - "Int" => JsonSchema::Num {}, - "Boolean" => JsonSchema::Bool {}, - "JSON" => JsonSchema::Obj(HashMap::new()), - _ => JsonSchema::Str {}, - }, - }; + None => match type_of { + "String" => JsonSchema::Str {}, + "Int" => JsonSchema::Num {}, + "Boolean" => JsonSchema::Bool {}, + "JSON" => JsonSchema::Obj(HashMap::new()), + _ => JsonSchema::Str {}, + }, + }; - if !required { - if list { - JsonSchema::Opt(Box::new(JsonSchema::Arr(Box::new(schema)))) + if !required { + if list { + JsonSchema::Opt(Box::new(JsonSchema::Arr(Box::new(schema)))) + } else { + JsonSchema::Opt(Box::new(schema)) + } + } else if list { + JsonSchema::Arr(Box::new(schema)) } else { - JsonSchema::Opt(Box::new(schema)) + schema } - } else if list { - JsonSchema::Arr(Box::new(schema)) - } else { - schema - } } impl TryFrom<&Config> for Blueprint { - type Error = ValidationError; + type Error = ValidationError; - fn try_from(config: &Config) -> Result { - config_blueprint().try_fold(config, Blueprint::default()).to_result() - } + fn try_from(config: &Config) -> Result { + config_blueprint() + .try_fold(config, Blueprint::default()) + .to_result() + } } diff --git a/src/blueprint/into_schema.rs b/src/blueprint/into_schema.rs index 9738e4044cb..2dab36d7663 100644 --- a/src/blueprint/into_schema.rs +++ b/src/blueprint/into_schema.rs @@ -12,179 +12,191 @@ use crate::json::JsonLike; use crate::lambda::{Concurrent, Eval, EvaluationContext}; fn to_type_ref(type_of: &Type) -> dynamic::TypeRef { - match type_of { - Type::NamedType { name, non_null } => { - if *non_null { - dynamic::TypeRef::NonNull(Box::from(dynamic::TypeRef::Named(Cow::Owned(name.clone())))) - } else { - dynamic::TypeRef::Named(Cow::Owned(name.clone())) - } - } - Type::ListType { of_type, non_null } => { - let inner = Box::new(to_type_ref(of_type)); - if *non_null { - dynamic::TypeRef::NonNull(Box::from(dynamic::TypeRef::List(inner))) - } else { - dynamic::TypeRef::List(inner) - } + match type_of { + Type::NamedType { name, non_null } => { + if *non_null { + dynamic::TypeRef::NonNull(Box::from(dynamic::TypeRef::Named(Cow::Owned( + name.clone(), + )))) + } else { + dynamic::TypeRef::Named(Cow::Owned(name.clone())) + } + } + Type::ListType { of_type, non_null } => { + let inner = Box::new(to_type_ref(of_type)); + if *non_null { + dynamic::TypeRef::NonNull(Box::from(dynamic::TypeRef::List(inner))) + } else { + dynamic::TypeRef::List(inner) + } + } } - } } fn get_cache_key<'a, H: Hasher + Clone>( - ctx: &'a EvaluationContext<'a, ResolverContext<'a>>, - mut hasher: H, + ctx: &'a EvaluationContext<'a, ResolverContext<'a>>, + mut hasher: H, ) -> Option { - // Hash on parent value - if let Some(const_value) = ctx - .graphql_ctx - .parent_value - .as_value() - // TODO: handle _id, id, or any field that has @key on it. - .filter(|value| value != &&ConstValue::Null) - .map(|data| data.get_key("id")) - { - // Hash on parent's id only? - helpers::value::hash(const_value?, &mut hasher); - } + // Hash on parent value + if let Some(const_value) = ctx + .graphql_ctx + .parent_value + .as_value() + // TODO: handle _id, id, or any field that has @key on it. + .filter(|value| value != &&ConstValue::Null) + .map(|data| data.get_key("id")) + { + // Hash on parent's id only? + helpers::value::hash(const_value?, &mut hasher); + } - let key = ctx - .graphql_ctx - .args - .iter() - .map(|(key, value)| { - let mut hasher = hasher.clone(); - key.hash(&mut hasher); - helpers::value::hash(value.as_value(), &mut hasher); - hasher.finish() - }) - .fold(hasher.finish(), |acc, val| acc ^ val); + let key = ctx + .graphql_ctx + .args + .iter() + .map(|(key, value)| { + let mut hasher = hasher.clone(); + key.hash(&mut hasher); + helpers::value::hash(value.as_value(), &mut hasher); + hasher.finish() + }) + .fold(hasher.finish(), |acc, val| acc ^ val); - Some(key) + Some(key) } fn to_type(def: &Definition) -> dynamic::Type { - match def { - Definition::ObjectTypeDefinition(def) => { - let mut object = dynamic::Object::new(def.name.clone()); - for field in def.fields.iter() { - let field = field.clone(); - let type_ref = to_type_ref(&field.of_type); - let field_name = &field.name.clone(); - let cache = field.cache.clone(); - let mut dyn_schema_field = dynamic::Field::new(field_name, type_ref, move |ctx| { - let req_ctx = ctx.ctx.data::>().unwrap(); - let field_name = &field.name; - match &field.resolver { - None => { - let ctx = EvaluationContext::new(req_ctx, &ctx); - FieldFuture::from_value(ctx.path_value(&[field_name]).map(|a| a.to_owned())) - } - Some(expr) => { - let expr = expr.to_owned(); - let cache = cache.clone(); - FieldFuture::new(async move { - let ctx = EvaluationContext::new(req_ctx, &ctx); + match def { + Definition::ObjectTypeDefinition(def) => { + let mut object = dynamic::Object::new(def.name.clone()); + for field in def.fields.iter() { + let field = field.clone(); + let type_ref = to_type_ref(&field.of_type); + let field_name = &field.name.clone(); + let cache = field.cache.clone(); + let mut dyn_schema_field = dynamic::Field::new(field_name, type_ref, move |ctx| { + let req_ctx = ctx.ctx.data::>().unwrap(); + let field_name = &field.name; + match &field.resolver { + None => { + let ctx = EvaluationContext::new(req_ctx, &ctx); + FieldFuture::from_value( + ctx.path_value(&[field_name]).map(|a| a.to_owned()), + ) + } + Some(expr) => { + let expr = expr.to_owned(); + let cache = cache.clone(); + FieldFuture::new(async move { + let ctx = EvaluationContext::new(req_ctx, &ctx); - let ttl_and_key = - cache.and_then(|Cache { max_age: ttl, hasher }| Some((ttl, get_cache_key(&ctx, hasher)?))); - let const_value = match ttl_and_key { - Some((ttl, key)) => { - if let Some(const_value) = ctx.req_ctx.cache_get(&key).await { - // Return value from cache - log::info!("Reading from cache. key = {key}"); - const_value - } else { - let const_value = expr.eval(&ctx, &Concurrent::Sequential).await?; - log::info!("Writing to cache. key = {key}"); - // Write value to cache - ctx.req_ctx.cache_insert(key, const_value.clone(), ttl).await; - const_value - } - } - _ => expr.eval(&ctx, &Concurrent::Sequential).await?, - }; + let ttl_and_key = + cache.and_then(|Cache { max_age: ttl, hasher }| { + Some((ttl, get_cache_key(&ctx, hasher)?)) + }); + let const_value = match ttl_and_key { + Some((ttl, key)) => { + if let Some(const_value) = ctx.req_ctx.cache_get(&key).await + { + // Return value from cache + log::info!("Reading from cache. key = {key}"); + const_value + } else { + let const_value = + expr.eval(&ctx, &Concurrent::Sequential).await?; + log::info!("Writing to cache. key = {key}"); + // Write value to cache + ctx.req_ctx + .cache_insert(key, const_value.clone(), ttl) + .await; + const_value + } + } + _ => expr.eval(&ctx, &Concurrent::Sequential).await?, + }; - let p = match const_value { - ConstValue::List(a) => FieldValue::list(a), - a => FieldValue::from(a), - }; - Ok(Some(p)) - }) + let p = match const_value { + ConstValue::List(a) => FieldValue::list(a), + a => FieldValue::from(a), + }; + Ok(Some(p)) + }) + } + } + }); + if let Some(description) = &field.description { + dyn_schema_field = dyn_schema_field.description(description); + } + for arg in field.args.iter() { + dyn_schema_field = dyn_schema_field.argument(dynamic::InputValue::new( + arg.name.clone(), + to_type_ref(&arg.of_type), + )); + } + object = object.field(dyn_schema_field); + } + for interface in def.implements.iter() { + object = object.implement(interface.clone()); } - } - }); - if let Some(description) = &field.description { - dyn_schema_field = dyn_schema_field.description(description); - } - for arg in field.args.iter() { - dyn_schema_field = - dyn_schema_field.argument(dynamic::InputValue::new(arg.name.clone(), to_type_ref(&arg.of_type))); - } - object = object.field(dyn_schema_field); - } - for interface in def.implements.iter() { - object = object.implement(interface.clone()); - } - dynamic::Type::Object(object) - } - Definition::InterfaceTypeDefinition(def) => { - let mut interface = dynamic::Interface::new(def.name.clone()); - for field in def.fields.iter() { - interface = interface.field(dynamic::InterfaceField::new( - field.name.clone(), - to_type_ref(&field.of_type), - )); - } + dynamic::Type::Object(object) + } + Definition::InterfaceTypeDefinition(def) => { + let mut interface = dynamic::Interface::new(def.name.clone()); + for field in def.fields.iter() { + interface = interface.field(dynamic::InterfaceField::new( + field.name.clone(), + to_type_ref(&field.of_type), + )); + } - dynamic::Type::Interface(interface) - } - Definition::InputObjectTypeDefinition(def) => { - let mut input_object = dynamic::InputObject::new(def.name.clone()); - for field in def.fields.iter() { - input_object = input_object.field(dynamic::InputValue::new( - field.name.clone(), - to_type_ref(&field.of_type), - )); - } + dynamic::Type::Interface(interface) + } + Definition::InputObjectTypeDefinition(def) => { + let mut input_object = dynamic::InputObject::new(def.name.clone()); + for field in def.fields.iter() { + input_object = input_object.field(dynamic::InputValue::new( + field.name.clone(), + to_type_ref(&field.of_type), + )); + } - dynamic::Type::InputObject(input_object) - } - Definition::ScalarTypeDefinition(def) => { - let mut scalar = dynamic::Scalar::new(def.name.clone()); - if let Some(description) = &def.description { - scalar = scalar.description(description); - } - dynamic::Type::Scalar(scalar) - } - Definition::EnumTypeDefinition(def) => { - let mut enum_type = dynamic::Enum::new(def.name.clone()); - for value in def.enum_values.iter() { - enum_type = enum_type.item(dynamic::EnumItem::new(value.name.clone())); - } - dynamic::Type::Enum(enum_type) - } - Definition::UnionTypeDefinition(def) => { - let mut union = dynamic::Union::new(def.name.clone()); - for type_ in def.types.iter() { - union = union.possible_type(type_.clone()); - } - dynamic::Type::Union(union) + dynamic::Type::InputObject(input_object) + } + Definition::ScalarTypeDefinition(def) => { + let mut scalar = dynamic::Scalar::new(def.name.clone()); + if let Some(description) = &def.description { + scalar = scalar.description(description); + } + dynamic::Type::Scalar(scalar) + } + Definition::EnumTypeDefinition(def) => { + let mut enum_type = dynamic::Enum::new(def.name.clone()); + for value in def.enum_values.iter() { + enum_type = enum_type.item(dynamic::EnumItem::new(value.name.clone())); + } + dynamic::Type::Enum(enum_type) + } + Definition::UnionTypeDefinition(def) => { + let mut union = dynamic::Union::new(def.name.clone()); + for type_ in def.types.iter() { + union = union.possible_type(type_.clone()); + } + dynamic::Type::Union(union) + } } - } } impl From<&Blueprint> for SchemaBuilder { - fn from(blueprint: &Blueprint) -> Self { - let query = blueprint.query(); - let mutation = blueprint.mutation(); - let mut schema = dynamic::Schema::build(query.as_str(), mutation.as_deref(), None); + fn from(blueprint: &Blueprint) -> Self { + let query = blueprint.query(); + let mutation = blueprint.mutation(); + let mut schema = dynamic::Schema::build(query.as_str(), mutation.as_deref(), None); - for def in blueprint.definitions.iter() { - schema = schema.register(to_type(def)); - } + for def in blueprint.definitions.iter() { + schema = schema.register(to_type(def)); + } - schema - } + schema + } } diff --git a/src/blueprint/mod.rs b/src/blueprint/mod.rs index f8e720dc6ac..43c37389ae8 100644 --- a/src/blueprint/mod.rs +++ b/src/blueprint/mod.rs @@ -26,71 +26,74 @@ use crate::try_fold::TryFold; pub type TryFoldConfig<'a, A> = TryFold<'a, Config, A, String>; pub(crate) trait TypeLike { - fn name(&self) -> &str; - fn list(&self) -> bool; - fn non_null(&self) -> bool; - fn list_type_required(&self) -> bool; + fn name(&self) -> &str; + fn list(&self) -> bool; + fn non_null(&self) -> bool; + fn list_type_required(&self) -> bool; } impl TypeLike for Field { - fn name(&self) -> &str { - &self.type_of - } + fn name(&self) -> &str { + &self.type_of + } - fn list(&self) -> bool { - self.list - } + fn list(&self) -> bool { + self.list + } - fn non_null(&self) -> bool { - self.required - } + fn non_null(&self) -> bool { + self.required + } - fn list_type_required(&self) -> bool { - self.list_type_required - } + fn list_type_required(&self) -> bool { + self.list_type_required + } } impl TypeLike for Arg { - fn name(&self) -> &str { - &self.type_of - } + fn name(&self) -> &str { + &self.type_of + } - fn list(&self) -> bool { - self.list - } + fn list(&self) -> bool { + self.list + } - fn non_null(&self) -> bool { - self.required - } + fn non_null(&self) -> bool { + self.required + } - fn list_type_required(&self) -> bool { - false - } + fn list_type_required(&self) -> bool { + false + } } pub(crate) fn to_type(field: &T, override_non_null: Option) -> Type where - T: TypeLike, + T: TypeLike, { - let name = field.name(); - let list = field.list(); - let list_type_required = field.list_type_required(); - let non_null = if let Some(non_null) = override_non_null { - non_null - } else { - field.non_null() - }; + let name = field.name(); + let list = field.list(); + let list_type_required = field.list_type_required(); + let non_null = if let Some(non_null) = override_non_null { + non_null + } else { + field.non_null() + }; - if list { - Type::ListType { - of_type: Box::new(Type::NamedType { name: name.to_string(), non_null: list_type_required }), - non_null, + if list { + Type::ListType { + of_type: Box::new(Type::NamedType { + name: name.to_string(), + non_null: list_type_required, + }), + non_null, + } + } else { + Type::NamedType { name: name.to_string(), non_null } } - } else { - Type::NamedType { name: name.to_string(), non_null } - } } pub fn is_scalar(type_name: &str) -> bool { - ["String", "Int", "Float", "Boolean", "ID", "JSON"].contains(&type_name) + ["String", "Int", "Float", "Boolean", "ID", "JSON"].contains(&type_name) } diff --git a/src/blueprint/mustache.rs b/src/blueprint/mustache.rs index 1a98456db64..2f5aaccea2c 100644 --- a/src/blueprint/mustache.rs +++ b/src/blueprint/mustache.rs @@ -4,169 +4,169 @@ use crate::lambda::{Expression, IO}; use crate::valid::Valid; struct MustachePartsValidator<'a> { - type_of: &'a config::Type, - config: &'a Config, - field: &'a FieldDefinition, + type_of: &'a config::Type, + config: &'a Config, + field: &'a FieldDefinition, } impl<'a> MustachePartsValidator<'a> { - fn new(type_of: &'a config::Type, config: &'a Config, field: &'a FieldDefinition) -> Self { - Self { type_of, config, field } - } - - fn validate_type(&self, parts: &[String], is_query: bool) -> Result<(), String> { - let mut len = parts.len(); - let mut type_of = self.type_of; - for item in parts { - let field = type_of.fields.get(item).ok_or_else(|| { - format!( - "no value '{}' found", - parts[0..parts.len() - len + 1].join(".").as_str() - ) - })?; - let val_type = to_type(field, None); - - if !is_query && val_type.is_nullable() { - return Err(format!("value '{}' is a nullable type", item.as_str())); - } else if len == 1 && !is_scalar(val_type.name()) { - return Err(format!("value '{}' is not of a scalar type", item.as_str())); - } else if len == 1 { - break; - } - - type_of = self - .config - .find_type(&field.type_of) - .ok_or_else(|| format!("no type '{}' found", parts.join(".").as_str()))?; - - len -= 1; + fn new(type_of: &'a config::Type, config: &'a Config, field: &'a FieldDefinition) -> Self { + Self { type_of, config, field } } - Ok(()) - } - - fn validate(&self, parts: &[String], is_query: bool) -> Valid<(), String> { - let config = self.config; - let args = &self.field.args; + fn validate_type(&self, parts: &[String], is_query: bool) -> Result<(), String> { + let mut len = parts.len(); + let mut type_of = self.type_of; + for item in parts { + let field = type_of.fields.get(item).ok_or_else(|| { + format!( + "no value '{}' found", + parts[0..parts.len() - len + 1].join(".").as_str() + ) + })?; + let val_type = to_type(field, None); + + if !is_query && val_type.is_nullable() { + return Err(format!("value '{}' is a nullable type", item.as_str())); + } else if len == 1 && !is_scalar(val_type.name()) { + return Err(format!("value '{}' is not of a scalar type", item.as_str())); + } else if len == 1 { + break; + } + + type_of = self + .config + .find_type(&field.type_of) + .ok_or_else(|| format!("no type '{}' found", parts.join(".").as_str()))?; + + len -= 1; + } - if parts.len() < 2 { - return Valid::fail("too few parts in template".to_string()); + Ok(()) } - let head = parts[0].as_str(); - let tail = parts[1].as_str(); + fn validate(&self, parts: &[String], is_query: bool) -> Valid<(), String> { + let config = self.config; + let args = &self.field.args; - match head { - "value" => { - // all items on parts except the first one - let tail = &parts[1..]; - - if let Err(e) = self.validate_type(tail, is_query) { - return Valid::fail(e); - } - } - "args" => { - // XXX this is a linear search but it's cost is less than that of - // constructing a HashMap since we'd have 3-4 arguments at max in - // most cases - if let Some(arg) = args.iter().find(|arg| arg.name == tail) { - if let Type::ListType { .. } = arg.of_type { - return Valid::fail(format!("can't use list type '{tail}' here")); - } - - // we can use non-scalar types in args - if !is_query && arg.default_value.is_none() && arg.of_type.is_nullable() { - return Valid::fail(format!("argument '{tail}' is a nullable type")); - } - } else { - return Valid::fail(format!("no argument '{tail}' found")); + if parts.len() < 2 { + return Valid::fail("too few parts in template".to_string()); } - } - "vars" => { - if config.server.vars.get(tail).is_none() { - return Valid::fail(format!("var '{tail}' is not set in the server config")); + + let head = parts[0].as_str(); + let tail = parts[1].as_str(); + + match head { + "value" => { + // all items on parts except the first one + let tail = &parts[1..]; + + if let Err(e) = self.validate_type(tail, is_query) { + return Valid::fail(e); + } + } + "args" => { + // XXX this is a linear search but it's cost is less than that of + // constructing a HashMap since we'd have 3-4 arguments at max in + // most cases + if let Some(arg) = args.iter().find(|arg| arg.name == tail) { + if let Type::ListType { .. } = arg.of_type { + return Valid::fail(format!("can't use list type '{tail}' here")); + } + + // we can use non-scalar types in args + if !is_query && arg.default_value.is_none() && arg.of_type.is_nullable() { + return Valid::fail(format!("argument '{tail}' is a nullable type")); + } + } else { + return Valid::fail(format!("no argument '{tail}' found")); + } + } + "vars" => { + if config.server.vars.get(tail).is_none() { + return Valid::fail(format!("var '{tail}' is not set in the server config")); + } + } + "headers" | "env" => { + // "headers" and "env" refers to values known at runtime, which we can't + // validate here + } + _ => { + return Valid::fail(format!("unknown template directive '{head}'")); + } } - } - "headers" | "env" => { - // "headers" and "env" refers to values known at runtime, which we can't - // validate here - } - _ => { - return Valid::fail(format!("unknown template directive '{head}'")); - } - } - Valid::succeed(()) - } + Valid::succeed(()) + } } impl FieldDefinition { - pub fn validate_field(&self, type_of: &config::Type, config: &Config) -> Valid<(), String> { - // XXX we could use `Mustache`'s `render` method with a mock - // struct implementing the `PathString` trait encapsulating `validation_map` - // but `render` simply falls back to the default value for a given - // type if it doesn't exist, so we wouldn't be able to get enough - // context from that method alone - // So we must duplicate some of that logic here :( - let parts_validator = MustachePartsValidator::new(type_of, config, self); - - match &self.resolver { - Some(Expression::IO(IO::Http { req_template, .. })) => { - Valid::from_iter(req_template.root_url.expression_segments(), |parts| { - parts_validator.validate(parts, false).trace("path") - }) - .and(Valid::from_iter(req_template.query.clone(), |query| { - let (_, mustache) = query; - - Valid::from_iter(mustache.expression_segments(), |parts| { - parts_validator.validate(parts, true).trace("query") - }) - })) - .unit() - } - Some(Expression::IO(IO::GraphQLEndpoint { req_template, .. })) => { - Valid::from_iter(req_template.headers.clone(), |(_, mustache)| { - Valid::from_iter(mustache.expression_segments(), |parts| { - parts_validator.validate(parts, true).trace("headers") - }) - }) - .and_then(|_| { - if let Some(args) = &req_template.operation_arguments { - Valid::from_iter(args, |(_, mustache)| { - Valid::from_iter(mustache.expression_segments(), |parts| { - parts_validator.validate(parts, true).trace("args") - }) - }) - } else { - Valid::succeed(Default::default()) - } - }) - .unit() - } - Some(Expression::IO(IO::Grpc { req_template, .. })) => { - Valid::from_iter(req_template.url.expression_segments(), |parts| { - parts_validator.validate(parts, false).trace("path") - }) - .and( - Valid::from_iter(req_template.headers.clone(), |(_, mustache)| { - Valid::from_iter(mustache.expression_segments(), |parts| { - parts_validator.validate(parts, true).trace("headers") - }) - }) - .unit(), - ) - .and_then(|_| { - if let Some(body) = &req_template.body { - Valid::from_iter(body.expression_segments(), |parts| { - parts_validator.validate(parts, true).trace("body") - }) - } else { - Valid::succeed(Default::default()) - } - }) - .unit() - } - _ => Valid::succeed(()), + pub fn validate_field(&self, type_of: &config::Type, config: &Config) -> Valid<(), String> { + // XXX we could use `Mustache`'s `render` method with a mock + // struct implementing the `PathString` trait encapsulating `validation_map` + // but `render` simply falls back to the default value for a given + // type if it doesn't exist, so we wouldn't be able to get enough + // context from that method alone + // So we must duplicate some of that logic here :( + let parts_validator = MustachePartsValidator::new(type_of, config, self); + + match &self.resolver { + Some(Expression::IO(IO::Http { req_template, .. })) => { + Valid::from_iter(req_template.root_url.expression_segments(), |parts| { + parts_validator.validate(parts, false).trace("path") + }) + .and(Valid::from_iter(req_template.query.clone(), |query| { + let (_, mustache) = query; + + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("query") + }) + })) + .unit() + } + Some(Expression::IO(IO::GraphQLEndpoint { req_template, .. })) => { + Valid::from_iter(req_template.headers.clone(), |(_, mustache)| { + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("headers") + }) + }) + .and_then(|_| { + if let Some(args) = &req_template.operation_arguments { + Valid::from_iter(args, |(_, mustache)| { + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("args") + }) + }) + } else { + Valid::succeed(Default::default()) + } + }) + .unit() + } + Some(Expression::IO(IO::Grpc { req_template, .. })) => { + Valid::from_iter(req_template.url.expression_segments(), |parts| { + parts_validator.validate(parts, false).trace("path") + }) + .and( + Valid::from_iter(req_template.headers.clone(), |(_, mustache)| { + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("headers") + }) + }) + .unit(), + ) + .and_then(|_| { + if let Some(body) = &req_template.body { + Valid::from_iter(body.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("body") + }) + } else { + Valid::succeed(Default::default()) + } + }) + .unit() + } + _ => Valid::succeed(()), + } } - } } diff --git a/src/blueprint/operation.rs b/src/blueprint/operation.rs index c2b553c2630..d7ef468abdd 100644 --- a/src/blueprint/operation.rs +++ b/src/blueprint/operation.rs @@ -7,56 +7,59 @@ use crate::valid::{Cause, Valid}; #[derive(Debug)] pub struct OperationQuery { - query: String, - file: String, + query: String, + file: String, } impl OperationQuery { - pub fn new(query: String, trace: String) -> Self { - Self { query, file: trace } - } - - fn to_cause(&self, err: &async_graphql::ServerError) -> Cause { - let mut trace = Vec::new(); - let file = self.file.as_str(); - - for loc in err.locations.iter() { - let mut message = String::new(); - message.write_str(file).unwrap(); - message - .write_str(format!(":{}:{}", loc.line, loc.column).as_str()) - .unwrap(); - - trace.push(message); + pub fn new(query: String, trace: String) -> Self { + Self { query, file: trace } } - Cause::new(err.message.clone()).trace(trace) - } - - async fn validate(&self, schema: &Schema) -> Vec> { - schema - .execute(&self.query) - .await - .errors - .iter() - .map(|e| self.to_cause(e)) - .collect() - } + fn to_cause(&self, err: &async_graphql::ServerError) -> Cause { + let mut trace = Vec::new(); + let file = self.file.as_str(); + + for loc in err.locations.iter() { + let mut message = String::new(); + message.write_str(file).unwrap(); + message + .write_str(format!(":{}:{}", loc.line, loc.column).as_str()) + .unwrap(); + + trace.push(message); + } + + Cause::new(err.message.clone()).trace(trace) + } + + async fn validate(&self, schema: &Schema) -> Vec> { + schema + .execute(&self.query) + .await + .errors + .iter() + .map(|e| self.to_cause(e)) + .collect() + } } -pub async fn validate_operations(blueprint: &Blueprint, operations: Vec) -> Valid<(), String> { - let schema = blueprint.to_schema_with(SchemaModifiers::no_resolver()); - Valid::from_iter( - futures_util::future::join_all(operations.iter().map(|op| op.validate(&schema))) - .await - .iter(), - |errors| { - if errors.is_empty() { - Valid::succeed(()) - } else { - Valid::<(), String>::from_vec_cause(errors.to_vec()) - } - }, - ) - .unit() +pub async fn validate_operations( + blueprint: &Blueprint, + operations: Vec, +) -> Valid<(), String> { + let schema = blueprint.to_schema_with(SchemaModifiers::no_resolver()); + Valid::from_iter( + futures_util::future::join_all(operations.iter().map(|op| op.validate(&schema))) + .await + .iter(), + |errors| { + if errors.is_empty() { + Valid::succeed(()) + } else { + Valid::<(), String>::from_vec_cause(errors.to_vec()) + } + }, + ) + .unit() } diff --git a/src/blueprint/operators/const_field.rs b/src/blueprint/operators/const_field.rs index 546ad89aa23..e5c31a1ea29 100644 --- a/src/blueprint/operators/const_field.rs +++ b/src/blueprint/operators/const_field.rs @@ -9,51 +9,56 @@ use crate::try_fold::TryFold; use crate::valid::Valid; fn validate_data_with_schema( - config: &config::Config, - field: &config::Field, - gql_value: ConstValue, + config: &config::Config, + field: &config::Field, + gql_value: ConstValue, ) -> Valid<(), String> { - match to_json_schema_for_field(field, config).validate(&gql_value).to_result() { - Ok(_) => Valid::succeed(()), - Err(err) => Valid::from_validation_err(err.transform(&(|a| a.to_owned()))), - } + match to_json_schema_for_field(field, config) + .validate(&gql_value) + .to_result() + { + Ok(_) => Valid::succeed(()), + Err(err) => Valid::from_validation_err(err.transform(&(|a| a.to_owned()))), + } } pub struct CompileConst<'a> { - pub config: &'a config::Config, - pub field: &'a config::Field, - pub value: &'a serde_json::Value, - pub validate: bool, + pub config: &'a config::Config, + pub field: &'a config::Field, + pub value: &'a serde_json::Value, + pub validate: bool, } pub fn compile_const(inputs: CompileConst) -> Valid { - let config = inputs.config; - let field = inputs.field; - let value = inputs.value; - let validate = inputs.validate; + let config = inputs.config; + let field = inputs.field; + let value = inputs.value; + let validate = inputs.validate; - let data = value.to_owned(); - match ConstValue::from_json(data.to_owned()) { - Ok(gql) => { - let validation = if validate { - validate_data_with_schema(config, field, gql) - } else { - Valid::succeed(()) - }; - validation.map(|_| Literal(data)) + let data = value.to_owned(); + match ConstValue::from_json(data.to_owned()) { + Ok(gql) => { + let validation = if validate { + validate_data_with_schema(config, field, gql) + } else { + Valid::succeed(()) + }; + validation.map(|_| Literal(data)) + } + Err(e) => Valid::fail(format!("invalid JSON: {}", e)), } - Err(e) => Valid::fail(format!("invalid JSON: {}", e)), - } } pub fn update_const_field<'a>( ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new(|(config, field, _, _), b_field| { - let Some(const_field) = &field.const_field else { - return Valid::succeed(b_field); - }; + TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( + |(config, field, _, _), b_field| { + let Some(const_field) = &field.const_field else { + return Valid::succeed(b_field); + }; - compile_const(CompileConst { config, field, value: &const_field.data, validate: true }) - .map(|resolver| b_field.resolver(Some(resolver))) - }) + compile_const(CompileConst { config, field, value: &const_field.data, validate: true }) + .map(|resolver| b_field.resolver(Some(resolver))) + }, + ) } diff --git a/src/blueprint/operators/expr.rs b/src/blueprint/operators/expr.rs index b6be79cd75a..9d482511294 100644 --- a/src/blueprint/operators/expr.rs +++ b/src/blueprint/operators/expr.rs @@ -6,828 +6,886 @@ use crate::try_fold::TryFold; use crate::valid::Valid; struct CompilationContext<'a> { - config_field: &'a config::Field, - operation_type: &'a config::GraphQLOperationType, - config: &'a config::Config, + config_field: &'a config::Field, + operation_type: &'a config::GraphQLOperationType, + config: &'a config::Config, } pub fn update_expr( - operation_type: &config::GraphQLOperationType, + operation_type: &config::GraphQLOperationType, ) -> TryFold<'_, (&Config, &Field, &config::Type, &str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new(|(config, field, _, _), b_field| { - let Some(expr) = &field.expr else { - return Valid::succeed(b_field); - }; + TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( + |(config, field, _, _), b_field| { + let Some(expr) = &field.expr else { + return Valid::succeed(b_field); + }; - let context = CompilationContext { config, operation_type, config_field: field }; + let context = CompilationContext { config, operation_type, config_field: field }; - compile(&context, expr.body.clone()).map(|compiled| b_field.resolver(Some(compiled))) - }) + compile(&context, expr.body.clone()).map(|compiled| b_field.resolver(Some(compiled))) + }, + ) } /// /// Compiles a list of Exprs into a list of Expressions /// -fn compile_list(context: &CompilationContext, expr_vec: Vec) -> Valid, String> { - Valid::from_iter(expr_vec, |value| compile(context, value)) +fn compile_list( + context: &CompilationContext, + expr_vec: Vec, +) -> Valid, String> { + Valid::from_iter(expr_vec, |value| compile(context, value)) } /// /// Compiles a tuple of Exprs into a tuple of Expressions /// -fn compile_ab(context: &CompilationContext, ab: (ExprBody, ExprBody)) -> Valid<(Expression, Expression), String> { - compile(context, ab.0).zip(compile(context, ab.1)) +fn compile_ab( + context: &CompilationContext, + ab: (ExprBody, ExprBody), +) -> Valid<(Expression, Expression), String> { + compile(context, ab.0).zip(compile(context, ab.1)) } /// /// Compiles expr into Expression /// fn compile(ctx: &CompilationContext, expr: ExprBody) -> Valid { - let config = ctx.config; - let field = ctx.config_field; - let operation_type = ctx.operation_type; - match expr { - // Io Expr - ExprBody::Http(http) => compile_http(config, field, &http), - ExprBody::Grpc(grpc) => { - let grpc = CompileGrpc { config, field, operation_type, grpc: &grpc, validate_with_schema: false }; - compile_grpc(grpc) - } - ExprBody::GraphQL(gql) => compile_graphql(config, operation_type, &gql), - - // Safe Expr - ExprBody::Const(value) => compile_const(CompileConst { config, field, value: &value, validate: false }), - - // Logic - ExprBody::If { cond, on_true: then, on_false: els } => compile(ctx, *cond) - .map(Box::new) - .zip(compile(ctx, *then).map(Box::new)) - .zip(compile(ctx, *els).map(Box::new)) - .map(|((cond, then), els)| Expression::Logic(Logic::If { cond, then, els })), - - ExprBody::And(ref list) => { - compile_list(ctx, list.clone()).map(|a| Expression::Logic(Logic::And(a)).parallel_when(expr.has_io())) - } - ExprBody::Or(ref list) => { - compile_list(ctx, list.clone()).map(|a| Expression::Logic(Logic::Or(a)).parallel_when(expr.has_io())) - } - ExprBody::Cond(default, list) => Valid::from_iter(list, |(cond, operation)| { - compile_ab(ctx, (*cond, *operation)).map(|(cond, operation)| (Box::new(cond), Box::new(operation))) - }) - .and_then(|mut list| { - compile(ctx, *default).map(|default| { - list.push((Box::new(Expression::Literal(true.into())), Box::new(default))); - Expression::Logic(Logic::Cond(list)) - }) - }), - ExprBody::DefaultTo(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Logic(Logic::DefaultTo(Box::new(a), Box::new(b)))) - } - ExprBody::IsEmpty(a) => compile(ctx, *a).map(|a| Expression::Logic(Logic::IsEmpty(Box::new(a)))), - ExprBody::Not(a) => compile(ctx, *a).map(|a| Expression::Logic(Logic::Not(Box::new(a)))), - - // List - ExprBody::Concat(ref values) => { - compile_list(ctx, values.clone()).map(|a| Expression::List(List::Concat(a)).parallel_when(expr.has_io())) - } - - // Relation - ExprBody::Intersection(ref values) => compile_list(ctx, values.clone()) - .map(|a| Expression::Relation(Relation::Intersection(a)).parallel_when(expr.has_io())), - ExprBody::Difference(a, b) => compile_list(ctx, a) - .zip(compile_list(ctx, b)) - .map(|(a, b)| Expression::Relation(Relation::Difference(a, b))), - ExprBody::Equals(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::Equals(Box::new(a), Box::new(b)))) - } - ExprBody::Gt(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::Gt(Box::new(a), Box::new(b)))) - } - ExprBody::Gte(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::Gte(Box::new(a), Box::new(b)))) - } - ExprBody::Lt(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::Lt(Box::new(a), Box::new(b)))) - } - ExprBody::Lte(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::Lte(Box::new(a), Box::new(b)))) - } - ExprBody::Max(ref list) => { - compile_list(ctx, list.clone()).map(|a| Expression::Relation(Relation::Max(a)).parallel_when(expr.has_io())) - } - ExprBody::Min(ref list) => { - compile_list(ctx, list.clone()).map(|a| Expression::Relation(Relation::Min(a)).parallel_when(expr.has_io())) - } - ExprBody::PathEq(a, path, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::PathEq(Box::new(a), path, Box::new(b)))) - } - ExprBody::PropEq(a, path, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Relation(Relation::PropEq(Box::new(a), path, Box::new(b)))) - } - ExprBody::SortPath(a, path) => { - compile(ctx, *a).map(|a| Expression::Relation(Relation::SortPath(Box::new(a), path.clone()))) - } - ExprBody::SymmetricDifference(a, b) => compile_list(ctx, a) - .zip(compile_list(ctx, b)) - .map(|(a, b)| Expression::Relation(Relation::SymmetricDifference(a, b))), - ExprBody::Union(a, b) => compile_list(ctx, a) - .zip(compile_list(ctx, b)) - .map(|(a, b)| Expression::Relation(Relation::Union(a, b))), - - // Math - ExprBody::Mod(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Math(Math::Mod(Box::new(a), Box::new(b)))) - } - ExprBody::Add(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Math(Math::Add(Box::new(a), Box::new(b)))) - } - ExprBody::Dec(a) => compile(ctx, *a).map(|a| Expression::Math(Math::Dec(Box::new(a)))), - ExprBody::Divide(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Math(Math::Divide(Box::new(a), Box::new(b)))) - } - ExprBody::Inc(a) => compile(ctx, *a).map(|a| Expression::Math(Math::Inc(Box::new(a)))), - ExprBody::Multiply(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Math(Math::Multiply(Box::new(a), Box::new(b)))) - } - ExprBody::Negate(a) => compile(ctx, *a).map(|a| Expression::Math(Math::Negate(Box::new(a)))), - ExprBody::Product(ref list) => { - compile_list(ctx, list.clone()).map(|a| Expression::Math(Math::Product(a)).parallel_when(expr.has_io())) - } - ExprBody::Subtract(a, b) => { - compile_ab(ctx, (*a, *b)).map(|(a, b)| Expression::Math(Math::Subtract(Box::new(a), Box::new(b)))) - } - ExprBody::Sum(ref list) => { - compile_list(ctx, list.clone()).map(|a| Expression::Math(Math::Sum(a)).parallel_when(expr.has_io())) + let config = ctx.config; + let field = ctx.config_field; + let operation_type = ctx.operation_type; + match expr { + // Io Expr + ExprBody::Http(http) => compile_http(config, field, &http), + ExprBody::Grpc(grpc) => { + let grpc = CompileGrpc { + config, + field, + operation_type, + grpc: &grpc, + validate_with_schema: false, + }; + compile_grpc(grpc) + } + ExprBody::GraphQL(gql) => compile_graphql(config, operation_type, &gql), + + // Safe Expr + ExprBody::Const(value) => { + compile_const(CompileConst { config, field, value: &value, validate: false }) + } + + // Logic + ExprBody::If { cond, on_true: then, on_false: els } => compile(ctx, *cond) + .map(Box::new) + .zip(compile(ctx, *then).map(Box::new)) + .zip(compile(ctx, *els).map(Box::new)) + .map(|((cond, then), els)| Expression::Logic(Logic::If { cond, then, els })), + + ExprBody::And(ref list) => compile_list(ctx, list.clone()) + .map(|a| Expression::Logic(Logic::And(a)).parallel_when(expr.has_io())), + ExprBody::Or(ref list) => compile_list(ctx, list.clone()) + .map(|a| Expression::Logic(Logic::Or(a)).parallel_when(expr.has_io())), + ExprBody::Cond(default, list) => Valid::from_iter(list, |(cond, operation)| { + compile_ab(ctx, (*cond, *operation)) + .map(|(cond, operation)| (Box::new(cond), Box::new(operation))) + }) + .and_then(|mut list| { + compile(ctx, *default).map(|default| { + list.push(( + Box::new(Expression::Literal(true.into())), + Box::new(default), + )); + Expression::Logic(Logic::Cond(list)) + }) + }), + ExprBody::DefaultTo(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Logic(Logic::DefaultTo(Box::new(a), Box::new(b)))), + ExprBody::IsEmpty(a) => { + compile(ctx, *a).map(|a| Expression::Logic(Logic::IsEmpty(Box::new(a)))) + } + ExprBody::Not(a) => compile(ctx, *a).map(|a| Expression::Logic(Logic::Not(Box::new(a)))), + + // List + ExprBody::Concat(ref values) => compile_list(ctx, values.clone()) + .map(|a| Expression::List(List::Concat(a)).parallel_when(expr.has_io())), + + // Relation + ExprBody::Intersection(ref values) => compile_list(ctx, values.clone()) + .map(|a| Expression::Relation(Relation::Intersection(a)).parallel_when(expr.has_io())), + ExprBody::Difference(a, b) => compile_list(ctx, a) + .zip(compile_list(ctx, b)) + .map(|(a, b)| Expression::Relation(Relation::Difference(a, b))), + ExprBody::Equals(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::Equals(Box::new(a), Box::new(b)))), + ExprBody::Gt(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::Gt(Box::new(a), Box::new(b)))), + ExprBody::Gte(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::Gte(Box::new(a), Box::new(b)))), + ExprBody::Lt(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::Lt(Box::new(a), Box::new(b)))), + ExprBody::Lte(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::Lte(Box::new(a), Box::new(b)))), + ExprBody::Max(ref list) => compile_list(ctx, list.clone()) + .map(|a| Expression::Relation(Relation::Max(a)).parallel_when(expr.has_io())), + ExprBody::Min(ref list) => compile_list(ctx, list.clone()) + .map(|a| Expression::Relation(Relation::Min(a)).parallel_when(expr.has_io())), + ExprBody::PathEq(a, path, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::PathEq(Box::new(a), path, Box::new(b)))), + ExprBody::PropEq(a, path, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Relation(Relation::PropEq(Box::new(a), path, Box::new(b)))), + ExprBody::SortPath(a, path) => compile(ctx, *a) + .map(|a| Expression::Relation(Relation::SortPath(Box::new(a), path.clone()))), + ExprBody::SymmetricDifference(a, b) => compile_list(ctx, a) + .zip(compile_list(ctx, b)) + .map(|(a, b)| Expression::Relation(Relation::SymmetricDifference(a, b))), + ExprBody::Union(a, b) => compile_list(ctx, a) + .zip(compile_list(ctx, b)) + .map(|(a, b)| Expression::Relation(Relation::Union(a, b))), + + // Math + ExprBody::Mod(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Math(Math::Mod(Box::new(a), Box::new(b)))), + ExprBody::Add(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Math(Math::Add(Box::new(a), Box::new(b)))), + ExprBody::Dec(a) => compile(ctx, *a).map(|a| Expression::Math(Math::Dec(Box::new(a)))), + ExprBody::Divide(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Math(Math::Divide(Box::new(a), Box::new(b)))), + ExprBody::Inc(a) => compile(ctx, *a).map(|a| Expression::Math(Math::Inc(Box::new(a)))), + ExprBody::Multiply(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Math(Math::Multiply(Box::new(a), Box::new(b)))), + ExprBody::Negate(a) => { + compile(ctx, *a).map(|a| Expression::Math(Math::Negate(Box::new(a)))) + } + ExprBody::Product(ref list) => compile_list(ctx, list.clone()) + .map(|a| Expression::Math(Math::Product(a)).parallel_when(expr.has_io())), + ExprBody::Subtract(a, b) => compile_ab(ctx, (*a, *b)) + .map(|(a, b)| Expression::Math(Math::Subtract(Box::new(a), Box::new(b)))), + ExprBody::Sum(ref list) => compile_list(ctx, list.clone()) + .map(|a| Expression::Math(Math::Sum(a)).parallel_when(expr.has_io())), } - } } #[cfg(test)] mod tests { - use std::collections::HashSet; - use std::sync::{Arc, Mutex}; - - use pretty_assertions::assert_eq; - use serde_json::{json, Number}; - - use super::{compile, CompilationContext}; - use crate::config::{Config, Expr, Field, GraphQLOperationType}; - use crate::http::RequestContext; - use crate::lambda::{Concurrent, Eval, EvaluationContext, ResolverContextLike}; - - #[derive(Default)] - struct Context<'a> { - value: Option<&'a async_graphql_value::ConstValue>, - args: Option<&'a indexmap::IndexMap>, - field: Option>, - errors: Arc>>, - } - - impl<'a> ResolverContextLike<'a> for Context<'a> { - fn value(&'a self) -> Option<&'a async_graphql_value::ConstValue> { - self.value + use std::collections::HashSet; + use std::sync::{Arc, Mutex}; + + use pretty_assertions::assert_eq; + use serde_json::{json, Number}; + + use super::{compile, CompilationContext}; + use crate::config::{Config, Expr, Field, GraphQLOperationType}; + use crate::http::RequestContext; + use crate::lambda::{Concurrent, Eval, EvaluationContext, ResolverContextLike}; + + #[derive(Default)] + struct Context<'a> { + value: Option<&'a async_graphql_value::ConstValue>, + args: Option< + &'a indexmap::IndexMap, + >, + field: Option>, + errors: Arc>>, } - fn args(&'a self) -> Option<&'a indexmap::IndexMap> { - self.args + impl<'a> ResolverContextLike<'a> for Context<'a> { + fn value(&'a self) -> Option<&'a async_graphql_value::ConstValue> { + self.value + } + + fn args( + &'a self, + ) -> Option< + &'a indexmap::IndexMap, + > { + self.args + } + + fn field(&'a self) -> Option { + self.field + } + + fn add_error(&'a self, error: async_graphql::ServerError) { + self.errors.lock().unwrap().push(error); + } } - fn field(&'a self) -> Option { - self.field + impl Expr { + async fn eval(expr: serde_json::Value) -> anyhow::Result { + let expr = serde_json::from_value::(expr)?; + let config = Config::default(); + let field = Field::default(); + let operation_type = GraphQLOperationType::Query; + let context = CompilationContext { + config: &config, + config_field: &field, + operation_type: &operation_type, + }; + let expression = compile(&context, expr.body.clone()).to_result()?; + let req_ctx = RequestContext::default(); + let graphql_ctx = Context::default(); + let ctx = EvaluationContext::new(&req_ctx, &graphql_ctx); + let value = expression.eval(&ctx, &Concurrent::default()).await?; + + Ok(serde_json::to_value(value)?) + } } - fn add_error(&'a self, error: async_graphql::ServerError) { - self.errors.lock().unwrap().push(error); + #[tokio::test] + async fn test_is_truthy() { + let actual = Expr::eval(json!({"body": {"inc": {"const": 1}}})) + .await + .unwrap(); + let expected = json!(2.0); + assert_eq!(actual, expected); } - } - - impl Expr { - async fn eval(expr: serde_json::Value) -> anyhow::Result { - let expr = serde_json::from_value::(expr)?; - let config = Config::default(); - let field = Field::default(); - let operation_type = GraphQLOperationType::Query; - let context = CompilationContext { config: &config, config_field: &field, operation_type: &operation_type }; - let expression = compile(&context, expr.body.clone()).to_result()?; - let req_ctx = RequestContext::default(); - let graphql_ctx = Context::default(); - let ctx = EvaluationContext::new(&req_ctx, &graphql_ctx); - let value = expression.eval(&ctx, &Concurrent::default()).await?; - Ok(serde_json::to_value(value)?) + #[tokio::test] + async fn test_math_add() { + let actual = Expr::eval(json!({"body": {"add": [{"const": 40}, {"const": 2}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); } - } - - #[tokio::test] - async fn test_is_truthy() { - let actual = Expr::eval(json!({"body": {"inc": {"const": 1}}})).await.unwrap(); - let expected = json!(2.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_add() { - let actual = Expr::eval(json!({"body": {"add": [{"const": 40}, {"const": 2}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_subtract() { - let actual = Expr::eval(json!({"body": {"subtract": [{"const": 52}, {"const": 10}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_multiply() { - let actual = Expr::eval(json!({"body": {"multiply": [{"const": 7}, {"const": 6}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_mod() { - let actual = Expr::eval(json!({"body": {"mod": [{"const": 1379}, {"const": 1337}]}})) - .await - .unwrap(); - let expected = json!(42); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_div1() { - let actual = Expr::eval(json!({"body": {"divide": [{"const": 9828}, {"const": 234}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - #[tokio::test] - async fn test_math_div2() { - let actual = Expr::eval(json!({"body": {"divide": [{"const": 105}, {"const": 2.5}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_inc() { - let actual = Expr::eval(json!({"body": {"inc": {"const": 41}}})).await.unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_dec() { - let actual = Expr::eval(json!({"body": {"dec": {"const": 43}}})).await.unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_product() { - let actual = Expr::eval(json!({"body": {"product": [{"const": 7}, {"const": 3}, {"const": 2}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_math_sum() { - let actual = Expr::eval(json!({"body": {"sum": [{"const": 20}, {"const": 15}, {"const": 7}]}})) - .await - .unwrap(); - let expected = json!(42.0); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_logic_and_true() { - let expected = json!(true); + #[tokio::test] + async fn test_math_subtract() { + let actual = Expr::eval(json!({"body": {"subtract": [{"const": 52}, {"const": 10}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"and": [{"const": true}, {"const": true}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_math_multiply() { + let actual = Expr::eval(json!({"body": {"multiply": [{"const": 7}, {"const": 6}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"and": [{"const": true}, {"const": true}, {"const": true}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_math_mod() { + let actual = Expr::eval(json!({"body": {"mod": [{"const": 1379}, {"const": 1337}]}})) + .await + .unwrap(); + let expected = json!(42); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_and_false() { - let expected = json!(false); + #[tokio::test] + async fn test_math_div1() { + let actual = Expr::eval(json!({"body": {"divide": [{"const": 9828}, {"const": 234}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"and": [{"const": true}, {"const": false}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_math_div2() { + let actual = Expr::eval(json!({"body": {"divide": [{"const": 105}, {"const": 2.5}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"and": [{"const": true}, {"const": true}, {"const": false}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_math_inc() { + let actual = Expr::eval(json!({"body": {"inc": {"const": 41}}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"and": [{"const": false}, {"const": false}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_math_dec() { + let actual = Expr::eval(json!({"body": {"dec": {"const": 43}}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_is_empty_true() { - let expected = json!(true); + #[tokio::test] + async fn test_math_product() { + let actual = + Expr::eval(json!({"body": {"product": [{"const": 7}, {"const": 3}, {"const": 2}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": []}}})).await.unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_math_sum() { + let actual = + Expr::eval(json!({"body": {"sum": [{"const": 20}, {"const": 15}, {"const": 7}]}})) + .await + .unwrap(); + let expected = json!(42.0); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": {}}}})).await.unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_logic_and_true() { + let expected = json!(true); - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": ""}}})).await.unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"and": [{"const": true}, {"const": true}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": null}}})).await.unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval( + json!({"body": {"and": [{"const": true}, {"const": true}, {"const": true}]}}), + ) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_is_empty_false() { - let expected = json!(false); + #[tokio::test] + async fn test_logic_and_false() { + let expected = json!(false); - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": [1]}}})).await.unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"and": [{"const": true}, {"const": false}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": {"a": 1}}}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval( + json!({"body": {"and": [{"const": true}, {"const": true}, {"const": false}]}}), + ) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": "a"}}})).await.unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"and": [{"const": false}, {"const": false}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": 1}}})).await.unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_logic_is_empty_true() { + let expected = json!(true); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": []}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": {}}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": ""}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": null}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"isEmpty": {"const": false}}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_logic_is_empty_false() { + let expected = json!(false); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": [1]}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": {"a": 1}}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": "a"}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": 1}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"isEmpty": {"const": false}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_not_true() { - let expected = json!(false); + #[tokio::test] + async fn test_logic_not_true() { + let expected = json!(false); - let actual = Expr::eval(json!({"body": {"not": {"const": true}}})).await.unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"not": {"const": true}}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"not": {"const": 1}}})).await.unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval(json!({"body": {"not": {"const": 1}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_not_false() { - let expected = json!(true); + #[tokio::test] + async fn test_logic_not_false() { + let expected = json!(true); - let actual = Expr::eval(json!({"body": {"not": {"const": false}}})).await.unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval(json!({"body": {"not": {"const": false}}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_or_false() { - let expected = json!(false); + #[tokio::test] + async fn test_logic_or_false() { + let expected = json!(false); - let actual = Expr::eval(json!({"body": {"or": [{"const": false}, {"const": false}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"or": [{"const": false}, {"const": false}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"or": [{"const": false}, {"const": false}, {"const": false}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval( + json!({"body": {"or": [{"const": false}, {"const": false}, {"const": false}]}}), + ) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_or_true() { - let expected = json!(true); + #[tokio::test] + async fn test_logic_or_true() { + let expected = json!(true); - let actual = Expr::eval(json!({"body": {"or": [{"const": true}, {"const": false}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"or": [{"const": true}, {"const": false}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"or": [{"const": false}, {"const": false}, {"const": true}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval( + json!({"body": {"or": [{"const": false}, {"const": false}, {"const": true}]}}), + ) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"or": [{"const": true}, {"const": true}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval(json!({"body": {"or": [{"const": true}, {"const": true}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_logic_cond() { - let expected = json!(0); + #[tokio::test] + async fn test_logic_cond() { + let expected = json!(0); - let actual = Expr::eval( + let actual = Expr::eval( json!({"body": {"cond": [{"const": 0}, [[{"const": false}, {"const": 1}], [{"const": false}, {"const": 2}]]]}}), ) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let expected = json!(1); + let expected = json!(1); - let actual = Expr::eval( + let actual = Expr::eval( json!({"body": {"cond": [{"const": 0}, [[{"const": true}, {"const": 1}], [{"const": true}, {"const": 2}]]]}}), ) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let expected = json!(2); - let actual = Expr::eval( + let expected = json!(2); + let actual = Expr::eval( json!({"body": {"cond": [{"const": 0}, [[{"const": false}, {"const": 1}], [{"const": true}, {"const": 2}]]]}}), ) .await .unwrap(); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_logic_default_to() { - let expected = json!(0); - let actual = Expr::eval(json!({"body": {"defaultTo": [{"const": null}, {"const": 0}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); + } - let expected = json!(true); - let actual = Expr::eval(json!({"body": {"defaultTo": [{"const": ""}, {"const": true}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_logic_default_to() { + let expected = json!(0); + let actual = Expr::eval(json!({"body": {"defaultTo": [{"const": null}, {"const": 0}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let expected = json!(true); + let actual = Expr::eval(json!({"body": {"defaultTo": [{"const": ""}, {"const": true}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_concat() { - let expected = json!([1, 2, 3, 4]); - let actual = Expr::eval(json!({"body": {"concat": [{"const": [1, 2]}, {"const": [3, 4]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_concat() { + let expected = json!([1, 2, 3, 4]); + let actual = + Expr::eval(json!({"body": {"concat": [{"const": [1, 2]}, {"const": [3, 4]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_intersection() { - let expected = json!([3]); - let actual = Expr::eval(json!({"body": {"intersection": [{"const": [1, 2, 3]}, {"const": [3, 4, 5]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_intersection() { + let expected = json!([3]); + let actual = Expr::eval( + json!({"body": {"intersection": [{"const": [1, 2, 3]}, {"const": [3, 4, 5]}]}}), + ) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_difference() { - let expected = json!([1]); - let actual = Expr::eval( + #[tokio::test] + async fn test_relation_difference() { + let expected = json!([1]); + let actual = Expr::eval( json!({"body": {"difference": [[{"const": 1}, {"const": 2}, {"const": 3}], [{"const": 2}, {"const": 3}]]}}), ) .await .unwrap(); - assert_eq!(actual, expected); - } + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_simmetric_difference() { - let expected = json!([1]); + #[tokio::test] + async fn test_relation_simmetric_difference() { + let expected = json!([1]); - let actual = Expr::eval( + let actual = Expr::eval( json!({"body": {"symmetricDifference": [[{"const": 1}, {"const": 2}, {"const": 3}], [{"const": 2}, {"const": 3}]]}}), ) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let actual = Expr::eval( + let actual = Expr::eval( json!({"body": {"symmetricDifference": [[{"const": 2}, {"const": 3}], [{"const": 1}, {"const": 2}, {"const": 3}]]}}), ) .await .unwrap(); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_relation_union() { - let expected = serde_json::from_value::>(json!([1, 2, 3, 4])).unwrap(); - - let actual = Expr::eval(json!({"body": {"union": [[{"const": 1}, {"const": 2}, {"const": 3}], [{"const": 2}, {"const": 3}, {"const": 4}]]}})) - .await - .unwrap(); - let actual = serde_json::from_value::>(actual).unwrap(); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_relation_eq_true() { - let expected = json!(true); - - let actual = Expr::eval(json!({"body": {"eq": [{"const": [1, 2, 3]}, {"const": [1, 2, 3]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - - let actual = Expr::eval(json!({"body": {"eq": [{"const": "abc"}, {"const": "abc"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - - let actual = Expr::eval(json!({"body": {"eq": [{"const": true}, {"const": true}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_relation_eq_false() { - let expected = json!(false); - - let actual = Expr::eval(json!({"body": {"eq": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - - let actual = Expr::eval(json!({"body": {"eq": [{"const": "abc"}, {"const": 1}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - - let actual = Expr::eval(json!({"body": {"eq": [{"const": "abc"}, {"const": "ac"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_relation_gt_true() { - let expected = json!(true); - - let actual = Expr::eval(json!({"body": {"gt": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"gt": [{"const": "bc"}, {"const": "ab"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_relation_union() { + let expected = serde_json::from_value::>(json!([1, 2, 3, 4])).unwrap(); - let actual = Expr::eval(json!({"body": {"gt": [{"const": 4}, {"const": -1}]}})) + let actual = Expr::eval(json!({"body": {"union": [[{"const": 1}, {"const": 2}, {"const": 3}], [{"const": 2}, {"const": 3}, {"const": 4}]]}})) .await .unwrap(); - assert_eq!(actual, expected); - } + let actual = serde_json::from_value::>(actual).unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_gt_false() { - let expected = json!(false); + #[tokio::test] + async fn test_relation_eq_true() { + let expected = json!(true); + + let actual = + Expr::eval(json!({"body": {"eq": [{"const": [1, 2, 3]}, {"const": [1, 2, 3]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"eq": [{"const": "abc"}, {"const": "abc"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"eq": [{"const": true}, {"const": true}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"gt": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_relation_eq_false() { + let expected = json!(false); - let actual = Expr::eval(json!({"body": {"gt": [{"const": "abc"}, {"const": "z"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"eq": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"gt": [{"const": 0}, {"const": 3.74}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval(json!({"body": {"eq": [{"const": "abc"}, {"const": 1}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - #[tokio::test] - async fn test_relation_lt_true() { - let expected = json!(true); - - let actual = Expr::eval(json!({"body": {"lt": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - - let actual = Expr::eval(json!({"body": {"lt": [{"const": "abc"}, {"const": "z"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"eq": [{"const": "abc"}, {"const": "ac"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"lt": [{"const": 0}, {"const": 3.74}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_gt_true() { + let expected = json!(true); - #[tokio::test] - async fn test_relation_lt_false() { - let expected = json!(false); + let actual = Expr::eval(json!({"body": {"gt": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"lt": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"gt": [{"const": "bc"}, {"const": "ab"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"lt": [{"const": "bc"}, {"const": "ab"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"gt": [{"const": 4}, {"const": -1}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"lt": [{"const": 4}, {"const": -1}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_gt_false() { + let expected = json!(false); - #[tokio::test] - async fn test_relation_gte_true() { - let expected = json!(true); + let actual = Expr::eval(json!({"body": {"gt": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"gte": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"gt": [{"const": "abc"}, {"const": "z"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"gte": [{"const": "bc"}, {"const": "ab"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"gt": [{"const": 0}, {"const": 3.74}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"gte": [{"const": 4}, {"const": -1}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_relation_lt_true() { + let expected = json!(true); - let actual = Expr::eval(json!({"body": {"gte": [{"const": 4}, {"const": 4}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval(json!({"body": {"lt": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - #[tokio::test] - async fn test_relation_gte_false() { - let expected = json!(false); + let actual = Expr::eval(json!({"body": {"lt": [{"const": "abc"}, {"const": "z"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"gte": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - - let actual = Expr::eval(json!({"body": {"gte": [{"const": "abc"}, {"const": "z"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"lt": [{"const": 0}, {"const": 3.74}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"gte": [{"const": 0}, {"const": 3.74}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_lt_false() { + let expected = json!(false); - #[tokio::test] - async fn test_relation_lte_true() { - let expected = json!(true); + let actual = Expr::eval(json!({"body": {"lt": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"lte": [{"const": [1, 2, 3]}, {"const": [1, 2, 3]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"lt": [{"const": "bc"}, {"const": "ab"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); - let actual = Expr::eval(json!({"body": {"lte": [{"const": 4}, {"const": 4}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + let actual = Expr::eval(json!({"body": {"lt": [{"const": 4}, {"const": -1}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"lte": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_relation_gte_true() { + let expected = json!(true); + + let actual = + Expr::eval(json!({"body": {"gte": [{"const": [1, 2, 3]}, {"const": [1, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"gte": [{"const": "bc"}, {"const": "ab"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"gte": [{"const": 4}, {"const": -1}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"gte": [{"const": 4}, {"const": 4}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"lte": [{"const": "abc"}, {"const": "z"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); + #[tokio::test] + async fn test_relation_gte_false() { + let expected = json!(false); + + let actual = + Expr::eval(json!({"body": {"gte": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"gte": [{"const": "abc"}, {"const": "z"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"gte": [{"const": 0}, {"const": 3.74}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - let actual = Expr::eval(json!({"body": {"lte": [{"const": 0}, {"const": 3.74}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_lte_true() { + let expected = json!(true); + + let actual = + Expr::eval(json!({"body": {"lte": [{"const": [1, 2, 3]}, {"const": [1, 2, 3]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"lte": [{"const": 4}, {"const": 4}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = + Expr::eval(json!({"body": {"lte": [{"const": [1, 2, 3]}, {"const": [2, 2]}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"lte": [{"const": "abc"}, {"const": "z"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + + let actual = Expr::eval(json!({"body": {"lte": [{"const": 0}, {"const": 3.74}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_lte_false() { - let expected = json!(false); + #[tokio::test] + async fn test_relation_lte_false() { + let expected = json!(false); - let actual = Expr::eval(json!({"body": {"lte": [{"const": "bc"}, {"const": "ab"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + let actual = Expr::eval(json!({"body": {"lte": [{"const": "bc"}, {"const": "ab"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_max() { - let expected = json!(923.83); - let actual = Expr::eval( + #[tokio::test] + async fn test_relation_max() { + let expected = json!(923.83); + let actual = Expr::eval( json!({"body": {"max": [{"const": 1}, {"const": 23}, {"const": -423}, {"const": 0}, {"const": 923.83}]}}), ) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let expected = json!("z"); - let actual = + let expected = json!("z"); + let actual = Expr::eval(json!({"body": {"max": [{"const": "abc"}, {"const": "z"}, {"const": "bcd"}, {"const": "foo"}]}})) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let expected = json!([2, 3]); - let actual = Expr::eval( + let expected = json!([2, 3]); + let actual = Expr::eval( json!({"body": {"max": [{"const": [2, 3]}, {"const": [0, 1, 2]}, {"const": [-1, 0, 0, 0]}, {"const": [1]}]}}), ) .await .unwrap(); - assert_eq!(actual, expected); - } + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_min() { - let expected = json!(-423); - let actual = Expr::eval( + #[tokio::test] + async fn test_relation_min() { + let expected = json!(-423); + let actual = Expr::eval( json!({"body": {"min": [{"const": 1}, {"const": 23}, {"const": -423}, {"const": 0}, {"const": 923.83}]}}), ) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let expected = json!("abc"); - let actual = + let expected = json!("abc"); + let actual = Expr::eval(json!({"body": {"min": [{"const": "abc"}, {"const": "z"}, {"const": "bcd"}, {"const": "foo"}]}})) .await .unwrap(); - assert_eq!(actual, expected); + assert_eq!(actual, expected); - let expected = json!([-1, 0, 0, 0]); - let actual = Expr::eval( + let expected = json!([-1, 0, 0, 0]); + let actual = Expr::eval( json!({"body": {"min": [{"const": [2, 3]}, {"const": [0, 1, 2]}, {"const": [-1, 0, 0, 0]}, {"const": [1]}]}}), ) .await .unwrap(); - assert_eq!(actual, expected); - } + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_sort_path() { - let expected = json!([2, 3, 4]); - let actual = Expr::eval(json!({"body": {"sortPath": [{"const": [4, 2, 3]}, []]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_sort_path() { + let expected = json!([2, 3, 4]); + let actual = Expr::eval(json!({"body": {"sortPath": [{"const": [4, 2, 3]}, []]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_path_eq_true() { - let expected = json!(true); - let actual = Expr::eval(json!({"body": {"pathEq": [{"const": 10}, [], {"const": 10}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_path_eq_true() { + let expected = json!(true); + let actual = Expr::eval(json!({"body": {"pathEq": [{"const": 10}, [], {"const": 10}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - #[tokio::test] - async fn test_relation_path_eq_false() { - let expected = json!(false); - let actual = Expr::eval(json!({"body": {"pathEq": [{"const": "ab"}, [], {"const": "bcd"}]}})) - .await - .unwrap(); - assert_eq!(actual, expected); - } + #[tokio::test] + async fn test_relation_path_eq_false() { + let expected = json!(false); + let actual = + Expr::eval(json!({"body": {"pathEq": [{"const": "ab"}, [], {"const": "bcd"}]}})) + .await + .unwrap(); + assert_eq!(actual, expected); + } - // TODO: add tests for all other expr operators + // TODO: add tests for all other expr operators } diff --git a/src/blueprint/operators/graphql.rs b/src/blueprint/operators/graphql.rs index 04e0ce505ec..4be5959e7dc 100644 --- a/src/blueprint/operators/graphql.rs +++ b/src/blueprint/operators/graphql.rs @@ -7,40 +7,49 @@ use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; pub fn compile_graphql( - config: &config::Config, - operation_type: &config::GraphQLOperationType, - graphql: &config::GraphQL, + config: &config::Config, + operation_type: &config::GraphQLOperationType, + graphql: &config::GraphQL, ) -> Valid { - let args = graphql.args.as_ref(); - Valid::from_option( - graphql.base_url.as_ref().or(config.upstream.base_url.as_ref()), - "No base URL defined".to_string(), - ) - .zip(helpers::headers::to_mustache_headers(&graphql.headers)) - .and_then(|(base_url, headers)| { - Valid::from( - RequestTemplate::new(base_url.to_owned(), operation_type, &graphql.name, args, headers) - .map_err(|e| ValidationError::new(e.to_string())), + let args = graphql.args.as_ref(); + Valid::from_option( + graphql + .base_url + .as_ref() + .or(config.upstream.base_url.as_ref()), + "No base URL defined".to_string(), ) - }) - .map(|req_template| { - let field_name = graphql.name.clone(); - Lambda::from_graphql_request_template(req_template, field_name, graphql.batch).expression - }) + .zip(helpers::headers::to_mustache_headers(&graphql.headers)) + .and_then(|(base_url, headers)| { + Valid::from( + RequestTemplate::new( + base_url.to_owned(), + operation_type, + &graphql.name, + args, + headers, + ) + .map_err(|e| ValidationError::new(e.to_string())), + ) + }) + .map(|req_template| { + let field_name = graphql.name.clone(); + Lambda::from_graphql_request_template(req_template, field_name, graphql.batch).expression + }) } pub fn update_graphql<'a>( - operation_type: &'a GraphQLOperationType, + operation_type: &'a GraphQLOperationType, ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, type_of, _), b_field| { - let Some(graphql) = &field.graphql else { - return Valid::succeed(b_field); - }; + TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( + |(config, field, type_of, _), b_field| { + let Some(graphql) = &field.graphql else { + return Valid::succeed(b_field); + }; - compile_graphql(config, operation_type, graphql) - .map(|resolver| b_field.resolver(Some(resolver))) - .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) - }, - ) + compile_graphql(config, operation_type, graphql) + .map(|resolver| b_field.resolver(Some(resolver))) + .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) + }, + ) } diff --git a/src/blueprint/operators/grpc.rs b/src/blueprint/operators/grpc.rs index f64742c12ee..8be9ac285d2 100644 --- a/src/blueprint/operators/grpc.rs +++ b/src/blueprint/operators/grpc.rs @@ -15,146 +15,167 @@ use crate::valid::{Valid, ValidationError}; use crate::{config, helpers}; fn to_url(grpc: &Grpc, config: &Config) -> Valid { - Valid::from_option( - grpc.base_url.as_ref().or(config.upstream.base_url.as_ref()), - "No base URL defined".to_string(), - ) - .and_then(|base_url| { - let mut base_url = base_url.trim_end_matches('/').to_owned(); - base_url.push('/'); - base_url.push_str(&grpc.service); - base_url.push('/'); - base_url.push_str(&grpc.method); + Valid::from_option( + grpc.base_url.as_ref().or(config.upstream.base_url.as_ref()), + "No base URL defined".to_string(), + ) + .and_then(|base_url| { + let mut base_url = base_url.trim_end_matches('/').to_owned(); + base_url.push('/'); + base_url.push_str(&grpc.service); + base_url.push('/'); + base_url.push_str(&grpc.method); - helpers::url::to_url(&base_url) - }) + helpers::url::to_url(&base_url) + }) } fn to_operation(grpc: &Grpc) -> Valid { - Valid::from( - ProtobufSet::from_proto_file(Path::new(&grpc.proto_path)).map_err(|e| ValidationError::new(e.to_string())), - ) - .and_then(|set| { - Valid::from( - set - .find_service(&grpc.service) - .map_err(|e| ValidationError::new(e.to_string())), - ) - }) - .and_then(|service| { Valid::from( - service - .find_operation(&grpc.method) - .map_err(|e| ValidationError::new(e.to_string())), + ProtobufSet::from_proto_file(Path::new(&grpc.proto_path)) + .map_err(|e| ValidationError::new(e.to_string())), ) - }) + .and_then(|set| { + Valid::from( + set.find_service(&grpc.service) + .map_err(|e| ValidationError::new(e.to_string())), + ) + }) + .and_then(|service| { + Valid::from( + service + .find_operation(&grpc.method) + .map_err(|e| ValidationError::new(e.to_string())), + ) + }) } fn json_schema_from_field(config: &Config, field: &Field) -> FieldSchema { - let field_schema = crate::blueprint::to_json_schema_for_field(field, config); - let args_schema = crate::blueprint::to_json_schema_for_args(&field.args, config); - FieldSchema { args: args_schema, field: field_schema } + let field_schema = crate::blueprint::to_json_schema_for_field(field, config); + let args_schema = crate::blueprint::to_json_schema_for_args(&field.args, config); + FieldSchema { args: args_schema, field: field_schema } } pub struct FieldSchema { - pub args: JsonSchema, - pub field: JsonSchema, + pub args: JsonSchema, + pub field: JsonSchema, } -fn validate_schema(field_schema: FieldSchema, operation: &ProtobufOperation, name: &str) -> Valid<(), String> { - let input_type = &operation.input_type; - let output_type = &operation.output_type; +fn validate_schema( + field_schema: FieldSchema, + operation: &ProtobufOperation, + name: &str, +) -> Valid<(), String> { + let input_type = &operation.input_type; + let output_type = &operation.output_type; - Valid::from(JsonSchema::try_from(input_type)) - .zip(Valid::from(JsonSchema::try_from(output_type))) - .and_then(|(_input_schema, output_schema)| { - // TODO: add validation for input schema - should compare result grpc.body to schema - let fields = field_schema.field; - let _args = field_schema.args; - fields.compare(&output_schema, name) - }) + Valid::from(JsonSchema::try_from(input_type)) + .zip(Valid::from(JsonSchema::try_from(output_type))) + .and_then(|(_input_schema, output_schema)| { + // TODO: add validation for input schema - should compare result grpc.body to schema + let fields = field_schema.field; + let _args = field_schema.args; + fields.compare(&output_schema, name) + }) } fn validate_group_by( - field_schema: &FieldSchema, - operation: &ProtobufOperation, - group_by: Vec, + field_schema: &FieldSchema, + operation: &ProtobufOperation, + group_by: Vec, ) -> Valid<(), String> { - let input_type = &operation.input_type; - let output_type = &operation.output_type; - let mut field_descriptor: Result> = - None.ok_or(ValidationError::new(format!("field {} not found", group_by[0]))); - for item in group_by.iter().take(&group_by.len() - 1) { - field_descriptor = output_type - .get_field_by_json_name(item.as_str()) - .ok_or(ValidationError::new(format!("field {} not found", item))); - } - let output_type = field_descriptor.and_then(|f| JsonSchema::try_from(&f)); + let input_type = &operation.input_type; + let output_type = &operation.output_type; + let mut field_descriptor: Result> = None.ok_or( + ValidationError::new(format!("field {} not found", group_by[0])), + ); + for item in group_by.iter().take(&group_by.len() - 1) { + field_descriptor = output_type + .get_field_by_json_name(item.as_str()) + .ok_or(ValidationError::new(format!("field {} not found", item))); + } + let output_type = field_descriptor.and_then(|f| JsonSchema::try_from(&f)); - Valid::from(JsonSchema::try_from(input_type)) - .zip(Valid::from(output_type)) - .and_then(|(_input_schema, output_schema)| { - // TODO: add validation for input schema - should compare result grpc.body to schema considering repeated message type - let fields = &field_schema.field; - let args = &field_schema.args; - let fields = JsonSchema::Arr(Box::new(fields.to_owned())); - let _args = JsonSchema::Arr(Box::new(args.to_owned())); - fields.compare(&output_schema, group_by[0].as_str()) - }) + Valid::from(JsonSchema::try_from(input_type)) + .zip(Valid::from(output_type)) + .and_then(|(_input_schema, output_schema)| { + // TODO: add validation for input schema - should compare result grpc.body to schema considering repeated message type + let fields = &field_schema.field; + let args = &field_schema.args; + let fields = JsonSchema::Arr(Box::new(fields.to_owned())); + let _args = JsonSchema::Arr(Box::new(args.to_owned())); + fields.compare(&output_schema, group_by[0].as_str()) + }) } pub struct CompileGrpc<'a> { - pub config: &'a config::Config, - pub operation_type: &'a config::GraphQLOperationType, - pub field: &'a config::Field, - pub grpc: &'a config::Grpc, - pub validate_with_schema: bool, + pub config: &'a config::Config, + pub operation_type: &'a config::GraphQLOperationType, + pub field: &'a config::Field, + pub grpc: &'a config::Grpc, + pub validate_with_schema: bool, } pub fn compile_grpc(inputs: CompileGrpc) -> Valid { - let config = inputs.config; - let operation_type = inputs.operation_type; - let field = inputs.field; - let grpc = inputs.grpc; - let validate_with_schema = inputs.validate_with_schema; + let config = inputs.config; + let operation_type = inputs.operation_type; + let field = inputs.field; + let grpc = inputs.grpc; + let validate_with_schema = inputs.validate_with_schema; - to_url(grpc, config) - .zip(to_operation(grpc)) - .zip(helpers::headers::to_mustache_headers(&grpc.headers)) - .zip(helpers::body::to_body(grpc.body.as_deref())) - .and_then(|(((url, operation), headers), body)| { - let validation = if validate_with_schema { - let field_schema = json_schema_from_field(config, field); - if grpc.group_by.is_empty() { - validate_schema(field_schema, &operation, field.name()).unit() - } else { - validate_group_by(&field_schema, &operation, grpc.group_by.clone()).unit() - } - } else { - Valid::succeed(()) - }; - validation.map(|_| (url, headers, operation, body)) - }) - .map(|(url, headers, operation, body)| { - let req_template = RequestTemplate { url, headers, operation, body, operation_type: operation_type.clone() }; - if !grpc.group_by.is_empty() { - Expression::IO(IO::Grpc { req_template, group_by: Some(GroupBy::new(grpc.group_by.clone())), dl_id: None }) - } else { - Lambda::from_grpc_request_template(req_template).expression - } - }) + to_url(grpc, config) + .zip(to_operation(grpc)) + .zip(helpers::headers::to_mustache_headers(&grpc.headers)) + .zip(helpers::body::to_body(grpc.body.as_deref())) + .and_then(|(((url, operation), headers), body)| { + let validation = if validate_with_schema { + let field_schema = json_schema_from_field(config, field); + if grpc.group_by.is_empty() { + validate_schema(field_schema, &operation, field.name()).unit() + } else { + validate_group_by(&field_schema, &operation, grpc.group_by.clone()).unit() + } + } else { + Valid::succeed(()) + }; + validation.map(|_| (url, headers, operation, body)) + }) + .map(|(url, headers, operation, body)| { + let req_template = RequestTemplate { + url, + headers, + operation, + body, + operation_type: operation_type.clone(), + }; + if !grpc.group_by.is_empty() { + Expression::IO(IO::Grpc { + req_template, + group_by: Some(GroupBy::new(grpc.group_by.clone())), + dl_id: None, + }) + } else { + Lambda::from_grpc_request_template(req_template).expression + } + }) } pub fn update_grpc<'a>( - operation_type: &'a GraphQLOperationType, + operation_type: &'a GraphQLOperationType, ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, type_of, _name), b_field| { - let Some(grpc) = &field.grpc else { - return Valid::succeed(b_field); - }; + TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( + |(config, field, type_of, _name), b_field| { + let Some(grpc) = &field.grpc else { + return Valid::succeed(b_field); + }; - compile_grpc(CompileGrpc { config, operation_type, field, grpc, validate_with_schema: true }) - .map(|resolver| b_field.resolver(Some(resolver))) - .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) - }, - ) + compile_grpc(CompileGrpc { + config, + operation_type, + field, + grpc, + validate_with_schema: true, + }) + .map(|resolver| b_field.resolver(Some(resolver))) + .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) + }, + ) } diff --git a/src/blueprint/operators/http.rs b/src/blueprint/operators/http.rs index b60490828a4..ce0990548b1 100644 --- a/src/blueprint/operators/http.rs +++ b/src/blueprint/operators/http.rs @@ -8,58 +8,77 @@ use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; use crate::{config, helpers}; -pub fn compile_http(config: &config::Config, field: &config::Field, http: &config::Http) -> Valid { - Valid::<(), String>::fail("GroupBy is only supported for GET requests".to_string()) - .when(|| !http.group_by.is_empty() && http.method != Method::GET) - .and( - Valid::<(), String>::fail("GroupBy can only be applied if batching is enabled".to_string()) - .when(|| (config.upstream.get_delay() < 1 || config.upstream.get_max_size() < 1) && !http.group_by.is_empty()), - ) - .and(Valid::from_option( - http.base_url.as_ref().or(config.upstream.base_url.as_ref()), - "No base URL defined".to_string(), - )) - .zip(helpers::headers::to_mustache_headers(&http.headers)) - .and_then(|(base_url, headers)| { - let mut base_url = base_url.trim_end_matches('/').to_owned(); - base_url.push_str(http.path.clone().as_str()); +pub fn compile_http( + config: &config::Config, + field: &config::Field, + http: &config::Http, +) -> Valid { + Valid::<(), String>::fail("GroupBy is only supported for GET requests".to_string()) + .when(|| !http.group_by.is_empty() && http.method != Method::GET) + .and( + Valid::<(), String>::fail( + "GroupBy can only be applied if batching is enabled".to_string(), + ) + .when(|| { + (config.upstream.get_delay() < 1 || config.upstream.get_max_size() < 1) + && !http.group_by.is_empty() + }), + ) + .and(Valid::from_option( + http.base_url.as_ref().or(config.upstream.base_url.as_ref()), + "No base URL defined".to_string(), + )) + .zip(helpers::headers::to_mustache_headers(&http.headers)) + .and_then(|(base_url, headers)| { + let mut base_url = base_url.trim_end_matches('/').to_owned(); + base_url.push_str(http.path.clone().as_str()); - let query = http.query.clone().iter().map(|(k, v)| (k.clone(), v.clone())).collect(); - let output_schema = to_json_schema_for_field(field, config); - let input_schema = to_json_schema_for_args(&field.args, config); + let query = http + .query + .clone() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let output_schema = to_json_schema_for_field(field, config); + let input_schema = to_json_schema_for_args(&field.args, config); - RequestTemplate::try_from( - Endpoint::new(base_url.to_string()) - .method(http.method.clone()) - .query(query) - .output(output_schema) - .input(input_schema) - .body(http.body.clone()) - .encoding(http.encoding.clone()), - ) - .map(|req_tmpl| req_tmpl.headers(headers)) - .map_err(|e| ValidationError::new(e.to_string())) - .into() - }) - .map(|req_template| { - if !http.group_by.is_empty() && http.method == Method::GET { - Expression::IO(IO::Http { req_template, group_by: Some(GroupBy::new(http.group_by.clone())), dl_id: None }) - } else { - Lambda::from_request_template(req_template).expression - } - }) + RequestTemplate::try_from( + Endpoint::new(base_url.to_string()) + .method(http.method.clone()) + .query(query) + .output(output_schema) + .input(input_schema) + .body(http.body.clone()) + .encoding(http.encoding.clone()), + ) + .map(|req_tmpl| req_tmpl.headers(headers)) + .map_err(|e| ValidationError::new(e.to_string())) + .into() + }) + .map(|req_template| { + if !http.group_by.is_empty() && http.method == Method::GET { + Expression::IO(IO::Http { + req_template, + group_by: Some(GroupBy::new(http.group_by.clone())), + dl_id: None, + }) + } else { + Lambda::from_request_template(req_template).expression + } + }) } -pub fn update_http<'a>() -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, type_of, _), b_field| { - let Some(http) = &field.http else { - return Valid::succeed(b_field); - }; +pub fn update_http<'a>( +) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( + |(config, field, type_of, _), b_field| { + let Some(http) = &field.http else { + return Valid::succeed(b_field); + }; - compile_http(config, field, http) - .map(|resolver| b_field.resolver(Some(resolver))) - .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) - }, - ) + compile_http(config, field, http) + .map(|resolver| b_field.resolver(Some(resolver))) + .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) + }, + ) } diff --git a/src/blueprint/operators/js.rs b/src/blueprint/operators/js.rs index d4ade3e3d77..dc2a8217448 100644 --- a/src/blueprint/operators/js.rs +++ b/src/blueprint/operators/js.rs @@ -5,14 +5,18 @@ use crate::lambda::Lambda; use crate::try_fold::TryFold; use crate::valid::Valid; -pub fn update_js<'a>() -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new(|(_, field, _, _), b_field| { - let mut updated_b_field = b_field; - if let Some(op) = &field.script { - updated_b_field = updated_b_field.resolver_or_default(Lambda::context().to_js(op.script.clone()), |r| { - r.to_js(op.script.clone()) - }); - } - Valid::succeed(updated_b_field) - }) +pub fn update_js<'a>( +) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&Config, &Field, &config::Type, &str), FieldDefinition, String>::new( + |(_, field, _, _), b_field| { + let mut updated_b_field = b_field; + if let Some(op) = &field.script { + updated_b_field = updated_b_field + .resolver_or_default(Lambda::context().to_js(op.script.clone()), |r| { + r.to_js(op.script.clone()) + }); + } + Valid::succeed(updated_b_field) + }, + ) } diff --git a/src/blueprint/operators/modify.rs b/src/blueprint/operators/modify.rs index 419ce66c109..cdf8bcb7082 100644 --- a/src/blueprint/operators/modify.rs +++ b/src/blueprint/operators/modify.rs @@ -5,26 +5,29 @@ use crate::lambda::Lambda; use crate::try_fold::TryFold; use crate::valid::Valid; -pub fn update_modify<'a>() -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { - TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, type_of, _), mut b_field| { - if let Some(modify) = field.modify.as_ref() { - if let Some(new_name) = &modify.name { - for name in type_of.implements.iter() { - let interface = config.find_type(name); - if let Some(interface) = interface { - if interface.fields.iter().any(|(name, _)| name == new_name) { - return Valid::fail("Field is already implemented from interface".to_string()); - } - } - } +pub fn update_modify<'a>( +) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { + TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( + |(config, field, type_of, _), mut b_field| { + if let Some(modify) = field.modify.as_ref() { + if let Some(new_name) = &modify.name { + for name in type_of.implements.iter() { + let interface = config.find_type(name); + if let Some(interface) = interface { + if interface.fields.iter().any(|(name, _)| name == new_name) { + return Valid::fail( + "Field is already implemented from interface".to_string(), + ); + } + } + } - let lambda = Lambda::context_field(b_field.name.clone()); - b_field = b_field.resolver_or_default(lambda, |r| r); - b_field = b_field.name(new_name.clone()); - } - } - Valid::succeed(b_field) - }, - ) + let lambda = Lambda::context_field(b_field.name.clone()); + b_field = b_field.resolver_or_default(lambda, |r| r); + b_field = b_field.name(new_name.clone()); + } + } + Valid::succeed(b_field) + }, + ) } diff --git a/src/blueprint/schema.rs b/src/blueprint/schema.rs index 5d0ca164f55..c49d5396ce4 100644 --- a/src/blueprint/schema.rs +++ b/src/blueprint/schema.rs @@ -8,88 +8,104 @@ use crate::directive::DirectiveCodec; use crate::valid::{Valid, ValidationError}; fn validate_query(config: &Config) -> Valid<(), String> { - Valid::from_option(config.schema.query.clone(), "Query root is missing".to_owned()) + Valid::from_option( + config.schema.query.clone(), + "Query root is missing".to_owned(), + ) .and_then(|ref query_type_name| { - let Some(query) = config.find_type(query_type_name) else { - return Valid::fail("Query type is not defined".to_owned()).trace(query_type_name); - }; + let Some(query) = config.find_type(query_type_name) else { + return Valid::fail("Query type is not defined".to_owned()).trace(query_type_name); + }; - validate_type_has_resolvers(query_type_name, query, &config.types) + validate_type_has_resolvers(query_type_name, query, &config.types) }) .unit() } /// Validates that all the root type fields has resolver /// making into the account the nesting -fn validate_type_has_resolvers(name: &str, ty: &Type, types: &BTreeMap) -> Valid<(), String> { - Valid::from_iter(ty.fields.iter(), |(name, field)| { - validate_field_has_resolver(name, field, types) - }) - .trace(name) - .unit() -} - -pub fn validate_field_has_resolver(name: &str, field: &Field, types: &BTreeMap) -> Valid<(), String> { - Valid::<(), String>::fail("No resolver has been found in the schema".to_owned()) - .when(|| { - if !field.has_resolver() { - let f_type = &field.type_of; - if let Some(ty) = types.get(f_type) { - let res = validate_type_has_resolvers(f_type, ty, types); - return !res.is_succeed(); - } else { - return true; - } - } - false +fn validate_type_has_resolvers( + name: &str, + ty: &Type, + types: &BTreeMap, +) -> Valid<(), String> { + Valid::from_iter(ty.fields.iter(), |(name, field)| { + validate_field_has_resolver(name, field, types) }) .trace(name) + .unit() +} + +pub fn validate_field_has_resolver( + name: &str, + field: &Field, + types: &BTreeMap, +) -> Valid<(), String> { + Valid::<(), String>::fail("No resolver has been found in the schema".to_owned()) + .when(|| { + if !field.has_resolver() { + let f_type = &field.type_of; + if let Some(ty) = types.get(f_type) { + let res = validate_type_has_resolvers(f_type, ty, types); + return !res.is_succeed(); + } else { + return true; + } + } + false + }) + .trace(name) } pub fn to_directive(const_directive: ConstDirective) -> Valid { - const_directive - .arguments - .into_iter() - .map(|(k, v)| { - let value = v.node.into_json(); - if let Ok(value) = value { - return Ok((k.node.to_string(), value)); - } - Err(value.unwrap_err()) - }) - .collect::, _>>() - .map_err(|e| ValidationError::new(e.to_string())) - .map(|arguments| Directive { name: const_directive.name.node.clone().to_string(), arguments, index: 0 }) - .into() + const_directive + .arguments + .into_iter() + .map(|(k, v)| { + let value = v.node.into_json(); + if let Ok(value) = value { + return Ok((k.node.to_string(), value)); + } + Err(value.unwrap_err()) + }) + .collect::, _>>() + .map_err(|e| ValidationError::new(e.to_string())) + .map(|arguments| Directive { + name: const_directive.name.node.clone().to_string(), + arguments, + index: 0, + }) + .into() } fn validate_mutation(config: &Config) -> Valid<(), String> { - let mutation_type_name = config.schema.mutation.as_ref(); + let mutation_type_name = config.schema.mutation.as_ref(); - if let Some(mutation_type_name) = mutation_type_name { - let Some(mutation) = config.find_type(mutation_type_name) else { - return Valid::fail("Mutation type is not defined".to_owned()).trace(mutation_type_name); - }; + if let Some(mutation_type_name) = mutation_type_name { + let Some(mutation) = config.find_type(mutation_type_name) else { + return Valid::fail("Mutation type is not defined".to_owned()) + .trace(mutation_type_name); + }; - validate_type_has_resolvers(mutation_type_name, mutation, &config.types) - } else { - Valid::succeed(()) - } + validate_type_has_resolvers(mutation_type_name, mutation, &config.types) + } else { + Valid::succeed(()) + } } pub fn to_schema<'a>() -> TryFoldConfig<'a, SchemaDefinition> { - TryFoldConfig::new(|config, _| { - validate_query(config) - .and(validate_mutation(config)) - .and(Valid::from_option( - config.schema.query.as_ref(), - "Query root is missing".to_owned(), - )) - .zip(to_directive(config.server.to_directive())) - .map(|(query_type_name, directive)| SchemaDefinition { - query: query_type_name.to_owned(), - mutation: config.schema.mutation.clone(), - directives: vec![directive], - }) - }) + TryFoldConfig::new(|config, _| { + validate_query(config) + .and(validate_mutation(config)) + .and(Valid::from_option( + config.schema.query.as_ref(), + "Query root is missing".to_owned(), + )) + .zip(to_directive(config.server.to_directive())) + .map(|(query_type_name, directive)| SchemaDefinition { + query: query_type_name.to_owned(), + mutation: config.schema.mutation.clone(), + directives: vec![directive], + }) + }) } diff --git a/src/blueprint/server.rs b/src/blueprint/server.rs index 33819a986a1..072f1abc24a 100644 --- a/src/blueprint/server.rs +++ b/src/blueprint/server.rs @@ -10,135 +10,139 @@ use crate::valid::{Valid, ValidationError}; #[derive(Clone, Debug, Setters)] pub struct Server { - pub enable_apollo_tracing: bool, - pub enable_cache_control_header: bool, - pub enable_graphiql: bool, - pub enable_introspection: bool, - pub enable_query_validation: bool, - pub enable_response_validation: bool, - pub enable_batch_requests: bool, - pub enable_showcase: bool, - pub global_response_timeout: i64, - pub worker: usize, - pub port: u16, - pub hostname: IpAddr, - pub vars: BTreeMap, - pub response_headers: HeaderMap, - pub http: Http, - pub pipeline_flush: bool, + pub enable_apollo_tracing: bool, + pub enable_cache_control_header: bool, + pub enable_graphiql: bool, + pub enable_introspection: bool, + pub enable_query_validation: bool, + pub enable_response_validation: bool, + pub enable_batch_requests: bool, + pub enable_showcase: bool, + pub global_response_timeout: i64, + pub worker: usize, + pub port: u16, + pub hostname: IpAddr, + pub vars: BTreeMap, + pub response_headers: HeaderMap, + pub http: Http, + pub pipeline_flush: bool, } #[derive(Clone, Debug)] pub enum Http { - HTTP1, - HTTP2 { cert: String, key: String }, + HTTP1, + HTTP2 { cert: String, key: String }, } impl Default for Server { - fn default() -> Self { - // NOTE: Using unwrap because try_from default will never fail - Server::try_from(config::Server::default()).unwrap() - } + fn default() -> Self { + // NOTE: Using unwrap because try_from default will never fail + Server::try_from(config::Server::default()).unwrap() + } } impl Server { - pub fn get_enable_http_validation(&self) -> bool { - self.enable_response_validation - } - pub fn get_enable_cache_control(&self) -> bool { - self.enable_cache_control_header - } - - pub fn get_enable_introspection(&self) -> bool { - self.enable_introspection - } - - pub fn get_enable_query_validation(&self) -> bool { - self.enable_query_validation - } + pub fn get_enable_http_validation(&self) -> bool { + self.enable_response_validation + } + pub fn get_enable_cache_control(&self) -> bool { + self.enable_cache_control_header + } + + pub fn get_enable_introspection(&self) -> bool { + self.enable_introspection + } + + pub fn get_enable_query_validation(&self) -> bool { + self.enable_query_validation + } } impl TryFrom for Server { - type Error = ValidationError; - - fn try_from(config_server: config::Server) -> Result { - let http_server = match config_server.clone().get_version() { - HttpVersion::HTTP2 => { - let cert = Valid::from_option( - config_server.cert.clone(), - "Certificate is required for HTTP2".to_string(), - ); - let key = Valid::from_option(config_server.key.clone(), "Key is required for HTTP2".to_string()); - - cert.zip(key).map(|(cert, key)| Http::HTTP2 { cert, key }) - } - _ => Valid::succeed(Http::HTTP1), - }; - - validate_hostname((config_server).get_hostname().to_lowercase()) - .zip(http_server) - .zip(handle_response_headers((config_server).get_response_headers().0)) - .map(|((hostname, http), response_headers)| Server { - enable_apollo_tracing: (config_server).enable_apollo_tracing(), - enable_cache_control_header: (config_server).enable_cache_control(), - enable_graphiql: (config_server).enable_graphiql(), - enable_introspection: (config_server).enable_introspection(), - enable_query_validation: (config_server).enable_query_validation(), - enable_response_validation: (config_server).enable_http_validation(), - enable_batch_requests: (config_server).enable_batch_requests(), - enable_showcase: (config_server).enable_showcase(), - global_response_timeout: (config_server).get_global_response_timeout(), - http, - worker: (config_server).get_workers(), - port: (config_server).get_port(), - hostname, - vars: (config_server).get_vars(), - pipeline_flush: (config_server).get_pipeline_flush(), - response_headers, - }) - .to_result() - } + type Error = ValidationError; + + fn try_from(config_server: config::Server) -> Result { + let http_server = match config_server.clone().get_version() { + HttpVersion::HTTP2 => { + let cert = Valid::from_option( + config_server.cert.clone(), + "Certificate is required for HTTP2".to_string(), + ); + let key = Valid::from_option( + config_server.key.clone(), + "Key is required for HTTP2".to_string(), + ); + + cert.zip(key).map(|(cert, key)| Http::HTTP2 { cert, key }) + } + _ => Valid::succeed(Http::HTTP1), + }; + + validate_hostname((config_server).get_hostname().to_lowercase()) + .zip(http_server) + .zip(handle_response_headers( + (config_server).get_response_headers().0, + )) + .map(|((hostname, http), response_headers)| Server { + enable_apollo_tracing: (config_server).enable_apollo_tracing(), + enable_cache_control_header: (config_server).enable_cache_control(), + enable_graphiql: (config_server).enable_graphiql(), + enable_introspection: (config_server).enable_introspection(), + enable_query_validation: (config_server).enable_query_validation(), + enable_response_validation: (config_server).enable_http_validation(), + enable_batch_requests: (config_server).enable_batch_requests(), + enable_showcase: (config_server).enable_showcase(), + global_response_timeout: (config_server).get_global_response_timeout(), + http, + worker: (config_server).get_workers(), + port: (config_server).get_port(), + hostname, + vars: (config_server).get_vars(), + pipeline_flush: (config_server).get_pipeline_flush(), + response_headers, + }) + .to_result() + } } fn validate_hostname(hostname: String) -> Valid { - if hostname == "localhost" { - Valid::succeed(IpAddr::from([127, 0, 0, 1])) - } else { - Valid::from( - hostname - .parse() - .map_err(|e: AddrParseError| ValidationError::new(format!("Parsing failed because of {}", e))), - ) - .trace("hostname") - .trace("@server") - .trace("schema") - } + if hostname == "localhost" { + Valid::succeed(IpAddr::from([127, 0, 0, 1])) + } else { + Valid::from(hostname.parse().map_err(|e: AddrParseError| { + ValidationError::new(format!("Parsing failed because of {}", e)) + })) + .trace("hostname") + .trace("@server") + .trace("schema") + } } fn handle_response_headers(resp_headers: BTreeMap) -> Valid { - Valid::from_iter(resp_headers.iter(), |(k, v)| { - let name = Valid::from( - HeaderName::from_bytes(k.as_bytes()) - .map_err(|e| ValidationError::new(format!("Parsing failed because of {}", e))), - ); - let value = Valid::from( - HeaderValue::from_str(v.as_str()).map_err(|e| ValidationError::new(format!("Parsing failed because of {}", e))), - ); - name.zip(value) - }) - .map(|headers| headers.into_iter().collect::()) - .trace("responseHeaders") - .trace("@server") - .trace("schema") + Valid::from_iter(resp_headers.iter(), |(k, v)| { + let name = Valid::from( + HeaderName::from_bytes(k.as_bytes()) + .map_err(|e| ValidationError::new(format!("Parsing failed because of {}", e))), + ); + let value = Valid::from( + HeaderValue::from_str(v.as_str()) + .map_err(|e| ValidationError::new(format!("Parsing failed because of {}", e))), + ); + name.zip(value) + }) + .map(|headers| headers.into_iter().collect::()) + .trace("responseHeaders") + .trace("@server") + .trace("schema") } #[cfg(test)] mod tests { - use crate::config; + use crate::config; - #[test] - fn test_try_from_default() { - let actual = super::Server::try_from(config::Server::default()); - assert!(actual.is_ok()) - } + #[test] + fn test_try_from_default() { + let actual = super::Server::try_from(config::Server::default()); + assert!(actual.is_ok()) + } } diff --git a/src/blueprint/timeout.rs b/src/blueprint/timeout.rs index 4e18a47e72a..7f7318371f0 100644 --- a/src/blueprint/timeout.rs +++ b/src/blueprint/timeout.rs @@ -9,34 +9,40 @@ use tokio::time::timeout; pub struct GlobalTimeout; impl ExtensionFactory for GlobalTimeout { - fn create(&self) -> Arc { - Arc::new(GlobalTimeoutExtension) - } + fn create(&self) -> Arc { + Arc::new(GlobalTimeoutExtension) + } } struct GlobalTimeoutExtension; #[async_trait::async_trait] impl Extension for GlobalTimeoutExtension { - async fn execute(&self, ctx: &ExtensionContext<'_>, operation_name: Option<&str>, next: NextExecute<'_>) -> Response { - let future = next.run(ctx, operation_name); - if let ConstValue::Number(number) = ctx.data_unchecked::() { - let timeout_duration = number.as_u64().unwrap_or(0); - if timeout_duration > 0 { - let result = timeout(Duration::from_millis(timeout_duration), future).await; - match result { - Ok(result) => result, - Err(_) => { - let mut response = Response::new(ConstValue::Null); - response.errors = vec![ServerError::new("Global timeout".to_string(), None)]; - response - } + async fn execute( + &self, + ctx: &ExtensionContext<'_>, + operation_name: Option<&str>, + next: NextExecute<'_>, + ) -> Response { + let future = next.run(ctx, operation_name); + if let ConstValue::Number(number) = ctx.data_unchecked::() { + let timeout_duration = number.as_u64().unwrap_or(0); + if timeout_duration > 0 { + let result = timeout(Duration::from_millis(timeout_duration), future).await; + match result { + Ok(result) => result, + Err(_) => { + let mut response = Response::new(ConstValue::Null); + response.errors = + vec![ServerError::new("Global timeout".to_string(), None)]; + response + } + } + } else { + future.await + } + } else { + future.await } - } else { - future.await - } - } else { - future.await } - } } diff --git a/src/blueprint/upstream.rs b/src/blueprint/upstream.rs index aaea2161447..88bec02d341 100644 --- a/src/blueprint/upstream.rs +++ b/src/blueprint/upstream.rs @@ -4,13 +4,15 @@ use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; pub fn to_upstream<'a>() -> TryFold<'a, Config, Upstream, String> { - TryFoldConfig::::new(|config, up| { - let upstream = up.merge_right(config.upstream.clone()); - if let Some(ref base_url) = upstream.base_url { - Valid::from(reqwest::Url::parse(base_url).map_err(|e| ValidationError::new(e.to_string()))) - .map_to(upstream.clone()) - } else { - Valid::succeed(upstream.clone()) - } - }) + TryFoldConfig::::new(|config, up| { + let upstream = up.merge_right(config.upstream.clone()); + if let Some(ref base_url) = upstream.base_url { + Valid::from( + reqwest::Url::parse(base_url).map_err(|e| ValidationError::new(e.to_string())), + ) + .map_to(upstream.clone()) + } else { + Valid::succeed(upstream.clone()) + } + }) } diff --git a/src/cache.rs b/src/cache.rs index 078477ed7f0..c71e92df959 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -5,20 +5,20 @@ pub struct Cache(Mutex>); impl Cache where - K: std::cmp::Eq, - K: PartialEq, - K: core::hash::Hash, - V: std::clone::Clone, + K: std::cmp::Eq, + K: PartialEq, + K: core::hash::Hash, + V: std::clone::Clone, { - pub fn get(&self, key: &K) -> Option { - self.0.lock().unwrap().get(key).cloned() - } + pub fn get(&self, key: &K) -> Option { + self.0.lock().unwrap().get(key).cloned() + } - pub fn insert(&self, key: K, value: V) { - self.0.lock().unwrap().insert(key, value); - } + pub fn insert(&self, key: K, value: V) { + self.0.lock().unwrap().insert(key, value); + } - pub fn empty() -> Self { - Self(Mutex::new(HashMap::new())) - } + pub fn empty() -> Self { + Self(Mutex::new(HashMap::new())) + } } diff --git a/src/cli/cache.rs b/src/cli/cache.rs index 41ffea92fd8..99d90f74b3b 100644 --- a/src/cli/cache.rs +++ b/src/cli/cache.rs @@ -11,42 +11,40 @@ use crate::Cache; const CACHE_CAPACITY: usize = 100000; pub struct NativeChronoCache { - data: Arc>>, + data: Arc>>, } impl Default for NativeChronoCache { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl NativeChronoCache { - pub fn new() -> Self { - NativeChronoCache { data: Arc::new(RwLock::new(TtlCache::new(CACHE_CAPACITY))) } - } + pub fn new() -> Self { + NativeChronoCache { data: Arc::new(RwLock::new(TtlCache::new(CACHE_CAPACITY))) } + } } #[async_trait::async_trait] impl Cache for NativeChronoCache { - type Key = K; - type Value = V; - #[allow(clippy::too_many_arguments)] - async fn set<'a>(&'a self, key: K, value: V, ttl: NonZeroU64) -> Result { - let ttl = Duration::from_millis(ttl.get()); - self - .data - .write() - .unwrap() - .insert(key, value, ttl) - .ok_or(anyhow!("unable to insert value")) - } + type Key = K; + type Value = V; + #[allow(clippy::too_many_arguments)] + async fn set<'a>(&'a self, key: K, value: V, ttl: NonZeroU64) -> Result { + let ttl = Duration::from_millis(ttl.get()); + self.data + .write() + .unwrap() + .insert(key, value, ttl) + .ok_or(anyhow!("unable to insert value")) + } - async fn get<'a>(&'a self, key: &'a K) -> Result { - self - .data - .read() - .unwrap() - .get(key) - .cloned() - .ok_or(anyhow!("key not found")) - } + async fn get<'a>(&'a self, key: &'a K) -> Result { + self.data + .read() + .unwrap() + .get(key) + .cloned() + .ok_or(anyhow!("key not found")) + } } diff --git a/src/cli/command.rs b/src/cli/command.rs index b7cba960729..d1a83c7ded5 100644 --- a/src/cli/command.rs +++ b/src/cli/command.rs @@ -3,8 +3,8 @@ use clap::{Parser, Subcommand}; use crate::config::Source; const VERSION: &str = match option_env!("APP_VERSION") { - Some(version) => version, - _ => "0.1.0-dev", + Some(version) => version, + _ => "0.1.0-dev", }; const ABOUT: &str = r" __ _ __ ____ @@ -16,53 +16,53 @@ const ABOUT: &str = r" #[derive(Parser)] #[command(name ="tailcall",author, version = VERSION, about, long_about = Some(ABOUT))] pub struct Cli { - #[command(subcommand)] - pub command: Command, + #[command(subcommand)] + pub command: Command, } #[derive(Subcommand)] pub enum Command { - /// Starts the GraphQL server on the configured port - Start { - /// Path for the configuration files or http(s) link to config files separated by spaces if more than one - #[arg(required = true)] - file_paths: Vec, - }, + /// Starts the GraphQL server on the configured port + Start { + /// Path for the configuration files or http(s) link to config files separated by spaces if more than one + #[arg(required = true)] + file_paths: Vec, + }, - /// Validate a composition spec - Check { - /// Path for the configuration files separated by spaces if more than one - #[arg(required = true)] - file_paths: Vec, + /// Validate a composition spec + Check { + /// Path for the configuration files separated by spaces if more than one + #[arg(required = true)] + file_paths: Vec, - /// N plus one queries - #[arg(short, long)] - n_plus_one_queries: bool, + /// N plus one queries + #[arg(short, long)] + n_plus_one_queries: bool, - /// Display schema - #[arg(short, long)] - schema: bool, + /// Display schema + #[arg(short, long)] + schema: bool, - /// Operations to check - #[arg(short, long, value_delimiter=',', num_args = 1..)] - operations: Vec, - }, + /// Operations to check + #[arg(short, long, value_delimiter=',', num_args = 1..)] + operations: Vec, + }, - /// Merge multiple configuration file into one - Compose { - /// Path for the configuration files separated by spaces if more than one - #[arg(required = true)] - file_paths: Vec, + /// Merge multiple configuration file into one + Compose { + /// Path for the configuration files separated by spaces if more than one + #[arg(required = true)] + file_paths: Vec, - /// Format of the result. Accepted values: JSON|YML|GQL. - #[clap(short, long, default_value = "gql")] - format: Source, - }, + /// Format of the result. Accepted values: JSON|YML|GQL. + #[clap(short, long, default_value = "gql")] + format: Source, + }, - /// Initialize a new project - Init { - // default is current directory - #[arg(default_value = ".")] - folder_path: String, - }, + /// Initialize a new project + Init { + // default is current directory + #[arg(default_value = ".")] + folder_path: String, + }, } diff --git a/src/cli/env.rs b/src/cli/env.rs index 9002af685de..76f0101954e 100644 --- a/src/cli/env.rs +++ b/src/cli/env.rs @@ -4,17 +4,17 @@ use crate::EnvIO; #[derive(Clone)] pub struct EnvNative { - vars: HashMap, + vars: HashMap, } impl EnvIO for EnvNative { - fn get(&self, key: &str) -> Option { - self.vars.get(key).cloned() - } + fn get(&self, key: &str) -> Option { + self.vars.get(key).cloned() + } } impl EnvNative { - pub fn init() -> Self { - Self { vars: std::env::vars().collect() } - } + pub fn init() -> Self { + Self { vars: std::env::vars().collect() } + } } diff --git a/src/cli/error.rs b/src/cli/error.rs index 38450d2cc26..7602a72205a 100644 --- a/src/cli/error.rs +++ b/src/cli/error.rs @@ -8,325 +8,341 @@ use crate::valid::ValidationError; #[derive(Debug, Error, Setters)] pub struct CLIError { - is_root: bool, - #[setters(skip)] - color: bool, - message: String, - #[setters(strip_option)] - description: Option, - trace: Vec, - - #[setters(skip)] - caused_by: Vec, + is_root: bool, + #[setters(skip)] + color: bool, + message: String, + #[setters(strip_option)] + description: Option, + trace: Vec, + + #[setters(skip)] + caused_by: Vec, } impl CLIError { - pub fn new(message: &str) -> Self { - CLIError { - is_root: true, - color: false, - message: message.to_string(), - description: Default::default(), - trace: Default::default(), - caused_by: Default::default(), + pub fn new(message: &str) -> Self { + CLIError { + is_root: true, + color: false, + message: message.to_string(), + description: Default::default(), + trace: Default::default(), + caused_by: Default::default(), + } } - } - pub fn caused_by(mut self, error: Vec) -> Self { - self.caused_by = error; + pub fn caused_by(mut self, error: Vec) -> Self { + self.caused_by = error; - for error in self.caused_by.iter_mut() { - error.is_root = false; - } + for error in self.caused_by.iter_mut() { + error.is_root = false; + } - self - } + self + } - fn colored<'a>(&'a self, str: &'a str, color: colored::Color) -> String { - if self.color { - str.color(color).to_string() - } else { - str.to_string() + fn colored<'a>(&'a self, str: &'a str, color: colored::Color) -> String { + if self.color { + str.color(color).to_string() + } else { + str.to_string() + } } - } - fn dimmed<'a>(&'a self, str: &'a str) -> String { - if self.color { - str.dimmed().to_string() - } else { - str.to_string() + fn dimmed<'a>(&'a self, str: &'a str) -> String { + if self.color { + str.dimmed().to_string() + } else { + str.to_string() + } } - } - pub fn color(mut self, color: bool) -> Self { - self.color = color; - for inner in self.caused_by.iter_mut() { - inner.color = color; + pub fn color(mut self, color: bool) -> Self { + self.color = color; + for inner in self.caused_by.iter_mut() { + inner.color = color; + } + self } - self - } } fn margin(str: &str, margin: usize) -> String { - let mut result = String::new(); - for line in str.split_inclusive('\n') { - result.push_str(&format!("{}{}", " ".repeat(margin), line)); - } - result + let mut result = String::new(); + for line in str.split_inclusive('\n') { + result.push_str(&format!("{}{}", " ".repeat(margin), line)); + } + result } fn bullet(str: &str) -> String { - let mut chars = margin(str, 2).chars().collect::>(); - chars[0] = '•'; - chars[1] = ' '; - chars.into_iter().collect::() + let mut chars = margin(str, 2).chars().collect::>(); + chars[0] = '•'; + chars[1] = ' '; + chars.into_iter().collect::() } impl Display for CLIError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let error_prefix = "Error: "; - let default_padding = 2; - let root_padding_size = if self.is_root { - error_prefix.len() - } else { - default_padding - }; - - if self.is_root { - f.write_str(self.colored(error_prefix, colored::Color::Red).as_str())?; - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let error_prefix = "Error: "; + let default_padding = 2; + let root_padding_size = if self.is_root { + error_prefix.len() + } else { + default_padding + }; + + if self.is_root { + f.write_str(self.colored(error_prefix, colored::Color::Red).as_str())?; + } - f.write_str(&self.message.to_string())?; - - if let Some(description) = &self.description { - f.write_str("\n")?; - let color = if self.is_root { - colored::Color::Yellow - } else { - colored::Color::White - }; - f.write_str( - margin( - &self.colored(format!("❯ {}", description).as_str(), color), - root_padding_size, - ) - .as_str(), - )?; - } + f.write_str(&self.message.to_string())?; + + if let Some(description) = &self.description { + f.write_str("\n")?; + let color = if self.is_root { + colored::Color::Yellow + } else { + colored::Color::White + }; + f.write_str( + margin( + &self.colored(format!("❯ {}", description).as_str(), color), + root_padding_size, + ) + .as_str(), + )?; + } - if !self.trace.is_empty() { - let mut buf = String::new(); - buf.push_str(" [at "); - let len = self.trace.len(); - for (i, trace) in self.trace.iter().enumerate() { - buf.push_str(&trace.to_string()); - if i < len - 1 { - buf.push('.'); + if !self.trace.is_empty() { + let mut buf = String::new(); + buf.push_str(" [at "); + let len = self.trace.len(); + for (i, trace) in self.trace.iter().enumerate() { + buf.push_str(&trace.to_string()); + if i < len - 1 { + buf.push('.'); + } + } + buf.push(']'); + f.write_str(&self.colored(&buf, colored::Color::Cyan))?; } - } - buf.push(']'); - f.write_str(&self.colored(&buf, colored::Color::Cyan))?; - } - if !self.caused_by.is_empty() { - f.write_str(self.dimmed("\nCaused by:\n").as_str())?; - for (i, error) in self.caused_by.iter().enumerate() { - let message = &error.to_string(); - f.write_str(&margin(bullet(message.as_str()).as_str(), default_padding))?; + if !self.caused_by.is_empty() { + f.write_str(self.dimmed("\nCaused by:\n").as_str())?; + for (i, error) in self.caused_by.iter().enumerate() { + let message = &error.to_string(); + f.write_str(&margin(bullet(message.as_str()).as_str(), default_padding))?; - if i < self.caused_by.len() - 1 { - f.write_str("\n")?; + if i < self.caused_by.len() - 1 { + f.write_str("\n")?; + } + } } - } - } - Ok(()) - } + Ok(()) + } } impl From for CLIError { - fn from(error: hyper::Error) -> Self { - // TODO: add type-safety to CLIError conversion - let cli_error = CLIError::new("Server Failed"); - let message = error.to_string(); - if message.to_lowercase().contains("os error 48") { - cli_error - .description("The port is already in use".to_string()) - .caused_by(vec![CLIError::new(message.as_str())]) - } else { - cli_error.description(message) + fn from(error: hyper::Error) -> Self { + // TODO: add type-safety to CLIError conversion + let cli_error = CLIError::new("Server Failed"); + let message = error.to_string(); + if message.to_lowercase().contains("os error 48") { + cli_error + .description("The port is already in use".to_string()) + .caused_by(vec![CLIError::new(message.as_str())]) + } else { + cli_error.description(message) + } } - } } impl From for CLIError { - fn from(error: rustls::Error) -> Self { - let cli_error = CLIError::new("Failed to create TLS Acceptor"); - let message = error.to_string(); + fn from(error: rustls::Error) -> Self { + let cli_error = CLIError::new("Failed to create TLS Acceptor"); + let message = error.to_string(); - cli_error.description(message) - } + cli_error.description(message) + } } impl From for CLIError { - fn from(error: std::io::Error) -> Self { - let cli_error = CLIError::new("IO Error"); - let message = error.to_string(); + fn from(error: std::io::Error) -> Self { + let cli_error = CLIError::new("IO Error"); + let message = error.to_string(); - cli_error.description(message) - } + cli_error.description(message) + } } impl<'a> From> for CLIError { - fn from(error: ValidationError<&'a str>) -> Self { - CLIError::new("Invalid Configuration").caused_by( - error - .as_vec() - .iter() - .map(|cause| { - let mut err = CLIError::new(cause.message).trace(Vec::from(cause.trace.clone())); - if let Some(description) = cause.description { - err = err.description(description.to_owned()); - } - err - }) - .collect(), - ) - } + fn from(error: ValidationError<&'a str>) -> Self { + CLIError::new("Invalid Configuration").caused_by( + error + .as_vec() + .iter() + .map(|cause| { + let mut err = + CLIError::new(cause.message).trace(Vec::from(cause.trace.clone())); + if let Some(description) = cause.description { + err = err.description(description.to_owned()); + } + err + }) + .collect(), + ) + } } impl From> for CLIError { - fn from(error: ValidationError) -> Self { - CLIError::new("Invalid Configuration").caused_by( - error - .as_vec() - .iter() - .map(|cause| CLIError::new(cause.message.as_str()).trace(Vec::from(cause.trace.clone()))) - .collect(), - ) - } + fn from(error: ValidationError) -> Self { + CLIError::new("Invalid Configuration").caused_by( + error + .as_vec() + .iter() + .map(|cause| { + CLIError::new(cause.message.as_str()).trace(Vec::from(cause.trace.clone())) + }) + .collect(), + ) + } } impl From> for CLIError { - fn from(value: Box) -> Self { - CLIError::new(value.to_string().as_str()) - } + fn from(value: Box) -> Self { + CLIError::new(value.to_string().as_str()) + } } #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; - use stripmargin::StripMargin; - - use super::*; - use crate::valid::Cause; - - #[test] - fn test_no_newline() { - let input = "Hello"; - let expected = " Hello"; - assert_eq!(margin(input, 4), expected); - } - - #[test] - fn test_with_newline() { - let input = "Hello\nWorld"; - let expected = " Hello\n World"; - assert_eq!(margin(input, 4), expected); - } - - #[test] - fn test_empty_string() { - let input = ""; - let expected = ""; - assert_eq!(margin(input, 4), expected); - } - - #[test] - fn test_zero_margin() { - let input = "Hello"; - let expected = "Hello"; - assert_eq!(margin(input, 0), expected); - } - - #[test] - fn test_zero_margin_with_newline() { - let input = "Hello\nWorld"; - let expected = "Hello\nWorld"; - assert_eq!(margin(input, 0), expected); - } - - #[test] - fn test_title() { - let error = CLIError::new("Server could not be started"); - let expected = r"Error: Server could not be started".strip_margin(); - assert_eq!(error.to_string(), expected); - } - - #[test] - fn test_title_description() { - let error = CLIError::new("Server could not be started").description("The port is already in use".to_string()); - let expected = r"|Error: Server could not be started + use pretty_assertions::assert_eq; + use stripmargin::StripMargin; + + use super::*; + use crate::valid::Cause; + + #[test] + fn test_no_newline() { + let input = "Hello"; + let expected = " Hello"; + assert_eq!(margin(input, 4), expected); + } + + #[test] + fn test_with_newline() { + let input = "Hello\nWorld"; + let expected = " Hello\n World"; + assert_eq!(margin(input, 4), expected); + } + + #[test] + fn test_empty_string() { + let input = ""; + let expected = ""; + assert_eq!(margin(input, 4), expected); + } + + #[test] + fn test_zero_margin() { + let input = "Hello"; + let expected = "Hello"; + assert_eq!(margin(input, 0), expected); + } + + #[test] + fn test_zero_margin_with_newline() { + let input = "Hello\nWorld"; + let expected = "Hello\nWorld"; + assert_eq!(margin(input, 0), expected); + } + + #[test] + fn test_title() { + let error = CLIError::new("Server could not be started"); + let expected = r"Error: Server could not be started".strip_margin(); + assert_eq!(error.to_string(), expected); + } + + #[test] + fn test_title_description() { + let error = CLIError::new("Server could not be started") + .description("The port is already in use".to_string()); + let expected = r"|Error: Server could not be started | ❯ The port is already in use" - .strip_margin(); + .strip_margin(); - assert_eq!(error.to_string(), expected); - } + assert_eq!(error.to_string(), expected); + } - #[test] - fn test_title_description_trace() { - let error = CLIError::new("Server could not be started") - .description("The port is already in use".to_string()) - .trace(vec!["@server".into(), "port".into()]); + #[test] + fn test_title_description_trace() { + let error = CLIError::new("Server could not be started") + .description("The port is already in use".to_string()) + .trace(vec!["@server".into(), "port".into()]); - let expected = r"|Error: Server could not be started + let expected = r"|Error: Server could not be started | ❯ The port is already in use [at @server.port]" - .strip_margin(); + .strip_margin(); - assert_eq!(error.to_string(), expected); - } - - #[test] - fn test_title_trace_caused_by() { - let error = CLIError::new("Configuration Error").caused_by(vec![CLIError::new("Base URL needs to be specified") - .trace(vec!["User".into(), "posts".into(), "@http".into(), "baseURL".into()])]); + assert_eq!(error.to_string(), expected); + } - let expected = r"|Error: Configuration Error + #[test] + fn test_title_trace_caused_by() { + let error = CLIError::new("Configuration Error").caused_by(vec![CLIError::new( + "Base URL needs to be specified", + ) + .trace(vec![ + "User".into(), + "posts".into(), + "@http".into(), + "baseURL".into(), + ])]); + + let expected = r"|Error: Configuration Error |Caused by: | • Base URL needs to be specified [at User.posts.@http.baseURL]" - .strip_margin(); + .strip_margin(); + + assert_eq!(error.to_string(), expected); + } - assert_eq!(error.to_string(), expected); - } - - #[test] - fn test_title_trace_multiple_caused_by() { - let error = CLIError::new("Configuration Error").caused_by(vec![ - CLIError::new("Base URL needs to be specified").trace(vec![ - "User".into(), - "posts".into(), - "@http".into(), - "baseURL".into(), - ]), - CLIError::new("Base URL needs to be specified").trace(vec![ - "Post".into(), - "users".into(), - "@http".into(), - "baseURL".into(), - ]), - CLIError::new("Base URL needs to be specified") - .description("Set `baseURL` in @http or @server directives".into()) - .trace(vec!["Query".into(), "users".into(), "@http".into(), "baseURL".into()]), - CLIError::new("Base URL needs to be specified").trace(vec![ - "Query".into(), - "posts".into(), - "@http".into(), - "baseURL".into(), - ]), - ]); - - let expected = r"|Error: Configuration Error + #[test] + fn test_title_trace_multiple_caused_by() { + let error = CLIError::new("Configuration Error").caused_by(vec![ + CLIError::new("Base URL needs to be specified").trace(vec![ + "User".into(), + "posts".into(), + "@http".into(), + "baseURL".into(), + ]), + CLIError::new("Base URL needs to be specified").trace(vec![ + "Post".into(), + "users".into(), + "@http".into(), + "baseURL".into(), + ]), + CLIError::new("Base URL needs to be specified") + .description("Set `baseURL` in @http or @server directives".into()) + .trace(vec![ + "Query".into(), + "users".into(), + "@http".into(), + "baseURL".into(), + ]), + CLIError::new("Base URL needs to be specified").trace(vec![ + "Query".into(), + "posts".into(), + "@http".into(), + "baseURL".into(), + ]), + ]); + + let expected = r"|Error: Configuration Error |Caused by: | • Base URL needs to be specified [at User.posts.@http.baseURL] | • Base URL needs to be specified [at Post.users.@http.baseURL] @@ -335,22 +351,22 @@ mod tests { | • Base URL needs to be specified [at Query.posts.@http.baseURL]" .strip_margin(); - assert_eq!(error.to_string(), expected); - } - - #[test] - fn test_from_validation() { - let cause = Cause::new("Base URL needs to be specified") - .description("Set `baseURL` in @http or @server directives") - .trace(vec!["Query", "users", "@http", "baseURL"]); - let valid = ValidationError::from(cause); - let error = CLIError::from(valid); - let expected = r"|Error: Invalid Configuration + assert_eq!(error.to_string(), expected); + } + + #[test] + fn test_from_validation() { + let cause = Cause::new("Base URL needs to be specified") + .description("Set `baseURL` in @http or @server directives") + .trace(vec!["Query", "users", "@http", "baseURL"]); + let valid = ValidationError::from(cause); + let error = CLIError::from(valid); + let expected = r"|Error: Invalid Configuration |Caused by: | • Base URL needs to be specified | ❯ Set `baseURL` in @http or @server directives [at Query.users.@http.baseURL]" .strip_margin(); - assert_eq!(error.to_string(), expected); - } + assert_eq!(error.to_string(), expected); + } } diff --git a/src/cli/file.rs b/src/cli/file.rs index bdde83e95e4..50f2d649b06 100644 --- a/src/cli/file.rs +++ b/src/cli/file.rs @@ -8,24 +8,26 @@ use crate::FileIO; pub struct NativeFileIO {} impl NativeFileIO { - pub fn init() -> Self { - NativeFileIO {} - } + pub fn init() -> Self { + NativeFileIO {} + } } impl FileIO for NativeFileIO { - async fn write<'a>(&'a self, file_path: &'a str, content: &'a [u8]) -> Result<()> { - let mut file = tokio::fs::File::create(file_path).await?; - file.write_all(content).await.map_err(CLIError::from)?; - log::info!("File write: {} ... ok", file_path); - Ok(()) - } + async fn write<'a>(&'a self, file_path: &'a str, content: &'a [u8]) -> Result<()> { + let mut file = tokio::fs::File::create(file_path).await?; + file.write_all(content).await.map_err(CLIError::from)?; + log::info!("File write: {} ... ok", file_path); + Ok(()) + } - async fn read<'a>(&'a self, file_path: &'a str) -> Result { - let mut file = tokio::fs::File::open(file_path).await?; - let mut buffer = Vec::new(); - file.read_to_end(&mut buffer).await.map_err(CLIError::from)?; - log::info!("File read: {} ... ok", file_path); - Ok(String::from_utf8(buffer)?) - } + async fn read<'a>(&'a self, file_path: &'a str) -> Result { + let mut file = tokio::fs::File::open(file_path).await?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer) + .await + .map_err(CLIError::from)?; + log::info!("File read: {} ... ok", file_path); + Ok(String::from_utf8(buffer)?) + } } diff --git a/src/cli/fmt.rs b/src/cli/fmt.rs index 56479a1ba4f..533aad6157d 100644 --- a/src/cli/fmt.rs +++ b/src/cli/fmt.rs @@ -5,87 +5,91 @@ use crate::config::Config; pub struct Fmt {} impl Fmt { - pub fn heading(heading: &String) -> String { - format!("{}", heading.bold()) - } + pub fn heading(heading: &String) -> String { + format!("{}", heading.bold()) + } - pub fn meta(meta: &String) -> String { - format!("{}", meta.yellow()) - } + pub fn meta(meta: &String) -> String { + format!("{}", meta.yellow()) + } - pub fn display(s: String) { - println!("{}", s); - } + pub fn display(s: String) { + println!("{}", s); + } - pub fn table(labels: Vec<(String, String)>) -> String { - let max_length = labels.iter().map(|(key, _)| key.len()).max().unwrap_or(0) + 1; - let padding = " ".repeat(max_length); - let mut table = labels - .iter() - .map(|(key, value)| { - Fmt::heading( - &(key.clone() + ":" + padding.as_str()) - .chars() - .take(max_length) - .collect::(), - ) + " " - + value - }) - .collect::>() - .join("\n"); - table.push('\n'); - table - } + pub fn table(labels: Vec<(String, String)>) -> String { + let max_length = labels.iter().map(|(key, _)| key.len()).max().unwrap_or(0) + 1; + let padding = " ".repeat(max_length); + let mut table = labels + .iter() + .map(|(key, value)| { + Fmt::heading( + &(key.clone() + ":" + padding.as_str()) + .chars() + .take(max_length) + .collect::(), + ) + " " + + value + }) + .collect::>() + .join("\n"); + table.push('\n'); + table + } - pub fn format_n_plus_one_queries(n_plus_one_info: Vec>) -> String { - let query_paths: Vec> = n_plus_one_info - .iter() - .map(|item| item.iter().map(|(_, field_name)| field_name).collect::>()) - .collect(); + pub fn format_n_plus_one_queries(n_plus_one_info: Vec>) -> String { + let query_paths: Vec> = n_plus_one_info + .iter() + .map(|item| { + item.iter() + .map(|(_, field_name)| field_name) + .collect::>() + }) + .collect(); - let query_data: Vec = query_paths - .iter() - .map(|query_path| { - let mut path = " query { ".to_string(); - path.push_str( - query_path + let query_data: Vec = query_paths .iter() - .rfold("".to_string(), |s, field_name| { - if s.is_empty() { - field_name.to_string() - } else { - format!("{} {{ {} }}", field_name, s) - } + .map(|query_path| { + let mut path = " query { ".to_string(); + path.push_str( + query_path + .iter() + .rfold("".to_string(), |s, field_name| { + if s.is_empty() { + field_name.to_string() + } else { + format!("{} {{ {} }}", field_name, s) + } + }) + .as_str(), + ); + path.push_str(" }"); + path }) - .as_str(), - ); - path.push_str(" }"); - path - }) - .collect(); + .collect(); - Fmt::meta(&query_data.iter().rfold("".to_string(), |s, query| { - if s.is_empty() { - query.to_string() - } else { - format!("{}\n{}", query, s) - } - })) - } + Fmt::meta(&query_data.iter().rfold("".to_string(), |s, query| { + if s.is_empty() { + query.to_string() + } else { + format!("{}\n{}", query, s) + } + })) + } - pub fn n_plus_one_data(n_plus_one_queries: bool, config: &Config) -> (String, String) { - let n_plus_one_info = config.n_plus_one(); - if n_plus_one_queries { - ( - "N + 1".to_string(), - [ - n_plus_one_info.len().to_string(), - Self::format_n_plus_one_queries(n_plus_one_info), - ] - .join("\n"), - ) - } else { - ("N + 1".to_string(), n_plus_one_info.len().to_string()) + pub fn n_plus_one_data(n_plus_one_queries: bool, config: &Config) -> (String, String) { + let n_plus_one_info = config.n_plus_one(); + if n_plus_one_queries { + ( + "N + 1".to_string(), + [ + n_plus_one_info.len().to_string(), + Self::format_n_plus_one_queries(n_plus_one_info), + ] + .join("\n"), + ) + } else { + ("N + 1".to_string(), n_plus_one_info.len().to_string()) + } } - } } diff --git a/src/cli/http.rs b/src/cli/http.rs index 761fb6f6e02..5a5f9db453e 100644 --- a/src/cli/http.rs +++ b/src/cli/http.rs @@ -12,61 +12,77 @@ use crate::http::Response; #[derive(Clone)] pub struct NativeHttp { - client: ClientWithMiddleware, - http2_only: bool, + client: ClientWithMiddleware, + http2_only: bool, } impl Default for NativeHttp { - fn default() -> Self { - Self { client: ClientBuilder::new(Client::new()).build(), http2_only: false } - } + fn default() -> Self { + Self { + client: ClientBuilder::new(Client::new()).build(), + http2_only: false, + } + } } impl NativeHttp { - pub fn init(upstream: &Upstream) -> Self { - let mut builder = Client::builder() - .tcp_keepalive(Some(Duration::from_secs(upstream.get_tcp_keep_alive()))) - .timeout(Duration::from_secs(upstream.get_timeout())) - .connect_timeout(Duration::from_secs(upstream.get_connect_timeout())) - .http2_keep_alive_interval(Some(Duration::from_secs(upstream.get_keep_alive_interval()))) - .http2_keep_alive_timeout(Duration::from_secs(upstream.get_keep_alive_timeout())) - .http2_keep_alive_while_idle(upstream.get_keep_alive_while_idle()) - .pool_idle_timeout(Some(Duration::from_secs(upstream.get_pool_idle_timeout()))) - .pool_max_idle_per_host(upstream.get_pool_max_idle_per_host()) - .user_agent(upstream.get_user_agent()); + pub fn init(upstream: &Upstream) -> Self { + let mut builder = Client::builder() + .tcp_keepalive(Some(Duration::from_secs(upstream.get_tcp_keep_alive()))) + .timeout(Duration::from_secs(upstream.get_timeout())) + .connect_timeout(Duration::from_secs(upstream.get_connect_timeout())) + .http2_keep_alive_interval(Some(Duration::from_secs( + upstream.get_keep_alive_interval(), + ))) + .http2_keep_alive_timeout(Duration::from_secs(upstream.get_keep_alive_timeout())) + .http2_keep_alive_while_idle(upstream.get_keep_alive_while_idle()) + .pool_idle_timeout(Some(Duration::from_secs(upstream.get_pool_idle_timeout()))) + .pool_max_idle_per_host(upstream.get_pool_max_idle_per_host()) + .user_agent(upstream.get_user_agent()); - // Add Http2 Prior Knowledge - if upstream.get_http_2_only() { - log::info!("Enabled Http2 prior knowledge"); - builder = builder.http2_prior_knowledge(); - } + // Add Http2 Prior Knowledge + if upstream.get_http_2_only() { + log::info!("Enabled Http2 prior knowledge"); + builder = builder.http2_prior_knowledge(); + } - // Add Http Proxy - if let Some(ref proxy) = upstream.proxy { - builder = builder.proxy(reqwest::Proxy::http(proxy.url.clone()).expect("Failed to set proxy in http client")); - } + // Add Http Proxy + if let Some(ref proxy) = upstream.proxy { + builder = builder.proxy( + reqwest::Proxy::http(proxy.url.clone()) + .expect("Failed to set proxy in http client"), + ); + } - let mut client = ClientBuilder::new(builder.build().expect("Failed to build client")); + let mut client = ClientBuilder::new(builder.build().expect("Failed to build client")); - if upstream.get_enable_http_cache() { - client = client.with(Cache(HttpCache { - mode: CacheMode::Default, - manager: MokaManager::default(), - options: HttpCacheOptions::default(), - })) + if upstream.get_enable_http_cache() { + client = client.with(Cache(HttpCache { + mode: CacheMode::Default, + manager: MokaManager::default(), + options: HttpCacheOptions::default(), + })) + } + Self { + client: client.build(), + http2_only: upstream.get_http_2_only(), + } } - Self { client: client.build(), http2_only: upstream.get_http_2_only() } - } } #[async_trait::async_trait] impl HttpIO for NativeHttp { - async fn execute(&self, mut request: reqwest::Request) -> Result> { - if self.http2_only { - *request.version_mut() = reqwest::Version::HTTP_2; + async fn execute(&self, mut request: reqwest::Request) -> Result> { + if self.http2_only { + *request.version_mut() = reqwest::Version::HTTP_2; + } + log::info!( + "{} {} {:?}", + request.method(), + request.url(), + request.version() + ); + let response = self.client.execute(request).await?.error_for_status()?; + Ok(Response::from_reqwest(response).await?) } - log::info!("{} {} {:?}", request.method(), request.url(), request.version()); - let response = self.client.execute(request).await?.error_for_status()?; - Ok(Response::from_reqwest(response).await?) - } } diff --git a/src/cli/mod.rs b/src/cli/mod.rs index e2e19c86971..a37c4ce94dd 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -23,24 +23,24 @@ pub use http::NativeHttp; // Provides access to env in native rust environment pub fn init_env() -> env::EnvNative { - env::EnvNative::init() + env::EnvNative::init() } // Provides access to file system in native rust environment pub fn init_file() -> file::NativeFileIO { - file::NativeFileIO::init() + file::NativeFileIO::init() } // Provides access to http in native rust environment pub fn init_http(upstream: &Upstream) -> http::NativeHttp { - http::NativeHttp::init(upstream) + http::NativeHttp::init(upstream) } // Provides access to http in native rust environment pub fn init_http2_only(upstream: &Upstream) -> http::NativeHttp { - http::NativeHttp::init(&upstream.clone().http2_only(true)) + http::NativeHttp::init(&upstream.clone().http2_only(true)) } pub fn init_chrono_cache() -> NativeChronoCache { - NativeChronoCache::new() + NativeChronoCache::new() } diff --git a/src/cli/server/http_1.rs b/src/cli/server/http_1.rs index e3cbc0d434e..0c5708e0f0d 100644 --- a/src/cli/server/http_1.rs +++ b/src/cli/server/http_1.rs @@ -10,41 +10,53 @@ use crate::cli::http::NativeHttp; use crate::cli::CLIError; use crate::http::handle_request; -pub async fn start_http_1(sc: Arc, server_up_sender: Option>) -> anyhow::Result<()> { - let addr = sc.addr(); - let make_svc_single_req = make_service_fn(|_conn| { - let state = Arc::clone(&sc); - async move { - Ok::<_, anyhow::Error>(service_fn(move |req| { - handle_request::(req, state.server_context.clone()) - })) +pub async fn start_http_1( + sc: Arc, + server_up_sender: Option>, +) -> anyhow::Result<()> { + let addr = sc.addr(); + let make_svc_single_req = make_service_fn(|_conn| { + let state = Arc::clone(&sc); + async move { + Ok::<_, anyhow::Error>(service_fn(move |req| { + handle_request::( + req, + state.server_context.clone(), + ) + })) + } + }); + + let make_svc_batch_req = make_service_fn(|_conn| { + let state = Arc::clone(&sc); + async move { + Ok::<_, anyhow::Error>(service_fn(move |req| { + handle_request::( + req, + state.server_context.clone(), + ) + })) + } + }); + let builder = hyper::Server::try_bind(&addr) + .map_err(CLIError::from)? + .http1_pipeline_flush(sc.server_context.blueprint.server.pipeline_flush); + super::log_launch_and_open_browser(sc.as_ref()); + + if let Some(sender) = server_up_sender { + sender + .send(()) + .or(Err(anyhow::anyhow!("Failed to send message")))?; } - }); - - let make_svc_batch_req = make_service_fn(|_conn| { - let state = Arc::clone(&sc); - async move { - Ok::<_, anyhow::Error>(service_fn(move |req| { - handle_request::(req, state.server_context.clone()) - })) - } - }); - let builder = hyper::Server::try_bind(&addr) - .map_err(CLIError::from)? - .http1_pipeline_flush(sc.server_context.blueprint.server.pipeline_flush); - super::log_launch_and_open_browser(sc.as_ref()); - - if let Some(sender) = server_up_sender { - sender.send(()).or(Err(anyhow::anyhow!("Failed to send message")))?; - } - let server: std::prelude::v1::Result<(), hyper::Error> = if sc.blueprint.server.enable_batch_requests { - builder.serve(make_svc_batch_req).await - } else { - builder.serve(make_svc_single_req).await - }; + let server: std::prelude::v1::Result<(), hyper::Error> = + if sc.blueprint.server.enable_batch_requests { + builder.serve(make_svc_batch_req).await + } else { + builder.serve(make_svc_single_req).await + }; - let result = server.map_err(CLIError::from); + let result = server.map_err(CLIError::from); - Ok(result?) + Ok(result?) } diff --git a/src/cli/server/http_2.rs b/src/cli/server/http_2.rs index ff9daee7d48..23a618e9ac6 100644 --- a/src/cli/server/http_2.rs +++ b/src/cli/server/http_2.rs @@ -7,7 +7,9 @@ use hyper::server::conn::AddrIncoming; use hyper::service::{make_service_fn, service_fn}; use hyper::Server; use hyper_rustls::TlsAcceptor; -use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, PrivateSec1KeyDer}; +use rustls_pki_types::{ + CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, PrivateSec1KeyDer, +}; use tokio::fs::File; use tokio::sync::oneshot; @@ -19,83 +21,96 @@ use crate::cli::CLIError; use crate::http::handle_request; async fn load_cert(filename: String) -> Result>, std::io::Error> { - let file = File::open(filename).await?; - let file = file.into_std().await; - let mut file = BufReader::new(file); + let file = File::open(filename).await?; + let file = file.into_std().await; + let mut file = BufReader::new(file); - let certificates = rustls_pemfile::certs(&mut file)?; + let certificates = rustls_pemfile::certs(&mut file)?; - Ok(certificates.into_iter().map(CertificateDer::from).collect()) + Ok(certificates.into_iter().map(CertificateDer::from).collect()) } async fn load_private_key(filename: String) -> anyhow::Result> { - let file = File::open(filename).await?; - let file = file.into_std().await; - let mut file = BufReader::new(file); + let file = File::open(filename).await?; + let file = file.into_std().await; + let mut file = BufReader::new(file); - let keys = rustls_pemfile::read_all(&mut file)?; + let keys = rustls_pemfile::read_all(&mut file)?; - if keys.len() != 1 { - return Err(CLIError::new("Expected a single private key").into()); - } - - let key = keys.into_iter().find_map(|key| match key { - rustls_pemfile::Item::RSAKey(key) => Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(key))), - rustls_pemfile::Item::ECKey(key) => Some(PrivateKeyDer::Sec1(PrivateSec1KeyDer::from(key))), - rustls_pemfile::Item::PKCS8Key(key) => Some(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key))), - _ => None, - }); + if keys.len() != 1 { + return Err(CLIError::new("Expected a single private key").into()); + } - key.ok_or(CLIError::new("Invalid private key").into()) + let key = keys.into_iter().find_map(|key| match key { + rustls_pemfile::Item::RSAKey(key) => { + Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(key))) + } + rustls_pemfile::Item::ECKey(key) => Some(PrivateKeyDer::Sec1(PrivateSec1KeyDer::from(key))), + rustls_pemfile::Item::PKCS8Key(key) => { + Some(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key))) + } + _ => None, + }); + + key.ok_or(CLIError::new("Invalid private key").into()) } pub async fn start_http_2( - sc: Arc, - cert: String, - key: String, - server_up_sender: Option>, + sc: Arc, + cert: String, + key: String, + server_up_sender: Option>, ) -> anyhow::Result<()> { - let addr = sc.addr(); - let cert_chain = load_cert(cert).await?; - let key = load_private_key(key).await?; - let incoming = AddrIncoming::bind(&addr)?; - let acceptor = TlsAcceptor::builder() - .with_single_cert(cert_chain, key)? - .with_http2_alpn() - .with_incoming(incoming); - let make_svc_single_req = make_service_fn(|_conn| { - let state = Arc::clone(&sc); - async move { - Ok::<_, anyhow::Error>(service_fn(move |req| { - handle_request::(req, state.server_context.clone()) - })) - } - }); - - let make_svc_batch_req = make_service_fn(|_conn| { - let state = Arc::clone(&sc); - async move { - Ok::<_, anyhow::Error>(service_fn(move |req| { - handle_request::(req, state.server_context.clone()) - })) + let addr = sc.addr(); + let cert_chain = load_cert(cert).await?; + let key = load_private_key(key).await?; + let incoming = AddrIncoming::bind(&addr)?; + let acceptor = TlsAcceptor::builder() + .with_single_cert(cert_chain, key)? + .with_http2_alpn() + .with_incoming(incoming); + let make_svc_single_req = make_service_fn(|_conn| { + let state = Arc::clone(&sc); + async move { + Ok::<_, anyhow::Error>(service_fn(move |req| { + handle_request::( + req, + state.server_context.clone(), + ) + })) + } + }); + + let make_svc_batch_req = make_service_fn(|_conn| { + let state = Arc::clone(&sc); + async move { + Ok::<_, anyhow::Error>(service_fn(move |req| { + handle_request::( + req, + state.server_context.clone(), + ) + })) + } + }); + + let builder = Server::builder(acceptor).http2_only(true); + + super::log_launch_and_open_browser(sc.as_ref()); + + if let Some(sender) = server_up_sender { + sender + .send(()) + .or(Err(anyhow::anyhow!("Failed to send message")))?; } - }); - - let builder = Server::builder(acceptor).http2_only(true); - - super::log_launch_and_open_browser(sc.as_ref()); - - if let Some(sender) = server_up_sender { - sender.send(()).or(Err(anyhow::anyhow!("Failed to send message")))?; - } - let server: std::prelude::v1::Result<(), hyper::Error> = if sc.blueprint.server.enable_batch_requests { - builder.serve(make_svc_batch_req).await - } else { - builder.serve(make_svc_single_req).await - }; + let server: std::prelude::v1::Result<(), hyper::Error> = + if sc.blueprint.server.enable_batch_requests { + builder.serve(make_svc_batch_req).await + } else { + builder.serve(make_svc_single_req).await + }; - let result = server.map_err(CLIError::from); + let result = server.map_err(CLIError::from); - Ok(result?) + Ok(result?) } diff --git a/src/cli/server/mod.rs b/src/cli/server/mod.rs index f23211a40de..6eeda5e1e85 100644 --- a/src/cli/server/mod.rs +++ b/src/cli/server/mod.rs @@ -8,12 +8,16 @@ pub use server::Server; use self::server_config::ServerConfig; fn log_launch_and_open_browser(sc: &ServerConfig) { - let addr = sc.addr().to_string(); - log::info!("🚀 Tailcall launched at [{}] over {}", addr, sc.http_version()); - if sc.graphiql() { - let url = sc.graphiql_url(); - log::info!("🌍 Playground: {}", url); + let addr = sc.addr().to_string(); + log::info!( + "🚀 Tailcall launched at [{}] over {}", + addr, + sc.http_version() + ); + if sc.graphiql() { + let url = sc.graphiql_url(); + log::info!("🌍 Playground: {}", url); - let _ = webbrowser::open(url.as_str()); - } + let _ = webbrowser::open(url.as_str()); + } } diff --git a/src/cli/server/server.rs b/src/cli/server/server.rs index 22b975bdb58..1a90c703eed 100644 --- a/src/cli/server/server.rs +++ b/src/cli/server/server.rs @@ -11,44 +11,46 @@ use crate::cli::CLIError; use crate::config::Config; pub struct Server { - config: Config, - server_up_sender: Option>, + config: Config, + server_up_sender: Option>, } impl Server { - pub fn new(config: Config) -> Self { - Self { config, server_up_sender: None } - } - - pub fn server_up_receiver(&mut self) -> oneshot::Receiver<()> { - let (tx, rx) = oneshot::channel(); + pub fn new(config: Config) -> Self { + Self { config, server_up_sender: None } + } - self.server_up_sender = Some(tx); + pub fn server_up_receiver(&mut self) -> oneshot::Receiver<()> { + let (tx, rx) = oneshot::channel(); - rx - } + self.server_up_sender = Some(tx); - /// Starts the server in the current Runtime - pub async fn start(self) -> Result<()> { - let blueprint = Blueprint::try_from(&self.config).map_err(CLIError::from)?; - let server_config = Arc::new(ServerConfig::new(blueprint.clone())); + rx + } - match blueprint.server.http.clone() { - Http::HTTP2 { cert, key } => start_http_2(server_config, cert, key, self.server_up_sender).await, - Http::HTTP1 => start_http_1(server_config, self.server_up_sender).await, + /// Starts the server in the current Runtime + pub async fn start(self) -> Result<()> { + let blueprint = Blueprint::try_from(&self.config).map_err(CLIError::from)?; + let server_config = Arc::new(ServerConfig::new(blueprint.clone())); + + match blueprint.server.http.clone() { + Http::HTTP2 { cert, key } => { + start_http_2(server_config, cert, key, self.server_up_sender).await + } + Http::HTTP1 => start_http_1(server_config, self.server_up_sender).await, + } } - } - /// Starts the server in its own multithreaded Runtime - pub async fn fork_start(self) -> anyhow::Result<()> { - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(self.config.server.get_workers()) - .enable_all() - .build()?; + /// Starts the server in its own multithreaded Runtime + pub async fn fork_start(self) -> anyhow::Result<()> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(self.config.server.get_workers()) + .enable_all() + .build()?; - let result = runtime.spawn(async { self.start().await }).await?; - runtime.shutdown_background(); + let result = runtime.spawn(async { self.start().await }).await?; + runtime.shutdown_background(); - result - } + result + } } diff --git a/src/cli/server/server_config.rs b/src/cli/server/server_config.rs index b8337a46ed2..8619615a070 100644 --- a/src/cli/server/server_config.rs +++ b/src/cli/server/server_config.rs @@ -8,52 +8,52 @@ use crate::cli::{init_chrono_cache, init_env, init_http, init_http2_only}; use crate::http::AppContext; pub struct ServerConfig { - pub blueprint: Blueprint, - pub server_context: Arc>, + pub blueprint: Blueprint, + pub server_context: Arc>, } impl ServerConfig { - pub fn new(blueprint: Blueprint) -> Self { - let h_client = Arc::new(init_http(&blueprint.upstream)); - let h2_client = Arc::new(init_http2_only(&blueprint.upstream)); - let env = init_env(); - let chrono_cache = init_chrono_cache(); - let server_context = Arc::new(AppContext::new( - blueprint.clone(), - h_client, - h2_client, - Arc::new(env), - Arc::new(chrono_cache), - )); - Self { server_context, blueprint } - } - - pub fn addr(&self) -> SocketAddr { - (self.blueprint.server.hostname, self.blueprint.server.port).into() - } - - pub fn http_version(&self) -> String { - match self.blueprint.server.http { - Http::HTTP2 { cert: _, key: _ } => "HTTP/2".to_string(), - _ => "HTTP/1.1".to_string(), + pub fn new(blueprint: Blueprint) -> Self { + let h_client = Arc::new(init_http(&blueprint.upstream)); + let h2_client = Arc::new(init_http2_only(&blueprint.upstream)); + let env = init_env(); + let chrono_cache = init_chrono_cache(); + let server_context = Arc::new(AppContext::new( + blueprint.clone(), + h_client, + h2_client, + Arc::new(env), + Arc::new(chrono_cache), + )); + Self { server_context, blueprint } } - } - pub fn graphiql_url(&self) -> String { - let protocol = match self.http_version().as_str() { - "HTTP/2" => "https", - _ => "http", - }; - let mut addr = self.addr(); + pub fn addr(&self) -> SocketAddr { + (self.blueprint.server.hostname, self.blueprint.server.port).into() + } - if addr.ip().is_unspecified() { - addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), addr.port()); + pub fn http_version(&self) -> String { + match self.blueprint.server.http { + Http::HTTP2 { cert: _, key: _ } => "HTTP/2".to_string(), + _ => "HTTP/1.1".to_string(), + } } - format!("{}://{}", protocol, addr) - } + pub fn graphiql_url(&self) -> String { + let protocol = match self.http_version().as_str() { + "HTTP/2" => "https", + _ => "http", + }; + let mut addr = self.addr(); + + if addr.ip().is_unspecified() { + addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), addr.port()); + } - pub fn graphiql(&self) -> bool { - self.blueprint.server.enable_graphiql - } + format!("{}://{}", protocol, addr) + } + + pub fn graphiql(&self) -> bool { + self.blueprint.server.enable_graphiql + } } diff --git a/src/cli/tc.rs b/src/cli/tc.rs index 52c02f0dd20..5e60d223787 100644 --- a/src/cli/tc.rs +++ b/src/cli/tc.rs @@ -20,162 +20,171 @@ const FILE_NAME: &str = ".tailcallrc.graphql"; const YML_FILE_NAME: &str = ".graphqlrc.yml"; pub async fn run() -> anyhow::Result<()> { - let cli = Cli::parse(); - - logger_init(); - let file_io = init_file(); - let default_http_io = init_http(&Upstream::default()); - let config_reader = ConfigReader::init(file_io.clone(), default_http_io); - match cli.command { - Command::Start { file_paths } => { - let config = config_reader.read(&file_paths).await?; - log::info!("N + 1: {}", config.n_plus_one().len().to_string()); - let server = Server::new(config); - server.fork_start().await?; - Ok(()) - } - Command::Check { file_paths, n_plus_one_queries, schema, operations } => { - let config = (config_reader.read(&file_paths)).await?; - let blueprint = Blueprint::try_from(&config).map_err(CLIError::from); - - match blueprint { - Ok(blueprint) => { - log::info!("{}", "Config successfully validated".to_string()); - display_config(&config, n_plus_one_queries); - if schema { - display_schema(&blueprint); - } - - let ops: Vec = futures_util::future::join_all(operations.iter().map(|op| async { - file_io - .read(op) - .await - .map(|query| OperationQuery::new(query, op.clone())) - })) - .await - .into_iter() - .collect::>>()?; - - validate_operations(&blueprint, ops) - .await - .to_result() - .map_err(|e| CLIError::from(e).message("Invalid Operation".to_string()).into()) + let cli = Cli::parse(); + + logger_init(); + let file_io = init_file(); + let default_http_io = init_http(&Upstream::default()); + let config_reader = ConfigReader::init(file_io.clone(), default_http_io); + match cli.command { + Command::Start { file_paths } => { + let config = config_reader.read(&file_paths).await?; + log::info!("N + 1: {}", config.n_plus_one().len().to_string()); + let server = Server::new(config); + server.fork_start().await?; + Ok(()) + } + Command::Check { file_paths, n_plus_one_queries, schema, operations } => { + let config = (config_reader.read(&file_paths)).await?; + let blueprint = Blueprint::try_from(&config).map_err(CLIError::from); + + match blueprint { + Ok(blueprint) => { + log::info!("{}", "Config successfully validated".to_string()); + display_config(&config, n_plus_one_queries); + if schema { + display_schema(&blueprint); + } + + let ops: Vec = + futures_util::future::join_all(operations.iter().map(|op| async { + file_io + .read(op) + .await + .map(|query| OperationQuery::new(query, op.clone())) + })) + .await + .into_iter() + .collect::>>()?; + + validate_operations(&blueprint, ops) + .await + .to_result() + .map_err(|e| { + CLIError::from(e) + .message("Invalid Operation".to_string()) + .into() + }) + } + Err(e) => Err(e.into()), + } + } + Command::Init { folder_path } => init(&folder_path).await, + Command::Compose { file_paths, format } => { + let config = (config_reader.read(&file_paths).await)?; + Fmt::display(format.encode(&config)?); + Ok(()) } - Err(e) => Err(e.into()), - } - } - Command::Init { folder_path } => init(&folder_path).await, - Command::Compose { file_paths, format } => { - let config = (config_reader.read(&file_paths).await)?; - Fmt::display(format.encode(&config)?); - Ok(()) } - } } pub async fn init(folder_path: &str) -> Result<()> { - let folder_exists = fs::metadata(folder_path).is_ok(); + let folder_exists = fs::metadata(folder_path).is_ok(); - if !folder_exists { - let confirm = Confirm::new(&format!("Do you want to create the folder {}?", folder_path)) - .with_default(false) - .prompt()?; + if !folder_exists { + let confirm = Confirm::new(&format!( + "Do you want to create the folder {}?", + folder_path + )) + .with_default(false) + .prompt()?; - if confirm { - fs::create_dir_all(folder_path)?; - } else { - return Ok(()); - }; - } + if confirm { + fs::create_dir_all(folder_path)?; + } else { + return Ok(()); + }; + } - let tailcallrc = include_str!("../../examples/.tailcallrc.graphql"); + let tailcallrc = include_str!("../../examples/.tailcallrc.graphql"); - let file_path = Path::new(folder_path).join(FILE_NAME); - let yml_file_path = Path::new(folder_path).join(YML_FILE_NAME); + let file_path = Path::new(folder_path).join(FILE_NAME); + let yml_file_path = Path::new(folder_path).join(YML_FILE_NAME); - let tailcall_exists = fs::metadata(&file_path).is_ok(); + let tailcall_exists = fs::metadata(&file_path).is_ok(); - if tailcall_exists { - // confirm overwrite - let confirm = Confirm::new(&format!("Do you want to overwrite the file {}?", FILE_NAME)) - .with_default(false) - .prompt()?; + if tailcall_exists { + // confirm overwrite + let confirm = Confirm::new(&format!("Do you want to overwrite the file {}?", FILE_NAME)) + .with_default(false) + .prompt()?; - if confirm { - fs::write(&file_path, tailcallrc.as_bytes())?; + if confirm { + fs::write(&file_path, tailcallrc.as_bytes())?; + } + } else { + fs::write(&file_path, tailcallrc.as_bytes())?; } - } else { - fs::write(&file_path, tailcallrc.as_bytes())?; - } - let yml_exists = fs::metadata(&yml_file_path).is_ok(); + let yml_exists = fs::metadata(&yml_file_path).is_ok(); - if !yml_exists { - fs::write(&yml_file_path, "")?; + if !yml_exists { + fs::write(&yml_file_path, "")?; - let graphqlrc = r"|schema: + let graphqlrc = r"|schema: |- './.tailcallrc.graphql' " - .strip_margin(); - - fs::write(&yml_file_path, graphqlrc)?; - } - - let graphqlrc = fs::read_to_string(&yml_file_path)?; + .strip_margin(); - let file_path = file_path.to_str().unwrap(); - - let mut yaml: serde_yaml::Value = serde_yaml::from_str(&graphqlrc)?; - - if let Some(mapping) = yaml.as_mapping_mut() { - let schema = mapping - .entry("schema".into()) - .or_insert(serde_yaml::Value::Sequence(Default::default())); - if let Some(schema) = schema.as_sequence_mut() { - if !schema - .iter() - .any(|v| v == &serde_yaml::Value::from("./.tailcallrc.graphql")) - { - let confirm = Confirm::new(&format!("Do you want to add {} to the schema?", file_path)) - .with_default(false) - .prompt()?; + fs::write(&yml_file_path, graphqlrc)?; + } - if confirm { - schema.push(serde_yaml::Value::from("./.tailcallrc.graphql")); - let updated = serde_yaml::to_string(&yaml)?; - fs::write(yml_file_path, updated)?; + let graphqlrc = fs::read_to_string(&yml_file_path)?; + + let file_path = file_path.to_str().unwrap(); + + let mut yaml: serde_yaml::Value = serde_yaml::from_str(&graphqlrc)?; + + if let Some(mapping) = yaml.as_mapping_mut() { + let schema = mapping + .entry("schema".into()) + .or_insert(serde_yaml::Value::Sequence(Default::default())); + if let Some(schema) = schema.as_sequence_mut() { + if !schema + .iter() + .any(|v| v == &serde_yaml::Value::from("./.tailcallrc.graphql")) + { + let confirm = + Confirm::new(&format!("Do you want to add {} to the schema?", file_path)) + .with_default(false) + .prompt()?; + + if confirm { + schema.push(serde_yaml::Value::from("./.tailcallrc.graphql")); + let updated = serde_yaml::to_string(&yaml)?; + fs::write(yml_file_path, updated)?; + } + } } - } } - } - Ok(()) + Ok(()) } pub fn display_schema(blueprint: &Blueprint) { - Fmt::display(Fmt::heading(&"GraphQL Schema:\n".to_string())); - let sdl = blueprint.to_schema(); - Fmt::display(format!("{}\n", print_schema::print_schema(sdl))); + Fmt::display(Fmt::heading(&"GraphQL Schema:\n".to_string())); + let sdl = blueprint.to_schema(); + Fmt::display(format!("{}\n", print_schema::print_schema(sdl))); } fn display_config(config: &Config, n_plus_one_queries: bool) { - let seq = vec![Fmt::n_plus_one_data(n_plus_one_queries, config)]; - Fmt::display(Fmt::table(seq)); + let seq = vec![Fmt::n_plus_one_data(n_plus_one_queries, config)]; + Fmt::display(Fmt::table(seq)); } // initialize logger fn logger_init() { - // set the log level - const LONG_ENV_FILTER_VAR_NAME: &str = "TAILCALL_LOG_LEVEL"; - const SHORT_ENV_FILTER_VAR_NAME: &str = "TC_LOG_LEVEL"; + // set the log level + const LONG_ENV_FILTER_VAR_NAME: &str = "TAILCALL_LOG_LEVEL"; + const SHORT_ENV_FILTER_VAR_NAME: &str = "TC_LOG_LEVEL"; - // Select which env variable to use for the log level filter. This is because filter_or doesn't allow picking between multiple env_var for the filter value - let filter_env_name = env::var(LONG_ENV_FILTER_VAR_NAME) - .map(|_| LONG_ENV_FILTER_VAR_NAME) - .unwrap_or_else(|_| SHORT_ENV_FILTER_VAR_NAME); + // Select which env variable to use for the log level filter. This is because filter_or doesn't allow picking between multiple env_var for the filter value + let filter_env_name = env::var(LONG_ENV_FILTER_VAR_NAME) + .map(|_| LONG_ENV_FILTER_VAR_NAME) + .unwrap_or_else(|_| SHORT_ENV_FILTER_VAR_NAME); - // use the log level from the env if there is one, otherwise use the default. - let env = Env::new().filter_or(filter_env_name, "info"); + // use the log level from the env if there is one, otherwise use the default. + let env = Env::new().filter_or(filter_env_name, "info"); - env_logger::Builder::from_env(env).init(); + env_logger::Builder::from_env(env).init(); } diff --git a/src/config/config.rs b/src/config/config.rs index 251fb61d20b..a33e149820a 100644 --- a/src/config/config.rs +++ b/src/config/config.rs @@ -17,153 +17,157 @@ use crate::http::Method; use crate::json::JsonSchema; use crate::valid::Valid; -#[derive(Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema, +)] #[serde(rename_all = "camelCase")] pub struct Config { - /// - /// Dictates how the server behaves and helps tune tailcall for all ingress requests. - /// Features such as request batching, SSL, HTTP2 etc. can be configured here. - /// - #[serde(default)] - pub server: Server, - - /// - /// Dictates how tailcall should handle upstream requests/responses. - /// Tuning upstream can improve performance and reliability for connections. - /// - #[serde(default)] - pub upstream: Upstream, - - /// - /// Specifies the entry points for query and mutation in the generated GraphQL schema. - /// - pub schema: RootSchema, - - /// - /// A map of all the types in the schema. - /// - #[serde(default)] - #[setters(skip)] - pub types: BTreeMap, - - /// - /// A map of all the union types in the schema. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub unions: BTreeMap, + /// + /// Dictates how the server behaves and helps tune tailcall for all ingress requests. + /// Features such as request batching, SSL, HTTP2 etc. can be configured here. + /// + #[serde(default)] + pub server: Server, + + /// + /// Dictates how tailcall should handle upstream requests/responses. + /// Tuning upstream can improve performance and reliability for connections. + /// + #[serde(default)] + pub upstream: Upstream, + + /// + /// Specifies the entry points for query and mutation in the generated GraphQL schema. + /// + pub schema: RootSchema, + + /// + /// A map of all the types in the schema. + /// + #[serde(default)] + #[setters(skip)] + pub types: BTreeMap, + + /// + /// A map of all the union types in the schema. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub unions: BTreeMap, } impl Config { - pub fn port(&self) -> u16 { - self.server.port.unwrap_or(8000) - } + pub fn port(&self) -> u16 { + self.server.port.unwrap_or(8000) + } + + pub fn output_types(&self) -> HashSet<&String> { + let mut types = HashSet::new(); + let input_types = self.input_types(); - pub fn output_types(&self) -> HashSet<&String> { - let mut types = HashSet::new(); - let input_types = self.input_types(); + if let Some(ref query) = &self.schema.query { + types.insert(query); + } - if let Some(ref query) = &self.schema.query { - types.insert(query); + if let Some(ref mutation) = &self.schema.mutation { + types.insert(mutation); + } + for (type_name, type_of) in self.types.iter() { + if (type_of.interface || !type_of.fields.is_empty()) + && !input_types.contains(&type_name) + { + for (_, field) in type_of.fields.iter() { + types.insert(&field.type_of); + } + } + } + types } - if let Some(ref mutation) = &self.schema.mutation { - types.insert(mutation); + pub fn recurse_type<'a>(&'a self, type_of: &str, types: &mut HashSet<&'a String>) { + if let Some(type_) = self.find_type(type_of) { + for (_, field) in type_.fields.iter() { + if !types.contains(&field.type_of) { + types.insert(&field.type_of); + self.recurse_type(&field.type_of, types); + } + } + } } - for (type_name, type_of) in self.types.iter() { - if (type_of.interface || !type_of.fields.is_empty()) && !input_types.contains(&type_name) { - for (_, field) in type_of.fields.iter() { - types.insert(&field.type_of); + + pub fn input_types(&self) -> HashSet<&String> { + let mut types = HashSet::new(); + for (_, type_of) in self.types.iter() { + if !type_of.interface { + for (_, field) in type_of.fields.iter() { + for (_, arg) in field.args.iter() { + if let Some(t) = self.find_type(&arg.type_of) { + t.fields.iter().for_each(|(_, f)| { + types.insert(&f.type_of); + self.recurse_type(&f.type_of, &mut types) + }) + } + types.insert(&arg.type_of); + } + } + } } - } + types + } + pub fn find_type(&self, name: &str) -> Option<&Type> { + self.types.get(name) } - types - } - pub fn recurse_type<'a>(&'a self, type_of: &str, types: &mut HashSet<&'a String>) { - if let Some(type_) = self.find_type(type_of) { - for (_, field) in type_.fields.iter() { - if !types.contains(&field.type_of) { - types.insert(&field.type_of); - self.recurse_type(&field.type_of, types); + pub fn find_union(&self, name: &str) -> Option<&Union> { + self.unions.get(name) + } + + pub fn to_yaml(&self) -> Result { + Ok(serde_yaml::to_string(self)?) + } + + pub fn to_json(&self, pretty: bool) -> Result { + if pretty { + Ok(serde_json::to_string_pretty(self)?) + } else { + Ok(serde_json::to_string(self)?) } - } - } - } - - pub fn input_types(&self) -> HashSet<&String> { - let mut types = HashSet::new(); - for (_, type_of) in self.types.iter() { - if !type_of.interface { - for (_, field) in type_of.fields.iter() { - for (_, arg) in field.args.iter() { - if let Some(t) = self.find_type(&arg.type_of) { - t.fields.iter().for_each(|(_, f)| { - types.insert(&f.type_of); - self.recurse_type(&f.type_of, &mut types) - }) - } - types.insert(&arg.type_of); - } + } + + pub fn to_document(&self) -> ServiceDocument { + self.clone().into() + } + + pub fn to_sdl(&self) -> String { + let doc = self.to_document(); + crate::document::print(doc) + } + + pub fn query(mut self, query: &str) -> Self { + self.schema.query = Some(query.to_string()); + self + } + + pub fn types(mut self, types: Vec<(&str, Type)>) -> Self { + let mut graphql_types = BTreeMap::new(); + for (name, type_) in types { + graphql_types.insert(name.to_string(), type_); } - } - } - types - } - pub fn find_type(&self, name: &str) -> Option<&Type> { - self.types.get(name) - } - - pub fn find_union(&self, name: &str) -> Option<&Union> { - self.unions.get(name) - } - - pub fn to_yaml(&self) -> Result { - Ok(serde_yaml::to_string(self)?) - } - - pub fn to_json(&self, pretty: bool) -> Result { - if pretty { - Ok(serde_json::to_string_pretty(self)?) - } else { - Ok(serde_json::to_string(self)?) - } - } - - pub fn to_document(&self) -> ServiceDocument { - self.clone().into() - } - - pub fn to_sdl(&self) -> String { - let doc = self.to_document(); - crate::document::print(doc) - } - - pub fn query(mut self, query: &str) -> Self { - self.schema.query = Some(query.to_string()); - self - } - - pub fn types(mut self, types: Vec<(&str, Type)>) -> Self { - let mut graphql_types = BTreeMap::new(); - for (name, type_) in types { - graphql_types.insert(name.to_string(), type_); - } - self.types = graphql_types; - self - } - - pub fn contains(&self, name: &str) -> bool { - self.types.contains_key(name) || self.unions.contains_key(name) - } - - pub fn merge_right(self, other: &Self) -> Self { - let server = self.server.merge_right(other.server.clone()); - let types = merge_types(self.types, other.types.clone()); - let unions = merge_unions(self.unions, other.unions.clone()); - let schema = self.schema.merge_right(other.schema.clone()); - let upstream = self.upstream.merge_right(other.upstream.clone()); - - Self { server, upstream, types, schema, unions } - } + self.types = graphql_types; + self + } + + pub fn contains(&self, name: &str) -> bool { + self.types.contains_key(name) || self.unions.contains_key(name) + } + + pub fn merge_right(self, other: &Self) -> Self { + let server = self.server.merge_right(other.server.clone()); + let types = merge_types(self.types, other.types.clone()); + let unions = merge_unions(self.unions, other.unions.clone()); + let schema = self.schema.merge_right(other.schema.clone()); + let upstream = self.upstream.merge_right(other.upstream.clone()); + + Self { server, upstream, types, schema, unions } + } } /// @@ -172,122 +176,127 @@ impl Config { /// #[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] pub struct Type { - /// - /// A map of field name and its definition. - /// - pub fields: BTreeMap, - #[serde(default, skip_serializing_if = "is_default")] - /// - /// Additional fields to be added to the type - /// - pub added_fields: Vec, - #[serde(default, skip_serializing_if = "is_default")] - /// - /// Documentation for the type that is publicly visible. - /// - pub doc: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// - /// Flag to indicate if the type is an interface. - /// - pub interface: bool, - #[serde(default, skip_serializing_if = "is_default")] - /// - /// Interfaces that the type implements. - /// - pub implements: BTreeSet, - #[serde(rename = "enum", default, skip_serializing_if = "is_default")] - /// - /// Variants for the type if it's an enum - /// - pub variants: Option>, - #[serde(default, skip_serializing_if = "is_default")] - /// - /// Flag to indicate if the type is a scalar. - /// - pub scalar: bool, - #[serde(default)] - /// - /// Setting to indicate if the type is cacheable. - /// - pub cache: Option, + /// + /// A map of field name and its definition. + /// + pub fields: BTreeMap, + #[serde(default, skip_serializing_if = "is_default")] + /// + /// Additional fields to be added to the type + /// + pub added_fields: Vec, + #[serde(default, skip_serializing_if = "is_default")] + /// + /// Documentation for the type that is publicly visible. + /// + pub doc: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// + /// Flag to indicate if the type is an interface. + /// + pub interface: bool, + #[serde(default, skip_serializing_if = "is_default")] + /// + /// Interfaces that the type implements. + /// + pub implements: BTreeSet, + #[serde(rename = "enum", default, skip_serializing_if = "is_default")] + /// + /// Variants for the type if it's an enum + /// + pub variants: Option>, + #[serde(default, skip_serializing_if = "is_default")] + /// + /// Flag to indicate if the type is a scalar. + /// + pub scalar: bool, + #[serde(default)] + /// + /// Setting to indicate if the type is cacheable. + /// + pub cache: Option, } impl Type { - pub fn fields(mut self, fields: Vec<(&str, Field)>) -> Self { - let mut graphql_fields = BTreeMap::new(); - for (name, field) in fields { - graphql_fields.insert(name.to_string(), field); - } - self.fields = graphql_fields; - self - } - pub fn merge_right(mut self, other: &Self) -> Self { - let mut fields = self.fields.clone(); - fields.extend(other.fields.clone()); - self.implements.extend(other.implements.clone()); - if let Some(ref variants) = self.variants { - if let Some(ref other) = other.variants { - self.variants = Some(variants.union(other).cloned().collect()); - } - } else { - self.variants = other.variants.clone(); - } - Self { fields, ..self.clone() } - } + pub fn fields(mut self, fields: Vec<(&str, Field)>) -> Self { + let mut graphql_fields = BTreeMap::new(); + for (name, field) in fields { + graphql_fields.insert(name.to_string(), field); + } + self.fields = graphql_fields; + self + } + pub fn merge_right(mut self, other: &Self) -> Self { + let mut fields = self.fields.clone(); + fields.extend(other.fields.clone()); + self.implements.extend(other.implements.clone()); + if let Some(ref variants) = self.variants { + if let Some(ref other) = other.variants { + self.variants = Some(variants.union(other).cloned().collect()); + } + } else { + self.variants = other.variants.clone(); + } + Self { fields, ..self.clone() } + } } #[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Eq, schemars::JsonSchema)] /// The @cache operator enables caching for the query, field or type it is applied to. #[serde(rename_all = "camelCase")] pub struct Cache { - /// Specifies the duration, in milliseconds, of how long the value has to be stored in the cache. - pub max_age: NonZeroU64, + /// Specifies the duration, in milliseconds, of how long the value has to be stored in the cache. + pub max_age: NonZeroU64, } -fn merge_types(mut self_types: BTreeMap, other_types: BTreeMap) -> BTreeMap { - for (name, mut other_type) in other_types { - if let Some(self_type) = self_types.remove(&name) { - other_type = self_type.merge_right(&other_type); - } +fn merge_types( + mut self_types: BTreeMap, + other_types: BTreeMap, +) -> BTreeMap { + for (name, mut other_type) in other_types { + if let Some(self_type) = self_types.remove(&name) { + other_type = self_type.merge_right(&other_type); + } - self_types.insert(name, other_type); - } - self_types + self_types.insert(name, other_type); + } + self_types } fn merge_unions( - mut self_unions: BTreeMap, - other_unions: BTreeMap, + mut self_unions: BTreeMap, + other_unions: BTreeMap, ) -> BTreeMap { - for (name, mut other_union) in other_unions { - if let Some(self_union) = self_unions.remove(&name) { - other_union = self_union.merge_right(other_union); + for (name, mut other_union) in other_unions { + if let Some(self_union) = self_unions.remove(&name) { + other_union = self_union.merge_right(other_union); + } + self_unions.insert(name, other_union); } - self_unions.insert(name, other_union); - } - self_unions + self_unions } -#[derive(Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema, +)] #[setters(strip_option)] pub struct RootSchema { - pub query: Option, - #[serde(default, skip_serializing_if = "is_default")] - pub mutation: Option, - #[serde(default, skip_serializing_if = "is_default")] - pub subscription: Option, + pub query: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub mutation: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub subscription: Option, } impl RootSchema { - // TODO: add unit-tests - fn merge_right(self, other: Self) -> Self { - Self { - query: other.query.or(self.query), - mutation: other.mutation.or(self.mutation), - subscription: other.subscription.or(self.subscription), - } - } + // TODO: add unit-tests + fn merge_right(self, other: Self) -> Self { + Self { + query: other.query.or(self.query), + mutation: other.mutation.or(self.mutation), + subscription: other.subscription.or(self.subscription), + } + } } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] @@ -296,206 +305,213 @@ pub struct Omit {} /// /// A field definition containing all the metadata information about resolving a field. /// -#[derive(Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, Clone, Debug, Default, Setters, PartialEq, Eq, schemars::JsonSchema, +)] #[setters(strip_option)] pub struct Field { - /// - /// Refers to the type of the value the field can be resolved to. - /// - #[serde(rename = "type", default, skip_serializing_if = "is_default")] - pub type_of: String, - - /// - /// Flag to indicate the type is a list. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub list: bool, - - /// - /// Flag to indicate the type is required. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub required: bool, - - /// - /// Flag to indicate if the type inside the list is required. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub list_type_required: bool, - - /// - /// Map of argument name and its definition. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub args: BTreeMap, - - /// - /// Publicly visible documentation for the field. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub doc: Option, - - /// - /// Allows modifying existing fields. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub modify: Option, - - /// - /// Omits a field from public consumption. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub omit: Option, - - /// - /// Inserts an HTTP resolver for the field. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub http: Option, - - /// - /// Inserts a GRPC resolver for the field. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub grpc: Option, - - /// - /// Inserts a Javascript resolver for the field. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub script: Option, - - /// - /// Inserts a constant resolver for the field. - /// - #[serde(rename = "const", default, skip_serializing_if = "is_default")] - pub const_field: Option, - - /// - /// Inserts a GraphQL resolver for the field. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub graphql: Option, - - /// - /// Inserts an Expression resolver for the field. - /// - #[serde(default, skip_serializing_if = "is_default")] - pub expr: Option, - /// - /// Sets the cache configuration for a field - /// - pub cache: Option, + /// + /// Refers to the type of the value the field can be resolved to. + /// + #[serde(rename = "type", default, skip_serializing_if = "is_default")] + pub type_of: String, + + /// + /// Flag to indicate the type is a list. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub list: bool, + + /// + /// Flag to indicate the type is required. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub required: bool, + + /// + /// Flag to indicate if the type inside the list is required. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub list_type_required: bool, + + /// + /// Map of argument name and its definition. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub args: BTreeMap, + + /// + /// Publicly visible documentation for the field. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub doc: Option, + + /// + /// Allows modifying existing fields. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub modify: Option, + + /// + /// Omits a field from public consumption. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub omit: Option, + + /// + /// Inserts an HTTP resolver for the field. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub http: Option, + + /// + /// Inserts a GRPC resolver for the field. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub grpc: Option, + + /// + /// Inserts a Javascript resolver for the field. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub script: Option, + + /// + /// Inserts a constant resolver for the field. + /// + #[serde(rename = "const", default, skip_serializing_if = "is_default")] + pub const_field: Option, + + /// + /// Inserts a GraphQL resolver for the field. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub graphql: Option, + + /// + /// Inserts an Expression resolver for the field. + /// + #[serde(default, skip_serializing_if = "is_default")] + pub expr: Option, + /// + /// Sets the cache configuration for a field + /// + pub cache: Option, } impl Field { - pub fn has_resolver(&self) -> bool { - self.http.is_some() - || self.script.is_some() - || self.const_field.is_some() - || self.graphql.is_some() - || self.grpc.is_some() - || self.expr.is_some() - } - pub fn resolvable_directives(&self) -> Vec { - let mut directives = Vec::new(); - if self.http.is_some() { - directives.push(Http::trace_name()); - } - if self.graphql.is_some() { - directives.push(GraphQL::trace_name()); - } - if self.script.is_some() { - directives.push(JS::trace_name()); - } - if self.const_field.is_some() { - directives.push(Const::trace_name()); - } - if self.grpc.is_some() { - directives.push(Grpc::trace_name()); - } - directives - } - pub fn has_batched_resolver(&self) -> bool { - self.http.as_ref().is_some_and(|http| !http.group_by.is_empty()) - || self.graphql.as_ref().is_some_and(|graphql| graphql.batch) - || self.grpc.as_ref().is_some_and(|grpc| !grpc.group_by.is_empty()) - } - pub fn to_list(mut self) -> Self { - self.list = true; - self - } - - pub fn int() -> Self { - Self { type_of: "Int".to_string(), ..Default::default() } - } - - pub fn string() -> Self { - Self { type_of: "String".to_string(), ..Default::default() } - } - - pub fn float() -> Self { - Self { type_of: "Float".to_string(), ..Default::default() } - } - - pub fn boolean() -> Self { - Self { type_of: "Boolean".to_string(), ..Default::default() } - } - - pub fn id() -> Self { - Self { type_of: "ID".to_string(), ..Default::default() } - } - - pub fn is_omitted(&self) -> bool { - self.omit.is_some() || self.modify.as_ref().map(|m| m.omit).unwrap_or(false) - } + pub fn has_resolver(&self) -> bool { + self.http.is_some() + || self.script.is_some() + || self.const_field.is_some() + || self.graphql.is_some() + || self.grpc.is_some() + || self.expr.is_some() + } + pub fn resolvable_directives(&self) -> Vec { + let mut directives = Vec::new(); + if self.http.is_some() { + directives.push(Http::trace_name()); + } + if self.graphql.is_some() { + directives.push(GraphQL::trace_name()); + } + if self.script.is_some() { + directives.push(JS::trace_name()); + } + if self.const_field.is_some() { + directives.push(Const::trace_name()); + } + if self.grpc.is_some() { + directives.push(Grpc::trace_name()); + } + directives + } + pub fn has_batched_resolver(&self) -> bool { + self.http + .as_ref() + .is_some_and(|http| !http.group_by.is_empty()) + || self.graphql.as_ref().is_some_and(|graphql| graphql.batch) + || self + .grpc + .as_ref() + .is_some_and(|grpc| !grpc.group_by.is_empty()) + } + pub fn to_list(mut self) -> Self { + self.list = true; + self + } + + pub fn int() -> Self { + Self { type_of: "Int".to_string(), ..Default::default() } + } + + pub fn string() -> Self { + Self { type_of: "String".to_string(), ..Default::default() } + } + + pub fn float() -> Self { + Self { type_of: "Float".to_string(), ..Default::default() } + } + + pub fn boolean() -> Self { + Self { type_of: "Boolean".to_string(), ..Default::default() } + } + + pub fn id() -> Self { + Self { type_of: "ID".to_string(), ..Default::default() } + } + + pub fn is_omitted(&self) -> bool { + self.omit.is_some() || self.modify.as_ref().map(|m| m.omit).unwrap_or(false) + } } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] pub struct JS { - pub script: String, + pub script: String, } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] pub struct Modify { - #[serde(default, skip_serializing_if = "is_default")] - pub name: Option, - #[serde(default, skip_serializing_if = "is_default")] - pub omit: bool, + #[serde(default, skip_serializing_if = "is_default")] + pub name: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub omit: bool, } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] pub struct Inline { - pub path: Vec, + pub path: Vec, } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] pub struct Arg { - #[serde(rename = "type")] - pub type_of: String, - #[serde(default, skip_serializing_if = "is_default")] - pub list: bool, - #[serde(default, skip_serializing_if = "is_default")] - pub required: bool, - #[serde(default, skip_serializing_if = "is_default")] - pub doc: Option, - #[serde(default, skip_serializing_if = "is_default")] - pub modify: Option, - #[serde(default, skip_serializing_if = "is_default")] - pub default_value: Option, + #[serde(rename = "type")] + pub type_of: String, + #[serde(default, skip_serializing_if = "is_default")] + pub list: bool, + #[serde(default, skip_serializing_if = "is_default")] + pub required: bool, + #[serde(default, skip_serializing_if = "is_default")] + pub doc: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub modify: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub default_value: Option, } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] pub struct Union { - pub types: BTreeSet, - pub doc: Option, + pub types: BTreeSet, + pub doc: Option, } impl Union { - pub fn merge_right(mut self, other: Self) -> Self { - self.types.extend(other.types); - self - } + pub fn merge_right(mut self, other: Self) -> Self { + self.types.extend(other.types); + self + } } #[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] @@ -505,37 +521,37 @@ impl Union { /// The path argument specifies the path of the REST API. /// In this scenario, the GraphQL server will make a GET request to the API endpoint specified when the `users` field is queried. pub struct Http { - /// This refers to the API endpoint you're going to call. For instance https://jsonplaceholder.typicode.com/users`. - /// - /// For dynamic segments in your API endpoint, use Mustache templates for variable substitution. For instance, to fetch a specific user, use `/users/{{args.id}}`. - pub path: String, - #[serde(default, skip_serializing_if = "is_default")] - /// This refers to the HTTP method of the API call. Commonly used methods include `GET`, `POST`, `PUT`, `DELETE` etc. @default `GET`. - pub method: Method, - #[serde(default, skip_serializing_if = "is_default")] - /// This represents the query parameters of your API call. You can pass it as a static object or use Mustache template for dynamic parameters. These parameters will be added to the URL. - pub query: KeyValues, - #[serde(default, skip_serializing_if = "is_default")] - /// Schema of the input of the API call. It is automatically inferred in most cases. - pub input: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// Schema of the output of the API call. It is automatically inferred in most cases. - pub output: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The body of the API call. It's used for methods like POST or PUT that send data to the server. You can pass it as a static object or use a Mustache template to substitute variables from the GraphQL variables. - pub body: Option, - #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] - /// This refers to the base URL of the API. If not specified, the default base URL is the one specified in the `@upstream` operator - pub base_url: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The `headers` parameter allows you to customize the headers of the HTTP request made by the `@http` operator. It is used by specifying a key-value map of header names and their values. - pub headers: KeyValues, - #[serde(rename = "groupBy", default, skip_serializing_if = "is_default")] - /// The `groupBy` parameter groups multiple data requests into a single call. For more details please refer out [n + 1 guide](https://tailcall.run/docs/guides/n+1#solving-using-batching). - pub group_by: Vec, - #[serde(default, skip_serializing_if = "is_default")] - /// The `encoding` parameter specifies the encoding of the request body. It can be `ApplicationJson` or `ApplicationXWwwFormUrlEncoded`. @default `ApplicationJson`. - pub encoding: Encoding, + /// This refers to the API endpoint you're going to call. For instance https://jsonplaceholder.typicode.com/users`. + /// + /// For dynamic segments in your API endpoint, use Mustache templates for variable substitution. For instance, to fetch a specific user, use `/users/{{args.id}}`. + pub path: String, + #[serde(default, skip_serializing_if = "is_default")] + /// This refers to the HTTP method of the API call. Commonly used methods include `GET`, `POST`, `PUT`, `DELETE` etc. @default `GET`. + pub method: Method, + #[serde(default, skip_serializing_if = "is_default")] + /// This represents the query parameters of your API call. You can pass it as a static object or use Mustache template for dynamic parameters. These parameters will be added to the URL. + pub query: KeyValues, + #[serde(default, skip_serializing_if = "is_default")] + /// Schema of the input of the API call. It is automatically inferred in most cases. + pub input: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// Schema of the output of the API call. It is automatically inferred in most cases. + pub output: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The body of the API call. It's used for methods like POST or PUT that send data to the server. You can pass it as a static object or use a Mustache template to substitute variables from the GraphQL variables. + pub body: Option, + #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] + /// This refers to the base URL of the API. If not specified, the default base URL is the one specified in the `@upstream` operator + pub base_url: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The `headers` parameter allows you to customize the headers of the HTTP request made by the `@http` operator. It is used by specifying a key-value map of header names and their values. + pub headers: KeyValues, + #[serde(rename = "groupBy", default, skip_serializing_if = "is_default")] + /// The `groupBy` parameter groups multiple data requests into a single call. For more details please refer out [n + 1 guide](https://tailcall.run/docs/guides/n+1#solving-using-batching). + pub group_by: Vec, + #[serde(default, skip_serializing_if = "is_default")] + /// The `encoding` parameter specifies the encoding of the request body. It can be `ApplicationJson` or `ApplicationXWwwFormUrlEncoded`. @default `ApplicationJson`. + pub encoding: Encoding, } #[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] @@ -547,147 +563,157 @@ pub struct Http { /// The `method` argument specifies the name of the gRPC method. /// In this scenario, the GraphQL server will make a gRPC request to the gRPC endpoint specified when the `users` field is queried. pub struct Grpc { - /// This refers to the gRPC service you're going to call. For instance `NewsService`. - pub service: String, - /// This refers to the gRPC method you're going to call. For instance `GetAllNews`. - pub method: String, - #[serde(default, skip_serializing_if = "is_default")] - /// This refers to the arguments of your gRPC call. You can pass it as a static object or use Mustache template for dynamic parameters. These parameters will be added in the body in `protobuf` format. - pub body: Option, - #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] - /// This refers to the base URL of the API. If not specified, the default base URL is the one specified in the `@upstream` operator - pub base_url: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The `headers` parameter allows you to customize the headers of the HTTP request made by the `@grpc` operator. It is used by specifying a key-value map of header names and their values. Note: content-type is automatically set to application/grpc - pub headers: KeyValues, - /// The `protoPath` parameter allows you to specify the path to the proto file which contains service and method definitions and is used to encode and decode the request and response body. - pub proto_path: String, - #[serde(default, skip_serializing_if = "is_default")] - /// The key path in the response which should be used to group multiple requests. For instance `["news","id"]`. For more details please refer out [n + 1 guide](https://tailcall.run/docs/guides/n+1#solving-using-batching). - pub group_by: Vec, + /// This refers to the gRPC service you're going to call. For instance `NewsService`. + pub service: String, + /// This refers to the gRPC method you're going to call. For instance `GetAllNews`. + pub method: String, + #[serde(default, skip_serializing_if = "is_default")] + /// This refers to the arguments of your gRPC call. You can pass it as a static object or use Mustache template for dynamic parameters. These parameters will be added in the body in `protobuf` format. + pub body: Option, + #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] + /// This refers to the base URL of the API. If not specified, the default base URL is the one specified in the `@upstream` operator + pub base_url: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The `headers` parameter allows you to customize the headers of the HTTP request made by the `@grpc` operator. It is used by specifying a key-value map of header names and their values. Note: content-type is automatically set to application/grpc + pub headers: KeyValues, + /// The `protoPath` parameter allows you to specify the path to the proto file which contains service and method definitions and is used to encode and decode the request and response body. + pub proto_path: String, + #[serde(default, skip_serializing_if = "is_default")] + /// The key path in the response which should be used to group multiple requests. For instance `["news","id"]`. For more details please refer out [n + 1 guide](https://tailcall.run/docs/guides/n+1#solving-using-batching). + pub group_by: Vec, } #[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq, Eq, schemars::JsonSchema)] /// The @graphQL operator allows to specify GraphQL API server request to fetch data from. pub struct GraphQL { - /// Specifies the root field on the upstream to request data from. This maps a field in your schema to a field in the upstream schema. When a query is received for this field, Tailcall requests data from the corresponding upstream field. - pub name: String, - #[serde(default, skip_serializing_if = "is_default")] - /// Named arguments for the requested field. More info [here](https://tailcall.run/docs/guides/operators/#args) - pub args: Option, - #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] - /// This refers to the base URL of the API. If not specified, the default base URL is the one specified in the `@upstream` operator. - pub base_url: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The headers parameter allows you to customize the headers of the GraphQL request made by the `@graphQL` operator. It is used by specifying a key-value map of header names and their values. - pub headers: KeyValues, - #[serde(default, skip_serializing_if = "is_default")] - /// If the upstream GraphQL server supports request batching, you can specify the 'batch' argument to batch several requests into a single batch request. - /// - /// Make sure you have also specified batch settings to the `@upstream` and to the `@graphQL` operator. - pub batch: bool, + /// Specifies the root field on the upstream to request data from. This maps a field in your schema to a field in the upstream schema. When a query is received for this field, Tailcall requests data from the corresponding upstream field. + pub name: String, + #[serde(default, skip_serializing_if = "is_default")] + /// Named arguments for the requested field. More info [here](https://tailcall.run/docs/guides/operators/#args) + pub args: Option, + #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] + /// This refers to the base URL of the API. If not specified, the default base URL is the one specified in the `@upstream` operator. + pub base_url: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The headers parameter allows you to customize the headers of the GraphQL request made by the `@graphQL` operator. It is used by specifying a key-value map of header names and their values. + pub headers: KeyValues, + #[serde(default, skip_serializing_if = "is_default")] + /// If the upstream GraphQL server supports request batching, you can specify the 'batch' argument to batch several requests into a single batch request. + /// + /// Make sure you have also specified batch settings to the `@upstream` and to the `@graphQL` operator. + pub batch: bool, } #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum GraphQLOperationType { - #[default] - Query, - Mutation, + #[default] + Query, + Mutation, } impl Display for GraphQLOperationType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match self { - Self::Query => "query", - Self::Mutation => "mutation", - }) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + Self::Query => "query", + Self::Mutation => "mutation", + }) + } } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] /// The `@const` operators allows us to embed a constant response for the schema. pub struct Const { - pub data: Value, + pub data: Value, } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] /// The @addField operator simplifies data structures and queries by adding a field that inlines or flattens a nested field or node within your schema. more info [here](https://tailcall.run/docs/guides/operators/#addfield) pub struct AddField { - /// Name of the new field to be added - pub name: String, - /// Path of the data where the field should point to - pub path: Vec, + /// Name of the new field to be added + pub name: String, + /// Path of the data where the field should point to + pub path: Vec, } impl Config { - pub fn from_json(json: &str) -> Result { - Ok(serde_json::from_str(json)?) - } + pub fn from_json(json: &str) -> Result { + Ok(serde_json::from_str(json)?) + } - pub fn from_yaml(yaml: &str) -> Result { - Ok(serde_yaml::from_str(yaml)?) - } + pub fn from_yaml(yaml: &str) -> Result { + Ok(serde_yaml::from_str(yaml)?) + } - pub fn from_sdl(sdl: &str) -> Valid { - let doc = async_graphql::parser::parse_schema(sdl); - match doc { - Ok(doc) => from_document(doc), - Err(e) => Valid::fail(e.to_string()), + pub fn from_sdl(sdl: &str) -> Valid { + let doc = async_graphql::parser::parse_schema(sdl); + match doc { + Ok(doc) => from_document(doc), + Err(e) => Valid::fail(e.to_string()), + } } - } - pub fn from_source(source: Source, schema: &str) -> Result { - match source { - Source::GraphQL => Ok(Config::from_sdl(schema).to_result()?), - Source::Json => Ok(Config::from_json(schema)?), - Source::Yml => Ok(Config::from_yaml(schema)?), + pub fn from_source(source: Source, schema: &str) -> Result { + match source { + Source::GraphQL => Ok(Config::from_sdl(schema).to_result()?), + Source::Json => Ok(Config::from_json(schema)?), + Source::Yml => Ok(Config::from_yaml(schema)?), + } } - } - pub fn n_plus_one(&self) -> Vec> { - super::n_plus_one::n_plus_one(self) - } + pub fn n_plus_one(&self) -> Vec> { + super::n_plus_one::n_plus_one(self) + } } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default, schemars::JsonSchema)] +#[derive( + Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default, schemars::JsonSchema, +)] pub enum Encoding { - #[default] - ApplicationJson, - ApplicationXWwwFormUrlencoded, + #[default] + ApplicationJson, + ApplicationXWwwFormUrlencoded, } #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; + use pretty_assertions::assert_eq; - use super::*; + use super::*; - #[test] - fn test_field_has_or_not_batch_resolver() { - let f1 = Field { ..Default::default() }; + #[test] + fn test_field_has_or_not_batch_resolver() { + let f1 = Field { ..Default::default() }; - let f2 = - Field { http: Some(Http { group_by: vec!["id".to_string()], ..Default::default() }), ..Default::default() }; + let f2 = Field { + http: Some(Http { group_by: vec!["id".to_string()], ..Default::default() }), + ..Default::default() + }; - let f3 = Field { http: Some(Http { group_by: vec![], ..Default::default() }), ..Default::default() }; + let f3 = Field { + http: Some(Http { group_by: vec![], ..Default::default() }), + ..Default::default() + }; - assert!(!f1.has_batched_resolver()); - assert!(f2.has_batched_resolver()); - assert!(!f3.has_batched_resolver()); - } + assert!(!f1.has_batched_resolver()); + assert!(f2.has_batched_resolver()); + assert!(!f3.has_batched_resolver()); + } - #[test] - fn test_graphql_directive_name() { - let name = GraphQL::directive_name(); - assert_eq!(name, "graphQL"); - } + #[test] + fn test_graphql_directive_name() { + let name = GraphQL::directive_name(); + assert_eq!(name, "graphQL"); + } - #[test] - fn test_from_sdl_empty() { - let actual = Config::from_sdl("type Foo {a: Int}").to_result().unwrap(); - let expected = Config::default().types(vec![("Foo", Type::default().fields(vec![("a", Field::int())]))]); - assert_eq!(actual, expected); - } + #[test] + fn test_from_sdl_empty() { + let actual = Config::from_sdl("type Foo {a: Int}").to_result().unwrap(); + let expected = Config::default().types(vec![( + "Foo", + Type::default().fields(vec![("a", Field::int())]), + )]); + assert_eq!(actual, expected); + } } diff --git a/src/config/expr.rs b/src/config/expr.rs index 7727a26a142..5f3531b9b26 100644 --- a/src/config/expr.rs +++ b/src/config/expr.rs @@ -6,156 +6,167 @@ use super::{GraphQL, Grpc, Http}; #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] /// Allows composing operators as simple expressions pub struct Expr { - /// Root of the expression AST - pub body: ExprBody, + /// Root of the expression AST + pub body: ExprBody, } #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, schemars::JsonSchema)] pub enum ExprBody { - /// Fetch a resources using the http operator - #[serde(rename = "http")] - Http(Http), + /// Fetch a resources using the http operator + #[serde(rename = "http")] + Http(Http), - /// Fetch a resources using the grpc operator - #[serde(rename = "grpc")] - Grpc(Grpc), + /// Fetch a resources using the grpc operator + #[serde(rename = "grpc")] + Grpc(Grpc), - /// Fetch a resources using the graphQL operator - #[serde(rename = "graphQL")] - GraphQL(GraphQL), + /// Fetch a resources using the graphQL operator + #[serde(rename = "graphQL")] + GraphQL(GraphQL), - /// Evaluate to constant data - #[serde(rename = "const")] - Const(Value), - // Logic - /// Branch based on a condition - #[serde(rename = "if")] - If { - /// Condition to evaluate - cond: Box, + /// Evaluate to constant data + #[serde(rename = "const")] + Const(Value), + // Logic + /// Branch based on a condition + #[serde(rename = "if")] + If { + /// Condition to evaluate + cond: Box, - /// Expression to evaluate if the condition is true - #[serde(rename = "then")] - on_true: Box, + /// Expression to evaluate if the condition is true + #[serde(rename = "then")] + on_true: Box, - /// Expression to evaluate if the condition is false - #[serde(rename = "else")] - on_false: Box, - }, - #[serde(rename = "and")] - And(Vec), - #[serde(rename = "or")] - Or(Vec), - #[serde(rename = "cond")] - Cond(Box, Vec<(Box, Box)>), - #[serde(rename = "defaultTo")] - DefaultTo(Box, Box), - #[serde(rename = "isEmpty")] - IsEmpty(Box), - #[serde(rename = "not")] - Not(Box), + /// Expression to evaluate if the condition is false + #[serde(rename = "else")] + on_false: Box, + }, + #[serde(rename = "and")] + And(Vec), + #[serde(rename = "or")] + Or(Vec), + #[serde(rename = "cond")] + Cond(Box, Vec<(Box, Box)>), + #[serde(rename = "defaultTo")] + DefaultTo(Box, Box), + #[serde(rename = "isEmpty")] + IsEmpty(Box), + #[serde(rename = "not")] + Not(Box), - // List - #[serde(rename = "concat")] - Concat(Vec), + // List + #[serde(rename = "concat")] + Concat(Vec), - // Relation - #[serde(rename = "intersection")] - Intersection(Vec), - #[serde(rename = "difference")] - Difference(Vec, Vec), - #[serde(rename = "eq")] - Equals(Box, Box), - #[serde(rename = "gt")] - Gt(Box, Box), - #[serde(rename = "gte")] - Gte(Box, Box), - #[serde(rename = "lt")] - Lt(Box, Box), - #[serde(rename = "lte")] - Lte(Box, Box), - #[serde(rename = "max")] - Max(Vec), - #[serde(rename = "min")] - Min(Vec), - #[serde(rename = "pathEq")] - PathEq(Box, Vec, Box), - #[serde(rename = "propEq")] - PropEq(Box, String, Box), - #[serde(rename = "sortPath")] - SortPath(Box, Vec), - #[serde(rename = "symmetricDifference")] - SymmetricDifference(Vec, Vec), - #[serde(rename = "union")] - Union(Vec, Vec), + // Relation + #[serde(rename = "intersection")] + Intersection(Vec), + #[serde(rename = "difference")] + Difference(Vec, Vec), + #[serde(rename = "eq")] + Equals(Box, Box), + #[serde(rename = "gt")] + Gt(Box, Box), + #[serde(rename = "gte")] + Gte(Box, Box), + #[serde(rename = "lt")] + Lt(Box, Box), + #[serde(rename = "lte")] + Lte(Box, Box), + #[serde(rename = "max")] + Max(Vec), + #[serde(rename = "min")] + Min(Vec), + #[serde(rename = "pathEq")] + PathEq(Box, Vec, Box), + #[serde(rename = "propEq")] + PropEq(Box, String, Box), + #[serde(rename = "sortPath")] + SortPath(Box, Vec), + #[serde(rename = "symmetricDifference")] + SymmetricDifference(Vec, Vec), + #[serde(rename = "union")] + Union(Vec, Vec), - // Math - #[serde(rename = "mod")] - Mod(Box, Box), - #[serde(rename = "add")] - Add(Box, Box), - #[serde(rename = "dec")] - Dec(Box), - #[serde(rename = "divide")] - Divide(Box, Box), - #[serde(rename = "inc")] - Inc(Box), - #[serde(rename = "multiply")] - Multiply(Box, Box), - #[serde(rename = "negate")] - Negate(Box), - #[serde(rename = "product")] - Product(Vec), - #[serde(rename = "subtract")] - Subtract(Box, Box), - #[serde(rename = "sum")] - Sum(Vec), + // Math + #[serde(rename = "mod")] + Mod(Box, Box), + #[serde(rename = "add")] + Add(Box, Box), + #[serde(rename = "dec")] + Dec(Box), + #[serde(rename = "divide")] + Divide(Box, Box), + #[serde(rename = "inc")] + Inc(Box), + #[serde(rename = "multiply")] + Multiply(Box, Box), + #[serde(rename = "negate")] + Negate(Box), + #[serde(rename = "product")] + Product(Vec), + #[serde(rename = "subtract")] + Subtract(Box, Box), + #[serde(rename = "sum")] + Sum(Vec), } impl ExprBody { - /// - /// Performs a deep check on if the expression has any IO. - /// - pub fn has_io(&self) -> bool { - match self { - ExprBody::Http(_) => true, - ExprBody::Grpc(_) => true, - ExprBody::GraphQL(_) => true, - ExprBody::Const(_) => false, - ExprBody::If { cond, on_true, on_false } => cond.has_io() || on_true.has_io() || on_false.has_io(), - ExprBody::And(l) => l.iter().any(|e| e.has_io()), - ExprBody::Or(l) => l.iter().any(|e| e.has_io()), - ExprBody::Cond(default, branches) => { - default.has_io() || branches.iter().any(|(cond, expr)| cond.has_io() || expr.has_io()) - } - ExprBody::DefaultTo(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::IsEmpty(expr) => expr.has_io(), - ExprBody::Not(expr) => expr.has_io(), - ExprBody::Concat(l) => l.iter().any(|e| e.has_io()), - ExprBody::Intersection(l) => l.iter().any(|e| e.has_io()), - ExprBody::Mod(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Add(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Dec(expr) => expr.has_io(), - ExprBody::Divide(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Inc(expr) => expr.has_io(), - ExprBody::Multiply(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Negate(expr) => expr.has_io(), - ExprBody::Product(l) => l.iter().any(|e| e.has_io()), - ExprBody::Subtract(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Sum(l) => l.iter().any(|e| e.has_io()), - ExprBody::Difference(l1, l2) => l1.iter().any(|e| e.has_io()) || l2.iter().any(|e| e.has_io()), - ExprBody::Equals(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Gt(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Gte(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Lt(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Lte(expr1, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::Max(l) => l.iter().any(|e| e.has_io()), - ExprBody::Min(l) => l.iter().any(|e| e.has_io()), - ExprBody::PathEq(expr1, _, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::PropEq(expr1, _, expr2) => expr1.has_io() || expr2.has_io(), - ExprBody::SortPath(l, _) => l.has_io(), - ExprBody::SymmetricDifference(l1, l2) => l1.iter().any(|e| e.has_io()) || l2.iter().any(|e| e.has_io()), - ExprBody::Union(l1, l2) => l1.iter().any(|e| e.has_io()) || l2.iter().any(|e| e.has_io()), + /// + /// Performs a deep check on if the expression has any IO. + /// + pub fn has_io(&self) -> bool { + match self { + ExprBody::Http(_) => true, + ExprBody::Grpc(_) => true, + ExprBody::GraphQL(_) => true, + ExprBody::Const(_) => false, + ExprBody::If { cond, on_true, on_false } => { + cond.has_io() || on_true.has_io() || on_false.has_io() + } + ExprBody::And(l) => l.iter().any(|e| e.has_io()), + ExprBody::Or(l) => l.iter().any(|e| e.has_io()), + ExprBody::Cond(default, branches) => { + default.has_io() + || branches + .iter() + .any(|(cond, expr)| cond.has_io() || expr.has_io()) + } + ExprBody::DefaultTo(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::IsEmpty(expr) => expr.has_io(), + ExprBody::Not(expr) => expr.has_io(), + ExprBody::Concat(l) => l.iter().any(|e| e.has_io()), + ExprBody::Intersection(l) => l.iter().any(|e| e.has_io()), + ExprBody::Mod(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Add(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Dec(expr) => expr.has_io(), + ExprBody::Divide(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Inc(expr) => expr.has_io(), + ExprBody::Multiply(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Negate(expr) => expr.has_io(), + ExprBody::Product(l) => l.iter().any(|e| e.has_io()), + ExprBody::Subtract(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Sum(l) => l.iter().any(|e| e.has_io()), + ExprBody::Difference(l1, l2) => { + l1.iter().any(|e| e.has_io()) || l2.iter().any(|e| e.has_io()) + } + ExprBody::Equals(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Gt(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Gte(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Lt(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Lte(expr1, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::Max(l) => l.iter().any(|e| e.has_io()), + ExprBody::Min(l) => l.iter().any(|e| e.has_io()), + ExprBody::PathEq(expr1, _, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::PropEq(expr1, _, expr2) => expr1.has_io() || expr2.has_io(), + ExprBody::SortPath(l, _) => l.has_io(), + ExprBody::SymmetricDifference(l1, l2) => { + l1.iter().any(|e| e.has_io()) || l2.iter().any(|e| e.has_io()) + } + ExprBody::Union(l1, l2) => { + l1.iter().any(|e| e.has_io()) || l2.iter().any(|e| e.has_io()) + } + } } - } } diff --git a/src/config/from_document.rs b/src/config/from_document.rs index c3eea04928f..191772a2bc9 100644 --- a/src/config/from_document.rs +++ b/src/config/from_document.rs @@ -1,383 +1,442 @@ use std::collections::BTreeMap; use async_graphql::parser::types::{ - BaseType, ConstDirective, EnumType, FieldDefinition, InputObjectType, InputValueDefinition, InterfaceType, - ObjectType, SchemaDefinition, ServiceDocument, Type, TypeDefinition, TypeKind, TypeSystemDefinition, UnionType, + BaseType, ConstDirective, EnumType, FieldDefinition, InputObjectType, InputValueDefinition, + InterfaceType, ObjectType, SchemaDefinition, ServiceDocument, Type, TypeDefinition, TypeKind, + TypeSystemDefinition, UnionType, }; use async_graphql::parser::Positioned; use async_graphql::Name; use super::JS; -use crate::config::{self, Cache, Config, Expr, GraphQL, Grpc, Modify, Omit, RootSchema, Server, Union, Upstream}; +use crate::config::{ + self, Cache, Config, Expr, GraphQL, Grpc, Modify, Omit, RootSchema, Server, Union, Upstream, +}; use crate::directive::DirectiveCodec; use crate::valid::Valid; -const DEFAULT_SCHEMA_DEFINITION: &SchemaDefinition = - &SchemaDefinition { extend: false, directives: Vec::new(), query: None, mutation: None, subscription: None }; +const DEFAULT_SCHEMA_DEFINITION: &SchemaDefinition = &SchemaDefinition { + extend: false, + directives: Vec::new(), + query: None, + mutation: None, + subscription: None, +}; pub fn from_document(doc: ServiceDocument) -> Valid { - let type_definitions: Vec<_> = doc - .definitions - .iter() - .filter_map(|def| match def { - TypeSystemDefinition::Type(td) => Some(td), - _ => None, - }) - .collect(); + let type_definitions: Vec<_> = doc + .definitions + .iter() + .filter_map(|def| match def { + TypeSystemDefinition::Type(td) => Some(td), + _ => None, + }) + .collect(); - let types = to_types(&type_definitions); - let unions = to_union_types(&type_definitions); - let schema = schema_definition(&doc).map(to_root_schema); + let types = to_types(&type_definitions); + let unions = to_union_types(&type_definitions); + let schema = schema_definition(&doc).map(to_root_schema); - schema_definition(&doc) - .and_then(|sd| server(sd).zip(upstream(sd)).zip(types).zip(unions).zip(schema)) - .map(|((((server, upstream), types), unions), schema)| Config { server, upstream, types, unions, schema }) + schema_definition(&doc) + .and_then(|sd| { + server(sd) + .zip(upstream(sd)) + .zip(types) + .zip(unions) + .zip(schema) + }) + .map(|((((server, upstream), types), unions), schema)| Config { + server, + upstream, + types, + unions, + schema, + }) } fn schema_definition(doc: &ServiceDocument) -> Valid<&SchemaDefinition, String> { - doc - .definitions - .iter() - .find_map(|def| match def { - TypeSystemDefinition::Schema(schema_definition) => Some(&schema_definition.node), - _ => None, - }) - .map_or_else(|| Valid::succeed(DEFAULT_SCHEMA_DEFINITION), Valid::succeed) + doc.definitions + .iter() + .find_map(|def| match def { + TypeSystemDefinition::Schema(schema_definition) => Some(&schema_definition.node), + _ => None, + }) + .map_or_else(|| Valid::succeed(DEFAULT_SCHEMA_DEFINITION), Valid::succeed) } fn process_schema_directives + Default>( - schema_definition: &SchemaDefinition, - directive_name: &str, + schema_definition: &SchemaDefinition, + directive_name: &str, ) -> Valid { - let mut res = Valid::succeed(T::default()); - for directive in schema_definition.directives.iter() { - if directive.node.name.node.as_ref() == directive_name { - res = T::from_directive(&directive.node); + let mut res = Valid::succeed(T::default()); + for directive in schema_definition.directives.iter() { + if directive.node.name.node.as_ref() == directive_name { + res = T::from_directive(&directive.node); + } } - } - res + res } fn server(schema_definition: &SchemaDefinition) -> Valid { - process_schema_directives(schema_definition, config::Server::directive_name().as_str()) + process_schema_directives(schema_definition, config::Server::directive_name().as_str()) } fn upstream(schema_definition: &SchemaDefinition) -> Valid { - process_schema_directives(schema_definition, config::Upstream::directive_name().as_str()) + process_schema_directives( + schema_definition, + config::Upstream::directive_name().as_str(), + ) } fn to_root_schema(schema_definition: &SchemaDefinition) -> RootSchema { - let query = schema_definition.query.as_ref().map(pos_name_to_string); - let mutation = schema_definition.mutation.as_ref().map(pos_name_to_string); - let subscription = schema_definition.subscription.as_ref().map(pos_name_to_string); + let query = schema_definition.query.as_ref().map(pos_name_to_string); + let mutation = schema_definition.mutation.as_ref().map(pos_name_to_string); + let subscription = schema_definition + .subscription + .as_ref() + .map(pos_name_to_string); - RootSchema { query, mutation, subscription } + RootSchema { query, mutation, subscription } } fn pos_name_to_string(pos: &Positioned) -> String { - pos.node.to_string() -} -fn to_types(type_definitions: &Vec<&Positioned>) -> Valid, String> { - Valid::from_iter(type_definitions, |type_definition| { - let type_name = pos_name_to_string(&type_definition.node.name); - let directives = &type_definition.node.directives; - Cache::from_directives(directives.iter()) - .and_then(|cache| match type_definition.node.kind.clone() { - TypeKind::Object(object_type) => to_object_type( - &object_type, - &type_definition.node.description, - &type_definition.node.directives, - cache, - ) - .some(), - TypeKind::Interface(interface_type) => to_object_type( - &interface_type, - &type_definition.node.description, - &type_definition.node.directives, - cache, + pos.node.to_string() +} +fn to_types( + type_definitions: &Vec<&Positioned>, +) -> Valid, String> { + Valid::from_iter(type_definitions, |type_definition| { + let type_name = pos_name_to_string(&type_definition.node.name); + let directives = &type_definition.node.directives; + Cache::from_directives(directives.iter()) + .and_then(|cache| match type_definition.node.kind.clone() { + TypeKind::Object(object_type) => to_object_type( + &object_type, + &type_definition.node.description, + &type_definition.node.directives, + cache, + ) + .some(), + TypeKind::Interface(interface_type) => to_object_type( + &interface_type, + &type_definition.node.description, + &type_definition.node.directives, + cache, + ) + .some(), + TypeKind::Enum(enum_type) => Valid::succeed(Some(to_enum(enum_type))), + TypeKind::InputObject(input_object_type) => { + to_input_object(input_object_type, cache).some() + } + TypeKind::Union(_) => Valid::none(), + TypeKind::Scalar => Valid::succeed(Some(to_scalar_type())), + }) + .map(|option| (type_name, option)) + }) + .map(|vec| { + BTreeMap::from_iter( + vec.into_iter() + .filter_map(|(name, option)| option.map(|tpe| (name, tpe))), ) - .some(), - TypeKind::Enum(enum_type) => Valid::succeed(Some(to_enum(enum_type))), - TypeKind::InputObject(input_object_type) => to_input_object(input_object_type, cache).some(), - TypeKind::Union(_) => Valid::none(), - TypeKind::Scalar => Valid::succeed(Some(to_scalar_type())), - }) - .map(|option| (type_name, option)) - }) - .map(|vec| { - BTreeMap::from_iter( - vec - .into_iter() - .filter_map(|(name, option)| option.map(|tpe| (name, tpe))), - ) - }) + }) } fn to_scalar_type() -> config::Type { - config::Type { scalar: true, ..Default::default() } -} -fn to_union_types(type_definitions: &Vec<&Positioned>) -> Valid, String> { - let mut unions = BTreeMap::new(); - for type_definition in type_definitions { - let type_name = pos_name_to_string(&type_definition.node.name); - let type_opt = match type_definition.node.kind.clone() { - TypeKind::Union(union_type) => to_union( - union_type, - &type_definition.node.description.to_owned().map(|pos| pos.node), - ), - _ => continue, - }; - unions.insert(type_name, type_opt); - } + config::Type { scalar: true, ..Default::default() } +} +fn to_union_types( + type_definitions: &Vec<&Positioned>, +) -> Valid, String> { + let mut unions = BTreeMap::new(); + for type_definition in type_definitions { + let type_name = pos_name_to_string(&type_definition.node.name); + let type_opt = match type_definition.node.kind.clone() { + TypeKind::Union(union_type) => to_union( + union_type, + &type_definition + .node + .description + .to_owned() + .map(|pos| pos.node), + ), + _ => continue, + }; + unions.insert(type_name, type_opt); + } - Valid::succeed(unions) + Valid::succeed(unions) } #[allow(clippy::too_many_arguments)] fn to_object_type( - object: &T, - description: &Option>, - directives: &[Positioned], - cache: Option, + object: &T, + description: &Option>, + directives: &[Positioned], + cache: Option, ) -> Valid where - T: ObjectLike, + T: ObjectLike, { - let fields = object.fields(); - let implements = object.implements(); - let interface = object.is_interface(); + let fields = object.fields(); + let implements = object.implements(); + let interface = object.is_interface(); - to_fields(fields, cache).map(|fields| { - let doc = description.to_owned().map(|pos| pos.node); - let implements = implements.iter().map(|pos| pos.node.to_string()).collect(); - let added_fields = to_add_fields_from_directives(directives); - config::Type { fields, added_fields, doc, interface, implements, ..Default::default() } - }) + to_fields(fields, cache).map(|fields| { + let doc = description.to_owned().map(|pos| pos.node); + let implements = implements.iter().map(|pos| pos.node.to_string()).collect(); + let added_fields = to_add_fields_from_directives(directives); + config::Type { + fields, + added_fields, + doc, + interface, + implements, + ..Default::default() + } + }) } fn to_enum(enum_type: EnumType) -> config::Type { - let variants = enum_type - .values - .iter() - .map(|value| value.node.value.to_string()) - .collect(); - config::Type { variants: Some(variants), ..Default::default() } -} -fn to_input_object(input_object_type: InputObjectType, cache: Option) -> Valid { - to_input_object_fields(&input_object_type.fields, cache).map(|fields| config::Type { fields, ..Default::default() }) + let variants = enum_type + .values + .iter() + .map(|value| value.node.value.to_string()) + .collect(); + config::Type { variants: Some(variants), ..Default::default() } +} +fn to_input_object( + input_object_type: InputObjectType, + cache: Option, +) -> Valid { + to_input_object_fields(&input_object_type.fields, cache) + .map(|fields| config::Type { fields, ..Default::default() }) } fn to_fields_inner( - fields: &Vec>, - cache: Option, - transform: F, + fields: &Vec>, + cache: Option, + transform: F, ) -> Valid, String> where - F: Fn(&T, Option) -> Valid, - T: HasName, + F: Fn(&T, Option) -> Valid, + T: HasName, { - Valid::from_iter(fields, |field| { - let field_name = pos_name_to_string(field.node.name()); - transform(&field.node, cache.clone()).map(|field| (field_name, field)) - }) - .map(BTreeMap::from_iter) + Valid::from_iter(fields, |field| { + let field_name = pos_name_to_string(field.node.name()); + transform(&field.node, cache.clone()).map(|field| (field_name, field)) + }) + .map(BTreeMap::from_iter) } fn to_fields( - fields: &Vec>, - cache: Option, + fields: &Vec>, + cache: Option, ) -> Valid, String> { - to_fields_inner(fields, cache, to_field) + to_fields_inner(fields, cache, to_field) } fn to_input_object_fields( - input_object_fields: &Vec>, - cache: Option, + input_object_fields: &Vec>, + cache: Option, ) -> Valid, String> { - to_fields_inner(input_object_fields, cache, to_input_object_field) + to_fields_inner(input_object_fields, cache, to_input_object_field) } -fn to_field(field_definition: &FieldDefinition, cache: Option) -> Valid { - to_common_field(field_definition, to_args(field_definition), cache) +fn to_field( + field_definition: &FieldDefinition, + cache: Option, +) -> Valid { + to_common_field(field_definition, to_args(field_definition), cache) } fn to_input_object_field( - field_definition: &InputValueDefinition, - cache: Option, + field_definition: &InputValueDefinition, + cache: Option, ) -> Valid { - to_common_field(field_definition, BTreeMap::new(), cache) + to_common_field(field_definition, BTreeMap::new(), cache) } fn to_common_field( - field: &F, - args: BTreeMap, - parent_cache: Option, + field: &F, + args: BTreeMap, + parent_cache: Option, ) -> Valid where - F: Fieldlike, + F: Fieldlike, { - let type_of = field.type_of(); - let base = &type_of.base; - let nullable = &type_of.nullable; - let description = field.description(); - let directives = field.directives(); + let type_of = field.type_of(); + let base = &type_of.base; + let nullable = &type_of.nullable; + let description = field.description(); + let directives = field.directives(); - let type_of = to_type_of(type_of); - let list = matches!(&base, BaseType::List(_)); - let list_type_required = matches!(&base, BaseType::List(type_of) if !type_of.nullable); - let doc = description.to_owned().map(|pos| pos.node); - config::Http::from_directives(directives.iter()) - .zip(GraphQL::from_directives(directives.iter())) - .zip(Cache::from_directives(directives.iter())) - .zip(Grpc::from_directives(directives.iter())) - .zip(Expr::from_directives(directives.iter())) - .zip(Omit::from_directives(directives.iter())) - .zip(Modify::from_directives(directives.iter())) - .zip(JS::from_directives(directives.iter())) - .map(|(((((((http, graphql), cache), grpc), expr), omit), modify), script)| { - let const_field = to_const_field(directives); - config::Field { - type_of, - list, - required: !nullable, - list_type_required, - args, - doc, - modify, - omit, - http, - grpc, - script, - const_field, - graphql, - expr, - cache: cache.or(parent_cache), - } - }) + let type_of = to_type_of(type_of); + let list = matches!(&base, BaseType::List(_)); + let list_type_required = matches!(&base, BaseType::List(type_of) if !type_of.nullable); + let doc = description.to_owned().map(|pos| pos.node); + config::Http::from_directives(directives.iter()) + .zip(GraphQL::from_directives(directives.iter())) + .zip(Cache::from_directives(directives.iter())) + .zip(Grpc::from_directives(directives.iter())) + .zip(Expr::from_directives(directives.iter())) + .zip(Omit::from_directives(directives.iter())) + .zip(Modify::from_directives(directives.iter())) + .zip(JS::from_directives(directives.iter())) + .map( + |(((((((http, graphql), cache), grpc), expr), omit), modify), script)| { + let const_field = to_const_field(directives); + config::Field { + type_of, + list, + required: !nullable, + list_type_required, + args, + doc, + modify, + omit, + http, + grpc, + script, + const_field, + graphql, + expr, + cache: cache.or(parent_cache), + } + }, + ) } fn to_type_of(type_: &Type) -> String { - match &type_.base { - BaseType::Named(name) => name.to_string(), - BaseType::List(ty) => to_type_of(ty), - } + match &type_.base { + BaseType::Named(name) => name.to_string(), + BaseType::List(ty) => to_type_of(ty), + } } fn to_args(field_definition: &FieldDefinition) -> BTreeMap { - let mut args: BTreeMap = BTreeMap::new(); + let mut args: BTreeMap = BTreeMap::new(); - for arg in field_definition.arguments.iter() { - let arg_name = pos_name_to_string(&arg.node.name); - let arg_val = to_arg(&arg.node); - args.insert(arg_name, arg_val); - } + for arg in field_definition.arguments.iter() { + let arg_name = pos_name_to_string(&arg.node.name); + let arg_val = to_arg(&arg.node); + args.insert(arg_name, arg_val); + } - args + args } fn to_arg(input_value_definition: &InputValueDefinition) -> config::Arg { - let type_of = to_type_of(&input_value_definition.ty.node); - let list = matches!(&input_value_definition.ty.node.base, BaseType::List(_)); - let required = !input_value_definition.ty.node.nullable; - let doc = input_value_definition.description.to_owned().map(|pos| pos.node); - let modify = Modify::from_directives(input_value_definition.directives.iter()) - .to_result() - .ok() - .flatten(); - let default_value = if let Some(pos) = input_value_definition.default_value.as_ref() { - let value = &pos.node; - serde_json::to_value(value).ok() - } else { - None - }; - config::Arg { type_of, list, required, doc, modify, default_value } + let type_of = to_type_of(&input_value_definition.ty.node); + let list = matches!(&input_value_definition.ty.node.base, BaseType::List(_)); + let required = !input_value_definition.ty.node.nullable; + let doc = input_value_definition + .description + .to_owned() + .map(|pos| pos.node); + let modify = Modify::from_directives(input_value_definition.directives.iter()) + .to_result() + .ok() + .flatten(); + let default_value = if let Some(pos) = input_value_definition.default_value.as_ref() { + let value = &pos.node; + serde_json::to_value(value).ok() + } else { + None + }; + config::Arg { type_of, list, required, doc, modify, default_value } } fn to_union(union_type: UnionType, doc: &Option) -> Union { - let types = union_type - .members - .iter() - .map(|member| member.node.to_string()) - .collect(); - Union { types, doc: doc.clone() } + let types = union_type + .members + .iter() + .map(|member| member.node.to_string()) + .collect(); + Union { types, doc: doc.clone() } } fn to_const_field(directives: &[Positioned]) -> Option { - directives.iter().find_map(|directive| { - if directive.node.name.node == config::Const::directive_name() { - config::Const::from_directive(&directive.node).to_result().ok() - } else { - None - } - }) + directives.iter().find_map(|directive| { + if directive.node.name.node == config::Const::directive_name() { + config::Const::from_directive(&directive.node) + .to_result() + .ok() + } else { + None + } + }) } -fn to_add_fields_from_directives(directives: &[Positioned]) -> Vec { - directives - .iter() - .filter_map(|directive| { - if directive.node.name.node == config::AddField::directive_name() { - config::AddField::from_directive(&directive.node).to_result().ok() - } else { - None - } - }) - .collect::>() +fn to_add_fields_from_directives( + directives: &[Positioned], +) -> Vec { + directives + .iter() + .filter_map(|directive| { + if directive.node.name.node == config::AddField::directive_name() { + config::AddField::from_directive(&directive.node) + .to_result() + .ok() + } else { + None + } + }) + .collect::>() } trait HasName { - fn name(&self) -> &Positioned; + fn name(&self) -> &Positioned; } impl HasName for FieldDefinition { - fn name(&self) -> &Positioned { - &self.name - } + fn name(&self) -> &Positioned { + &self.name + } } impl HasName for InputValueDefinition { - fn name(&self) -> &Positioned { - &self.name - } + fn name(&self) -> &Positioned { + &self.name + } } trait Fieldlike { - fn type_of(&self) -> &Type; - fn description(&self) -> &Option>; - fn directives(&self) -> &[Positioned]; + fn type_of(&self) -> &Type; + fn description(&self) -> &Option>; + fn directives(&self) -> &[Positioned]; } impl Fieldlike for FieldDefinition { - fn type_of(&self) -> &Type { - &self.ty.node - } - fn description(&self) -> &Option> { - &self.description - } - fn directives(&self) -> &[Positioned] { - &self.directives - } + fn type_of(&self) -> &Type { + &self.ty.node + } + fn description(&self) -> &Option> { + &self.description + } + fn directives(&self) -> &[Positioned] { + &self.directives + } } impl Fieldlike for InputValueDefinition { - fn type_of(&self) -> &Type { - &self.ty.node - } - fn description(&self) -> &Option> { - &self.description - } - fn directives(&self) -> &[Positioned] { - &self.directives - } + fn type_of(&self) -> &Type { + &self.ty.node + } + fn description(&self) -> &Option> { + &self.description + } + fn directives(&self) -> &[Positioned] { + &self.directives + } } trait ObjectLike { - fn fields(&self) -> &Vec>; - fn implements(&self) -> &Vec>; - fn is_interface(&self) -> bool; + fn fields(&self) -> &Vec>; + fn implements(&self) -> &Vec>; + fn is_interface(&self) -> bool; } impl ObjectLike for ObjectType { - fn fields(&self) -> &Vec> { - &self.fields - } - fn implements(&self) -> &Vec> { - &self.implements - } - fn is_interface(&self) -> bool { - false - } + fn fields(&self) -> &Vec> { + &self.fields + } + fn implements(&self) -> &Vec> { + &self.implements + } + fn is_interface(&self) -> bool { + false + } } impl ObjectLike for InterfaceType { - fn fields(&self) -> &Vec> { - &self.fields - } - fn implements(&self) -> &Vec> { - &self.implements - } - fn is_interface(&self) -> bool { - true - } + fn fields(&self) -> &Vec> { + &self.fields + } + fn implements(&self) -> &Vec> { + &self.implements + } + fn is_interface(&self) -> bool { + true + } } diff --git a/src/config/group_by.rs b/src/config/group_by.rs index 1705a58438a..e2064114a58 100644 --- a/src/config/group_by.rs +++ b/src/config/group_by.rs @@ -4,31 +4,31 @@ use crate::config::is_default; #[derive(Clone, Debug, Eq, Serialize, Deserialize, PartialEq, schemars::JsonSchema)] /// The `groupBy` parameter groups multiple data requests into a single call. For more details please refer out [n + 1 guide](https://tailcall.run/docs/guides/n+1#solving-using-batching). pub struct GroupBy { - #[serde(default, skip_serializing_if = "is_default")] - path: Vec, + #[serde(default, skip_serializing_if = "is_default")] + path: Vec, } impl GroupBy { - pub fn new(path: Vec) -> Self { - Self { path } - } + pub fn new(path: Vec) -> Self { + Self { path } + } - pub fn path(&self) -> Vec { - if self.path.is_empty() { - return vec![String::from(ID)]; + pub fn path(&self) -> Vec { + if self.path.is_empty() { + return vec![String::from(ID)]; + } + self.path.clone() } - self.path.clone() - } - pub fn key(&self) -> &str { - self.path.last().map(|a| a.as_str()).unwrap_or(ID) - } + pub fn key(&self) -> &str { + self.path.last().map(|a| a.as_str()).unwrap_or(ID) + } } const ID: &str = "id"; impl Default for GroupBy { - fn default() -> Self { - Self { path: vec![ID.to_string()] } - } + fn default() -> Self { + Self { path: vec![ID.to_string()] } + } } diff --git a/src/config/into_document.rs b/src/config/into_document.rs index 3545d43e828..4cd4509c1a0 100644 --- a/src/config/into_document.rs +++ b/src/config/into_document.rs @@ -7,194 +7,225 @@ use crate::blueprint::TypeLike; use crate::directive::DirectiveCodec; fn pos(a: A) -> Positioned { - Positioned::new(a, Pos::default()) + Positioned::new(a, Pos::default()) } fn config_document(config: &Config) -> ServiceDocument { - let mut definitions = Vec::new(); - let schema_definition = SchemaDefinition { - extend: false, - directives: vec![pos(config.server.to_directive()), pos(config.upstream.to_directive())], - query: config.schema.query.clone().map(|name| pos(Name::new(name))), - mutation: config.schema.mutation.clone().map(|name| pos(Name::new(name))), - subscription: config.schema.subscription.clone().map(|name| pos(Name::new(name))), - }; - definitions.push(TypeSystemDefinition::Schema(pos(schema_definition))); - for (type_name, type_def) in config.types.iter() { - let kind = if type_def.interface { - TypeKind::Interface(InterfaceType { - implements: type_def - .implements - .iter() - .map(|name| pos(Name::new(name.clone()))) - .collect(), - fields: type_def - .fields - .clone() - .iter() - .map(|(name, field)| { - let directives = get_directives(field); - let base_type = if field.list { - BaseType::List(Box::new(Type { - nullable: !field.list_type_required, - base: BaseType::Named(Name::new(field.type_of.clone())), - })) - } else { - BaseType::Named(Name::new(field.type_of.clone())) - }; - pos(FieldDefinition { - description: field.doc.clone().map(pos), - name: pos(Name::new(name.clone())), - arguments: vec![], - ty: pos(Type { nullable: !field.required, base: base_type }), + let mut definitions = Vec::new(); + let schema_definition = SchemaDefinition { + extend: false, + directives: vec![ + pos(config.server.to_directive()), + pos(config.upstream.to_directive()), + ], + query: config.schema.query.clone().map(|name| pos(Name::new(name))), + mutation: config + .schema + .mutation + .clone() + .map(|name| pos(Name::new(name))), + subscription: config + .schema + .subscription + .clone() + .map(|name| pos(Name::new(name))), + }; + definitions.push(TypeSystemDefinition::Schema(pos(schema_definition))); + for (type_name, type_def) in config.types.iter() { + let kind = if type_def.interface { + TypeKind::Interface(InterfaceType { + implements: type_def + .implements + .iter() + .map(|name| pos(Name::new(name.clone()))) + .collect(), + fields: type_def + .fields + .clone() + .iter() + .map(|(name, field)| { + let directives = get_directives(field); + let base_type = if field.list { + BaseType::List(Box::new(Type { + nullable: !field.list_type_required, + base: BaseType::Named(Name::new(field.type_of.clone())), + })) + } else { + BaseType::Named(Name::new(field.type_of.clone())) + }; + pos(FieldDefinition { + description: field.doc.clone().map(pos), + name: pos(Name::new(name.clone())), + arguments: vec![], + ty: pos(Type { nullable: !field.required, base: base_type }), - directives, + directives, + }) + }) + .collect::>>(), }) - }) - .collect::>>(), - }) - } else if let Some(variants) = &type_def.variants { - TypeKind::Enum(EnumType { - values: variants - .iter() - .map(|value| { - pos(EnumValueDefinition { description: None, value: pos(Name::new(value.clone())), directives: Vec::new() }) - }) - .collect(), - }) - } else if config.input_types().contains(type_name) { - TypeKind::InputObject(InputObjectType { - fields: type_def - .fields - .clone() - .iter() - .map(|(name, field)| { - let directives = get_directives(field); - let base_type = if field.list { - async_graphql::parser::types::BaseType::List(Box::new(Type { - nullable: !field.list_type_required, - base: async_graphql::parser::types::BaseType::Named(Name::new(field.type_of.clone())), - })) - } else { - async_graphql::parser::types::BaseType::Named(Name::new(field.type_of.clone())) - }; + } else if let Some(variants) = &type_def.variants { + TypeKind::Enum(EnumType { + values: variants + .iter() + .map(|value| { + pos(EnumValueDefinition { + description: None, + value: pos(Name::new(value.clone())), + directives: Vec::new(), + }) + }) + .collect(), + }) + } else if config.input_types().contains(type_name) { + TypeKind::InputObject(InputObjectType { + fields: type_def + .fields + .clone() + .iter() + .map(|(name, field)| { + let directives = get_directives(field); + let base_type = if field.list { + async_graphql::parser::types::BaseType::List(Box::new(Type { + nullable: !field.list_type_required, + base: async_graphql::parser::types::BaseType::Named(Name::new( + field.type_of.clone(), + )), + })) + } else { + async_graphql::parser::types::BaseType::Named(Name::new( + field.type_of.clone(), + )) + }; - pos(async_graphql::parser::types::InputValueDefinition { - description: field.doc.clone().map(pos), - name: pos(Name::new(name.clone())), - ty: pos(Type { nullable: !field.required, base: base_type }), + pos(async_graphql::parser::types::InputValueDefinition { + description: field.doc.clone().map(pos), + name: pos(Name::new(name.clone())), + ty: pos(Type { nullable: !field.required, base: base_type }), - default_value: None, - directives, + default_value: None, + directives, + }) + }) + .collect::>>(), }) - }) - .collect::>>(), - }) - } else if type_def.fields.is_empty() { - TypeKind::Scalar - } else { - TypeKind::Object(ObjectType { - implements: type_def - .implements - .iter() - .map(|name| pos(Name::new(name.clone()))) - .collect(), - fields: type_def - .fields - .clone() - .iter() - .map(|(name, field)| { - let directives = get_directives(field); - let base_type = if field.list { - async_graphql::parser::types::BaseType::List(Box::new(Type { - nullable: !field.list_type_required, - base: async_graphql::parser::types::BaseType::Named(Name::new(field.type_of.clone())), - })) - } else { - async_graphql::parser::types::BaseType::Named(Name::new(field.type_of.clone())) - }; + } else if type_def.fields.is_empty() { + TypeKind::Scalar + } else { + TypeKind::Object(ObjectType { + implements: type_def + .implements + .iter() + .map(|name| pos(Name::new(name.clone()))) + .collect(), + fields: type_def + .fields + .clone() + .iter() + .map(|(name, field)| { + let directives = get_directives(field); + let base_type = if field.list { + async_graphql::parser::types::BaseType::List(Box::new(Type { + nullable: !field.list_type_required, + base: async_graphql::parser::types::BaseType::Named(Name::new( + field.type_of.clone(), + )), + })) + } else { + async_graphql::parser::types::BaseType::Named(Name::new( + field.type_of.clone(), + )) + }; - let args_map = field.args.clone(); - let args = args_map - .iter() - .map(|(name, arg)| { - let base_type = if arg.list { - async_graphql::parser::types::BaseType::List(Box::new(Type { - nullable: !arg.list_type_required(), - base: async_graphql::parser::types::BaseType::Named(Name::new(arg.type_of.clone())), - })) - } else { - async_graphql::parser::types::BaseType::Named(Name::new(arg.type_of.clone())) - }; - pos(async_graphql::parser::types::InputValueDefinition { - description: arg.doc.clone().map(pos), - name: pos(Name::new(name.clone())), - ty: pos(Type { nullable: !arg.required, base: base_type }), + let args_map = field.args.clone(); + let args = args_map + .iter() + .map(|(name, arg)| { + let base_type = if arg.list { + async_graphql::parser::types::BaseType::List(Box::new(Type { + nullable: !arg.list_type_required(), + base: async_graphql::parser::types::BaseType::Named( + Name::new(arg.type_of.clone()), + ), + })) + } else { + async_graphql::parser::types::BaseType::Named(Name::new( + arg.type_of.clone(), + )) + }; + pos(async_graphql::parser::types::InputValueDefinition { + description: arg.doc.clone().map(pos), + name: pos(Name::new(name.clone())), + ty: pos(Type { nullable: !arg.required, base: base_type }), - default_value: arg - .default_value - .clone() - .map(|v| pos(ConstValue::String(v.to_string()))), - directives: Vec::new(), - }) - }) - .collect::>>(); + default_value: arg + .default_value + .clone() + .map(|v| pos(ConstValue::String(v.to_string()))), + directives: Vec::new(), + }) + }) + .collect::>>(); - pos(async_graphql::parser::types::FieldDefinition { - description: field.doc.clone().map(pos), - name: pos(Name::new(name.clone())), - arguments: args, - ty: pos(Type { nullable: !field.required, base: base_type }), + pos(async_graphql::parser::types::FieldDefinition { + description: field.doc.clone().map(pos), + name: pos(Name::new(name.clone())), + arguments: args, + ty: pos(Type { nullable: !field.required, base: base_type }), - directives, + directives, + }) + }) + .collect::>>(), }) - }) - .collect::>>(), - }) - }; - definitions.push(TypeSystemDefinition::Type(pos(TypeDefinition { - extend: false, - description: None, - name: pos(Name::new(type_name.clone())), - directives: type_def - .added_fields - .iter() - .map(|added_field| pos(added_field.to_directive())) - .collect::>(), - kind, - }))); - } - for (name, union) in config.unions.iter() { - definitions.push(TypeSystemDefinition::Type(pos(TypeDefinition { - extend: false, - description: None, - name: pos(Name::new(name)), - directives: Vec::new(), - kind: TypeKind::Union(UnionType { - members: union.types.iter().map(|name| pos(Name::new(name.clone()))).collect(), - }), - }))); - } + }; + definitions.push(TypeSystemDefinition::Type(pos(TypeDefinition { + extend: false, + description: None, + name: pos(Name::new(type_name.clone())), + directives: type_def + .added_fields + .iter() + .map(|added_field| pos(added_field.to_directive())) + .collect::>(), + kind, + }))); + } + for (name, union) in config.unions.iter() { + definitions.push(TypeSystemDefinition::Type(pos(TypeDefinition { + extend: false, + description: None, + name: pos(Name::new(name)), + directives: Vec::new(), + kind: TypeKind::Union(UnionType { + members: union + .types + .iter() + .map(|name| pos(Name::new(name.clone()))) + .collect(), + }), + }))); + } - ServiceDocument { definitions } + ServiceDocument { definitions } } fn get_directives(field: &crate::config::Field) -> Vec> { - let directives = vec![ - field.http.as_ref().map(|d| pos(d.to_directive())), - field.script.as_ref().map(|d| pos(d.to_directive())), - field.const_field.as_ref().map(|d| pos(d.to_directive())), - field.modify.as_ref().map(|d| pos(d.to_directive())), - field.omit.as_ref().map(|d| pos(d.to_directive())), - field.graphql.as_ref().map(|d| pos(d.to_directive())), - field.grpc.as_ref().map(|d| pos(d.to_directive())), - field.expr.as_ref().map(|d| pos(d.to_directive())), - ]; + let directives = vec![ + field.http.as_ref().map(|d| pos(d.to_directive())), + field.script.as_ref().map(|d| pos(d.to_directive())), + field.const_field.as_ref().map(|d| pos(d.to_directive())), + field.modify.as_ref().map(|d| pos(d.to_directive())), + field.omit.as_ref().map(|d| pos(d.to_directive())), + field.graphql.as_ref().map(|d| pos(d.to_directive())), + field.grpc.as_ref().map(|d| pos(d.to_directive())), + field.expr.as_ref().map(|d| pos(d.to_directive())), + ]; - directives.into_iter().flatten().collect() + directives.into_iter().flatten().collect() } impl From for ServiceDocument { - fn from(value: Config) -> Self { - config_document(&value) - } + fn from(value: Config) -> Self { + config_document(&value) + } } diff --git a/src/config/key_values.rs b/src/config/key_values.rs index 4a42f5ab743..d2c31ae475a 100644 --- a/src/config/key_values.rs +++ b/src/config/key_values.rs @@ -7,87 +7,87 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; pub struct KeyValues(pub BTreeMap); impl Deref for KeyValues { - type Target = BTreeMap; + type Target = BTreeMap; - fn deref(&self) -> &Self::Target { - &self.0 - } + fn deref(&self) -> &Self::Target { + &self.0 + } } #[derive(Serialize, Deserialize, Clone, Debug, Default, Eq, PartialEq)] pub struct KeyValue { - pub key: String, - pub value: String, + pub key: String, + pub value: String, } impl Serialize for KeyValues { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let vec: Vec = self - .0 - .iter() - .map(|(k, v)| KeyValue { key: k.clone(), value: v.clone() }) - .collect(); - vec.serialize(serializer) - } + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let vec: Vec = self + .0 + .iter() + .map(|(k, v)| KeyValue { key: k.clone(), value: v.clone() }) + .collect(); + vec.serialize(serializer) + } } impl<'de> Deserialize<'de> for KeyValues { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let vec: Vec = Vec::deserialize(deserializer)?; - let btree_map = vec.into_iter().map(|kv| (kv.key, kv.value)).collect(); - Ok(KeyValues(btree_map)) - } + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let vec: Vec = Vec::deserialize(deserializer)?; + let btree_map = vec.into_iter().map(|kv| (kv.key, kv.value)).collect(); + Ok(KeyValues(btree_map)) + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn test_serialize_empty_keyvalues() { - let kv = KeyValues::default(); - let serialized = serde_json::to_string(&kv).unwrap(); - assert_eq!(serialized, "[]"); - } + #[test] + fn test_serialize_empty_keyvalues() { + let kv = KeyValues::default(); + let serialized = serde_json::to_string(&kv).unwrap(); + assert_eq!(serialized, "[]"); + } - #[test] - fn test_serialize_non_empty_keyvalues() { - let mut kv = KeyValues::default(); - kv.0.insert("a".to_string(), "b".to_string()); - let serialized = serde_json::to_string(&kv).unwrap(); - assert_eq!(serialized, r#"[{"key":"a","value":"b"}]"#); - } + #[test] + fn test_serialize_non_empty_keyvalues() { + let mut kv = KeyValues::default(); + kv.0.insert("a".to_string(), "b".to_string()); + let serialized = serde_json::to_string(&kv).unwrap(); + assert_eq!(serialized, r#"[{"key":"a","value":"b"}]"#); + } - #[test] - fn test_deserialize_empty_keyvalues() { - let data = "[]"; - let kv: KeyValues = serde_json::from_str(data).unwrap(); - assert_eq!(kv, KeyValues::default()); - } + #[test] + fn test_deserialize_empty_keyvalues() { + let data = "[]"; + let kv: KeyValues = serde_json::from_str(data).unwrap(); + assert_eq!(kv, KeyValues::default()); + } - #[test] - fn test_deserialize_non_empty_keyvalues() { - let data = r#"[{"key":"a","value":"b"}]"#; - let kv: KeyValues = serde_json::from_str(data).unwrap(); - assert_eq!(kv.0["a"], "b"); - } + #[test] + fn test_deserialize_non_empty_keyvalues() { + let data = r#"[{"key":"a","value":"b"}]"#; + let kv: KeyValues = serde_json::from_str(data).unwrap(); + assert_eq!(kv.0["a"], "b"); + } - #[test] - fn test_default_keyvalues() { - let kv = KeyValues::default(); - assert_eq!(kv.0.len(), 0); - } + #[test] + fn test_default_keyvalues() { + let kv = KeyValues::default(); + assert_eq!(kv.0.len(), 0); + } - #[test] - fn test_deref() { - let mut kv = KeyValues::default(); - kv.0.insert("a".to_string(), "b".to_string()); - // Using the deref trait - assert_eq!(kv["a"], "b"); - } + #[test] + fn test_deref() { + let mut kv = KeyValues::default(); + kv.0.insert("a".to_string(), "b".to_string()); + // Using the deref trait + assert_eq!(kv["a"], "b"); + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 69efb3dec5f..610821b75e2 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,5 +17,5 @@ mod source; mod upstream; fn is_default(val: &T) -> bool { - *val == T::default() + *val == T::default() } diff --git a/src/config/n_plus_one.rs b/src/config/n_plus_one.rs index 856e17f089a..68b84aa2eaa 100644 --- a/src/config/n_plus_one.rs +++ b/src/config/n_plus_one.rs @@ -1,332 +1,370 @@ use crate::config::Config; struct FindFanOutContext<'a> { - config: &'a Config, - type_name: &'a String, - path: Vec<(String, String)>, - is_list: bool, + config: &'a Config, + type_name: &'a String, + path: Vec<(String, String)>, + is_list: bool, } fn find_fan_out(context: FindFanOutContext) -> Vec> { - let config = context.config; - let type_name = context.type_name; - let path = context.path; - let is_list = context.is_list; - match config.find_type(type_name) { - Some(type_) => type_ - .fields - .iter() - .flat_map(|(field_name, field)| { - let mut new_path = path.clone(); - new_path.push((type_name.clone(), field_name.clone())); - if path.iter().any(|item| &item.0 == type_name && &item.1 == field_name) { - Vec::new() - } else if field.has_resolver() && !field.has_batched_resolver() && is_list { - vec![new_path] - } else { - find_fan_out(FindFanOutContext { - config, - type_name: &field.type_of, - path: new_path, - is_list: field.list || is_list, - }) - } - }) - .collect(), - None => Vec::new(), - } + let config = context.config; + let type_name = context.type_name; + let path = context.path; + let is_list = context.is_list; + match config.find_type(type_name) { + Some(type_) => type_ + .fields + .iter() + .flat_map(|(field_name, field)| { + let mut new_path = path.clone(); + new_path.push((type_name.clone(), field_name.clone())); + if path + .iter() + .any(|item| &item.0 == type_name && &item.1 == field_name) + { + Vec::new() + } else if field.has_resolver() && !field.has_batched_resolver() && is_list { + vec![new_path] + } else { + find_fan_out(FindFanOutContext { + config, + type_name: &field.type_of, + path: new_path, + is_list: field.list || is_list, + }) + } + }) + .collect(), + None => Vec::new(), + } } pub fn n_plus_one(config: &Config) -> Vec> { - if let Some(query) = &config.schema.query { - find_fan_out(FindFanOutContext { config, type_name: query, path: Vec::new(), is_list: false }) - } else { - Vec::new() - } + if let Some(query) = &config.schema.query { + find_fan_out(FindFanOutContext { + config, + type_name: query, + path: Vec::new(), + is_list: false, + }) + } else { + Vec::new() + } } #[cfg(test)] mod tests { - use crate::config::{Config, Field, Http, Type}; + use crate::config::{Config, Field, Http, Type}; - #[test] - fn test_nplusone_resolvers() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default() - .type_of("F1".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![( - "f2", - Field::default() - .type_of("F2".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("String".to_string()))]), - ), - ]); + #[test] + fn test_nplusone_resolvers() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![( + "f2", + Field::default() + .type_of("F2".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F2", + Type::default() + .fields(vec![("f3", Field::default().type_of("String".to_string()))]), + ), + ]); - let actual = config.n_plus_one(); - let expected = vec![vec![ - ("Query".to_string(), "f1".to_string()), - ("F1".to_string(), "f2".to_string()), - ]]; - assert_eq!(actual, expected) - } + let actual = config.n_plus_one(); + let expected = vec![vec![ + ("Query".to_string(), "f1".to_string()), + ("F1".to_string(), "f2".to_string()), + ]]; + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_batched_resolvers() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default() - .type_of("F1".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![( - "f2", - Field::default() - .type_of("F2".to_string()) - .to_list() - .http(Http { group_by: vec!["id".into()], ..Default::default() }), - )]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("String".to_string()))]), - ), - ]); + #[test] + fn test_nplusone_batched_resolvers() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![( + "f2", + Field::default() + .type_of("F2".to_string()) + .to_list() + .http(Http { group_by: vec!["id".into()], ..Default::default() }), + )]), + ), + ( + "F2", + Type::default() + .fields(vec![("f3", Field::default().type_of("String".to_string()))]), + ), + ]); - let actual = config.n_plus_one(); - let expected: Vec> = vec![]; - assert_eq!(actual, expected) - } + let actual = config.n_plus_one(); + let expected: Vec> = vec![]; + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_nested_resolvers() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default() - .type_of("F1".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![("f2", Field::default().type_of("F2".to_string()).to_list())]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("F3".to_string()).to_list())]), - ), - ( - "F3", - Type::default().fields(vec![( - "f4", - Field::default().type_of("String".to_string()).http(Http::default()), - )]), - ), - ]); + #[test] + fn test_nplusone_nested_resolvers() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![( + "f2", + Field::default().type_of("F2".to_string()).to_list(), + )]), + ), + ( + "F2", + Type::default().fields(vec![( + "f3", + Field::default().type_of("F3".to_string()).to_list(), + )]), + ), + ( + "F3", + Type::default().fields(vec![( + "f4", + Field::default() + .type_of("String".to_string()) + .http(Http::default()), + )]), + ), + ]); - let actual = config.n_plus_one(); - let expected = vec![vec![ - ("Query".to_string(), "f1".to_string()), - ("F1".to_string(), "f2".to_string()), - ("F2".to_string(), "f3".to_string()), - ("F3".to_string(), "f4".to_string()), - ]]; - assert_eq!(actual, expected) - } + let actual = config.n_plus_one(); + let expected = vec![vec![ + ("Query".to_string(), "f1".to_string()), + ("F1".to_string(), "f2".to_string()), + ("F2".to_string(), "f3".to_string()), + ("F3".to_string(), "f4".to_string()), + ]]; + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_nested_resolvers_non_list_resolvers() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default().type_of("F1".to_string()).http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![("f2", Field::default().type_of("F2".to_string()).to_list())]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("F3".to_string()).to_list())]), - ), - ( - "F3", - Type::default().fields(vec![( - "f4", - Field::default().type_of("String".to_string()).http(Http::default()), - )]), - ), - ]); + #[test] + fn test_nplusone_nested_resolvers_non_list_resolvers() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![( + "f2", + Field::default().type_of("F2".to_string()).to_list(), + )]), + ), + ( + "F2", + Type::default().fields(vec![( + "f3", + Field::default().type_of("F3".to_string()).to_list(), + )]), + ), + ( + "F3", + Type::default().fields(vec![( + "f4", + Field::default() + .type_of("String".to_string()) + .http(Http::default()), + )]), + ), + ]); - let actual = config.n_plus_one(); - let expected = vec![vec![ - ("Query".to_string(), "f1".to_string()), - ("F1".to_string(), "f2".to_string()), - ("F2".to_string(), "f3".to_string()), - ("F3".to_string(), "f4".to_string()), - ]]; - assert_eq!(actual, expected) - } + let actual = config.n_plus_one(); + let expected = vec![vec![ + ("Query".to_string(), "f1".to_string()), + ("F1".to_string(), "f2".to_string()), + ("F2".to_string(), "f3".to_string()), + ("F3".to_string(), "f4".to_string()), + ]]; + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_nested_resolvers_without_resolvers() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default() - .type_of("F1".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![("f2", Field::default().type_of("F2".to_string()).to_list())]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("String".to_string()))]), - ), - ]); + #[test] + fn test_nplusone_nested_resolvers_without_resolvers() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![( + "f2", + Field::default().type_of("F2".to_string()).to_list(), + )]), + ), + ( + "F2", + Type::default() + .fields(vec![("f3", Field::default().type_of("String".to_string()))]), + ), + ]); - let actual = config.n_plus_one(); - let expected: Vec> = vec![]; - assert_eq!(actual, expected) - } + let actual = config.n_plus_one(); + let expected: Vec> = vec![]; + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_cycles() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default() - .type_of("F1".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![ - ("f1", Field::default().type_of("F1".to_string())), - ("f2", Field::default().type_of("F2".to_string()).to_list()), - ]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("String".to_string()))]), - ), - ]); + #[test] + fn test_nplusone_cycles() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![ + ("f1", Field::default().type_of("F1".to_string())), + ("f2", Field::default().type_of("F2".to_string()).to_list()), + ]), + ), + ( + "F2", + Type::default() + .fields(vec![("f3", Field::default().type_of("String".to_string()))]), + ), + ]); - let actual = config.n_plus_one(); - let expected: Vec> = vec![]; - assert_eq!(actual, expected) - } + let actual = config.n_plus_one(); + let expected: Vec> = vec![]; + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_cycles_with_resolvers() { - let config = Config::default().query("Query").types(vec![ - ( - "Query", - Type::default().fields(vec![( - "f1", - Field::default() - .type_of("F1".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "F1", - Type::default().fields(vec![ - ("f1", Field::default().type_of("F1".to_string()).to_list()), - ( - "f2", - Field::default().type_of("String".to_string()).http(Http::default()), - ), - ]), - ), - ( - "F2", - Type::default().fields(vec![("f3", Field::default().type_of("String".to_string()))]), - ), - ]); + #[test] + fn test_nplusone_cycles_with_resolvers() { + let config = Config::default().query("Query").types(vec![ + ( + "Query", + Type::default().fields(vec![( + "f1", + Field::default() + .type_of("F1".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "F1", + Type::default().fields(vec![ + ("f1", Field::default().type_of("F1".to_string()).to_list()), + ( + "f2", + Field::default() + .type_of("String".to_string()) + .http(Http::default()), + ), + ]), + ), + ( + "F2", + Type::default() + .fields(vec![("f3", Field::default().type_of("String".to_string()))]), + ), + ]); - let actual = config.n_plus_one(); - let expected = vec![ - vec![ - ("Query".to_string(), "f1".to_string()), - ("F1".to_string(), "f1".to_string()), - ("F1".to_string(), "f2".to_string()), - ], - vec![ - ("Query".to_string(), "f1".to_string()), - ("F1".to_string(), "f2".to_string()), - ], - ]; + let actual = config.n_plus_one(); + let expected = vec![ + vec![ + ("Query".to_string(), "f1".to_string()), + ("F1".to_string(), "f1".to_string()), + ("F1".to_string(), "f2".to_string()), + ], + vec![ + ("Query".to_string(), "f1".to_string()), + ("F1".to_string(), "f2".to_string()), + ], + ]; - assert_eq!(actual, expected) - } + assert_eq!(actual, expected) + } - #[test] - fn test_nplusone_nested_non_list() { - let f_field = Field::default().type_of("F".to_string()).http(Http::default()); + #[test] + fn test_nplusone_nested_non_list() { + let f_field = Field::default() + .type_of("F".to_string()) + .http(Http::default()); - let config = Config::default().query("Query").types(vec![ - ("Query", Type::default().fields(vec![("f", f_field)])), - ( - "F", - Type::default().fields(vec![( - "g", - Field::default() - .type_of("G".to_string()) - .to_list() - .http(Http::default()), - )]), - ), - ( - "G", - Type::default().fields(vec![("e", Field::default().type_of("String".to_string()))]), - ), - ]); + let config = Config::default().query("Query").types(vec![ + ("Query", Type::default().fields(vec![("f", f_field)])), + ( + "F", + Type::default().fields(vec![( + "g", + Field::default() + .type_of("G".to_string()) + .to_list() + .http(Http::default()), + )]), + ), + ( + "G", + Type::default().fields(vec![("e", Field::default().type_of("String".to_string()))]), + ), + ]); - let actual = config.n_plus_one(); - let expected = Vec::>::new(); + let actual = config.n_plus_one(); + let expected = Vec::>::new(); - assert_eq!(actual, expected) - } + assert_eq!(actual, expected) + } } diff --git a/src/config/reader.rs b/src/config/reader.rs index 33285e8022b..ecd95878783 100644 --- a/src/config/reader.rs +++ b/src/config/reader.rs @@ -6,136 +6,145 @@ use crate::{FileIO, HttpIO}; /// Reads the configuration from a file or from an HTTP URL and resolves all linked assets. pub struct ConfigReader { - file: File, - http: Http, + file: File, + http: Http, } struct FileRead { - content: String, - path: String, + content: String, + path: String, } impl ConfigReader { - pub fn init(file: File, http: Http) -> Self { - Self { file, http } - } - - /// Reads a file from the filesystem or from an HTTP URL - async fn read_file(&self, file: T) -> anyhow::Result { - // Is an HTTP URL - let content = if let Ok(url) = Url::parse(&file.to_string()) { - let response = self - .http - .execute(reqwest::Request::new(reqwest::Method::GET, url)) - .await?; - - String::from_utf8(response.body.to_vec())? - } else { - // Is a file path - self.file.read(&file.to_string()).await? - }; - - Ok(FileRead { content, path: file.to_string() }) - } - - /// Reads all the files in parallel - async fn read_files(&self, files: &[T]) -> anyhow::Result> { - let files = files.iter().map(|x| self.read_file(x.to_string())); - let content = join_all(files).await.into_iter().collect::>>()?; - Ok(content) - } - - pub async fn read(&self, files: &[T]) -> anyhow::Result { - let files = self.read_files(files).await?; - let mut config = Config::default(); - for file in files.iter() { - let source = Source::detect(&file.path)?; - let schema = &file.content; - let new_config = Config::from_source(source, schema)?; - config = config.merge_right(&new_config); + pub fn init(file: File, http: Http) -> Self { + Self { file, http } } - Ok(config) - } + /// Reads a file from the filesystem or from an HTTP URL + async fn read_file(&self, file: T) -> anyhow::Result { + // Is an HTTP URL + let content = if let Ok(url) = Url::parse(&file.to_string()) { + let response = self + .http + .execute(reqwest::Request::new(reqwest::Method::GET, url)) + .await?; + + String::from_utf8(response.body.to_vec())? + } else { + // Is a file path + self.file.read(&file.to_string()).await? + }; + + Ok(FileRead { content, path: file.to_string() }) + } + + /// Reads all the files in parallel + async fn read_files(&self, files: &[T]) -> anyhow::Result> { + let files = files.iter().map(|x| self.read_file(x.to_string())); + let content = join_all(files) + .await + .into_iter() + .collect::>>()?; + Ok(content) + } + + pub async fn read(&self, files: &[T]) -> anyhow::Result { + let files = self.read_files(files).await?; + let mut config = Config::default(); + for file in files.iter() { + let source = Source::detect(&file.path)?; + let schema = &file.content; + let new_config = Config::from_source(source, schema)?; + config = config.merge_right(&new_config); + } + + Ok(config) + } } #[cfg(test)] mod reader_tests { - use tokio::io::AsyncReadExt; - - use crate::cli::{init_file, init_http}; - use crate::config::reader::ConfigReader; - use crate::config::{Config, Type, Upstream}; - - fn start_mock_server() -> httpmock::MockServer { - httpmock::MockServer::start() - } - - #[tokio::test] - async fn test_all() { - let mut cfg = Config::default(); - cfg.schema.query = Some("Test".to_string()); - cfg = cfg.types([("Test", Type::default())].to_vec()); - - let server = start_mock_server(); - let header_serv = server.mock(|when, then| { - when.method(httpmock::Method::GET).path("/bar.graphql"); - then.status(200).body(cfg.to_sdl()); - }); - - let mut json = String::new(); - tokio::fs::File::open("examples/jsonplaceholder.json") - .await - .unwrap() - .read_to_string(&mut json) - .await - .unwrap(); - - let foo_json_server = server.mock(|when, then| { - when.method(httpmock::Method::GET).path("/foo.json"); - then.status(200).body(json); - }); - - let port = server.port(); - let files: Vec = [ - "examples/jsonplaceholder.yml", // config from local file - format!("http://localhost:{port}/bar.graphql").as_str(), // with content-type header - format!("http://localhost:{port}/foo.json").as_str(), // with url extension - ] - .iter() - .map(|x| x.to_string()) - .collect(); - let cr = ConfigReader::init(init_file(), init_http(&Upstream::default())); - let c = cr.read(&files).await.unwrap(); - assert_eq!( - ["Post", "Query", "Test", "User"] + use tokio::io::AsyncReadExt; + + use crate::cli::{init_file, init_http}; + use crate::config::reader::ConfigReader; + use crate::config::{Config, Type, Upstream}; + + fn start_mock_server() -> httpmock::MockServer { + httpmock::MockServer::start() + } + + #[tokio::test] + async fn test_all() { + let mut cfg = Config::default(); + cfg.schema.query = Some("Test".to_string()); + cfg = cfg.types([("Test", Type::default())].to_vec()); + + let server = start_mock_server(); + let header_serv = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/bar.graphql"); + then.status(200).body(cfg.to_sdl()); + }); + + let mut json = String::new(); + tokio::fs::File::open("examples/jsonplaceholder.json") + .await + .unwrap() + .read_to_string(&mut json) + .await + .unwrap(); + + let foo_json_server = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/foo.json"); + then.status(200).body(json); + }); + + let port = server.port(); + let files: Vec = [ + "examples/jsonplaceholder.yml", // config from local file + format!("http://localhost:{port}/bar.graphql").as_str(), // with content-type header + format!("http://localhost:{port}/foo.json").as_str(), // with url extension + ] .iter() - .map(|i| i.to_string()) - .collect::>(), - c.types.keys().map(|i| i.to_string()).collect::>() - ); - foo_json_server.assert(); // checks if the request was actually made - header_serv.assert(); - } - - #[tokio::test] - async fn test_local_files() { - let files: Vec = [ - "examples/jsonplaceholder.yml", - "examples/jsonplaceholder.graphql", - "examples/jsonplaceholder.json", - ] - .iter() - .map(|x| x.to_string()) - .collect(); - let cr = ConfigReader::init(init_file(), init_http(&Upstream::default())); - let c = cr.read(&files).await.unwrap(); - assert_eq!( - ["Post", "Query", "User"] + .map(|x| x.to_string()) + .collect(); + let cr = ConfigReader::init(init_file(), init_http(&Upstream::default())); + let c = cr.read(&files).await.unwrap(); + assert_eq!( + ["Post", "Query", "Test", "User"] + .iter() + .map(|i| i.to_string()) + .collect::>(), + c.types + .keys() + .map(|i| i.to_string()) + .collect::>() + ); + foo_json_server.assert(); // checks if the request was actually made + header_serv.assert(); + } + + #[tokio::test] + async fn test_local_files() { + let files: Vec = [ + "examples/jsonplaceholder.yml", + "examples/jsonplaceholder.graphql", + "examples/jsonplaceholder.json", + ] .iter() - .map(|i| i.to_string()) - .collect::>(), - c.types.keys().map(|i| i.to_string()).collect::>() - ); - } + .map(|x| x.to_string()) + .collect(); + let cr = ConfigReader::init(init_file(), init_http(&Upstream::default())); + let c = cr.read(&files).await.unwrap(); + assert_eq!( + ["Post", "Query", "User"] + .iter() + .map(|i| i.to_string()) + .collect::>(), + c.types + .keys() + .map(|i| i.to_string()) + .collect::>() + ); + } } diff --git a/src/config/server.rs b/src/config/server.rs index 095cf7387d3..145519829c7 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -8,148 +8,150 @@ use crate::config::{is_default, KeyValues}; #[serde(rename_all = "camelCase")] /// The `@server` directive, when applied at the schema level, offers a comprehensive set of server configurations. It dictates how the server behaves and helps tune tailcall for various use-cases. pub struct Server { - #[serde(default, skip_serializing_if = "is_default")] - /// `apolloTracing` exposes GraphQL query performance data, including execution time of queries and individual resolvers. - pub apollo_tracing: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `cacheControlHeader` sends `Cache-Control` headers in responses when activated. The `max-age` value is the least of the values received from upstream services. @default `false`. - pub cache_control_header: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `graphiql` activates the GraphiQL IDE at the root path within Tailcall, a tool for query development and testing. @default `false`. - pub graphiql: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `introspection` allows clients to fetch schema information directly, aiding tools and applications in understanding available types, fields, and operations. @default `true`. - pub introspection: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `queryValidation` checks incoming GraphQL queries against the schema, preventing errors from invalid queries. Can be disabled for performance. @default `false`. - pub query_validation: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `responseValidation` Tailcall automatically validates responses from upstream services using inferred schema. @default `false`. - pub response_validation: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `batchRequests` combines multiple requests into one, improving performance but potentially introducing latency and complicating debugging. Use judiciously. @default `false` - pub batch_requests: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `globalResponseTimeout` sets the maximum query duration before termination, acting as a safeguard against long-running queries. - pub global_response_timeout: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `showcase` enables the /showcase/graphql endpoint. - pub showcase: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `workers` sets the number of worker threads. @default the number of system cores. - pub workers: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `hostname` sets the server hostname. - pub hostname: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `port` sets the Tailcall running port. @default `8000`. - pub port: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// This configuration defines local variables for server operations. Useful for storing constant configurations, secrets, or shared information. - pub vars: KeyValues, - #[serde(skip_serializing_if = "is_default", default)] - /// The responseHeader is a key-value pair array. These headers are included in every server response. Useful for headers like Access-Control-Allow-Origin for cross-origin requests, or additional headers like X-Allowed-Roles for downstream services. - pub response_headers: KeyValues, - #[serde(default, skip_serializing_if = "is_default")] - /// `version` sets the HTTP version for the server. Options are `HTTP1` and `HTTP2`. @default `HTTP1`. - pub version: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `cert` sets the path to certificate(s) for running the server over HTTP2 (HTTPS). @default `null`. - pub cert: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `key` sets the path to key for running the server over HTTP2 (HTTPS). @default `null`. - pub key: Option, - #[serde(default, skip_serializing_if = "is_default")] - pub pipeline_flush: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `apolloTracing` exposes GraphQL query performance data, including execution time of queries and individual resolvers. + pub apollo_tracing: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `cacheControlHeader` sends `Cache-Control` headers in responses when activated. The `max-age` value is the least of the values received from upstream services. @default `false`. + pub cache_control_header: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `graphiql` activates the GraphiQL IDE at the root path within Tailcall, a tool for query development and testing. @default `false`. + pub graphiql: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `introspection` allows clients to fetch schema information directly, aiding tools and applications in understanding available types, fields, and operations. @default `true`. + pub introspection: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `queryValidation` checks incoming GraphQL queries against the schema, preventing errors from invalid queries. Can be disabled for performance. @default `false`. + pub query_validation: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `responseValidation` Tailcall automatically validates responses from upstream services using inferred schema. @default `false`. + pub response_validation: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `batchRequests` combines multiple requests into one, improving performance but potentially introducing latency and complicating debugging. Use judiciously. @default `false` + pub batch_requests: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `globalResponseTimeout` sets the maximum query duration before termination, acting as a safeguard against long-running queries. + pub global_response_timeout: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `showcase` enables the /showcase/graphql endpoint. + pub showcase: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `workers` sets the number of worker threads. @default the number of system cores. + pub workers: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `hostname` sets the server hostname. + pub hostname: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `port` sets the Tailcall running port. @default `8000`. + pub port: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// This configuration defines local variables for server operations. Useful for storing constant configurations, secrets, or shared information. + pub vars: KeyValues, + #[serde(skip_serializing_if = "is_default", default)] + /// The responseHeader is a key-value pair array. These headers are included in every server response. Useful for headers like Access-Control-Allow-Origin for cross-origin requests, or additional headers like X-Allowed-Roles for downstream services. + pub response_headers: KeyValues, + #[serde(default, skip_serializing_if = "is_default")] + /// `version` sets the HTTP version for the server. Options are `HTTP1` and `HTTP2`. @default `HTTP1`. + pub version: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `cert` sets the path to certificate(s) for running the server over HTTP2 (HTTPS). @default `null`. + pub cert: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `key` sets the path to key for running the server over HTTP2 (HTTPS). @default `null`. + pub key: Option, + #[serde(default, skip_serializing_if = "is_default")] + pub pipeline_flush: Option, } #[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Clone, Default, schemars::JsonSchema)] pub enum HttpVersion { - #[default] - HTTP1, - HTTP2, + #[default] + HTTP1, + HTTP2, } impl Server { - pub fn enable_apollo_tracing(&self) -> bool { - self.apollo_tracing.unwrap_or(false) - } - pub fn enable_graphiql(&self) -> bool { - self.graphiql.unwrap_or(false) - } - pub fn get_global_response_timeout(&self) -> i64 { - self.global_response_timeout.unwrap_or(0) - } + pub fn enable_apollo_tracing(&self) -> bool { + self.apollo_tracing.unwrap_or(false) + } + pub fn enable_graphiql(&self) -> bool { + self.graphiql.unwrap_or(false) + } + pub fn get_global_response_timeout(&self) -> i64 { + self.global_response_timeout.unwrap_or(0) + } - pub fn get_workers(&self) -> usize { - self.workers.unwrap_or(num_cpus::get()) - } + pub fn get_workers(&self) -> usize { + self.workers.unwrap_or(num_cpus::get()) + } - pub fn get_port(&self) -> u16 { - self.port.unwrap_or(8000) - } - pub fn enable_http_validation(&self) -> bool { - self.response_validation.unwrap_or(false) - } - pub fn enable_cache_control(&self) -> bool { - self.cache_control_header.unwrap_or(false) - } - pub fn enable_introspection(&self) -> bool { - self.introspection.unwrap_or(true) - } - pub fn enable_query_validation(&self) -> bool { - self.query_validation.unwrap_or(false) - } - pub fn enable_batch_requests(&self) -> bool { - self.batch_requests.unwrap_or(false) - } - pub fn enable_showcase(&self) -> bool { - self.showcase.unwrap_or(false) - } + pub fn get_port(&self) -> u16 { + self.port.unwrap_or(8000) + } + pub fn enable_http_validation(&self) -> bool { + self.response_validation.unwrap_or(false) + } + pub fn enable_cache_control(&self) -> bool { + self.cache_control_header.unwrap_or(false) + } + pub fn enable_introspection(&self) -> bool { + self.introspection.unwrap_or(true) + } + pub fn enable_query_validation(&self) -> bool { + self.query_validation.unwrap_or(false) + } + pub fn enable_batch_requests(&self) -> bool { + self.batch_requests.unwrap_or(false) + } + pub fn enable_showcase(&self) -> bool { + self.showcase.unwrap_or(false) + } - pub fn get_hostname(&self) -> String { - self.hostname.clone().unwrap_or("127.0.0.1".to_string()) - } + pub fn get_hostname(&self) -> String { + self.hostname.clone().unwrap_or("127.0.0.1".to_string()) + } - pub fn get_vars(&self) -> BTreeMap { - self.vars.clone().0 - } + pub fn get_vars(&self) -> BTreeMap { + self.vars.clone().0 + } - pub fn get_response_headers(&self) -> KeyValues { - self.response_headers.clone() - } + pub fn get_response_headers(&self) -> KeyValues { + self.response_headers.clone() + } - pub fn get_version(self) -> HttpVersion { - self.version.unwrap_or(HttpVersion::HTTP1) - } + pub fn get_version(self) -> HttpVersion { + self.version.unwrap_or(HttpVersion::HTTP1) + } - pub fn get_pipeline_flush(&self) -> bool { - self.pipeline_flush.unwrap_or(true) - } + pub fn get_pipeline_flush(&self) -> bool { + self.pipeline_flush.unwrap_or(true) + } - pub fn merge_right(mut self, other: Self) -> Self { - self.apollo_tracing = other.apollo_tracing.or(self.apollo_tracing); - self.cache_control_header = other.cache_control_header.or(self.cache_control_header); - self.graphiql = other.graphiql.or(self.graphiql); - self.introspection = other.introspection.or(self.introspection); - self.query_validation = other.query_validation.or(self.query_validation); - self.response_validation = other.response_validation.or(self.response_validation); - self.batch_requests = other.batch_requests.or(self.batch_requests); - self.global_response_timeout = other.global_response_timeout.or(self.global_response_timeout); - self.showcase = other.showcase.or(self.showcase); - self.workers = other.workers.or(self.workers); - self.port = other.port.or(self.port); - self.hostname = other.hostname.or(self.hostname); - let mut vars = self.vars.0.clone(); - vars.extend(other.vars.0); - self.vars = KeyValues(vars); - let mut response_headers = self.response_headers.0.clone(); - response_headers.extend(other.response_headers.0); - self.response_headers = KeyValues(response_headers); - self.version = other.version.or(self.version); - self.cert = other.cert.or(self.cert); - self.key = other.key.or(self.key); - self.pipeline_flush = other.pipeline_flush.or(self.pipeline_flush); - self - } + pub fn merge_right(mut self, other: Self) -> Self { + self.apollo_tracing = other.apollo_tracing.or(self.apollo_tracing); + self.cache_control_header = other.cache_control_header.or(self.cache_control_header); + self.graphiql = other.graphiql.or(self.graphiql); + self.introspection = other.introspection.or(self.introspection); + self.query_validation = other.query_validation.or(self.query_validation); + self.response_validation = other.response_validation.or(self.response_validation); + self.batch_requests = other.batch_requests.or(self.batch_requests); + self.global_response_timeout = other + .global_response_timeout + .or(self.global_response_timeout); + self.showcase = other.showcase.or(self.showcase); + self.workers = other.workers.or(self.workers); + self.port = other.port.or(self.port); + self.hostname = other.hostname.or(self.hostname); + let mut vars = self.vars.0.clone(); + vars.extend(other.vars.0); + self.vars = KeyValues(vars); + let mut response_headers = self.response_headers.0.clone(); + response_headers.extend(other.response_headers.0); + self.response_headers = KeyValues(response_headers); + self.version = other.version.or(self.version); + self.cert = other.cert.or(self.cert); + self.key = other.key.or(self.key); + self.pipeline_flush = other.pipeline_flush.or(self.pipeline_flush); + self + } } diff --git a/src/config/source.rs b/src/config/source.rs index f5e9adbc556..e974553e27a 100644 --- a/src/config/source.rs +++ b/src/config/source.rs @@ -4,9 +4,9 @@ use super::Config; #[derive(Clone)] pub enum Source { - Json, - Yml, - GraphQL, + Json, + Yml, + GraphQL, } const JSON_EXT: &str = "json"; @@ -19,43 +19,42 @@ const ALL: [Source; 3] = [Source::Json, Source::Yml, Source::GraphQL]; pub struct UnsupportedFileFormat(String); impl std::str::FromStr for Source { - type Err = UnsupportedFileFormat; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "json" => Ok(Source::Json), - "yml" | "yaml" => Ok(Source::Yml), - "graphql" | "gql" => Ok(Source::GraphQL), - _ => Err(UnsupportedFileFormat(s.to_string())), + type Err = UnsupportedFileFormat; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "json" => Ok(Source::Json), + "yml" | "yaml" => Ok(Source::Yml), + "graphql" | "gql" => Ok(Source::GraphQL), + _ => Err(UnsupportedFileFormat(s.to_string())), + } } - } } impl Source { - pub fn ext(&self) -> &'static str { - match self { - Source::Json => JSON_EXT, - Source::Yml => YML_EXT, - Source::GraphQL => GRAPHQL_EXT, + pub fn ext(&self) -> &'static str { + match self { + Source::Json => JSON_EXT, + Source::Yml => YML_EXT, + Source::GraphQL => GRAPHQL_EXT, + } } - } - - fn ends_with(&self, file: &str) -> bool { - file.ends_with(&format!(".{}", self.ext())) - } - - pub fn detect(name: &str) -> Result { - ALL - .into_iter() - .find(|format| format.ends_with(name)) - .ok_or(UnsupportedFileFormat(name.to_string())) - } - - pub fn encode(&self, config: &Config) -> Result { - match self { - Source::Yml => Ok(config.to_yaml()?), - Source::GraphQL => Ok(config.to_sdl()), - Source::Json => Ok(config.to_json(true)?), + + fn ends_with(&self, file: &str) -> bool { + file.ends_with(&format!(".{}", self.ext())) + } + + pub fn detect(name: &str) -> Result { + ALL.into_iter() + .find(|format| format.ends_with(name)) + .ok_or(UnsupportedFileFormat(name.to_string())) + } + + pub fn encode(&self, config: &Config) -> Result { + match self { + Source::Yml => Ok(config.to_yaml()?), + Source::GraphQL => Ok(config.to_sdl()), + Source::Json => Ok(config.to_json(true)?), + } } - } } diff --git a/src/config/upstream.rs b/src/config/upstream.rs index 4b8e07c0489..41be880d2b6 100644 --- a/src/config/upstream.rs +++ b/src/config/upstream.rs @@ -8,152 +8,156 @@ use crate::config::is_default; #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, schemars::JsonSchema)] #[serde(rename_all = "camelCase", default)] pub struct Batch { - pub max_size: usize, - pub delay: usize, - pub headers: BTreeSet, + pub max_size: usize, + pub delay: usize, + pub headers: BTreeSet, } impl Default for Batch { - fn default() -> Self { - Batch { max_size: 100, delay: 0, headers: BTreeSet::new() } - } + fn default() -> Self { + Batch { max_size: 100, delay: 0, headers: BTreeSet::new() } + } } #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, schemars::JsonSchema)] pub struct Proxy { - pub url: String, + pub url: String, } -#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, Default, schemars::JsonSchema)] +#[derive( + Serialize, Deserialize, PartialEq, Eq, Clone, Debug, Setters, Default, schemars::JsonSchema, +)] #[serde(rename_all = "camelCase", default)] /// The `upstream` directive allows you to control various aspects of the upstream server connection. This includes settings like connection timeouts, keep-alive intervals, and more. If not specified, default values are used. pub struct Upstream { - #[serde(default, skip_serializing_if = "is_default")] - /// The time in seconds that the connection pool will wait before closing idle connections. - pub pool_idle_timeout: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The maximum number of idle connections that will be maintained per host. - pub pool_max_idle_per_host: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The time in seconds between each keep-alive message sent to maintain the connection. - pub keep_alive_interval: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The time in seconds that the connection will wait for a keep-alive message before closing. - pub keep_alive_timeout: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// A boolean value that determines whether keep-alive messages should be sent while the connection is idle. - pub keep_alive_while_idle: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The `proxy` setting defines an intermediary server through which the upstream requests will be routed before reaching their intended endpoint. By specifying a proxy URL, you introduce an additional layer, enabling custom routing and security policies. - pub proxy: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The time in seconds that the connection will wait for a response before timing out. - pub connect_timeout: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The maximum time in seconds that the connection will wait for a response. - pub timeout: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The time in seconds between each TCP keep-alive message sent to maintain the connection. - pub tcp_keep_alive: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// The User-Agent header value to be used in HTTP requests. @default `Tailcall/1.0` - pub user_agent: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// `allowedHeaders` defines the HTTP headers allowed to be forwarded to upstream services. If not set, no headers are forwarded, enhancing security but possibly limiting data flow. - pub allowed_headers: Option>, - #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] - /// This refers to the default base URL for your APIs. If it's not explicitly mentioned in the `@upstream` operator, then each [@http](#http) operator must specify its own `baseURL`. If neither `@upstream` nor [@http](#http) provides a `baseURL`, it results in a compilation error. - pub base_url: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// Activating this enables Tailcall's HTTP caching, adhering to the [HTTP Caching RFC](https://tools.ietf.org/html/rfc7234), to enhance performance by minimizing redundant data fetches. Defaults to `false` if unspecified. - pub http_cache: Option, - #[serde(default, skip_serializing_if = "is_default")] - /// An object that specifies the batch settings, including `maxSize` (the maximum size of the batch), `delay` (the delay in milliseconds between each batch), and `headers` (an array of HTTP headers to be included in the batch). - pub batch: Option, - #[setters(strip_option)] - #[serde(rename = "http2Only", default, skip_serializing_if = "is_default")] - /// The `http2Only` setting allows you to specify whether the client should always issue HTTP2 requests, without checking if the server supports it or not. By default it is set to `false` for all HTTP requests made by the server, but is automatically set to true for GRPC. - pub http2_only: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The time in seconds that the connection pool will wait before closing idle connections. + pub pool_idle_timeout: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The maximum number of idle connections that will be maintained per host. + pub pool_max_idle_per_host: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The time in seconds between each keep-alive message sent to maintain the connection. + pub keep_alive_interval: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The time in seconds that the connection will wait for a keep-alive message before closing. + pub keep_alive_timeout: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// A boolean value that determines whether keep-alive messages should be sent while the connection is idle. + pub keep_alive_while_idle: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The `proxy` setting defines an intermediary server through which the upstream requests will be routed before reaching their intended endpoint. By specifying a proxy URL, you introduce an additional layer, enabling custom routing and security policies. + pub proxy: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The time in seconds that the connection will wait for a response before timing out. + pub connect_timeout: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The maximum time in seconds that the connection will wait for a response. + pub timeout: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The time in seconds between each TCP keep-alive message sent to maintain the connection. + pub tcp_keep_alive: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// The User-Agent header value to be used in HTTP requests. @default `Tailcall/1.0` + pub user_agent: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// `allowedHeaders` defines the HTTP headers allowed to be forwarded to upstream services. If not set, no headers are forwarded, enhancing security but possibly limiting data flow. + pub allowed_headers: Option>, + #[serde(rename = "baseURL", default, skip_serializing_if = "is_default")] + /// This refers to the default base URL for your APIs. If it's not explicitly mentioned in the `@upstream` operator, then each [@http](#http) operator must specify its own `baseURL`. If neither `@upstream` nor [@http](#http) provides a `baseURL`, it results in a compilation error. + pub base_url: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// Activating this enables Tailcall's HTTP caching, adhering to the [HTTP Caching RFC](https://tools.ietf.org/html/rfc7234), to enhance performance by minimizing redundant data fetches. Defaults to `false` if unspecified. + pub http_cache: Option, + #[serde(default, skip_serializing_if = "is_default")] + /// An object that specifies the batch settings, including `maxSize` (the maximum size of the batch), `delay` (the delay in milliseconds between each batch), and `headers` (an array of HTTP headers to be included in the batch). + pub batch: Option, + #[setters(strip_option)] + #[serde(rename = "http2Only", default, skip_serializing_if = "is_default")] + /// The `http2Only` setting allows you to specify whether the client should always issue HTTP2 requests, without checking if the server supports it or not. By default it is set to `false` for all HTTP requests made by the server, but is automatically set to true for GRPC. + pub http2_only: Option, } impl Upstream { - pub fn get_pool_idle_timeout(&self) -> u64 { - self.pool_idle_timeout.unwrap_or(60) - } - pub fn get_pool_max_idle_per_host(&self) -> usize { - self.pool_max_idle_per_host.unwrap_or(60) - } - pub fn get_keep_alive_interval(&self) -> u64 { - self.keep_alive_interval.unwrap_or(60) - } - pub fn get_keep_alive_timeout(&self) -> u64 { - self.keep_alive_timeout.unwrap_or(60) - } - pub fn get_keep_alive_while_idle(&self) -> bool { - self.keep_alive_while_idle.unwrap_or(false) - } - pub fn get_connect_timeout(&self) -> u64 { - self.connect_timeout.unwrap_or(60) - } - pub fn get_timeout(&self) -> u64 { - self.timeout.unwrap_or(60) - } - pub fn get_tcp_keep_alive(&self) -> u64 { - self.tcp_keep_alive.unwrap_or(5) - } - pub fn get_user_agent(&self) -> String { - self.user_agent.clone().unwrap_or("Tailcall/1.0".to_string()) - } - pub fn get_enable_http_cache(&self) -> bool { - self.http_cache.unwrap_or(false) - } - pub fn get_allowed_headers(&self) -> BTreeSet { - self.allowed_headers.clone().unwrap_or_default() - } - pub fn get_delay(&self) -> usize { - self.batch.clone().unwrap_or_default().delay - } + pub fn get_pool_idle_timeout(&self) -> u64 { + self.pool_idle_timeout.unwrap_or(60) + } + pub fn get_pool_max_idle_per_host(&self) -> usize { + self.pool_max_idle_per_host.unwrap_or(60) + } + pub fn get_keep_alive_interval(&self) -> u64 { + self.keep_alive_interval.unwrap_or(60) + } + pub fn get_keep_alive_timeout(&self) -> u64 { + self.keep_alive_timeout.unwrap_or(60) + } + pub fn get_keep_alive_while_idle(&self) -> bool { + self.keep_alive_while_idle.unwrap_or(false) + } + pub fn get_connect_timeout(&self) -> u64 { + self.connect_timeout.unwrap_or(60) + } + pub fn get_timeout(&self) -> u64 { + self.timeout.unwrap_or(60) + } + pub fn get_tcp_keep_alive(&self) -> u64 { + self.tcp_keep_alive.unwrap_or(5) + } + pub fn get_user_agent(&self) -> String { + self.user_agent + .clone() + .unwrap_or("Tailcall/1.0".to_string()) + } + pub fn get_enable_http_cache(&self) -> bool { + self.http_cache.unwrap_or(false) + } + pub fn get_allowed_headers(&self) -> BTreeSet { + self.allowed_headers.clone().unwrap_or_default() + } + pub fn get_delay(&self) -> usize { + self.batch.clone().unwrap_or_default().delay + } - pub fn get_max_size(&self) -> usize { - self.batch.clone().unwrap_or_default().max_size - } + pub fn get_max_size(&self) -> usize { + self.batch.clone().unwrap_or_default().max_size + } + + pub fn get_http_2_only(&self) -> bool { + self.http2_only.unwrap_or(false) + } - pub fn get_http_2_only(&self) -> bool { - self.http2_only.unwrap_or(false) - } + // TODO: add unit tests for merge + pub fn merge_right(mut self, other: Self) -> Self { + self.allowed_headers = other.allowed_headers.map(|other| { + if let Some(mut self_headers) = self.allowed_headers { + self_headers.extend(&mut other.iter().map(|s| s.to_owned())); + self_headers + } else { + other + } + }); + self.base_url = other.base_url.or(self.base_url); + self.connect_timeout = other.connect_timeout.or(self.connect_timeout); + self.http_cache = other.http_cache.or(self.http_cache); + self.keep_alive_interval = other.keep_alive_interval.or(self.keep_alive_interval); + self.keep_alive_timeout = other.keep_alive_timeout.or(self.keep_alive_timeout); + self.keep_alive_while_idle = other.keep_alive_while_idle.or(self.keep_alive_while_idle); + self.pool_idle_timeout = other.pool_idle_timeout.or(self.pool_idle_timeout); + self.pool_max_idle_per_host = other.pool_max_idle_per_host.or(self.pool_max_idle_per_host); + self.proxy = other.proxy.or(self.proxy); + self.tcp_keep_alive = other.tcp_keep_alive.or(self.tcp_keep_alive); + self.timeout = other.timeout.or(self.timeout); + self.user_agent = other.user_agent.or(self.user_agent); - // TODO: add unit tests for merge - pub fn merge_right(mut self, other: Self) -> Self { - self.allowed_headers = other.allowed_headers.map(|other| { - if let Some(mut self_headers) = self.allowed_headers { - self_headers.extend(&mut other.iter().map(|s| s.to_owned())); - self_headers - } else { - other - } - }); - self.base_url = other.base_url.or(self.base_url); - self.connect_timeout = other.connect_timeout.or(self.connect_timeout); - self.http_cache = other.http_cache.or(self.http_cache); - self.keep_alive_interval = other.keep_alive_interval.or(self.keep_alive_interval); - self.keep_alive_timeout = other.keep_alive_timeout.or(self.keep_alive_timeout); - self.keep_alive_while_idle = other.keep_alive_while_idle.or(self.keep_alive_while_idle); - self.pool_idle_timeout = other.pool_idle_timeout.or(self.pool_idle_timeout); - self.pool_max_idle_per_host = other.pool_max_idle_per_host.or(self.pool_max_idle_per_host); - self.proxy = other.proxy.or(self.proxy); - self.tcp_keep_alive = other.tcp_keep_alive.or(self.tcp_keep_alive); - self.timeout = other.timeout.or(self.timeout); - self.user_agent = other.user_agent.or(self.user_agent); + if let Some(other) = other.batch { + let mut batch = self.batch.unwrap_or_default(); + batch.max_size = other.max_size; + batch.delay = other.delay; + batch.headers.extend(other.headers); - if let Some(other) = other.batch { - let mut batch = self.batch.unwrap_or_default(); - batch.max_size = other.max_size; - batch.delay = other.delay; - batch.headers.extend(other.headers); + self.batch = Some(batch); + } - self.batch = Some(batch); + self.http2_only = other.http2_only.or(self.http2_only); + self } - - self.http2_only = other.http2_only.or(self.http2_only); - self - } } diff --git a/src/data_loader/cache.rs b/src/data_loader/cache.rs index 319ca6c582a..085818d8c75 100644 --- a/src/data_loader/cache.rs +++ b/src/data_loader/cache.rs @@ -13,169 +13,169 @@ pub struct NoCache; impl CacheFactory for NoCache where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, { - type Storage = NoCacheImpl; + type Storage = NoCacheImpl; - fn create(&self) -> Self::Storage { - NoCacheImpl { _mark1: PhantomData, _mark2: PhantomData } - } + fn create(&self) -> Self::Storage { + NoCacheImpl { _mark1: PhantomData, _mark2: PhantomData } + } } pub struct NoCacheImpl { - _mark1: PhantomData, - _mark2: PhantomData, + _mark1: PhantomData, + _mark2: PhantomData, } impl CacheStorage for NoCacheImpl where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, { - type Key = K; - type Value = V; + type Key = K; + type Value = V; - #[inline] - fn get(&mut self, _key: &K) -> Option<&V> { - None - } + #[inline] + fn get(&mut self, _key: &K) -> Option<&V> { + None + } - #[inline] - fn insert(&mut self, _key: Cow<'_, Self::Key>, _val: Cow<'_, Self::Value>) {} + #[inline] + fn insert(&mut self, _key: Cow<'_, Self::Key>, _val: Cow<'_, Self::Value>) {} - #[inline] - fn remove(&mut self, _key: &K) {} + #[inline] + fn remove(&mut self, _key: &K) {} - #[inline] - fn clear(&mut self) {} + #[inline] + fn clear(&mut self) {} - fn iter(&self) -> Box + '_> { - Box::new(std::iter::empty()) - } + fn iter(&self) -> Box + '_> { + Box::new(std::iter::empty()) + } } /// [std::collections::HashMap] cache. pub struct HashMapCache { - _mark: PhantomData, + _mark: PhantomData, } impl HashMapCache { - /// Use specified `S: BuildHasher` to create a `HashMap` cache. - pub fn new() -> Self { - Self { _mark: PhantomData } - } + /// Use specified `S: BuildHasher` to create a `HashMap` cache. + pub fn new() -> Self { + Self { _mark: PhantomData } + } } impl Default for HashMapCache { - fn default() -> Self { - Self { _mark: PhantomData } - } + fn default() -> Self { + Self { _mark: PhantomData } + } } impl CacheFactory for HashMapCache where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, { - type Storage = HashMapCacheImpl; + type Storage = HashMapCacheImpl; - fn create(&self) -> Self::Storage { - HashMapCacheImpl(HashMap::default()) - } + fn create(&self) -> Self::Storage { + HashMapCacheImpl(HashMap::default()) + } } pub struct HashMapCacheImpl(HashMap); impl CacheStorage for HashMapCacheImpl where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, - S: Send + Sync + BuildHasher + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, + S: Send + Sync + BuildHasher + 'static, { - type Key = K; - type Value = V; - - #[inline] - fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> { - self.0.get(key) - } - - #[inline] - fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) { - self.0.insert(key.into_owned(), val.into_owned()); - } - - #[inline] - fn remove(&mut self, key: &Self::Key) { - self.0.remove(key); - } - - #[inline] - fn clear(&mut self) { - self.0.clear(); - } - - fn iter(&self) -> Box + '_> { - Box::new(self.0.iter()) - } + type Key = K; + type Value = V; + + #[inline] + fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> { + self.0.get(key) + } + + #[inline] + fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) { + self.0.insert(key.into_owned(), val.into_owned()); + } + + #[inline] + fn remove(&mut self, key: &Self::Key) { + self.0.remove(key); + } + + #[inline] + fn clear(&mut self) { + self.0.clear(); + } + + fn iter(&self) -> Box + '_> { + Box::new(self.0.iter()) + } } /// LRU cache. pub struct LruCache { - cap: usize, + cap: usize, } impl LruCache { - /// Creates a new LRU Cache that holds at most `cap` items. - pub fn new(cap: usize) -> Self { - Self { cap } - } + /// Creates a new LRU Cache that holds at most `cap` items. + pub fn new(cap: usize) -> Self { + Self { cap } + } } impl CacheFactory for LruCache where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, { - type Storage = LruCacheImpl; + type Storage = LruCacheImpl; - fn create(&self) -> Self::Storage { - LruCacheImpl(lru::LruCache::new(NonZeroUsize::new(self.cap).unwrap())) - } + fn create(&self) -> Self::Storage { + LruCacheImpl(lru::LruCache::new(NonZeroUsize::new(self.cap).unwrap())) + } } pub struct LruCacheImpl(lru::LruCache); impl CacheStorage for LruCacheImpl where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, { - type Key = K; - type Value = V; - - #[inline] - fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> { - self.0.get(key) - } - - #[inline] - fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) { - self.0.put(key.into_owned(), val.into_owned()); - } - - #[inline] - fn remove(&mut self, key: &Self::Key) { - self.0.pop(key); - } - - #[inline] - fn clear(&mut self) { - self.0.clear(); - } - - fn iter(&self) -> Box + '_> { - Box::new(self.0.iter()) - } + type Key = K; + type Value = V; + + #[inline] + fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> { + self.0.get(key) + } + + #[inline] + fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) { + self.0.put(key.into_owned(), val.into_owned()); + } + + #[inline] + fn remove(&mut self, key: &Self::Key) { + self.0.pop(key); + } + + #[inline] + fn clear(&mut self) { + self.0.clear(); + } + + fn iter(&self) -> Box + '_> { + Box::new(self.0.iter()) + } } diff --git a/src/data_loader/data_loader.rs b/src/data_loader/data_loader.rs index 82c494d0046..65187d11f07 100644 --- a/src/data_loader/data_loader.rs +++ b/src/data_loader/data_loader.rs @@ -22,536 +22,573 @@ pub use super::storage::CacheStorage; /// /// Reference: pub struct DataLoader< - K: Send + Sync + Eq + Clone + Hash + 'static, - T: Loader, - C: CacheFactory = NoCache, + K: Send + Sync + Eq + Clone + Hash + 'static, + T: Loader, + C: CacheFactory = NoCache, > { - inner: Arc>, - delay: Duration, - max_batch_size: usize, - disable_cache: AtomicBool, + inner: Arc>, + delay: Duration, + max_batch_size: usize, + disable_cache: AtomicBool, } impl DataLoader where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, { - /// Use `Loader` to create a [DataLoader] that does not cache records. - pub fn new(loader: T) -> Self { - Self { - inner: Arc::new(DataLoaderInner { requests: Mutex::new(Requests::new(&NoCache)), loader }), - delay: Duration::from_millis(1), - max_batch_size: 1000, - disable_cache: false.into(), + /// Use `Loader` to create a [DataLoader] that does not cache records. + pub fn new(loader: T) -> Self { + Self { + inner: Arc::new(DataLoaderInner { + requests: Mutex::new(Requests::new(&NoCache)), + loader, + }), + delay: Duration::from_millis(1), + max_batch_size: 1000, + disable_cache: false.into(), + } } - } } impl DataLoader where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - C: CacheFactory, -{ - /// Use `Loader` to create a [DataLoader] with a cache factory. - pub fn with_cache(loader: T, cache_factory: C) -> Self { - Self { - inner: Arc::new(DataLoaderInner { requests: Mutex::new(Requests::new(&cache_factory)), loader }), - delay: Duration::from_millis(1), - max_batch_size: 1000, - disable_cache: false.into(), - } - } - - /// Specify the delay time for loading data, the default is `1ms`. - #[must_use] - pub fn delay(self, delay: Duration) -> Self { - Self { delay, ..self } - } - - /// pub fn Specify the max batch size for loading data, the default is - /// `1000`. - /// - /// If the keys waiting to be loaded reach the threshold, they are loaded - /// immediately. - #[must_use] - pub fn max_batch_size(self, max_batch_size: usize) -> Self { - Self { max_batch_size, ..self } - } - - /// Get the loader. - #[inline] - pub fn loader(&self) -> &T { - &self.inner.loader - } - - /// Enable/Disable cache of all loaders. - pub fn enable_all_cache(&self, enable: bool) { - self.disable_cache.store(!enable, Ordering::SeqCst); - } - - /// Enable/Disable cache of specified loader. - pub fn enable_cache(&self, enable: bool) - where K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader, - { - let mut requests = self.inner.requests.lock().unwrap(); - requests.disable_cache = !enable; - } - - /// Use this `DataLoader` load a data. - #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub async fn load_one(&self, key: K) -> Result, T::Error> - where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - { - let mut values = self.load_many(std::iter::once(key.clone())).await?; - Ok(values.remove(&key)) - } - - /// Use this `DataLoader` to load some data. - #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub async fn load_many(&self, keys: I) -> Result, T::Error> - where - K: Send + Sync + Hash + Eq + Clone + 'static, - I: IntoIterator, - T: Loader, - { - enum Action> { - ImmediateLoad(KeysAndSender), - StartFetch, - Delay, + C: CacheFactory, +{ + /// Use `Loader` to create a [DataLoader] with a cache factory. + pub fn with_cache(loader: T, cache_factory: C) -> Self { + Self { + inner: Arc::new(DataLoaderInner { + requests: Mutex::new(Requests::new(&cache_factory)), + loader, + }), + delay: Duration::from_millis(1), + max_batch_size: 1000, + disable_cache: false.into(), + } } - let (action, rx) = { - let mut requests = self.inner.requests.lock().unwrap(); - let prev_count = requests.keys.len(); - let mut keys_set = HashSet::new(); - let mut use_cache_values = HashMap::new(); - - if requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) { - keys_set = keys.into_iter().collect(); - } else { - for key in keys { - if let Some(value) = requests.cache_storage.get(&key) { - // Already in cache - use_cache_values.insert(key.clone(), value.clone()); - } else { - keys_set.insert(key); - } - } - } + /// Specify the delay time for loading data, the default is `1ms`. + #[must_use] + pub fn delay(self, delay: Duration) -> Self { + Self { delay, ..self } + } + + /// pub fn Specify the max batch size for loading data, the default is + /// `1000`. + /// + /// If the keys waiting to be loaded reach the threshold, they are loaded + /// immediately. + #[must_use] + pub fn max_batch_size(self, max_batch_size: usize) -> Self { + Self { max_batch_size, ..self } + } - if !use_cache_values.is_empty() && keys_set.is_empty() { - return Ok(use_cache_values); - } else if use_cache_values.is_empty() && keys_set.is_empty() { - return Ok(Default::default()); - } + /// Get the loader. + #[inline] + pub fn loader(&self) -> &T { + &self.inner.loader + } - requests.keys.extend(keys_set.clone()); - let (tx, rx) = oneshot::channel(); - requests.pending.push((keys_set, ResSender { use_cache_values, tx })); + /// Enable/Disable cache of all loaders. + pub fn enable_all_cache(&self, enable: bool) { + self.disable_cache.store(!enable, Ordering::SeqCst); + } - if requests.keys.len() >= self.max_batch_size { - (Action::ImmediateLoad(requests.take()), rx) - } else { - ( - if !requests.keys.is_empty() && prev_count == 0 { - Action::StartFetch - } else { - Action::Delay - }, - rx, - ) - } - }; - - match action { - Action::ImmediateLoad(keys) => { - let inner = self.inner.clone(); - let disable_cache = self.disable_cache.load(Ordering::SeqCst); - let task = async move { inner.do_load(disable_cache, keys).await }; - #[cfg(feature = "tracing")] - let task = task.instrument(info_span!("immediate_load")).in_current_span(); - - #[cfg(not(target_arch = "wasm32"))] - tokio::spawn(Box::pin(task)); - #[cfg(target_arch = "wasm32")] - async_std::task::spawn_local(Box::pin(task)); - } - Action::StartFetch => { - let inner = self.inner.clone(); - let disable_cache = self.disable_cache.load(Ordering::SeqCst); - let delay = self.delay; - - let task = async move { - Delay::new(delay).await; - - let keys = { - let mut requests = inner.requests.lock().unwrap(); - requests.take() - }; - - if !keys.0.is_empty() { - inner.do_load(disable_cache, keys).await - } + /// Enable/Disable cache of specified loader. + pub fn enable_cache(&self, enable: bool) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let mut requests = self.inner.requests.lock().unwrap(); + requests.disable_cache = !enable; + } + + /// Use this `DataLoader` load a data. + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub async fn load_one(&self, key: K) -> Result, T::Error> + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let mut values = self.load_many(std::iter::once(key.clone())).await?; + Ok(values.remove(&key)) + } + + /// Use this `DataLoader` to load some data. + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub async fn load_many(&self, keys: I) -> Result, T::Error> + where + K: Send + Sync + Hash + Eq + Clone + 'static, + I: IntoIterator, + T: Loader, + { + enum Action> { + ImmediateLoad(KeysAndSender), + StartFetch, + Delay, + } + + let (action, rx) = { + let mut requests = self.inner.requests.lock().unwrap(); + let prev_count = requests.keys.len(); + let mut keys_set = HashSet::new(); + let mut use_cache_values = HashMap::new(); + + if requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) { + keys_set = keys.into_iter().collect(); + } else { + for key in keys { + if let Some(value) = requests.cache_storage.get(&key) { + // Already in cache + use_cache_values.insert(key.clone(), value.clone()); + } else { + keys_set.insert(key); + } + } + } + + if !use_cache_values.is_empty() && keys_set.is_empty() { + return Ok(use_cache_values); + } else if use_cache_values.is_empty() && keys_set.is_empty() { + return Ok(Default::default()); + } + + requests.keys.extend(keys_set.clone()); + let (tx, rx) = oneshot::channel(); + requests + .pending + .push((keys_set, ResSender { use_cache_values, tx })); + + if requests.keys.len() >= self.max_batch_size { + (Action::ImmediateLoad(requests.take()), rx) + } else { + ( + if !requests.keys.is_empty() && prev_count == 0 { + Action::StartFetch + } else { + Action::Delay + }, + rx, + ) + } }; - #[cfg(feature = "tracing")] - let task = task.instrument(info_span!("start_fetch")).in_current_span(); - #[cfg(not(target_arch = "wasm32"))] - tokio::spawn(Box::pin(task)); - #[cfg(target_arch = "wasm32")] - async_std::task::spawn_local(Box::pin(task)); - } - Action::Delay => {} + + match action { + Action::ImmediateLoad(keys) => { + let inner = self.inner.clone(); + let disable_cache = self.disable_cache.load(Ordering::SeqCst); + let task = async move { inner.do_load(disable_cache, keys).await }; + #[cfg(feature = "tracing")] + let task = task + .instrument(info_span!("immediate_load")) + .in_current_span(); + + #[cfg(not(target_arch = "wasm32"))] + tokio::spawn(Box::pin(task)); + #[cfg(target_arch = "wasm32")] + async_std::task::spawn_local(Box::pin(task)); + } + Action::StartFetch => { + let inner = self.inner.clone(); + let disable_cache = self.disable_cache.load(Ordering::SeqCst); + let delay = self.delay; + + let task = async move { + Delay::new(delay).await; + + let keys = { + let mut requests = inner.requests.lock().unwrap(); + requests.take() + }; + + if !keys.0.is_empty() { + inner.do_load(disable_cache, keys).await + } + }; + #[cfg(feature = "tracing")] + let task = task.instrument(info_span!("start_fetch")).in_current_span(); + #[cfg(not(target_arch = "wasm32"))] + tokio::spawn(Box::pin(task)); + #[cfg(target_arch = "wasm32")] + async_std::task::spawn_local(Box::pin(task)); + } + Action::Delay => {} + } + + rx.await.unwrap() } - rx.await.unwrap() - } + /// Feed some data into the cache. + /// + /// **NOTE: If the cache type is [NoCache], this function will not take + /// effect. ** + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub async fn feed_many(&self, values: I) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + I: IntoIterator, + T: Loader, + { + let mut requests = self.inner.requests.lock().unwrap(); + for (key, value) in values { + requests + .cache_storage + .insert(Cow::Owned(key), Cow::Owned(value)); + } + } - /// Feed some data into the cache. - /// - /// **NOTE: If the cache type is [NoCache], this function will not take - /// effect. ** - #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub async fn feed_many(&self, values: I) - where - K: Send + Sync + Hash + Eq + Clone + 'static, - I: IntoIterator, - T: Loader, - { - let mut requests = self.inner.requests.lock().unwrap(); - for (key, value) in values { - requests.cache_storage.insert(Cow::Owned(key), Cow::Owned(value)); + /// Feed some data into the cache. + /// + /// **NOTE: If the cache type is [NoCache], this function will not take + /// effect. ** + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub async fn feed_one(&self, key: K, value: T::Value) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + self.feed_many(std::iter::once((key, value))).await; + } + + /// Clears the cache. + /// + /// **NOTE: If the cache type is [NoCache], this function will not take + /// effect. ** + #[cfg_attr(feature = "tracing", instrument(skip_all))] + pub fn clear(&self) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let _tid = TypeId::of::(); + let mut requests = self.inner.requests.lock().unwrap(); + requests.cache_storage.clear(); + } + + /// Gets all values in the cache. + pub fn get_cached_values(&self) -> HashMap + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let _tid = TypeId::of::(); + let requests = self.inner.requests.lock().unwrap(); + requests + .cache_storage + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect() } - } - - /// Feed some data into the cache. - /// - /// **NOTE: If the cache type is [NoCache], this function will not take - /// effect. ** - #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub async fn feed_one(&self, key: K, value: T::Value) - where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - { - self.feed_many(std::iter::once((key, value))).await; - } - - /// Clears the cache. - /// - /// **NOTE: If the cache type is [NoCache], this function will not take - /// effect. ** - #[cfg_attr(feature = "tracing", instrument(skip_all))] - pub fn clear(&self) - where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - { - let _tid = TypeId::of::(); - let mut requests = self.inner.requests.lock().unwrap(); - requests.cache_storage.clear(); - } - - /// Gets all values in the cache. - pub fn get_cached_values(&self) -> HashMap - where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - { - let _tid = TypeId::of::(); - let requests = self.inner.requests.lock().unwrap(); - requests - .cache_storage - .iter() - .map(|(k, v)| (k.clone(), v.clone())) - .collect() - } } #[allow(clippy::type_complexity)] struct ResSender> { - use_cache_values: HashMap, - tx: oneshot::Sender, T::Error>>, + use_cache_values: HashMap, + tx: oneshot::Sender, T::Error>>, } -struct Requests, C: CacheFactory> { - keys: HashSet, - pending: Vec<(HashSet, ResSender)>, - cache_storage: C::Storage, - disable_cache: bool, +struct Requests< + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + C: CacheFactory, +> { + keys: HashSet, + pending: Vec<(HashSet, ResSender)>, + cache_storage: C::Storage, + disable_cache: bool, } type KeysAndSender = (HashSet, Vec<(HashSet, ResSender)>); -impl, C: CacheFactory> Requests { - fn new(cache_factory: &C) -> Self { - Self { keys: Default::default(), pending: Vec::new(), cache_storage: cache_factory.create(), disable_cache: false } - } +impl, C: CacheFactory> + Requests +{ + fn new(cache_factory: &C) -> Self { + Self { + keys: Default::default(), + pending: Vec::new(), + cache_storage: cache_factory.create(), + disable_cache: false, + } + } - fn take(&mut self) -> KeysAndSender { - (std::mem::take(&mut self.keys), std::mem::take(&mut self.pending)) - } + fn take(&mut self) -> KeysAndSender { + ( + std::mem::take(&mut self.keys), + std::mem::take(&mut self.pending), + ) + } } -struct DataLoaderInner, C: CacheFactory> { - requests: Mutex>, - loader: T, +struct DataLoaderInner< + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + C: CacheFactory, +> { + requests: Mutex>, + loader: T, } impl DataLoaderInner where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - C: CacheFactory, -{ - #[cfg_attr(feature = "tracing", instrument(skip_all))] - async fn do_load(&self, disable_cache: bool, (keys, senders): KeysAndSender) - where K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader, - { - let keys = keys.into_iter().collect::>(); - - match self.loader.load(&keys).await { - Ok(values) => { - // update cache - let mut requests = self.requests.lock().unwrap(); - let disable_cache = requests.disable_cache || disable_cache; - if !disable_cache { - for (key, value) in &values { - requests.cache_storage.insert(Cow::Borrowed(key), Cow::Borrowed(value)); - } - } - - // send response - for (keys, sender) in senders { - let mut res = HashMap::new(); - res.extend(sender.use_cache_values); - for key in &keys { - res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); - } - sender.tx.send(Ok(res)).ok(); - } - } - Err(err) => { - for (_, sender) in senders { - sender.tx.send(Err(err.clone())).ok(); + C: CacheFactory, +{ + #[cfg_attr(feature = "tracing", instrument(skip_all))] + async fn do_load(&self, disable_cache: bool, (keys, senders): KeysAndSender) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let keys = keys.into_iter().collect::>(); + + match self.loader.load(&keys).await { + Ok(values) => { + // update cache + let mut requests = self.requests.lock().unwrap(); + let disable_cache = requests.disable_cache || disable_cache; + if !disable_cache { + for (key, value) in &values { + requests + .cache_storage + .insert(Cow::Borrowed(key), Cow::Borrowed(value)); + } + } + + // send response + for (keys, sender) in senders { + let mut res = HashMap::new(); + res.extend(sender.use_cache_values); + for key in &keys { + res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); + } + sender.tx.send(Ok(res)).ok(); + } + } + Err(err) => { + for (_, sender) in senders { + sender.tx.send(Err(err.clone())).ok(); + } + } } - } } - } } #[cfg(test)] mod tests { - use std::sync::Arc; - use std::time::Duration; + use std::sync::Arc; + use std::time::Duration; + + use fnv::FnvBuildHasher; + + use super::*; + use crate::data_loader::HashMapCache; + + struct MyLoader; - use fnv::FnvBuildHasher; + #[async_trait::async_trait] + impl Loader for MyLoader { + type Value = i32; + type Error = (); - use super::*; - use crate::data_loader::HashMapCache; + async fn load(&self, keys: &[i32]) -> Result, Self::Error> { + assert!(keys.len() <= 10); + Ok(keys.iter().copied().map(|k| (k, k)).collect()) + } + } - struct MyLoader; + #[async_trait::async_trait] + impl Loader for MyLoader { + type Value = i64; + type Error = (); - #[async_trait::async_trait] - impl Loader for MyLoader { - type Value = i32; - type Error = (); + async fn load(&self, keys: &[i64]) -> Result, Self::Error> { + assert!(keys.len() <= 10); + Ok(keys.iter().copied().map(|k| (k, k)).collect()) + } + } - async fn load(&self, keys: &[i32]) -> Result, Self::Error> { - assert!(keys.len() <= 10); - Ok(keys.iter().copied().map(|k| (k, k)).collect()) + #[tokio::test] + async fn test_dataloader() { + let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); + assert_eq!( + futures_util::future::try_join_all((0..100i32).map({ + let loader = loader.clone(); + move |n| { + let loader = loader.clone(); + async move { loader.load_one(n).await } + } + })) + .await + .unwrap(), + (0..100).map(Option::Some).collect::>() + ); } - } - #[async_trait::async_trait] - impl Loader for MyLoader { - type Value = i64; - type Error = (); + #[tokio::test] + async fn test_duplicate_keys() { + let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); + assert_eq!( + futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({ + let loader = loader.clone(); + move |n| { + let loader = loader.clone(); + async move { loader.load_one(n).await } + } + })) + .await + .unwrap(), + [1, 3, 5, 1, 7, 8, 3, 7] + .iter() + .copied() + .map(Option::Some) + .collect::>() + ); + } - async fn load(&self, keys: &[i64]) -> Result, Self::Error> { - assert!(keys.len() <= 10); - Ok(keys.iter().copied().map(|k| (k, k)).collect()) + #[tokio::test] + async fn test_dataloader_load_empty() { + let loader = DataLoader::new(MyLoader); + assert!(loader + .load_many::>(vec![]) + .await + .unwrap() + .is_empty()); } - } - - #[tokio::test] - async fn test_dataloader() { - let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); - assert_eq!( - futures_util::future::try_join_all((0..100i32).map({ - let loader = loader.clone(); - move |n| { - let loader = loader.clone(); - async move { loader.load_one(n).await } - } - })) - .await - .unwrap(), - (0..100).map(Option::Some).collect::>() - ); - } - - #[tokio::test] - async fn test_duplicate_keys() { - let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); - assert_eq!( - futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({ - let loader = loader.clone(); - move |n| { - let loader = loader.clone(); - async move { loader.load_one(n).await } - } - })) - .await - .unwrap(), - [1, 3, 5, 1, 7, 8, 3, 7] - .iter() - .copied() - .map(Option::Some) - .collect::>() - ); - } - - #[tokio::test] - async fn test_dataloader_load_empty() { - let loader = DataLoader::new(MyLoader); - assert!(loader.load_many::>(vec![]).await.unwrap().is_empty()); - } - - #[tokio::test] - async fn test_dataloader_with_cache() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); - loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; - - // All from the cache - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() - ); - - // Part from the cache - assert_eq!( - loader.load_many(vec![1, 5, 6]).await.unwrap(), - vec![(1, 10), (5, 5), (6, 6)].into_iter().collect() - ); - - // All from the loader - assert_eq!( - loader.load_many(vec![8, 9, 10]).await.unwrap(), - vec![(8, 8), (9, 9), (10, 10)].into_iter().collect() - ); - - // Clear cache - loader.clear(); - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() - ); - } - - #[tokio::test] - async fn test_dataloader_with_cache_hashmap_fnv() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::::new()); - loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; - - // All from the cache - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() - ); - - // Part from the cache - assert_eq!( - loader.load_many(vec![1, 5, 6]).await.unwrap(), - vec![(1, 10), (5, 5), (6, 6)].into_iter().collect() - ); - - // All from the loader - assert_eq!( - loader.load_many(vec![8, 9, 10]).await.unwrap(), - vec![(8, 8), (9, 9), (10, 10)].into_iter().collect() - ); - - // Clear cache - loader.clear(); - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() - ); - } - - #[tokio::test] - async fn test_dataloader_disable_all_cache() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); - loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; - - // All from the loader - loader.enable_all_cache(false); - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() - ); - - // All from the cache - loader.enable_all_cache(true); - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() - ); - } - - #[tokio::test] - async fn test_dataloader_disable_cache() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); - loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; - - // All from the loader - loader.enable_cache(false); - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() - ); - - // All from the cache - loader.enable_cache(true); - assert_eq!( - loader.load_many(vec![1, 2, 3]).await.unwrap(), - vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() - ); - } - - #[tokio::test] - async fn test_dataloader_dead_lock() { - struct MyDelayLoader; - #[async_trait::async_trait] - impl Loader for MyDelayLoader { - type Value = i32; - type Error = (); - - async fn load(&self, keys: &[i32]) -> Result, Self::Error> { - tokio::time::sleep(Duration::from_secs(1)).await; - Ok(keys.iter().copied().map(|k| (k, k)).collect()) - } + #[tokio::test] + async fn test_dataloader_with_cache() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the cache + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + + // Part from the cache + assert_eq!( + loader.load_many(vec![1, 5, 6]).await.unwrap(), + vec![(1, 10), (5, 5), (6, 6)].into_iter().collect() + ); + + // All from the loader + assert_eq!( + loader.load_many(vec![8, 9, 10]).await.unwrap(), + vec![(8, 8), (9, 9), (10, 10)].into_iter().collect() + ); + + // Clear cache + loader.clear(); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); } - let loader = Arc::new(DataLoader::with_cache(MyDelayLoader, NoCache).delay(Duration::from_secs(1))); - let handle = tokio::spawn({ - let loader = loader.clone(); - async move { - loader.load_many(vec![1, 2, 3]).await.unwrap(); - } - }); - - tokio::time::sleep(Duration::from_millis(500)).await; - handle.abort(); - loader.load_many(vec![4, 5, 6]).await.unwrap(); - } + #[tokio::test] + async fn test_dataloader_with_cache_hashmap_fnv() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache::::new()); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the cache + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + + // Part from the cache + assert_eq!( + loader.load_many(vec![1, 5, 6]).await.unwrap(), + vec![(1, 10), (5, 5), (6, 6)].into_iter().collect() + ); + + // All from the loader + assert_eq!( + loader.load_many(vec![8, 9, 10]).await.unwrap(), + vec![(8, 8), (9, 9), (10, 10)].into_iter().collect() + ); + + // Clear cache + loader.clear(); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); + } + + #[tokio::test] + async fn test_dataloader_disable_all_cache() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the loader + loader.enable_all_cache(false); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); + + // All from the cache + loader.enable_all_cache(true); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + } + + #[tokio::test] + async fn test_dataloader_disable_cache() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the loader + loader.enable_cache(false); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); + + // All from the cache + loader.enable_cache(true); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + } + + #[tokio::test] + async fn test_dataloader_dead_lock() { + struct MyDelayLoader; + + #[async_trait::async_trait] + impl Loader for MyDelayLoader { + type Value = i32; + type Error = (); + + async fn load(&self, keys: &[i32]) -> Result, Self::Error> { + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(keys.iter().copied().map(|k| (k, k)).collect()) + } + } + + let loader = + Arc::new(DataLoader::with_cache(MyDelayLoader, NoCache).delay(Duration::from_secs(1))); + let handle = tokio::spawn({ + let loader = loader.clone(); + async move { + loader.load_many(vec![1, 2, 3]).await.unwrap(); + } + }); + + tokio::time::sleep(Duration::from_millis(500)).await; + handle.abort(); + loader.load_many(vec![4, 5, 6]).await.unwrap(); + } } diff --git a/src/data_loader/factory.rs b/src/data_loader/factory.rs index b31943fb4c3..f66ef31e87e 100644 --- a/src/data_loader/factory.rs +++ b/src/data_loader/factory.rs @@ -5,12 +5,12 @@ use super::storage::CacheStorage; /// Factory for creating cache storage. pub trait CacheFactory: Send + Sync + 'static where - K: Send + Sync + Clone + Eq + Hash + 'static, - V: Send + Sync + Clone + 'static, + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, { - type Storage: CacheStorage; + type Storage: CacheStorage; - /// Create a cache storage. - /// - fn create(&self) -> Self::Storage; + /// Create a cache storage. + /// + fn create(&self) -> Self::Storage; } diff --git a/src/data_loader/loader.rs b/src/data_loader/loader.rs index 703103e7859..1db5d1d44d4 100644 --- a/src/data_loader/loader.rs +++ b/src/data_loader/loader.rs @@ -4,12 +4,12 @@ use std::hash::Hash; /// Trait for batch loading. #[async_trait::async_trait] pub trait Loader: Send + Sync + 'static { - /// type of value. - type Value: Send + Sync + Clone + 'static; + /// type of value. + type Value: Send + Sync + Clone + 'static; - /// Type of error. - type Error: Send + Clone + 'static; + /// Type of error. + type Error: Send + Clone + 'static; - /// Load the data set specified by the `keys`. - async fn load(&self, keys: &[K]) -> Result, Self::Error>; + /// Load the data set specified by the `keys`. + async fn load(&self, keys: &[K]) -> Result, Self::Error>; } diff --git a/src/data_loader/storage.rs b/src/data_loader/storage.rs index 6c3b83a39c7..ea36c74f835 100644 --- a/src/data_loader/storage.rs +++ b/src/data_loader/storage.rs @@ -3,26 +3,26 @@ use std::hash::Hash; /// Cache storage for [DataLoader](crate::dataloader::DataLoader). pub trait CacheStorage: Send + Sync + 'static { - /// The key type of the record. - type Key: Send + Sync + Clone + Eq + Hash + 'static; + /// The key type of the record. + type Key: Send + Sync + Clone + Eq + Hash + 'static; - /// The value type of the record. - type Value: Send + Sync + Clone + 'static; + /// The value type of the record. + type Value: Send + Sync + Clone + 'static; - /// Returns a reference to the value of the key in the cache or None if it - /// is not present in the cache. - fn get(&mut self, key: &Self::Key) -> Option<&Self::Value>; + /// Returns a reference to the value of the key in the cache or None if it + /// is not present in the cache. + fn get(&mut self, key: &Self::Key) -> Option<&Self::Value>; - /// Puts a key-value pair into the cache. If the key already exists in the - /// cache, then it updates the key's value. - fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>); + /// Puts a key-value pair into the cache. If the key already exists in the + /// cache, then it updates the key's value. + fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>); - /// Removes the value corresponding to the key from the cache. - fn remove(&mut self, key: &Self::Key); + /// Removes the value corresponding to the key from the cache. + fn remove(&mut self, key: &Self::Key); - /// Clears the cache, removing all key-value pairs. - fn clear(&mut self); + /// Clears the cache, removing all key-value pairs. + fn clear(&mut self); - /// Returns an iterator over the key-value pairs in the cache. - fn iter(&self) -> Box + '_>; + /// Returns an iterator over the key-value pairs in the cache. + fn iter(&self) -> Box + '_>; } diff --git a/src/directive.rs b/src/directive.rs index e88bb4743cc..5c27fd3fc5f 100644 --- a/src/directive.rs +++ b/src/directive.rs @@ -10,128 +10,133 @@ use crate::blueprint; use crate::valid::{Valid, ValidationError}; fn pos(a: A) -> Positioned { - Positioned::new(a, Pos::default()) + Positioned::new(a, Pos::default()) } fn to_const_directive(directive: &blueprint::Directive) -> Valid { - Valid::from_iter(directive.arguments.iter(), |(k, v)| { - let name = pos(Name::new(k.clone())); - Valid::from( - serde_json::from_value(v.clone()) - .map(pos) - .map_err(|e| ValidationError::new(e.to_string()).trace(format!("@{}", directive.name).as_str())), - ) - .map(|value| (name, value)) - }) - .map(|arguments| ConstDirective { name: pos(Name::new(directive.name.clone())), arguments }) + Valid::from_iter(directive.arguments.iter(), |(k, v)| { + let name = pos(Name::new(k.clone())); + Valid::from(serde_json::from_value(v.clone()).map(pos).map_err(|e| { + ValidationError::new(e.to_string()).trace(format!("@{}", directive.name).as_str()) + })) + .map(|value| (name, value)) + }) + .map(|arguments| ConstDirective { name: pos(Name::new(directive.name.clone())), arguments }) } pub trait DirectiveCodec { - fn directive_name() -> String; - fn from_directive(directive: &ConstDirective) -> Valid; - fn from_blueprint_directive(directive: &blueprint::Directive) -> Valid { - to_const_directive(directive).and_then(|a| Self::from_directive(&a)) - } - fn to_directive(&self) -> ConstDirective; - fn trace_name() -> String { - format!("@{}", Self::directive_name()) - } - fn from_directives(directives: Iter<'_, Positioned>) -> Valid, String> { - for directive in directives { - if directive.node.name.node == Self::directive_name() { - return Self::from_directive(&directive.node).map(Some); - } + fn directive_name() -> String; + fn from_directive(directive: &ConstDirective) -> Valid; + fn from_blueprint_directive(directive: &blueprint::Directive) -> Valid { + to_const_directive(directive).and_then(|a| Self::from_directive(&a)) + } + fn to_directive(&self) -> ConstDirective; + fn trace_name() -> String { + format!("@{}", Self::directive_name()) + } + fn from_directives( + directives: Iter<'_, Positioned>, + ) -> Valid, String> { + for directive in directives { + if directive.node.name.node == Self::directive_name() { + return Self::from_directive(&directive.node).map(Some); + } + } + Valid::succeed(None) } - Valid::succeed(None) - } } fn lower_case_first_letter(s: String) -> String { - if s.len() <= 2 { - s.to_lowercase() - } else if let Some(first_char) = s.chars().next() { - first_char.to_string().to_lowercase() + &s[first_char.len_utf8()..] - } else { - s.to_string() - } + if s.len() <= 2 { + s.to_lowercase() + } else if let Some(first_char) = s.chars().next() { + first_char.to_string().to_lowercase() + &s[first_char.len_utf8()..] + } else { + s.to_string() + } } impl<'a, A: Deserialize<'a> + Serialize + 'a> DirectiveCodec for A { - fn directive_name() -> String { - lower_case_first_letter( - std::any::type_name::() - .split("::") - .last() - .unwrap_or_default() - .to_string(), - ) - } - - fn from_directive(directive: &ConstDirective) -> Valid { - Valid::from_iter(directive.arguments.iter(), |(k, v)| { - Valid::from( - serde_json::to_value(&v.node) - .map_err(|e| ValidationError::new(e.to_string()).trace(format!("@{}", directive.name.node).as_str())), - ) - .map(|v| (k.node.as_str().to_string(), v)) - }) - .map(|items| { - items.iter().fold(Map::new(), |mut map, (k, v)| { - map.insert(k.clone(), v.clone()); - map - }) - }) - .and_then(|map| match deserialize(Value::Object(map)) { - Ok(a) => Valid::succeed(a), - Err(e) => { - Valid::from_validation_err(ValidationError::from(e).trace(format!("@{}", directive.name.node).as_str())) - } - }) - } - - fn to_directive(&self) -> ConstDirective { - let name = Self::directive_name(); - let value = serde_json::to_value(self).unwrap(); - let default_map = &Map::new(); - let map = value.as_object().unwrap_or(default_map); - - let mut arguments = Vec::new(); - for (k, v) in map { - arguments.push(( - pos(Name::new(k.clone())), - pos(serde_json::from_value(v.to_owned()).unwrap()), - )); + fn directive_name() -> String { + lower_case_first_letter( + std::any::type_name::() + .split("::") + .last() + .unwrap_or_default() + .to_string(), + ) + } + + fn from_directive(directive: &ConstDirective) -> Valid { + Valid::from_iter(directive.arguments.iter(), |(k, v)| { + Valid::from(serde_json::to_value(&v.node).map_err(|e| { + ValidationError::new(e.to_string()) + .trace(format!("@{}", directive.name.node).as_str()) + })) + .map(|v| (k.node.as_str().to_string(), v)) + }) + .map(|items| { + items.iter().fold(Map::new(), |mut map, (k, v)| { + map.insert(k.clone(), v.clone()); + map + }) + }) + .and_then(|map| match deserialize(Value::Object(map)) { + Ok(a) => Valid::succeed(a), + Err(e) => Valid::from_validation_err( + ValidationError::from(e).trace(format!("@{}", directive.name.node).as_str()), + ), + }) } - ConstDirective { name: pos(Name::new(name)), arguments } - } + fn to_directive(&self) -> ConstDirective { + let name = Self::directive_name(); + let value = serde_json::to_value(self).unwrap(); + let default_map = &Map::new(); + let map = value.as_object().unwrap_or(default_map); + + let mut arguments = Vec::new(); + for (k, v) in map { + arguments.push(( + pos(Name::new(k.clone())), + pos(serde_json::from_value(v.to_owned()).unwrap()), + )); + } + + ConstDirective { name: pos(Name::new(name)), arguments } + } } #[cfg(test)] mod tests { - use async_graphql::parser::types::ConstDirective; - use async_graphql_value::Name; - use pretty_assertions::assert_eq; - - use crate::blueprint::Directive; - use crate::directive::{pos, to_const_directive}; - - #[test] - fn test_to_const_directive() { - let directive = Directive { - name: "test".to_string(), - arguments: vec![("a".to_string(), serde_json::json!(1.0))].into_iter().collect(), - index: 0, - }; - - let const_directive: ConstDirective = to_const_directive(&directive).to_result().unwrap(); - let expected_directive: ConstDirective = ConstDirective { - name: pos(Name::new("test")), - arguments: vec![(pos(Name::new("a")), pos(async_graphql::Value::from(1.0)))] - .into_iter() - .collect(), - }; - - assert_eq!(format!("{:?}", const_directive), format!("{:?}", expected_directive)); - } + use async_graphql::parser::types::ConstDirective; + use async_graphql_value::Name; + use pretty_assertions::assert_eq; + + use crate::blueprint::Directive; + use crate::directive::{pos, to_const_directive}; + + #[test] + fn test_to_const_directive() { + let directive = Directive { + name: "test".to_string(), + arguments: vec![("a".to_string(), serde_json::json!(1.0))] + .into_iter() + .collect(), + index: 0, + }; + + let const_directive: ConstDirective = to_const_directive(&directive).to_result().unwrap(); + let expected_directive: ConstDirective = ConstDirective { + name: pos(Name::new("test")), + arguments: vec![(pos(Name::new("a")), pos(async_graphql::Value::from(1.0)))] + .into_iter() + .collect(), + }; + + assert_eq!( + format!("{:?}", const_directive), + format!("{:?}", expected_directive) + ); + } } diff --git a/src/document.rs b/src/document.rs index 85aebace29a..d0f2c8c805c 100644 --- a/src/document.rs +++ b/src/document.rs @@ -3,277 +3,287 @@ use async_graphql::{Pos, Positioned}; use async_graphql_value::{ConstValue, Name}; fn pos(a: A) -> Positioned { - Positioned::new(a, Pos::default()) + Positioned::new(a, Pos::default()) } fn print_schema(schema: &SchemaDefinition) -> String { - let directives = schema - .directives - .iter() - .map(|d| print_directive(&const_directive_to_sdl(&d.node))) - .collect::>() - .join(" "); + let directives = schema + .directives + .iter() + .map(|d| print_directive(&const_directive_to_sdl(&d.node))) + .collect::>() + .join(" "); - let query = schema - .query - .as_ref() - .map_or(String::new(), |q| format!(" query: {}\n", q.node)); - let mutation = schema - .mutation - .as_ref() - .map_or(String::new(), |m| format!(" mutation: {}\n", m.node)); - let subscription = schema - .subscription - .as_ref() - .map_or(String::new(), |s| format!(" subscription: {}\n", s.node)); - if directives.is_empty() { - format!("schema {{\n{}{}{}}}\n", query, mutation, subscription) - } else { - format!("schema {} {{\n{}{}{}}}\n", directives, query, mutation, subscription) - } -} -fn const_directive_to_sdl(directive: &ConstDirective) -> DirectiveDefinition { - DirectiveDefinition { - description: None, - name: pos(Name::new(directive.name.node.clone())), - arguments: directive - .arguments - .iter() - .filter_map(|(name, value)| { - if value.node.clone() != ConstValue::Null { - Some(pos(InputValueDefinition { - description: None, - name: pos(Name::new(name.node.clone())), - ty: pos(Type { - nullable: true, - base: async_graphql::parser::types::BaseType::Named(Name::new(value.node.clone().to_string())), - }), - default_value: Some(pos(ConstValue::String(value.node.clone().to_string()))), - directives: Vec::new(), - })) - } else { - None - } - }) - .collect(), - is_repeatable: true, - locations: vec![], - } -} -fn print_type_def(type_def: &TypeDefinition) -> String { - match &type_def.kind { - TypeKind::Scalar => { - format!("scalar {}\n", type_def.name.node) - } - TypeKind::Union(union) => { - format!( - "union {} = {}\n", - type_def.name.node, - union - .members - .iter() - .map(|name| name.node.clone()) - .collect::>() - .join(" | ") - ) - } - TypeKind::InputObject(input) => { - format!( - "input {} {{\n{}\n}}\n", - type_def.name.node, - input - .fields - .iter() - .map(|f| print_input_value(&f.node)) - .collect::>() - .join("\n") - ) - } - TypeKind::Interface(interface) => { - let implements = if !interface.implements.is_empty() { + let query = schema + .query + .as_ref() + .map_or(String::new(), |q| format!(" query: {}\n", q.node)); + let mutation = schema + .mutation + .as_ref() + .map_or(String::new(), |m| format!(" mutation: {}\n", m.node)); + let subscription = schema + .subscription + .as_ref() + .map_or(String::new(), |s| format!(" subscription: {}\n", s.node)); + if directives.is_empty() { + format!("schema {{\n{}{}{}}}\n", query, mutation, subscription) + } else { format!( - "implements {} ", - interface - .implements - .iter() - .map(|name| name.node.clone()) - .collect::>() - .join(" & ") + "schema {} {{\n{}{}{}}}\n", + directives, query, mutation, subscription ) - } else { - String::new() - }; - format!( - "interface {} {}{{\n{}\n}}\n", - type_def.name.node, - implements, - interface - .fields - .iter() - .map(|f| print_field(&f.node)) - .collect::>() - .join("\n") - ) } - TypeKind::Object(object) => { - let implements = if !object.implements.is_empty() { - format!( - "implements {} ", - object - .implements - .iter() - .map(|name| name.node.clone()) - .collect::>() - .join(" & ") - ) - } else { - String::new() - }; - let directives = if !type_def.directives.is_empty() { - format!( - "{} ", - type_def - .directives +} +fn const_directive_to_sdl(directive: &ConstDirective) -> DirectiveDefinition { + DirectiveDefinition { + description: None, + name: pos(Name::new(directive.name.node.clone())), + arguments: directive + .arguments .iter() - .map(|d| print_directive(&const_directive_to_sdl(&d.node))) - .collect::>() - .join(" ") - ) - } else { - String::new() - }; + .filter_map(|(name, value)| { + if value.node.clone() != ConstValue::Null { + Some(pos(InputValueDefinition { + description: None, + name: pos(Name::new(name.node.clone())), + ty: pos(Type { + nullable: true, + base: async_graphql::parser::types::BaseType::Named(Name::new( + value.node.clone().to_string(), + )), + }), + default_value: Some(pos(ConstValue::String( + value.node.clone().to_string(), + ))), + directives: Vec::new(), + })) + } else { + None + } + }) + .collect(), + is_repeatable: true, + locations: vec![], + } +} +fn print_type_def(type_def: &TypeDefinition) -> String { + match &type_def.kind { + TypeKind::Scalar => { + format!("scalar {}\n", type_def.name.node) + } + TypeKind::Union(union) => { + format!( + "union {} = {}\n", + type_def.name.node, + union + .members + .iter() + .map(|name| name.node.clone()) + .collect::>() + .join(" | ") + ) + } + TypeKind::InputObject(input) => { + format!( + "input {} {{\n{}\n}}\n", + type_def.name.node, + input + .fields + .iter() + .map(|f| print_input_value(&f.node)) + .collect::>() + .join("\n") + ) + } + TypeKind::Interface(interface) => { + let implements = if !interface.implements.is_empty() { + format!( + "implements {} ", + interface + .implements + .iter() + .map(|name| name.node.clone()) + .collect::>() + .join(" & ") + ) + } else { + String::new() + }; + format!( + "interface {} {}{{\n{}\n}}\n", + type_def.name.node, + implements, + interface + .fields + .iter() + .map(|f| print_field(&f.node)) + .collect::>() + .join("\n") + ) + } + TypeKind::Object(object) => { + let implements = if !object.implements.is_empty() { + format!( + "implements {} ", + object + .implements + .iter() + .map(|name| name.node.clone()) + .collect::>() + .join(" & ") + ) + } else { + String::new() + }; + let directives = if !type_def.directives.is_empty() { + format!( + "{} ", + type_def + .directives + .iter() + .map(|d| print_directive(&const_directive_to_sdl(&d.node))) + .collect::>() + .join(" ") + ) + } else { + String::new() + }; - format!( - "type {} {}{}{{\n{}\n}}\n", - type_def.name.node, - implements, - directives, - object - .fields - .iter() - .map(|f| print_field(&f.node)) - .collect::>() - .join("\n") - ) + format!( + "type {} {}{}{{\n{}\n}}\n", + type_def.name.node, + implements, + directives, + object + .fields + .iter() + .map(|f| print_field(&f.node)) + .collect::>() + .join("\n") + ) + } + TypeKind::Enum(en) => format!( + "enum {} {{\n{}\n}}\n", + type_def.name.node, + en.values + .iter() + .map(|v| format!(" {}", v.node.value)) + .collect::>() + .join("\n") + ), + // Handle other type kinds... } - TypeKind::Enum(en) => format!( - "enum {} {{\n{}\n}}\n", - type_def.name.node, - en.values - .iter() - .map(|v| format!(" {}", v.node.value)) - .collect::>() - .join("\n") - ), - // Handle other type kinds... - } } fn print_field(field: &async_graphql::parser::types::FieldDefinition) -> String { - let directives: Vec = field - .directives - .iter() - .map(|d| print_directive(&const_directive_to_sdl(&d.node))) - .collect(); - let directives_str = if !directives.is_empty() { - format!(" {}", directives.join(" ")) - } else { - String::new() - }; + let directives: Vec = field + .directives + .iter() + .map(|d| print_directive(&const_directive_to_sdl(&d.node))) + .collect(); + let directives_str = if !directives.is_empty() { + format!(" {}", directives.join(" ")) + } else { + String::new() + }; - let args_str = if !field.arguments.is_empty() { - let args = field - .arguments - .iter() - .map(|arg| { - let nullable = if arg.node.ty.node.nullable { "" } else { "!" }; - format!("{}: {}{}", arg.node.name, arg.node.ty.node.base, nullable) - }) - .collect::>() - .join(", "); - format!("({})", args) - } else { - String::new() - }; - let doc = field.description.as_ref().map_or(String::new(), |d| { - format!(r#" """{} {}{} """{}"#, "\n", d.node, "\n", "\n") - }); - let node = &format!(" {}{}: {}{}", field.name.node, args_str, field.ty.node, directives_str); - doc + node + let args_str = if !field.arguments.is_empty() { + let args = field + .arguments + .iter() + .map(|arg| { + let nullable = if arg.node.ty.node.nullable { "" } else { "!" }; + format!("{}: {}{}", arg.node.name, arg.node.ty.node.base, nullable) + }) + .collect::>() + .join(", "); + format!("({})", args) + } else { + String::new() + }; + let doc = field.description.as_ref().map_or(String::new(), |d| { + format!(r#" """{} {}{} """{}"#, "\n", d.node, "\n", "\n") + }); + let node = &format!( + " {}{}: {}{}", + field.name.node, args_str, field.ty.node, directives_str + ); + doc + node } fn print_input_value(field: &async_graphql::parser::types::InputValueDefinition) -> String { - let directives: Vec = field - .directives - .iter() - .map(|d| print_directive(&const_directive_to_sdl(&d.node))) - .collect(); + let directives: Vec = field + .directives + .iter() + .map(|d| print_directive(&const_directive_to_sdl(&d.node))) + .collect(); - let directives_str = if !directives.is_empty() { - format!(" {}", directives.join(" ")) - } else { - String::new() - }; + let directives_str = if !directives.is_empty() { + format!(" {}", directives.join(" ")) + } else { + String::new() + }; - format!(" {}: {}{}", field.name.node, field.ty.node, directives_str) + format!(" {}: {}{}", field.name.node, field.ty.node, directives_str) } fn print_directive(directive: &DirectiveDefinition) -> String { - let args = directive - .arguments - .iter() - .map(|arg| { - let type_str = format!("{}", arg.node.ty.node); - if type_str.starts_with('[') || type_str.starts_with('{') { - let parts: Vec<&str> = type_str.split(',').collect(); - format!("{}: {}", arg.node.name.node, parts.join(", ")) - } else { - format!("{}: {}", arg.node.name.node, type_str) - } - }) - .collect::>() - .join(", "); + let args = directive + .arguments + .iter() + .map(|arg| { + let type_str = format!("{}", arg.node.ty.node); + if type_str.starts_with('[') || type_str.starts_with('{') { + let parts: Vec<&str> = type_str.split(',').collect(); + format!("{}: {}", arg.node.name.node, parts.join(", ")) + } else { + format!("{}: {}", arg.node.name.node, type_str) + } + }) + .collect::>() + .join(", "); - if args.is_empty() { - format!("@{}", directive.name.node) - } else { - format!("@{}({})", directive.name.node, args) - } + if args.is_empty() { + format!("@{}", directive.name.node) + } else { + format!("@{}({})", directive.name.node, args) + } } pub fn print(sd: ServiceDocument) -> String { - // Separate the definitions by type - let definitions_len = sd.definitions.len(); - let mut schemas = Vec::with_capacity(definitions_len); - let mut scalars = Vec::with_capacity(definitions_len); - let mut interfaces = Vec::with_capacity(definitions_len); - let mut objects = Vec::with_capacity(definitions_len); - let mut enums = Vec::with_capacity(definitions_len); - let mut unions = Vec::with_capacity(definitions_len); - let mut inputs = Vec::with_capacity(definitions_len); + // Separate the definitions by type + let definitions_len = sd.definitions.len(); + let mut schemas = Vec::with_capacity(definitions_len); + let mut scalars = Vec::with_capacity(definitions_len); + let mut interfaces = Vec::with_capacity(definitions_len); + let mut objects = Vec::with_capacity(definitions_len); + let mut enums = Vec::with_capacity(definitions_len); + let mut unions = Vec::with_capacity(definitions_len); + let mut inputs = Vec::with_capacity(definitions_len); - for def in sd.definitions.iter() { - match def { - TypeSystemDefinition::Schema(schema) => schemas.push(print_schema(&schema.node)), - TypeSystemDefinition::Type(type_def) => match &type_def.node.kind { - TypeKind::Scalar => scalars.push(print_type_def(&type_def.node)), - TypeKind::Interface(_) => interfaces.push(print_type_def(&type_def.node)), - TypeKind::Enum(_) => enums.push(print_type_def(&type_def.node)), - TypeKind::Object(_) => objects.push(print_type_def(&type_def.node)), - TypeKind::Union(_) => unions.push(print_type_def(&type_def.node)), - TypeKind::InputObject(_) => inputs.push(print_type_def(&type_def.node)), - }, - TypeSystemDefinition::Directive(_) => todo!("Directives are not supported yet"), + for def in sd.definitions.iter() { + match def { + TypeSystemDefinition::Schema(schema) => schemas.push(print_schema(&schema.node)), + TypeSystemDefinition::Type(type_def) => match &type_def.node.kind { + TypeKind::Scalar => scalars.push(print_type_def(&type_def.node)), + TypeKind::Interface(_) => interfaces.push(print_type_def(&type_def.node)), + TypeKind::Enum(_) => enums.push(print_type_def(&type_def.node)), + TypeKind::Object(_) => objects.push(print_type_def(&type_def.node)), + TypeKind::Union(_) => unions.push(print_type_def(&type_def.node)), + TypeKind::InputObject(_) => inputs.push(print_type_def(&type_def.node)), + }, + TypeSystemDefinition::Directive(_) => todo!("Directives are not supported yet"), + } } - } - // Concatenate the definitions in the desired order - let sdl_string = schemas - .into_iter() - .chain(scalars) - .chain(inputs) - .chain(interfaces) - .chain(unions) - .chain(enums) - .chain(objects) - // Chain other types as needed... - .collect::>() - .join("\n"); + // Concatenate the definitions in the desired order + let sdl_string = schemas + .into_iter() + .chain(scalars) + .chain(inputs) + .chain(interfaces) + .chain(unions) + .chain(enums) + .chain(objects) + // Chain other types as needed... + .collect::>() + .join("\n"); - sdl_string.trim_end_matches('\n').to_string() + sdl_string.trim_end_matches('\n').to_string() } diff --git a/src/endpoint.rs b/src/endpoint.rs index c715e561fbd..564cdaee1da 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -7,29 +7,29 @@ use crate::json::JsonSchema; #[derive(Clone, Debug, Setters)] pub struct Endpoint { - pub path: String, - pub query: Vec<(String, String)>, - pub method: Method, - pub input: JsonSchema, - pub output: JsonSchema, - pub headers: HeaderMap, - pub body: Option, - pub description: Option, - pub encoding: Encoding, + pub path: String, + pub query: Vec<(String, String)>, + pub method: Method, + pub input: JsonSchema, + pub output: JsonSchema, + pub headers: HeaderMap, + pub body: Option, + pub description: Option, + pub encoding: Encoding, } impl Endpoint { - pub fn new(url: String) -> Endpoint { - Self { - path: url, - query: Default::default(), - method: Default::default(), - input: Default::default(), - output: Default::default(), - headers: Default::default(), - body: Default::default(), - description: Default::default(), - encoding: Default::default(), + pub fn new(url: String) -> Endpoint { + Self { + path: url, + query: Default::default(), + method: Default::default(), + input: Default::default(), + output: Default::default(), + headers: Default::default(), + body: Default::default(), + description: Default::default(), + encoding: Default::default(), + } } - } } diff --git a/src/graphql/data_loader.rs b/src/graphql/data_loader.rs index 51db346f213..9d08b420b4e 100644 --- a/src/graphql/data_loader.rs +++ b/src/graphql/data_loader.rs @@ -12,123 +12,137 @@ use crate::http::{DataLoaderRequest, Response}; use crate::HttpIO; pub struct GraphqlDataLoader { - pub client: Arc, - pub batch: bool, + pub client: Arc, + pub batch: bool, } impl GraphqlDataLoader { - pub fn new(client: Arc, batch: bool) -> Self { - GraphqlDataLoader { client, batch } - } - - pub fn to_data_loader(self, batch: Batch) -> DataLoader { - DataLoader::new(self) - .delay(Duration::from_millis(batch.delay as u64)) - .max_batch_size(batch.max_size) - } + pub fn new(client: Arc, batch: bool) -> Self { + GraphqlDataLoader { client, batch } + } + + pub fn to_data_loader(self, batch: Batch) -> DataLoader { + DataLoader::new(self) + .delay(Duration::from_millis(batch.delay as u64)) + .max_batch_size(batch.max_size) + } } #[async_trait::async_trait] impl Loader for GraphqlDataLoader { - type Value = Response; - type Error = Arc; - - #[allow(clippy::mutable_key_type)] - async fn load( - &self, - keys: &[DataLoaderRequest], - ) -> async_graphql::Result, Self::Error> { - if self.batch { - let batched_req = create_batched_request(keys); - let result = self.client.execute(batched_req).await?.to_json(); - let hashmap = extract_responses(result, keys); - Ok(hashmap) - } else { - let results = keys.iter().map(|key| async { - let result = self.client.execute(key.to_request()).await; - (key.clone(), result) - }); - let results = join_all(results).await; - #[allow(clippy::mutable_key_type)] - let mut hashmap = HashMap::new(); - for (key, value) in results { - hashmap.insert(key, value?.to_json()?); - } - - Ok(hashmap) + type Value = Response; + type Error = Arc; + + #[allow(clippy::mutable_key_type)] + async fn load( + &self, + keys: &[DataLoaderRequest], + ) -> async_graphql::Result, Self::Error> { + if self.batch { + let batched_req = create_batched_request(keys); + let result = self.client.execute(batched_req).await?.to_json(); + let hashmap = extract_responses(result, keys); + Ok(hashmap) + } else { + let results = keys.iter().map(|key| async { + let result = self.client.execute(key.to_request()).await; + (key.clone(), result) + }); + let results = join_all(results).await; + #[allow(clippy::mutable_key_type)] + let mut hashmap = HashMap::new(); + for (key, value) in results { + hashmap.insert(key, value?.to_json()?); + } + + Ok(hashmap) + } } - } } fn collect_request_bodies(dataloader_requests: &[DataLoaderRequest]) -> String { - let batched_query = dataloader_requests - .iter() - .filter_map(|dataloader_req| { - dataloader_req - .body() - .and_then(|body| body.as_bytes()) - // PERF: conversion from bytes to string with utf8 validation - .and_then(|body| from_utf8(body).ok()) - .or(Some("")) - }) - .collect::>() - .join(","); - format!("[{}]", batched_query) + let batched_query = dataloader_requests + .iter() + .filter_map(|dataloader_req| { + dataloader_req + .body() + .and_then(|body| body.as_bytes()) + // PERF: conversion from bytes to string with utf8 validation + .and_then(|body| from_utf8(body).ok()) + .or(Some("")) + }) + .collect::>() + .join(","); + format!("[{}]", batched_query) } fn create_batched_request(dataloader_requests: &[DataLoaderRequest]) -> reqwest::Request { - let batched_query = collect_request_bodies(dataloader_requests); - - let first_req = dataloader_requests.first().unwrap(); - let mut batched_req = first_req.to_request(); - batched_req.body_mut().replace(reqwest::Body::from(batched_query)); - batched_req + let batched_query = collect_request_bodies(dataloader_requests); + + let first_req = dataloader_requests.first().unwrap(); + let mut batched_req = first_req.to_request(); + batched_req + .body_mut() + .replace(reqwest::Body::from(batched_query)); + batched_req } #[allow(clippy::mutable_key_type)] fn extract_responses( - result: Result, anyhow::Error>, - keys: &[DataLoaderRequest], + result: Result, anyhow::Error>, + keys: &[DataLoaderRequest], ) -> HashMap> { - let mut hashmap = HashMap::new(); - if let Ok(res) = result { - if let async_graphql_value::ConstValue::List(values) = res.body { - for (i, request) in keys.iter().enumerate() { - let value = values.get(i).unwrap_or(&async_graphql_value::ConstValue::Null); - hashmap.insert( - request.clone(), - Response { status: res.status, headers: res.headers.clone(), body: value.clone() }, - ); - } + let mut hashmap = HashMap::new(); + if let Ok(res) = result { + if let async_graphql_value::ConstValue::List(values) = res.body { + for (i, request) in keys.iter().enumerate() { + let value = values + .get(i) + .unwrap_or(&async_graphql_value::ConstValue::Null); + hashmap.insert( + request.clone(), + Response { + status: res.status, + headers: res.headers.clone(), + body: value.clone(), + }, + ); + } + } } - } - hashmap + hashmap } #[cfg(test)] mod tests { - use std::collections::BTreeSet; - - use reqwest::Url; - - use super::*; - use crate::http::DataLoaderRequest; - - #[test] - fn test_collect_request_bodies() { - let url = Url::parse("http://example.com").unwrap(); - let mut request1 = reqwest::Request::new(reqwest::Method::GET, url.clone()); - request1.body_mut().replace(reqwest::Body::from("a".to_string())); - let mut request2 = reqwest::Request::new(reqwest::Method::GET, url.clone()); - request2.body_mut().replace(reqwest::Body::from("b".to_string())); - let mut request3 = reqwest::Request::new(reqwest::Method::GET, url.clone()); - request3.body_mut().replace(reqwest::Body::from("c".to_string())); - - let dl_req1 = DataLoaderRequest::new(request1, BTreeSet::new()); - let dl_req2 = DataLoaderRequest::new(request2, BTreeSet::new()); - let dl_req3 = DataLoaderRequest::new(request3, BTreeSet::new()); - - let body = collect_request_bodies(&[dl_req1, dl_req2, dl_req3]); - assert_eq!(body, "[a,b,c]"); - } + use std::collections::BTreeSet; + + use reqwest::Url; + + use super::*; + use crate::http::DataLoaderRequest; + + #[test] + fn test_collect_request_bodies() { + let url = Url::parse("http://example.com").unwrap(); + let mut request1 = reqwest::Request::new(reqwest::Method::GET, url.clone()); + request1 + .body_mut() + .replace(reqwest::Body::from("a".to_string())); + let mut request2 = reqwest::Request::new(reqwest::Method::GET, url.clone()); + request2 + .body_mut() + .replace(reqwest::Body::from("b".to_string())); + let mut request3 = reqwest::Request::new(reqwest::Method::GET, url.clone()); + request3 + .body_mut() + .replace(reqwest::Body::from("c".to_string())); + + let dl_req1 = DataLoaderRequest::new(request1, BTreeSet::new()); + let dl_req2 = DataLoaderRequest::new(request2, BTreeSet::new()); + let dl_req3 = DataLoaderRequest::new(request3, BTreeSet::new()); + + let body = collect_request_bodies(&[dl_req1, dl_req2, dl_req3]); + assert_eq!(body, "[a,b,c]"); + } } diff --git a/src/graphql/request_template.rs b/src/graphql/request_template.rs index 4525374bac0..3038047ddf2 100644 --- a/src/graphql/request_template.rs +++ b/src/graphql/request_template.rs @@ -15,176 +15,179 @@ use crate::path::PathGraphql; /// RequestTemplate for GraphQL requests (See RequestTemplate documentation) #[derive(Setters, Debug, Clone)] pub struct RequestTemplate { - // TODO: should be Mustache as for other templates - pub url: String, - pub operation_type: GraphQLOperationType, - pub operation_name: String, - pub operation_arguments: Option>, - pub headers: MustacheHeaders, + // TODO: should be Mustache as for other templates + pub url: String, + pub operation_type: GraphQLOperationType, + pub operation_name: String, + pub operation_arguments: Option>, + pub headers: MustacheHeaders, } impl RequestTemplate { - fn create_headers(&self, ctx: &C) -> HeaderMap { - let mut header_map = HeaderMap::new(); + fn create_headers(&self, ctx: &C) -> HeaderMap { + let mut header_map = HeaderMap::new(); - for (k, v) in &self.headers { - if let Ok(header_value) = HeaderValue::from_str(&v.render_graphql(ctx)) { - header_map.insert(k, header_value); - } + for (k, v) in &self.headers { + if let Ok(header_value) = HeaderValue::from_str(&v.render_graphql(ctx)) { + header_map.insert(k, header_value); + } + } + + header_map } - header_map - } + fn set_headers( + &self, + mut req: reqwest::Request, + ctx: &C, + ) -> reqwest::Request { + let headers = req.headers_mut(); + let config_headers = self.create_headers(ctx); - fn set_headers(&self, mut req: reqwest::Request, ctx: &C) -> reqwest::Request { - let headers = req.headers_mut(); - let config_headers = self.create_headers(ctx); + if !config_headers.is_empty() { + headers.extend(config_headers); + } + headers.insert( + reqwest::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + headers.extend(ctx.headers().to_owned()); + req + } - if !config_headers.is_empty() { - headers.extend(config_headers); + pub fn to_request( + &self, + ctx: &C, + ) -> anyhow::Result { + let mut req = reqwest::Request::new(POST.to_hyper(), url::Url::parse(self.url.as_str())?); + req = self.set_headers(req, ctx); + req = self.set_body(req, ctx); + Ok(req) } - headers.insert( - reqwest::header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - headers.extend(ctx.headers().to_owned()); - req - } - - pub fn to_request( - &self, - ctx: &C, - ) -> anyhow::Result { - let mut req = reqwest::Request::new(POST.to_hyper(), url::Url::parse(self.url.as_str())?); - req = self.set_headers(req, ctx); - req = self.set_body(req, ctx); - Ok(req) - } - - fn set_body( - &self, - mut req: reqwest::Request, - ctx: &C, - ) -> reqwest::Request { - let operation_type = &self.operation_type; - let selection_set = ctx.selection_set().unwrap_or_default(); - let operation = self - .operation_arguments - .as_ref() - .map(|args| { - args - .iter() - .map(|(k, v)| format!(r#"{}: {}"#, k, v.render_graphql(ctx).escape_default())) - .collect::>() - .join(", ") - }) - .map(|args| format!("{}({})", self.operation_name, args)) - .unwrap_or(self.operation_name.clone()); - - let graphql_query = format!(r#"{{ "query": "{operation_type} {{ {operation} {selection_set} }}" }}"#); - - req.body_mut().replace(graphql_query.into()); - req - } - - pub fn new( - url: String, - operation_type: &GraphQLOperationType, - operation_name: &str, - args: Option<&KeyValues>, - headers: MustacheHeaders, - ) -> anyhow::Result { - let mut operation_arguments = None; - - if let Some(args) = args.as_ref() { - operation_arguments = Some( - args - .iter() - .map(|(k, v)| Ok((k.to_owned(), Mustache::parse(v)?))) - .collect::>>()?, - ); + + fn set_body( + &self, + mut req: reqwest::Request, + ctx: &C, + ) -> reqwest::Request { + let operation_type = &self.operation_type; + let selection_set = ctx.selection_set().unwrap_or_default(); + let operation = self + .operation_arguments + .as_ref() + .map(|args| { + args.iter() + .map(|(k, v)| format!(r#"{}: {}"#, k, v.render_graphql(ctx).escape_default())) + .collect::>() + .join(", ") + }) + .map(|args| format!("{}({})", self.operation_name, args)) + .unwrap_or(self.operation_name.clone()); + + let graphql_query = + format!(r#"{{ "query": "{operation_type} {{ {operation} {selection_set} }}" }}"#); + + req.body_mut().replace(graphql_query.into()); + req } - Ok(Self { - url, - operation_type: operation_type.to_owned(), - operation_name: operation_name.to_owned(), - operation_arguments, - headers, - }) - } + pub fn new( + url: String, + operation_type: &GraphQLOperationType, + operation_name: &str, + args: Option<&KeyValues>, + headers: MustacheHeaders, + ) -> anyhow::Result { + let mut operation_arguments = None; + + if let Some(args) = args.as_ref() { + operation_arguments = Some( + args.iter() + .map(|(k, v)| Ok((k.to_owned(), Mustache::parse(v)?))) + .collect::>>()?, + ); + } + + Ok(Self { + url, + operation_type: operation_type.to_owned(), + operation_name: operation_name.to_owned(), + operation_arguments, + headers, + }) + } } #[cfg(test)] mod tests { - use async_graphql::Value; - use hyper::HeaderMap; - use pretty_assertions::assert_eq; - use serde_json::json; - - use crate::config::GraphQLOperationType; - use crate::graphql::RequestTemplate; - use crate::has_headers::HasHeaders; - use crate::json::JsonLike; - use crate::lambda::GraphQLOperationContext; - use crate::path::PathGraphql; - - struct Context { - pub value: Value, - pub headers: HeaderMap, - } - - impl PathGraphql for Context { - fn path_graphql>(&self, path: &[T]) -> Option { - self.value.get_path(path).map(|v| v.to_string()) + use async_graphql::Value; + use hyper::HeaderMap; + use pretty_assertions::assert_eq; + use serde_json::json; + + use crate::config::GraphQLOperationType; + use crate::graphql::RequestTemplate; + use crate::has_headers::HasHeaders; + use crate::json::JsonLike; + use crate::lambda::GraphQLOperationContext; + use crate::path::PathGraphql; + + struct Context { + pub value: Value, + pub headers: HeaderMap, } - } - impl HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + impl PathGraphql for Context { + fn path_graphql>(&self, path: &[T]) -> Option { + self.value.get_path(path).map(|v| v.to_string()) + } } - } - impl GraphQLOperationContext for Context { - fn selection_set(&self) -> Option { - Some("{ a,b,c }".to_owned()) + impl HasHeaders for Context { + fn headers(&self) -> &HeaderMap { + &self.headers + } } - } - #[test] - fn test_query_without_args() { - let tmpl = RequestTemplate::new( - "http://localhost:3000".to_string(), - &GraphQLOperationType::Query, - "myQuery", - None, - vec![], - ) - .unwrap(); - let ctx = Context { - value: Value::from_json(json!({ - "foo": { - "bar": "baz", - "header": "abc" + impl GraphQLOperationContext for Context { + fn selection_set(&self) -> Option { + Some("{ a,b,c }".to_owned()) } - })) - .unwrap(), - headers: Default::default(), - }; - - let req = tmpl.to_request(&ctx).unwrap(); - let body = req.body().unwrap().as_bytes().unwrap().to_owned(); - - assert_eq!( - std::str::from_utf8(&body).unwrap(), - r#"{ "query": "query { myQuery { a,b,c } }" }"# - ); - } - - #[test] - fn test_query_with_args() { - let tmpl = RequestTemplate::new( + } + + #[test] + fn test_query_without_args() { + let tmpl = RequestTemplate::new( + "http://localhost:3000".to_string(), + &GraphQLOperationType::Query, + "myQuery", + None, + vec![], + ) + .unwrap(); + let ctx = Context { + value: Value::from_json(json!({ + "foo": { + "bar": "baz", + "header": "abc" + } + })) + .unwrap(), + headers: Default::default(), + }; + + let req = tmpl.to_request(&ctx).unwrap(); + let body = req.body().unwrap().as_bytes().unwrap().to_owned(); + + assert_eq!( + std::str::from_utf8(&body).unwrap(), + r#"{ "query": "query { myQuery { a,b,c } }" }"# + ); + } + + #[test] + fn test_query_with_args() { + let tmpl = RequestTemplate::new( "http://localhost:3000".to_string(), &GraphQLOperationType::Mutation, "create", @@ -196,23 +199,23 @@ mod tests { vec![], ) .unwrap(); - let ctx = Context { - value: Value::from_json(json!({ - "foo": { - "bar": "baz", - "header": "abc" - } - })) - .unwrap(), - headers: Default::default(), - }; - - let req = tmpl.to_request(&ctx).unwrap(); - let body = req.body().unwrap().as_bytes().unwrap().to_owned(); - - assert_eq!( - std::str::from_utf8(&body).unwrap(), - r#"{ "query": "mutation { create(id: \"baz\", struct: {bar: \"baz\",header: \"abc\"}) { a,b,c } }" }"# - ); - } + let ctx = Context { + value: Value::from_json(json!({ + "foo": { + "bar": "baz", + "header": "abc" + } + })) + .unwrap(), + headers: Default::default(), + }; + + let req = tmpl.to_request(&ctx).unwrap(); + let body = req.body().unwrap().as_bytes().unwrap().to_owned(); + + assert_eq!( + std::str::from_utf8(&body).unwrap(), + r#"{ "query": "mutation { create(id: \"baz\", struct: {bar: \"baz\",header: \"abc\"}) { a,b,c } }" }"# + ); + } } diff --git a/src/grpc/data_loader.rs b/src/grpc/data_loader.rs index af7190d155b..bd9fb622fe7 100644 --- a/src/grpc/data_loader.rs +++ b/src/grpc/data_loader.rs @@ -20,93 +20,98 @@ use crate::HttpIO; #[derive(Clone)] pub struct GrpcDataLoader { - pub(crate) client: Arc, - pub(crate) operation: ProtobufOperation, - pub(crate) group_by: Option, + pub(crate) client: Arc, + pub(crate) operation: ProtobufOperation, + pub(crate) group_by: Option, } impl GrpcDataLoader { - pub fn to_data_loader(self, batch: Batch) -> DataLoader { - DataLoader::new(self) - .delay(Duration::from_millis(batch.delay as u64)) - .max_batch_size(batch.max_size) - } - - async fn load_dedupe_only( - &self, - keys: &[DataLoaderRequest], - ) -> anyhow::Result>> { - let results = keys.iter().map(|key| async { - let result = match key.to_request() { - Ok(req) => execute_grpc_request(&self.client, &self.operation, req).await, - Err(error) => Err(error), - }; - - // TODO: do we have to clone keys here? join_all seems like returns the results in passed order - (key.clone(), result) - }); - - let results = join_all(results).await; - - #[allow(clippy::mutable_key_type)] - let mut hashmap = HashMap::new(); - for (key, value) in results { - hashmap.insert(key, value?); + pub fn to_data_loader(self, batch: Batch) -> DataLoader { + DataLoader::new(self) + .delay(Duration::from_millis(batch.delay as u64)) + .max_batch_size(batch.max_size) } - Ok(hashmap) - } - - async fn load_with_group_by( - &self, - group_by: &GroupBy, - keys: &[DataLoaderRequest], - ) -> Result>> { - let inputs = keys.iter().map(|key| key.template.body.as_str()); - let (multiple_body, grouped_keys) = self.operation.convert_multiple_inputs(inputs, group_by.key())?; - - let first_request = keys[0].clone(); - let multiple_request = create_grpc_request( - first_request.template.url, - first_request.template.headers, - multiple_body, - ); - - let response = execute_grpc_request(&self.client, &self.operation, multiple_request).await?; - - let path = &group_by.path(); - let response_body = response.body.group_by(path); - - let mut result = HashMap::new(); - - for (key, id) in keys.iter().zip(grouped_keys) { - let res = response.clone().body( - response_body - .get(&id) - .and_then(|a| a.first().cloned().cloned()) - .unwrap_or(ConstValue::Null), - ); - - result.insert(key.clone(), res); + async fn load_dedupe_only( + &self, + keys: &[DataLoaderRequest], + ) -> anyhow::Result>> { + let results = keys.iter().map(|key| async { + let result = match key.to_request() { + Ok(req) => execute_grpc_request(&self.client, &self.operation, req).await, + Err(error) => Err(error), + }; + + // TODO: do we have to clone keys here? join_all seems like returns the results in passed order + (key.clone(), result) + }); + + let results = join_all(results).await; + + #[allow(clippy::mutable_key_type)] + let mut hashmap = HashMap::new(); + for (key, value) in results { + hashmap.insert(key, value?); + } + + Ok(hashmap) } - Ok(result) - } + async fn load_with_group_by( + &self, + group_by: &GroupBy, + keys: &[DataLoaderRequest], + ) -> Result>> { + let inputs = keys.iter().map(|key| key.template.body.as_str()); + let (multiple_body, grouped_keys) = self + .operation + .convert_multiple_inputs(inputs, group_by.key())?; + + let first_request = keys[0].clone(); + let multiple_request = create_grpc_request( + first_request.template.url, + first_request.template.headers, + multiple_body, + ); + + let response = + execute_grpc_request(&self.client, &self.operation, multiple_request).await?; + + let path = &group_by.path(); + let response_body = response.body.group_by(path); + + let mut result = HashMap::new(); + + for (key, id) in keys.iter().zip(grouped_keys) { + let res = response.clone().body( + response_body + .get(&id) + .and_then(|a| a.first().cloned().cloned()) + .unwrap_or(ConstValue::Null), + ); + + result.insert(key.clone(), res); + } + + Ok(result) + } } #[async_trait::async_trait] impl Loader for GrpcDataLoader { - type Value = Response; - type Error = Arc; - - async fn load( - &self, - keys: &[DataLoaderRequest], - ) -> async_graphql::Result, Self::Error> { - if let Some(group_by) = &self.group_by { - self.load_with_group_by(group_by, keys).await.map_err(Arc::new) - } else { - self.load_dedupe_only(keys).await.map_err(Arc::new) + type Value = Response; + type Error = Arc; + + async fn load( + &self, + keys: &[DataLoaderRequest], + ) -> async_graphql::Result, Self::Error> { + if let Some(group_by) = &self.group_by { + self.load_with_group_by(group_by, keys) + .await + .map_err(Arc::new) + } else { + self.load_dedupe_only(keys).await.map_err(Arc::new) + } } - } } diff --git a/src/grpc/data_loader_request.rs b/src/grpc/data_loader_request.rs index 8ad15a4e8c0..55ba03db141 100644 --- a/src/grpc/data_loader_request.rs +++ b/src/grpc/data_loader_request.rs @@ -8,121 +8,121 @@ use super::request_template::RenderedRequestTemplate; #[derive(Debug, Clone, Eq)] pub struct DataLoaderRequest { - pub template: RenderedRequestTemplate, - batch_headers: BTreeSet, + pub template: RenderedRequestTemplate, + batch_headers: BTreeSet, } impl Hash for DataLoaderRequest { - fn hash(&self, state: &mut H) { - self.template.url.hash(state); - self.template.body.hash(state); - - for name in &self.batch_headers { - if let Some(value) = self.template.headers.get(name) { - name.hash(state); - value.hash(state); - } + fn hash(&self, state: &mut H) { + self.template.url.hash(state); + self.template.body.hash(state); + + for name in &self.batch_headers { + if let Some(value) = self.template.headers.get(name) { + name.hash(state); + value.hash(state); + } + } } - } } impl PartialEq for DataLoaderRequest { - fn eq(&self, other: &Self) -> bool { - let mut hasher_self = DefaultHasher::new(); - self.hash(&mut hasher_self); - let hash_self = hasher_self.finish(); + fn eq(&self, other: &Self) -> bool { + let mut hasher_self = DefaultHasher::new(); + self.hash(&mut hasher_self); + let hash_self = hasher_self.finish(); - let mut hasher_other = DefaultHasher::new(); - other.hash(&mut hasher_other); - let hash_other = hasher_other.finish(); + let mut hasher_other = DefaultHasher::new(); + other.hash(&mut hasher_other); + let hash_other = hasher_other.finish(); - hash_self == hash_other - } + hash_self == hash_other + } } impl DataLoaderRequest { - pub fn new(template: RenderedRequestTemplate, batch_headers: BTreeSet) -> Self { - Self { template, batch_headers } - } + pub fn new(template: RenderedRequestTemplate, batch_headers: BTreeSet) -> Self { + Self { template, batch_headers } + } - pub fn to_request(&self) -> Result { - self.template.to_request() - } + pub fn to_request(&self) -> Result { + self.template.to_request() + } } #[cfg(test)] mod tests { - use std::collections::BTreeSet; - use std::path::PathBuf; - - use hyper::header::{HeaderName, HeaderValue}; - use hyper::HeaderMap; - use once_cell::sync::Lazy; - use pretty_assertions::assert_eq; - use url::Url; - - use super::DataLoaderRequest; - use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; - use crate::grpc::request_template::RenderedRequestTemplate; - - static PROTOBUF_OPERATION: Lazy = Lazy::new(|| { - let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let mut test_file = root_dir.join(file!()); - - test_file.pop(); - test_file.push("tests"); - test_file.push("greetings.proto"); - - let protobuf_set = ProtobufSet::from_proto_file(&test_file).unwrap(); - let service = protobuf_set.find_service("Greeter").unwrap(); - - service.find_operation("SayHello").unwrap() - }); - - #[test] - fn dataloader_req_empty_headers() { - let batch_headers = BTreeSet::default(); - let tmpl = RenderedRequestTemplate { - url: Url::parse("http://localhost:3000/").unwrap(), - headers: HeaderMap::new(), - operation: PROTOBUF_OPERATION.clone(), - body: "{}".to_owned(), - }; - - let dl_req_1 = DataLoaderRequest::new(tmpl.clone(), batch_headers.clone()); - let dl_req_2 = DataLoaderRequest::new(tmpl.clone(), batch_headers); - - assert_eq!(dl_req_1, dl_req_2); - } - - #[test] - fn dataloader_req_batch_headers() { - let batch_headers = BTreeSet::from_iter(["test-header".to_owned()]); - let tmpl_1 = RenderedRequestTemplate { - url: Url::parse("http://localhost:3000/").unwrap(), - headers: HeaderMap::from_iter([( - HeaderName::from_static("test-header"), - HeaderValue::from_static("value1"), - )]), - operation: PROTOBUF_OPERATION.clone(), - body: "{}".to_owned(), - }; - let tmpl_2 = tmpl_1.clone(); - - let dl_req_1 = DataLoaderRequest::new(tmpl_1.clone(), batch_headers.clone()); - let dl_req_2 = DataLoaderRequest::new(tmpl_2, batch_headers.clone()); - - assert_eq!(dl_req_1, dl_req_2); - - let tmpl_2 = RenderedRequestTemplate { - headers: HeaderMap::from_iter([( - HeaderName::from_static("test-header"), - HeaderValue::from_static("value2"), - )]), - ..tmpl_1.clone() - }; - let dl_req_2 = DataLoaderRequest::new(tmpl_2, batch_headers.clone()); - - assert_ne!(dl_req_1, dl_req_2); - } + use std::collections::BTreeSet; + use std::path::PathBuf; + + use hyper::header::{HeaderName, HeaderValue}; + use hyper::HeaderMap; + use once_cell::sync::Lazy; + use pretty_assertions::assert_eq; + use url::Url; + + use super::DataLoaderRequest; + use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; + use crate::grpc::request_template::RenderedRequestTemplate; + + static PROTOBUF_OPERATION: Lazy = Lazy::new(|| { + let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut test_file = root_dir.join(file!()); + + test_file.pop(); + test_file.push("tests"); + test_file.push("greetings.proto"); + + let protobuf_set = ProtobufSet::from_proto_file(&test_file).unwrap(); + let service = protobuf_set.find_service("Greeter").unwrap(); + + service.find_operation("SayHello").unwrap() + }); + + #[test] + fn dataloader_req_empty_headers() { + let batch_headers = BTreeSet::default(); + let tmpl = RenderedRequestTemplate { + url: Url::parse("http://localhost:3000/").unwrap(), + headers: HeaderMap::new(), + operation: PROTOBUF_OPERATION.clone(), + body: "{}".to_owned(), + }; + + let dl_req_1 = DataLoaderRequest::new(tmpl.clone(), batch_headers.clone()); + let dl_req_2 = DataLoaderRequest::new(tmpl.clone(), batch_headers); + + assert_eq!(dl_req_1, dl_req_2); + } + + #[test] + fn dataloader_req_batch_headers() { + let batch_headers = BTreeSet::from_iter(["test-header".to_owned()]); + let tmpl_1 = RenderedRequestTemplate { + url: Url::parse("http://localhost:3000/").unwrap(), + headers: HeaderMap::from_iter([( + HeaderName::from_static("test-header"), + HeaderValue::from_static("value1"), + )]), + operation: PROTOBUF_OPERATION.clone(), + body: "{}".to_owned(), + }; + let tmpl_2 = tmpl_1.clone(); + + let dl_req_1 = DataLoaderRequest::new(tmpl_1.clone(), batch_headers.clone()); + let dl_req_2 = DataLoaderRequest::new(tmpl_2, batch_headers.clone()); + + assert_eq!(dl_req_1, dl_req_2); + + let tmpl_2 = RenderedRequestTemplate { + headers: HeaderMap::from_iter([( + HeaderName::from_static("test-header"), + HeaderValue::from_static("value2"), + )]), + ..tmpl_1.clone() + }; + let dl_req_2 = DataLoaderRequest::new(tmpl_2, batch_headers.clone()); + + assert_ne!(dl_req_1, dl_req_2); + } } diff --git a/src/grpc/protobuf.rs b/src/grpc/protobuf.rs index b586be657eb..0b3f8cc44d2 100644 --- a/src/grpc/protobuf.rs +++ b/src/grpc/protobuf.rs @@ -6,344 +6,385 @@ use anyhow::{anyhow, bail, Context, Result}; use async_graphql::Value; use prost::bytes::BufMut; use prost::Message; -use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor, MethodDescriptor, ServiceDescriptor}; +use prost_reflect::{ + DescriptorPool, DynamicMessage, MessageDescriptor, MethodDescriptor, ServiceDescriptor, +}; use serde_json::Deserializer; fn to_message(descriptor: &MessageDescriptor, input: &str) -> Result { - let mut deserializer = Deserializer::from_str(input); - let message = DynamicMessage::deserialize(descriptor.clone(), &mut deserializer) - .with_context(|| format!("Failed to parse input according to type {}", descriptor.full_name()))?; - deserializer.end()?; - - Ok(message) + let mut deserializer = Deserializer::from_str(input); + let message = + DynamicMessage::deserialize(descriptor.clone(), &mut deserializer).with_context(|| { + format!( + "Failed to parse input according to type {}", + descriptor.full_name() + ) + })?; + deserializer.end()?; + + Ok(message) } fn message_to_bytes(message: DynamicMessage) -> Result> { - let mut buf: Vec = Vec::with_capacity(message.encoded_len() + 5); - // set compression flag - buf.put_u8(0); - // next 4 bytes should encode message length - buf.put_u32(message.encoded_len() as u32); - // encode the message itself - message.encode(&mut buf)?; - - Ok(buf) + let mut buf: Vec = Vec::with_capacity(message.encoded_len() + 5); + // set compression flag + buf.put_u8(0); + // next 4 bytes should encode message length + buf.put_u32(message.encoded_len() as u32); + // encode the message itself + message.encode(&mut buf)?; + + Ok(buf) } pub fn protobuf_value_as_str(value: &prost_reflect::Value) -> String { - use prost_reflect::Value; - - match value { - Value::I32(v) => v.to_string(), - Value::I64(v) => v.to_string(), - Value::U32(v) => v.to_string(), - Value::U64(v) => v.to_string(), - Value::F32(v) => v.to_string(), - Value::F64(v) => v.to_string(), - Value::EnumNumber(v) => v.to_string(), - Value::String(s) => s.clone(), - _ => Default::default(), - } + use prost_reflect::Value; + + match value { + Value::I32(v) => v.to_string(), + Value::I64(v) => v.to_string(), + Value::U32(v) => v.to_string(), + Value::U64(v) => v.to_string(), + Value::F32(v) => v.to_string(), + Value::F64(v) => v.to_string(), + Value::EnumNumber(v) => v.to_string(), + Value::String(s) => s.clone(), + _ => Default::default(), + } } pub fn get_field_value_as_str(message: &DynamicMessage, field_name: &str) -> Result { - let field = message - .get_field_by_name(field_name) - .ok_or(anyhow!("Unable to find key"))?; + let field = message + .get_field_by_name(field_name) + .ok_or(anyhow!("Unable to find key"))?; - Ok(protobuf_value_as_str(&field)) + Ok(protobuf_value_as_str(&field)) } #[derive(Debug)] pub struct ProtobufSet { - descriptor_pool: DescriptorPool, + descriptor_pool: DescriptorPool, } // TODO: support for reflection impl ProtobufSet { - // TODO: load definitions from proto file for now, but in future - // it could be more convenient to load FileDescriptorSet instead - // either from file or server reflection - pub fn from_proto_file(proto_path: &Path) -> Result { - let proto_path = if proto_path.is_relative() { - let dir = current_dir()?; - - dir.join(proto_path) - } else { - PathBuf::from(proto_path) - }; - - let parent_dir = proto_path - .parent() - .context("Failed to resolve parent dir for proto file")?; - - let file_descriptor_set = protox::compile([proto_path.as_path()], [parent_dir]) - .with_context(|| "Failed to parse or load proto file".to_string())?; - - let descriptor_pool = DescriptorPool::from_file_descriptor_set(file_descriptor_set)?; - - Ok(Self { descriptor_pool }) - } - - pub fn find_service(&self, name: &str) -> Result { - let service_descriptor = self - .descriptor_pool - .get_service_by_name(name) - .with_context(|| format!("Couldn't find definitions for service {name}"))?; - - Ok(ProtobufService { service_descriptor }) - } + // TODO: load definitions from proto file for now, but in future + // it could be more convenient to load FileDescriptorSet instead + // either from file or server reflection + pub fn from_proto_file(proto_path: &Path) -> Result { + let proto_path = if proto_path.is_relative() { + let dir = current_dir()?; + + dir.join(proto_path) + } else { + PathBuf::from(proto_path) + }; + + let parent_dir = proto_path + .parent() + .context("Failed to resolve parent dir for proto file")?; + + let file_descriptor_set = protox::compile([proto_path.as_path()], [parent_dir]) + .with_context(|| "Failed to parse or load proto file".to_string())?; + + let descriptor_pool = DescriptorPool::from_file_descriptor_set(file_descriptor_set)?; + + Ok(Self { descriptor_pool }) + } + + pub fn find_service(&self, name: &str) -> Result { + let service_descriptor = self + .descriptor_pool + .get_service_by_name(name) + .with_context(|| format!("Couldn't find definitions for service {name}"))?; + + Ok(ProtobufService { service_descriptor }) + } } #[derive(Debug)] pub struct ProtobufService { - service_descriptor: ServiceDescriptor, + service_descriptor: ServiceDescriptor, } impl ProtobufService { - pub fn find_operation(&self, method_name: &str) -> Result { - let method = self - .service_descriptor - .methods() - .find(|method| method.name() == method_name) - .with_context(|| format!("Couldn't find method {method_name}"))?; - - let input_type = method.input(); - let output_type = method.output(); - - Ok(ProtobufOperation { method, input_type, output_type }) - } + pub fn find_operation(&self, method_name: &str) -> Result { + let method = self + .service_descriptor + .methods() + .find(|method| method.name() == method_name) + .with_context(|| format!("Couldn't find method {method_name}"))?; + + let input_type = method.input(); + let output_type = method.output(); + + Ok(ProtobufOperation { method, input_type, output_type }) + } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ProtobufOperation { - method: MethodDescriptor, - pub input_type: MessageDescriptor, - pub output_type: MessageDescriptor, + method: MethodDescriptor, + pub input_type: MessageDescriptor, + pub output_type: MessageDescriptor, } // TODO: support compression impl ProtobufOperation { - pub fn name(&self) -> &str { - self.method.name() - } - - pub fn service_name(&self) -> &str { - self.method.parent_service().name() - } - - pub fn convert_input(&self, input: &str) -> Result> { - let message = to_message(&self.input_type, input)?; - - message_to_bytes(message) - } - - pub fn convert_multiple_inputs<'a>( - &self, - child_inputs: impl Iterator, - id: &str, - ) -> Result<(Vec, Vec)> { - // Find the field of list type that should hold child messages - let field_descriptor = self - .input_type - .fields() - .find(|field| field.is_list()) - .ok_or(anyhow!("Unable to find list field on type"))?; - - let field_kind = field_descriptor.kind(); - let child_message_descriptor = field_kind.as_message().ok_or(anyhow!("Couldn't resolve message"))?; - let mut message = DynamicMessage::new(self.input_type.clone()); - - let child_messages = child_inputs - .map(|input| to_message(child_message_descriptor, input)) - .collect::>>()?; - - let ids = child_messages - .iter() - .map(|message| get_field_value_as_str(message, id)) - .collect::>>()?; - - message.set_field( - &field_descriptor, - prost_reflect::Value::List(child_messages.into_iter().map(prost_reflect::Value::Message).collect()), - ); - - message_to_bytes(message).map(|result| (result, ids)) - } - - pub fn convert_output(&self, bytes: &[u8]) -> Result { - if bytes.len() < 5 { - bail!("Empty response"); + pub fn name(&self) -> &str { + self.method.name() + } + + pub fn service_name(&self) -> &str { + self.method.parent_service().name() } - // ignore 5 first bytes as they are part of Length-Prefixed Message Framing - // see https://www.oreilly.com/library/view/grpc-up-and/9781492058328/ch04.html#:~:text=Length%2DPrefixed%20Message%20Framing - // 1st byte - compression flag - // 2-4th bytes - length of the message - let message = DynamicMessage::decode(self.output_type.clone(), &bytes[5..]) - .with_context(|| format!("Failed to parse response for type {}", self.output_type.full_name()))?; - let json = serde_json::to_value(message)?; + pub fn convert_input(&self, input: &str) -> Result> { + let message = to_message(&self.input_type, input)?; + + message_to_bytes(message) + } + + pub fn convert_multiple_inputs<'a>( + &self, + child_inputs: impl Iterator, + id: &str, + ) -> Result<(Vec, Vec)> { + // Find the field of list type that should hold child messages + let field_descriptor = self + .input_type + .fields() + .find(|field| field.is_list()) + .ok_or(anyhow!("Unable to find list field on type"))?; + + let field_kind = field_descriptor.kind(); + let child_message_descriptor = field_kind + .as_message() + .ok_or(anyhow!("Couldn't resolve message"))?; + let mut message = DynamicMessage::new(self.input_type.clone()); + + let child_messages = child_inputs + .map(|input| to_message(child_message_descriptor, input)) + .collect::>>()?; + + let ids = child_messages + .iter() + .map(|message| get_field_value_as_str(message, id)) + .collect::>>()?; + + message.set_field( + &field_descriptor, + prost_reflect::Value::List( + child_messages + .into_iter() + .map(prost_reflect::Value::Message) + .collect(), + ), + ); + + message_to_bytes(message).map(|result| (result, ids)) + } - Ok(async_graphql::Value::from_json(json)?) - } + pub fn convert_output(&self, bytes: &[u8]) -> Result { + if bytes.len() < 5 { + bail!("Empty response"); + } + // ignore 5 first bytes as they are part of Length-Prefixed Message Framing + // see https://www.oreilly.com/library/view/grpc-up-and/9781492058328/ch04.html#:~:text=Length%2DPrefixed%20Message%20Framing + // 1st byte - compression flag + // 2-4th bytes - length of the message + let message = + DynamicMessage::decode(self.output_type.clone(), &bytes[5..]).with_context(|| { + format!( + "Failed to parse response for type {}", + self.output_type.full_name() + ) + })?; + + let json = serde_json::to_value(message)?; + + Ok(async_graphql::Value::from_json(json)?) + } } #[cfg(test)] mod tests { - use std::path::PathBuf; - - use anyhow::Result; - use once_cell::sync::Lazy; - use prost_reflect::Value; - use serde_json::json; - - use super::*; - - static TEST_DIR: Lazy = Lazy::new(|| { - let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let mut test_dir = root_dir.join(file!()); + use std::path::PathBuf; + + use anyhow::Result; + use once_cell::sync::Lazy; + use prost_reflect::Value; + use serde_json::json; + + use super::*; - test_dir.pop(); - test_dir.push("tests"); - - test_dir - }); + static TEST_DIR: Lazy = Lazy::new(|| { + let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut test_dir = root_dir.join(file!()); - fn get_test_file(name: &str) -> PathBuf { - let mut test_file = TEST_DIR.clone(); + test_dir.pop(); + test_dir.push("tests"); - test_file.push(name); - test_file - } - - #[test] - fn convert_value() { - assert_eq!( - protobuf_value_as_str(&Value::String("test string".to_owned())), - "test string".to_owned() - ); - assert_eq!(protobuf_value_as_str(&Value::I32(25)), "25".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::F32(1.25)), "1.25".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::I64(35)), "35".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::F64(3.38)), "3.38".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::EnumNumber(55)), "55".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::Bool(false)), "".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::Map(Default::default())), "".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::List(Default::default())), "".to_owned()); - assert_eq!(protobuf_value_as_str(&Value::Bytes(Default::default())), "".to_owned()); - } + test_dir + }); - #[test] - fn unknown_file() -> Result<()> { - let proto_file = get_test_file("_unknown.proto"); - let error = ProtobufSet::from_proto_file(&proto_file).unwrap_err(); - - assert_eq!(error.to_string(), format!("Failed to parse or load proto file")); - - Ok(()) - } - - #[test] - fn service_not_found() -> Result<()> { - let proto_file = get_test_file("greetings.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; - let error = file.find_service("_unknown").unwrap_err(); - - assert_eq!(error.to_string(), "Couldn't find definitions for service _unknown"); - - Ok(()) - } - - #[test] - fn method_not_found() -> Result<()> { - let proto_file = get_test_file("greetings.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; - let service = file.find_service("Greeter")?; - let error = service.find_operation("_unknown").unwrap_err(); - - assert_eq!(error.to_string(), "Couldn't find method _unknown"); - - Ok(()) - } - - #[test] - fn greetings_proto_file() -> Result<()> { - let proto_file = get_test_file("greetings.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; - let service = file.find_service("Greeter")?; - let operation = service.find_operation("SayHello")?; - - let output = b"\0\0\0\0\x0e\n\x0ctest message"; - - let parsed = operation.convert_output(output)?; - - assert_eq!( - serde_json::to_value(parsed)?, - json!({ - "message": "test message" - }) - ); - - Ok(()) - } - - #[test] - fn news_proto_file() -> Result<()> { - let proto_file = get_test_file("news.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; - let service = file.find_service("NewsService")?; - let operation = service.find_operation("GetNews")?; - - let input = operation.convert_input(r#"{ "id": 1 }"#)?; - - assert_eq!(input, b"\0\0\0\0\x02\x08\x01"); - - let output = b"\0\0\0\x005\x08\x01\x12\x06Note 1\x1a\tContent 1\"\x0cPost image 1"; - - let parsed = operation.convert_output(output)?; - - assert_eq!( - serde_json::to_value(parsed)?, - json!({ - "id": 1, "title": "Note 1", "body": "Content 1", "postImage": "Post image 1" - }) - ); - - Ok(()) - } - - #[test] - fn news_proto_file_multiple_messages() -> Result<()> { - let proto_file = get_test_file("news.proto"); - let file = ProtobufSet::from_proto_file(&proto_file)?; - let service = file.find_service("NewsService")?; - let multiple_operation = service.find_operation("GetMultipleNews")?; - - let child_messages = vec![r#"{ "id": 3 }"#, r#"{ "id": 5 }"#, r#"{ "id": 1 }"#]; - - let (multiple_message, grouped) = multiple_operation.convert_multiple_inputs(child_messages.into_iter(), "id")?; - - assert_eq!( - multiple_message, - b"\0\0\0\0\x0c\n\x02\x08\x03\n\x02\x08\x05\n\x02\x08\x01" - ); - assert_eq!(grouped, vec!["3".to_owned(), "5".to_owned(), "1".to_owned()]); + fn get_test_file(name: &str) -> PathBuf { + let mut test_file = TEST_DIR.clone(); - let output = b"\0\0\0\0o\n#\x08\x01\x12\x06Note 1\x1a\tContent 1\"\x0cPost image 1\n#\x08\x03\x12\x06Note 3\x1a\tContent 3\"\x0cPost image 3\n#\x08\x05\x12\x06Note 5\x1a\tContent 5\"\x0cPost image 5"; + test_file.push(name); + test_file + } + + #[test] + fn convert_value() { + assert_eq!( + protobuf_value_as_str(&Value::String("test string".to_owned())), + "test string".to_owned() + ); + assert_eq!(protobuf_value_as_str(&Value::I32(25)), "25".to_owned()); + assert_eq!(protobuf_value_as_str(&Value::F32(1.25)), "1.25".to_owned()); + assert_eq!(protobuf_value_as_str(&Value::I64(35)), "35".to_owned()); + assert_eq!(protobuf_value_as_str(&Value::F64(3.38)), "3.38".to_owned()); + assert_eq!( + protobuf_value_as_str(&Value::EnumNumber(55)), + "55".to_owned() + ); + assert_eq!(protobuf_value_as_str(&Value::Bool(false)), "".to_owned()); + assert_eq!( + protobuf_value_as_str(&Value::Map(Default::default())), + "".to_owned() + ); + assert_eq!( + protobuf_value_as_str(&Value::List(Default::default())), + "".to_owned() + ); + assert_eq!( + protobuf_value_as_str(&Value::Bytes(Default::default())), + "".to_owned() + ); + } - let parsed = multiple_operation.convert_output(output)?; + #[test] + fn unknown_file() -> Result<()> { + let proto_file = get_test_file("_unknown.proto"); + let error = ProtobufSet::from_proto_file(&proto_file).unwrap_err(); - assert_eq!( - serde_json::to_value(parsed)?, - json!({ - "news": [ - { "id": 1, "title": "Note 1", "body": "Content 1", "postImage": "Post image 1" }, - { "id": 3, "title": "Note 3", "body": "Content 3", "postImage": "Post image 3" }, - { "id": 5, "title": "Note 5", "body": "Content 5", "postImage": "Post image 5" }, - ] - }) - ); + assert_eq!( + error.to_string(), + format!("Failed to parse or load proto file") + ); - Ok(()) - } + Ok(()) + } + + #[test] + fn service_not_found() -> Result<()> { + let proto_file = get_test_file("greetings.proto"); + let file = ProtobufSet::from_proto_file(&proto_file)?; + let error = file.find_service("_unknown").unwrap_err(); + + assert_eq!( + error.to_string(), + "Couldn't find definitions for service _unknown" + ); + + Ok(()) + } + + #[test] + fn method_not_found() -> Result<()> { + let proto_file = get_test_file("greetings.proto"); + let file = ProtobufSet::from_proto_file(&proto_file)?; + let service = file.find_service("Greeter")?; + let error = service.find_operation("_unknown").unwrap_err(); + + assert_eq!(error.to_string(), "Couldn't find method _unknown"); + + Ok(()) + } + + #[test] + fn greetings_proto_file() -> Result<()> { + let proto_file = get_test_file("greetings.proto"); + let file = ProtobufSet::from_proto_file(&proto_file)?; + let service = file.find_service("Greeter")?; + let operation = service.find_operation("SayHello")?; + + let output = b"\0\0\0\0\x0e\n\x0ctest message"; + + let parsed = operation.convert_output(output)?; + + assert_eq!( + serde_json::to_value(parsed)?, + json!({ + "message": "test message" + }) + ); + + Ok(()) + } + + #[test] + fn news_proto_file() -> Result<()> { + let proto_file = get_test_file("news.proto"); + let file = ProtobufSet::from_proto_file(&proto_file)?; + let service = file.find_service("NewsService")?; + let operation = service.find_operation("GetNews")?; + + let input = operation.convert_input(r#"{ "id": 1 }"#)?; + + assert_eq!(input, b"\0\0\0\0\x02\x08\x01"); + + let output = b"\0\0\0\x005\x08\x01\x12\x06Note 1\x1a\tContent 1\"\x0cPost image 1"; + + let parsed = operation.convert_output(output)?; + + assert_eq!( + serde_json::to_value(parsed)?, + json!({ + "id": 1, "title": "Note 1", "body": "Content 1", "postImage": "Post image 1" + }) + ); + + Ok(()) + } + + #[test] + fn news_proto_file_multiple_messages() -> Result<()> { + let proto_file = get_test_file("news.proto"); + let file = ProtobufSet::from_proto_file(&proto_file)?; + let service = file.find_service("NewsService")?; + let multiple_operation = service.find_operation("GetMultipleNews")?; + + let child_messages = vec![r#"{ "id": 3 }"#, r#"{ "id": 5 }"#, r#"{ "id": 1 }"#]; + + let (multiple_message, grouped) = + multiple_operation.convert_multiple_inputs(child_messages.into_iter(), "id")?; + + assert_eq!( + multiple_message, + b"\0\0\0\0\x0c\n\x02\x08\x03\n\x02\x08\x05\n\x02\x08\x01" + ); + assert_eq!( + grouped, + vec!["3".to_owned(), "5".to_owned(), "1".to_owned()] + ); + + let output = b"\0\0\0\0o\n#\x08\x01\x12\x06Note 1\x1a\tContent 1\"\x0cPost image 1\n#\x08\x03\x12\x06Note 3\x1a\tContent 3\"\x0cPost image 3\n#\x08\x05\x12\x06Note 5\x1a\tContent 5\"\x0cPost image 5"; + + let parsed = multiple_operation.convert_output(output)?; + + assert_eq!( + serde_json::to_value(parsed)?, + json!({ + "news": [ + { "id": 1, "title": "Note 1", "body": "Content 1", "postImage": "Post image 1" }, + { "id": 3, "title": "Note 3", "body": "Content 3", "postImage": "Post image 3" }, + { "id": 5, "title": "Note 5", "body": "Content 5", "postImage": "Post image 5" }, + ] + }) + ); + + Ok(()) + } } diff --git a/src/grpc/request.rs b/src/grpc/request.rs index 4a2224978f0..d2a8635f7b8 100644 --- a/src/grpc/request.rs +++ b/src/grpc/request.rs @@ -10,23 +10,23 @@ use crate::http::Response; use crate::HttpIO; pub fn create_grpc_request(url: Url, headers: HeaderMap, body: Vec) -> Request { - let mut req = Request::new(Method::POST, url); - req.headers_mut().extend(headers.clone()); - req.body_mut().replace(body.into()); + let mut req = Request::new(Method::POST, url); + req.headers_mut().extend(headers.clone()); + req.body_mut().replace(body.into()); - req + req } pub async fn execute_grpc_request( - client: &Arc, - operation: &ProtobufOperation, - request: Request, + client: &Arc, + operation: &ProtobufOperation, + request: Request, ) -> Result> { - let response = client.execute(request).await?; + let response = client.execute(request).await?; - if response.status.is_success() { - return response.to_grpc_value(operation); - } + if response.status.is_success() { + return response.to_grpc_value(operation); + } - bail!("Failed to execute request") + bail!("Failed to execute request") } diff --git a/src/grpc/request_template.rs b/src/grpc/request_template.rs index ef7dc16175b..3c0bf71f428 100644 --- a/src/grpc/request_template.rs +++ b/src/grpc/request_template.rs @@ -17,190 +17,190 @@ static GRPC_MIME_TYPE: HeaderValue = HeaderValue::from_static("application/grpc" #[derive(Setters, Debug, Clone)] pub struct RequestTemplate { - pub url: Mustache, - pub headers: MustacheHeaders, - pub body: Option, - pub operation: ProtobufOperation, - pub operation_type: GraphQLOperationType, + pub url: Mustache, + pub headers: MustacheHeaders, + pub body: Option, + pub operation: ProtobufOperation, + pub operation_type: GraphQLOperationType, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct RenderedRequestTemplate { - pub url: Url, - pub headers: HeaderMap, - pub body: String, - pub operation: ProtobufOperation, + pub url: Url, + pub headers: HeaderMap, + pub body: String, + pub operation: ProtobufOperation, } impl RequestTemplate { - fn create_url(&self, ctx: &C) -> Result { - let url = url::Url::parse(self.url.render(ctx).as_str())?; + fn create_url(&self, ctx: &C) -> Result { + let url = url::Url::parse(self.url.render(ctx).as_str())?; - Ok(url) - } + Ok(url) + } - fn create_headers(&self, ctx: &C) -> HeaderMap { - let mut header_map = HeaderMap::new(); + fn create_headers(&self, ctx: &C) -> HeaderMap { + let mut header_map = HeaderMap::new(); - header_map.insert(CONTENT_TYPE, GRPC_MIME_TYPE.to_owned()); + header_map.insert(CONTENT_TYPE, GRPC_MIME_TYPE.to_owned()); - for (k, v) in &self.headers { - if let Ok(header_value) = HeaderValue::from_str(&v.render(ctx)) { - header_map.insert(k, header_value); - } - } + for (k, v) in &self.headers { + if let Ok(header_value) = HeaderValue::from_str(&v.render(ctx)) { + header_map.insert(k, header_value); + } + } - header_map - } - - pub fn render(&self, ctx: &C) -> Result { - let url = self.create_url(ctx)?; - let headers = self.render_headers(ctx); - let body = self.render_body(ctx); - Ok(RenderedRequestTemplate { url, headers, body, operation: self.operation.clone() }) - } - - fn render_body(&self, ctx: &C) -> String { - if let Some(body) = &self.body { - body.render(ctx) - } else { - "{}".to_owned() + header_map } - } - fn render_headers(&self, ctx: &C) -> HeaderMap { - let mut req_headers = HeaderMap::new(); + pub fn render(&self, ctx: &C) -> Result { + let url = self.create_url(ctx)?; + let headers = self.render_headers(ctx); + let body = self.render_body(ctx); + Ok(RenderedRequestTemplate { url, headers, body, operation: self.operation.clone() }) + } - let headers = self.create_headers(ctx); - if !headers.is_empty() { - req_headers.extend(headers); + fn render_body(&self, ctx: &C) -> String { + if let Some(body) = &self.body { + body.render(ctx) + } else { + "{}".to_owned() + } } - req_headers.extend(ctx.headers().to_owned()); + fn render_headers(&self, ctx: &C) -> HeaderMap { + let mut req_headers = HeaderMap::new(); + + let headers = self.create_headers(ctx); + if !headers.is_empty() { + req_headers.extend(headers); + } - req_headers - } + req_headers.extend(ctx.headers().to_owned()); + + req_headers + } } impl RenderedRequestTemplate { - pub fn to_request(&self) -> Result { - let mut req = reqwest::Request::new(Method::POST, self.url.clone()); - req.headers_mut().extend(self.headers.clone()); - - Ok(create_grpc_request( - self.url.clone(), - self.headers.clone(), - self.operation.convert_input(self.body.as_str())?, - )) - } + pub fn to_request(&self) -> Result { + let mut req = reqwest::Request::new(Method::POST, self.url.clone()); + req.headers_mut().extend(self.headers.clone()); + + Ok(create_grpc_request( + self.url.clone(), + self.headers.clone(), + self.operation.convert_input(self.body.as_str())?, + )) + } } #[cfg(test)] mod tests { - use std::borrow::Cow; - use std::path::PathBuf; - - use derive_setters::Setters; - use hyper::header::{HeaderName, HeaderValue}; - use hyper::{HeaderMap, Method}; - use once_cell::sync::Lazy; - use pretty_assertions::assert_eq; - - use super::RequestTemplate; - use crate::config::GraphQLOperationType; - use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; - use crate::mustache::Mustache; - - static PROTOBUF_OPERATION: Lazy = Lazy::new(|| { - let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let mut test_file = root_dir.join(file!()); - - test_file.pop(); - test_file.push("tests"); - test_file.push("greetings.proto"); - - let protobuf_set = ProtobufSet::from_proto_file(&test_file).unwrap(); - let service = protobuf_set.find_service("Greeter").unwrap(); - - service.find_operation("SayHello").unwrap() - }); - - #[derive(Setters)] - struct Context { - pub value: serde_json::Value, - pub headers: HeaderMap, - } + use std::borrow::Cow; + use std::path::PathBuf; + + use derive_setters::Setters; + use hyper::header::{HeaderName, HeaderValue}; + use hyper::{HeaderMap, Method}; + use once_cell::sync::Lazy; + use pretty_assertions::assert_eq; + + use super::RequestTemplate; + use crate::config::GraphQLOperationType; + use crate::grpc::protobuf::{ProtobufOperation, ProtobufSet}; + use crate::mustache::Mustache; + + static PROTOBUF_OPERATION: Lazy = Lazy::new(|| { + let root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let mut test_file = root_dir.join(file!()); + + test_file.pop(); + test_file.push("tests"); + test_file.push("greetings.proto"); + + let protobuf_set = ProtobufSet::from_proto_file(&test_file).unwrap(); + let service = protobuf_set.find_service("Greeter").unwrap(); + + service.find_operation("SayHello").unwrap() + }); + + #[derive(Setters)] + struct Context { + pub value: serde_json::Value, + pub headers: HeaderMap, + } - impl Default for Context { - fn default() -> Self { - Self { value: serde_json::Value::Null, headers: HeaderMap::new() } + impl Default for Context { + fn default() -> Self { + Self { value: serde_json::Value::Null, headers: HeaderMap::new() } + } } - } - impl crate::path::PathString for Context { - fn path_string>(&self, parts: &[T]) -> Option> { - self.value.path_string(parts) + impl crate::path::PathString for Context { + fn path_string>(&self, parts: &[T]) -> Option> { + self.value.path_string(parts) + } } - } - impl crate::has_headers::HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + impl crate::has_headers::HasHeaders for Context { + fn headers(&self) -> &HeaderMap { + &self.headers + } } - } - - #[test] - fn request_with_empty_body() { - let tmpl = RequestTemplate { - url: Mustache::parse("http://localhost:3000/").unwrap(), - headers: vec![( - HeaderName::from_static("test-header"), - Mustache::parse("value").unwrap(), - )], - operation: PROTOBUF_OPERATION.clone(), - body: None, - operation_type: GraphQLOperationType::Query, - }; - let ctx = Context::default(); - let rendered = tmpl.render(&ctx).unwrap(); - let req = rendered.to_request().unwrap(); - - assert_eq!(req.url().as_str(), "http://localhost:3000/"); - assert_eq!(req.method(), Method::POST); - assert_eq!( - req.headers(), - &HeaderMap::from_iter([ - ( - HeaderName::from_static("test-header"), - HeaderValue::from_static("value") - ), - ( - HeaderName::from_static("content-type"), - HeaderValue::from_static("application/grpc") - ) - ]) - ); - - if let Some(body) = req.body() { - assert_eq!(body.as_bytes(), Some(b"\0\0\0\0\0".as_ref())) + + #[test] + fn request_with_empty_body() { + let tmpl = RequestTemplate { + url: Mustache::parse("http://localhost:3000/").unwrap(), + headers: vec![( + HeaderName::from_static("test-header"), + Mustache::parse("value").unwrap(), + )], + operation: PROTOBUF_OPERATION.clone(), + body: None, + operation_type: GraphQLOperationType::Query, + }; + let ctx = Context::default(); + let rendered = tmpl.render(&ctx).unwrap(); + let req = rendered.to_request().unwrap(); + + assert_eq!(req.url().as_str(), "http://localhost:3000/"); + assert_eq!(req.method(), Method::POST); + assert_eq!( + req.headers(), + &HeaderMap::from_iter([ + ( + HeaderName::from_static("test-header"), + HeaderValue::from_static("value") + ), + ( + HeaderName::from_static("content-type"), + HeaderValue::from_static("application/grpc") + ) + ]) + ); + + if let Some(body) = req.body() { + assert_eq!(body.as_bytes(), Some(b"\0\0\0\0\0".as_ref())) + } } - } - - #[test] - fn request_with_body() { - let tmpl = RequestTemplate { - url: Mustache::parse("http://localhost:3000/").unwrap(), - headers: vec![], - operation: PROTOBUF_OPERATION.clone(), - body: Some(Mustache::parse(r#"{ "name": "test" }"#).unwrap()), - operation_type: GraphQLOperationType::Query, - }; - let ctx = Context::default(); - let rendered = tmpl.render(&ctx).unwrap(); - let req = rendered.to_request().unwrap(); - - if let Some(body) = req.body() { - assert_eq!(body.as_bytes(), Some(b"\0\0\0\0\x06\n\x04test".as_ref())) + + #[test] + fn request_with_body() { + let tmpl = RequestTemplate { + url: Mustache::parse("http://localhost:3000/").unwrap(), + headers: vec![], + operation: PROTOBUF_OPERATION.clone(), + body: Some(Mustache::parse(r#"{ "name": "test" }"#).unwrap()), + operation_type: GraphQLOperationType::Query, + }; + let ctx = Context::default(); + let rendered = tmpl.render(&ctx).unwrap(); + let req = rendered.to_request().unwrap(); + + if let Some(body) = req.body() { + assert_eq!(body.as_bytes(), Some(b"\0\0\0\0\x06\n\x04test".as_ref())) + } } - } } diff --git a/src/has_headers.rs b/src/has_headers.rs index 1d1fc543ccc..c2c04cda940 100644 --- a/src/has_headers.rs +++ b/src/has_headers.rs @@ -3,11 +3,11 @@ use hyper::HeaderMap; use crate::lambda::{EvaluationContext, ResolverContextLike}; pub trait HasHeaders { - fn headers(&self) -> &HeaderMap; + fn headers(&self) -> &HeaderMap; } impl<'a, Ctx: ResolverContextLike<'a>> HasHeaders for EvaluationContext<'a, Ctx> { - fn headers(&self) -> &HeaderMap { - self.headers() - } + fn headers(&self) -> &HeaderMap { + self.headers() + } } diff --git a/src/helpers/body.rs b/src/helpers/body.rs index b6bc5ace3ea..9940bbb338b 100644 --- a/src/helpers/body.rs +++ b/src/helpers/body.rs @@ -2,34 +2,37 @@ use crate::mustache::Mustache; use crate::valid::{Valid, ValidationError}; pub fn to_body(body: Option<&str>) -> Valid, String> { - let Some(body) = body else { - return Valid::succeed(None); - }; + let Some(body) = body else { + return Valid::succeed(None); + }; - Valid::from( - Mustache::parse(body) - .map(Some) - .map_err(|e| ValidationError::new(e.to_string())), - ) + Valid::from( + Mustache::parse(body) + .map(Some) + .map_err(|e| ValidationError::new(e.to_string())), + ) } #[cfg(test)] mod tests { - use super::to_body; - use crate::mustache::Mustache; - use crate::valid::Valid; + use super::to_body; + use crate::mustache::Mustache; + use crate::valid::Valid; - #[test] - fn no_body() { - let result = to_body(None); + #[test] + fn no_body() { + let result = to_body(None); - assert_eq!(result, Valid::succeed(None)); - } + assert_eq!(result, Valid::succeed(None)); + } - #[test] - fn body_parse_success() { - let result = to_body(Some("content")); + #[test] + fn body_parse_success() { + let result = to_body(Some("content")); - assert_eq!(result, Valid::succeed(Some(Mustache::parse("content").unwrap()))); - } + assert_eq!( + result, + Valid::succeed(Some(Mustache::parse("content").unwrap())) + ); + } } diff --git a/src/helpers/headers.rs b/src/helpers/headers.rs index 438358b9111..365ce940572 100644 --- a/src/helpers/headers.rs +++ b/src/helpers/headers.rs @@ -7,55 +7,63 @@ use crate::valid::{Valid, ValidationError}; pub type MustacheHeaders = Vec<(HeaderName, Mustache)>; pub fn to_mustache_headers(headers: &KeyValues) -> Valid { - Valid::from_iter(headers.iter(), |(k, v)| { - let name = - Valid::from(HeaderName::from_bytes(k.as_bytes()).map_err(|e| ValidationError::new(e.to_string()))).trace(k); + Valid::from_iter(headers.iter(), |(k, v)| { + let name = Valid::from( + HeaderName::from_bytes(k.as_bytes()).map_err(|e| ValidationError::new(e.to_string())), + ) + .trace(k); - let value = Valid::from(Mustache::parse(v.as_str()).map_err(|e| ValidationError::new(e.to_string()))).trace(v); + let value = Valid::from( + Mustache::parse(v.as_str()).map_err(|e| ValidationError::new(e.to_string())), + ) + .trace(v); - name.zip(value).map(|(name, value)| (name, value)) - }) + name.zip(value).map(|(name, value)| (name, value)) + }) } #[cfg(test)] mod tests { - use anyhow::Result; - use hyper::header::HeaderName; - - use super::to_mustache_headers; - use crate::config::KeyValues; - use crate::mustache::Mustache; - - #[test] - fn valid_headers() -> Result<()> { - let input: KeyValues = serde_json::from_str(r#"[{"key": "a", "value": "str"}, {"key": "b", "value": "123"}]"#)?; - - let headers = to_mustache_headers(&input).to_result()?; - - assert_eq!( - headers, - vec![ - (HeaderName::from_bytes(b"a")?, Mustache::parse("str")?), - (HeaderName::from_bytes(b"b")?, Mustache::parse("123")?) - ] - ); - - Ok(()) - } - - #[test] - fn not_valid_due_to_utf8() { - let input: KeyValues = - serde_json::from_str(r#"[{"key": "😅", "value": "str"}, {"key": "b", "value": "🦀"}]"#).unwrap(); - let error = to_mustache_headers(&input).to_result().unwrap_err(); - - // HeaderValue should be parsed just fine despite non-visible ascii symbols range - // see https://github.com/hyperium/http/issues/519 - assert_eq!( - error.to_string(), - r"Validation Error + use anyhow::Result; + use hyper::header::HeaderName; + + use super::to_mustache_headers; + use crate::config::KeyValues; + use crate::mustache::Mustache; + + #[test] + fn valid_headers() -> Result<()> { + let input: KeyValues = serde_json::from_str( + r#"[{"key": "a", "value": "str"}, {"key": "b", "value": "123"}]"#, + )?; + + let headers = to_mustache_headers(&input).to_result()?; + + assert_eq!( + headers, + vec![ + (HeaderName::from_bytes(b"a")?, Mustache::parse("str")?), + (HeaderName::from_bytes(b"b")?, Mustache::parse("123")?) + ] + ); + + Ok(()) + } + + #[test] + fn not_valid_due_to_utf8() { + let input: KeyValues = + serde_json::from_str(r#"[{"key": "😅", "value": "str"}, {"key": "b", "value": "🦀"}]"#) + .unwrap(); + let error = to_mustache_headers(&input).to_result().unwrap_err(); + + // HeaderValue should be parsed just fine despite non-visible ascii symbols range + // see https://github.com/hyperium/http/issues/519 + assert_eq!( + error.to_string(), + r"Validation Error • invalid HTTP header name [😅] " - ); - } + ); + } } diff --git a/src/helpers/url.rs b/src/helpers/url.rs index 168eb379bb5..996c3f12383 100644 --- a/src/helpers/url.rs +++ b/src/helpers/url.rs @@ -2,20 +2,23 @@ use crate::mustache::Mustache; use crate::valid::{Valid, ValidationError}; pub fn to_url(url: &str) -> Valid { - Valid::from(Mustache::parse(url).map_err(|e| ValidationError::new(e.to_string()))) + Valid::from(Mustache::parse(url).map_err(|e| ValidationError::new(e.to_string()))) } #[cfg(test)] mod tests { - use super::to_url; + use super::to_url; - #[test] - fn parse_url() { - use crate::mustache::Mustache; - use crate::valid::Valid; + #[test] + fn parse_url() { + use crate::mustache::Mustache; + use crate::valid::Valid; - let url = to_url("http://localhost:3000"); + let url = to_url("http://localhost:3000"); - assert_eq!(url, Valid::succeed(Mustache::parse("http://localhost:3000").unwrap())); - } + assert_eq!( + url, + Valid::succeed(Mustache::parse("http://localhost:3000").unwrap()) + ); + } } diff --git a/src/helpers/value.rs b/src/helpers/value.rs index 4b025ce446c..26a972932fd 100644 --- a/src/helpers/value.rs +++ b/src/helpers/value.rs @@ -6,33 +6,33 @@ use async_graphql_value::ConstValue; pub struct HashableConstValue(pub ConstValue); impl Hash for HashableConstValue { - fn hash(&self, state: &mut H) { - hash(&self.0, state) - } + fn hash(&self, state: &mut H) { + hash(&self.0, state) + } } impl PartialEq for HashableConstValue { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } } pub fn hash(const_value: &ConstValue, state: &mut H) { - match const_value { - ConstValue::Null => {} - ConstValue::Boolean(val) => val.hash(state), - ConstValue::Enum(name) => name.hash(state), - ConstValue::Number(num) => num.hash(state), - ConstValue::Binary(bytes) => bytes.hash(state), - ConstValue::String(string) => string.hash(state), - ConstValue::List(list) => list.iter().for_each(|val| hash(val, state)), - ConstValue::Object(object) => { - let mut tmp_list: Vec<_> = object.iter().collect(); - tmp_list.sort_by(|(key1, _), (key2, _)| key1.cmp(key2)); - tmp_list.iter().for_each(|(key, value)| { - key.hash(state); - hash(value, state); - }) + match const_value { + ConstValue::Null => {} + ConstValue::Boolean(val) => val.hash(state), + ConstValue::Enum(name) => name.hash(state), + ConstValue::Number(num) => num.hash(state), + ConstValue::Binary(bytes) => bytes.hash(state), + ConstValue::String(string) => string.hash(state), + ConstValue::List(list) => list.iter().for_each(|val| hash(val, state)), + ConstValue::Object(object) => { + let mut tmp_list: Vec<_> = object.iter().collect(); + tmp_list.sort_by(|(key1, _), (key2, _)| key1.cmp(key2)); + tmp_list.iter().for_each(|(key, value)| { + key.hash(state); + hash(value, state); + }) + } } - } } diff --git a/src/http/cache.rs b/src/http/cache.rs index 124c70d3ae6..106d7e20f6f 100644 --- a/src/http/cache.rs +++ b/src/http/cache.rs @@ -5,105 +5,106 @@ use cache_control::{Cachability, CacheControl}; use super::Response; pub fn cache_policy(res: &Response) -> Option { - let header = res.headers.get(hyper::header::CACHE_CONTROL)?; - let value = header.to_str().ok()?; + let header = res.headers.get(hyper::header::CACHE_CONTROL)?; + let value = header.to_str().ok()?; - CacheControl::from_value(value) + CacheControl::from_value(value) } pub fn max_age(res: &Response) -> Option { - match cache_policy(res) { - Some(value) => value.max_age, - None => None, - } + match cache_policy(res) { + Some(value) => value.max_age, + None => None, + } } pub fn cache_visibility(res: &Response) -> String { - let cachability = cache_policy(res).and_then(|value| value.cachability); - - match cachability { - Some(Cachability::Public) => "public".to_string(), - Some(Cachability::Private) => "private".to_string(), - Some(Cachability::NoCache) => "no-cache".to_string(), - _ => "".to_string(), - } + let cachability = cache_policy(res).and_then(|value| value.cachability); + + match cachability { + Some(Cachability::Public) => "public".to_string(), + Some(Cachability::Private) => "private".to_string(), + Some(Cachability::NoCache) => "no-cache".to_string(), + _ => "".to_string(), + } } /// Returns the minimum TTL of the given responses. pub fn min_ttl<'a>(res_vec: impl Iterator>) -> i32 { - let mut min = -1; - - for res in res_vec { - if let Some(max_age) = max_age(res) { - let ttl = max_age.as_secs() as i32; - if min == -1 || ttl < min { - min = ttl; - } + let mut min = -1; + + for res in res_vec { + if let Some(max_age) = max_age(res) { + let ttl = max_age.as_secs() as i32; + if min == -1 || ttl < min { + min = ttl; + } + } } - } - min + min } #[cfg(test)] mod tests { - use std::time::Duration; - - use hyper::HeaderMap; - - use crate::http::Response; - - fn cache_control_header(i: i32) -> HeaderMap { - let mut headers = reqwest::header::HeaderMap::default(); - headers.append("Cache-Control", format!("max-age={}", i).parse().unwrap()); - headers - } - - fn cache_control_header_visibility(i: i32, visibility: &str) -> HeaderMap { - let mut headers = reqwest::header::HeaderMap::default(); - headers.append( - "Cache-Control", - format!("max-age={}, {}", i, visibility).parse().unwrap(), - ); - headers - } - - #[test] - fn test_max_age_none() { - let response = Response::default(); - assert_eq!(super::max_age(&response), None); - } - - #[test] - fn test_max_age_some() { - let headers = cache_control_header(3600); - let response = Response::default().headers(headers); - - assert_eq!(super::max_age(&response), Some(Duration::from_secs(3600))); - } - - #[test] - fn test_min_ttl() { - let max_ages = [3600, 1800, 7200].map(|i| Response::default().headers(cache_control_header(i))); - let min = super::min_ttl(max_ages.iter()); - assert_eq!(min, 1800); - } - - #[test] - fn test_cache_visibility_public() { - let headers = cache_control_header_visibility(3600, "public"); - let response = Response::default().headers(headers); - - assert_eq!(super::max_age(&response), Some(Duration::from_secs(3600))); - assert_eq!(super::cache_visibility(&response), "public"); - } - - #[test] - fn test_cache_visibility_private() { - let headers = cache_control_header_visibility(3600, "private"); - let response = Response::default().headers(headers); - - assert_eq!(super::max_age(&response), Some(Duration::from_secs(3600))); - assert_eq!(super::cache_visibility(&response), "private"); - } + use std::time::Duration; + + use hyper::HeaderMap; + + use crate::http::Response; + + fn cache_control_header(i: i32) -> HeaderMap { + let mut headers = reqwest::header::HeaderMap::default(); + headers.append("Cache-Control", format!("max-age={}", i).parse().unwrap()); + headers + } + + fn cache_control_header_visibility(i: i32, visibility: &str) -> HeaderMap { + let mut headers = reqwest::header::HeaderMap::default(); + headers.append( + "Cache-Control", + format!("max-age={}, {}", i, visibility).parse().unwrap(), + ); + headers + } + + #[test] + fn test_max_age_none() { + let response = Response::default(); + assert_eq!(super::max_age(&response), None); + } + + #[test] + fn test_max_age_some() { + let headers = cache_control_header(3600); + let response = Response::default().headers(headers); + + assert_eq!(super::max_age(&response), Some(Duration::from_secs(3600))); + } + + #[test] + fn test_min_ttl() { + let max_ages = + [3600, 1800, 7200].map(|i| Response::default().headers(cache_control_header(i))); + let min = super::min_ttl(max_ages.iter()); + assert_eq!(min, 1800); + } + + #[test] + fn test_cache_visibility_public() { + let headers = cache_control_header_visibility(3600, "public"); + let response = Response::default().headers(headers); + + assert_eq!(super::max_age(&response), Some(Duration::from_secs(3600))); + assert_eq!(super::cache_visibility(&response), "public"); + } + + #[test] + fn test_cache_visibility_private() { + let headers = cache_control_header_visibility(3600, "private"); + let response = Response::default().headers(headers); + + assert_eq!(super::max_age(&response), Some(Duration::from_secs(3600))); + assert_eq!(super::cache_visibility(&response), "private"); + } } diff --git a/src/http/data_loader.rs b/src/http/data_loader.rs index 8940778eb2a..7b202d4d1f9 100644 --- a/src/http/data_loader.rs +++ b/src/http/data_loader.rs @@ -14,101 +14,102 @@ use crate::json::JsonLike; use crate::HttpIO; fn get_body_value_single(body_value: &HashMap>, id: &str) -> ConstValue { - body_value - .get(id) - .and_then(|a| a.first().cloned().cloned()) - .unwrap_or(ConstValue::Null) + body_value + .get(id) + .and_then(|a| a.first().cloned().cloned()) + .unwrap_or(ConstValue::Null) } fn get_body_value_list(body_value: &HashMap>, id: &str) -> ConstValue { - ConstValue::List( - body_value - .get(id) - .unwrap_or(&Vec::new()) - .iter() - .map(|&o| o.to_owned()) - .collect::>(), - ) + ConstValue::List( + body_value + .get(id) + .unwrap_or(&Vec::new()) + .iter() + .map(|&o| o.to_owned()) + .collect::>(), + ) } #[derive(Clone)] pub struct HttpDataLoader { - pub client: Arc, - pub group_by: Option, - pub body: fn(&HashMap>, &str) -> ConstValue, + pub client: Arc, + pub group_by: Option, + pub body: fn(&HashMap>, &str) -> ConstValue, } impl HttpDataLoader { - pub fn new(client: Arc, group_by: Option, is_list: bool) -> Self { - HttpDataLoader { - client, - group_by, - body: if is_list { - get_body_value_list - } else { - get_body_value_single - }, + pub fn new(client: Arc, group_by: Option, is_list: bool) -> Self { + HttpDataLoader { + client, + group_by, + body: if is_list { + get_body_value_list + } else { + get_body_value_single + }, + } } - } - pub fn to_data_loader(self, batch: Batch) -> DataLoader { - DataLoader::new(self) - .delay(Duration::from_millis(batch.delay as u64)) - .max_batch_size(batch.max_size) - } + pub fn to_data_loader(self, batch: Batch) -> DataLoader { + DataLoader::new(self) + .delay(Duration::from_millis(batch.delay as u64)) + .max_batch_size(batch.max_size) + } } #[async_trait::async_trait] impl Loader for HttpDataLoader { - type Value = Response; - type Error = Arc; + type Value = Response; + type Error = Arc; - async fn load( - &self, - keys: &[DataLoaderRequest], - ) -> async_graphql::Result, Self::Error> { - if let Some(group_by) = &self.group_by { - let mut keys = keys.to_vec(); - keys.sort_by(|a, b| a.to_request().url().cmp(b.to_request().url())); + async fn load( + &self, + keys: &[DataLoaderRequest], + ) -> async_graphql::Result, Self::Error> { + if let Some(group_by) = &self.group_by { + let mut keys = keys.to_vec(); + keys.sort_by(|a, b| a.to_request().url().cmp(b.to_request().url())); - let mut request = keys[0].to_request(); - let first_url = request.url_mut(); + let mut request = keys[0].to_request(); + let first_url = request.url_mut(); - for key in &keys[1..] { - let request = key.to_request(); - let url = request.url(); - first_url.query_pairs_mut().extend_pairs(url.query_pairs()); - } + for key in &keys[1..] { + let request = key.to_request(); + let url = request.url(); + first_url.query_pairs_mut().extend_pairs(url.query_pairs()); + } - let res = self.client.execute(request).await?.to_json()?; - #[allow(clippy::mutable_key_type)] - let mut hashmap = HashMap::with_capacity(keys.len()); - let path = &group_by.path(); - let body_value = res.body.group_by(path); + let res = self.client.execute(request).await?.to_json()?; + #[allow(clippy::mutable_key_type)] + let mut hashmap = HashMap::with_capacity(keys.len()); + let path = &group_by.path(); + let body_value = res.body.group_by(path); - for key in &keys { - let req = key.to_request(); - let query_set: std::collections::HashMap<_, _> = req.url().query_pairs().collect(); - let id = query_set - .get(group_by.key()) - .ok_or(anyhow::anyhow!("Unable to find key {} in query params", group_by.key()))?; - hashmap.insert(key.clone(), res.clone().body((self.body)(&body_value, id))); - } - Ok(hashmap) - } else { - let results = keys.iter().map(|key| async { - let result = self.client.execute(key.to_request()).await; - (key.clone(), result) - }); + for key in &keys { + let req = key.to_request(); + let query_set: std::collections::HashMap<_, _> = req.url().query_pairs().collect(); + let id = query_set.get(group_by.key()).ok_or(anyhow::anyhow!( + "Unable to find key {} in query params", + group_by.key() + ))?; + hashmap.insert(key.clone(), res.clone().body((self.body)(&body_value, id))); + } + Ok(hashmap) + } else { + let results = keys.iter().map(|key| async { + let result = self.client.execute(key.to_request()).await; + (key.clone(), result) + }); - let results = join_all(results).await; + let results = join_all(results).await; - #[allow(clippy::mutable_key_type)] - let mut hashmap = HashMap::new(); - for (key, value) in results { - hashmap.insert(key, value?.to_json()?); - } + #[allow(clippy::mutable_key_type)] + let mut hashmap = HashMap::new(); + for (key, value) in results { + hashmap.insert(key, value?.to_json()?); + } - Ok(hashmap) + Ok(hashmap) + } } - } } diff --git a/src/http/data_loader_request.rs b/src/http/data_loader_request.rs index 8c80940bf35..03b6991b743 100644 --- a/src/http/data_loader_request.rs +++ b/src/http/data_loader_request.rs @@ -7,234 +7,237 @@ use std::ops::Deref; pub struct DataLoaderRequest(reqwest::Request, BTreeSet); impl DataLoaderRequest { - pub fn new(req: reqwest::Request, headers: BTreeSet) -> Self { - // TODO: req should already have headers builtin, no? - DataLoaderRequest(req, headers) - } - pub fn to_request(&self) -> reqwest::Request { - // TODO: excessive clone for the whole structure instead of cloning only part of it - // check if we really need to clone anything at all or just pass references? - self.clone().0 - } - pub fn headers(&self) -> &BTreeSet { - &self.1 - } + pub fn new(req: reqwest::Request, headers: BTreeSet) -> Self { + // TODO: req should already have headers builtin, no? + DataLoaderRequest(req, headers) + } + pub fn to_request(&self) -> reqwest::Request { + // TODO: excessive clone for the whole structure instead of cloning only part of it + // check if we really need to clone anything at all or just pass references? + self.clone().0 + } + pub fn headers(&self) -> &BTreeSet { + &self.1 + } } impl Hash for DataLoaderRequest { - fn hash(&self, state: &mut H) { - self.0.url().hash(state); - // use body in hash for graphql queries with query operation as they used to fetch data - // while http post and graphql mutation should not be loaded through dataloader at all! - if let Some(body) = self.0.body() { - body.as_bytes().hash(state); - } - for name in &self.1 { - if let Some(value) = self.0.headers().get(name) { - name.hash(state); - value.hash(state); - } - } - } + fn hash(&self, state: &mut H) { + self.0.url().hash(state); + // use body in hash for graphql queries with query operation as they used to fetch data + // while http post and graphql mutation should not be loaded through dataloader at all! + if let Some(body) = self.0.body() { + body.as_bytes().hash(state); + } + for name in &self.1 { + if let Some(value) = self.0.headers().get(name) { + name.hash(state); + value.hash(state); + } + } + } } impl PartialEq for DataLoaderRequest { - fn eq(&self, other: &Self) -> bool { - let mut hasher_self = DefaultHasher::new(); - self.hash(&mut hasher_self); - let hash_self = hasher_self.finish(); + fn eq(&self, other: &Self) -> bool { + let mut hasher_self = DefaultHasher::new(); + self.hash(&mut hasher_self); + let hash_self = hasher_self.finish(); - let mut hasher_other = DefaultHasher::new(); - other.hash(&mut hasher_other); - let hash_other = hasher_other.finish(); + let mut hasher_other = DefaultHasher::new(); + other.hash(&mut hasher_other); + let hash_other = hasher_other.finish(); - hash_self == hash_other - } + hash_self == hash_other + } } impl Eq for DataLoaderRequest {} impl Clone for DataLoaderRequest { - fn clone(&self) -> Self { - let req = self.0.try_clone().unwrap_or_else(|| { - let mut req = reqwest::Request::new(self.0.method().clone(), self.0.url().clone()); - req.headers_mut().extend(self.0.headers().clone()); - req - }); - - DataLoaderRequest(req, self.1.clone()) - } + fn clone(&self) -> Self { + let req = self.0.try_clone().unwrap_or_else(|| { + let mut req = reqwest::Request::new(self.0.method().clone(), self.0.url().clone()); + req.headers_mut().extend(self.0.headers().clone()); + req + }); + + DataLoaderRequest(req, self.1.clone()) + } } impl Deref for DataLoaderRequest { - type Target = reqwest::Request; + type Target = reqwest::Request; - fn deref(&self) -> &Self::Target { - &self.0 - } + fn deref(&self) -> &Self::Target { + &self.0 + } } #[cfg(test)] mod tests { - use hyper::header::{HeaderName, HeaderValue}; - - use super::*; - fn create_request_with_headers(url: &str, headers: Vec<(&str, &str)>) -> reqwest::Request { - let mut req = reqwest::Request::new(reqwest::Method::GET, url.parse().unwrap()); - for (name, value) in headers { - req.headers_mut().insert( - name.parse::().unwrap(), - value.parse::().unwrap(), - ); - } - req - } - - fn create_endpoint_key( - url: &str, - headers: Vec<(&str, &str)>, - hash_key_headers: BTreeSet, - ) -> DataLoaderRequest { - DataLoaderRequest::new(create_request_with_headers(url, headers), hash_key_headers) - } - - #[test] - fn test_hash_endpoint_key() { - let endpoint_key_1 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); - let endpoint_key_2 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); - assert_eq!(endpoint_key_1, endpoint_key_2); - } - - #[test] - fn test_with_endpoint_key_with_headers() { - let endpoint_key_1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "2")], - BTreeSet::from(["a".to_string(), "b".to_string()]), - ); - let endpoint_key_2 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "2"), ("c", "3")], - BTreeSet::from(["a".to_string(), "b".to_string()]), - ); - assert_eq!(endpoint_key_1, endpoint_key_2); - } - - #[test] - fn test_with_endpoint_key_with_headers_ne() { - let endpoint_key_1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "2"), ("c", "4")], - BTreeSet::from(["a".to_string(), "b".to_string(), "c".to_string()]), - ); - let endpoint_key_2 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "2"), ("c", "3")], - BTreeSet::from(["a".to_string(), "b".to_string(), "c".to_string()]), - ); - assert_ne!(endpoint_key_1, endpoint_key_2); - } - #[test] - fn test_different_http_methods() { - let key1 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); - let req = reqwest::Request::new(reqwest::Method::POST, "http://localhost:8080".parse().unwrap()); - let key2 = DataLoaderRequest::new(req, BTreeSet::new()); - assert_eq!(key1, key2); - } - - #[test] - fn test_different_urls() { - let key1 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); - let key2 = create_endpoint_key("http://example.com:8080", vec![], BTreeSet::new()); - assert_ne!(key1, key2); - } - - #[test] - fn test_mismatched_header_names() { - let key1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1")], - BTreeSet::from(["a".to_string()]), - ); - let key2 = create_endpoint_key( - "http://localhost:8080", - vec![("b", "1")], - BTreeSet::from(["b".to_string()]), - ); - assert_ne!(key1, key2); - } - - #[test] - fn test_mismatched_header_values() { - let key1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1")], - BTreeSet::from(["a".to_string()]), - ); - let key2 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "2")], - BTreeSet::from(["a".to_string()]), - ); - assert_ne!(key1, key2); - } - - #[test] - fn test_differing_number_of_headers() { - let key1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1")], - BTreeSet::from(["a".to_string()]), - ); - let key2 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "2")], - BTreeSet::from(["a".to_string(), "b".to_string()]), - ); - assert_ne!(key1, key2); - } - #[test] - fn test_clone_trait() { - let key1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1")], - BTreeSet::from(["a".to_string()]), - ); - let key2 = key1.clone(); - - // The cloned key should be equal to the original - assert_eq!(key1, key2); - } - - #[test] - fn test_partial_eq_trait() { - let key1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1")], - BTreeSet::from(["a".to_string()]), - ); - let key2 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1")], - BTreeSet::from(["a".to_string()]), - ); - - // Both keys have the same data, so they should be equal - assert_eq!(key1, key2); - } - - #[test] - fn test_partial_eq_not_equal() { - let key1 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "1")], - BTreeSet::from(["a".to_string(), "b".to_string()]), - ); - let key2 = create_endpoint_key( - "http://localhost:8080", - vec![("a", "1"), ("b", "1")], - BTreeSet::from(["a".to_string()]), - ); - - assert_ne!(key1, key2); - } + use hyper::header::{HeaderName, HeaderValue}; + + use super::*; + fn create_request_with_headers(url: &str, headers: Vec<(&str, &str)>) -> reqwest::Request { + let mut req = reqwest::Request::new(reqwest::Method::GET, url.parse().unwrap()); + for (name, value) in headers { + req.headers_mut().insert( + name.parse::().unwrap(), + value.parse::().unwrap(), + ); + } + req + } + + fn create_endpoint_key( + url: &str, + headers: Vec<(&str, &str)>, + hash_key_headers: BTreeSet, + ) -> DataLoaderRequest { + DataLoaderRequest::new(create_request_with_headers(url, headers), hash_key_headers) + } + + #[test] + fn test_hash_endpoint_key() { + let endpoint_key_1 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); + let endpoint_key_2 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); + assert_eq!(endpoint_key_1, endpoint_key_2); + } + + #[test] + fn test_with_endpoint_key_with_headers() { + let endpoint_key_1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "2")], + BTreeSet::from(["a".to_string(), "b".to_string()]), + ); + let endpoint_key_2 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "2"), ("c", "3")], + BTreeSet::from(["a".to_string(), "b".to_string()]), + ); + assert_eq!(endpoint_key_1, endpoint_key_2); + } + + #[test] + fn test_with_endpoint_key_with_headers_ne() { + let endpoint_key_1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "2"), ("c", "4")], + BTreeSet::from(["a".to_string(), "b".to_string(), "c".to_string()]), + ); + let endpoint_key_2 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "2"), ("c", "3")], + BTreeSet::from(["a".to_string(), "b".to_string(), "c".to_string()]), + ); + assert_ne!(endpoint_key_1, endpoint_key_2); + } + #[test] + fn test_different_http_methods() { + let key1 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); + let req = reqwest::Request::new( + reqwest::Method::POST, + "http://localhost:8080".parse().unwrap(), + ); + let key2 = DataLoaderRequest::new(req, BTreeSet::new()); + assert_eq!(key1, key2); + } + + #[test] + fn test_different_urls() { + let key1 = create_endpoint_key("http://localhost:8080", vec![], BTreeSet::new()); + let key2 = create_endpoint_key("http://example.com:8080", vec![], BTreeSet::new()); + assert_ne!(key1, key2); + } + + #[test] + fn test_mismatched_header_names() { + let key1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1")], + BTreeSet::from(["a".to_string()]), + ); + let key2 = create_endpoint_key( + "http://localhost:8080", + vec![("b", "1")], + BTreeSet::from(["b".to_string()]), + ); + assert_ne!(key1, key2); + } + + #[test] + fn test_mismatched_header_values() { + let key1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1")], + BTreeSet::from(["a".to_string()]), + ); + let key2 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "2")], + BTreeSet::from(["a".to_string()]), + ); + assert_ne!(key1, key2); + } + + #[test] + fn test_differing_number_of_headers() { + let key1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1")], + BTreeSet::from(["a".to_string()]), + ); + let key2 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "2")], + BTreeSet::from(["a".to_string(), "b".to_string()]), + ); + assert_ne!(key1, key2); + } + #[test] + fn test_clone_trait() { + let key1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1")], + BTreeSet::from(["a".to_string()]), + ); + let key2 = key1.clone(); + + // The cloned key should be equal to the original + assert_eq!(key1, key2); + } + + #[test] + fn test_partial_eq_trait() { + let key1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1")], + BTreeSet::from(["a".to_string()]), + ); + let key2 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1")], + BTreeSet::from(["a".to_string()]), + ); + + // Both keys have the same data, so they should be equal + assert_eq!(key1, key2); + } + + #[test] + fn test_partial_eq_not_equal() { + let key1 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "1")], + BTreeSet::from(["a".to_string(), "b".to_string()]), + ); + let key2 = create_endpoint_key( + "http://localhost:8080", + vec![("a", "1"), ("b", "1")], + BTreeSet::from(["a".to_string()]), + ); + + assert_ne!(key1, key2); + } } diff --git a/src/http/method.rs b/src/http/method.rs index b59b516afdf..34bd994b366 100644 --- a/src/http/method.rs +++ b/src/http/method.rs @@ -1,30 +1,32 @@ use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default, schemars::JsonSchema)] +#[derive( + Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Default, schemars::JsonSchema, +)] pub enum Method { - #[default] - GET, - POST, - PUT, - PATCH, - DELETE, - HEAD, - OPTIONS, - CONNECT, - TRACE, + #[default] + GET, + POST, + PUT, + PATCH, + DELETE, + HEAD, + OPTIONS, + CONNECT, + TRACE, } impl Method { - pub fn to_hyper(self) -> hyper::Method { - match self { - Method::GET => hyper::Method::GET, - Method::POST => hyper::Method::POST, - Method::PUT => hyper::Method::PUT, - Method::PATCH => hyper::Method::PATCH, - Method::DELETE => hyper::Method::DELETE, - Method::HEAD => hyper::Method::HEAD, - Method::OPTIONS => hyper::Method::OPTIONS, - Method::CONNECT => hyper::Method::CONNECT, - Method::TRACE => hyper::Method::TRACE, + pub fn to_hyper(self) -> hyper::Method { + match self { + Method::GET => hyper::Method::GET, + Method::POST => hyper::Method::POST, + Method::PUT => hyper::Method::PUT, + Method::PATCH => hyper::Method::PATCH, + Method::DELETE => hyper::Method::DELETE, + Method::HEAD => hyper::Method::HEAD, + Method::OPTIONS => hyper::Method::OPTIONS, + Method::CONNECT => hyper::Method::CONNECT, + Method::TRACE => hyper::Method::TRACE, + } } - } } diff --git a/src/http/request_context.rs b/src/http/request_context.rs index 0d7ad1a8d7f..7b62c1747e1 100644 --- a/src/http/request_context.rs +++ b/src/http/request_context.rs @@ -16,187 +16,193 @@ use crate::{grpc, EntityCache, EnvIO, HttpIO}; #[derive(Setters)] pub struct RequestContext { - // TODO: consider storing http clients where they are used i.e. expression and dataloaders - pub h_client: Arc, - // http2 only client is required for grpc in cases the server supports only http2 - // and the request will fail on protocol negotiation - // having separate client for now looks like the only way to do with reqwest - pub h2_client: Arc, - pub server: Server, - pub upstream: Upstream, - pub req_headers: HeaderMap, - pub http_data_loaders: Arc>>, - pub gql_data_loaders: Arc>>, - pub grpc_data_loaders: Arc>>, - pub min_max_age: Arc>>, - pub cache_public: Arc>>, - pub env_vars: Arc, - pub cache: Arc, + // TODO: consider storing http clients where they are used i.e. expression and dataloaders + pub h_client: Arc, + // http2 only client is required for grpc in cases the server supports only http2 + // and the request will fail on protocol negotiation + // having separate client for now looks like the only way to do with reqwest + pub h2_client: Arc, + pub server: Server, + pub upstream: Upstream, + pub req_headers: HeaderMap, + pub http_data_loaders: Arc>>, + pub gql_data_loaders: Arc>>, + pub grpc_data_loaders: Arc>>, + pub min_max_age: Arc>>, + pub cache_public: Arc>>, + pub env_vars: Arc, + pub cache: Arc, } impl RequestContext { - fn set_min_max_age_conc(&self, min_max_age: i32) { - *self.min_max_age.lock().unwrap() = Some(min_max_age); - } - pub fn get_min_max_age(&self) -> Option { - *self.min_max_age.lock().unwrap() - } - - pub fn set_cache_public_false(&self) { - *self.cache_public.lock().unwrap() = Some(false); - } - - pub fn is_cache_public(&self) -> Option { - *self.cache_public.lock().unwrap() - } - - pub fn set_min_max_age(&self, max_age: i32) { - let min_max_age_lock = self.get_min_max_age(); - match min_max_age_lock { - Some(min_max_age) if max_age < min_max_age => { - self.set_min_max_age_conc(max_age); - } - None => { - self.set_min_max_age_conc(max_age); - } - _ => {} - } - } - - pub fn set_cache_visibility(&self, cachability: &Option) { - if let Some(Cachability::Private) = cachability { - self.set_cache_public_false() - } - } - - pub fn set_cache_control(&self, cache_policy: CacheControl) { - if let Some(max_age) = cache_policy.max_age { - self.set_min_max_age(max_age.as_secs() as i32); - } - self.set_cache_visibility(&cache_policy.cachability); - if Some(Cachability::NoCache) == cache_policy.cachability { - self.set_min_max_age(-1); - } - } - - pub async fn cache_get(&self, key: &u64) -> Option { - self.cache.get(key).await.ok() - } - - #[allow(clippy::too_many_arguments)] - pub async fn cache_insert(&self, key: u64, value: ConstValue, ttl: NonZeroU64) -> Option { - self.cache.set(key, value, ttl).await.ok() - } - - pub fn is_batching_enabled(&self) -> bool { - self.upstream.batch.is_some() && (self.upstream.get_delay() >= 1 || self.upstream.get_max_size() >= 1) - } + fn set_min_max_age_conc(&self, min_max_age: i32) { + *self.min_max_age.lock().unwrap() = Some(min_max_age); + } + pub fn get_min_max_age(&self) -> Option { + *self.min_max_age.lock().unwrap() + } + + pub fn set_cache_public_false(&self) { + *self.cache_public.lock().unwrap() = Some(false); + } + + pub fn is_cache_public(&self) -> Option { + *self.cache_public.lock().unwrap() + } + + pub fn set_min_max_age(&self, max_age: i32) { + let min_max_age_lock = self.get_min_max_age(); + match min_max_age_lock { + Some(min_max_age) if max_age < min_max_age => { + self.set_min_max_age_conc(max_age); + } + None => { + self.set_min_max_age_conc(max_age); + } + _ => {} + } + } + + pub fn set_cache_visibility(&self, cachability: &Option) { + if let Some(Cachability::Private) = cachability { + self.set_cache_public_false() + } + } + + pub fn set_cache_control(&self, cache_policy: CacheControl) { + if let Some(max_age) = cache_policy.max_age { + self.set_min_max_age(max_age.as_secs() as i32); + } + self.set_cache_visibility(&cache_policy.cachability); + if Some(Cachability::NoCache) == cache_policy.cachability { + self.set_min_max_age(-1); + } + } + + pub async fn cache_get(&self, key: &u64) -> Option { + self.cache.get(key).await.ok() + } + + #[allow(clippy::too_many_arguments)] + pub async fn cache_insert( + &self, + key: u64, + value: ConstValue, + ttl: NonZeroU64, + ) -> Option { + self.cache.set(key, value, ttl).await.ok() + } + + pub fn is_batching_enabled(&self) -> bool { + self.upstream.batch.is_some() + && (self.upstream.get_delay() >= 1 || self.upstream.get_max_size() >= 1) + } } impl From<&AppContext> for RequestContext { - fn from(server_ctx: &AppContext) -> Self { - Self { - h_client: server_ctx.universal_http_client.clone(), - h2_client: server_ctx.http2_only_client.clone(), - server: server_ctx.blueprint.server.clone(), - upstream: server_ctx.blueprint.upstream.clone(), - req_headers: HeaderMap::new(), - http_data_loaders: server_ctx.http_data_loaders.clone(), - gql_data_loaders: server_ctx.gql_data_loaders.clone(), - cache: server_ctx.cache.clone(), - grpc_data_loaders: server_ctx.grpc_data_loaders.clone(), - min_max_age: Arc::new(Mutex::new(None)), - cache_public: Arc::new(Mutex::new(None)), - env_vars: server_ctx.env_vars.clone(), - } - } + fn from(server_ctx: &AppContext) -> Self { + Self { + h_client: server_ctx.universal_http_client.clone(), + h2_client: server_ctx.http2_only_client.clone(), + server: server_ctx.blueprint.server.clone(), + upstream: server_ctx.blueprint.upstream.clone(), + req_headers: HeaderMap::new(), + http_data_loaders: server_ctx.http_data_loaders.clone(), + gql_data_loaders: server_ctx.gql_data_loaders.clone(), + cache: server_ctx.cache.clone(), + grpc_data_loaders: server_ctx.grpc_data_loaders.clone(), + min_max_age: Arc::new(Mutex::new(None)), + cache_public: Arc::new(Mutex::new(None)), + env_vars: server_ctx.env_vars.clone(), + } + } } #[cfg(test)] mod test { - use std::sync::{Arc, Mutex}; - - use cache_control::Cachability; - use hyper::HeaderMap; - - use crate::blueprint::Server; - use crate::cli::cache::NativeChronoCache; - use crate::cli::{init_env, init_http, init_http2_only}; - use crate::config::{self, Batch}; - use crate::http::RequestContext; - - impl Default for RequestContext { - fn default() -> Self { - let crate::config::Config { server, upstream, .. } = crate::config::Config::default(); - //TODO: default is used only in tests. Drop default and move it to test. - let server = Server::try_from(server).unwrap(); - - let h_client = Arc::new(init_http(&upstream)); - let h2_client = Arc::new(init_http2_only(&upstream.clone())); - RequestContext { - req_headers: HeaderMap::new(), - h_client, - h2_client, - server, - upstream, - http_data_loaders: Arc::new(vec![]), - gql_data_loaders: Arc::new(vec![]), - cache: Arc::new(NativeChronoCache::new()), - grpc_data_loaders: Arc::new(vec![]), - min_max_age: Arc::new(Mutex::new(None)), - cache_public: Arc::new(Mutex::new(None)), - env_vars: Arc::new(init_env()), - } - } - } - - #[test] - fn test_update_max_age_less_than_existing() { - let req_ctx = RequestContext::default(); - req_ctx.set_min_max_age(120); - req_ctx.set_min_max_age(60); - assert_eq!(req_ctx.get_min_max_age(), Some(60)); - } - - #[test] - fn test_update_max_age_greater_than_existing() { - let req_ctx = RequestContext::default(); - req_ctx.set_min_max_age(60); - req_ctx.set_min_max_age(120); - assert_eq!(req_ctx.get_min_max_age(), Some(60)); - } - - #[test] - fn test_update_max_age_no_existing_value() { - let req_ctx = RequestContext::default(); - req_ctx.set_min_max_age(120); - assert_eq!(req_ctx.get_min_max_age(), Some(120)); - } - - #[test] - fn test_update_cache_visibility_private() { - let req_ctx = RequestContext::default(); - req_ctx.set_cache_visibility(&Some(Cachability::Private)); - assert_eq!(req_ctx.is_cache_public(), Some(false)); - } - - #[test] - fn test_update_cache_visibility_public() { - let req_ctx: RequestContext = RequestContext::default(); - req_ctx.set_cache_visibility(&Some(Cachability::Public)); - assert_eq!(req_ctx.is_cache_public(), None); - } - - #[test] - fn test_is_batching_enabled_default() { - // create ctx with default batch - let config = config::Config::default(); - let mut upstream = config.upstream.clone(); - upstream.batch = Some(Batch::default()); - let server = Server::try_from(config.server.clone()).unwrap(); - - let req_ctx: RequestContext = RequestContext::default().upstream(upstream).server(server); - - assert!(req_ctx.is_batching_enabled()); - } + use std::sync::{Arc, Mutex}; + + use cache_control::Cachability; + use hyper::HeaderMap; + + use crate::blueprint::Server; + use crate::cli::cache::NativeChronoCache; + use crate::cli::{init_env, init_http, init_http2_only}; + use crate::config::{self, Batch}; + use crate::http::RequestContext; + + impl Default for RequestContext { + fn default() -> Self { + let crate::config::Config { server, upstream, .. } = crate::config::Config::default(); + //TODO: default is used only in tests. Drop default and move it to test. + let server = Server::try_from(server).unwrap(); + + let h_client = Arc::new(init_http(&upstream)); + let h2_client = Arc::new(init_http2_only(&upstream.clone())); + RequestContext { + req_headers: HeaderMap::new(), + h_client, + h2_client, + server, + upstream, + http_data_loaders: Arc::new(vec![]), + gql_data_loaders: Arc::new(vec![]), + cache: Arc::new(NativeChronoCache::new()), + grpc_data_loaders: Arc::new(vec![]), + min_max_age: Arc::new(Mutex::new(None)), + cache_public: Arc::new(Mutex::new(None)), + env_vars: Arc::new(init_env()), + } + } + } + + #[test] + fn test_update_max_age_less_than_existing() { + let req_ctx = RequestContext::default(); + req_ctx.set_min_max_age(120); + req_ctx.set_min_max_age(60); + assert_eq!(req_ctx.get_min_max_age(), Some(60)); + } + + #[test] + fn test_update_max_age_greater_than_existing() { + let req_ctx = RequestContext::default(); + req_ctx.set_min_max_age(60); + req_ctx.set_min_max_age(120); + assert_eq!(req_ctx.get_min_max_age(), Some(60)); + } + + #[test] + fn test_update_max_age_no_existing_value() { + let req_ctx = RequestContext::default(); + req_ctx.set_min_max_age(120); + assert_eq!(req_ctx.get_min_max_age(), Some(120)); + } + + #[test] + fn test_update_cache_visibility_private() { + let req_ctx = RequestContext::default(); + req_ctx.set_cache_visibility(&Some(Cachability::Private)); + assert_eq!(req_ctx.is_cache_public(), Some(false)); + } + + #[test] + fn test_update_cache_visibility_public() { + let req_ctx: RequestContext = RequestContext::default(); + req_ctx.set_cache_visibility(&Some(Cachability::Public)); + assert_eq!(req_ctx.is_cache_public(), None); + } + + #[test] + fn test_is_batching_enabled_default() { + // create ctx with default batch + let config = config::Config::default(); + let mut upstream = config.upstream.clone(); + upstream.batch = Some(Batch::default()); + let server = Server::try_from(config.server.clone()).unwrap(); + + let req_ctx: RequestContext = RequestContext::default().upstream(upstream).server(server); + + assert!(req_ctx.is_batching_enabled()); + } } diff --git a/src/http/request_handler.rs b/src/http/request_handler.rs index b8bd9c67d76..086794528f5 100644 --- a/src/http/request_handler.rs +++ b/src/http/request_handler.rs @@ -15,129 +15,138 @@ use crate::async_graphql_hyper::{GraphQLRequestLike, GraphQLResponse}; use crate::{EnvIO, HttpIO}; pub fn graphiql(req: &Request) -> Result> { - let query = req.uri().query(); - let endpoint = "/graphql"; - let endpoint = if let Some(query) = query { - if query.is_empty() { - Cow::Borrowed(endpoint) + let query = req.uri().query(); + let endpoint = "/graphql"; + let endpoint = if let Some(query) = query { + if query.is_empty() { + Cow::Borrowed(endpoint) + } else { + Cow::Owned(format!("{}?{}", endpoint, query)) + } } else { - Cow::Owned(format!("{}?{}", endpoint, query)) - } - } else { - Cow::Borrowed(endpoint) - }; + Cow::Borrowed(endpoint) + }; - log::info!("GraphiQL endpoint: {}", endpoint); - Ok(Response::new(Body::from(playground_source( - GraphQLPlaygroundConfig::new(&endpoint).title("Tailcall - GraphQL IDE"), - )))) + log::info!("GraphiQL endpoint: {}", endpoint); + Ok(Response::new(Body::from(playground_source( + GraphQLPlaygroundConfig::new(&endpoint).title("Tailcall - GraphQL IDE"), + )))) } fn not_found() -> Result> { - Ok(Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty())?) + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::empty())?) } fn create_request_context( - req: &Request, - server_ctx: &AppContext, + req: &Request, + server_ctx: &AppContext, ) -> RequestContext { - let upstream = server_ctx.blueprint.upstream.clone(); - let allowed = upstream.get_allowed_headers(); - let headers = create_allowed_headers(req.headers(), &allowed); - RequestContext::from(server_ctx).req_headers(headers) + let upstream = server_ctx.blueprint.upstream.clone(); + let allowed = upstream.get_allowed_headers(); + let headers = create_allowed_headers(req.headers(), &allowed); + RequestContext::from(server_ctx).req_headers(headers) } fn update_cache_control_header( - response: GraphQLResponse, - server_ctx: &AppContext, - req_ctx: Arc, + response: GraphQLResponse, + server_ctx: &AppContext, + req_ctx: Arc, ) -> GraphQLResponse { - if server_ctx.blueprint.server.enable_cache_control_header { - let ttl = req_ctx.get_min_max_age().unwrap_or(0); - let cache_public_flag = req_ctx.is_cache_public().unwrap_or(true); - return response.set_cache_control(ttl, cache_public_flag); - } - response + if server_ctx.blueprint.server.enable_cache_control_header { + let ttl = req_ctx.get_min_max_age().unwrap_or(0); + let cache_public_flag = req_ctx.is_cache_public().unwrap_or(true); + return response.set_cache_control(ttl, cache_public_flag); + } + response } -pub fn update_response_headers(resp: &mut hyper::Response, server_ctx: &AppContext) { - if !server_ctx.blueprint.server.response_headers.is_empty() { - resp - .headers_mut() - .extend(server_ctx.blueprint.server.response_headers.clone()); - } +pub fn update_response_headers( + resp: &mut hyper::Response, + server_ctx: &AppContext, +) { + if !server_ctx.blueprint.server.response_headers.is_empty() { + resp.headers_mut() + .extend(server_ctx.blueprint.server.response_headers.clone()); + } } pub async fn graphql_request( - req: Request, - server_ctx: &AppContext, + req: Request, + server_ctx: &AppContext, ) -> Result> { - let req_ctx = Arc::new(create_request_context(&req, server_ctx)); - let bytes = hyper::body::to_bytes(req.into_body()).await?; - let request = serde_json::from_slice::(&bytes); - match request { - Ok(request) => { - let mut response = request.data(req_ctx.clone()).execute(&server_ctx.schema).await; - response = update_cache_control_header(response, server_ctx, req_ctx); - let mut resp = response.to_response()?; - update_response_headers(&mut resp, server_ctx); - Ok(resp) - } - Err(err) => { - log::error!( - "Failed to parse request: {}", - String::from_utf8(bytes.to_vec()).unwrap() - ); + let req_ctx = Arc::new(create_request_context(&req, server_ctx)); + let bytes = hyper::body::to_bytes(req.into_body()).await?; + let request = serde_json::from_slice::(&bytes); + match request { + Ok(request) => { + let mut response = request + .data(req_ctx.clone()) + .execute(&server_ctx.schema) + .await; + response = update_cache_control_header(response, server_ctx, req_ctx); + let mut resp = response.to_response()?; + update_response_headers(&mut resp, server_ctx); + Ok(resp) + } + Err(err) => { + log::error!( + "Failed to parse request: {}", + String::from_utf8(bytes.to_vec()).unwrap() + ); - let mut response = async_graphql::Response::default(); - let server_error = ServerError::new(format!("Unexpected GraphQL Request: {}", err), None); - response.errors = vec![server_error]; + let mut response = async_graphql::Response::default(); + let server_error = + ServerError::new(format!("Unexpected GraphQL Request: {}", err), None); + response.errors = vec![server_error]; - Ok(GraphQLResponse::from(response).to_response()?) + Ok(GraphQLResponse::from(response).to_response()?) + } } - } } fn create_allowed_headers(headers: &HeaderMap, allowed: &BTreeSet) -> HeaderMap { - let mut new_headers = HeaderMap::new(); - for (k, v) in headers.iter() { - if allowed.contains(k.as_str()) { - new_headers.insert(k, v.clone()); + let mut new_headers = HeaderMap::new(); + for (k, v) in headers.iter() { + if allowed.contains(k.as_str()) { + new_headers.insert(k, v.clone()); + } } - } - new_headers + new_headers } pub async fn handle_request( - req: Request, - state: Arc>, + req: Request, + state: Arc>, ) -> Result> { - match *req.method() { - hyper::Method::POST - if state.blueprint.server.enable_showcase && req.uri().path().ends_with("/showcase/graphql") => - { - let server_ctx = match showcase_get_app_ctx::( - &req, - ( - state.universal_http_client.clone(), - DummyEnvIO, - None, - state.cache.clone(), - ), - ) - .await? - { - Ok(server_ctx) => server_ctx, - Err(res) => return Ok(res), - }; + match *req.method() { + hyper::Method::POST + if state.blueprint.server.enable_showcase + && req.uri().path().ends_with("/showcase/graphql") => + { + let server_ctx = match showcase_get_app_ctx::( + &req, + ( + state.universal_http_client.clone(), + DummyEnvIO, + None, + state.cache.clone(), + ), + ) + .await? + { + Ok(server_ctx) => server_ctx, + Err(res) => return Ok(res), + }; - graphql_request::(req, &server_ctx).await - } - hyper::Method::POST if req.uri().path().ends_with("/graphql") => { - graphql_request::(req, state.as_ref()).await + graphql_request::(req, &server_ctx).await + } + hyper::Method::POST if req.uri().path().ends_with("/graphql") => { + graphql_request::(req, state.as_ref()).await + } + hyper::Method::GET if state.blueprint.server.enable_graphiql => graphiql(&req), + _ => not_found(), } - hyper::Method::GET if state.blueprint.server.enable_graphiql => graphiql(&req), - _ => not_found(), - } } diff --git a/src/http/request_template.rs b/src/http/request_template.rs index d27080cc1e7..8f88a2a3e4b 100644 --- a/src/http/request_template.rs +++ b/src/http/request_template.rs @@ -18,577 +18,646 @@ use crate::path::PathString; /// To call `to_request` we need to provide a context. #[derive(Setters, Debug, Clone)] pub struct RequestTemplate { - pub root_url: Mustache, - pub query: Vec<(String, Mustache)>, - pub method: reqwest::Method, - pub headers: MustacheHeaders, - pub body_path: Option, - pub endpoint: Endpoint, - pub encoding: Encoding, + pub root_url: Mustache, + pub query: Vec<(String, Mustache)>, + pub method: reqwest::Method, + pub headers: MustacheHeaders, + pub body_path: Option, + pub endpoint: Endpoint, + pub encoding: Encoding, } impl RequestTemplate { - /// Creates a URL for the context - /// Fills in all the mustache templates with required values. - fn create_url(&self, ctx: &C) -> anyhow::Result { - let mut url = url::Url::parse(self.root_url.render(ctx).as_str())?; - if self.query.is_empty() && self.root_url.is_const() { - return Ok(url); - } - let extra_qp = self.query.iter().filter_map(|(k, v)| { - let value = v.render(ctx); - if value.is_empty() { - None - } else { - Some((Cow::Borrowed(k.as_str()), Cow::Owned(value))) - } - }); - - let base_qp = url - .query_pairs() - .filter_map(|(k, v)| if v.is_empty() { None } else { Some((k, v)) }); - - let qp_string = base_qp - .chain(extra_qp) - .map(|(k, v)| format!("{}={}", k, v)) - .fold("".to_string(), |str, item| { - if str.is_empty() { - item + /// Creates a URL for the context + /// Fills in all the mustache templates with required values. + fn create_url(&self, ctx: &C) -> anyhow::Result { + let mut url = url::Url::parse(self.root_url.render(ctx).as_str())?; + if self.query.is_empty() && self.root_url.is_const() { + return Ok(url); + } + let extra_qp = self.query.iter().filter_map(|(k, v)| { + let value = v.render(ctx); + if value.is_empty() { + None + } else { + Some((Cow::Borrowed(k.as_str()), Cow::Owned(value))) + } + }); + + let base_qp = url + .query_pairs() + .filter_map(|(k, v)| if v.is_empty() { None } else { Some((k, v)) }); + + let qp_string = base_qp + .chain(extra_qp) + .map(|(k, v)| format!("{}={}", k, v)) + .fold("".to_string(), |str, item| { + if str.is_empty() { + item + } else { + format!("{}&{}", str, item) + } + }); + + if qp_string.is_empty() { + url.set_query(None); + Ok(url) } else { - format!("{}&{}", str, item) + url.set_query(Some(qp_string.as_str())); + Ok(url) } - }); - - if qp_string.is_empty() { - url.set_query(None); - Ok(url) - } else { - url.set_query(Some(qp_string.as_str())); - Ok(url) } - } - - /// Checks if the template has any mustache templates or not - /// Returns true if there are not templates - pub fn is_const(&self) -> bool { - self.root_url.is_const() - && self.body_path.as_ref().map_or(true, Mustache::is_const) - && self.query.iter().all(|(_, v)| v.is_const()) - && self.headers.iter().all(|(_, v)| v.is_const()) - } - - /// Creates a HeaderMap for the context - fn create_headers(&self, ctx: &C) -> HeaderMap { - let mut header_map = HeaderMap::new(); - - for (k, v) in &self.headers { - if let Ok(header_value) = HeaderValue::from_str(&v.render(ctx)) { - header_map.insert(k, header_value); - } + + /// Checks if the template has any mustache templates or not + /// Returns true if there are not templates + pub fn is_const(&self) -> bool { + self.root_url.is_const() + && self.body_path.as_ref().map_or(true, Mustache::is_const) + && self.query.iter().all(|(_, v)| v.is_const()) + && self.headers.iter().all(|(_, v)| v.is_const()) } - header_map - } - - /// Creates a Request for the given context - pub fn to_request(&self, ctx: &C) -> anyhow::Result { - // Create url - let url = self.create_url(ctx)?; - let method = self.method.clone(); - let mut req = reqwest::Request::new(method, url); - req = self.set_headers(req, ctx); - req = self.set_body(req, ctx)?; - - Ok(req) - } - - /// Sets the body for the request - fn set_body( - &self, - mut req: reqwest::Request, - ctx: &C, - ) -> anyhow::Result { - if let Some(body_path) = &self.body_path { - match &self.encoding { - Encoding::ApplicationJson => { - req.body_mut().replace(body_path.render(ctx).into()); + /// Creates a HeaderMap for the context + fn create_headers(&self, ctx: &C) -> HeaderMap { + let mut header_map = HeaderMap::new(); + + for (k, v) in &self.headers { + if let Ok(header_value) = HeaderValue::from_str(&v.render(ctx)) { + header_map.insert(k, header_value); + } } - Encoding::ApplicationXWwwFormUrlencoded => { - // TODO: this is a performance bottleneck - // We first encode everything to string and then back to form-urlencoded - let body: String = body_path.render(ctx); - let form_data = match serde_json::from_str::(&body) { - Ok(deserialized_data) => serde_urlencoded::to_string(deserialized_data)?, - Err(_) => body, - }; - - req.body_mut().replace(form_data.into()); + + header_map + } + + /// Creates a Request for the given context + pub fn to_request( + &self, + ctx: &C, + ) -> anyhow::Result { + // Create url + let url = self.create_url(ctx)?; + let method = self.method.clone(); + let mut req = reqwest::Request::new(method, url); + req = self.set_headers(req, ctx); + req = self.set_body(req, ctx)?; + + Ok(req) + } + + /// Sets the body for the request + fn set_body( + &self, + mut req: reqwest::Request, + ctx: &C, + ) -> anyhow::Result { + if let Some(body_path) = &self.body_path { + match &self.encoding { + Encoding::ApplicationJson => { + req.body_mut().replace(body_path.render(ctx).into()); + } + Encoding::ApplicationXWwwFormUrlencoded => { + // TODO: this is a performance bottleneck + // We first encode everything to string and then back to form-urlencoded + let body: String = body_path.render(ctx); + let form_data = match serde_json::from_str::(&body) { + Ok(deserialized_data) => serde_urlencoded::to_string(deserialized_data)?, + Err(_) => body, + }; + + req.body_mut().replace(form_data.into()); + } + } + } + Ok(req) + } + + /// Sets the headers for the request + fn set_headers( + &self, + mut req: reqwest::Request, + ctx: &C, + ) -> reqwest::Request { + let headers = self.create_headers(ctx); + if !headers.is_empty() { + req.headers_mut().extend(headers); } - } + + let headers = req.headers_mut(); + // We want to set the header value based on encoding + headers.insert( + reqwest::header::CONTENT_TYPE, + match self.encoding { + Encoding::ApplicationJson => HeaderValue::from_static("application/json"), + Encoding::ApplicationXWwwFormUrlencoded => { + HeaderValue::from_static("application/x-www-form-urlencoded") + } + }, + ); + + headers.extend(ctx.headers().to_owned()); + req } - Ok(req) - } - - /// Sets the headers for the request - fn set_headers(&self, mut req: reqwest::Request, ctx: &C) -> reqwest::Request { - let headers = self.create_headers(ctx); - if !headers.is_empty() { - req.headers_mut().extend(headers); + + pub fn new(root_url: &str) -> anyhow::Result { + Ok(Self { + root_url: Mustache::parse(root_url)?, + query: Default::default(), + method: reqwest::Method::GET, + headers: Default::default(), + body_path: Default::default(), + endpoint: Endpoint::new(root_url.to_string()), + encoding: Default::default(), + }) } - let headers = req.headers_mut(); - // We want to set the header value based on encoding - headers.insert( - reqwest::header::CONTENT_TYPE, - match self.encoding { - Encoding::ApplicationJson => HeaderValue::from_static("application/json"), - Encoding::ApplicationXWwwFormUrlencoded => HeaderValue::from_static("application/x-www-form-urlencoded"), - }, - ); - - headers.extend(ctx.headers().to_owned()); - req - } - - pub fn new(root_url: &str) -> anyhow::Result { - Ok(Self { - root_url: Mustache::parse(root_url)?, - query: Default::default(), - method: reqwest::Method::GET, - headers: Default::default(), - body_path: Default::default(), - endpoint: Endpoint::new(root_url.to_string()), - encoding: Default::default(), - }) - } - - pub fn form_encoded_url(url: &str) -> anyhow::Result { - Ok(Self::new(url)?.encoding(Encoding::ApplicationXWwwFormUrlencoded)) - } + pub fn form_encoded_url(url: &str) -> anyhow::Result { + Ok(Self::new(url)?.encoding(Encoding::ApplicationXWwwFormUrlencoded)) + } } impl TryFrom for RequestTemplate { - type Error = anyhow::Error; - fn try_from(endpoint: Endpoint) -> anyhow::Result { - let path = Mustache::parse(endpoint.path.as_str())?; - let query = endpoint - .query - .iter() - .map(|(k, v)| Ok((k.to_owned(), Mustache::parse(v.as_str())?))) - .collect::>>()?; - let method = endpoint.method.clone().to_hyper(); - let headers = endpoint - .headers - .iter() - .map(|(k, v)| Ok((k.clone(), Mustache::parse(v.to_str()?)?))) - .collect::>>()?; - - let body = if let Some(body) = &endpoint.body { - Some(Mustache::parse(body.as_str())?) - } else { - None - }; - let encoding = endpoint.encoding.clone(); - - Ok(Self { root_url: path, query, method, headers, body_path: body, endpoint, encoding }) - } + type Error = anyhow::Error; + fn try_from(endpoint: Endpoint) -> anyhow::Result { + let path = Mustache::parse(endpoint.path.as_str())?; + let query = endpoint + .query + .iter() + .map(|(k, v)| Ok((k.to_owned(), Mustache::parse(v.as_str())?))) + .collect::>>()?; + let method = endpoint.method.clone().to_hyper(); + let headers = endpoint + .headers + .iter() + .map(|(k, v)| Ok((k.clone(), Mustache::parse(v.to_str()?)?))) + .collect::>>()?; + + let body = if let Some(body) = &endpoint.body { + Some(Mustache::parse(body.as_str())?) + } else { + None + }; + let encoding = endpoint.encoding.clone(); + + Ok(Self { + root_url: path, + query, + method, + headers, + body_path: body, + endpoint, + encoding, + }) + } } #[cfg(test)] mod tests { - use std::borrow::Cow; - - use derive_setters::Setters; - use hyper::header::HeaderName; - use hyper::HeaderMap; - use pretty_assertions::assert_eq; - use serde_json::json; - - use super::RequestTemplate; - use crate::has_headers::HasHeaders; - use crate::mustache::Mustache; - use crate::path::PathString; - - #[derive(Setters)] - struct Context { - pub value: serde_json::Value, - pub headers: HeaderMap, - } - - impl Default for Context { - fn default() -> Self { - Self { value: serde_json::Value::Null, headers: HeaderMap::new() } + use std::borrow::Cow; + + use derive_setters::Setters; + use hyper::header::HeaderName; + use hyper::HeaderMap; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::RequestTemplate; + use crate::has_headers::HasHeaders; + use crate::mustache::Mustache; + use crate::path::PathString; + + #[derive(Setters)] + struct Context { + pub value: serde_json::Value, + pub headers: HeaderMap, } - } - impl crate::path::PathString for Context { - fn path_string>(&self, parts: &[T]) -> Option> { - self.value.path_string(parts) + + impl Default for Context { + fn default() -> Self { + Self { value: serde_json::Value::Null, headers: HeaderMap::new() } + } } - } - impl crate::has_headers::HasHeaders for Context { - fn headers(&self) -> &HeaderMap { - &self.headers + impl crate::path::PathString for Context { + fn path_string>(&self, parts: &[T]) -> Option> { + self.value.path_string(parts) + } } - } - - impl RequestTemplate { - fn to_body(&self, ctx: &C) -> anyhow::Result { - let body = self - .to_request(ctx)? - .body() - .and_then(|a| a.as_bytes()) - .map(|a| a.to_vec()) - .unwrap_or_default(); - - Ok(std::str::from_utf8(&body)?.to_string()) + impl crate::has_headers::HasHeaders for Context { + fn headers(&self) -> &HeaderMap { + &self.headers + } } - } - #[test] - fn test_url() { - let tmpl = RequestTemplate::new("http://localhost:3000/").unwrap(); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/"); - } - - #[test] - fn test_url_path() { - let tmpl = RequestTemplate::new("http://localhost:3000/foo/bar").unwrap(); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar"); - } - - #[test] - fn test_url_path_template() { - let tmpl = RequestTemplate::new("http://localhost:3000/foo/{{bar.baz}}").unwrap(); - let ctx = Context::default().value(json!({ - "bar": { - "baz": "bar" - } - })); - - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar"); - } - #[test] - fn test_url_path_template_multi() { - let tmpl = RequestTemplate::new("http://localhost:3000/foo/{{bar.baz}}/boozes/{{bar.booz}}").unwrap(); - let ctx = Context::default().value(json!({ - "bar": { - "baz": "bar", - "booz": 1 - } - })); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar/boozes/1"); - } - #[test] - fn test_url_query_params() { - let query = vec![ - ("foo".to_string(), Mustache::parse("0").unwrap()), - ("bar".to_string(), Mustache::parse("1").unwrap()), - ("baz".to_string(), Mustache::parse("2").unwrap()), - ]; - let tmpl = RequestTemplate::new("http://localhost:3000").unwrap().query(query); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/?foo=0&bar=1&baz=2"); - } - #[test] - fn test_url_query_params_template() { - let query = vec![ - ("foo".to_string(), Mustache::parse("0").unwrap()), - ("bar".to_string(), Mustache::parse("{{bar.id}}").unwrap()), - ("baz".to_string(), Mustache::parse("{{baz.id}}").unwrap()), - ]; - let tmpl = RequestTemplate::new("http://localhost:3000/").unwrap().query(query); - let ctx = Context::default().value(json!({ - "bar": { - "id": 1 - }, - "baz": { - "id": 2 - } - })); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/?foo=0&bar=1&baz=2"); - } - - #[test] - fn test_headers() { - let headers = vec![ - (HeaderName::from_static("foo"), Mustache::parse("foo").unwrap()), - (HeaderName::from_static("bar"), Mustache::parse("bar").unwrap()), - (HeaderName::from_static("baz"), Mustache::parse("baz").unwrap()), - ]; - let tmpl = RequestTemplate::new("http://localhost:3000").unwrap().headers(headers); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.headers().get("foo").unwrap(), "foo"); - assert_eq!(req.headers().get("bar").unwrap(), "bar"); - assert_eq!(req.headers().get("baz").unwrap(), "baz"); - } - #[test] - fn test_header_template() { - let headers = vec![ - (HeaderName::from_static("foo"), Mustache::parse("0").unwrap()), - (HeaderName::from_static("bar"), Mustache::parse("{{bar.id}}").unwrap()), - (HeaderName::from_static("baz"), Mustache::parse("{{baz.id}}").unwrap()), - ]; - let tmpl = RequestTemplate::new("http://localhost:3000").unwrap().headers(headers); - let ctx = Context::default().value(json!({ - "bar": { - "id": 1 - }, - "baz": { - "id": 2 - } - })); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.headers().get("foo").unwrap(), "0"); - assert_eq!(req.headers().get("bar").unwrap(), "1"); - assert_eq!(req.headers().get("baz").unwrap(), "2"); - } - #[test] - fn test_header_encoding_application_json() { - let tmpl = RequestTemplate::new("http://localhost:3000") - .unwrap() - .encoding(crate::config::Encoding::ApplicationJson); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.headers().get("Content-Type").unwrap(), "application/json"); - } - #[test] - fn test_header_encoding_application_x_www_form_urlencoded() { - let tmpl = RequestTemplate::new("http://localhost:3000") - .unwrap() - .encoding(crate::config::Encoding::ApplicationXWwwFormUrlencoded); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!( - req.headers().get("Content-Type").unwrap(), - "application/x-www-form-urlencoded" - ); - } - #[test] - fn test_method() { - let tmpl = RequestTemplate::new("http://localhost:3000") - .unwrap() - .method(reqwest::Method::POST); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.method(), reqwest::Method::POST); - } - #[test] - fn test_body() { - let tmpl = RequestTemplate::new("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse("foo").unwrap())); - let ctx = Context::default(); - let body = tmpl.to_body(&ctx).unwrap(); - assert_eq!(body, "foo"); - } - #[test] - fn test_body_template() { - let tmpl = RequestTemplate::new("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse("{{foo.bar}}").unwrap())); - let ctx = Context::default().value(json!({ - "foo": { - "bar": "baz" - } - })); - let body = tmpl.to_body(&ctx).unwrap(); - assert_eq!(body, "baz"); - } - #[test] - fn test_body_encoding_application_json() { - let tmpl = RequestTemplate::new("http://localhost:3000") - .unwrap() - .encoding(crate::config::Encoding::ApplicationJson) - .body_path(Some(Mustache::parse("{{foo.bar}}").unwrap())); - let ctx = Context::default().value(json!({ - "foo": { - "bar": "baz" - } - })); - let body = tmpl.to_body(&ctx).unwrap(); - assert_eq!(body, "baz"); - } - - mod endpoint { - use hyper::HeaderMap; - use serde_json::json; - use crate::http::request_template::tests::Context; - use crate::http::RequestTemplate; + impl RequestTemplate { + fn to_body(&self, ctx: &C) -> anyhow::Result { + let body = self + .to_request(ctx)? + .body() + .and_then(|a| a.as_bytes()) + .map(|a| a.to_vec()) + .unwrap_or_default(); - #[test] - fn test_from_endpoint() { - let mut headers = HeaderMap::new(); - headers.insert("foo", "bar".parse().unwrap()); - let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/".to_string()) - .method(crate::http::Method::POST) - .headers(headers) - .body(Some("foo".into())); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.method(), reqwest::Method::POST); - assert_eq!(req.headers().get("foo").unwrap(), "bar"); - let body = req.body().unwrap().as_bytes().unwrap().to_owned(); - assert_eq!(body, "foo".as_bytes()); - assert_eq!(req.url().to_string(), "http://localhost:3000/"); + Ok(std::str::from_utf8(&body)?.to_string()) + } } #[test] - fn test_from_endpoint_template() { - let mut headers = HeaderMap::new(); - headers.insert("foo", "{{foo.header}}".parse().unwrap()); - let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/{{foo.bar}}".to_string()) - .method(crate::http::Method::POST) - .query(vec![("foo".to_string(), "{{foo.bar}}".to_string())]) - .headers(headers) - .body(Some("{{foo.bar}}".into())); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let ctx = Context::default().value(json!({ - "foo": { - "bar": "baz", - "header": "abc" - } - })); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.method(), reqwest::Method::POST); - assert_eq!(req.headers().get("foo").unwrap(), "abc"); - let body = req.body().unwrap().as_bytes().unwrap().to_owned(); - assert_eq!(body, "baz".as_bytes()); - assert_eq!(req.url().to_string(), "http://localhost:3000/baz?foo=baz"); + fn test_url() { + let tmpl = RequestTemplate::new("http://localhost:3000/").unwrap(); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.url().to_string(), "http://localhost:3000/"); } #[test] - fn test_from_endpoint_template_null_value() { - let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/?a={{args.a}}".to_string()); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/"); + fn test_url_path() { + let tmpl = RequestTemplate::new("http://localhost:3000/foo/bar").unwrap(); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar"); } #[test] - fn test_from_endpoint_template_with_query_null_value() { - let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/?a={{args.a}}&q=1".to_string()).query(vec![ - ("b".to_string(), "1".to_string()), - ("c".to_string(), "{{args.c}}".to_string()), - ]); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let ctx = Context::default(); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/?q=1&b=1"); + fn test_url_path_template() { + let tmpl = RequestTemplate::new("http://localhost:3000/foo/{{bar.baz}}").unwrap(); + let ctx = Context::default().value(json!({ + "bar": { + "baz": "bar" + } + })); + + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.url().to_string(), "http://localhost:3000/foo/bar"); + } + #[test] + fn test_url_path_template_multi() { + let tmpl = + RequestTemplate::new("http://localhost:3000/foo/{{bar.baz}}/boozes/{{bar.booz}}") + .unwrap(); + let ctx = Context::default().value(json!({ + "bar": { + "baz": "bar", + "booz": 1 + } + })); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.url().to_string(), + "http://localhost:3000/foo/bar/boozes/1" + ); + } + #[test] + fn test_url_query_params() { + let query = vec![ + ("foo".to_string(), Mustache::parse("0").unwrap()), + ("bar".to_string(), Mustache::parse("1").unwrap()), + ("baz".to_string(), Mustache::parse("2").unwrap()), + ]; + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .query(query); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.url().to_string(), + "http://localhost:3000/?foo=0&bar=1&baz=2" + ); + } + #[test] + fn test_url_query_params_template() { + let query = vec![ + ("foo".to_string(), Mustache::parse("0").unwrap()), + ("bar".to_string(), Mustache::parse("{{bar.id}}").unwrap()), + ("baz".to_string(), Mustache::parse("{{baz.id}}").unwrap()), + ]; + let tmpl = RequestTemplate::new("http://localhost:3000/") + .unwrap() + .query(query); + let ctx = Context::default().value(json!({ + "bar": { + "id": 1 + }, + "baz": { + "id": 2 + } + })); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.url().to_string(), + "http://localhost:3000/?foo=0&bar=1&baz=2" + ); } #[test] - fn test_from_endpoint_template_few_null_value() { - let endpoint = crate::endpoint::Endpoint::new( + fn test_headers() { + let headers = vec![ + ( + HeaderName::from_static("foo"), + Mustache::parse("foo").unwrap(), + ), + ( + HeaderName::from_static("bar"), + Mustache::parse("bar").unwrap(), + ), + ( + HeaderName::from_static("baz"), + Mustache::parse("baz").unwrap(), + ), + ]; + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .headers(headers); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.headers().get("foo").unwrap(), "foo"); + assert_eq!(req.headers().get("bar").unwrap(), "bar"); + assert_eq!(req.headers().get("baz").unwrap(), "baz"); + } + #[test] + fn test_header_template() { + let headers = vec![ + ( + HeaderName::from_static("foo"), + Mustache::parse("0").unwrap(), + ), + ( + HeaderName::from_static("bar"), + Mustache::parse("{{bar.id}}").unwrap(), + ), + ( + HeaderName::from_static("baz"), + Mustache::parse("{{baz.id}}").unwrap(), + ), + ]; + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .headers(headers); + let ctx = Context::default().value(json!({ + "bar": { + "id": 1 + }, + "baz": { + "id": 2 + } + })); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.headers().get("foo").unwrap(), "0"); + assert_eq!(req.headers().get("bar").unwrap(), "1"); + assert_eq!(req.headers().get("baz").unwrap(), "2"); + } + #[test] + fn test_header_encoding_application_json() { + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .encoding(crate::config::Encoding::ApplicationJson); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.headers().get("Content-Type").unwrap(), + "application/json" + ); + } + #[test] + fn test_header_encoding_application_x_www_form_urlencoded() { + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .encoding(crate::config::Encoding::ApplicationXWwwFormUrlencoded); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.headers().get("Content-Type").unwrap(), + "application/x-www-form-urlencoded" + ); + } + #[test] + fn test_method() { + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .method(reqwest::Method::POST); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.method(), reqwest::Method::POST); + } + #[test] + fn test_body() { + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse("foo").unwrap())); + let ctx = Context::default(); + let body = tmpl.to_body(&ctx).unwrap(); + assert_eq!(body, "foo"); + } + #[test] + fn test_body_template() { + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse("{{foo.bar}}").unwrap())); + let ctx = Context::default().value(json!({ + "foo": { + "bar": "baz" + } + })); + let body = tmpl.to_body(&ctx).unwrap(); + assert_eq!(body, "baz"); + } + #[test] + fn test_body_encoding_application_json() { + let tmpl = RequestTemplate::new("http://localhost:3000") + .unwrap() + .encoding(crate::config::Encoding::ApplicationJson) + .body_path(Some(Mustache::parse("{{foo.bar}}").unwrap())); + let ctx = Context::default().value(json!({ + "foo": { + "bar": "baz" + } + })); + let body = tmpl.to_body(&ctx).unwrap(); + assert_eq!(body, "baz"); + } + + mod endpoint { + use hyper::HeaderMap; + use serde_json::json; + + use crate::http::request_template::tests::Context; + use crate::http::RequestTemplate; + + #[test] + fn test_from_endpoint() { + let mut headers = HeaderMap::new(); + headers.insert("foo", "bar".parse().unwrap()); + let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/".to_string()) + .method(crate::http::Method::POST) + .headers(headers) + .body(Some("foo".into())); + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.method(), reqwest::Method::POST); + assert_eq!(req.headers().get("foo").unwrap(), "bar"); + let body = req.body().unwrap().as_bytes().unwrap().to_owned(); + assert_eq!(body, "foo".as_bytes()); + assert_eq!(req.url().to_string(), "http://localhost:3000/"); + } + #[test] + fn test_from_endpoint_template() { + let mut headers = HeaderMap::new(); + headers.insert("foo", "{{foo.header}}".parse().unwrap()); + let endpoint = + crate::endpoint::Endpoint::new("http://localhost:3000/{{foo.bar}}".to_string()) + .method(crate::http::Method::POST) + .query(vec![("foo".to_string(), "{{foo.bar}}".to_string())]) + .headers(headers) + .body(Some("{{foo.bar}}".into())); + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let ctx = Context::default().value(json!({ + "foo": { + "bar": "baz", + "header": "abc" + } + })); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.method(), reqwest::Method::POST); + assert_eq!(req.headers().get("foo").unwrap(), "abc"); + let body = req.body().unwrap().as_bytes().unwrap().to_owned(); + assert_eq!(body, "baz".as_bytes()); + assert_eq!(req.url().to_string(), "http://localhost:3000/baz?foo=baz"); + } + + #[test] + fn test_from_endpoint_template_null_value() { + let endpoint = + crate::endpoint::Endpoint::new("http://localhost:3000/?a={{args.a}}".to_string()); + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.url().to_string(), "http://localhost:3000/"); + } + + #[test] + fn test_from_endpoint_template_with_query_null_value() { + let endpoint = crate::endpoint::Endpoint::new( + "http://localhost:3000/?a={{args.a}}&q=1".to_string(), + ) + .query(vec![ + ("b".to_string(), "1".to_string()), + ("c".to_string(), "{{args.c}}".to_string()), + ]); + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let ctx = Context::default(); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.url().to_string(), "http://localhost:3000/?q=1&b=1"); + } + + #[test] + fn test_from_endpoint_template_few_null_value() { + let endpoint = crate::endpoint::Endpoint::new( "http://localhost:3000/{{args.b}}?a={{args.a}}&b={{args.b}}&c={{args.c}}&d={{args.d}}".to_string(), ); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let ctx = Context::default().value(json!({ - "args": { - "b": "foo", - "d": "bar" + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let ctx = Context::default().value(json!({ + "args": { + "b": "foo", + "d": "bar" + } + })); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.url().to_string(), + "http://localhost:3000/foo?b=foo&d=bar" + ); } - })); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/foo?b=foo&d=bar"); - } - #[test] - fn test_from_endpoint_template_few_null_value_mixed() { - let endpoint = crate::endpoint::Endpoint::new( + #[test] + fn test_from_endpoint_template_few_null_value_mixed() { + let endpoint = crate::endpoint::Endpoint::new( "http://localhost:3000/{{args.b}}?a={{args.a}}&b={{args.b}}&c={{args.c}}&d={{args.d}}".to_string(), ) .query(vec![ ("e".to_string(), "{{args.e}}".to_string()), ("f".to_string(), "{{args.f}}".to_string()), ]); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let ctx = Context::default().value(json!({ - "args": { - "b": "foo", - "d": "bar", - "f": "baz" + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let ctx = Context::default().value(json!({ + "args": { + "b": "foo", + "d": "bar", + "f": "baz" + } + })); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!( + req.url().to_string(), + "http://localhost:3000/foo?b=foo&d=bar&f=baz" + ); + } + #[test] + fn test_headers_forward() { + let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/".to_string()); + let tmpl = RequestTemplate::try_from(endpoint).unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("baz", "qux".parse().unwrap()); + let ctx = Context::default().headers(headers); + let req = tmpl.to_request(&ctx).unwrap(); + assert_eq!(req.headers().get("baz").unwrap(), "qux"); } - })); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.url().to_string(), "http://localhost:3000/foo?b=foo&d=bar&f=baz"); - } - #[test] - fn test_headers_forward() { - let endpoint = crate::endpoint::Endpoint::new("http://localhost:3000/".to_string()); - let tmpl = RequestTemplate::try_from(endpoint).unwrap(); - let mut headers = HeaderMap::new(); - headers.insert("baz", "qux".parse().unwrap()); - let ctx = Context::default().headers(headers); - let req = tmpl.to_request(&ctx).unwrap(); - assert_eq!(req.headers().get("baz").unwrap(), "qux"); } - } - mod form_encoded_url { - use serde_json::json; - - use crate::http::request_template::tests::Context; - use crate::http::RequestTemplate; - use crate::mustache::Mustache; - - #[test] - fn test_with_string() { - let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse("{{foo.bar}}").unwrap())); - let ctx = Context::default().value(json!({"foo": {"bar": "baz"}})); - let request_body = tmpl.to_body(&ctx); - let body = request_body.unwrap(); - assert_eq!(body, "baz"); - } - #[test] - fn test_with_json_template() { - let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse(r#"{"foo": "{{baz}}"}"#).unwrap())); - let ctx = Context::default().value(json!({"baz": "baz"})); - let body = tmpl.to_body(&ctx).unwrap(); - assert_eq!(body, "foo=%7B%7Bbaz%7D%7D"); - } + mod form_encoded_url { + use serde_json::json; + + use crate::http::request_template::tests::Context; + use crate::http::RequestTemplate; + use crate::mustache::Mustache; + + #[test] + fn test_with_string() { + let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse("{{foo.bar}}").unwrap())); + let ctx = Context::default().value(json!({"foo": {"bar": "baz"}})); + let request_body = tmpl.to_body(&ctx); + let body = request_body.unwrap(); + assert_eq!(body, "baz"); + } + #[test] + fn test_with_json_template() { + let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse(r#"{"foo": "{{baz}}"}"#).unwrap())); + let ctx = Context::default().value(json!({"baz": "baz"})); + let body = tmpl.to_body(&ctx).unwrap(); + assert_eq!(body, "foo=%7B%7Bbaz%7D%7D"); + } - #[test] - fn test_with_json_body() { - let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse("{{foo}}").unwrap())); - let ctx = Context::default().value(json!({"foo": {"bar": "baz"}})); - let body = tmpl.to_body(&ctx).unwrap(); - assert_eq!(body, "bar=baz"); - } + #[test] + fn test_with_json_body() { + let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse("{{foo}}").unwrap())); + let ctx = Context::default().value(json!({"foo": {"bar": "baz"}})); + let body = tmpl.to_body(&ctx).unwrap(); + assert_eq!(body, "bar=baz"); + } - #[test] - fn test_with_json_body_nested() { - let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse("{{a}}").unwrap())); - let ctx = Context::default().value(json!({"a": {"special chars": "a !@#$%^&*()<>?:{}-=1[];',./"}})); - let a = tmpl.to_body(&ctx).unwrap(); - let e = "special+chars=a+%21%40%23%24%25%5E%26*%28%29%3C%3E%3F%3A%7B%7D-%3D1%5B%5D%3B%27%2C.%2F"; - assert_eq!(a, e); - } - #[test] - fn test_with_mustache_literal() { - let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") - .unwrap() - .body_path(Some(Mustache::parse(r#"{"foo": "bar"}"#).unwrap())); - let ctx = Context::default().value(json!({})); - let body = tmpl.to_body(&ctx).unwrap(); - assert_eq!(body, r#"foo=bar"#); + #[test] + fn test_with_json_body_nested() { + let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse("{{a}}").unwrap())); + let ctx = Context::default() + .value(json!({"a": {"special chars": "a !@#$%^&*()<>?:{}-=1[];',./"}})); + let a = tmpl.to_body(&ctx).unwrap(); + let e = "special+chars=a+%21%40%23%24%25%5E%26*%28%29%3C%3E%3F%3A%7B%7D-%3D1%5B%5D%3B%27%2C.%2F"; + assert_eq!(a, e); + } + #[test] + fn test_with_mustache_literal() { + let tmpl = RequestTemplate::form_encoded_url("http://localhost:3000") + .unwrap() + .body_path(Some(Mustache::parse(r#"{"foo": "bar"}"#).unwrap())); + let ctx = Context::default().value(json!({})); + let body = tmpl.to_body(&ctx).unwrap(); + assert_eq!(body, r#"foo=bar"#); + } } - } } diff --git a/src/http/response.rs b/src/http/response.rs index 6e1cb8e5408..057c629953b 100644 --- a/src/http/response.rs +++ b/src/http/response.rs @@ -6,41 +6,52 @@ use crate::grpc::protobuf::ProtobufOperation; #[derive(Clone, Debug, Default, Setters)] pub struct Response { - pub status: reqwest::StatusCode, - pub headers: reqwest::header::HeaderMap, - pub body: Body, + pub status: reqwest::StatusCode, + pub headers: reqwest::header::HeaderMap, + pub body: Body, } impl Response { - pub async fn from_reqwest(resp: reqwest::Response) -> Result { - let status = resp.status(); - let headers = resp.headers().to_owned(); - let body = resp.bytes().await?; - Ok(Response { status, headers, body }) - } - pub fn empty() -> Self { - Response { status: reqwest::StatusCode::OK, headers: reqwest::header::HeaderMap::default(), body: Bytes::new() } - } + pub async fn from_reqwest(resp: reqwest::Response) -> Result { + let status = resp.status(); + let headers = resp.headers().to_owned(); + let body = resp.bytes().await?; + Ok(Response { status, headers, body }) + } + pub fn empty() -> Self { + Response { + status: reqwest::StatusCode::OK, + headers: reqwest::header::HeaderMap::default(), + body: Bytes::new(), + } + } - pub fn to_json(self) -> Result> { - let mut resp = Response::default(); - let body = serde_json::from_slice::(&self.body)?; - resp.body = body; - resp.status = self.status; - resp.headers = self.headers; - Ok(resp) - } + pub fn to_json(self) -> Result> { + let mut resp = Response::default(); + let body = serde_json::from_slice::(&self.body)?; + resp.body = body; + resp.status = self.status; + resp.headers = self.headers; + Ok(resp) + } - pub fn to_grpc_value(self, operation: &ProtobufOperation) -> Result> { - let mut resp = Response::default(); - let body = operation.convert_output(&self.body)?; - resp.body = body; - resp.status = self.status; - resp.headers = self.headers; - Ok(resp) - } + pub fn to_grpc_value( + self, + operation: &ProtobufOperation, + ) -> Result> { + let mut resp = Response::default(); + let body = operation.convert_output(&self.body)?; + resp.body = body; + resp.status = self.status; + resp.headers = self.headers; + Ok(resp) + } - pub fn to_resp_string(self) -> Result> { - Ok(Response:: { body: String::from_utf8(self.body.to_vec())?, status: self.status, headers: self.headers }) - } + pub fn to_resp_string(self) -> Result> { + Ok(Response:: { + body: String::from_utf8(self.body.to_vec())?, + status: self.status, + headers: self.headers, + }) + } } diff --git a/src/http/showcase.rs b/src/http/showcase.rs index c29b77ca668..a95a948f048 100644 --- a/src/http/showcase.rs +++ b/src/http/showcase.rs @@ -15,78 +15,84 @@ use crate::{EntityCache, EnvIO, FileIO, HttpIO}; pub struct DummyFileIO; impl FileIO for DummyFileIO { - async fn write<'a>(&'a self, _file_path: &'a str, _content: &'a [u8]) -> anyhow::Result<()> { - Err(anyhow!("DummyFileIO")) - } + async fn write<'a>(&'a self, _file_path: &'a str, _content: &'a [u8]) -> anyhow::Result<()> { + Err(anyhow!("DummyFileIO")) + } - async fn read<'a>(&'a self, _file_path: &'a str) -> anyhow::Result { - Err(anyhow!("DummyFileIO")) - } + async fn read<'a>(&'a self, _file_path: &'a str) -> anyhow::Result { + Err(anyhow!("DummyFileIO")) + } } pub struct DummyEnvIO; impl EnvIO for DummyEnvIO { - fn get(&self, _key: &str) -> Option { - None - } + fn get(&self, _key: &str) -> Option { + None + } } pub async fn showcase_get_app_ctx< - T: DeserializeOwned + GraphQLRequestLike, - Http: HttpIO + Clone, - File: FileIO, - Env: EnvIO, + T: DeserializeOwned + GraphQLRequestLike, + Http: HttpIO + Clone, + File: FileIO, + Env: EnvIO, >( - req: &Request, - (http, env, file, cache): (Http, Env, Option, Arc), + req: &Request, + (http, env, file, cache): (Http, Env, Option, Arc), ) -> Result, Response>> { - let url = Url::parse(&req.uri().to_string())?; - let mut query = url.query_pairs(); + let url = Url::parse(&req.uri().to_string())?; + let mut query = url.query_pairs(); - let config_url = if let Some(pair) = query.find(|x| x.0 == "config") { - pair.1 - } else { - let mut response = async_graphql::Response::default(); - let server_error = ServerError::new("No Config URL specified", None); - response.errors = vec![server_error]; - return Ok(Err(GraphQLResponse::from(response).to_response()?)); - }; + let config_url = if let Some(pair) = query.find(|x| x.0 == "config") { + pair.1 + } else { + let mut response = async_graphql::Response::default(); + let server_error = ServerError::new("No Config URL specified", None); + response.errors = vec![server_error]; + return Ok(Err(GraphQLResponse::from(response).to_response()?)); + }; - let config = if let Some(file) = file { - let reader = ConfigReader::init(file, http.clone()); - reader.read(&[config_url]).await - } else { - let reader = ConfigReader::init(DummyFileIO, http.clone()); - reader.read(&[config_url]).await - }; + let config = if let Some(file) = file { + let reader = ConfigReader::init(file, http.clone()); + reader.read(&[config_url]).await + } else { + let reader = ConfigReader::init(DummyFileIO, http.clone()); + reader.read(&[config_url]).await + }; - let config = match config { - Ok(config) => config, - Err(e) => { - let mut response = async_graphql::Response::default(); - let server_error = if e.to_string() == "DummyFileIO" { - ServerError::new("Invalid Config URL specified", None) - } else { - ServerError::new(format!("{}", e), None) - }; - response.errors = vec![server_error]; - return Ok(Err(GraphQLResponse::from(response).to_response()?)); - } - }; + let config = match config { + Ok(config) => config, + Err(e) => { + let mut response = async_graphql::Response::default(); + let server_error = if e.to_string() == "DummyFileIO" { + ServerError::new("Invalid Config URL specified", None) + } else { + ServerError::new(format!("{}", e), None) + }; + response.errors = vec![server_error]; + return Ok(Err(GraphQLResponse::from(response).to_response()?)); + } + }; - let blueprint = match Blueprint::try_from(&config) { - Ok(blueprint) => blueprint, - Err(e) => { - let mut response = async_graphql::Response::default(); - let server_error = ServerError::new(format!("{}", e), None); - response.errors = vec![server_error]; - return Ok(Err(GraphQLResponse::from(response).to_response()?)); - } - }; + let blueprint = match Blueprint::try_from(&config) { + Ok(blueprint) => blueprint, + Err(e) => { + let mut response = async_graphql::Response::default(); + let server_error = ServerError::new(format!("{}", e), None); + response.errors = vec![server_error]; + return Ok(Err(GraphQLResponse::from(response).to_response()?)); + } + }; - let http = Arc::new(http); - let env = Arc::new(env); + let http = Arc::new(http); + let env = Arc::new(env); - Ok(Ok(AppContext::new(blueprint, http.clone(), http, env, cache))) + Ok(Ok(AppContext::new( + blueprint, + http.clone(), + http, + env, + cache, + ))) } diff --git a/src/javascript.rs b/src/javascript.rs index 0150664ccc0..41ed7d3d209 100644 --- a/src/javascript.rs +++ b/src/javascript.rs @@ -5,49 +5,53 @@ use mini_v8::{Error, MiniV8, Script}; // TODO: Performance optimizations // This function can be optimized quite heavily pub fn execute_js( - script: &str, - input: async_graphql::Value, - timeout: Option, + script: &str, + input: async_graphql::Value, + timeout: Option, ) -> Result { - let mv8 = MiniV8::new(); - let source = create_source(script, input); - let value: String = mv8.eval(Script { source, timeout, origin: None })?; - let json = serde_json::from_str(value.as_str()).unwrap(); - Ok(json) + let mv8 = MiniV8::new(); + let source = create_source(script, input); + let value: String = mv8.eval(Script { source, timeout, origin: None })?; + let json = serde_json::from_str(value.as_str()).unwrap(); + Ok(json) } fn create_source(script: &str, input: async_graphql::Value) -> String { - let template = "(function (ctx) {return JSON.stringify(--SCRIPT--)} )(--INPUT--);"; + let template = "(function (ctx) {return JSON.stringify(--SCRIPT--)} )(--INPUT--);"; - template - .replace("--SCRIPT--", script) - .replace("--INPUT--", input.to_string().as_str()) + template + .replace("--SCRIPT--", script) + .replace("--INPUT--", input.to_string().as_str()) } #[cfg(test)] #[test] fn test_json() { - let json = r#" + let json = r#" { "name": "John Doe", "age": 43 } "#; - let json = serde_json::from_str(json).unwrap(); - let script = "ctx.name"; - let actual = execute_js(script, json, Some(Duration::from_secs(1))).unwrap(); - let expected = async_graphql::Value::from("John Doe"); + let json = serde_json::from_str(json).unwrap(); + let script = "ctx.name"; + let actual = execute_js(script, json, Some(Duration::from_secs(1))).unwrap(); + let expected = async_graphql::Value::from("John Doe"); - assert_eq!(actual, expected); + assert_eq!(actual, expected); } #[cfg(test)] #[test] fn test_timeout() { - let script = "(function () {while(true) {};})()"; - let actual = execute_js(script, async_graphql::Value::Null, Some(Duration::from_millis(10))); - match actual { - Err(Error::Timeout) => {} // Success case - _ => panic!("Expected a Timeout error, but got {:?}", actual), // Failure case - } + let script = "(function () {while(true) {};})()"; + let actual = execute_js( + script, + async_graphql::Value::Null, + Some(Duration::from_millis(10)), + ); + match actual { + Err(Error::Timeout) => {} // Success case + _ => panic!("Expected a Timeout error, but got {:?}", actual), // Failure case + } } diff --git a/src/json/json_like.rs b/src/json/json_like.rs index a02d9321069..2172528afb8 100644 --- a/src/json/json_like.rs +++ b/src/json/json_like.rs @@ -3,342 +3,345 @@ use std::collections::HashMap; use async_graphql_value::ConstValue; pub trait JsonLike { - type Output; - fn as_array_ok(&self) -> Result<&Vec, &str>; - fn as_str_ok(&self) -> Result<&str, &str>; - fn as_string_ok(&self) -> Result<&String, &str>; - fn as_i64_ok(&self) -> Result; - fn as_u64_ok(&self) -> Result; - fn as_f64_ok(&self) -> Result; - fn as_bool_ok(&self) -> Result; - fn as_null_ok(&self) -> Result<(), &str>; - fn as_option_ok(&self) -> Result, &str>; - fn get_path>(&self, path: &[T]) -> Option<&Self::Output>; - fn get_key(&self, path: &str) -> Option<&Self::Output>; - fn new(value: &Self::Output) -> &Self; - fn group_by<'a>(&'a self, path: &'a [String]) -> HashMap>; + type Output; + fn as_array_ok(&self) -> Result<&Vec, &str>; + fn as_str_ok(&self) -> Result<&str, &str>; + fn as_string_ok(&self) -> Result<&String, &str>; + fn as_i64_ok(&self) -> Result; + fn as_u64_ok(&self) -> Result; + fn as_f64_ok(&self) -> Result; + fn as_bool_ok(&self) -> Result; + fn as_null_ok(&self) -> Result<(), &str>; + fn as_option_ok(&self) -> Result, &str>; + fn get_path>(&self, path: &[T]) -> Option<&Self::Output>; + fn get_key(&self, path: &str) -> Option<&Self::Output>; + fn new(value: &Self::Output) -> &Self; + fn group_by<'a>(&'a self, path: &'a [String]) -> HashMap>; } impl JsonLike for serde_json::Value { - type Output = serde_json::Value; - fn as_array_ok(&self) -> Result<&Vec, &str> { - self.as_array().ok_or("expected array") - } - fn as_str_ok(&self) -> Result<&str, &str> { - self.as_str().ok_or("expected str") - } - fn as_i64_ok(&self) -> Result { - self.as_i64().ok_or("expected i64") - } - fn as_u64_ok(&self) -> Result { - self.as_u64().ok_or("expected u64") - } - fn as_f64_ok(&self) -> Result { - self.as_f64().ok_or("expected f64") - } - fn as_bool_ok(&self) -> Result { - self.as_bool().ok_or("expected bool") - } - fn as_null_ok(&self) -> Result<(), &str> { - self.as_null().ok_or("expected null") - } - - fn as_option_ok(&self) -> Result, &str> { - match self { - serde_json::Value::Null => Ok(None), - _ => Ok(Some(self)), - } - } - - fn get_path>(&self, path: &[T]) -> Option<&Self::Output> { - let mut val = self; - for token in path { - val = match val { - serde_json::Value::Array(arr) => { - let index = token.as_ref().parse::().ok()?; - arr.get(index)? + type Output = serde_json::Value; + fn as_array_ok(&self) -> Result<&Vec, &str> { + self.as_array().ok_or("expected array") + } + fn as_str_ok(&self) -> Result<&str, &str> { + self.as_str().ok_or("expected str") + } + fn as_i64_ok(&self) -> Result { + self.as_i64().ok_or("expected i64") + } + fn as_u64_ok(&self) -> Result { + self.as_u64().ok_or("expected u64") + } + fn as_f64_ok(&self) -> Result { + self.as_f64().ok_or("expected f64") + } + fn as_bool_ok(&self) -> Result { + self.as_bool().ok_or("expected bool") + } + fn as_null_ok(&self) -> Result<(), &str> { + self.as_null().ok_or("expected null") + } + + fn as_option_ok(&self) -> Result, &str> { + match self { + serde_json::Value::Null => Ok(None), + _ => Ok(Some(self)), + } + } + + fn get_path>(&self, path: &[T]) -> Option<&Self::Output> { + let mut val = self; + for token in path { + val = match val { + serde_json::Value::Array(arr) => { + let index = token.as_ref().parse::().ok()?; + arr.get(index)? + } + serde_json::Value::Object(map) => map.get(token.as_ref())?, + _ => return None, + }; } - serde_json::Value::Object(map) => map.get(token.as_ref())?, - _ => return None, - }; + Some(val) } - Some(val) - } - fn new(value: &Self::Output) -> &Self { - value - } + fn new(value: &Self::Output) -> &Self { + value + } - fn get_key(&self, path: &str) -> Option<&Self::Output> { - match self { - serde_json::Value::Object(map) => map.get(path), - _ => None, + fn get_key(&self, path: &str) -> Option<&Self::Output> { + match self { + serde_json::Value::Object(map) => map.get(path), + _ => None, + } } - } - fn as_string_ok(&self) -> Result<&String, &str> { - match self { - serde_json::Value::String(s) => Ok(s), - _ => Err("expected string"), + fn as_string_ok(&self) -> Result<&String, &str> { + match self { + serde_json::Value::String(s) => Ok(s), + _ => Err("expected string"), + } } - } - fn group_by<'a>(&'a self, path: &'a [String]) -> HashMap> { - let src = gather_path_matches(self, path, vec![]); - group_by_key(src) - } + fn group_by<'a>(&'a self, path: &'a [String]) -> HashMap> { + let src = gather_path_matches(self, path, vec![]); + group_by_key(src) + } } impl JsonLike for async_graphql::Value { - type Output = async_graphql::Value; + type Output = async_graphql::Value; - fn as_array_ok(&self) -> Result<&Vec, &str> { - match self { - ConstValue::List(seq) => Ok(seq), - _ => Err("array"), + fn as_array_ok(&self) -> Result<&Vec, &str> { + match self { + ConstValue::List(seq) => Ok(seq), + _ => Err("array"), + } } - } - fn as_str_ok(&self) -> Result<&str, &str> { - match self { - ConstValue::String(s) => Ok(s), - _ => Err("str"), + fn as_str_ok(&self) -> Result<&str, &str> { + match self { + ConstValue::String(s) => Ok(s), + _ => Err("str"), + } } - } - fn as_i64_ok(&self) -> Result { - match self { - ConstValue::Number(n) => n.as_i64().ok_or("expected i64"), - _ => Err("i64"), + fn as_i64_ok(&self) -> Result { + match self { + ConstValue::Number(n) => n.as_i64().ok_or("expected i64"), + _ => Err("i64"), + } } - } - fn as_u64_ok(&self) -> Result { - match self { - ConstValue::Number(n) => n.as_u64().ok_or("expected u64"), - _ => Err("u64"), + fn as_u64_ok(&self) -> Result { + match self { + ConstValue::Number(n) => n.as_u64().ok_or("expected u64"), + _ => Err("u64"), + } } - } - fn as_f64_ok(&self) -> Result { - match self { - ConstValue::Number(n) => n.as_f64().ok_or("expected f64"), - _ => Err("f64"), + fn as_f64_ok(&self) -> Result { + match self { + ConstValue::Number(n) => n.as_f64().ok_or("expected f64"), + _ => Err("f64"), + } } - } - fn as_bool_ok(&self) -> Result { - match self { - ConstValue::Boolean(b) => Ok(*b), - _ => Err("bool"), + fn as_bool_ok(&self) -> Result { + match self { + ConstValue::Boolean(b) => Ok(*b), + _ => Err("bool"), + } } - } - fn as_null_ok(&self) -> Result<(), &str> { - match self { - ConstValue::Null => Ok(()), - _ => Err("null"), + fn as_null_ok(&self) -> Result<(), &str> { + match self { + ConstValue::Null => Ok(()), + _ => Err("null"), + } } - } - fn as_option_ok(&self) -> Result, &str> { - match self { - ConstValue::Null => Ok(None), - _ => Ok(Some(self)), + fn as_option_ok(&self) -> Result, &str> { + match self { + ConstValue::Null => Ok(None), + _ => Ok(Some(self)), + } } - } - fn get_path>(&self, path: &[T]) -> Option<&Self::Output> { - let mut val = self; - for token in path { - val = match val { - ConstValue::List(seq) => { - let index = token.as_ref().parse::().ok()?; - seq.get(index)? + fn get_path>(&self, path: &[T]) -> Option<&Self::Output> { + let mut val = self; + for token in path { + val = match val { + ConstValue::List(seq) => { + let index = token.as_ref().parse::().ok()?; + seq.get(index)? + } + ConstValue::Object(map) => map.get(token.as_ref())?, + _ => return None, + }; } - ConstValue::Object(map) => map.get(token.as_ref())?, - _ => return None, - }; - } - Some(val) - } - - fn new(value: &Self::Output) -> &Self { - value - } - - fn get_key(&self, path: &str) -> Option<&Self::Output> { - match self { - ConstValue::Object(map) => map.get(&async_graphql::Name::new(path)), - _ => None, - } - } - fn as_string_ok(&self) -> Result<&String, &str> { - match self { - ConstValue::String(s) => Ok(s), - _ => Err("expected string"), - } - } - - fn group_by<'a>(&'a self, path: &'a [String]) -> HashMap> { - let src = gather_path_matches(self, path, vec![]); - group_by_key(src) - } + Some(val) + } + + fn new(value: &Self::Output) -> &Self { + value + } + + fn get_key(&self, path: &str) -> Option<&Self::Output> { + match self { + ConstValue::Object(map) => map.get(&async_graphql::Name::new(path)), + _ => None, + } + } + fn as_string_ok(&self) -> Result<&String, &str> { + match self { + ConstValue::String(s) => Ok(s), + _ => Err("expected string"), + } + } + + fn group_by<'a>(&'a self, path: &'a [String]) -> HashMap> { + let src = gather_path_matches(self, path, vec![]); + group_by_key(src) + } } // Highly micro-optimized and benchmarked version of get_path_all // Any further changes should be verified with benchmarks pub fn gather_path_matches<'a, J: JsonLike>( - root: &'a J, - path: &'a [String], - mut vector: Vec<(&'a J, &'a J)>, + root: &'a J, + path: &'a [String], + mut vector: Vec<(&'a J, &'a J)>, ) -> Vec<(&'a J, &'a J)> { - if let Ok(root) = root.as_array_ok() { - for value in root { - vector = gather_path_matches(J::new(value), path, vector); - } - } else if let Some((key, tail)) = path.split_first() { - if let Some(value) = root.get_key(key) { - if tail.is_empty() { - vector.push((J::new(value), root)); - } else { - vector = gather_path_matches(J::new(value), tail, vector); - } + if let Ok(root) = root.as_array_ok() { + for value in root { + vector = gather_path_matches(J::new(value), path, vector); + } + } else if let Some((key, tail)) = path.split_first() { + if let Some(value) = root.get_key(key) { + if tail.is_empty() { + vector.push((J::new(value), root)); + } else { + vector = gather_path_matches(J::new(value), tail, vector); + } + } } - } - vector + vector } pub fn group_by_key<'a, J: JsonLike>(src: Vec<(&'a J, &'a J)>) -> HashMap> { - let mut map: HashMap> = HashMap::new(); - for (key, value) in src { - // Need to handle number and string keys - let key_str = key - .as_string_ok() - .cloned() - .or_else(|_| key.as_f64_ok().map(|a| a.to_string())); - - if let Ok(key) = key_str { - if let Some(values) = map.get_mut(&key) { - values.push(value); - } else { - map.insert(key, vec![value]); - } - } - } - map + let mut map: HashMap> = HashMap::new(); + for (key, value) in src { + // Need to handle number and string keys + let key_str = key + .as_string_ok() + .cloned() + .or_else(|_| key.as_f64_ok().map(|a| a.to_string())); + + if let Ok(key) = key_str { + if let Some(values) = map.get_mut(&key) { + values.push(value); + } else { + map.insert(key, vec![value]); + } + } + } + map } #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; - use serde_json::json; - - use crate::json::group_by_key; - use crate::json::json_like::gather_path_matches; - - #[test] - fn test_gather_path_matches() { - let input = json!([ - {"id": "1"}, - {"id": "2"}, - {"id": "3"} - ]); - - let actual = serde_json::to_value(gather_path_matches(&input, &["id".into()], vec![])).unwrap(); - - let expected = json!( - [ - ["1", {"id": "1"}], - ["2", {"id": "2"}], - ["3", {"id": "3"}], - ] - ); - - assert_eq!(actual, expected) - } - - #[test] - fn test_gather_path_matches_nested() { - let input = json!({ - "data": [ - {"user": {"id": "1"}}, - {"user": {"id": "2"}}, - {"user": {"id": "3"}}, - {"user": [ - {"id": "4"}, - {"id": "5"} - ] - }, - ] - }); - - let actual = serde_json::to_value(gather_path_matches( - &input, - &["data".into(), "user".into(), "id".into()], - vec![], - )) - .unwrap(); - - let expected = json!( - [ - ["1", {"id": "1"}], - ["2", {"id": "2"}], - ["3", {"id": "3"}], - ["4", {"id": "4"}], - ["5", {"id": "5"}], - - ] - ); - - assert_eq!(actual, expected) - } - - #[test] - fn test_group_by_key() { - let arr = vec![ - (json!("1"), json!({"id": "1"})), - (json!("2"), json!({"id": "2"})), - (json!("2"), json!({"id": "2"})), - (json!("3"), json!({"id": "3"})), - ]; - let input: Vec<(&serde_json::Value, &serde_json::Value)> = arr.iter().map(|a| (&a.0, &a.1)).collect(); - - let actual = serde_json::to_value(group_by_key(input)).unwrap(); - - let expected = json!( - { - "1": [{"id": "1"}], - "2": [{"id": "2"}, {"id": "2"}], - "3": [{"id": "3"}], - } - ); - - assert_eq!(actual, expected) - } - - #[test] - fn test_group_by_numeric_key() { - let arr = vec![ - (json!(1), json!({"id": 1})), - (json!(2), json!({"id": 2})), - (json!(2), json!({"id": 2})), - (json!(3), json!({"id": 3})), - ]; - let input: Vec<(&serde_json::Value, &serde_json::Value)> = arr.iter().map(|a| (&a.0, &a.1)).collect(); - - let actual = serde_json::to_value(group_by_key(input)).unwrap(); - - let expected = json!( - { - "1": [{"id": 1}], - "2": [{"id": 2}, {"id": 2}], - "3": [{"id": 3}], - } - ); + use pretty_assertions::assert_eq; + use serde_json::json; + + use crate::json::group_by_key; + use crate::json::json_like::gather_path_matches; + + #[test] + fn test_gather_path_matches() { + let input = json!([ + {"id": "1"}, + {"id": "2"}, + {"id": "3"} + ]); + + let actual = + serde_json::to_value(gather_path_matches(&input, &["id".into()], vec![])).unwrap(); + + let expected = json!( + [ + ["1", {"id": "1"}], + ["2", {"id": "2"}], + ["3", {"id": "3"}], + ] + ); + + assert_eq!(actual, expected) + } + + #[test] + fn test_gather_path_matches_nested() { + let input = json!({ + "data": [ + {"user": {"id": "1"}}, + {"user": {"id": "2"}}, + {"user": {"id": "3"}}, + {"user": [ + {"id": "4"}, + {"id": "5"} + ] + }, + ] + }); + + let actual = serde_json::to_value(gather_path_matches( + &input, + &["data".into(), "user".into(), "id".into()], + vec![], + )) + .unwrap(); + + let expected = json!( + [ + ["1", {"id": "1"}], + ["2", {"id": "2"}], + ["3", {"id": "3"}], + ["4", {"id": "4"}], + ["5", {"id": "5"}], + + ] + ); + + assert_eq!(actual, expected) + } - assert_eq!(actual, expected) - } + #[test] + fn test_group_by_key() { + let arr = vec![ + (json!("1"), json!({"id": "1"})), + (json!("2"), json!({"id": "2"})), + (json!("2"), json!({"id": "2"})), + (json!("3"), json!({"id": "3"})), + ]; + let input: Vec<(&serde_json::Value, &serde_json::Value)> = + arr.iter().map(|a| (&a.0, &a.1)).collect(); + + let actual = serde_json::to_value(group_by_key(input)).unwrap(); + + let expected = json!( + { + "1": [{"id": "1"}], + "2": [{"id": "2"}, {"id": "2"}], + "3": [{"id": "3"}], + } + ); + + assert_eq!(actual, expected) + } + + #[test] + fn test_group_by_numeric_key() { + let arr = vec![ + (json!(1), json!({"id": 1})), + (json!(2), json!({"id": 2})), + (json!(2), json!({"id": 2})), + (json!(3), json!({"id": 3})), + ]; + let input: Vec<(&serde_json::Value, &serde_json::Value)> = + arr.iter().map(|a| (&a.0, &a.1)).collect(); + + let actual = serde_json::to_value(group_by_key(input)).unwrap(); + + let expected = json!( + { + "1": [{"id": 1}], + "2": [{"id": 2}, {"id": 2}], + "3": [{"id": 3}], + } + ); + + assert_eq!(actual, expected) + } } diff --git a/src/json/json_schema.rs b/src/json/json_schema.rs index d92b26f1922..cfd09bb3d95 100644 --- a/src/json/json_schema.rs +++ b/src/json/json_schema.rs @@ -8,252 +8,270 @@ use crate::valid::Valid; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, schemars::JsonSchema)] #[serde(rename = "schema")] pub enum JsonSchema { - Obj(HashMap), - Arr(Box), - Opt(Box), - Str, - Num, - Bool, + Obj(HashMap), + Arr(Box), + Opt(Box), + Str, + Num, + Bool, } impl From<[(&'static str, JsonSchema); L]> for JsonSchema { - fn from(fields: [(&'static str, JsonSchema); L]) -> Self { - let mut map = HashMap::new(); - for (name, schema) in fields { - map.insert(name.to_string(), schema); + fn from(fields: [(&'static str, JsonSchema); L]) -> Self { + let mut map = HashMap::new(); + for (name, schema) in fields { + map.insert(name.to_string(), schema); + } + JsonSchema::Obj(map) } - JsonSchema::Obj(map) - } } impl Default for JsonSchema { - fn default() -> Self { - JsonSchema::Obj(HashMap::new()) - } + fn default() -> Self { + JsonSchema::Obj(HashMap::new()) + } } impl JsonSchema { - // TODO: validate `JsonLike` instead of fixing on `async_graphql::Value` - pub fn validate(&self, value: &async_graphql::Value) -> Valid<(), &'static str> { - match self { - JsonSchema::Str => match value { - async_graphql::Value::String(_) => Valid::succeed(()), - _ => Valid::fail("expected string"), - }, - JsonSchema::Num => match value { - async_graphql::Value::Number(_) => Valid::succeed(()), - _ => Valid::fail("expected number"), - }, - JsonSchema::Bool => match value { - async_graphql::Value::Boolean(_) => Valid::succeed(()), - _ => Valid::fail("expected boolean"), - }, - JsonSchema::Arr(schema) => match value { - async_graphql::Value::List(list) => { - // TODO: add unit tests - Valid::from_iter(list.iter().enumerate(), |(i, item)| { - schema.validate(item).trace(i.to_string().as_str()) - }) - .unit() - } - _ => Valid::fail("expected array"), - }, - JsonSchema::Obj(fields) => { - let field_schema_list: Vec<(&String, &JsonSchema)> = fields.iter().collect(); - match value { - async_graphql::Value::Object(map) => Valid::from_iter(field_schema_list, |(name, schema)| { - if schema.is_required() { - if let Some(field_value) = map.get::(name.as_ref()) { - schema.validate(field_value).trace(name) - } else { - Valid::fail("expected field to be non-nullable").trace(name) - } - } else if let Some(field_value) = map.get::(name.as_ref()) { - schema.validate(field_value).trace(name) - } else { - Valid::succeed(()) + // TODO: validate `JsonLike` instead of fixing on `async_graphql::Value` + pub fn validate(&self, value: &async_graphql::Value) -> Valid<(), &'static str> { + match self { + JsonSchema::Str => match value { + async_graphql::Value::String(_) => Valid::succeed(()), + _ => Valid::fail("expected string"), + }, + JsonSchema::Num => match value { + async_graphql::Value::Number(_) => Valid::succeed(()), + _ => Valid::fail("expected number"), + }, + JsonSchema::Bool => match value { + async_graphql::Value::Boolean(_) => Valid::succeed(()), + _ => Valid::fail("expected boolean"), + }, + JsonSchema::Arr(schema) => match value { + async_graphql::Value::List(list) => { + // TODO: add unit tests + Valid::from_iter(list.iter().enumerate(), |(i, item)| { + schema.validate(item).trace(i.to_string().as_str()) + }) + .unit() + } + _ => Valid::fail("expected array"), + }, + JsonSchema::Obj(fields) => { + let field_schema_list: Vec<(&String, &JsonSchema)> = fields.iter().collect(); + match value { + async_graphql::Value::Object(map) => { + Valid::from_iter(field_schema_list, |(name, schema)| { + if schema.is_required() { + if let Some(field_value) = map.get::(name.as_ref()) { + schema.validate(field_value).trace(name) + } else { + Valid::fail("expected field to be non-nullable").trace(name) + } + } else if let Some(field_value) = map.get::(name.as_ref()) { + schema.validate(field_value).trace(name) + } else { + Valid::succeed(()) + } + }) + .unit() + } + _ => Valid::fail("expected object"), + } } - }) - .unit(), - _ => Valid::fail("expected object"), + JsonSchema::Opt(schema) => match value { + async_graphql::Value::Null => Valid::succeed(()), + _ => schema.validate(value), + }, } - } - JsonSchema::Opt(schema) => match value { - async_graphql::Value::Null => Valid::succeed(()), - _ => schema.validate(value), - }, } - } - // TODO: add unit tests - pub fn compare(&self, other: &JsonSchema, name: &str) -> Valid<(), String> { - match self { - JsonSchema::Obj(a) => { - if let JsonSchema::Obj(b) = other { - return Valid::from_iter(b.iter(), |(key, b)| { - Valid::from_option(a.get(key), format!("missing key: {}", key)).and_then(|a| a.compare(b, key)) - }) - .trace(name) - .unit(); - } else { - return Valid::fail("expected Object type".to_string()).trace(name); - } - } - JsonSchema::Arr(a) => { - if let JsonSchema::Arr(b) = other { - return a.compare(b, name); - } else { - return Valid::fail("expected Non repeatable type".to_string()).trace(name); - } - } - JsonSchema::Opt(a) => { - if let JsonSchema::Opt(b) = other { - return a.compare(b, name); - } else { - return Valid::fail("expected type to be required".to_string()).trace(name); - } - } - JsonSchema::Str => { - if other != self { - return Valid::fail(format!("expected String, got {:?}", other)).trace(name); - } - } - JsonSchema::Num => { - if other != self { - return Valid::fail(format!("expected Number, got {:?}", other)).trace(name); - } - } - JsonSchema::Bool => { - if other != self { - return Valid::fail(format!("expected Boolean, got {:?}", other)).trace(name); + // TODO: add unit tests + pub fn compare(&self, other: &JsonSchema, name: &str) -> Valid<(), String> { + match self { + JsonSchema::Obj(a) => { + if let JsonSchema::Obj(b) = other { + return Valid::from_iter(b.iter(), |(key, b)| { + Valid::from_option(a.get(key), format!("missing key: {}", key)) + .and_then(|a| a.compare(b, key)) + }) + .trace(name) + .unit(); + } else { + return Valid::fail("expected Object type".to_string()).trace(name); + } + } + JsonSchema::Arr(a) => { + if let JsonSchema::Arr(b) = other { + return a.compare(b, name); + } else { + return Valid::fail("expected Non repeatable type".to_string()).trace(name); + } + } + JsonSchema::Opt(a) => { + if let JsonSchema::Opt(b) = other { + return a.compare(b, name); + } else { + return Valid::fail("expected type to be required".to_string()).trace(name); + } + } + JsonSchema::Str => { + if other != self { + return Valid::fail(format!("expected String, got {:?}", other)).trace(name); + } + } + JsonSchema::Num => { + if other != self { + return Valid::fail(format!("expected Number, got {:?}", other)).trace(name); + } + } + JsonSchema::Bool => { + if other != self { + return Valid::fail(format!("expected Boolean, got {:?}", other)).trace(name); + } + } } - } + Valid::succeed(()) } - Valid::succeed(()) - } - pub fn optional(self) -> JsonSchema { - JsonSchema::Opt(Box::new(self)) - } + pub fn optional(self) -> JsonSchema { + JsonSchema::Opt(Box::new(self)) + } - pub fn is_optional(&self) -> bool { - matches!(self, JsonSchema::Opt(_)) - } + pub fn is_optional(&self) -> bool { + matches!(self, JsonSchema::Opt(_)) + } - pub fn is_required(&self) -> bool { - !self.is_optional() - } + pub fn is_required(&self) -> bool { + !self.is_optional() + } } impl TryFrom<&MessageDescriptor> for JsonSchema { - type Error = crate::valid::ValidationError; + type Error = crate::valid::ValidationError; - fn try_from(value: &MessageDescriptor) -> Result { - let mut map = std::collections::HashMap::new(); - let fields = value.fields(); + fn try_from(value: &MessageDescriptor) -> Result { + let mut map = std::collections::HashMap::new(); + let fields = value.fields(); - for field in fields { - let field_schema = JsonSchema::try_from(&field)?; + for field in fields { + let field_schema = JsonSchema::try_from(&field)?; - map.insert(field.name().to_string(), field_schema); - } + map.insert(field.name().to_string(), field_schema); + } - Ok(JsonSchema::Obj(map)) - } + Ok(JsonSchema::Obj(map)) + } } impl TryFrom<&FieldDescriptor> for JsonSchema { - type Error = crate::valid::ValidationError; + type Error = crate::valid::ValidationError; - fn try_from(value: &FieldDescriptor) -> Result { - let field_schema = match value.kind() { - Kind::Double => JsonSchema::Num, - Kind::Float => JsonSchema::Num, - Kind::Int32 => JsonSchema::Num, - Kind::Int64 => JsonSchema::Num, - Kind::Uint32 => JsonSchema::Num, - Kind::Uint64 => JsonSchema::Num, - Kind::Sint32 => JsonSchema::Num, - Kind::Sint64 => JsonSchema::Num, - Kind::Fixed32 => JsonSchema::Num, - Kind::Fixed64 => JsonSchema::Num, - Kind::Sfixed32 => JsonSchema::Num, - Kind::Sfixed64 => JsonSchema::Num, - Kind::Bool => JsonSchema::Bool, - Kind::String => JsonSchema::Str, - Kind::Bytes => JsonSchema::Str, - Kind::Message(msg) => JsonSchema::try_from(&msg)?, - Kind::Enum(_) => { - todo!("Enum") - } - }; - let field_schema = if value.cardinality().eq(&prost_reflect::Cardinality::Optional) { - JsonSchema::Opt(Box::new(field_schema)) - } else { - field_schema - }; - let field_schema = if value.is_list() { - JsonSchema::Arr(Box::new(field_schema)) - } else { - field_schema - }; + fn try_from(value: &FieldDescriptor) -> Result { + let field_schema = match value.kind() { + Kind::Double => JsonSchema::Num, + Kind::Float => JsonSchema::Num, + Kind::Int32 => JsonSchema::Num, + Kind::Int64 => JsonSchema::Num, + Kind::Uint32 => JsonSchema::Num, + Kind::Uint64 => JsonSchema::Num, + Kind::Sint32 => JsonSchema::Num, + Kind::Sint64 => JsonSchema::Num, + Kind::Fixed32 => JsonSchema::Num, + Kind::Fixed64 => JsonSchema::Num, + Kind::Sfixed32 => JsonSchema::Num, + Kind::Sfixed64 => JsonSchema::Num, + Kind::Bool => JsonSchema::Bool, + Kind::String => JsonSchema::Str, + Kind::Bytes => JsonSchema::Str, + Kind::Message(msg) => JsonSchema::try_from(&msg)?, + Kind::Enum(_) => { + todo!("Enum") + } + }; + let field_schema = if value + .cardinality() + .eq(&prost_reflect::Cardinality::Optional) + { + JsonSchema::Opt(Box::new(field_schema)) + } else { + field_schema + }; + let field_schema = if value.is_list() { + JsonSchema::Arr(Box::new(field_schema)) + } else { + field_schema + }; - Ok(field_schema) - } + Ok(field_schema) + } } #[cfg(test)] mod tests { - use async_graphql::Name; - use indexmap::IndexMap; + use async_graphql::Name; + use indexmap::IndexMap; - use crate::json::JsonSchema; - use crate::valid::Valid; + use crate::json::JsonSchema; + use crate::valid::Valid; - #[test] - fn test_validate_string() { - let schema = JsonSchema::Str; - let value = async_graphql::Value::String("hello".to_string()); - let result = schema.validate(&value); - assert_eq!(result, Valid::succeed(())); - } + #[test] + fn test_validate_string() { + let schema = JsonSchema::Str; + let value = async_graphql::Value::String("hello".to_string()); + let result = schema.validate(&value); + assert_eq!(result, Valid::succeed(())); + } - #[test] - fn test_validate_valid_object() { - let schema = JsonSchema::from([("name", JsonSchema::Str), ("age", JsonSchema::Num)]); - let value = async_graphql::Value::Object({ - let mut map = IndexMap::new(); - map.insert(Name::new("name"), async_graphql::Value::String("hello".to_string())); - map.insert(Name::new("age"), async_graphql::Value::Number(1.into())); - map - }); - let result = schema.validate(&value); - assert_eq!(result, Valid::succeed(())); - } + #[test] + fn test_validate_valid_object() { + let schema = JsonSchema::from([("name", JsonSchema::Str), ("age", JsonSchema::Num)]); + let value = async_graphql::Value::Object({ + let mut map = IndexMap::new(); + map.insert( + Name::new("name"), + async_graphql::Value::String("hello".to_string()), + ); + map.insert(Name::new("age"), async_graphql::Value::Number(1.into())); + map + }); + let result = schema.validate(&value); + assert_eq!(result, Valid::succeed(())); + } - #[test] - fn test_validate_invalid_object() { - let schema = JsonSchema::from([("name", JsonSchema::Str), ("age", JsonSchema::Num)]); - let value = async_graphql::Value::Object({ - let mut map = IndexMap::new(); - map.insert(Name::new("name"), async_graphql::Value::String("hello".to_string())); - map.insert(Name::new("age"), async_graphql::Value::String("1".to_string())); - map - }); - let result = schema.validate(&value); - assert_eq!(result, Valid::fail("expected number").trace("age")); - } + #[test] + fn test_validate_invalid_object() { + let schema = JsonSchema::from([("name", JsonSchema::Str), ("age", JsonSchema::Num)]); + let value = async_graphql::Value::Object({ + let mut map = IndexMap::new(); + map.insert( + Name::new("name"), + async_graphql::Value::String("hello".to_string()), + ); + map.insert( + Name::new("age"), + async_graphql::Value::String("1".to_string()), + ); + map + }); + let result = schema.validate(&value); + assert_eq!(result, Valid::fail("expected number").trace("age")); + } - #[test] - fn test_null_key() { - let schema = JsonSchema::from([("name", JsonSchema::Str.optional()), ("age", JsonSchema::Num)]); - let value = async_graphql::Value::Object({ - let mut map = IndexMap::new(); - map.insert(Name::new("age"), async_graphql::Value::Number(1.into())); - map - }); + #[test] + fn test_null_key() { + let schema = JsonSchema::from([ + ("name", JsonSchema::Str.optional()), + ("age", JsonSchema::Num), + ]); + let value = async_graphql::Value::Object({ + let mut map = IndexMap::new(); + map.insert(Name::new("age"), async_graphql::Value::Number(1.into())); + map + }); - let result = schema.validate(&value); - assert_eq!(result, Valid::succeed(())); - } + let result = schema.validate(&value); + assert_eq!(result, Valid::succeed(())); + } } diff --git a/src/lambda/concurrent.rs b/src/lambda/concurrent.rs index 10944623b6f..bd3face86de 100644 --- a/src/lambda/concurrent.rs +++ b/src/lambda/concurrent.rs @@ -8,49 +8,52 @@ use futures_util::{Future, StreamExt}; /// #[derive(Clone, Debug, Default)] pub enum Concurrent { - Parallel, - #[default] - Sequential, + Parallel, + #[default] + Sequential, } impl Concurrent { - pub async fn fold( - &self, - iter: impl Iterator, - acc: B, - f: impl Fn(B, A) -> anyhow::Result, - ) -> anyhow::Result - where - F: Future, - { - match self { - Concurrent::Sequential => { - let mut output = acc; - for future in iter.into_iter() { - output = f(output, future.await)?; + pub async fn fold( + &self, + iter: impl Iterator, + acc: B, + f: impl Fn(B, A) -> anyhow::Result, + ) -> anyhow::Result + where + F: Future, + { + match self { + Concurrent::Sequential => { + let mut output = acc; + for future in iter.into_iter() { + output = f(output, future.await)?; + } + Ok(output) + } + Concurrent::Parallel => { + let mut futures: FuturesUnordered<_> = iter.into_iter().collect(); + let mut output = acc; + while let Some(result) = futures.next().await { + output = f(output, result)?; + } + Ok(output) + } } - Ok(output) - } - Concurrent::Parallel => { - let mut futures: FuturesUnordered<_> = iter.into_iter().collect(); - let mut output = acc; - while let Some(result) = futures.next().await { - output = f(output, result)?; - } - Ok(output) - } } - } - pub async fn foreach(&self, iter: impl Iterator, f: impl Fn(A) -> B) -> anyhow::Result> - where - F: Future>, - { - self - .fold(iter, vec![], |mut acc, val| { - acc.push(f(val?)); - Ok(acc) - }) - .await - } + pub async fn foreach( + &self, + iter: impl Iterator, + f: impl Fn(A) -> B, + ) -> anyhow::Result> + where + F: Future>, + { + self.fold(iter, vec![], |mut acc, val| { + acc.push(f(val?)); + Ok(acc) + }) + .await + } } diff --git a/src/lambda/eval.rs b/src/lambda/eval.rs index fcb84172412..3337682428d 100644 --- a/src/lambda/eval.rs +++ b/src/lambda/eval.rs @@ -7,13 +7,13 @@ use super::{Concurrent, EvaluationContext, ResolverContextLike}; pub trait Eval where - Self: Send + Sync, + Self: Send + Sync, { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> - where - Output: 'a; + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> + where + Output: 'a; } diff --git a/src/lambda/evaluation_context.rs b/src/lambda/evaluation_context.rs index b84fe687ba0..690f73e3555 100644 --- a/src/lambda/evaluation_context.rs +++ b/src/lambda/evaluation_context.rs @@ -13,190 +13,194 @@ use crate::http::RequestContext; #[derive(Clone, Setters)] #[setters(strip_option)] pub struct EvaluationContext<'a, Ctx: ResolverContextLike<'a>> { - pub req_ctx: &'a RequestContext, - pub graphql_ctx: &'a Ctx, + pub req_ctx: &'a RequestContext, + pub graphql_ctx: &'a Ctx, - // TODO: JS timeout should be read from server settings - pub timeout: Duration, + // TODO: JS timeout should be read from server settings + pub timeout: Duration, } impl<'a, Ctx: ResolverContextLike<'a>> EvaluationContext<'a, Ctx> { - pub fn new(req_ctx: &'a RequestContext, graphql_ctx: &'a Ctx) -> EvaluationContext<'a, Ctx> { - Self { timeout: Duration::from_millis(5), req_ctx, graphql_ctx } - } + pub fn new(req_ctx: &'a RequestContext, graphql_ctx: &'a Ctx) -> EvaluationContext<'a, Ctx> { + Self { timeout: Duration::from_millis(5), req_ctx, graphql_ctx } + } - pub fn value(&self) -> Option<&Value> { - self.graphql_ctx.value() - } + pub fn value(&self) -> Option<&Value> { + self.graphql_ctx.value() + } - pub fn arg>(&self, path: &[T]) -> Option<&'a Value> { - let arg = self.graphql_ctx.args()?.get(path[0].as_ref()); + pub fn arg>(&self, path: &[T]) -> Option<&'a Value> { + let arg = self.graphql_ctx.args()?.get(path[0].as_ref()); - get_path_value(arg?, &path[1..]) - } + get_path_value(arg?, &path[1..]) + } - pub fn path_value>(&self, path: &[T]) -> Option<&'a Value> { - get_path_value(self.graphql_ctx.value()?, path) - } + pub fn path_value>(&self, path: &[T]) -> Option<&'a Value> { + get_path_value(self.graphql_ctx.value()?, path) + } - pub fn headers(&self) -> &HeaderMap { - &self.req_ctx.req_headers - } + pub fn headers(&self) -> &HeaderMap { + &self.req_ctx.req_headers + } - pub fn header(&self, key: &str) -> Option<&str> { - let value = self.headers().get(key)?; + pub fn header(&self, key: &str) -> Option<&str> { + let value = self.headers().get(key)?; - value.to_str().ok() - } + value.to_str().ok() + } - pub fn env_var(&self, key: &str) -> Option { - self.req_ctx.env_vars.get(key) - } + pub fn env_var(&self, key: &str) -> Option { + self.req_ctx.env_vars.get(key) + } - pub fn var(&self, key: &str) -> Option<&str> { - let vars = &self.req_ctx.server.vars; + pub fn var(&self, key: &str) -> Option<&str> { + let vars = &self.req_ctx.server.vars; - vars.get(key).map(|v| v.as_str()) - } + vars.get(key).map(|v| v.as_str()) + } - pub fn vars(&self) -> &BTreeMap { - &self.req_ctx.server.vars - } + pub fn vars(&self) -> &BTreeMap { + &self.req_ctx.server.vars + } - pub fn add_error(&self, error: ServerError) { - self.graphql_ctx.add_error(error) - } + pub fn add_error(&self, error: ServerError) { + self.graphql_ctx.add_error(error) + } } impl<'a, Ctx: ResolverContextLike<'a>> GraphQLOperationContext for EvaluationContext<'a, Ctx> { - fn selection_set(&self) -> Option { - let selection_set = self.graphql_ctx.field()?.selection_set(); + fn selection_set(&self) -> Option { + let selection_set = self.graphql_ctx.field()?.selection_set(); - format_selection_set(selection_set) - } + format_selection_set(selection_set) + } } -fn format_selection_set<'a>(selection_set: impl Iterator>) -> Option { - let set = selection_set.map(format_selection_field).collect::>(); +fn format_selection_set<'a>( + selection_set: impl Iterator>, +) -> Option { + let set = selection_set + .map(format_selection_field) + .collect::>(); - if set.is_empty() { - return None; - } + if set.is_empty() { + return None; + } - Some(format!("{{ {} }}", set.join(" "))) + Some(format!("{{ {} }}", set.join(" "))) } fn format_selection_field(field: SelectionField) -> String { - let name = field.name(); - let arguments = format_selection_field_arguments(field); - let selection_set = format_selection_set(field.selection_set()); - - if let Some(set) = selection_set { - format!("{}{} {}", name, arguments, set) - } else { - format!("{}{}", name, arguments) - } + let name = field.name(); + let arguments = format_selection_field_arguments(field); + let selection_set = format_selection_set(field.selection_set()); + + if let Some(set) = selection_set { + format!("{}{} {}", name, arguments, set) + } else { + format!("{}{}", name, arguments) + } } fn format_selection_field_arguments(field: SelectionField) -> Cow<'static, str> { - let name = field.name(); - let arguments = field - .arguments() - .map_err(|error| { - log::warn!("Failed to resolve arguments for field {name}, due to error: {error}"); - - error - }) - .unwrap_or_default(); - - if arguments.is_empty() { - return Cow::Borrowed(""); - } - - let args = arguments - .iter() - .map(|(name, value)| format!("{}: {}", name, value)) - .collect::>() - .join(","); - - Cow::Owned(format!("({})", args)) + let name = field.name(); + let arguments = field + .arguments() + .map_err(|error| { + log::warn!("Failed to resolve arguments for field {name}, due to error: {error}"); + + error + }) + .unwrap_or_default(); + + if arguments.is_empty() { + return Cow::Borrowed(""); + } + + let args = arguments + .iter() + .map(|(name, value)| format!("{}: {}", name, value)) + .collect::>() + .join(","); + + Cow::Owned(format!("({})", args)) } // TODO: this is the same code as src/json/json_like.rs::get_path pub fn get_path_value<'a, T: AsRef>(input: &'a Value, path: &[T]) -> Option<&'a Value> { - let mut value = Some(input); - for name in path { - match value { - Some(Value::Object(map)) => { - value = map.get(name.as_ref()); - } + let mut value = Some(input); + for name in path { + match value { + Some(Value::Object(map)) => { + value = map.get(name.as_ref()); + } - Some(Value::List(list)) => { - value = list.get(name.as_ref().parse::().ok()?); - } - _ => return None, + Some(Value::List(list)) => { + value = list.get(name.as_ref().parse::().ok()?); + } + _ => return None, + } } - } - value + value } #[cfg(test)] mod tests { - use async_graphql::Value; - use serde_json::json; - - use crate::lambda::evaluation_context::get_path_value; - - #[test] - fn test_path_value() { - let json = json!( - { - "a": { - "b": { - "c": "d" + use async_graphql::Value; + use serde_json::json; + + use crate::lambda::evaluation_context::get_path_value; + + #[test] + fn test_path_value() { + let json = json!( + { + "a": { + "b": { + "c": "d" + } } - } - }); - - let async_value = Value::from_json(json).unwrap(); - - let path = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - let result = get_path_value(&async_value, &path); - assert!(result.is_some()); - assert_eq!(result.unwrap(), &Value::String("d".to_string())); - } - - #[test] - fn test_path_not_found() { - let json = json!( - { - "a": { - "b": "c" - } - }); - - let async_value = Value::from_json(json).unwrap(); - - let path = vec!["a".to_string(), "b".to_string(), "c".to_string()]; - let result = get_path_value(&async_value, &path); - assert!(result.is_none()); - } - - #[test] - fn test_numeric_path() { - let json = json!( - { - "a": [{ - "b": "c" - }] - }); - - let async_value = Value::from_json(json).unwrap(); - - let path = vec!["a".to_string(), "0".to_string(), "b".to_string()]; - let result = get_path_value(&async_value, &path); - assert!(result.is_some()); - assert_eq!(result.unwrap(), &Value::String("c".to_string())); - } + }); + + let async_value = Value::from_json(json).unwrap(); + + let path = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let result = get_path_value(&async_value, &path); + assert!(result.is_some()); + assert_eq!(result.unwrap(), &Value::String("d".to_string())); + } + + #[test] + fn test_path_not_found() { + let json = json!( + { + "a": { + "b": "c" + } + }); + + let async_value = Value::from_json(json).unwrap(); + + let path = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let result = get_path_value(&async_value, &path); + assert!(result.is_none()); + } + + #[test] + fn test_numeric_path() { + let json = json!( + { + "a": [{ + "b": "c" + }] + }); + + let async_value = Value::from_json(json).unwrap(); + + let path = vec!["a".to_string(), "0".to_string(), "b".to_string()]; + let result = get_path_value(&async_value, &path); + assert!(result.is_some()); + assert_eq!(result.unwrap(), &Value::String("c".to_string())); + } } diff --git a/src/lambda/expression.rs b/src/lambda/expression.rs index d9948241f64..875fe0e3a2c 100644 --- a/src/lambda/expression.rs +++ b/src/lambda/expression.rs @@ -14,95 +14,109 @@ use crate::json::JsonLike; #[derive(Clone, Debug)] pub enum Expression { - Context(Context), - Literal(Value), // TODO: this should async_graphql::Value - EqualTo(Box, Box), - IO(IO), - Input(Box, Vec), - Logic(Logic), - Relation(Relation), - List(List), - Math(Math), - Concurrency(Concurrent, Box), + Context(Context), + Literal(Value), // TODO: this should async_graphql::Value + EqualTo(Box, Box), + IO(IO), + Input(Box, Vec), + Logic(Logic), + Relation(Relation), + List(List), + Math(Math), + Concurrency(Concurrent, Box), } #[derive(Clone, Debug)] pub enum Context { - Value, - Path(Vec), + Value, + Path(Vec), } #[derive(Debug, Error)] pub enum EvaluationError { - #[error("IOException: {0}")] - IOException(String), + #[error("IOException: {0}")] + IOException(String), - #[error("JSException: {0}")] - JSException(String), + #[error("JSException: {0}")] + JSException(String), - #[error("APIValidationError: {0:?}")] - APIValidationError(Vec), + #[error("APIValidationError: {0:?}")] + APIValidationError(Vec), - #[error("ExprEvalError: {0:?}")] - ExprEvalError(String), + #[error("ExprEvalError: {0:?}")] + ExprEvalError(String), } impl<'a> From> for EvaluationError { - fn from(_value: crate::valid::ValidationError<&'a str>) -> Self { - EvaluationError::APIValidationError(_value.as_vec().iter().map(|e| e.message.to_owned()).collect()) - } + fn from(_value: crate::valid::ValidationError<&'a str>) -> Self { + EvaluationError::APIValidationError( + _value + .as_vec() + .iter() + .map(|e| e.message.to_owned()) + .collect(), + ) + } } impl Expression { - pub fn concurrency(self, conc: Concurrent) -> Self { - Expression::Concurrency(conc, Box::new(self)) - } + pub fn concurrency(self, conc: Concurrent) -> Self { + Expression::Concurrency(conc, Box::new(self)) + } - pub fn in_parallel(self) -> Self { - self.concurrency(Concurrent::Parallel) - } + pub fn in_parallel(self) -> Self { + self.concurrency(Concurrent::Parallel) + } - pub fn parallel_when(self, cond: bool) -> Self { - if cond { - self.concurrency(Concurrent::Parallel) - } else { - self + pub fn parallel_when(self, cond: bool) -> Self { + if cond { + self.concurrency(Concurrent::Parallel) + } else { + self + } } - } - pub fn in_sequence(self) -> Self { - self.concurrency(Concurrent::Sequential) - } + pub fn in_sequence(self) -> Self { + self.concurrency(Concurrent::Sequential) + } } impl Eval for Expression { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - match self { - Expression::Concurrency(conc, expr) => Ok(expr.eval(ctx, conc).await?), - Expression::Context(op) => match op { - Context::Value => Ok(ctx.value().cloned().unwrap_or(async_graphql::Value::Null)), - Context::Path(path) => Ok(ctx.path_value(path).cloned().unwrap_or(async_graphql::Value::Null)), - }, - Expression::Input(input, path) => { - let inp = &input.eval(ctx, conc).await?; - Ok(inp.get_path(path).unwrap_or(&async_graphql::Value::Null).clone()) - } - Expression::Literal(value) => Ok(serde_json::from_value(value.clone())?), - Expression::EqualTo(left, right) => Ok(async_graphql::Value::from( - left.eval(ctx, conc).await? == right.eval(ctx, conc).await?, - )), - Expression::IO(operation) => operation.eval(ctx, conc).await, + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + match self { + Expression::Concurrency(conc, expr) => Ok(expr.eval(ctx, conc).await?), + Expression::Context(op) => match op { + Context::Value => { + Ok(ctx.value().cloned().unwrap_or(async_graphql::Value::Null)) + } + Context::Path(path) => Ok(ctx + .path_value(path) + .cloned() + .unwrap_or(async_graphql::Value::Null)), + }, + Expression::Input(input, path) => { + let inp = &input.eval(ctx, conc).await?; + Ok(inp + .get_path(path) + .unwrap_or(&async_graphql::Value::Null) + .clone()) + } + Expression::Literal(value) => Ok(serde_json::from_value(value.clone())?), + Expression::EqualTo(left, right) => Ok(async_graphql::Value::from( + left.eval(ctx, conc).await? == right.eval(ctx, conc).await?, + )), + Expression::IO(operation) => operation.eval(ctx, conc).await, - Expression::Relation(relation) => relation.eval(ctx, conc).await, - Expression::Logic(logic) => logic.eval(ctx, conc).await, - Expression::List(list) => list.eval(ctx, conc).await, - Expression::Math(math) => math.eval(ctx, conc).await, - } - }) - } + Expression::Relation(relation) => relation.eval(ctx, conc).await, + Expression::Logic(logic) => logic.eval(ctx, conc).await, + Expression::List(list) => list.eval(ctx, conc).await, + Expression::Math(math) => math.eval(ctx, conc).await, + } + }) + } } diff --git a/src/lambda/graphql_operation_context.rs b/src/lambda/graphql_operation_context.rs index c8d53211f26..74358041586 100644 --- a/src/lambda/graphql_operation_context.rs +++ b/src/lambda/graphql_operation_context.rs @@ -1,3 +1,3 @@ pub trait GraphQLOperationContext { - fn selection_set(&self) -> Option; + fn selection_set(&self) -> Option; } diff --git a/src/lambda/io.rs b/src/lambda/io.rs index edd143a1a4b..4970222917a 100644 --- a/src/lambda/io.rs +++ b/src/lambda/io.rs @@ -24,220 +24,225 @@ use crate::{grpc, http}; #[derive(Clone, Debug)] pub enum IO { - Http { - req_template: http::RequestTemplate, - group_by: Option, - dl_id: Option, - }, - GraphQLEndpoint { - req_template: graphql::RequestTemplate, - field_name: String, - batch: bool, - dl_id: Option, - }, - Grpc { - req_template: grpc::RequestTemplate, - group_by: Option, - dl_id: Option, - }, - JS(Box, String), + Http { + req_template: http::RequestTemplate, + group_by: Option, + dl_id: Option, + }, + GraphQLEndpoint { + req_template: graphql::RequestTemplate, + field_name: String, + batch: bool, + dl_id: Option, + }, + Grpc { + req_template: grpc::RequestTemplate, + group_by: Option, + dl_id: Option, + }, + JS(Box, String), } #[derive(Clone, Copy, Debug)] pub struct DataLoaderId(pub usize); impl Eval for IO { - fn eval<'a, Ctx: super::ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a super::EvaluationContext<'a, Ctx>, - _conc: &'a super::Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - match self { - IO::Http { req_template, dl_id, .. } => { - let req = req_template.to_request(ctx)?; - let is_get = req.method() == reqwest::Method::GET; - - let res = if is_get && ctx.req_ctx.is_batching_enabled() { - let data_loader: Option<&DataLoader> = - dl_id.and_then(|index| ctx.req_ctx.http_data_loaders.get(index.0)); - execute_request_with_dl(ctx, req, data_loader).await? - } else { - execute_raw_request(ctx, req).await? - }; - - if ctx.req_ctx.server.get_enable_http_validation() { - req_template - .endpoint - .output - .validate(&res.body) - .to_result() - .map_err(EvaluationError::from)?; - } - - set_cache_control(ctx, &res); - - Ok(res.body) - } - IO::GraphQLEndpoint { req_template, field_name, dl_id, .. } => { - let req = req_template.to_request(ctx)?; - - let res = if ctx.req_ctx.upstream.batch.is_some() - && matches!(req_template.operation_type, GraphQLOperationType::Query) - { - let data_loader: Option<&DataLoader> = - dl_id.and_then(|index| ctx.req_ctx.gql_data_loaders.get(index.0)); - execute_request_with_dl(ctx, req, data_loader).await? - } else { - execute_raw_request(ctx, req).await? - }; - - set_cache_control(ctx, &res); - parse_graphql_response(ctx, res, field_name) - } - IO::Grpc { req_template, dl_id, .. } => { - let rendered = req_template.render(ctx)?; - - let res = if ctx.req_ctx.upstream.batch.is_some() && + fn eval<'a, Ctx: super::ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a super::EvaluationContext<'a, Ctx>, + _conc: &'a super::Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + match self { + IO::Http { req_template, dl_id, .. } => { + let req = req_template.to_request(ctx)?; + let is_get = req.method() == reqwest::Method::GET; + + let res = if is_get && ctx.req_ctx.is_batching_enabled() { + let data_loader: Option<&DataLoader> = + dl_id.and_then(|index| ctx.req_ctx.http_data_loaders.get(index.0)); + execute_request_with_dl(ctx, req, data_loader).await? + } else { + execute_raw_request(ctx, req).await? + }; + + if ctx.req_ctx.server.get_enable_http_validation() { + req_template + .endpoint + .output + .validate(&res.body) + .to_result() + .map_err(EvaluationError::from)?; + } + + set_cache_control(ctx, &res); + + Ok(res.body) + } + IO::GraphQLEndpoint { req_template, field_name, dl_id, .. } => { + let req = req_template.to_request(ctx)?; + + let res = if ctx.req_ctx.upstream.batch.is_some() + && matches!(req_template.operation_type, GraphQLOperationType::Query) + { + let data_loader: Option<&DataLoader> = + dl_id.and_then(|index| ctx.req_ctx.gql_data_loaders.get(index.0)); + execute_request_with_dl(ctx, req, data_loader).await? + } else { + execute_raw_request(ctx, req).await? + }; + + set_cache_control(ctx, &res); + parse_graphql_response(ctx, res, field_name) + } + IO::Grpc { req_template, dl_id, .. } => { + let rendered = req_template.render(ctx)?; + + let res = if ctx.req_ctx.upstream.batch.is_some() && // TODO: share check for operation_type for resolvers matches!(req_template.operation_type, GraphQLOperationType::Query) - { - let data_loader: Option<&DataLoader> = - dl_id.and_then(|index| ctx.req_ctx.grpc_data_loaders.get(index.0)); - execute_grpc_request_with_dl(ctx, rendered, data_loader).await? - } else { - let req = rendered.to_request()?; - execute_raw_grpc_request(ctx, req, &req_template.operation).await? - }; - - set_cache_control(ctx, &res); - - Ok(res.body) - } - IO::JS(input, script) => { - let result; - #[cfg(not(feature = "unsafe-js"))] - { - let _ = script; - let _ = input; - result = Err(EvaluationError::JSException("JS execution is disabled".to_string()).into()); - } - - #[cfg(feature = "unsafe-js")] - { - let input = input.eval(ctx, _conc).await?; - result = javascript::execute_js(script, input, Some(ctx.timeout)) - .map_err(|e| EvaluationError::JSException(e.to_string()).into()); - } - result - } - } - }) - } + { + let data_loader: Option< + &DataLoader, + > = dl_id.and_then(|index| ctx.req_ctx.grpc_data_loaders.get(index.0)); + execute_grpc_request_with_dl(ctx, rendered, data_loader).await? + } else { + let req = rendered.to_request()?; + execute_raw_grpc_request(ctx, req, &req_template.operation).await? + }; + + set_cache_control(ctx, &res); + + Ok(res.body) + } + IO::JS(input, script) => { + let result; + #[cfg(not(feature = "unsafe-js"))] + { + let _ = script; + let _ = input; + result = Err(EvaluationError::JSException( + "JS execution is disabled".to_string(), + ) + .into()); + } + + #[cfg(feature = "unsafe-js")] + { + let input = input.eval(ctx, _conc).await?; + result = javascript::execute_js(script, input, Some(ctx.timeout)) + .map_err(|e| EvaluationError::JSException(e.to_string()).into()); + } + result + } + } + }) + } } fn set_cache_control<'ctx, Ctx: ResolverContextLike<'ctx>>( - ctx: &EvaluationContext<'ctx, Ctx>, - res: &Response, + ctx: &EvaluationContext<'ctx, Ctx>, + res: &Response, ) { - if ctx.req_ctx.server.get_enable_cache_control() && res.status.is_success() { - if let Some(policy) = cache_policy(res) { - ctx.req_ctx.set_cache_control(policy); + if ctx.req_ctx.server.get_enable_cache_control() && res.status.is_success() { + if let Some(policy) = cache_policy(res) { + ctx.req_ctx.set_cache_control(policy); + } } - } } async fn execute_raw_request<'ctx, Ctx: ResolverContextLike<'ctx>>( - ctx: &EvaluationContext<'ctx, Ctx>, - req: Request, + ctx: &EvaluationContext<'ctx, Ctx>, + req: Request, ) -> Result> { - ctx - .req_ctx - .h_client - .execute(req) - .await - .map_err(|e| EvaluationError::IOException(e.to_string()))? - .to_json() + ctx.req_ctx + .h_client + .execute(req) + .await + .map_err(|e| EvaluationError::IOException(e.to_string()))? + .to_json() } async fn execute_raw_grpc_request<'ctx, Ctx: ResolverContextLike<'ctx>>( - ctx: &EvaluationContext<'ctx, Ctx>, - req: Request, - operation: &ProtobufOperation, + ctx: &EvaluationContext<'ctx, Ctx>, + req: Request, + operation: &ProtobufOperation, ) -> Result> { - Ok( - execute_grpc_request(&ctx.req_ctx.h2_client, operation, req) - .await - .map_err(|e| EvaluationError::IOException(e.to_string()))?, - ) + Ok(execute_grpc_request(&ctx.req_ctx.h2_client, operation, req) + .await + .map_err(|e| EvaluationError::IOException(e.to_string()))?) } async fn execute_grpc_request_with_dl< - 'ctx, - Ctx: ResolverContextLike<'ctx>, - Dl: Loader, Error = Arc>, + 'ctx, + Ctx: ResolverContextLike<'ctx>, + Dl: Loader< + grpc::DataLoaderRequest, + Value = Response, + Error = Arc, + >, >( - ctx: &EvaluationContext<'ctx, Ctx>, - rendered: RenderedRequestTemplate, - data_loader: Option<&DataLoader>, + ctx: &EvaluationContext<'ctx, Ctx>, + rendered: RenderedRequestTemplate, + data_loader: Option<&DataLoader>, ) -> Result> { - let headers = ctx - .req_ctx - .upstream - .batch - .clone() - .map(|s| s.headers) - .unwrap_or_default(); - let endpoint_key = grpc::DataLoaderRequest::new(rendered, headers); - - Ok( - data_loader - .unwrap() - .load_one(endpoint_key) - .await - .map_err(|e| EvaluationError::IOException(e.to_string()))? - .unwrap_or_default(), - ) + let headers = ctx + .req_ctx + .upstream + .batch + .clone() + .map(|s| s.headers) + .unwrap_or_default(); + let endpoint_key = grpc::DataLoaderRequest::new(rendered, headers); + + Ok(data_loader + .unwrap() + .load_one(endpoint_key) + .await + .map_err(|e| EvaluationError::IOException(e.to_string()))? + .unwrap_or_default()) } async fn execute_request_with_dl< - 'ctx, - Ctx: ResolverContextLike<'ctx>, - Dl: Loader, Error = Arc>, + 'ctx, + Ctx: ResolverContextLike<'ctx>, + Dl: Loader, Error = Arc>, >( - ctx: &EvaluationContext<'ctx, Ctx>, - req: Request, - data_loader: Option<&DataLoader>, + ctx: &EvaluationContext<'ctx, Ctx>, + req: Request, + data_loader: Option<&DataLoader>, ) -> Result> { - let headers = ctx - .req_ctx - .upstream - .batch - .clone() - .map(|s| s.headers) - .unwrap_or_default(); - let endpoint_key = crate::http::DataLoaderRequest::new(req, headers); - - Ok( - data_loader - .unwrap() - .load_one(endpoint_key) - .await - .map_err(|e| EvaluationError::IOException(e.to_string()))? - .unwrap_or_default(), - ) + let headers = ctx + .req_ctx + .upstream + .batch + .clone() + .map(|s| s.headers) + .unwrap_or_default(); + let endpoint_key = crate::http::DataLoaderRequest::new(req, headers); + + Ok(data_loader + .unwrap() + .load_one(endpoint_key) + .await + .map_err(|e| EvaluationError::IOException(e.to_string()))? + .unwrap_or_default()) } fn parse_graphql_response<'ctx, Ctx: ResolverContextLike<'ctx>>( - ctx: &EvaluationContext<'ctx, Ctx>, - res: Response, - field_name: &str, + ctx: &EvaluationContext<'ctx, Ctx>, + res: Response, + field_name: &str, ) -> Result { - let res: async_graphql::Response = serde_json::from_value(res.body.into_json()?)?; + let res: async_graphql::Response = serde_json::from_value(res.body.into_json()?)?; - for error in res.errors { - ctx.add_error(error); - } + for error in res.errors { + ctx.add_error(error); + } - Ok(res.data.get_key(field_name).map(|v| v.to_owned()).unwrap_or_default()) + Ok(res + .data + .get_key(field_name) + .map(|v| v.to_owned()) + .unwrap_or_default()) } diff --git a/src/lambda/lambda.rs b/src/lambda/lambda.rs index fbcf9277c5e..522634aa4cc 100644 --- a/src/lambda/lambda.rs +++ b/src/lambda/lambda.rs @@ -6,139 +6,158 @@ use crate::{graphql, grpc, http}; #[derive(Clone)] pub struct Lambda { - _output: PhantomData A>, - pub expression: Expression, + _output: PhantomData A>, + pub expression: Expression, } impl Lambda { - fn box_expr(self) -> Box { - Box::new(self.expression) - } - pub fn new(expression: Expression) -> Self { - Self { _output: PhantomData, expression } - } - - pub fn eq(self, other: Self) -> Lambda { - Lambda::new(Expression::EqualTo(self.box_expr(), Box::new(other.expression))) - } - - pub fn to_js(self, script: String) -> Lambda { - Lambda::new(Expression::IO(IO::JS(self.box_expr(), script))) - } - - pub fn to_input_path(self, path: Vec) -> Lambda { - Lambda::new(Expression::Input(self.box_expr(), path)) - } + fn box_expr(self) -> Box { + Box::new(self.expression) + } + pub fn new(expression: Expression) -> Self { + Self { _output: PhantomData, expression } + } + + pub fn eq(self, other: Self) -> Lambda { + Lambda::new(Expression::EqualTo( + self.box_expr(), + Box::new(other.expression), + )) + } + + pub fn to_js(self, script: String) -> Lambda { + Lambda::new(Expression::IO(IO::JS(self.box_expr(), script))) + } + + pub fn to_input_path(self, path: Vec) -> Lambda { + Lambda::new(Expression::Input(self.box_expr(), path)) + } } impl Lambda { - pub fn context() -> Self { - Lambda::new(Expression::Context(expression::Context::Value)) - } - - pub fn context_field(name: String) -> Self { - Lambda::new(Expression::Context(Context::Path(vec![name]))) - } - - pub fn context_path(path: Vec) -> Self { - Lambda::new(Expression::Context(Context::Path(path))) - } - - pub fn from_request_template(req_template: http::RequestTemplate) -> Lambda { - Lambda::new(Expression::IO(IO::Http { req_template, group_by: None, dl_id: None })) - } - - pub fn from_graphql_request_template( - req_template: graphql::RequestTemplate, - field_name: String, - batch: bool, - ) -> Lambda { - Lambda::new(Expression::IO(IO::GraphQLEndpoint { - req_template, - field_name, - batch, - dl_id: None, - })) - } - - pub fn from_grpc_request_template(req_template: grpc::RequestTemplate) -> Lambda { - Lambda::new(Expression::IO(IO::Grpc { req_template, group_by: None, dl_id: None })) - } + pub fn context() -> Self { + Lambda::new(Expression::Context(expression::Context::Value)) + } + + pub fn context_field(name: String) -> Self { + Lambda::new(Expression::Context(Context::Path(vec![name]))) + } + + pub fn context_path(path: Vec) -> Self { + Lambda::new(Expression::Context(Context::Path(path))) + } + + pub fn from_request_template(req_template: http::RequestTemplate) -> Lambda { + Lambda::new(Expression::IO(IO::Http { + req_template, + group_by: None, + dl_id: None, + })) + } + + pub fn from_graphql_request_template( + req_template: graphql::RequestTemplate, + field_name: String, + batch: bool, + ) -> Lambda { + Lambda::new(Expression::IO(IO::GraphQLEndpoint { + req_template, + field_name, + batch, + dl_id: None, + })) + } + + pub fn from_grpc_request_template( + req_template: grpc::RequestTemplate, + ) -> Lambda { + Lambda::new(Expression::IO(IO::Grpc { + req_template, + group_by: None, + dl_id: None, + })) + } } impl From for Lambda where - serde_json::Value: From, + serde_json::Value: From, { - fn from(value: A) -> Self { - let json = serde_json::Value::from(value); - Lambda::new(Expression::Literal(json)) - } + fn from(value: A) -> Self { + let json = serde_json::Value::from(value); + Lambda::new(Expression::Literal(json)) + } } #[cfg(test)] mod tests { - use anyhow::Result; - use httpmock::Method::GET; - use httpmock::MockServer; - use serde::de::DeserializeOwned; - use serde_json::json; - - use crate::endpoint::Endpoint; - use crate::http::{RequestContext, RequestTemplate}; - use crate::lambda::{Concurrent, EmptyResolverContext, Eval, EvaluationContext, Lambda}; - - impl Lambda - where - B: DeserializeOwned, - { - async fn eval(self) -> Result { - let req_ctx = RequestContext::default(); - let ctx = EvaluationContext::new(&req_ctx, &EmptyResolverContext); - let result = self.expression.eval(&ctx, &Concurrent::Sequential).await?; - let json = serde_json::to_value(result)?; - Ok(serde_json::from_value(json)?) - } - } - - #[tokio::test] - async fn test_equal_to_true() { - let lambda = Lambda::from(1.0).eq(Lambda::from(1.0)); - let result = lambda.eval().await.unwrap(); - assert!(result) - } - - #[tokio::test] - async fn test_equal_to_false() { - let lambda = Lambda::from(1.0).eq(Lambda::from(2.0)); - let result = lambda.eval().await.unwrap(); - assert!(!result) - } - - #[tokio::test] - async fn test_endpoint() { - let server = MockServer::start(); - - server.mock(|when, then| { - when.method(GET).path("/users"); - then - .status(200) - .header("content-type", "application/json") - .json_body(json!({ "name": "Hans" })); - }); - - let endpoint = RequestTemplate::try_from(Endpoint::new(server.url("/users").to_string())).unwrap(); - let result = Lambda::from_request_template(endpoint).eval().await.unwrap(); - - assert_eq!(result.as_object().unwrap().get("name").unwrap(), "Hans") - } - - #[cfg(feature = "unsafe-js")] - #[tokio::test] - async fn test_js() { - let result = Lambda::from(1.0).to_js("ctx + 100".to_string()).eval().await; - let f64 = result.unwrap().as_f64().unwrap(); - assert_eq!(f64, 101.0) - } + use anyhow::Result; + use httpmock::Method::GET; + use httpmock::MockServer; + use serde::de::DeserializeOwned; + use serde_json::json; + + use crate::endpoint::Endpoint; + use crate::http::{RequestContext, RequestTemplate}; + use crate::lambda::{Concurrent, EmptyResolverContext, Eval, EvaluationContext, Lambda}; + + impl Lambda + where + B: DeserializeOwned, + { + async fn eval(self) -> Result { + let req_ctx = RequestContext::default(); + let ctx = EvaluationContext::new(&req_ctx, &EmptyResolverContext); + let result = self.expression.eval(&ctx, &Concurrent::Sequential).await?; + let json = serde_json::to_value(result)?; + Ok(serde_json::from_value(json)?) + } + } + + #[tokio::test] + async fn test_equal_to_true() { + let lambda = Lambda::from(1.0).eq(Lambda::from(1.0)); + let result = lambda.eval().await.unwrap(); + assert!(result) + } + + #[tokio::test] + async fn test_equal_to_false() { + let lambda = Lambda::from(1.0).eq(Lambda::from(2.0)); + let result = lambda.eval().await.unwrap(); + assert!(!result) + } + + #[tokio::test] + async fn test_endpoint() { + let server = MockServer::start(); + + server.mock(|when, then| { + when.method(GET).path("/users"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ "name": "Hans" })); + }); + + let endpoint = + RequestTemplate::try_from(Endpoint::new(server.url("/users").to_string())).unwrap(); + let result = Lambda::from_request_template(endpoint) + .eval() + .await + .unwrap(); + + assert_eq!(result.as_object().unwrap().get("name").unwrap(), "Hans") + } + + #[cfg(feature = "unsafe-js")] + #[tokio::test] + async fn test_js() { + let result = Lambda::from(1.0) + .to_js("ctx + 100".to_string()) + .eval() + .await; + let f64 = result.unwrap().as_f64().unwrap(); + assert_eq!(f64, 101.0) + } } diff --git a/src/lambda/list.rs b/src/lambda/list.rs index d0b56de40b5..11fd823a202 100644 --- a/src/lambda/list.rs +++ b/src/lambda/list.rs @@ -5,58 +5,67 @@ use anyhow::Result; use async_graphql_value::ConstValue; use futures_util::future::join_all; -use super::{Concurrent, Eval, EvaluationContext, EvaluationError, Expression, ResolverContextLike}; +use super::{ + Concurrent, Eval, EvaluationContext, EvaluationError, Expression, ResolverContextLike, +}; #[derive(Clone, Debug)] pub enum List { - Concat(Vec), + Concat(Vec), } impl Eval for List { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - match self { - List::Concat(list) => join_all(list.iter().map(|expr| expr.eval(ctx, conc))) - .await - .into_iter() - .try_fold(async_graphql::Value::List(vec![]), |acc, result| match (acc, result?) { - (ConstValue::List(mut lhs), ConstValue::List(rhs)) => { - lhs.extend(rhs); - Ok(ConstValue::List(lhs)) + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + match self { + List::Concat(list) => join_all(list.iter().map(|expr| expr.eval(ctx, conc))) + .await + .into_iter() + .try_fold(async_graphql::Value::List(vec![]), |acc, result| { + match (acc, result?) { + (ConstValue::List(mut lhs), ConstValue::List(rhs)) => { + lhs.extend(rhs); + Ok(ConstValue::List(lhs)) + } + _ => Err(EvaluationError::ExprEvalError( + "element is not a list".into(), + ))?, + } + }), } - _ => Err(EvaluationError::ExprEvalError("element is not a list".into()))?, - }), - } - }) - } + }) + } } impl Eval for T where - T: AsRef<[Expression]> + Send + Sync, - C: FromIterator, + T: AsRef<[Expression]> + Send + Sync, + C: FromIterator, { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - let future_iter = self.as_ref().iter().map(|expr| expr.eval(ctx, conc)); - match *conc { - Concurrent::Parallel => join_all(future_iter).await.into_iter().collect::>(), - Concurrent::Sequential => { - let mut results = Vec::with_capacity(self.as_ref().len()); - for future in future_iter { - results.push(future.await?); - } - Ok(results.into_iter().collect()) - } - } - }) - } + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + let future_iter = self.as_ref().iter().map(|expr| expr.eval(ctx, conc)); + match *conc { + Concurrent::Parallel => join_all(future_iter) + .await + .into_iter() + .collect::>(), + Concurrent::Sequential => { + let mut results = Vec::with_capacity(self.as_ref().len()); + for future in future_iter { + results.push(future.await?); + } + Ok(results.into_iter().collect()) + } + } + }) + } } diff --git a/src/lambda/logic.rs b/src/lambda/logic.rs index cc8ea5e3b42..43ceecf35f3 100644 --- a/src/lambda/logic.rs +++ b/src/lambda/logic.rs @@ -8,73 +8,75 @@ use super::{Concurrent, Eval, EvaluationContext, Expression, ResolverContextLike #[derive(Clone, Debug)] pub enum Logic { - If { - cond: Box, - then: Box, - els: Box, - }, - And(Vec), - Or(Vec), - Cond(Vec<(Box, Box)>), - DefaultTo(Box, Box), - IsEmpty(Box), - Not(Box), + If { + cond: Box, + then: Box, + els: Box, + }, + And(Vec), + Or(Vec), + Cond(Vec<(Box, Box)>), + DefaultTo(Box, Box), + IsEmpty(Box), + Not(Box), } impl Eval for Logic { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - Ok(match self { - Logic::Or(list) => { - let future_iter = list.iter().map(|expr| async move { expr.eval(ctx, conc).await }); + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + Ok(match self { + Logic::Or(list) => { + let future_iter = list + .iter() + .map(|expr| async move { expr.eval(ctx, conc).await }); - conc - .fold(future_iter, false, |acc, val| Ok(acc || is_truthy(&val?))) - .await - .map(ConstValue::from)? - } - Logic::Cond(list) => { - for (cond, expr) in list.iter() { - if is_truthy(&cond.eval(ctx, conc).await?) { - return expr.eval(ctx, conc).await; - } - } - ConstValue::Null - } - Logic::DefaultTo(value, default) => { - let result = value.eval(ctx, conc).await?; - if is_empty(&result) { - default.eval(ctx, conc).await? - } else { - result - } - } - Logic::IsEmpty(expr) => is_empty(&expr.eval(ctx, conc).await?).into(), - Logic::Not(expr) => (!is_truthy(&expr.eval(ctx, conc).await?)).into(), + conc.fold(future_iter, false, |acc, val| Ok(acc || is_truthy(&val?))) + .await + .map(ConstValue::from)? + } + Logic::Cond(list) => { + for (cond, expr) in list.iter() { + if is_truthy(&cond.eval(ctx, conc).await?) { + return expr.eval(ctx, conc).await; + } + } + ConstValue::Null + } + Logic::DefaultTo(value, default) => { + let result = value.eval(ctx, conc).await?; + if is_empty(&result) { + default.eval(ctx, conc).await? + } else { + result + } + } + Logic::IsEmpty(expr) => is_empty(&expr.eval(ctx, conc).await?).into(), + Logic::Not(expr) => (!is_truthy(&expr.eval(ctx, conc).await?)).into(), - Logic::And(list) => { - let future_iter = list.iter().map(|expr| async move { expr.eval(ctx, conc).await }); + Logic::And(list) => { + let future_iter = list + .iter() + .map(|expr| async move { expr.eval(ctx, conc).await }); - conc - .fold(future_iter, true, |acc, val| Ok(acc && is_truthy(&val?))) - .await - .map(ConstValue::from)? - } - Logic::If { cond, then, els } => { - let cond = cond.eval(ctx, conc).await?; - if is_truthy(&cond) { - then.eval(ctx, conc).await? - } else { - els.eval(ctx, conc).await? - } - } - }) - }) - } + conc.fold(future_iter, true, |acc, val| Ok(acc && is_truthy(&val?))) + .await + .map(ConstValue::from)? + } + Logic::If { cond, then, els } => { + let cond = cond.eval(ctx, conc).await?; + if is_truthy(&cond) { + then.eval(ctx, conc).await? + } else { + els.eval(ctx, conc).await? + } + } + }) + }) + } } /// Check if a value is truthy @@ -83,54 +85,54 @@ impl Eval for Logic { /// 1. An empty string is considered falsy /// 2. A collection of bytes is truthy, even if the value in those bytes is 0. An empty collection is falsy. pub fn is_truthy(value: &async_graphql::Value) -> bool { - use async_graphql::{Number, Value}; - use hyper::body::Bytes; + use async_graphql::{Number, Value}; + use hyper::body::Bytes; - match value { - &Value::Null => false, - &Value::Enum(_) => true, - &Value::List(_) => true, - &Value::Object(_) => true, - Value::String(s) => !s.is_empty(), - &Value::Boolean(b) => b, - Value::Number(n) => n != &Number::from(0), - Value::Binary(b) => b != &Bytes::default(), - } + match value { + &Value::Null => false, + &Value::Enum(_) => true, + &Value::List(_) => true, + &Value::Object(_) => true, + Value::String(s) => !s.is_empty(), + &Value::Boolean(b) => b, + Value::Number(n) => n != &Number::from(0), + Value::Binary(b) => b != &Bytes::default(), + } } fn is_empty(value: &async_graphql::Value) -> bool { - match value { - ConstValue::Null => true, - ConstValue::Number(_) | ConstValue::Boolean(_) | ConstValue::Enum(_) => false, - ConstValue::Binary(bytes) => bytes.is_empty(), - ConstValue::List(list) => list.is_empty(), - ConstValue::Object(obj) => obj.is_empty(), - ConstValue::String(string) => string.is_empty(), - } + match value { + ConstValue::Null => true, + ConstValue::Number(_) | ConstValue::Boolean(_) | ConstValue::Enum(_) => false, + ConstValue::Binary(bytes) => bytes.is_empty(), + ConstValue::List(list) => list.is_empty(), + ConstValue::Object(obj) => obj.is_empty(), + ConstValue::String(string) => string.is_empty(), + } } #[cfg(test)] mod tests { - use async_graphql::{Name, Number, Value}; - use hyper::body::Bytes; - use indexmap::IndexMap; + use async_graphql::{Name, Number, Value}; + use hyper::body::Bytes; + use indexmap::IndexMap; - use crate::lambda::is_truthy; + use crate::lambda::is_truthy; - #[test] - fn test_is_truthy() { - assert!(is_truthy(&Value::Enum(Name::new("EXAMPLE")))); - assert!(is_truthy(&Value::List(vec![]))); - assert!(is_truthy(&Value::Object(IndexMap::default()))); - assert!(is_truthy(&Value::String("Hello".to_string()))); - assert!(is_truthy(&Value::Boolean(true))); - assert!(is_truthy(&Value::Number(Number::from(1)))); - assert!(is_truthy(&Value::Binary(Bytes::from_static(&[0, 1, 2])))); + #[test] + fn test_is_truthy() { + assert!(is_truthy(&Value::Enum(Name::new("EXAMPLE")))); + assert!(is_truthy(&Value::List(vec![]))); + assert!(is_truthy(&Value::Object(IndexMap::default()))); + assert!(is_truthy(&Value::String("Hello".to_string()))); + assert!(is_truthy(&Value::Boolean(true))); + assert!(is_truthy(&Value::Number(Number::from(1)))); + assert!(is_truthy(&Value::Binary(Bytes::from_static(&[0, 1, 2])))); - assert!(!is_truthy(&Value::Null)); - assert!(!is_truthy(&Value::String("".to_string()))); - assert!(!is_truthy(&Value::Boolean(false))); - assert!(!is_truthy(&Value::Number(Number::from(0)))); - assert!(!is_truthy(&Value::Binary(Bytes::default()))); - } + assert!(!is_truthy(&Value::Null)); + assert!(!is_truthy(&Value::String("".to_string()))); + assert!(!is_truthy(&Value::Boolean(false))); + assert!(!is_truthy(&Value::Number(Number::from(0)))); + assert!(!is_truthy(&Value::Binary(Bytes::default()))); + } } diff --git a/src/lambda/math.rs b/src/lambda/math.rs index 29123b44cf2..9cd9044cede 100644 --- a/src/lambda/math.rs +++ b/src/lambda/math.rs @@ -5,164 +5,163 @@ use std::pin::Pin; use anyhow::Result; use async_graphql_value::ConstValue; -use super::{Concurrent, Eval, EvaluationContext, EvaluationError, Expression, ResolverContextLike}; +use super::{ + Concurrent, Eval, EvaluationContext, EvaluationError, Expression, ResolverContextLike, +}; use crate::json::JsonLike; #[derive(Clone, Debug)] pub enum Math { - Mod(Box, Box), - Add(Box, Box), - Dec(Box), - Divide(Box, Box), - Inc(Box), - Multiply(Box, Box), - Negate(Box), - Product(Vec), - Subtract(Box, Box), - Sum(Vec), + Mod(Box, Box), + Add(Box, Box), + Dec(Box), + Divide(Box, Box), + Inc(Box), + Multiply(Box, Box), + Negate(Box), + Product(Vec), + Subtract(Box, Box), + Sum(Vec), } impl Eval for Math { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - Ok(match self { - Math::Mod(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - try_i64_operation(&lhs, &rhs, ops::Rem::rem) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Rem::rem)) - .ok_or(EvaluationError::ExprEvalError("mod".into()))? - } - Math::Add(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - try_f64_operation(&lhs, &rhs, ops::Add::add) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Add::add)) - .or_else(|| try_i64_operation(&lhs, &rhs, ops::Add::add)) - .ok_or(EvaluationError::ExprEvalError("add".into()))? - } - Math::Dec(val) => { - let val = val.eval(ctx, conc).await?; - - val - .as_f64_ok() - .ok() - .map(|val| (val - 1f64).into()) - .or_else(|| val.as_u64_ok().ok().map(|val| (val - 1u64).into())) - .or_else(|| val.as_i64_ok().ok().map(|val| (val - 1i64).into())) - .ok_or(EvaluationError::ExprEvalError("dec".into()))? - } - Math::Divide(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - try_f64_operation(&lhs, &rhs, ops::Div::div) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Div::div)) - .or_else(|| try_i64_operation(&lhs, &rhs, ops::Div::div)) - .ok_or(EvaluationError::ExprEvalError("divide".into()))? - } - Math::Inc(val) => { - let val = val.eval(ctx, conc).await?; - - val - .as_f64_ok() - .ok() - .map(|val| (val + 1f64).into()) - .or_else(|| val.as_u64_ok().ok().map(|val| (val + 1u64).into())) - .or_else(|| val.as_i64_ok().ok().map(|val| (val + 1i64).into())) - .ok_or(EvaluationError::ExprEvalError("dec".into()))? - } - Math::Multiply(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - try_f64_operation(&lhs, &rhs, ops::Mul::mul) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Mul::mul)) - .or_else(|| try_i64_operation(&lhs, &rhs, ops::Mul::mul)) - .ok_or(EvaluationError::ExprEvalError("multiply".into()))? - } - Math::Negate(val) => { - let val = val.eval(ctx, conc).await?; - - val - .as_f64_ok() - .ok() - .map(|val| (-val).into()) - .or_else(|| val.as_i64_ok().ok().map(|val| (-val).into())) - .ok_or(EvaluationError::ExprEvalError("neg".into()))? - } - Math::Product(exprs) => { - let results: Vec<_> = exprs.eval(ctx, conc).await?; - - results.into_iter().try_fold(1i64.into(), |lhs, rhs| { - try_f64_operation(&lhs, &rhs, ops::Mul::mul) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Mul::mul)) - .or_else(|| try_i64_operation(&lhs, &rhs, ops::Mul::mul)) - .ok_or(EvaluationError::ExprEvalError("product".into())) - })? - } - Math::Subtract(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - try_f64_operation(&lhs, &rhs, ops::Sub::sub) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Sub::sub)) - .or_else(|| try_i64_operation(&lhs, &rhs, ops::Sub::sub)) - .ok_or(EvaluationError::ExprEvalError("subtract".into()))? - } - Math::Sum(exprs) => { - let results: Vec<_> = exprs.eval(ctx, conc).await?; - - results.into_iter().try_fold(0i64.into(), |lhs, rhs| { - try_f64_operation(&lhs, &rhs, ops::Add::add) - .or_else(|| try_u64_operation(&lhs, &rhs, ops::Add::add)) - .or_else(|| try_i64_operation(&lhs, &rhs, ops::Add::add)) - .ok_or(EvaluationError::ExprEvalError("sum".into())) - })? - } - }) - }) - } + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + Ok(match self { + Math::Mod(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + try_i64_operation(&lhs, &rhs, ops::Rem::rem) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Rem::rem)) + .ok_or(EvaluationError::ExprEvalError("mod".into()))? + } + Math::Add(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + try_f64_operation(&lhs, &rhs, ops::Add::add) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Add::add)) + .or_else(|| try_i64_operation(&lhs, &rhs, ops::Add::add)) + .ok_or(EvaluationError::ExprEvalError("add".into()))? + } + Math::Dec(val) => { + let val = val.eval(ctx, conc).await?; + + val.as_f64_ok() + .ok() + .map(|val| (val - 1f64).into()) + .or_else(|| val.as_u64_ok().ok().map(|val| (val - 1u64).into())) + .or_else(|| val.as_i64_ok().ok().map(|val| (val - 1i64).into())) + .ok_or(EvaluationError::ExprEvalError("dec".into()))? + } + Math::Divide(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + try_f64_operation(&lhs, &rhs, ops::Div::div) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Div::div)) + .or_else(|| try_i64_operation(&lhs, &rhs, ops::Div::div)) + .ok_or(EvaluationError::ExprEvalError("divide".into()))? + } + Math::Inc(val) => { + let val = val.eval(ctx, conc).await?; + + val.as_f64_ok() + .ok() + .map(|val| (val + 1f64).into()) + .or_else(|| val.as_u64_ok().ok().map(|val| (val + 1u64).into())) + .or_else(|| val.as_i64_ok().ok().map(|val| (val + 1i64).into())) + .ok_or(EvaluationError::ExprEvalError("dec".into()))? + } + Math::Multiply(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + try_f64_operation(&lhs, &rhs, ops::Mul::mul) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Mul::mul)) + .or_else(|| try_i64_operation(&lhs, &rhs, ops::Mul::mul)) + .ok_or(EvaluationError::ExprEvalError("multiply".into()))? + } + Math::Negate(val) => { + let val = val.eval(ctx, conc).await?; + + val.as_f64_ok() + .ok() + .map(|val| (-val).into()) + .or_else(|| val.as_i64_ok().ok().map(|val| (-val).into())) + .ok_or(EvaluationError::ExprEvalError("neg".into()))? + } + Math::Product(exprs) => { + let results: Vec<_> = exprs.eval(ctx, conc).await?; + + results.into_iter().try_fold(1i64.into(), |lhs, rhs| { + try_f64_operation(&lhs, &rhs, ops::Mul::mul) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Mul::mul)) + .or_else(|| try_i64_operation(&lhs, &rhs, ops::Mul::mul)) + .ok_or(EvaluationError::ExprEvalError("product".into())) + })? + } + Math::Subtract(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + try_f64_operation(&lhs, &rhs, ops::Sub::sub) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Sub::sub)) + .or_else(|| try_i64_operation(&lhs, &rhs, ops::Sub::sub)) + .ok_or(EvaluationError::ExprEvalError("subtract".into()))? + } + Math::Sum(exprs) => { + let results: Vec<_> = exprs.eval(ctx, conc).await?; + + results.into_iter().try_fold(0i64.into(), |lhs, rhs| { + try_f64_operation(&lhs, &rhs, ops::Add::add) + .or_else(|| try_u64_operation(&lhs, &rhs, ops::Add::add)) + .or_else(|| try_i64_operation(&lhs, &rhs, ops::Add::add)) + .ok_or(EvaluationError::ExprEvalError("sum".into())) + })? + } + }) + }) + } } fn try_f64_operation(lhs: &ConstValue, rhs: &ConstValue, f: F) -> Option where - F: Fn(f64, f64) -> f64, + F: Fn(f64, f64) -> f64, { - match (lhs, rhs) { - (ConstValue::Number(lhs), ConstValue::Number(rhs)) => { - lhs.as_f64().and_then(|lhs| rhs.as_f64().map(|rhs| f(lhs, rhs).into())) + match (lhs, rhs) { + (ConstValue::Number(lhs), ConstValue::Number(rhs)) => lhs + .as_f64() + .and_then(|lhs| rhs.as_f64().map(|rhs| f(lhs, rhs).into())), + _ => None, } - _ => None, - } } fn try_i64_operation(lhs: &ConstValue, rhs: &ConstValue, f: F) -> Option where - F: Fn(i64, i64) -> i64, + F: Fn(i64, i64) -> i64, { - match (lhs, rhs) { - (ConstValue::Number(lhs), ConstValue::Number(rhs)) => { - lhs.as_i64().and_then(|lhs| rhs.as_i64().map(|rhs| f(lhs, rhs).into())) + match (lhs, rhs) { + (ConstValue::Number(lhs), ConstValue::Number(rhs)) => lhs + .as_i64() + .and_then(|lhs| rhs.as_i64().map(|rhs| f(lhs, rhs).into())), + _ => None, } - _ => None, - } } fn try_u64_operation(lhs: &ConstValue, rhs: &ConstValue, f: F) -> Option where - F: Fn(u64, u64) -> u64, + F: Fn(u64, u64) -> u64, { - match (lhs, rhs) { - (ConstValue::Number(lhs), ConstValue::Number(rhs)) => { - lhs.as_u64().and_then(|lhs| rhs.as_u64().map(|rhs| f(lhs, rhs).into())) + match (lhs, rhs) { + (ConstValue::Number(lhs), ConstValue::Number(rhs)) => lhs + .as_u64() + .and_then(|lhs| rhs.as_u64().map(|rhs| f(lhs, rhs).into())), + _ => None, } - _ => None, - } } diff --git a/src/lambda/relation.rs b/src/lambda/relation.rs index 233e0d8738a..a90af9e0aaf 100644 --- a/src/lambda/relation.rs +++ b/src/lambda/relation.rs @@ -7,294 +7,323 @@ use anyhow::Result; use async_graphql_value::ConstValue; use futures_util::future::join_all; -use super::{Concurrent, Eval, EvaluationContext, EvaluationError, Expression, ResolverContextLike}; +use super::{ + Concurrent, Eval, EvaluationContext, EvaluationError, Expression, ResolverContextLike, +}; use crate::helpers::value::HashableConstValue; #[derive(Clone, Debug)] pub enum Relation { - Intersection(Vec), - Difference(Vec, Vec), - Equals(Box, Box), - Gt(Box, Box), - Gte(Box, Box), - Lt(Box, Box), - Lte(Box, Box), - Max(Vec), - Min(Vec), - PathEq(Box, Vec, Box), - PropEq(Box, String, Box), - SortPath(Box, Vec), - SymmetricDifference(Vec, Vec), - Union(Vec, Vec), + Intersection(Vec), + Difference(Vec, Vec), + Equals(Box, Box), + Gt(Box, Box), + Gte(Box, Box), + Lt(Box, Box), + Lte(Box, Box), + Max(Vec), + Min(Vec), + PathEq(Box, Vec, Box), + PropEq(Box, String, Box), + SortPath(Box, Vec), + SymmetricDifference(Vec, Vec), + Union(Vec, Vec), } impl Eval for Relation { - fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( - &'a self, - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - ) -> Pin> + 'a + Send>> { - Box::pin(async move { - Ok(match self { - Relation::Intersection(exprs) => { - let results = join_all(exprs.iter().map(|expr| expr.eval(ctx, conc))).await; - - let mut results_iter = results.into_iter(); - - let set: HashSet<_> = match results_iter.next() { - Some(first) => match first? { - ConstValue::List(list) => list.into_iter().map(HashableConstValue).collect(), - _ => Err(EvaluationError::ExprEvalError("element is not a list".into()))?, - }, - None => Err(EvaluationError::ExprEvalError("element is not a list".into()))?, - }; - - let final_set = results_iter.try_fold(set, |mut acc, result| match result? { - ConstValue::List(list) => { - let set: HashSet<_> = list.into_iter().map(HashableConstValue).collect(); - acc = acc.intersection(&set).cloned().collect(); - Ok::<_, anyhow::Error>(acc) - } - _ => Err(EvaluationError::ExprEvalError("element is not a list".into()))?, - })?; - - final_set - .into_iter() - .map(|HashableConstValue(const_value)| const_value) - .collect() - } - Relation::Difference(lhs, rhs) => { - set_operation(ctx, conc, lhs, rhs, |lhs, rhs| { - lhs - .difference(&rhs) - .cloned() - .map(|HashableConstValue(const_value)| const_value) - .collect() - }) - .await? - } - Relation::Equals(lhs, rhs) => (lhs.eval(ctx, conc).await? == rhs.eval(ctx, conc).await?).into(), - Relation::Gt(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - (compare(&lhs, &rhs) == Some(Ordering::Greater)).into() - } - Relation::Gte(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - matches!(compare(&lhs, &rhs), Some(Ordering::Greater) | Some(Ordering::Equal)).into() - } - Relation::Lt(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - (compare(&lhs, &rhs) == Some(Ordering::Less)).into() - } - Relation::Lte(lhs, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let rhs = rhs.eval(ctx, conc).await?; - - matches!(compare(&lhs, &rhs), Some(Ordering::Less) | Some(Ordering::Equal)).into() - } - Relation::Max(exprs) => { - let mut results: Vec<_> = exprs.eval(ctx, conc).await?; - - let last = results.pop().ok_or(EvaluationError::ExprEvalError( - "`max` cannot be called on empty list".into(), - ))?; - - results.into_iter().try_fold(last, |mut largest, current| { - let ord = compare(&largest, ¤t); - largest = match ord { - Some(Ordering::Greater | Ordering::Equal) => largest, - Some(Ordering::Less) => current, - _ => Err(anyhow::anyhow!( - "`max` cannot be calculated for types that cannot be compared" - ))?, - }; - Ok::<_, anyhow::Error>(largest) - })? - } - Relation::Min(exprs) => { - let mut results: Vec<_> = exprs.eval(ctx, conc).await?; - - let last = results.pop().ok_or(EvaluationError::ExprEvalError( - "`min` cannot be called on empty list".into(), - ))?; - - results.into_iter().try_fold(last, |mut largest, current| { - let ord = compare(&largest, ¤t); - largest = match ord { - Some(Ordering::Less | Ordering::Equal) => largest, - Some(Ordering::Greater) => current, - _ => Err(anyhow::anyhow!( - "`min` cannot be calculated for types that cannot be compared" - ))?, - }; - Ok::<_, anyhow::Error>(largest) - })? - } - Relation::PathEq(lhs, path, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let lhs = - get_path_for_const_value_owned(path, lhs).ok_or(anyhow::anyhow!("Could not find path: {path:?}"))?; - - let rhs = rhs.eval(ctx, conc).await?; - let rhs = - get_path_for_const_value_owned(path, rhs).ok_or(anyhow::anyhow!("Could not find path: {path:?}"))?; - - (lhs == rhs).into() - } - Relation::PropEq(lhs, prop, rhs) => { - let lhs = lhs.eval(ctx, conc).await?; - let lhs = - get_path_for_const_value_owned(&[prop], lhs).ok_or(anyhow::anyhow!("Could not find path: {prop:?}"))?; - - let rhs = rhs.eval(ctx, conc).await?; - let rhs = - get_path_for_const_value_owned(&[prop], rhs).ok_or(anyhow::anyhow!("Could not find path: {prop:?}"))?; - - (lhs == rhs).into() - } - Relation::SortPath(expr, path) => { - let value = expr.eval(ctx, conc).await?; - let values = match value { - ConstValue::List(list) => list, - _ => Err(EvaluationError::ExprEvalError( - "`sortPath` can only be applied to expressions that return list".into(), - ))?, - }; - - let is_comparable = is_list_comparable(&values); - let mut values: Vec<_> = values.into_iter().enumerate().collect(); - - if !is_comparable { - Err(anyhow::anyhow!("sortPath requires a list of comparable types"))? - } - - let value_paths: Vec<_> = values - .iter() - .filter_map(|(_, val)| get_path_for_const_value_ref(path, val)) - .cloned() - .collect(); - - if values.len() != value_paths.len() { - Err(anyhow::anyhow!( - "path is not valid for all the element in the list: {value_paths:?}" - ))? - } - - values.sort_by(|(index1, _), (index2, _)| compare(&value_paths[*index1], &value_paths[*index2]).unwrap()); - - values.into_iter().map(|(_, val)| val).collect::>().into() - } - Relation::SymmetricDifference(lhs, rhs) => { - set_operation(ctx, conc, lhs, rhs, |lhs, rhs| { - lhs - .symmetric_difference(&rhs) - .cloned() - .map(|HashableConstValue(const_value)| const_value) - .collect() - }) - .await? - } - Relation::Union(lhs, rhs) => { - set_operation(ctx, conc, lhs, rhs, |lhs, rhs| { - lhs - .union(&rhs) - .cloned() - .map(|HashableConstValue(const_value)| const_value) - .collect() - }) - .await? - } - }) - }) - } + fn eval<'a, Ctx: ResolverContextLike<'a> + Sync + Send>( + &'a self, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + ) -> Pin> + 'a + Send>> { + Box::pin(async move { + Ok(match self { + Relation::Intersection(exprs) => { + let results = join_all(exprs.iter().map(|expr| expr.eval(ctx, conc))).await; + + let mut results_iter = results.into_iter(); + + let set: HashSet<_> = match results_iter.next() { + Some(first) => match first? { + ConstValue::List(list) => { + list.into_iter().map(HashableConstValue).collect() + } + _ => Err(EvaluationError::ExprEvalError( + "element is not a list".into(), + ))?, + }, + None => Err(EvaluationError::ExprEvalError( + "element is not a list".into(), + ))?, + }; + + let final_set = + results_iter.try_fold(set, |mut acc, result| match result? { + ConstValue::List(list) => { + let set: HashSet<_> = + list.into_iter().map(HashableConstValue).collect(); + acc = acc.intersection(&set).cloned().collect(); + Ok::<_, anyhow::Error>(acc) + } + _ => Err(EvaluationError::ExprEvalError( + "element is not a list".into(), + ))?, + })?; + + final_set + .into_iter() + .map(|HashableConstValue(const_value)| const_value) + .collect() + } + Relation::Difference(lhs, rhs) => { + set_operation(ctx, conc, lhs, rhs, |lhs, rhs| { + lhs.difference(&rhs) + .cloned() + .map(|HashableConstValue(const_value)| const_value) + .collect() + }) + .await? + } + Relation::Equals(lhs, rhs) => { + (lhs.eval(ctx, conc).await? == rhs.eval(ctx, conc).await?).into() + } + Relation::Gt(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + (compare(&lhs, &rhs) == Some(Ordering::Greater)).into() + } + Relation::Gte(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + matches!( + compare(&lhs, &rhs), + Some(Ordering::Greater) | Some(Ordering::Equal) + ) + .into() + } + Relation::Lt(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + (compare(&lhs, &rhs) == Some(Ordering::Less)).into() + } + Relation::Lte(lhs, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let rhs = rhs.eval(ctx, conc).await?; + + matches!( + compare(&lhs, &rhs), + Some(Ordering::Less) | Some(Ordering::Equal) + ) + .into() + } + Relation::Max(exprs) => { + let mut results: Vec<_> = exprs.eval(ctx, conc).await?; + + let last = results.pop().ok_or(EvaluationError::ExprEvalError( + "`max` cannot be called on empty list".into(), + ))?; + + results.into_iter().try_fold(last, |mut largest, current| { + let ord = compare(&largest, ¤t); + largest = match ord { + Some(Ordering::Greater | Ordering::Equal) => largest, + Some(Ordering::Less) => current, + _ => Err(anyhow::anyhow!( + "`max` cannot be calculated for types that cannot be compared" + ))?, + }; + Ok::<_, anyhow::Error>(largest) + })? + } + Relation::Min(exprs) => { + let mut results: Vec<_> = exprs.eval(ctx, conc).await?; + + let last = results.pop().ok_or(EvaluationError::ExprEvalError( + "`min` cannot be called on empty list".into(), + ))?; + + results.into_iter().try_fold(last, |mut largest, current| { + let ord = compare(&largest, ¤t); + largest = match ord { + Some(Ordering::Less | Ordering::Equal) => largest, + Some(Ordering::Greater) => current, + _ => Err(anyhow::anyhow!( + "`min` cannot be calculated for types that cannot be compared" + ))?, + }; + Ok::<_, anyhow::Error>(largest) + })? + } + Relation::PathEq(lhs, path, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let lhs = get_path_for_const_value_owned(path, lhs) + .ok_or(anyhow::anyhow!("Could not find path: {path:?}"))?; + + let rhs = rhs.eval(ctx, conc).await?; + let rhs = get_path_for_const_value_owned(path, rhs) + .ok_or(anyhow::anyhow!("Could not find path: {path:?}"))?; + + (lhs == rhs).into() + } + Relation::PropEq(lhs, prop, rhs) => { + let lhs = lhs.eval(ctx, conc).await?; + let lhs = get_path_for_const_value_owned(&[prop], lhs) + .ok_or(anyhow::anyhow!("Could not find path: {prop:?}"))?; + + let rhs = rhs.eval(ctx, conc).await?; + let rhs = get_path_for_const_value_owned(&[prop], rhs) + .ok_or(anyhow::anyhow!("Could not find path: {prop:?}"))?; + + (lhs == rhs).into() + } + Relation::SortPath(expr, path) => { + let value = expr.eval(ctx, conc).await?; + let values = match value { + ConstValue::List(list) => list, + _ => Err(EvaluationError::ExprEvalError( + "`sortPath` can only be applied to expressions that return list".into(), + ))?, + }; + + let is_comparable = is_list_comparable(&values); + let mut values: Vec<_> = values.into_iter().enumerate().collect(); + + if !is_comparable { + Err(anyhow::anyhow!( + "sortPath requires a list of comparable types" + ))? + } + + let value_paths: Vec<_> = values + .iter() + .filter_map(|(_, val)| get_path_for_const_value_ref(path, val)) + .cloned() + .collect(); + + if values.len() != value_paths.len() { + Err(anyhow::anyhow!( + "path is not valid for all the element in the list: {value_paths:?}" + ))? + } + + values.sort_by(|(index1, _), (index2, _)| { + compare(&value_paths[*index1], &value_paths[*index2]).unwrap() + }); + + values + .into_iter() + .map(|(_, val)| val) + .collect::>() + .into() + } + Relation::SymmetricDifference(lhs, rhs) => { + set_operation(ctx, conc, lhs, rhs, |lhs, rhs| { + lhs.symmetric_difference(&rhs) + .cloned() + .map(|HashableConstValue(const_value)| const_value) + .collect() + }) + .await? + } + Relation::Union(lhs, rhs) => { + set_operation(ctx, conc, lhs, rhs, |lhs, rhs| { + lhs.union(&rhs) + .cloned() + .map(|HashableConstValue(const_value)| const_value) + .collect() + }) + .await? + } + }) + }) + } } fn is_list_comparable(list: &[ConstValue]) -> bool { - list - .iter() - .zip(list.iter().skip(1)) - .all(|(lhs, rhs)| is_pair_comparable(lhs, rhs)) + list.iter() + .zip(list.iter().skip(1)) + .all(|(lhs, rhs)| is_pair_comparable(lhs, rhs)) } fn compare(lhs: &ConstValue, rhs: &ConstValue) -> Option { - Some(match (lhs, rhs) { - (ConstValue::Null, ConstValue::Null) => Ordering::Equal, - (ConstValue::Boolean(lhs), ConstValue::Boolean(rhs)) => lhs.partial_cmp(rhs)?, - (ConstValue::Enum(lhs), ConstValue::Enum(rhs)) => lhs.partial_cmp(rhs)?, - (ConstValue::Number(lhs), ConstValue::Number(rhs)) => lhs - .as_f64() - .partial_cmp(&rhs.as_f64()) - .or(lhs.as_i64().partial_cmp(&rhs.as_i64())) - .or(lhs.as_u64().partial_cmp(&rhs.as_u64()))?, - (ConstValue::Binary(lhs), ConstValue::Binary(rhs)) => lhs.partial_cmp(rhs)?, - (ConstValue::String(lhs), ConstValue::String(rhs)) => lhs.partial_cmp(rhs)?, - (ConstValue::List(lhs), ConstValue::List(rhs)) => lhs - .iter() - .zip(rhs.iter()) - .find_map(|(lhs, rhs)| compare(lhs, rhs).filter(|ord| ord != &Ordering::Equal)) - .unwrap_or(lhs.len().partial_cmp(&rhs.len())?), - _ => None?, - }) + Some(match (lhs, rhs) { + (ConstValue::Null, ConstValue::Null) => Ordering::Equal, + (ConstValue::Boolean(lhs), ConstValue::Boolean(rhs)) => lhs.partial_cmp(rhs)?, + (ConstValue::Enum(lhs), ConstValue::Enum(rhs)) => lhs.partial_cmp(rhs)?, + (ConstValue::Number(lhs), ConstValue::Number(rhs)) => lhs + .as_f64() + .partial_cmp(&rhs.as_f64()) + .or(lhs.as_i64().partial_cmp(&rhs.as_i64())) + .or(lhs.as_u64().partial_cmp(&rhs.as_u64()))?, + (ConstValue::Binary(lhs), ConstValue::Binary(rhs)) => lhs.partial_cmp(rhs)?, + (ConstValue::String(lhs), ConstValue::String(rhs)) => lhs.partial_cmp(rhs)?, + (ConstValue::List(lhs), ConstValue::List(rhs)) => lhs + .iter() + .zip(rhs.iter()) + .find_map(|(lhs, rhs)| compare(lhs, rhs).filter(|ord| ord != &Ordering::Equal)) + .unwrap_or(lhs.len().partial_cmp(&rhs.len())?), + _ => None?, + }) } fn is_pair_comparable(lhs: &ConstValue, rhs: &ConstValue) -> bool { - matches!( - (lhs, rhs), - (ConstValue::Null, ConstValue::Null) - | (ConstValue::Boolean(_), ConstValue::Boolean(_)) - | (ConstValue::Enum(_), ConstValue::Enum(_)) - | (ConstValue::Number(_), ConstValue::Number(_)) - | (ConstValue::Binary(_), ConstValue::Binary(_)) - | (ConstValue::String(_), ConstValue::String(_)) - | (ConstValue::List(_), ConstValue::List(_)) - ) + matches!( + (lhs, rhs), + (ConstValue::Null, ConstValue::Null) + | (ConstValue::Boolean(_), ConstValue::Boolean(_)) + | (ConstValue::Enum(_), ConstValue::Enum(_)) + | (ConstValue::Number(_), ConstValue::Number(_)) + | (ConstValue::Binary(_), ConstValue::Binary(_)) + | (ConstValue::String(_), ConstValue::String(_)) + | (ConstValue::List(_), ConstValue::List(_)) + ) } #[allow(clippy::too_many_arguments)] async fn set_operation<'a, 'b, Ctx: ResolverContextLike<'a> + Sync + Send, F>( - ctx: &'a EvaluationContext<'a, Ctx>, - conc: &'a Concurrent, - lhs: &'a [Expression], - rhs: &'a [Expression], - operation: F, + ctx: &'a EvaluationContext<'a, Ctx>, + conc: &'a Concurrent, + lhs: &'a [Expression], + rhs: &'a [Expression], + operation: F, ) -> Result where - F: Fn(HashSet, HashSet) -> Vec, + F: Fn(HashSet, HashSet) -> Vec, { - let (lhs, rhs) = futures_util::join!( - conc.foreach(lhs.iter().map(|e| e.eval(ctx, conc)), HashableConstValue), - conc.foreach(rhs.iter().map(|e| e.eval(ctx, conc)), HashableConstValue) - ); - Ok(operation(HashSet::from_iter(lhs?), HashSet::from_iter(rhs?)).into()) + let (lhs, rhs) = futures_util::join!( + conc.foreach(lhs.iter().map(|e| e.eval(ctx, conc)), HashableConstValue), + conc.foreach(rhs.iter().map(|e| e.eval(ctx, conc)), HashableConstValue) + ); + Ok(operation(HashSet::from_iter(lhs?), HashSet::from_iter(rhs?)).into()) } fn get_path_for_const_value_ref<'a>( - path: &[impl AsRef], - mut const_value: &'a ConstValue, + path: &[impl AsRef], + mut const_value: &'a ConstValue, ) -> Option<&'a ConstValue> { - for path in path.iter() { - const_value = match const_value { - ConstValue::Object(ref obj) => obj.get(path.as_ref())?, - _ => None?, + for path in path.iter() { + const_value = match const_value { + ConstValue::Object(ref obj) => obj.get(path.as_ref())?, + _ => None?, + } } - } - Some(const_value) + Some(const_value) } -fn get_path_for_const_value_owned(path: &[impl AsRef], mut const_value: ConstValue) -> Option { - for path in path.iter() { - const_value = match const_value { - ConstValue::Object(mut obj) => obj.remove(path.as_ref())?, - _ => None?, +fn get_path_for_const_value_owned( + path: &[impl AsRef], + mut const_value: ConstValue, +) -> Option { + for path in path.iter() { + const_value = match const_value { + ConstValue::Object(mut obj) => obj.remove(path.as_ref())?, + _ => None?, + } } - } - Some(const_value) + Some(const_value) } diff --git a/src/lambda/resolver_context_like.rs b/src/lambda/resolver_context_like.rs index 4f978d81a13..5b4ff678114 100644 --- a/src/lambda/resolver_context_like.rs +++ b/src/lambda/resolver_context_like.rs @@ -4,44 +4,44 @@ use async_graphql::{Name, ServerError, Value}; use indexmap::IndexMap; pub trait ResolverContextLike<'a> { - fn value(&'a self) -> Option<&'a Value>; - fn args(&'a self) -> Option<&'a IndexMap>; - fn field(&'a self) -> Option; - fn add_error(&'a self, error: ServerError); + fn value(&'a self) -> Option<&'a Value>; + fn args(&'a self) -> Option<&'a IndexMap>; + fn field(&'a self) -> Option; + fn add_error(&'a self, error: ServerError); } pub struct EmptyResolverContext; impl<'a> ResolverContextLike<'a> for EmptyResolverContext { - fn value(&'a self) -> Option<&'a Value> { - None - } + fn value(&'a self) -> Option<&'a Value> { + None + } - fn args(&'a self) -> Option<&'a IndexMap> { - None - } + fn args(&'a self) -> Option<&'a IndexMap> { + None + } - fn field(&'a self) -> Option { - None - } + fn field(&'a self) -> Option { + None + } - fn add_error(&'a self, _: ServerError) {} + fn add_error(&'a self, _: ServerError) {} } impl<'a> ResolverContextLike<'a> for ResolverContext<'a> { - fn value(&'a self) -> Option<&'a Value> { - self.parent_value.as_value() - } + fn value(&'a self) -> Option<&'a Value> { + self.parent_value.as_value() + } - fn args(&'a self) -> Option<&'a IndexMap> { - Some(self.args.as_index_map()) - } + fn args(&'a self) -> Option<&'a IndexMap> { + Some(self.args.as_index_map()) + } - fn field(&'a self) -> Option { - Some(self.ctx.field()) - } + fn field(&'a self) -> Option { + Some(self.ctx.field()) + } - fn add_error(&'a self, error: ServerError) { - self.ctx.add_error(error) - } + fn add_error(&'a self, error: ServerError) { + self.ctx.add_error(error) + } } diff --git a/src/lib.rs b/src/lib.rs index 64e4d82b340..8cdef8dd2a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,44 +36,56 @@ use async_graphql_value::ConstValue; use http::Response; pub trait EnvIO: Send + Sync + 'static { - fn get(&self, key: &str) -> Option; + fn get(&self, key: &str) -> Option; } #[async_trait::async_trait] pub trait HttpIO: Sync + Send + 'static { - async fn execute(&self, request: reqwest::Request) -> anyhow::Result>; + async fn execute( + &self, + request: reqwest::Request, + ) -> anyhow::Result>; } impl HttpIO for Arc { - fn execute<'life0, 'async_trait>( - &'life0 self, - request: reqwest::Request, - ) -> core::pin::Pin< - Box< - dyn core::future::Future>> - + core::marker::Send - + 'async_trait, - >, - > - where - 'life0: 'async_trait, - Self: 'async_trait, - { - self.deref().execute(request) - } + fn execute<'life0, 'async_trait>( + &'life0 self, + request: reqwest::Request, + ) -> core::pin::Pin< + Box< + dyn core::future::Future>> + + core::marker::Send + + 'async_trait, + >, + > + where + 'life0: 'async_trait, + Self: 'async_trait, + { + self.deref().execute(request) + } } pub trait FileIO { - fn write<'a>(&'a self, file: &'a str, content: &'a [u8]) -> impl Future>; - fn read<'a>(&'a self, file_path: &'a str) -> impl Future>; + fn write<'a>( + &'a self, + file: &'a str, + content: &'a [u8], + ) -> impl Future>; + fn read<'a>(&'a self, file_path: &'a str) -> impl Future>; } #[async_trait::async_trait] pub trait Cache: Send + Sync { - type Key: Hash + Eq; - type Value; - async fn set<'a>(&'a self, key: Self::Key, value: Self::Value, ttl: NonZeroU64) -> anyhow::Result; - async fn get<'a>(&'a self, key: &'a Self::Key) -> anyhow::Result; + type Key: Hash + Eq; + type Value; + async fn set<'a>( + &'a self, + key: Self::Key, + value: Self::Value, + ttl: NonZeroU64, + ) -> anyhow::Result; + async fn get<'a>(&'a self, key: &'a Self::Key) -> anyhow::Result; } type EntityCache = dyn Cache; diff --git a/src/main.rs b/src/main.rs index d4b8afec9c1..eb14b121960 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,30 +8,30 @@ use tailcall::cli::CLIError; static GLOBAL: MiMalloc = MiMalloc; fn run_blocking() -> anyhow::Result<()> { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { tailcall::cli::run().await }) + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(async { tailcall::cli::run().await }) } fn main() -> anyhow::Result<()> { - let result = run_blocking(); - match result { - Ok(_) => {} - Err(error) => { - // Ensure all errors are converted to CLIErrors before being printed. - let cli_error = match error.downcast::() { - Ok(cli_error) => cli_error, + let result = run_blocking(); + match result { + Ok(_) => {} Err(error) => { - let sources = error - .source() - .map(|error| vec![CLIError::new(error.to_string().as_str())]) - .unwrap_or_default(); + // Ensure all errors are converted to CLIErrors before being printed. + let cli_error = match error.downcast::() { + Ok(cli_error) => cli_error, + Err(error) => { + let sources = error + .source() + .map(|error| vec![CLIError::new(error.to_string().as_str())]) + .unwrap_or_default(); - CLIError::new(&error.to_string()).caused_by(sources) + CLIError::new(&error.to_string()).caused_by(sources) + } + }; + eprintln!("{}", cli_error.color(true)); + std::process::exit(exitcode::CONFIG); } - }; - eprintln!("{}", cli_error.color(true)); - std::process::exit(exitcode::CONFIG); } - } - Ok(()) + Ok(()) } diff --git a/src/mustache.rs b/src/mustache.rs index c8d24dacb12..9aa82a8e0b3 100644 --- a/src/mustache.rs +++ b/src/mustache.rs @@ -7,397 +7,416 @@ pub struct Mustache(Vec); #[derive(Debug, Clone, PartialEq, Hash)] pub enum Segment { - Literal(String), - Expression(Vec), + Literal(String), + Expression(Vec), } impl From> for Mustache { - fn from(segments: Vec) -> Self { - Mustache(segments) - } + fn from(segments: Vec) -> Self { + Mustache(segments) + } } impl Mustache { - pub fn is_const(&self) -> bool { - match self { - Mustache(segments) => { - for s in segments { - if let Segment::Expression(_) = s { - return false; - } + pub fn is_const(&self) -> bool { + match self { + Mustache(segments) => { + for s in segments { + if let Segment::Expression(_) = s { + return false; + } + } + true + } } - true - } } - } - - // TODO: infallible function, no need to return Result - pub fn parse(str: &str) -> anyhow::Result { - let result = parse_mustache(str).finish(); - match result { - Ok((_, mustache)) => Ok(mustache), - Err(_) => Ok(Mustache::from(vec![Segment::Literal(str.to_string())])), + + // TODO: infallible function, no need to return Result + pub fn parse(str: &str) -> anyhow::Result { + let result = parse_mustache(str).finish(); + match result { + Ok((_, mustache)) => Ok(mustache), + Err(_) => Ok(Mustache::from(vec![Segment::Literal(str.to_string())])), + } } - } - - pub fn render(&self, value: &impl PathString) -> String { - match self { - Mustache(segments) => segments - .iter() - .map(|segment| match segment { - Segment::Literal(text) => text.clone(), - Segment::Expression(parts) => value.path_string(parts).map(|a| a.to_string()).unwrap_or_default(), - }) - .collect(), + + pub fn render(&self, value: &impl PathString) -> String { + match self { + Mustache(segments) => segments + .iter() + .map(|segment| match segment { + Segment::Literal(text) => text.clone(), + Segment::Expression(parts) => value + .path_string(parts) + .map(|a| a.to_string()) + .unwrap_or_default(), + }) + .collect(), + } } - } - - pub fn render_graphql(&self, value: &impl PathGraphql) -> String { - match self { - Mustache(segments) => segments - .iter() - .map(|segment| match segment { - Segment::Literal(text) => text.to_string(), - Segment::Expression(parts) => value.path_graphql(parts).unwrap_or_default(), - }) - .collect(), + + pub fn render_graphql(&self, value: &impl PathGraphql) -> String { + match self { + Mustache(segments) => segments + .iter() + .map(|segment| match segment { + Segment::Literal(text) => text.to_string(), + Segment::Expression(parts) => value.path_graphql(parts).unwrap_or_default(), + }) + .collect(), + } } - } - - pub fn expression_segments(&self) -> Vec<&Vec> { - match self { - Mustache(segments) => segments - .iter() - .filter_map(|seg| match seg { - Segment::Expression(parts) => Some(parts), - _ => None, - }) - .collect(), + + pub fn expression_segments(&self) -> Vec<&Vec> { + match self { + Mustache(segments) => segments + .iter() + .filter_map(|seg| match seg { + Segment::Expression(parts) => Some(parts), + _ => None, + }) + .collect(), + } } - } } fn parse_name(input: &str) -> IResult<&str, String> { - let spaces = nom::character::complete::multispace0; - let alpha = nom::character::complete::alpha1; - let alphanumeric_or_underscore = nom::multi::many0(nom::branch::alt(( - nom::character::complete::alphanumeric1, - nom::bytes::complete::tag("_"), - ))); - - let parser = nom::sequence::tuple((spaces, alpha, alphanumeric_or_underscore, spaces)); - - nom::combinator::map(parser, |(_, a, b, _)| { - let b: String = b.into_iter().collect(); - format!("{}{}", a, b) - })(input) + let spaces = nom::character::complete::multispace0; + let alpha = nom::character::complete::alpha1; + let alphanumeric_or_underscore = nom::multi::many0(nom::branch::alt(( + nom::character::complete::alphanumeric1, + nom::bytes::complete::tag("_"), + ))); + + let parser = nom::sequence::tuple((spaces, alpha, alphanumeric_or_underscore, spaces)); + + nom::combinator::map(parser, |(_, a, b, _)| { + let b: String = b.into_iter().collect(); + format!("{}{}", a, b) + })(input) } fn parse_expression(input: &str) -> IResult<&str, Vec> { - nom::combinator::map( - nom::sequence::tuple(( - nom::bytes::complete::tag("{{"), - nom::multi::separated_list1(nom::character::complete::char('.'), parse_name), - nom::bytes::complete::tag("}}"), - )), - |(_, vec, _)| vec, - )(input) + nom::combinator::map( + nom::sequence::tuple(( + nom::bytes::complete::tag("{{"), + nom::multi::separated_list1(nom::character::complete::char('.'), parse_name), + nom::bytes::complete::tag("}}"), + )), + |(_, vec, _)| vec, + )(input) } fn parse_segment(input: &str) -> IResult<&str, Segment> { - let expression = nom::combinator::map(parse_expression, Segment::Expression); - let literal = nom::combinator::map(nom::bytes::complete::take_while1(|c| c != '{'), |r: &str| { - Segment::Literal(r.to_string()) - }); + let expression = nom::combinator::map(parse_expression, Segment::Expression); + let literal = nom::combinator::map( + nom::bytes::complete::take_while1(|c| c != '{'), + |r: &str| Segment::Literal(r.to_string()), + ); - nom::branch::alt((expression, literal))(input) + nom::branch::alt((expression, literal))(input) } fn parse_mustache(input: &str) -> IResult<&str, Mustache> { - nom::combinator::map(nom::multi::many1(parse_segment), Mustache)(input) + nom::combinator::map(nom::multi::many1(parse_segment), Mustache)(input) } #[cfg(test)] mod tests { - mod parse { - use pretty_assertions::assert_eq; - - use crate::mustache::{Mustache, Segment}; - - #[test] - fn test_single_literal() { - let s = r"hello/world"; - let mustache: Mustache = Mustache::parse(s).unwrap(); - assert_eq!( - mustache, - Mustache::from(vec![Segment::Literal("hello/world".to_string())]) - ); - } - - #[test] - fn test_single_template() { - let s = r"{{hello.world}}"; - let mustache: Mustache = Mustache::parse(s).unwrap(); - assert_eq!( - mustache, - Mustache::from(vec![Segment::Expression(vec![ - "hello".to_string(), - "world".to_string() - ])]) - ); - } - - #[test] - fn test_mixed() { - let s = r"http://localhost:8090/{{foo.bar}}/api/{{hello.world}}/end"; - let mustache: Mustache = Mustache::parse(s).unwrap(); - assert_eq!( - mustache, - Mustache::from(vec![ - Segment::Literal("http://localhost:8090/".to_string()), - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal("/api/".to_string()), - Segment::Expression(vec!["hello".to_string(), "world".to_string()]), - Segment::Literal("/end".to_string()) - ]) - ); - } - - #[test] - fn test_with_spaces() { - let s = "{{ foo . bar }}"; - let mustache: Mustache = Mustache::parse(s).unwrap(); - assert_eq!( - mustache, - Mustache::from(vec![Segment::Expression(vec!["foo".to_string(), "bar".to_string()])]) - ); - } + mod parse { + use pretty_assertions::assert_eq; + + use crate::mustache::{Mustache, Segment}; + + #[test] + fn test_single_literal() { + let s = r"hello/world"; + let mustache: Mustache = Mustache::parse(s).unwrap(); + assert_eq!( + mustache, + Mustache::from(vec![Segment::Literal("hello/world".to_string())]) + ); + } - #[test] - fn test_parse_expression_with_valid_input() { - let result = Mustache::parse("{{ foo.bar }} extra").unwrap(); - let expected = Mustache::from(vec![ - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal(" extra".to_string()), - ]); - assert_eq!(result, expected); - } + #[test] + fn test_single_template() { + let s = r"{{hello.world}}"; + let mustache: Mustache = Mustache::parse(s).unwrap(); + assert_eq!( + mustache, + Mustache::from(vec![Segment::Expression(vec![ + "hello".to_string(), + "world".to_string() + ])]) + ); + } - #[test] - fn test_parse_expression_with_invalid_input() { - let result = Mustache::parse("foo.bar }}").unwrap(); - let expected = Mustache::from(vec![Segment::Literal("foo.bar }}".to_string())]); - assert_eq!(result, expected); - } + #[test] + fn test_mixed() { + let s = r"http://localhost:8090/{{foo.bar}}/api/{{hello.world}}/end"; + let mustache: Mustache = Mustache::parse(s).unwrap(); + assert_eq!( + mustache, + Mustache::from(vec![ + Segment::Literal("http://localhost:8090/".to_string()), + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal("/api/".to_string()), + Segment::Expression(vec!["hello".to_string(), "world".to_string()]), + Segment::Literal("/end".to_string()) + ]) + ); + } - #[test] - fn test_parse_segments_mixed() { - let result = Mustache::parse("prefix {{foo.bar}} middle {{baz.qux}} suffix").unwrap(); - let expected = Mustache::from(vec![ - Segment::Literal("prefix ".to_string()), - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal(" middle ".to_string()), - Segment::Expression(vec!["baz".to_string(), "qux".to_string()]), - Segment::Literal(" suffix".to_string()), - ]); - assert_eq!(result, expected); - } + #[test] + fn test_with_spaces() { + let s = "{{ foo . bar }}"; + let mustache: Mustache = Mustache::parse(s).unwrap(); + assert_eq!( + mustache, + Mustache::from(vec![Segment::Expression(vec![ + "foo".to_string(), + "bar".to_string() + ])]) + ); + } - #[test] - fn test_parse_segments_only_literal() { - let result = Mustache::parse("just a string").unwrap(); - let expected = Mustache(vec![Segment::Literal("just a string".to_string())]); - assert_eq!(result, expected); - } + #[test] + fn test_parse_expression_with_valid_input() { + let result = Mustache::parse("{{ foo.bar }} extra").unwrap(); + let expected = Mustache::from(vec![ + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal(" extra".to_string()), + ]); + assert_eq!(result, expected); + } - #[test] - fn test_parse_segments_only_expression() { - let result = Mustache::parse("{{foo.bar}}").unwrap(); - let expected = Mustache(vec![Segment::Expression(vec!["foo".to_string(), "bar".to_string()])]); - assert_eq!(result, expected); - } + #[test] + fn test_parse_expression_with_invalid_input() { + let result = Mustache::parse("foo.bar }}").unwrap(); + let expected = Mustache::from(vec![Segment::Literal("foo.bar }}".to_string())]); + assert_eq!(result, expected); + } - #[test] - fn test_unfinished_expression() { - let s = r"{{hello.world"; - let mustache: Mustache = Mustache::parse(s).unwrap(); - assert_eq!( - mustache, - Mustache::from(vec![Segment::Literal("{{hello.world".to_string())]) - ); - } + #[test] + fn test_parse_segments_mixed() { + let result = Mustache::parse("prefix {{foo.bar}} middle {{baz.qux}} suffix").unwrap(); + let expected = Mustache::from(vec![ + Segment::Literal("prefix ".to_string()), + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal(" middle ".to_string()), + Segment::Expression(vec!["baz".to_string(), "qux".to_string()]), + Segment::Literal(" suffix".to_string()), + ]); + assert_eq!(result, expected); + } - #[test] - fn test_new_number() { - let mustache = Mustache::parse("123").unwrap(); - assert_eq!(mustache, Mustache::from(vec![Segment::Literal("123".to_string())])); - } + #[test] + fn test_parse_segments_only_literal() { + let result = Mustache::parse("just a string").unwrap(); + let expected = Mustache(vec![Segment::Literal("just a string".to_string())]); + assert_eq!(result, expected); + } - #[test] - fn parse_env_name() { - let result = Mustache::parse("{{env.FOO}}").unwrap(); - assert_eq!( - result, - Mustache::from(vec![Segment::Expression(vec!["env".to_string(), "FOO".to_string()])]) - ); - } + #[test] + fn test_parse_segments_only_expression() { + let result = Mustache::parse("{{foo.bar}}").unwrap(); + let expected = Mustache(vec![Segment::Expression(vec![ + "foo".to_string(), + "bar".to_string(), + ])]); + assert_eq!(result, expected); + } - #[test] - fn parse_env_with_underscores() { - let result = Mustache::parse("{{env.FOO_BAR}}").unwrap(); - assert_eq!( - result, - Mustache::from(vec![Segment::Expression(vec![ - "env".to_string(), - "FOO_BAR".to_string() - ])]) - ); - } - } - mod render { - use std::borrow::Cow; - - use serde_json::json; - - use crate::mustache::{Mustache, Segment}; - use crate::path::PathString; - - #[test] - fn test_query_params_template() { - let s = r"/v1/templates?project-id={{value.projectId}}"; - let mustache: Mustache = Mustache::parse(s).unwrap(); - let ctx = json!(json!({"value": {"projectId": "123"}})); - let result = mustache.render(&ctx); - assert_eq!(result, "/v1/templates?project-id=123"); - } + #[test] + fn test_unfinished_expression() { + let s = r"{{hello.world"; + let mustache: Mustache = Mustache::parse(s).unwrap(); + assert_eq!( + mustache, + Mustache::from(vec![Segment::Literal("{{hello.world".to_string())]) + ); + } - #[test] - fn test_render_mixed() { - struct DummyPath; - - impl PathString for DummyPath { - fn path_string>(&self, parts: &[T]) -> Option> { - let parts: Vec<&str> = parts.iter().map(AsRef::as_ref).collect(); - - if parts == ["foo", "bar"] { - Some(Cow::Borrowed("FOOBAR")) - } else if parts == ["baz", "qux"] { - Some(Cow::Borrowed("BAZQUX")) - } else { - None - } + #[test] + fn test_new_number() { + let mustache = Mustache::parse("123").unwrap(); + assert_eq!( + mustache, + Mustache::from(vec![Segment::Literal("123".to_string())]) + ); } - } - let mustache = Mustache::from(vec![ - Segment::Literal("prefix ".to_string()), - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal(" middle ".to_string()), - Segment::Expression(vec!["baz".to_string(), "qux".to_string()]), - Segment::Literal(" suffix".to_string()), - ]); + #[test] + fn parse_env_name() { + let result = Mustache::parse("{{env.FOO}}").unwrap(); + assert_eq!( + result, + Mustache::from(vec![Segment::Expression(vec![ + "env".to_string(), + "FOO".to_string() + ])]) + ); + } - assert_eq!(mustache.render(&DummyPath), "prefix FOOBAR middle BAZQUX suffix"); + #[test] + fn parse_env_with_underscores() { + let result = Mustache::parse("{{env.FOO_BAR}}").unwrap(); + assert_eq!( + result, + Mustache::from(vec![Segment::Expression(vec![ + "env".to_string(), + "FOO_BAR".to_string() + ])]) + ); + } } + mod render { + use std::borrow::Cow; + + use serde_json::json; - #[test] - fn test_render_with_missing_path() { - struct DummyPath; + use crate::mustache::{Mustache, Segment}; + use crate::path::PathString; - impl PathString for DummyPath { - fn path_string>(&self, _: &[T]) -> Option> { - None + #[test] + fn test_query_params_template() { + let s = r"/v1/templates?project-id={{value.projectId}}"; + let mustache: Mustache = Mustache::parse(s).unwrap(); + let ctx = json!(json!({"value": {"projectId": "123"}})); + let result = mustache.render(&ctx); + assert_eq!(result, "/v1/templates?project-id=123"); } - } - let mustache = Mustache::from(vec![ - Segment::Literal("prefix ".to_string()), - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal(" suffix".to_string()), - ]); + #[test] + fn test_render_mixed() { + struct DummyPath; + + impl PathString for DummyPath { + fn path_string>(&self, parts: &[T]) -> Option> { + let parts: Vec<&str> = parts.iter().map(AsRef::as_ref).collect(); + + if parts == ["foo", "bar"] { + Some(Cow::Borrowed("FOOBAR")) + } else if parts == ["baz", "qux"] { + Some(Cow::Borrowed("BAZQUX")) + } else { + None + } + } + } + + let mustache = Mustache::from(vec![ + Segment::Literal("prefix ".to_string()), + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal(" middle ".to_string()), + Segment::Expression(vec!["baz".to_string(), "qux".to_string()]), + Segment::Literal(" suffix".to_string()), + ]); + + assert_eq!( + mustache.render(&DummyPath), + "prefix FOOBAR middle BAZQUX suffix" + ); + } - assert_eq!(mustache.render(&DummyPath), "prefix suffix"); - } + #[test] + fn test_render_with_missing_path() { + struct DummyPath; - #[test] - fn test_render_preserves_spaces() { - struct DummyPath; + impl PathString for DummyPath { + fn path_string>(&self, _: &[T]) -> Option> { + None + } + } - impl PathString for DummyPath { - fn path_string>(&self, parts: &[T]) -> Option> { - let parts: Vec<&str> = parts.iter().map(AsRef::as_ref).collect(); + let mustache = Mustache::from(vec![ + Segment::Literal("prefix ".to_string()), + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal(" suffix".to_string()), + ]); - if parts == ["foo"] { - Some(Cow::Borrowed("bar")) - } else { - None - } + assert_eq!(mustache.render(&DummyPath), "prefix suffix"); } - } - - let mustache = Mustache::from(vec![ - Segment::Literal(" ".to_string()), - Segment::Expression(vec!["foo".to_string()]), - Segment::Literal(" ".to_string()), - ]); - assert_eq!(mustache.render(&DummyPath).as_str(), " bar "); - } - } - - mod render_graphql { - use crate::mustache::{Mustache, Segment}; - use crate::path::PathGraphql; - - #[test] - fn test_render_mixed() { - struct DummyPath; - - impl PathGraphql for DummyPath { - fn path_graphql>(&self, parts: &[T]) -> Option { - let parts: Vec<&str> = parts.iter().map(AsRef::as_ref).collect(); - - if parts == ["foo", "bar"] { - Some("FOOBAR".to_owned()) - } else if parts == ["baz", "qux"] { - Some("BAZQUX".to_owned()) - } else { - None - } + #[test] + fn test_render_preserves_spaces() { + struct DummyPath; + + impl PathString for DummyPath { + fn path_string>(&self, parts: &[T]) -> Option> { + let parts: Vec<&str> = parts.iter().map(AsRef::as_ref).collect(); + + if parts == ["foo"] { + Some(Cow::Borrowed("bar")) + } else { + None + } + } + } + + let mustache = Mustache::from(vec![ + Segment::Literal(" ".to_string()), + Segment::Expression(vec!["foo".to_string()]), + Segment::Literal(" ".to_string()), + ]); + + assert_eq!(mustache.render(&DummyPath).as_str(), " bar "); } - } - - let mustache = Mustache::from(vec![ - Segment::Literal("prefix ".to_string()), - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal(" middle ".to_string()), - Segment::Expression(vec!["baz".to_string(), "qux".to_string()]), - Segment::Literal(" suffix".to_string()), - ]); - - assert_eq!( - mustache.render_graphql(&DummyPath), - "prefix FOOBAR middle BAZQUX suffix" - ); } - #[test] - fn test_render_with_missing_path() { - struct DummyPath; - - impl PathGraphql for DummyPath { - fn path_graphql>(&self, _: &[T]) -> Option { - None + mod render_graphql { + use crate::mustache::{Mustache, Segment}; + use crate::path::PathGraphql; + + #[test] + fn test_render_mixed() { + struct DummyPath; + + impl PathGraphql for DummyPath { + fn path_graphql>(&self, parts: &[T]) -> Option { + let parts: Vec<&str> = parts.iter().map(AsRef::as_ref).collect(); + + if parts == ["foo", "bar"] { + Some("FOOBAR".to_owned()) + } else if parts == ["baz", "qux"] { + Some("BAZQUX".to_owned()) + } else { + None + } + } + } + + let mustache = Mustache::from(vec![ + Segment::Literal("prefix ".to_string()), + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal(" middle ".to_string()), + Segment::Expression(vec!["baz".to_string(), "qux".to_string()]), + Segment::Literal(" suffix".to_string()), + ]); + + assert_eq!( + mustache.render_graphql(&DummyPath), + "prefix FOOBAR middle BAZQUX suffix" + ); } - } - let mustache = Mustache::from(vec![ - Segment::Literal("prefix ".to_string()), - Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), - Segment::Literal(" suffix".to_string()), - ]); + #[test] + fn test_render_with_missing_path() { + struct DummyPath; + + impl PathGraphql for DummyPath { + fn path_graphql>(&self, _: &[T]) -> Option { + None + } + } - assert_eq!(mustache.render_graphql(&DummyPath), "prefix suffix"); + let mustache = Mustache::from(vec![ + Segment::Literal("prefix ".to_string()), + Segment::Expression(vec!["foo".to_string(), "bar".to_string()]), + Segment::Literal(" suffix".to_string()), + ]); + + assert_eq!(mustache.render_graphql(&DummyPath), "prefix suffix"); + } } - } } diff --git a/src/path.rs b/src/path.rs index 6db0f19264f..278b3f117f4 100644 --- a/src/path.rs +++ b/src/path.rs @@ -15,7 +15,7 @@ use crate::lambda::{EvaluationContext, ResolverContextLike}; /// This is typically used in evaluating mustache templates. /// pub trait PathString { - fn path_string>(&self, path: &[T]) -> Option>; + fn path_string>(&self, path: &[T]) -> Option>; } /// @@ -23,301 +23,336 @@ pub trait PathString { /// The returned value is encoded as a GraphQL Value. /// pub trait PathGraphql { - fn path_graphql>(&self, path: &[T]) -> Option; + fn path_graphql>(&self, path: &[T]) -> Option; } impl PathString for serde_json::Value { - fn path_string>(&self, path: &[T]) -> Option> { - self.get_path(path).map(|a| match a { - serde_json::Value::String(s) => Cow::Borrowed(s.as_str()), - _ => Cow::Owned(a.to_string()), - }) - } + fn path_string>(&self, path: &[T]) -> Option> { + self.get_path(path).map(|a| match a { + serde_json::Value::String(s) => Cow::Borrowed(s.as_str()), + _ => Cow::Owned(a.to_string()), + }) + } } fn convert_value(value: &async_graphql::Value) -> Option> { - match value { - async_graphql::Value::String(s) => Some(Cow::Borrowed(s.as_str())), - async_graphql::Value::Number(n) => Some(Cow::Owned(n.to_string())), - async_graphql::Value::Boolean(b) => Some(Cow::Owned(b.to_string())), - async_graphql::Value::Object(map) => Some(json!(map).to_string().into()), - async_graphql::Value::List(list) => Some(json!(list).to_string().into()), - _ => None, - } + match value { + async_graphql::Value::String(s) => Some(Cow::Borrowed(s.as_str())), + async_graphql::Value::Number(n) => Some(Cow::Owned(n.to_string())), + async_graphql::Value::Boolean(b) => Some(Cow::Owned(b.to_string())), + async_graphql::Value::Object(map) => Some(json!(map).to_string().into()), + async_graphql::Value::List(list) => Some(json!(list).to_string().into()), + _ => None, + } } impl<'a, Ctx: ResolverContextLike<'a>> PathString for EvaluationContext<'a, Ctx> { - fn path_string>(&self, path: &[T]) -> Option> { - let ctx = self; - - if path.is_empty() { - return None; + fn path_string>(&self, path: &[T]) -> Option> { + let ctx = self; + + if path.is_empty() { + return None; + } + + if path.len() == 1 { + return match path[0].as_ref() { + "value" => convert_value(ctx.path_value(&[] as &[T])?), + "args" => Some(json!(ctx.graphql_ctx.args()?).to_string().into()), + "vars" => Some(json!(ctx.vars()).to_string().into()), + _ => None, + }; + } + + path.split_first() + .and_then(|(head, tail)| match head.as_ref() { + "value" => convert_value(ctx.path_value(tail)?), + "args" => convert_value(ctx.arg(tail)?), + "headers" => ctx.header(tail[0].as_ref()).map(|v| v.into()), + "vars" => ctx.var(tail[0].as_ref()).map(|v| v.into()), + "env" => ctx.env_var(tail[0].as_ref()).map(|v| v.into()), + _ => None, + }) } - - if path.len() == 1 { - return match path[0].as_ref() { - "value" => convert_value(ctx.path_value(&[] as &[T])?), - "args" => Some(json!(ctx.graphql_ctx.args()?).to_string().into()), - "vars" => Some(json!(ctx.vars()).to_string().into()), - _ => None, - }; - } - - path.split_first().and_then(|(head, tail)| match head.as_ref() { - "value" => convert_value(ctx.path_value(tail)?), - "args" => convert_value(ctx.arg(tail)?), - "headers" => ctx.header(tail[0].as_ref()).map(|v| v.into()), - "vars" => ctx.var(tail[0].as_ref()).map(|v| v.into()), - "env" => ctx.env_var(tail[0].as_ref()).map(|v| v.into()), - _ => None, - }) - } } impl<'a, Ctx: ResolverContextLike<'a>> PathGraphql for EvaluationContext<'a, Ctx> { - fn path_graphql>(&self, path: &[T]) -> Option { - let ctx = self; - - if path.len() < 2 { - return None; + fn path_graphql>(&self, path: &[T]) -> Option { + let ctx = self; + + if path.len() < 2 { + return None; + } + + path.split_first() + .and_then(|(head, tail)| match head.as_ref() { + "value" => Some(ctx.path_value(tail)?.to_string()), + "args" => Some(ctx.arg(tail)?.to_string()), + "headers" => ctx.header(tail[0].as_ref()).map(|v| format!(r#""{v}""#)), + "vars" => ctx.var(tail[0].as_ref()).map(|v| format!(r#""{v}""#)), + "env" => ctx.env_var(tail[0].as_ref()).map(|v| format!(r#""{v}""#)), + _ => None, + }) } - - path.split_first().and_then(|(head, tail)| match head.as_ref() { - "value" => Some(ctx.path_value(tail)?.to_string()), - "args" => Some(ctx.arg(tail)?.to_string()), - "headers" => ctx.header(tail[0].as_ref()).map(|v| format!(r#""{v}""#)), - "vars" => ctx.var(tail[0].as_ref()).map(|v| format!(r#""{v}""#)), - "env" => ctx.env_var(tail[0].as_ref()).map(|v| format!(r#""{v}""#)), - _ => None, - }) - } } #[cfg(test)] mod tests { - mod evaluation_context { - use std::borrow::Cow; - use std::collections::{BTreeMap, HashMap}; - use std::sync::Arc; - - use async_graphql::SelectionField; - use async_graphql_value::{ConstValue as Value, Name, Number}; - use hyper::header::HeaderValue; - use hyper::HeaderMap; - use indexmap::IndexMap; - use once_cell::sync::Lazy; - - use crate::http::RequestContext; - use crate::lambda::{EvaluationContext, ResolverContextLike}; - use crate::path::{PathGraphql, PathString}; - use crate::EnvIO; - - struct Env { - env: HashMap, - } - - impl EnvIO for Env { - fn get(&self, key: &str) -> Option { - self.env.get(key).cloned() - } - } - - impl Env { - pub fn init(map: HashMap) -> Self { - Self { env: map } - } - } - - static TEST_VALUES: Lazy = Lazy::new(|| { - let mut root = IndexMap::new(); - let mut nested = IndexMap::new(); - - nested.insert(Name::new("existing"), Value::String("nested-test".to_owned())); - - root.insert(Name::new("str"), Value::String("str-test".to_owned())); - root.insert(Name::new("number"), Value::Number(Number::from(2))); - root.insert(Name::new("bool"), Value::Boolean(true)); - root.insert(Name::new("nested"), Value::Object(nested)); + mod evaluation_context { + use std::borrow::Cow; + use std::collections::{BTreeMap, HashMap}; + use std::sync::Arc; - Value::Object(root) - }); + use async_graphql::SelectionField; + use async_graphql_value::{ConstValue as Value, Name, Number}; + use hyper::header::HeaderValue; + use hyper::HeaderMap; + use indexmap::IndexMap; + use once_cell::sync::Lazy; - static TEST_ARGS: Lazy> = Lazy::new(|| { - let mut root = IndexMap::new(); - let mut nested = IndexMap::new(); + use crate::http::RequestContext; + use crate::lambda::{EvaluationContext, ResolverContextLike}; + use crate::path::{PathGraphql, PathString}; + use crate::EnvIO; - nested.insert(Name::new("existing"), Value::String("nested-test".to_owned())); + struct Env { + env: HashMap, + } - root.insert(Name::new("root"), Value::String("root-test".to_owned())); - root.insert(Name::new("nested"), Value::Object(nested)); + impl EnvIO for Env { + fn get(&self, key: &str) -> Option { + self.env.get(key).cloned() + } + } - root - }); - - static TEST_HEADERS: Lazy = Lazy::new(|| { - let mut map = HeaderMap::new(); - - map.insert("x-existing", HeaderValue::from_static("header")); - - map - }); - - static TEST_VARS: Lazy> = Lazy::new(|| { - let mut map = BTreeMap::new(); - - map.insert("existing".to_owned(), "var".to_owned()); - - map - }); - - static TEST_ENV_VARS: Lazy> = Lazy::new(|| { - let mut map = HashMap::new(); - - map.insert("existing".to_owned(), "env".to_owned()); - - map - }); - - struct MockGraphqlContext; - - impl<'a> ResolverContextLike<'a> for MockGraphqlContext { - fn value(&'a self) -> Option<&'a Value> { - Some(&TEST_VALUES) - } - - fn args(&'a self) -> Option<&'a IndexMap> { - Some(&TEST_ARGS) - } - - fn field(&'a self) -> Option { - None - } - - fn add_error(&'a self, _: async_graphql::ServerError) {} - } - - static REQ_CTX: Lazy = Lazy::new(|| { - let mut req_ctx = RequestContext::default().req_headers(TEST_HEADERS.clone()); - - req_ctx.server.vars = TEST_VARS.clone(); - req_ctx.env_vars = Arc::new(Env::init(TEST_ENV_VARS.clone())); - - req_ctx - }); - - static EVAL_CTX: Lazy> = - Lazy::new(|| EvaluationContext::new(&REQ_CTX, &MockGraphqlContext)); - - #[test] - fn path_to_string() { - // value - assert_eq!(EVAL_CTX.path_string(&["value", "bool"]), Some(Cow::Borrowed("true"))); - assert_eq!(EVAL_CTX.path_string(&["value", "number"]), Some(Cow::Borrowed("2"))); - assert_eq!(EVAL_CTX.path_string(&["value", "str"]), Some(Cow::Borrowed("str-test"))); - assert_eq!( - EVAL_CTX.path_string(&["value", "nested"]), - Some(Cow::Borrowed("{\"existing\":\"nested-test\"}")) - ); - assert_eq!(EVAL_CTX.path_string(&["value", "missing"]), None); - assert_eq!(EVAL_CTX.path_string(&["value", "nested", "missing"]), None); - assert_eq!( - EVAL_CTX.path_string(&["value"]), - Some(Cow::Borrowed( - r#"{"bool":true,"nested":{"existing":"nested-test"},"number":2,"str":"str-test"}"# - )) - ); - - // args - assert_eq!( - EVAL_CTX.path_string(&["args", "root"]), - Some(Cow::Borrowed("root-test")) - ); - assert_eq!( - EVAL_CTX.path_string(&["args", "nested"]), - Some(Cow::Borrowed("{\"existing\":\"nested-test\"}")) - ); - assert_eq!(EVAL_CTX.path_string(&["args", "missing"]), None); - assert_eq!(EVAL_CTX.path_string(&["args", "nested", "missing"]), None); - assert_eq!( - EVAL_CTX.path_string(&["args"]), - Some(Cow::Borrowed( - r#"{"nested":{"existing":"nested-test"},"root":"root-test"}"# - )) - ); - - // headers - assert_eq!( - EVAL_CTX.path_string(&["headers", "x-existing"]), - Some(Cow::Borrowed("header")) - ); - assert_eq!(EVAL_CTX.path_string(&["headers", "x-missing"]), None); - - // vars - assert_eq!(EVAL_CTX.path_string(&["vars", "existing"]), Some(Cow::Borrowed("var"))); - assert_eq!(EVAL_CTX.path_string(&["vars", "missing"]), None); - assert_eq!( - EVAL_CTX.path_string(&["vars"]), - Some(Cow::Borrowed(r#"{"existing":"var"}"#)) - ); - - // envs - assert_eq!(EVAL_CTX.path_string(&["env", "existing"]), Some(Cow::Borrowed("env"))); - assert_eq!(EVAL_CTX.path_string(&["env", "x-missing"]), None); - - // other value types - assert_eq!(EVAL_CTX.path_string(&["foo", "key"]), None); - assert_eq!(EVAL_CTX.path_string(&["bar", "key"]), None); - assert_eq!(EVAL_CTX.path_string(&["baz", "key"]), None); - } + impl Env { + pub fn init(map: HashMap) -> Self { + Self { env: map } + } + } - #[test] - fn path_to_graphql_string() { - // value - assert_eq!(EVAL_CTX.path_graphql(&["value", "bool"]), Some("true".to_owned())); - assert_eq!(EVAL_CTX.path_graphql(&["value", "number"]), Some("2".to_owned())); - assert_eq!( - EVAL_CTX.path_graphql(&["value", "str"]), - Some("\"str-test\"".to_owned()) - ); - assert_eq!( - EVAL_CTX.path_graphql(&["value", "nested"]), - Some("{existing: \"nested-test\"}".to_owned()) - ); - assert_eq!(EVAL_CTX.path_graphql(&["value", "missing"]), None); - assert_eq!(EVAL_CTX.path_graphql(&["value", "nested", "missing"]), None); - - // args - assert_eq!( - EVAL_CTX.path_graphql(&["args", "root"]), - Some("\"root-test\"".to_owned()) - ); - assert_eq!( - EVAL_CTX.path_graphql(&["args", "nested"]), - Some("{existing: \"nested-test\"}".to_owned()) - ); - assert_eq!(EVAL_CTX.path_graphql(&["args", "missing"]), None); - assert_eq!(EVAL_CTX.path_graphql(&["args", "nested", "missing"]), None); - - // headers - assert_eq!( - EVAL_CTX.path_graphql(&["headers", "x-existing"]), - Some("\"header\"".to_owned()) - ); - assert_eq!(EVAL_CTX.path_graphql(&["headers", "x-missing"]), None); - - // vars - assert_eq!(EVAL_CTX.path_graphql(&["vars", "existing"]), Some("\"var\"".to_owned())); - assert_eq!(EVAL_CTX.path_graphql(&["vars", "missing"]), None); - - // envs - assert_eq!(EVAL_CTX.path_graphql(&["env", "existing"]), Some("\"env\"".to_owned())); - assert_eq!(EVAL_CTX.path_graphql(&["env", "x-missing"]), None); - - // other value types - assert_eq!(EVAL_CTX.path_graphql(&["foo", "key"]), None); - assert_eq!(EVAL_CTX.path_graphql(&["bar", "key"]), None); - assert_eq!(EVAL_CTX.path_graphql(&["baz", "key"]), None); + static TEST_VALUES: Lazy = Lazy::new(|| { + let mut root = IndexMap::new(); + let mut nested = IndexMap::new(); + + nested.insert( + Name::new("existing"), + Value::String("nested-test".to_owned()), + ); + + root.insert(Name::new("str"), Value::String("str-test".to_owned())); + root.insert(Name::new("number"), Value::Number(Number::from(2))); + root.insert(Name::new("bool"), Value::Boolean(true)); + root.insert(Name::new("nested"), Value::Object(nested)); + + Value::Object(root) + }); + + static TEST_ARGS: Lazy> = Lazy::new(|| { + let mut root = IndexMap::new(); + let mut nested = IndexMap::new(); + + nested.insert( + Name::new("existing"), + Value::String("nested-test".to_owned()), + ); + + root.insert(Name::new("root"), Value::String("root-test".to_owned())); + root.insert(Name::new("nested"), Value::Object(nested)); + + root + }); + + static TEST_HEADERS: Lazy = Lazy::new(|| { + let mut map = HeaderMap::new(); + + map.insert("x-existing", HeaderValue::from_static("header")); + + map + }); + + static TEST_VARS: Lazy> = Lazy::new(|| { + let mut map = BTreeMap::new(); + + map.insert("existing".to_owned(), "var".to_owned()); + + map + }); + + static TEST_ENV_VARS: Lazy> = Lazy::new(|| { + let mut map = HashMap::new(); + + map.insert("existing".to_owned(), "env".to_owned()); + + map + }); + + struct MockGraphqlContext; + + impl<'a> ResolverContextLike<'a> for MockGraphqlContext { + fn value(&'a self) -> Option<&'a Value> { + Some(&TEST_VALUES) + } + + fn args(&'a self) -> Option<&'a IndexMap> { + Some(&TEST_ARGS) + } + + fn field(&'a self) -> Option { + None + } + + fn add_error(&'a self, _: async_graphql::ServerError) {} + } + + static REQ_CTX: Lazy = Lazy::new(|| { + let mut req_ctx = RequestContext::default().req_headers(TEST_HEADERS.clone()); + + req_ctx.server.vars = TEST_VARS.clone(); + req_ctx.env_vars = Arc::new(Env::init(TEST_ENV_VARS.clone())); + + req_ctx + }); + + static EVAL_CTX: Lazy> = + Lazy::new(|| EvaluationContext::new(&REQ_CTX, &MockGraphqlContext)); + + #[test] + fn path_to_string() { + // value + assert_eq!( + EVAL_CTX.path_string(&["value", "bool"]), + Some(Cow::Borrowed("true")) + ); + assert_eq!( + EVAL_CTX.path_string(&["value", "number"]), + Some(Cow::Borrowed("2")) + ); + assert_eq!( + EVAL_CTX.path_string(&["value", "str"]), + Some(Cow::Borrowed("str-test")) + ); + assert_eq!( + EVAL_CTX.path_string(&["value", "nested"]), + Some(Cow::Borrowed("{\"existing\":\"nested-test\"}")) + ); + assert_eq!(EVAL_CTX.path_string(&["value", "missing"]), None); + assert_eq!(EVAL_CTX.path_string(&["value", "nested", "missing"]), None); + assert_eq!( + EVAL_CTX.path_string(&["value"]), + Some(Cow::Borrowed( + r#"{"bool":true,"nested":{"existing":"nested-test"},"number":2,"str":"str-test"}"# + )) + ); + + // args + assert_eq!( + EVAL_CTX.path_string(&["args", "root"]), + Some(Cow::Borrowed("root-test")) + ); + assert_eq!( + EVAL_CTX.path_string(&["args", "nested"]), + Some(Cow::Borrowed("{\"existing\":\"nested-test\"}")) + ); + assert_eq!(EVAL_CTX.path_string(&["args", "missing"]), None); + assert_eq!(EVAL_CTX.path_string(&["args", "nested", "missing"]), None); + assert_eq!( + EVAL_CTX.path_string(&["args"]), + Some(Cow::Borrowed( + r#"{"nested":{"existing":"nested-test"},"root":"root-test"}"# + )) + ); + + // headers + assert_eq!( + EVAL_CTX.path_string(&["headers", "x-existing"]), + Some(Cow::Borrowed("header")) + ); + assert_eq!(EVAL_CTX.path_string(&["headers", "x-missing"]), None); + + // vars + assert_eq!( + EVAL_CTX.path_string(&["vars", "existing"]), + Some(Cow::Borrowed("var")) + ); + assert_eq!(EVAL_CTX.path_string(&["vars", "missing"]), None); + assert_eq!( + EVAL_CTX.path_string(&["vars"]), + Some(Cow::Borrowed(r#"{"existing":"var"}"#)) + ); + + // envs + assert_eq!( + EVAL_CTX.path_string(&["env", "existing"]), + Some(Cow::Borrowed("env")) + ); + assert_eq!(EVAL_CTX.path_string(&["env", "x-missing"]), None); + + // other value types + assert_eq!(EVAL_CTX.path_string(&["foo", "key"]), None); + assert_eq!(EVAL_CTX.path_string(&["bar", "key"]), None); + assert_eq!(EVAL_CTX.path_string(&["baz", "key"]), None); + } + + #[test] + fn path_to_graphql_string() { + // value + assert_eq!( + EVAL_CTX.path_graphql(&["value", "bool"]), + Some("true".to_owned()) + ); + assert_eq!( + EVAL_CTX.path_graphql(&["value", "number"]), + Some("2".to_owned()) + ); + assert_eq!( + EVAL_CTX.path_graphql(&["value", "str"]), + Some("\"str-test\"".to_owned()) + ); + assert_eq!( + EVAL_CTX.path_graphql(&["value", "nested"]), + Some("{existing: \"nested-test\"}".to_owned()) + ); + assert_eq!(EVAL_CTX.path_graphql(&["value", "missing"]), None); + assert_eq!(EVAL_CTX.path_graphql(&["value", "nested", "missing"]), None); + + // args + assert_eq!( + EVAL_CTX.path_graphql(&["args", "root"]), + Some("\"root-test\"".to_owned()) + ); + assert_eq!( + EVAL_CTX.path_graphql(&["args", "nested"]), + Some("{existing: \"nested-test\"}".to_owned()) + ); + assert_eq!(EVAL_CTX.path_graphql(&["args", "missing"]), None); + assert_eq!(EVAL_CTX.path_graphql(&["args", "nested", "missing"]), None); + + // headers + assert_eq!( + EVAL_CTX.path_graphql(&["headers", "x-existing"]), + Some("\"header\"".to_owned()) + ); + assert_eq!(EVAL_CTX.path_graphql(&["headers", "x-missing"]), None); + + // vars + assert_eq!( + EVAL_CTX.path_graphql(&["vars", "existing"]), + Some("\"var\"".to_owned()) + ); + assert_eq!(EVAL_CTX.path_graphql(&["vars", "missing"]), None); + + // envs + assert_eq!( + EVAL_CTX.path_graphql(&["env", "existing"]), + Some("\"env\"".to_owned()) + ); + assert_eq!(EVAL_CTX.path_graphql(&["env", "x-missing"]), None); + + // other value types + assert_eq!(EVAL_CTX.path_graphql(&["foo", "key"]), None); + assert_eq!(EVAL_CTX.path_graphql(&["bar", "key"]), None); + assert_eq!(EVAL_CTX.path_graphql(&["baz", "key"]), None); + } } - } } diff --git a/src/print_schema.rs b/src/print_schema.rs index 91f767dacb3..35242e1b1dc 100644 --- a/src/print_schema.rs +++ b/src/print_schema.rs @@ -4,32 +4,34 @@ use async_graphql::SDLExportOptions; /// SDL returned from AsyncSchemaInner isn't standard /// We clean it up before returning. pub fn print_schema(schema: Schema) -> String { - let sdl = schema.sdl_with_options(SDLExportOptions::new().sorted_fields()); - let mut result = String::new(); - let mut prev_line_empty = false; + let sdl = schema.sdl_with_options(SDLExportOptions::new().sorted_fields()); + let mut result = String::new(); + let mut prev_line_empty = false; - for line in sdl.lines() { - let trimmed_line = line.trim(); - // Check if line contains the directives to be skipped - if trimmed_line.starts_with("directive @include") || trimmed_line.starts_with("directive @skip") { - continue; + for line in sdl.lines() { + let trimmed_line = line.trim(); + // Check if line contains the directives to be skipped + if trimmed_line.starts_with("directive @include") + || trimmed_line.starts_with("directive @skip") + { + continue; + } + if trimmed_line.is_empty() { + if !prev_line_empty { + result.push('\n'); + } + prev_line_empty = true; + } else { + let formatted_line = if line.starts_with('\t') { + line.replace('\t', " ") + } else { + line.to_string() + }; + result.push_str(&formatted_line); + result.push('\n'); + prev_line_empty = false; + } } - if trimmed_line.is_empty() { - if !prev_line_empty { - result.push('\n'); - } - prev_line_empty = true; - } else { - let formatted_line = if line.starts_with('\t') { - line.replace('\t', " ") - } else { - line.to_string() - }; - result.push_str(&formatted_line); - result.push('\n'); - prev_line_empty = false; - } - } - result.trim().to_string() + result.trim().to_string() } diff --git a/src/try_fold.rs b/src/try_fold.rs index 29752d8ede0..585a4278a00 100644 --- a/src/try_fold.rs +++ b/src/try_fold.rs @@ -9,312 +9,316 @@ type TryFoldFn<'a, I, O, E> = Box Valid + 'a>; pub struct TryFold<'a, I: 'a, O: 'a, E: 'a>(TryFoldFn<'a, I, O, E>); impl<'a, I, O: Clone + 'a, E> TryFold<'a, I, O, E> { - /// Try to fold the value with the input. - /// - /// # Parameters - /// - `input`: The input used in the folding operation. - /// - `value`: The value to be folded. - /// - /// # Returns - /// Returns a `Valid` value, which can be either a success with the folded value - /// or an error. - pub fn try_fold(&self, input: &I, state: O) -> Valid { - (self.0)(input, state) - } - - /// Combine two `TryFolding` implementors into a sequential operation. - /// - /// This method allows for chaining two `TryFolding` operations, where the result of the first operation - /// (if successful) will be used as the input for the second operation. - /// - /// # Parameters - /// - `other`: Another `TryFolding` implementor. - /// - /// # Returns - /// Returns a combined `And` structure that represents the sequential folding operation. - pub fn and(self, other: TryFold<'a, I, O, E>) -> Self { - TryFold(Box::new(move |input, state| { - self - .try_fold(input, state.clone()) - .fold(|state| other.try_fold(input, state), || other.try_fold(input, state)) - })) - } - - /// Create a new `TryFold` with a specified folding function. - /// - /// # Parameters - /// - `f`: The folding function. - /// - /// # Returns - /// Returns a new `TryFold` instance. - pub fn new(f: impl Fn(&I, O) -> Valid + 'a) -> Self { - TryFold(Box::new(f)) - } - - /// Transforms a TryFold to TryFold by applying transformations. - /// Check `transform_valid` if you want to return a `Valid` instead of an `O1`. - /// - /// # Parameters - /// - `up`: A function that uses O and O1 to create a new O1. - /// - `down`: A function that uses O1 to create a new O. - /// - /// # Returns - /// Returns a new TryFold that applies the transformations. - /// - pub fn transform( - self, - up: impl Fn(O, O1) -> O1 + 'a, - down: impl Fn(O1) -> O + 'a, - ) -> TryFold<'a, I, O1, E> { - self.transform_valid( - move |o, o1| Valid::succeed(up(o, o1)), - move |o1| Valid::succeed(down(o1)), - ) - } - - /// Transforms a TryFold to TryFold by applying transformations. - /// Check `transform` if you want to return an `O1` instead of a `Valid`. - /// - /// # Parameters - /// - `up`: A function that uses O and O1 to create a new Valid. - /// - `down`: A function that uses O1 to create a new Valid. - /// - /// # Returns - /// Returns a new TryFold that applies the transformations. - /// - pub fn transform_valid( - self, - up: impl Fn(O, O1) -> Valid + 'a, - down: impl Fn(O1) -> Valid + 'a, - ) -> TryFold<'a, I, O1, E> { - TryFold(Box::new(move |i, o1| { - down(o1.clone()) - .and_then(|o| self.try_fold(i, o)) - .and_then(|o| up(o, o1)) - })) - } - - pub fn update(self, f: impl Fn(O) -> O + 'a) -> TryFold<'a, I, O, E> { - self.transform(move |o, _| f(o), |o| o) - } - - /// Create a `TryFold` that always succeeds with the provided state. - /// - /// # Parameters - /// - `state`: The state to succeed with. - /// - /// # Returns - /// Returns a `TryFold` that always succeeds with the provided state. - pub fn succeed(f: impl Fn(&I, O) -> O + 'a) -> Self { - TryFold(Box::new(move |i, o| Valid::succeed(f(i, o)))) - } - - /// Create a `TryFold` that doesn't do anything. - /// - /// # Returns - /// Returns a `TryFold` that doesn't do anything. - pub fn empty() -> Self { - TryFold::new(|_, o| Valid::succeed(o)) - } - - /// Create a `TryFold` that always fails with the provided error. - /// - /// # Parameters - /// - `e`: The error to fail with. - /// - /// # Returns - /// Returns a `TryFold` that always fails with the provided error. - pub fn fail(e: E) -> Self - where - E: Clone, - { - TryFold::new(move |_, _| Valid::fail(e.clone())) - } - - /// Add trace logging to the fold operation. - /// - /// # Parameters - /// - /// * `msg` - The message to log when this fold operation is executed. - /// - /// # Returns - /// - /// Returns a new `TryFold` with trace logging added. - pub fn trace(self, msg: &'a str) -> Self { - TryFold::new(move |i, o| self.try_fold(i, o).trace(msg)) - } + /// Try to fold the value with the input. + /// + /// # Parameters + /// - `input`: The input used in the folding operation. + /// - `value`: The value to be folded. + /// + /// # Returns + /// Returns a `Valid` value, which can be either a success with the folded value + /// or an error. + pub fn try_fold(&self, input: &I, state: O) -> Valid { + (self.0)(input, state) + } + + /// Combine two `TryFolding` implementors into a sequential operation. + /// + /// This method allows for chaining two `TryFolding` operations, where the result of the first operation + /// (if successful) will be used as the input for the second operation. + /// + /// # Parameters + /// - `other`: Another `TryFolding` implementor. + /// + /// # Returns + /// Returns a combined `And` structure that represents the sequential folding operation. + pub fn and(self, other: TryFold<'a, I, O, E>) -> Self { + TryFold(Box::new(move |input, state| { + self.try_fold(input, state.clone()).fold( + |state| other.try_fold(input, state), + || other.try_fold(input, state), + ) + })) + } + + /// Create a new `TryFold` with a specified folding function. + /// + /// # Parameters + /// - `f`: The folding function. + /// + /// # Returns + /// Returns a new `TryFold` instance. + pub fn new(f: impl Fn(&I, O) -> Valid + 'a) -> Self { + TryFold(Box::new(f)) + } + + /// Transforms a TryFold to TryFold by applying transformations. + /// Check `transform_valid` if you want to return a `Valid` instead of an `O1`. + /// + /// # Parameters + /// - `up`: A function that uses O and O1 to create a new O1. + /// - `down`: A function that uses O1 to create a new O. + /// + /// # Returns + /// Returns a new TryFold that applies the transformations. + /// + pub fn transform( + self, + up: impl Fn(O, O1) -> O1 + 'a, + down: impl Fn(O1) -> O + 'a, + ) -> TryFold<'a, I, O1, E> { + self.transform_valid( + move |o, o1| Valid::succeed(up(o, o1)), + move |o1| Valid::succeed(down(o1)), + ) + } + + /// Transforms a TryFold to TryFold by applying transformations. + /// Check `transform` if you want to return an `O1` instead of a `Valid`. + /// + /// # Parameters + /// - `up`: A function that uses O and O1 to create a new Valid. + /// - `down`: A function that uses O1 to create a new Valid. + /// + /// # Returns + /// Returns a new TryFold that applies the transformations. + /// + pub fn transform_valid( + self, + up: impl Fn(O, O1) -> Valid + 'a, + down: impl Fn(O1) -> Valid + 'a, + ) -> TryFold<'a, I, O1, E> { + TryFold(Box::new(move |i, o1| { + down(o1.clone()) + .and_then(|o| self.try_fold(i, o)) + .and_then(|o| up(o, o1)) + })) + } + + pub fn update(self, f: impl Fn(O) -> O + 'a) -> TryFold<'a, I, O, E> { + self.transform(move |o, _| f(o), |o| o) + } + + /// Create a `TryFold` that always succeeds with the provided state. + /// + /// # Parameters + /// - `state`: The state to succeed with. + /// + /// # Returns + /// Returns a `TryFold` that always succeeds with the provided state. + pub fn succeed(f: impl Fn(&I, O) -> O + 'a) -> Self { + TryFold(Box::new(move |i, o| Valid::succeed(f(i, o)))) + } + + /// Create a `TryFold` that doesn't do anything. + /// + /// # Returns + /// Returns a `TryFold` that doesn't do anything. + pub fn empty() -> Self { + TryFold::new(|_, o| Valid::succeed(o)) + } + + /// Create a `TryFold` that always fails with the provided error. + /// + /// # Parameters + /// - `e`: The error to fail with. + /// + /// # Returns + /// Returns a `TryFold` that always fails with the provided error. + pub fn fail(e: E) -> Self + where + E: Clone, + { + TryFold::new(move |_, _| Valid::fail(e.clone())) + } + + /// Add trace logging to the fold operation. + /// + /// # Parameters + /// + /// * `msg` - The message to log when this fold operation is executed. + /// + /// # Returns + /// + /// Returns a new `TryFold` with trace logging added. + pub fn trace(self, msg: &'a str) -> Self { + TryFold::new(move |i, o| self.try_fold(i, o).trace(msg)) + } } impl<'a, I, O: Clone, E> FromIterator> for TryFold<'a, I, O, E> { - fn from_iter>>(iter: T) -> Self { - let mut iter = iter.into_iter(); - let head = iter.next(); - - if let Some(head) = head { - head.and(TryFold::from_iter(iter)) - } else { - TryFold::empty() + fn from_iter>>(iter: T) -> Self { + let mut iter = iter.into_iter(); + let head = iter.next(); + + if let Some(head) = head { + head.and(TryFold::from_iter(iter)) + } else { + TryFold::empty() + } } - } } #[cfg(test)] mod tests { - use std::cell::RefCell; - - use super::TryFold; - use crate::valid::{Valid, ValidationError}; - - #[test] - fn test_and() { - let t1 = TryFold::::new(|a: &i32, b: i32| Valid::succeed(a + b)); - let t2 = TryFold::::new(|a: &i32, b: i32| Valid::succeed(a * b)); - let t = t1.and(t2); - - let actual = t.try_fold(&2, 3).to_result().unwrap(); - let expected = 10; - - assert_eq!(actual, expected) - } - - #[test] - fn test_one_failure() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); - let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); - let t = t1.and(t2); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(5); - - assert_eq!(actual, expected) - } - - #[test] - fn test_both_failure() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); - let t2 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b)); - let t = t1.and(t2); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(5).combine(ValidationError::new(6)); - - assert_eq!(actual, expected) - } - - #[test] - fn test_order() { - let calls = RefCell::new(Vec::new()); - let t1 = TryFold::::new(|a: &i32, b: i32| { - calls.borrow_mut().push(1); - Valid::succeed(a + b) - }); // 2 + 3 - let t2 = TryFold::new(|a: &i32, b: i32| { - calls.borrow_mut().push(2); - Valid::succeed(a * b) - }); // 2 * 3 - let t3 = TryFold::new(|a: &i32, b: i32| { - calls.borrow_mut().push(3); - Valid::succeed(a * b * 100) - }); // 2 * 6 - let _t = t1.and(t2).and(t3).try_fold(&2, 3); - - assert_eq!(*calls.borrow(), vec![1, 2, 3]); - } - - #[test] - fn test_1_3_failure_left() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); // 2 + 3 - let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); // 2 * 3 - let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); // 2 * 6 - let t = t1.and(t2).and(t3); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(5).combine(ValidationError::new(600)); - - assert_eq!(actual, expected) - } - - #[test] - fn test_1_3_failure_right() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); // 2 + 3 - let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); // 2 * 3 - let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); // 2 * 6 - let t = t1.and(t2.and(t3)); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(5).combine(ValidationError::new(1200)); - - assert_eq!(actual, expected) - } - - #[test] - fn test_2_3_failure() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a + b)); - let t2 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b)); - let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); - let t = t1.and(t2.and(t3)); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(10).combine(ValidationError::new(1000)); - - assert_eq!(actual, expected) - } - - #[test] - fn test_try_all() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a + b)); - let t2 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b)); - let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); - let t = TryFold::from_iter(vec![t1, t2, t3]); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(10).combine(ValidationError::new(1000)); - - assert_eq!(actual, expected) - } - - #[test] - fn test_try_all_1_3_fail() { - let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); - let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); - let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); - let t = TryFold::from_iter(vec![t1, t2, t3]); - - let actual = t.try_fold(&2, 3).to_result().unwrap_err(); - let expected = ValidationError::new(5).combine(ValidationError::new(1200)); - - assert_eq!(actual, expected) - } - - #[test] - fn test_transform() { - let t: TryFold<'_, i32, String, ()> = TryFold::succeed(|a: &i32, b: i32| a + b) - .transform(|v: i32, _| v.to_string(), |v: String| v.parse::().unwrap()); - - let actual = t.try_fold(&2, "3".to_string()).to_result().unwrap(); - let expected = "5".to_string(); - - assert_eq!(actual, expected) - } - - #[test] - fn test_transform_valid() { - let t: TryFold<'_, i32, String, ()> = TryFold::succeed(|a: &i32, b: i32| a + b).transform_valid( - |v: i32, _| Valid::succeed(v.to_string()), - |v: String| Valid::succeed(v.parse::().unwrap()), - ); - - let actual = t.try_fold(&2, "3".to_string()).to_result().unwrap(); - let expected = "5".to_string(); - - assert_eq!(actual, expected) - } - - #[test] - fn test_update() { - let t = TryFold::::succeed(|a: &i32, b: i32| a + b).update(|a| a + 1); - let actual = t.try_fold(&2, 3).to_result().unwrap(); - let expected = 6; - assert_eq!(actual, expected); - } + use std::cell::RefCell; + + use super::TryFold; + use crate::valid::{Valid, ValidationError}; + + #[test] + fn test_and() { + let t1 = TryFold::::new(|a: &i32, b: i32| Valid::succeed(a + b)); + let t2 = TryFold::::new(|a: &i32, b: i32| Valid::succeed(a * b)); + let t = t1.and(t2); + + let actual = t.try_fold(&2, 3).to_result().unwrap(); + let expected = 10; + + assert_eq!(actual, expected) + } + + #[test] + fn test_one_failure() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); + let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); + let t = t1.and(t2); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(5); + + assert_eq!(actual, expected) + } + + #[test] + fn test_both_failure() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); + let t2 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b)); + let t = t1.and(t2); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(5).combine(ValidationError::new(6)); + + assert_eq!(actual, expected) + } + + #[test] + fn test_order() { + let calls = RefCell::new(Vec::new()); + let t1 = TryFold::::new(|a: &i32, b: i32| { + calls.borrow_mut().push(1); + Valid::succeed(a + b) + }); // 2 + 3 + let t2 = TryFold::new(|a: &i32, b: i32| { + calls.borrow_mut().push(2); + Valid::succeed(a * b) + }); // 2 * 3 + let t3 = TryFold::new(|a: &i32, b: i32| { + calls.borrow_mut().push(3); + Valid::succeed(a * b * 100) + }); // 2 * 6 + let _t = t1.and(t2).and(t3).try_fold(&2, 3); + + assert_eq!(*calls.borrow(), vec![1, 2, 3]); + } + + #[test] + fn test_1_3_failure_left() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); // 2 + 3 + let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); // 2 * 3 + let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); // 2 * 6 + let t = t1.and(t2).and(t3); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(5).combine(ValidationError::new(600)); + + assert_eq!(actual, expected) + } + + #[test] + fn test_1_3_failure_right() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); // 2 + 3 + let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); // 2 * 3 + let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); // 2 * 6 + let t = t1.and(t2.and(t3)); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(5).combine(ValidationError::new(1200)); + + assert_eq!(actual, expected) + } + + #[test] + fn test_2_3_failure() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a + b)); + let t2 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b)); + let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); + let t = t1.and(t2.and(t3)); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(10).combine(ValidationError::new(1000)); + + assert_eq!(actual, expected) + } + + #[test] + fn test_try_all() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a + b)); + let t2 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b)); + let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); + let t = TryFold::from_iter(vec![t1, t2, t3]); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(10).combine(ValidationError::new(1000)); + + assert_eq!(actual, expected) + } + + #[test] + fn test_try_all_1_3_fail() { + let t1 = TryFold::new(|a: &i32, b: i32| Valid::fail(a + b)); + let t2 = TryFold::new(|a: &i32, b: i32| Valid::succeed(a * b)); + let t3 = TryFold::new(|a: &i32, b: i32| Valid::fail(a * b * 100)); + let t = TryFold::from_iter(vec![t1, t2, t3]); + + let actual = t.try_fold(&2, 3).to_result().unwrap_err(); + let expected = ValidationError::new(5).combine(ValidationError::new(1200)); + + assert_eq!(actual, expected) + } + + #[test] + fn test_transform() { + let t: TryFold<'_, i32, String, ()> = TryFold::succeed(|a: &i32, b: i32| a + b).transform( + |v: i32, _| v.to_string(), + |v: String| v.parse::().unwrap(), + ); + + let actual = t.try_fold(&2, "3".to_string()).to_result().unwrap(); + let expected = "5".to_string(); + + assert_eq!(actual, expected) + } + + #[test] + fn test_transform_valid() { + let t: TryFold<'_, i32, String, ()> = TryFold::succeed(|a: &i32, b: i32| a + b) + .transform_valid( + |v: i32, _| Valid::succeed(v.to_string()), + |v: String| Valid::succeed(v.parse::().unwrap()), + ); + + let actual = t.try_fold(&2, "3".to_string()).to_result().unwrap(); + let expected = "5".to_string(); + + assert_eq!(actual, expected) + } + + #[test] + fn test_update() { + let t = TryFold::::succeed(|a: &i32, b: i32| a + b).update(|a| a + 1); + let actual = t.try_fold(&2, 3).to_result().unwrap(); + let expected = 6; + assert_eq!(actual, expected); + } } diff --git a/src/valid/cause.rs b/src/valid/cause.rs index e76d0551840..5cea7849535 100644 --- a/src/valid/cause.rs +++ b/src/valid/cause.rs @@ -6,53 +6,60 @@ use thiserror::Error; #[derive(Clone, PartialEq, Debug, Setters, Error)] pub struct Cause { - pub message: E, - #[setters(strip_option)] - pub description: Option, - #[setters(skip)] - pub trace: VecDeque, + pub message: E, + #[setters(strip_option)] + pub description: Option, + #[setters(skip)] + pub trace: VecDeque, } impl Display for Cause { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[")?; - for (i, entry) in self.trace.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", entry)?; + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[")?; + for (i, entry) in self.trace.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", entry)?; + } + write!(f, "] {}", self.message)?; + if let Some(desc) = self.description.as_ref() { + write!(f, ": {}", desc)?; + } + Ok(()) } - write!(f, "] {}", self.message)?; - if let Some(desc) = self.description.as_ref() { - write!(f, ": {}", desc)?; - } - Ok(()) - } } impl Cause { - pub fn new(e: E) -> Self { - Cause { message: e, description: None, trace: VecDeque::new() } - } + pub fn new(e: E) -> Self { + Cause { message: e, description: None, trace: VecDeque::new() } + } - pub fn transform(self, e: impl Fn(E) -> E1) -> Cause { - Cause { message: e(self.message), description: self.description.map(e), trace: self.trace } - } + pub fn transform(self, e: impl Fn(E) -> E1) -> Cause { + Cause { + message: e(self.message), + description: self.description.map(e), + trace: self.trace, + } + } - pub fn trace(mut self, trace: Vec) -> Self { - self.trace = trace.iter().map(|t| t.to_string()).collect::>(); - self - } + pub fn trace(mut self, trace: Vec) -> Self { + self.trace = trace + .iter() + .map(|t| t.to_string()) + .collect::>(); + self + } } #[cfg(test)] mod tests { - #[test] - fn test_display() { - use super::Cause; - let cause = Cause::new("error") - .trace(vec!["trace0", "trace1"]) - .description("description"); - assert_eq!(cause.to_string(), "[trace0, trace1] error: description"); - } + #[test] + fn test_display() { + use super::Cause; + let cause = Cause::new("error") + .trace(vec!["trace0", "trace1"]) + .description("description"); + assert_eq!(cause.to_string(), "[trace0, trace1] error: description"); + } } diff --git a/src/valid/error.rs b/src/valid/error.rs index cda0ca0b17c..146df552ed2 100644 --- a/src/valid/error.rs +++ b/src/valid/error.rs @@ -8,146 +8,161 @@ use super::Cause; pub struct ValidationError(Vec>); impl Display for ValidationError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("Validation Error\n")?; - let errors = self.as_vec(); - for error in errors { - f.write_str(format!("{} {}", '\u{2022}', error.message).as_str())?; - if !error.trace.is_empty() { - f.write_str(&(format!(" [{}]", error.trace.iter().cloned().collect::>().join(", "))))?; - } - f.write_str("\n")?; - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Validation Error\n")?; + let errors = self.as_vec(); + for error in errors { + f.write_str(format!("{} {}", '\u{2022}', error.message).as_str())?; + if !error.trace.is_empty() { + f.write_str( + &(format!( + " [{}]", + error + .trace + .iter() + .cloned() + .collect::>() + .join(", ") + )), + )?; + } + f.write_str("\n")?; + } - Ok(()) - } + Ok(()) + } } impl ValidationError { - pub fn as_vec(&self) -> &Vec> { - &self.0 - } - - pub fn combine(mut self, mut other: ValidationError) -> ValidationError { - self.0.append(&mut other.0); - self - } - - pub fn empty() -> Self { - ValidationError(Vec::new()) - } - - pub fn new(e: E) -> Self { - ValidationError(vec![Cause::new(e)]) - } - - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - pub fn trace(self, message: &str) -> Self { - let mut errors = self.0; - for cause in errors.iter_mut() { - cause.trace.insert(0, message.to_owned()); + pub fn as_vec(&self) -> &Vec> { + &self.0 + } + + pub fn combine(mut self, mut other: ValidationError) -> ValidationError { + self.0.append(&mut other.0); + self + } + + pub fn empty() -> Self { + ValidationError(Vec::new()) + } + + pub fn new(e: E) -> Self { + ValidationError(vec![Cause::new(e)]) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn trace(self, message: &str) -> Self { + let mut errors = self.0; + for cause in errors.iter_mut() { + cause.trace.insert(0, message.to_owned()); + } + Self(errors) + } + + pub fn append(self, error: E) -> Self { + let mut errors = self.0; + errors.push(Cause::new(error)); + Self(errors) + } + + pub fn transform(self, f: &impl Fn(E) -> E1) -> ValidationError { + ValidationError(self.0.into_iter().map(|cause| cause.transform(f)).collect()) } - Self(errors) - } - - pub fn append(self, error: E) -> Self { - let mut errors = self.0; - errors.push(Cause::new(error)); - Self(errors) - } - - pub fn transform(self, f: &impl Fn(E) -> E1) -> ValidationError { - ValidationError(self.0.into_iter().map(|cause| cause.transform(f)).collect()) - } } impl std::error::Error for ValidationError {} impl From> for ValidationError { - fn from(value: Cause) -> Self { - ValidationError(vec![value]) - } + fn from(value: Cause) -> Self { + ValidationError(vec![value]) + } } impl From>> for ValidationError { - fn from(value: Vec>) -> Self { - ValidationError(value) - } + fn from(value: Vec>) -> Self { + ValidationError(value) + } } impl From> for ValidationError { - fn from(error: serde_path_to_error::Error) -> Self { - let mut trace = Vec::new(); - let segments = error.path().iter(); - let len = segments.len(); - for (i, segment) in segments.enumerate() { - match segment { - serde_path_to_error::Segment::Seq { index } => { - trace.push(format!("[{}]", index)); + fn from(error: serde_path_to_error::Error) -> Self { + let mut trace = Vec::new(); + let segments = error.path().iter(); + let len = segments.len(); + for (i, segment) in segments.enumerate() { + match segment { + serde_path_to_error::Segment::Seq { index } => { + trace.push(format!("[{}]", index)); + } + serde_path_to_error::Segment::Map { key } => { + trace.push(key.to_string()); + } + serde_path_to_error::Segment::Enum { variant } => { + trace.push(variant.to_string()); + } + serde_path_to_error::Segment::Unknown => { + trace.push("?".to_owned()); + } + } + if i < len - 1 { + trace.push(".".to_owned()); + } } - serde_path_to_error::Segment::Map { key } => { - trace.push(key.to_string()); - } - serde_path_to_error::Segment::Enum { variant } => { - trace.push(variant.to_string()); - } - serde_path_to_error::Segment::Unknown => { - trace.push("?".to_owned()); - } - } - if i < len - 1 { - trace.push(".".to_owned()); - } - } - let re = Regex::new(r" at line \d+ column \d+$").unwrap(); - let message = re - .replace(format!("Parsing failed because of {}", error.inner()).as_str(), "") - .into_owned(); + let re = Regex::new(r" at line \d+ column \d+$").unwrap(); + let message = re + .replace( + format!("Parsing failed because of {}", error.inner()).as_str(), + "", + ) + .into_owned(); - ValidationError(vec![Cause::new(message).trace(trace)]) - } + ValidationError(vec![Cause::new(message).trace(trace)]) + } } #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; - use stripmargin::StripMargin; - - use crate::valid::{Cause, ValidationError}; - - #[derive(Debug, PartialEq, serde::Deserialize)] - struct Foo { - a: i32, - } - - #[test] - fn test_error_display_formatting() { - let error = ValidationError::from(vec![ - Cause::new("1").trace(vec!["a", "b"]), - Cause::new("2"), - Cause::new("3"), - ]); - let expected_output = "\ + use pretty_assertions::assert_eq; + use stripmargin::StripMargin; + + use crate::valid::{Cause, ValidationError}; + + #[derive(Debug, PartialEq, serde::Deserialize)] + struct Foo { + a: i32, + } + + #[test] + fn test_error_display_formatting() { + let error = ValidationError::from(vec![ + Cause::new("1").trace(vec!["a", "b"]), + Cause::new("2"), + Cause::new("3"), + ]); + let expected_output = "\ |Validation Error |• 1 [a, b] |• 2 |• 3 |" - .strip_margin(); - assert_eq!(format!("{}", error), expected_output); - } - - #[test] - fn test_from_serde_error() { - let foo = &mut serde_json::Deserializer::from_str("{ \"a\": true }"); - let actual = ValidationError::from(serde_path_to_error::deserialize::<_, Foo>(foo).unwrap_err()); - let expected = - ValidationError::new("Parsing failed because of invalid type: boolean `true`, expected i32".to_string()) + .strip_margin(); + assert_eq!(format!("{}", error), expected_output); + } + + #[test] + fn test_from_serde_error() { + let foo = &mut serde_json::Deserializer::from_str("{ \"a\": true }"); + let actual = + ValidationError::from(serde_path_to_error::deserialize::<_, Foo>(foo).unwrap_err()); + let expected = ValidationError::new( + "Parsing failed because of invalid type: boolean `true`, expected i32".to_string(), + ) .trace("a"); - assert_eq!(actual, expected); - } + assert_eq!(actual, expected); + } } diff --git a/src/valid/valid.rs b/src/valid/valid.rs index 58164942701..7cc557f59ea 100644 --- a/src/valid/valid.rs +++ b/src/valid/valid.rs @@ -5,311 +5,323 @@ use crate::valid::Cause; pub struct Valid(Result>); impl Valid { - pub fn fail(e: E) -> Valid { - Valid(Err((vec![Cause::new(e)]).into())) - } - - pub fn fail_with(message: E, description: E) -> Valid - where - E: std::fmt::Debug, - { - Valid(Err((vec![Cause::new(message).description(description)]).into())) - } - - pub fn from_validation_err(error: ValidationError) -> Self { - Valid(Err(error)) - } - - pub fn from_vec_cause(error: Vec>) -> Self { - Valid(Err(error.into())) - } - - pub fn map(self, f: impl FnOnce(A) -> A1) -> Valid { - Valid(self.0.map(f)) - } - - pub fn foreach(self, mut f: impl FnMut(A)) -> Valid - where - A: Clone, - { - match self.0 { - Ok(a) => { - f(a.clone()); - Valid::succeed(a) - } - Err(e) => Valid(Err(e)), - } - } - - pub fn succeed(a: A) -> Valid { - Valid(Ok(a)) - } - - pub fn is_succeed(&self) -> bool { - self.0.is_ok() - } - - pub fn and(self, other: Valid) -> Valid { - self.zip(other).map(|(_, a1)| a1) - } - - pub fn zip(self, other: Valid) -> Valid<(A, A1), E> { - match self.0 { - Ok(a) => match other.0 { - Ok(a1) => Valid(Ok((a, a1))), - Err(e1) => Valid(Err(e1)), - }, - Err(e1) => match other.0 { - Ok(_) => Valid(Err(e1)), - Err(e2) => Valid(Err(e1.combine(e2))), - }, - } - } - - pub fn trace(self, message: &str) -> Valid { - let valid = self.0; - if let Err(error) = valid { - return Valid(Err(error.trace(message))); - } - - Valid(valid) - } - - pub fn fold(self, ok: impl FnOnce(A) -> Valid, err: impl FnOnce() -> Valid) -> Valid { - match self.0 { - Ok(a) => ok(a), - Err(e) => Valid::(Err(e)).and(err()), - } - } - - pub fn from_iter(iter: impl IntoIterator, f: impl Fn(A) -> Valid) -> Valid, E> { - let mut values: Vec = Vec::new(); - let mut errors: ValidationError = ValidationError::empty(); - for a in iter.into_iter() { - match f(a).to_result() { - Ok(b) => { - values.push(b); + pub fn fail(e: E) -> Valid { + Valid(Err((vec![Cause::new(e)]).into())) + } + + pub fn fail_with(message: E, description: E) -> Valid + where + E: std::fmt::Debug, + { + Valid(Err( + (vec![Cause::new(message).description(description)]).into() + )) + } + + pub fn from_validation_err(error: ValidationError) -> Self { + Valid(Err(error)) + } + + pub fn from_vec_cause(error: Vec>) -> Self { + Valid(Err(error.into())) + } + + pub fn map(self, f: impl FnOnce(A) -> A1) -> Valid { + Valid(self.0.map(f)) + } + + pub fn foreach(self, mut f: impl FnMut(A)) -> Valid + where + A: Clone, + { + match self.0 { + Ok(a) => { + f(a.clone()); + Valid::succeed(a) + } + Err(e) => Valid(Err(e)), } - Err(err) => { - errors = errors.combine(err); + } + + pub fn succeed(a: A) -> Valid { + Valid(Ok(a)) + } + + pub fn is_succeed(&self) -> bool { + self.0.is_ok() + } + + pub fn and(self, other: Valid) -> Valid { + self.zip(other).map(|(_, a1)| a1) + } + + pub fn zip(self, other: Valid) -> Valid<(A, A1), E> { + match self.0 { + Ok(a) => match other.0 { + Ok(a1) => Valid(Ok((a, a1))), + Err(e1) => Valid(Err(e1)), + }, + Err(e1) => match other.0 { + Ok(_) => Valid(Err(e1)), + Err(e2) => Valid(Err(e1.combine(e2))), + }, } - } } - if errors.is_empty() { - Valid::succeed(values) - } else { - Valid::from_validation_err(errors) + pub fn trace(self, message: &str) -> Valid { + let valid = self.0; + if let Err(error) = valid { + return Valid(Err(error.trace(message))); + } + + Valid(valid) } - } - pub fn from_option(option: Option, e: E) -> Valid { - match option { - Some(a) => Valid::succeed(a), - None => Valid::fail(e), + pub fn fold( + self, + ok: impl FnOnce(A) -> Valid, + err: impl FnOnce() -> Valid, + ) -> Valid { + match self.0 { + Ok(a) => ok(a), + Err(e) => Valid::(Err(e)).and(err()), + } } - } - pub fn to_result(self) -> Result> { - self.0 - } + pub fn from_iter( + iter: impl IntoIterator, + f: impl Fn(A) -> Valid, + ) -> Valid, E> { + let mut values: Vec = Vec::new(); + let mut errors: ValidationError = ValidationError::empty(); + for a in iter.into_iter() { + match f(a).to_result() { + Ok(b) => { + values.push(b); + } + Err(err) => { + errors = errors.combine(err); + } + } + } + + if errors.is_empty() { + Valid::succeed(values) + } else { + Valid::from_validation_err(errors) + } + } + + pub fn from_option(option: Option, e: E) -> Valid { + match option { + Some(a) => Valid::succeed(a), + None => Valid::fail(e), + } + } - pub fn and_then(self, f: impl FnOnce(A) -> Valid) -> Valid { - match self.0 { - Ok(a) => f(a), - Err(e) => Valid(Err(e)), + pub fn to_result(self) -> Result> { + self.0 + } + + pub fn and_then(self, f: impl FnOnce(A) -> Valid) -> Valid { + match self.0 { + Ok(a) => f(a), + Err(e) => Valid(Err(e)), + } } - } - pub fn unit(self) -> Valid<(), E> { - self.map(|_| ()) - } + pub fn unit(self) -> Valid<(), E> { + self.map(|_| ()) + } - pub fn some(self) -> Valid, E> { - self.map(Some) - } + pub fn some(self) -> Valid, E> { + self.map(Some) + } - pub fn none() -> Valid, E> { - Valid::succeed(None) - } - pub fn map_to(self, b: B) -> Valid { - self.map(|_| b) - } - pub fn when(self, f: impl FnOnce() -> bool) -> Valid<(), E> { - if f() { - self.unit() - } else { - Valid::succeed(()) + pub fn none() -> Valid, E> { + Valid::succeed(None) + } + pub fn map_to(self, b: B) -> Valid { + self.map(|_| b) + } + pub fn when(self, f: impl FnOnce() -> bool) -> Valid<(), E> { + if f() { + self.unit() + } else { + Valid::succeed(()) + } } - } } impl From>> for Valid { - fn from(value: Result>) -> Self { - match value { - Ok(a) => Valid::succeed(a), - Err(e) => Valid::from_validation_err(e), + fn from(value: Result>) -> Self { + match value { + Ok(a) => Valid::succeed(a), + Err(e) => Valid::from_validation_err(e), + } } - } } #[cfg(test)] mod tests { - use super::{Cause, ValidationError}; - use crate::valid::valid::Valid; - - #[test] - fn test_ok() { - let result = Valid::::succeed(1); - assert_eq!(result, Valid::succeed(1)); - } - - #[test] - fn test_fail() { - let result = Valid::<(), i32>::fail(1); - assert_eq!(result, Valid::fail(1)); - } - - #[test] - fn test_validate_or_both_ok() { - let result1 = Valid::::succeed(true); - let result2 = Valid::::succeed(3); - - assert_eq!(result1.and(result2), Valid::succeed(3u8)); - } - - #[test] - fn test_validate_or_first_fail() { - let result1 = Valid::::fail(-1); - let result2 = Valid::::succeed(3); - - assert_eq!(result1.and(result2), Valid::fail(-1)); - } - - #[test] - fn test_validate_or_second_fail() { - let result1 = Valid::::succeed(true); - let result2 = Valid::::fail(-2); - - assert_eq!(result1.and(result2), Valid::fail(-2)); - } - - #[test] - fn test_validate_all() { - let input: Vec = [1, 2, 3].to_vec(); - let result: Valid, i32> = Valid::from_iter(input, |a| Valid::fail(a * 2)); - assert_eq!( - result, - Valid::from_vec_cause(vec![Cause::new(2), Cause::new(4), Cause::new(6)]) - ); - } - - #[test] - fn test_validate_all_ques() { - let input: Vec = [1, 2, 3].to_vec(); - let result: Valid, i32> = Valid::from_iter(input, |a| Valid::fail(a * 2)); - assert_eq!( - result, - Valid::from_vec_cause(vec![Cause::new(2), Cause::new(4), Cause::new(6)]) - ); - } - - #[test] - fn test_ok_ok_cause() { - let option: Option = None; - let result = Valid::from_option(option, 1); - assert_eq!(result, Valid::from_vec_cause(vec![Cause::new(1)])); - } - - #[test] - fn test_trace() { - let result = Valid::<(), i32>::fail(1).trace("A").trace("B").trace("C"); - let expected = Valid::from_vec_cause(vec![Cause { - message: 1, - description: None, - trace: vec!["C".to_string(), "B".to_string(), "A".to_string()].into(), - }]); - assert_eq!(result, expected); - } - - #[test] - fn test_validate_fold_err() { - let valid = Valid::<(), i32>::fail(1); - let result = valid.fold(|_| Valid::<(), i32>::fail(2), || Valid::<(), i32>::fail(3)); - assert_eq!(result, Valid::from_vec_cause(vec![Cause::new(1), Cause::new(3)])); - } - - #[test] - fn test_validate_fold_ok() { - let valid = Valid::::succeed(1); - let result = valid.fold(Valid::::fail, || Valid::::fail(2)); - assert_eq!(result, Valid::fail(1)); - } - - #[test] - fn test_to_result() { - let result = Valid::<(), i32>::fail(1).to_result().unwrap_err(); - assert_eq!(result, ValidationError::new(1)); - } - - #[test] - fn test_validate_both_ok() { - let result1 = Valid::::succeed(true); - let result2 = Valid::::succeed(3); - - assert_eq!(result1.zip(result2), Valid::succeed((true, 3u8))); - } - #[test] - fn test_validate_both_first_fail() { - let result1 = Valid::::fail(-1); - let result2 = Valid::::succeed(3); - - assert_eq!(result1.zip(result2), Valid::fail(-1)); - } - #[test] - fn test_validate_both_second_fail() { - let result1 = Valid::::succeed(true); - let result2 = Valid::::fail(-2); - - assert_eq!(result1.zip(result2), Valid::fail(-2)); - } - - #[test] - fn test_validate_both_both_fail() { - let result1 = Valid::::fail(-1); - let result2 = Valid::::fail(-2); - - assert_eq!( - result1.zip(result2), - Valid::from_vec_cause(vec![Cause::new(-1), Cause::new(-2)]) - ); - } - - #[test] - fn test_and_then_success() { - let result = Valid::::succeed(1).and_then(|a| Valid::succeed(a + 1)); - assert_eq!(result, Valid::succeed(2)); - } - - #[test] - fn test_and_then_fail() { - let result = Valid::::succeed(1).and_then(|a| Valid::::fail(a + 1)); - assert_eq!(result, Valid::fail(2)); - } - - #[test] - fn test_foreach_succeed() { - let mut a = 0; - let result = Valid::::succeed(1).foreach(|v| a = v); - assert_eq!(result, Valid::succeed(1)); - assert_eq!(a, 1); - } - - #[test] - fn test_foreach_fail() { - let mut a = 0; - let result = Valid::::fail(1).foreach(|v| a = v); - assert_eq!(result, Valid::fail(1)); - assert_eq!(a, 0); - } + use super::{Cause, ValidationError}; + use crate::valid::valid::Valid; + + #[test] + fn test_ok() { + let result = Valid::::succeed(1); + assert_eq!(result, Valid::succeed(1)); + } + + #[test] + fn test_fail() { + let result = Valid::<(), i32>::fail(1); + assert_eq!(result, Valid::fail(1)); + } + + #[test] + fn test_validate_or_both_ok() { + let result1 = Valid::::succeed(true); + let result2 = Valid::::succeed(3); + + assert_eq!(result1.and(result2), Valid::succeed(3u8)); + } + + #[test] + fn test_validate_or_first_fail() { + let result1 = Valid::::fail(-1); + let result2 = Valid::::succeed(3); + + assert_eq!(result1.and(result2), Valid::fail(-1)); + } + + #[test] + fn test_validate_or_second_fail() { + let result1 = Valid::::succeed(true); + let result2 = Valid::::fail(-2); + + assert_eq!(result1.and(result2), Valid::fail(-2)); + } + + #[test] + fn test_validate_all() { + let input: Vec = [1, 2, 3].to_vec(); + let result: Valid, i32> = Valid::from_iter(input, |a| Valid::fail(a * 2)); + assert_eq!( + result, + Valid::from_vec_cause(vec![Cause::new(2), Cause::new(4), Cause::new(6)]) + ); + } + + #[test] + fn test_validate_all_ques() { + let input: Vec = [1, 2, 3].to_vec(); + let result: Valid, i32> = Valid::from_iter(input, |a| Valid::fail(a * 2)); + assert_eq!( + result, + Valid::from_vec_cause(vec![Cause::new(2), Cause::new(4), Cause::new(6)]) + ); + } + + #[test] + fn test_ok_ok_cause() { + let option: Option = None; + let result = Valid::from_option(option, 1); + assert_eq!(result, Valid::from_vec_cause(vec![Cause::new(1)])); + } + + #[test] + fn test_trace() { + let result = Valid::<(), i32>::fail(1).trace("A").trace("B").trace("C"); + let expected = Valid::from_vec_cause(vec![Cause { + message: 1, + description: None, + trace: vec!["C".to_string(), "B".to_string(), "A".to_string()].into(), + }]); + assert_eq!(result, expected); + } + + #[test] + fn test_validate_fold_err() { + let valid = Valid::<(), i32>::fail(1); + let result = valid.fold(|_| Valid::<(), i32>::fail(2), || Valid::<(), i32>::fail(3)); + assert_eq!( + result, + Valid::from_vec_cause(vec![Cause::new(1), Cause::new(3)]) + ); + } + + #[test] + fn test_validate_fold_ok() { + let valid = Valid::::succeed(1); + let result = valid.fold(Valid::::fail, || Valid::::fail(2)); + assert_eq!(result, Valid::fail(1)); + } + + #[test] + fn test_to_result() { + let result = Valid::<(), i32>::fail(1).to_result().unwrap_err(); + assert_eq!(result, ValidationError::new(1)); + } + + #[test] + fn test_validate_both_ok() { + let result1 = Valid::::succeed(true); + let result2 = Valid::::succeed(3); + + assert_eq!(result1.zip(result2), Valid::succeed((true, 3u8))); + } + #[test] + fn test_validate_both_first_fail() { + let result1 = Valid::::fail(-1); + let result2 = Valid::::succeed(3); + + assert_eq!(result1.zip(result2), Valid::fail(-1)); + } + #[test] + fn test_validate_both_second_fail() { + let result1 = Valid::::succeed(true); + let result2 = Valid::::fail(-2); + + assert_eq!(result1.zip(result2), Valid::fail(-2)); + } + + #[test] + fn test_validate_both_both_fail() { + let result1 = Valid::::fail(-1); + let result2 = Valid::::fail(-2); + + assert_eq!( + result1.zip(result2), + Valid::from_vec_cause(vec![Cause::new(-1), Cause::new(-2)]) + ); + } + + #[test] + fn test_and_then_success() { + let result = Valid::::succeed(1).and_then(|a| Valid::succeed(a + 1)); + assert_eq!(result, Valid::succeed(2)); + } + + #[test] + fn test_and_then_fail() { + let result = Valid::::succeed(1).and_then(|a| Valid::::fail(a + 1)); + assert_eq!(result, Valid::fail(2)); + } + + #[test] + fn test_foreach_succeed() { + let mut a = 0; + let result = Valid::::succeed(1).foreach(|v| a = v); + assert_eq!(result, Valid::succeed(1)); + assert_eq!(a, 1); + } + + #[test] + fn test_foreach_fail() { + let mut a = 0; + let result = Valid::::fail(1).foreach(|v| a = v); + assert_eq!(result, Valid::fail(1)); + assert_eq!(a, 0); + } } diff --git a/tests/graphql_spec.rs b/tests/graphql_spec.rs index e91e8e5fbee..3528d836b24 100644 --- a/tests/graphql_spec.rs +++ b/tests/graphql_spec.rs @@ -26,203 +26,209 @@ static INIT: Once = Once::new(); #[derive(Debug, Clone, PartialEq)] enum Tag { - ClientSDL, - ServerSDL, - MergedSDL, + ClientSDL, + ServerSDL, + MergedSDL, } #[derive(Debug, Clone)] struct Source { - sdl: String, - tag: Tag, + sdl: String, + tag: Tag, } #[derive(Debug, Default, Setters)] struct GraphQLSpec { - path: PathBuf, - sources: Vec, - sdl_errors: Vec, - test_queries: Vec, - annotation: Option, + path: PathBuf, + sources: Vec, + sdl_errors: Vec, + test_queries: Vec, + annotation: Option, } #[derive(Debug)] enum Annotation { - Skip, - Only, - Fail, + Skip, + Only, + Fail, } impl GraphQLSpec { - fn find_source(&self, tag: Tag) -> String { - self.get_sources(tag).next().unwrap().to_string() - } - - fn get_sources(&self, tag: Tag) -> impl Iterator { - self - .sources - .iter() - .filter(move |s| s.tag == tag) - .map(|s| s.sdl.as_str()) - } + fn find_source(&self, tag: Tag) -> String { + self.get_sources(tag).next().unwrap().to_string() + } + + fn get_sources(&self, tag: Tag) -> impl Iterator { + self.sources + .iter() + .filter(move |s| s.tag == tag) + .map(|s| s.sdl.as_str()) + } } #[derive(Debug, Default, Deserialize, Serialize, PartialEq)] struct SDLError { - message: String, - trace: Vec, - description: Option, + message: String, + trace: Vec, + description: Option, } impl<'a> From> for SDLError { - fn from(value: Cause<&'a str>) -> Self { - SDLError { - message: value.message.to_string(), - trace: value.trace.iter().map(|e| e.to_string()).collect(), - description: None, + fn from(value: Cause<&'a str>) -> Self { + SDLError { + message: value.message.to_string(), + trace: value.trace.iter().map(|e| e.to_string()).collect(), + description: None, + } } - } } impl From> for SDLError { - fn from(value: Cause) -> Self { - SDLError { - message: value.message.to_string(), - trace: value.trace.iter().map(|e| e.to_string()).collect(), - description: value.description, + fn from(value: Cause) -> Self { + SDLError { + message: value.message.to_string(), + trace: value.trace.iter().map(|e| e.to_string()).collect(), + description: value.description, + } } - } } #[derive(Debug, Default)] struct GraphQLQuerySpec { - query: String, - expected: Value, + query: String, + expected: Value, } impl GraphQLSpec { - fn query(mut self, query: String, expected: Value) -> Self { - self.test_queries.push(GraphQLQuerySpec { query, expected }); - self - } - - fn new(path: PathBuf, content: &str) -> GraphQLSpec { - INIT.call_once(|| { - env_logger::builder() - .filter(Some("graphql_spec"), log::LevelFilter::Info) - .init(); - }); - - let mut spec = GraphQLSpec::default().path(path); - let mut server_sdl = Vec::new(); - for component in content.split("#>") { - if component.contains(SPEC_ONLY) { - spec = spec.annotation(Some(Annotation::Only)); - } - if component.contains(SPEC_SKIP) { - spec = spec.annotation(Some(Annotation::Skip)); - } - if component.contains(SPEC_FAIL) { - spec = spec.annotation(Some(Annotation::Fail)); - } - if component.contains(CLIENT_SDL) { - let trimmed = component.replace(CLIENT_SDL, "").trim().to_string(); - - // Extract all errors - if trimmed.contains("@error") { - let doc = async_graphql::parser::parse_schema(trimmed.as_str()).unwrap(); - for def in doc.definitions { - if let TypeSystemDefinition::Type(type_def) = def { - for dir in type_def.node.directives { - if dir.node.name.node == "error" { - spec - .sdl_errors - .push(SDLError::from_directive(&dir.node).to_result().unwrap()); - } - } + fn query(mut self, query: String, expected: Value) -> Self { + self.test_queries.push(GraphQLQuerySpec { query, expected }); + self + } + + fn new(path: PathBuf, content: &str) -> GraphQLSpec { + INIT.call_once(|| { + env_logger::builder() + .filter(Some("graphql_spec"), log::LevelFilter::Info) + .init(); + }); + + let mut spec = GraphQLSpec::default().path(path); + let mut server_sdl = Vec::new(); + for component in content.split("#>") { + if component.contains(SPEC_ONLY) { + spec = spec.annotation(Some(Annotation::Only)); } - } - } + if component.contains(SPEC_SKIP) { + spec = spec.annotation(Some(Annotation::Skip)); + } + if component.contains(SPEC_FAIL) { + spec = spec.annotation(Some(Annotation::Fail)); + } + if component.contains(CLIENT_SDL) { + let trimmed = component.replace(CLIENT_SDL, "").trim().to_string(); + + // Extract all errors + if trimmed.contains("@error") { + let doc = async_graphql::parser::parse_schema(trimmed.as_str()).unwrap(); + for def in doc.definitions { + if let TypeSystemDefinition::Type(type_def) = def { + for dir in type_def.node.directives { + if dir.node.name.node == "error" { + spec.sdl_errors.push( + SDLError::from_directive(&dir.node).to_result().unwrap(), + ); + } + } + } + } + } - spec.sources.push(Source { sdl: trimmed.clone(), tag: Tag::ClientSDL }); - } - if component.contains(SERVER_SDL) { - server_sdl.push(component.replace(SERVER_SDL, "").trim().to_string()); - for s in &server_sdl { - spec.sources.push(Source { sdl: s.to_string(), tag: Tag::ServerSDL }) - } - } - if component.contains(MERGED_SDL) { - let sdl = component.replace(MERGED_SDL, "").trim().to_string(); - spec.sources.push(Source { sdl, tag: Tag::MergedSDL }); - } - if component.contains(CLIENT_QUERY) { - let regex = Regex::new(r"@expect.*\) ").unwrap(); - let query_string = component.replace(CLIENT_QUERY, ""); - let parsed_query = async_graphql::parser::parse_query(query_string.clone()).unwrap(); - - let query_string = regex.replace_all(query_string.as_str(), ""); - let query_string = query_string.trim(); - for (_, q) in parsed_query.operations.iter() { - let expect = q.node.directives.iter().find(|d| d.node.name.node == "expect"); - assert!( - expect.is_some(), - "@expect directive is required in query:\n```\n{}\n```", - query_string - ); - if let Some(dir) = expect { - let expected = dir - .node - .arguments - .iter() - .find(|a| a.0.node == "json") - .map(|a| a.clone().1.node.into_json().unwrap()) - .unwrap(); - spec = spec.query(query_string.to_string(), expected); - } + spec.sources + .push(Source { sdl: trimmed.clone(), tag: Tag::ClientSDL }); + } + if component.contains(SERVER_SDL) { + server_sdl.push(component.replace(SERVER_SDL, "").trim().to_string()); + for s in &server_sdl { + spec.sources + .push(Source { sdl: s.to_string(), tag: Tag::ServerSDL }) + } + } + if component.contains(MERGED_SDL) { + let sdl = component.replace(MERGED_SDL, "").trim().to_string(); + spec.sources.push(Source { sdl, tag: Tag::MergedSDL }); + } + if component.contains(CLIENT_QUERY) { + let regex = Regex::new(r"@expect.*\) ").unwrap(); + let query_string = component.replace(CLIENT_QUERY, ""); + let parsed_query = + async_graphql::parser::parse_query(query_string.clone()).unwrap(); + + let query_string = regex.replace_all(query_string.as_str(), ""); + let query_string = query_string.trim(); + for (_, q) in parsed_query.operations.iter() { + let expect = q + .node + .directives + .iter() + .find(|d| d.node.name.node == "expect"); + assert!( + expect.is_some(), + "@expect directive is required in query:\n```\n{}\n```", + query_string + ); + if let Some(dir) = expect { + let expected = dir + .node + .arguments + .iter() + .find(|a| a.0.node == "json") + .map(|a| a.clone().1.node.into_json().unwrap()) + .unwrap(); + spec = spec.query(query_string.to_string(), expected); + } + } + } } - } + spec } - spec - } - - fn cargo_read(path: &str) -> std::io::Result> { - let mut dir_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - dir_path.push(path); - - let entries = fs::read_dir(dir_path.clone())?; - let mut files = Vec::new(); - let mut only_files = Vec::new(); - - for entry in entries { - let path = entry?.path(); - if path.is_file() && path.extension().unwrap_or_default() == "graphql" { - let contents = fs::read_to_string(path.clone())?; - let path_buf = path.clone(); - let spec = GraphQLSpec::new(path_buf, contents.as_str()); - - match spec.annotation { - Some(Annotation::Only) => only_files.push(spec), - Some(Annotation::Fail) | None => files.push(spec), - Some(Annotation::Skip) => { - log::warn!("{} ... skipped", spec.path.display()); - } + + fn cargo_read(path: &str) -> std::io::Result> { + let mut dir_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + dir_path.push(path); + + let entries = fs::read_dir(dir_path.clone())?; + let mut files = Vec::new(); + let mut only_files = Vec::new(); + + for entry in entries { + let path = entry?.path(); + if path.is_file() && path.extension().unwrap_or_default() == "graphql" { + let contents = fs::read_to_string(path.clone())?; + let path_buf = path.clone(); + let spec = GraphQLSpec::new(path_buf, contents.as_str()); + + match spec.annotation { + Some(Annotation::Only) => only_files.push(spec), + Some(Annotation::Fail) | None => files.push(spec), + Some(Annotation::Skip) => { + log::warn!("{} ... skipped", spec.path.display()); + } + } + } } - } - } - assert!( - !files.is_empty() || !only_files.is_empty(), - "No files found in {}", - dir_path.to_str().unwrap_or_default() - ); + assert!( + !files.is_empty() || !only_files.is_empty(), + "No files found in {}", + dir_path.to_str().unwrap_or_default() + ); - if !only_files.is_empty() { - Ok(only_files) - } else { - Ok(files) + if !only_files.is_empty() { + Ok(only_files) + } else { + Ok(files) + } } - } } const CLIENT_SDL: &str = "client-sdl"; @@ -236,163 +242,210 @@ const SPEC_FAIL: &str = "spec-fail"; // Check if SDL -> Config -> SDL is identity #[test] fn test_config_identity() -> std::io::Result<()> { - let specs = GraphQLSpec::cargo_read("tests/graphql"); - - for spec in specs? { - let content = spec.find_source(Tag::ServerSDL); - let content = content.as_str(); - let expected = content; - let config = Config::from_sdl(content).to_result().unwrap(); - let actual = config.to_sdl(); - - if spec.annotation.as_ref().is_some_and(|a| matches!(a, Annotation::Fail)) { - assert_ne!(actual, expected, "ServerSDLIdentity: {}", spec.path.display()); - } else { - assert_eq!(actual, expected, "ServerSDLIdentity: {}", spec.path.display()); - } + let specs = GraphQLSpec::cargo_read("tests/graphql"); + + for spec in specs? { + let content = spec.find_source(Tag::ServerSDL); + let content = content.as_str(); + let expected = content; + let config = Config::from_sdl(content).to_result().unwrap(); + let actual = config.to_sdl(); + + if spec + .annotation + .as_ref() + .is_some_and(|a| matches!(a, Annotation::Fail)) + { + assert_ne!( + actual, + expected, + "ServerSDLIdentity: {}", + spec.path.display() + ); + } else { + assert_eq!( + actual, + expected, + "ServerSDLIdentity: {}", + spec.path.display() + ); + } - log::info!("ServerSDLIdentity: {} ... ok", spec.path.display()); - } + log::info!("ServerSDLIdentity: {} ... ok", spec.path.display()); + } - Ok(()) + Ok(()) } // Check server SDL matches expected client SDL #[test] fn test_server_to_client_sdl() -> std::io::Result<()> { - let specs = GraphQLSpec::cargo_read("tests/graphql"); - - for spec in specs? { - let expected = spec.find_source(Tag::ClientSDL); - let expected = expected.as_str(); - let content = spec.find_source(Tag::ServerSDL); - let content = content.as_str(); - let config = Config::from_sdl(content).to_result().unwrap(); - let actual = print_schema::print_schema((Blueprint::try_from(&config).unwrap()).to_schema()); - - if spec.annotation.as_ref().is_some_and(|a| matches!(a, Annotation::Fail)) { - assert_ne!(actual, expected, "ClientSDL: {}", spec.path.display()); - } else { - assert_eq!(actual, expected, "ClientSDL: {}", spec.path.display()); - } + let specs = GraphQLSpec::cargo_read("tests/graphql"); + + for spec in specs? { + let expected = spec.find_source(Tag::ClientSDL); + let expected = expected.as_str(); + let content = spec.find_source(Tag::ServerSDL); + let content = content.as_str(); + let config = Config::from_sdl(content).to_result().unwrap(); + let actual = + print_schema::print_schema((Blueprint::try_from(&config).unwrap()).to_schema()); + + if spec + .annotation + .as_ref() + .is_some_and(|a| matches!(a, Annotation::Fail)) + { + assert_ne!(actual, expected, "ClientSDL: {}", spec.path.display()); + } else { + assert_eq!(actual, expected, "ClientSDL: {}", spec.path.display()); + } - log::info!("ClientSDL: {} ... ok", spec.path.display()); - } + log::info!("ClientSDL: {} ... ok", spec.path.display()); + } - Ok(()) + Ok(()) } // Check if execution gives expected response #[tokio::test] async fn test_execution() -> std::io::Result<()> { - let specs = GraphQLSpec::cargo_read("tests/graphql/passed"); - - let tasks: Vec<_> = specs? - .into_iter() - .map(|spec| { - tokio::spawn(async move { - let mut config = Config::from_sdl(spec.find_source(Tag::ServerSDL).as_str()) - .to_result() - .unwrap(); - config.server.query_validation = Some(false); - - let blueprint = Valid::from(Blueprint::try_from(&config)) - .trace(spec.path.to_str().unwrap_or_default()) - .to_result() - .unwrap(); - let h_client = Arc::new(init_http(&blueprint.upstream)); - let h2_client = Arc::new(init_http(&blueprint.upstream)); - let chrono_cache = init_chrono_cache(); - let server_ctx = AppContext::new( - blueprint, - h_client, - h2_client, - Arc::new(init_env()), - Arc::new(chrono_cache), - ); - let schema = &server_ctx.schema; - - for q in spec.test_queries { - let mut headers = HeaderMap::new(); - headers.insert(HeaderName::from_static("authorization"), HeaderValue::from_static("1")); - let req_ctx = Arc::new(RequestContext::from(&server_ctx).req_headers(headers)); - let req = Request::from(q.query.as_str()).data(req_ctx.clone()); - let res = schema.execute(req).await; - let json = serde_json::to_string(&res).unwrap(); - let expected = serde_json::to_string(&q.expected).unwrap(); - - if spec.annotation.as_ref().is_some_and(|a| matches!(a, Annotation::Fail)) { - assert_ne!(json, expected, "QueryExecution: {}", spec.path.display()); - } else { - assert_eq!(json, expected, "QueryExecution: {}", spec.path.display()); - } - - log::info!("QueryExecution: {} ... ok", spec.path.display()); - } - }) - }) - .collect(); + let specs = GraphQLSpec::cargo_read("tests/graphql/passed"); + + let tasks: Vec<_> = specs? + .into_iter() + .map(|spec| { + tokio::spawn(async move { + let mut config = Config::from_sdl(spec.find_source(Tag::ServerSDL).as_str()) + .to_result() + .unwrap(); + config.server.query_validation = Some(false); + + let blueprint = Valid::from(Blueprint::try_from(&config)) + .trace(spec.path.to_str().unwrap_or_default()) + .to_result() + .unwrap(); + let h_client = Arc::new(init_http(&blueprint.upstream)); + let h2_client = Arc::new(init_http(&blueprint.upstream)); + let chrono_cache = init_chrono_cache(); + let server_ctx = AppContext::new( + blueprint, + h_client, + h2_client, + Arc::new(init_env()), + Arc::new(chrono_cache), + ); + let schema = &server_ctx.schema; + + for q in spec.test_queries { + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("authorization"), + HeaderValue::from_static("1"), + ); + let req_ctx = Arc::new(RequestContext::from(&server_ctx).req_headers(headers)); + let req = Request::from(q.query.as_str()).data(req_ctx.clone()); + let res = schema.execute(req).await; + let json = serde_json::to_string(&res).unwrap(); + let expected = serde_json::to_string(&q.expected).unwrap(); + + if spec + .annotation + .as_ref() + .is_some_and(|a| matches!(a, Annotation::Fail)) + { + assert_ne!(json, expected, "QueryExecution: {}", spec.path.display()); + } else { + assert_eq!(json, expected, "QueryExecution: {}", spec.path.display()); + } + + log::info!("QueryExecution: {} ... ok", spec.path.display()); + } + }) + }) + .collect(); - join_all(tasks).await; + join_all(tasks).await; - Ok(()) + Ok(()) } // Standardize errors on Client SDL #[test] fn test_failures_in_client_sdl() -> std::io::Result<()> { - let specs = GraphQLSpec::cargo_read("tests/graphql/errors"); - - for spec in specs? { - let content = spec.find_source(Tag::ServerSDL); - let expected = spec.sdl_errors; - let content = content.as_str(); - let config = Config::from_sdl(content); - - let actual = config - .and_then(|config| Valid::from(Blueprint::try_from(&config))) - .to_result(); - match actual { - Err(cause) => { - let actual: Vec = cause.as_vec().iter().map(|e| e.to_owned().into()).collect(); - - if spec.annotation.as_ref().is_some_and(|a| matches!(a, Annotation::Fail)) { - assert_ne!(actual, expected, "Server SDL failure match: {}", spec.path.display()); - } else { - assert_eq!(actual, expected, "Server SDL failure mismatch: {}", spec.path.display()); - } + let specs = GraphQLSpec::cargo_read("tests/graphql/errors"); + + for spec in specs? { + let content = spec.find_source(Tag::ServerSDL); + let expected = spec.sdl_errors; + let content = content.as_str(); + let config = Config::from_sdl(content); + + let actual = config + .and_then(|config| Valid::from(Blueprint::try_from(&config))) + .to_result(); + match actual { + Err(cause) => { + let actual: Vec = + cause.as_vec().iter().map(|e| e.to_owned().into()).collect(); + + if spec + .annotation + .as_ref() + .is_some_and(|a| matches!(a, Annotation::Fail)) + { + assert_ne!( + actual, + expected, + "Server SDL failure match: {}", + spec.path.display() + ); + } else { + assert_eq!( + actual, + expected, + "Server SDL failure mismatch: {}", + spec.path.display() + ); + } - log::info!("ClientSDLError: {} ... ok", spec.path.display()); - } - _ => panic!("ClientSDLError: {}", spec.path.display()), + log::info!("ClientSDLError: {} ... ok", spec.path.display()); + } + _ => panic!("ClientSDLError: {}", spec.path.display()), + } } - } - Ok(()) + Ok(()) } #[test] fn test_merge_sdl() -> std::io::Result<()> { - let specs = GraphQLSpec::cargo_read("tests/graphql/merge"); - - for spec in specs? { - let expected = spec.find_source(Tag::MergedSDL); - let expected = expected.as_str(); - let content = spec - .get_sources(Tag::ServerSDL) - .map(|s| Config::from_sdl(s).to_result().unwrap()) - .collect::>(); - let config = content.iter().fold(Config::default(), |acc, c| acc.merge_right(c)); - let actual = config.to_sdl(); - - if spec.annotation.as_ref().is_some_and(|a| matches!(a, Annotation::Fail)) { - assert_ne!(actual, expected, "SDLMerge: {}", spec.path.display()); - } else { - assert_eq!(actual, expected, "SDLMerge: {}", spec.path.display()); - } + let specs = GraphQLSpec::cargo_read("tests/graphql/merge"); + + for spec in specs? { + let expected = spec.find_source(Tag::MergedSDL); + let expected = expected.as_str(); + let content = spec + .get_sources(Tag::ServerSDL) + .map(|s| Config::from_sdl(s).to_result().unwrap()) + .collect::>(); + let config = content + .iter() + .fold(Config::default(), |acc, c| acc.merge_right(c)); + let actual = config.to_sdl(); + + if spec + .annotation + .as_ref() + .is_some_and(|a| matches!(a, Annotation::Fail)) + { + assert_ne!(actual, expected, "SDLMerge: {}", spec.path.display()); + } else { + assert_eq!(actual, expected, "SDLMerge: {}", spec.path.display()); + } - log::info!("SDLMerge: {} ... ok", spec.path.display()); - } + log::info!("SDLMerge: {} ... ok", spec.path.display()); + } - Ok(()) + Ok(()) } diff --git a/tests/http_spec.rs b/tests/http_spec.rs index ab2719a4a60..f9dfcbaac80 100644 --- a/tests/http_spec.rs +++ b/tests/http_spec.rs @@ -30,52 +30,52 @@ static INIT: Once = Once::new(); #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] enum Annotation { - Skip, - Only, - Fail, + Skip, + Only, + Fail, } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(rename_all = "camelCase")] struct APIRequest { - #[serde(default)] - method: Method, - url: Url, - #[serde(default)] - headers: BTreeMap, - #[serde(default)] - body: serde_json::Value, + #[serde(default)] + method: Method, + url: Url, + #[serde(default)] + headers: BTreeMap, + #[serde(default)] + body: serde_json::Value, } #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] struct APIResponse { - #[serde(default = "default_status")] - status: u16, - #[serde(default)] - headers: BTreeMap, - #[serde(default)] - body: serde_json::Value, + #[serde(default = "default_status")] + status: u16, + #[serde(default)] + headers: BTreeMap, + #[serde(default)] + body: serde_json::Value, } pub struct Env { - env: HashMap, + env: HashMap, } impl EnvIO for Env { - fn get(&self, key: &str) -> Option { - self.env.get(key).cloned() - } + fn get(&self, key: &str) -> Option { + self.env.get(key).cloned() + } } impl Env { - pub fn init(map: HashMap) -> Self { - Self { env: map } - } + pub fn init(map: HashMap) -> Self { + Self { env: map } + } } fn default_status() -> u16 { - 200 + 200 } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] @@ -92,335 +92,353 @@ struct DownstreamResponse(APIResponse); #[derive(Serialize, Deserialize, Clone, Debug)] struct DownstreamAssertion { - request: DownstreamRequest, - response: DownstreamResponse, + request: DownstreamRequest, + response: DownstreamResponse, } #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] enum ConfigSource { - File(String), - Inline(Config), + File(String), + Inline(Config), } #[derive(Serialize, Deserialize, Clone, Debug)] struct Mock { - request: UpstreamRequest, - response: UpstreamResponse, + request: UpstreamRequest, + response: UpstreamResponse, } #[derive(Serialize, Deserialize, Clone, Setters, Debug)] #[serde(rename_all = "camelCase")] struct HttpSpec { - config: ConfigSource, - #[serde(skip)] - path: PathBuf, - name: String, - description: Option, + config: ConfigSource, + #[serde(skip)] + path: PathBuf, + name: String, + description: Option, - #[serde(default)] - mock: Vec, + #[serde(default)] + mock: Vec, - #[serde(default)] - env: HashMap, + #[serde(default)] + env: HashMap, - #[serde(default)] - expected_upstream_requests: Vec, - assert: Vec, + #[serde(default)] + expected_upstream_requests: Vec, + assert: Vec, - // Annotations for the runner - runner: Option, + // Annotations for the runner + runner: Option, } impl HttpSpec { - fn cargo_read(path: &str) -> anyhow::Result> { - let dir_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(path); - let mut files = Vec::new(); - - for entry in fs::read_dir(&dir_path)? { - let path = entry?.path(); - if path.is_dir() { - continue; - } - let source = Source::detect(path.to_str().unwrap_or_default())?; - if path.is_file() && (source.ext() == "json" || source.ext() == "yml") { - let contents = fs::read_to_string(&path)?; - let spec: HttpSpec = - Self::from_source(source, contents).map_err(|err| err.context(path.to_str().unwrap().to_string()))?; - - files.push(spec.path(path)); - } + fn cargo_read(path: &str) -> anyhow::Result> { + let dir_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(path); + let mut files = Vec::new(); + + for entry in fs::read_dir(&dir_path)? { + let path = entry?.path(); + if path.is_dir() { + continue; + } + let source = Source::detect(path.to_str().unwrap_or_default())?; + if path.is_file() && (source.ext() == "json" || source.ext() == "yml") { + let contents = fs::read_to_string(&path)?; + let spec: HttpSpec = Self::from_source(source, contents) + .map_err(|err| err.context(path.to_str().unwrap().to_string()))?; + + files.push(spec.path(path)); + } + } + + assert!( + !files.is_empty(), + "No files found in {}", + dir_path.to_str().unwrap_or_default() + ); + Ok(files) } - assert!( - !files.is_empty(), - "No files found in {}", - dir_path.to_str().unwrap_or_default() - ); - Ok(files) - } - - fn filter_specs(specs: Vec) -> Vec { - let mut only_specs = Vec::new(); - let mut filtered_specs = Vec::new(); - - for spec in specs { - match spec.runner { - Some(Annotation::Skip) => log::warn!("{} {} ... skipped", spec.name, spec.path.display()), - Some(Annotation::Only) => only_specs.push(spec), - Some(Annotation::Fail) => filtered_specs.push(spec), - None => filtered_specs.push(spec), - } + fn filter_specs(specs: Vec) -> Vec { + let mut only_specs = Vec::new(); + let mut filtered_specs = Vec::new(); + + for spec in specs { + match spec.runner { + Some(Annotation::Skip) => { + log::warn!("{} {} ... skipped", spec.name, spec.path.display()) + } + Some(Annotation::Only) => only_specs.push(spec), + Some(Annotation::Fail) => filtered_specs.push(spec), + None => filtered_specs.push(spec), + } + } + + // If any spec has the Only annotation, use those; otherwise, use the filtered list. + if !only_specs.is_empty() { + only_specs + } else { + filtered_specs + } + } + fn from_source(source: Source, contents: String) -> anyhow::Result { + INIT.call_once(|| { + env_logger::builder() + .filter(Some("http_spec"), log::LevelFilter::Info) + .init(); + }); + + let spec: HttpSpec = match source { + Source::Json => anyhow::Ok(serde_json::from_str(&contents)?), + Source::Yml => anyhow::Ok(serde_yaml::from_str(&contents)?), + _ => Err(anyhow!("only json and yaml are supported")), + }?; + + anyhow::Ok(spec) } - // If any spec has the Only annotation, use those; otherwise, use the filtered list. - if !only_specs.is_empty() { - only_specs - } else { - filtered_specs + async fn server_context(&self) -> Arc> { + let http_client = init_http(&Upstream::default()); + let config = match self.config.clone() { + ConfigSource::File(file) => { + let reader = ConfigReader::init(init_file(), http_client); + reader.read(&[file]).await.unwrap() + } + ConfigSource::Inline(config) => config, + }; + let blueprint = Blueprint::try_from(&config).unwrap(); + let client = Arc::new(MockHttpClient { spec: self.clone() }); + let http2_client = Arc::new(MockHttpClient { spec: self.clone() }); + let env = Arc::new(Env::init(self.env.clone())); + let chrono_cache = Arc::new(init_chrono_cache()); + let server_context = AppContext::new(blueprint, client, http2_client, env, chrono_cache); + Arc::new(server_context) } - } - fn from_source(source: Source, contents: String) -> anyhow::Result { - INIT.call_once(|| { - env_logger::builder() - .filter(Some("http_spec"), log::LevelFilter::Info) - .init(); - }); - - let spec: HttpSpec = match source { - Source::Json => anyhow::Ok(serde_json::from_str(&contents)?), - Source::Yml => anyhow::Ok(serde_yaml::from_str(&contents)?), - _ => Err(anyhow!("only json and yaml are supported")), - }?; - - anyhow::Ok(spec) - } - - async fn server_context(&self) -> Arc> { - let http_client = init_http(&Upstream::default()); - let config = match self.config.clone() { - ConfigSource::File(file) => { - let reader = ConfigReader::init(init_file(), http_client); - reader.read(&[file]).await.unwrap() - } - ConfigSource::Inline(config) => config, - }; - let blueprint = Blueprint::try_from(&config).unwrap(); - let client = Arc::new(MockHttpClient { spec: self.clone() }); - let http2_client = Arc::new(MockHttpClient { spec: self.clone() }); - let env = Arc::new(Env::init(self.env.clone())); - let chrono_cache = Arc::new(init_chrono_cache()); - let server_context = AppContext::new(blueprint, client, http2_client, env, chrono_cache); - Arc::new(server_context) - } } #[derive(Clone)] struct MockHttpClient { - spec: HttpSpec, + spec: HttpSpec, } fn string_to_bytes(input: &str) -> Vec { - let mut bytes = Vec::new(); - let mut chars = input.chars().peekable(); - - while let Some(c) = chars.next() { - match c { - '\\' => match chars.next() { - Some('0') => bytes.push(0), - Some('n') => bytes.push(b'\n'), - Some('t') => bytes.push(b'\t'), - Some('r') => bytes.push(b'\r'), - Some('\\') => bytes.push(b'\\'), - Some('\"') => bytes.push(b'\"'), - Some('x') => { - let mut hex = chars.next().unwrap().to_string(); - hex.push(chars.next().unwrap()); - let byte = u8::from_str_radix(&hex, 16).unwrap(); - bytes.push(byte); + let mut bytes = Vec::new(); + let mut chars = input.chars().peekable(); + + while let Some(c) = chars.next() { + match c { + '\\' => match chars.next() { + Some('0') => bytes.push(0), + Some('n') => bytes.push(b'\n'), + Some('t') => bytes.push(b'\t'), + Some('r') => bytes.push(b'\r'), + Some('\\') => bytes.push(b'\\'), + Some('\"') => bytes.push(b'\"'), + Some('x') => { + let mut hex = chars.next().unwrap().to_string(); + hex.push(chars.next().unwrap()); + let byte = u8::from_str_radix(&hex, 16).unwrap(); + bytes.push(byte); + } + _ => panic!("Unsupported escape sequence"), + }, + _ => bytes.push(c as u8), } - _ => panic!("Unsupported escape sequence"), - }, - _ => bytes.push(c as u8), } - } - bytes + bytes } #[async_trait::async_trait] impl HttpIO for MockHttpClient { - async fn execute(&self, req: reqwest::Request) -> anyhow::Result> { - let mocks = self.spec.mock.clone(); - - // Determine if the request is a GRPC request based on PORT - let is_grpc = req.url().as_str().contains("50051"); - - // Try to find a matching mock for the incoming request. - let mock = mocks - .iter() - .find(|Mock { request: mock_req, response: _ }| { - let method_match = req.method() == mock_req.0.method.clone().to_hyper(); - let url_match = req.url().as_str() == mock_req.0.url.clone().as_str(); - let req_body = match req.body() { - Some(body) => { - if let Some(bytes) = body.as_bytes() { - if let Ok(body_str) = std::str::from_utf8(bytes) { - Value::from(body_str) - } else { - Value::Null - } - } else { - Value::Null - } - } - None => Value::Null, - }; - let body_match = req_body == mock_req.0.body; - method_match && url_match && (body_match || is_grpc) - }) - .ok_or(anyhow!( - "No mock found for request: {:?} {} in {}", - req.method(), - req.url(), - format!("{}", self.spec.path.to_str().unwrap()) - ))?; - - // Clone the response from the mock to avoid borrowing issues. - let mock_response = mock.response.clone(); - - // Build the response with the status code from the mock. - let status_code = reqwest::StatusCode::from_u16(mock_response.0.status)?; - - if status_code.is_client_error() || status_code.is_server_error() { - return Err(anyhow::format_err!("Status code error")); - } + async fn execute(&self, req: reqwest::Request) -> anyhow::Result> { + let mocks = self.spec.mock.clone(); + + // Determine if the request is a GRPC request based on PORT + let is_grpc = req.url().as_str().contains("50051"); + + // Try to find a matching mock for the incoming request. + let mock = mocks + .iter() + .find(|Mock { request: mock_req, response: _ }| { + let method_match = req.method() == mock_req.0.method.clone().to_hyper(); + let url_match = req.url().as_str() == mock_req.0.url.clone().as_str(); + let req_body = match req.body() { + Some(body) => { + if let Some(bytes) = body.as_bytes() { + if let Ok(body_str) = std::str::from_utf8(bytes) { + Value::from(body_str) + } else { + Value::Null + } + } else { + Value::Null + } + } + None => Value::Null, + }; + let body_match = req_body == mock_req.0.body; + method_match && url_match && (body_match || is_grpc) + }) + .ok_or(anyhow!( + "No mock found for request: {:?} {} in {}", + req.method(), + req.url(), + format!("{}", self.spec.path.to_str().unwrap()) + ))?; + + // Clone the response from the mock to avoid borrowing issues. + let mock_response = mock.response.clone(); + + // Build the response with the status code from the mock. + let status_code = reqwest::StatusCode::from_u16(mock_response.0.status)?; + + if status_code.is_client_error() || status_code.is_server_error() { + return Err(anyhow::format_err!("Status code error")); + } - let mut response = Response { status: status_code, ..Default::default() }; + let mut response = Response { status: status_code, ..Default::default() }; - // Insert headers from the mock into the response. - for (key, value) in mock_response.0.headers { - let header_name = HeaderName::from_str(&key)?; - let header_value = HeaderValue::from_str(&value)?; - response.headers.insert(header_name, header_value); - } + // Insert headers from the mock into the response. + for (key, value) in mock_response.0.headers { + let header_name = HeaderName::from_str(&key)?; + let header_value = HeaderValue::from_str(&value)?; + response.headers.insert(header_name, header_value); + } - // Special Handling for GRPC - if is_grpc { - let body = string_to_bytes(mock_response.0.body.as_str().unwrap()); - response.body = Bytes::from_iter(body); - Ok(response) - } else { - let body = if let Value::String(x) = mock_response.0.body { - string_to_bytes(&x) - } else { - serde_json::to_vec(&mock_response.0.body)? - }; - response.body = Bytes::from_iter(body); - Ok(response) + // Special Handling for GRPC + if is_grpc { + let body = string_to_bytes(mock_response.0.body.as_str().unwrap()); + response.body = Bytes::from_iter(body); + Ok(response) + } else { + let body = if let Value::String(x) = mock_response.0.body { + string_to_bytes(&x) + } else { + serde_json::to_vec(&mock_response.0.body)? + }; + response.body = Bytes::from_iter(body); + Ok(response) + } } - } } async fn assert_downstream(spec: HttpSpec) { - for assertion in spec.assert.iter() { - if let Some(Annotation::Fail) = spec.runner { - let response = run(spec.clone(), &assertion).await.unwrap(); - let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - let result = panic::catch_unwind(AssertUnwindSafe(|| { - assert_eq!( - body, - serde_json::to_string(&assertion.response.0.body).unwrap(), - "File: {} {}", - spec.name, - spec.path.display() - ); - })); - - match result { - Ok(_) => { - panic!( - "Expected spec: {} {} to fail but it passed", - spec.name, - spec.path.display() - ); - } - Err(_) => { - log::info!("{} {} ... failed (expected)", spec.name, spec.path.display()); - } - } - } else { - let response = run(spec.clone(), &assertion) - .await - .context(spec.path.to_str().unwrap().to_string()) - .unwrap(); - let actual_status = response.status().clone().as_u16(); - let actual_headers = response.headers().clone(); - let actual_body = hyper::body::to_bytes(response.into_body()).await.unwrap(); - - // Assert Status - assert_eq!( - actual_status, - assertion.response.0.status, - "File: {} {}", - spec.name, - spec.path.display() - ); - - // Assert Body - assert_eq!( - to_json_pretty(actual_body).unwrap(), - serde_json::to_string_pretty(&assertion.response.0.body).unwrap(), - "File: {} {}", - spec.name, - spec.path.display() - ); - - // Assert Headers - for (key, value) in assertion.response.0.headers.iter() { - match actual_headers.get(key) { - None => panic!("Expected header {} to be present", key), - Some(actual_value) => assert_eq!(actual_value, value, "File: {} {}", spec.name, spec.path.display()), + for assertion in spec.assert.iter() { + if let Some(Annotation::Fail) = spec.runner { + let response = run(spec.clone(), &assertion).await.unwrap(); + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let result = panic::catch_unwind(AssertUnwindSafe(|| { + assert_eq!( + body, + serde_json::to_string(&assertion.response.0.body).unwrap(), + "File: {} {}", + spec.name, + spec.path.display() + ); + })); + + match result { + Ok(_) => { + panic!( + "Expected spec: {} {} to fail but it passed", + spec.name, + spec.path.display() + ); + } + Err(_) => { + log::info!( + "{} {} ... failed (expected)", + spec.name, + spec.path.display() + ); + } + } + } else { + let response = run(spec.clone(), &assertion) + .await + .context(spec.path.to_str().unwrap().to_string()) + .unwrap(); + let actual_status = response.status().clone().as_u16(); + let actual_headers = response.headers().clone(); + let actual_body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + + // Assert Status + assert_eq!( + actual_status, + assertion.response.0.status, + "File: {} {}", + spec.name, + spec.path.display() + ); + + // Assert Body + assert_eq!( + to_json_pretty(actual_body).unwrap(), + serde_json::to_string_pretty(&assertion.response.0.body).unwrap(), + "File: {} {}", + spec.name, + spec.path.display() + ); + + // Assert Headers + for (key, value) in assertion.response.0.headers.iter() { + match actual_headers.get(key) { + None => panic!("Expected header {} to be present", key), + Some(actual_value) => assert_eq!( + actual_value, + value, + "File: {} {}", + spec.name, + spec.path.display() + ), + } + } } - } } - } - log::info!("{} {} ... ok", spec.name, spec.path.display()); + log::info!("{} {} ... ok", spec.name, spec.path.display()); } fn to_json_pretty(bytes: Bytes) -> anyhow::Result { - let body_str = String::from_utf8(bytes.to_vec())?; - let json: Value = serde_json::from_str(&body_str)?; - Ok(serde_json::to_string_pretty(&json)?) + let body_str = String::from_utf8(bytes.to_vec())?; + let json: Value = serde_json::from_str(&body_str)?; + Ok(serde_json::to_string_pretty(&json)?) } #[tokio::test] async fn http_spec_e2e() -> anyhow::Result<()> { - let spec = HttpSpec::cargo_read("tests/http")?; - let spec = HttpSpec::filter_specs(spec); - let tasks: Vec<_> = spec.into_iter().map(assert_downstream).collect(); - join_all(tasks).await; - Ok(()) + let spec = HttpSpec::cargo_read("tests/http")?; + let spec = HttpSpec::filter_specs(spec); + let tasks: Vec<_> = spec.into_iter().map(assert_downstream).collect(); + join_all(tasks).await; + Ok(()) } -async fn run(spec: HttpSpec, downstream_assertion: &&DownstreamAssertion) -> anyhow::Result> { - let query_string = serde_json::to_string(&downstream_assertion.request.0.body).expect("body is required"); - let method = downstream_assertion.request.0.method.clone(); - let headers = downstream_assertion.request.0.headers.clone(); - let url = downstream_assertion.request.0.url.clone(); - let server_context = spec.server_context().await; - let req = headers - .into_iter() - .fold( - Request::builder().method(method.to_hyper()).uri(url.as_str()), - |acc, (key, value)| acc.header(key, value), - ) - .body(Body::from(query_string))?; - - // TODO: reuse logic from server.rs to select the correct handler - if server_context.blueprint.server.enable_batch_requests { - handle_request::(req, server_context).await - } else { - handle_request::(req, server_context).await - } +async fn run( + spec: HttpSpec, + downstream_assertion: &&DownstreamAssertion, +) -> anyhow::Result> { + let query_string = + serde_json::to_string(&downstream_assertion.request.0.body).expect("body is required"); + let method = downstream_assertion.request.0.method.clone(); + let headers = downstream_assertion.request.0.headers.clone(); + let url = downstream_assertion.request.0.url.clone(); + let server_context = spec.server_context().await; + let req = headers + .into_iter() + .fold( + Request::builder() + .method(method.to_hyper()) + .uri(url.as_str()), + |acc, (key, value)| acc.header(key, value), + ) + .body(Body::from(query_string))?; + + // TODO: reuse logic from server.rs to select the correct handler + if server_context.blueprint.server.enable_batch_requests { + handle_request::(req, server_context).await + } else { + handle_request::(req, server_context).await + } } diff --git a/tests/server_spec.rs b/tests/server_spec.rs index 329178877df..b56c93ed6a5 100644 --- a/tests/server_spec.rs +++ b/tests/server_spec.rs @@ -6,94 +6,103 @@ use tailcall::config::reader::ConfigReader; use tailcall::config::Upstream; async fn test_server(configs: &[&str], url: &str) { - let http_client = init_http(&Upstream::default()); - let reader = ConfigReader::init(init_file(), http_client); - let config = reader.read(configs).await.unwrap(); - let mut server = Server::new(config); - let server_up_receiver = server.server_up_receiver(); + let http_client = init_http(&Upstream::default()); + let reader = ConfigReader::init(init_file(), http_client); + let config = reader.read(configs).await.unwrap(); + let mut server = Server::new(config); + let server_up_receiver = server.server_up_receiver(); - tokio::spawn(async move { - server.start().await.unwrap(); - }); + tokio::spawn(async move { + server.start().await.unwrap(); + }); - server_up_receiver.await.expect("Server did not start up correctly"); + server_up_receiver + .await + .expect("Server did not start up correctly"); - // required since our cert is self signed - let client = Client::builder().danger_accept_invalid_certs(true).build().unwrap(); - let query = json!({ - "query": "{ greet }" - }); + // required since our cert is self signed + let client = Client::builder() + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + let query = json!({ + "query": "{ greet }" + }); - let mut tasks = vec![]; - for _ in 0..100 { - let client = client.clone(); - let url = url.to_owned(); - let query = query.clone(); + let mut tasks = vec![]; + for _ in 0..100 { + let client = client.clone(); + let url = url.to_owned(); + let query = query.clone(); - let task: tokio::task::JoinHandle> = tokio::spawn(async move { - let response = client.post(url).json(&query).send().await?; - let response_body: serde_json::Value = response.json().await?; - Ok(response_body) - }); - tasks.push(task); - } + let task: tokio::task::JoinHandle> = + tokio::spawn(async move { + let response = client.post(url).json(&query).send().await?; + let response_body: serde_json::Value = response.json().await?; + Ok(response_body) + }); + tasks.push(task); + } - for task in tasks { - let response_body = task - .await - .expect("Spawned task should success") - .expect("Request should success"); - let expected_response = json!({ - "data": { - "greet": "Hello World!" - } - }); - assert_eq!(response_body, expected_response, "Unexpected response from server"); - } + for task in tasks { + let response_body = task + .await + .expect("Spawned task should success") + .expect("Request should success"); + let expected_response = json!({ + "data": { + "greet": "Hello World!" + } + }); + assert_eq!( + response_body, expected_response, + "Unexpected response from server" + ); + } } #[tokio::test] async fn server_start() { - test_server( - &["tests/server/config/server-start.graphql"], - "http://localhost:8800/graphql", - ) - .await + test_server( + &["tests/server/config/server-start.graphql"], + "http://localhost:8800/graphql", + ) + .await } #[tokio::test] async fn server_start_http2_pcks8() { - test_server( - &["tests/server/config/server-start-http2-pkcs8.graphql"], - "https://localhost:8801/graphql", - ) - .await + test_server( + &["tests/server/config/server-start-http2-pkcs8.graphql"], + "https://localhost:8801/graphql", + ) + .await } #[tokio::test] async fn server_start_http2_rsa() { - test_server( - &["tests/server/config/server-start-http2-rsa.graphql"], - "https://localhost:8802/graphql", - ) - .await + test_server( + &["tests/server/config/server-start-http2-rsa.graphql"], + "https://localhost:8802/graphql", + ) + .await } #[tokio::test] async fn server_start_http2_nokey() { - let configs = &["tests/server/config/server-start-http2-nokey.graphql"]; - let http_client = init_http(&Upstream::default()); - let reader = ConfigReader::init(init_file(), http_client); - let config = reader.read(configs).await.unwrap(); - let server = Server::new(config); - assert!(server.start().await.is_err()) + let configs = &["tests/server/config/server-start-http2-nokey.graphql"]; + let http_client = init_http(&Upstream::default()); + let reader = ConfigReader::init(init_file(), http_client); + let config = reader.read(configs).await.unwrap(); + let server = Server::new(config); + assert!(server.start().await.is_err()) } #[tokio::test] async fn server_start_http2_ec() { - test_server( - &["tests/server/config/server-start-http2-ec.graphql"], - "https://localhost:8804/graphql", - ) - .await + test_server( + &["tests/server/config/server-start-http2-ec.graphql"], + "https://localhost:8804/graphql", + ) + .await }