From 652bc31451ed094d12eb7fcda0f349f59a652404 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Wed, 18 Dec 2024 11:41:44 +0000 Subject: [PATCH 1/3] refactor: drop async_graphql engine from executing queries --- benches/handle_request_bench.rs | 7 ++-- generated/.tailcallrc.schema.json | 6 ---- src/core/blueprint/server.rs | 2 -- src/core/config/directives/server.rs | 11 ++---- src/core/http/request_handler.rs | 35 ++++++------------- src/core/jit/builder.rs | 20 +++++------ src/core/jit/fixtures/jp.rs | 6 ++-- src/core/jit/graphql_executor.rs | 3 ++ src/core/jit/request.rs | 2 +- src/core/jit/synth/synth.rs | 2 +- tests/core/parse.rs | 6 +--- .../snapshots/test-enable-jit.md_merged.snap | 2 +- .../test-required-fields.md_merged.snap | 2 +- 13 files changed, 35 insertions(+), 69 deletions(-) diff --git a/benches/handle_request_bench.rs b/benches/handle_request_bench.rs index f33c52a895..88035b056e 100644 --- a/benches/handle_request_bench.rs +++ b/benches/handle_request_bench.rs @@ -16,13 +16,11 @@ pub fn benchmark_handle_request(c: &mut Criterion) { let sdl = std::fs::read_to_string("./ci-benchmark/benchmark.graphql").unwrap(); let config_module: ConfigModule = Config::from_sdl(sdl.as_str()).to_result().unwrap().into(); - let mut blueprint = Blueprint::try_from(&config_module).unwrap(); - let mut blueprint_clone = blueprint.clone(); + let blueprint = Blueprint::try_from(&config_module).unwrap(); let endpoints = config_module.extensions().endpoint_set.clone(); let endpoints_clone = endpoints.clone(); - blueprint.server.enable_jit = false; let server_config = tokio_runtime .block_on(ServerConfig::new(blueprint.clone(), endpoints.clone())) .unwrap(); @@ -47,9 +45,8 @@ pub fn benchmark_handle_request(c: &mut Criterion) { }) }); - blueprint_clone.server.enable_jit = true; let server_config = tokio_runtime - .block_on(ServerConfig::new(blueprint_clone, endpoints_clone)) + .block_on(ServerConfig::new(blueprint, endpoints_clone)) .unwrap(); let server_config = Arc::new(server_config); diff --git a/generated/.tailcallrc.schema.json b/generated/.tailcallrc.schema.json index 55fbf80df1..70ca3e9258 100644 --- a/generated/.tailcallrc.schema.json +++ b/generated/.tailcallrc.schema.json @@ -476,12 +476,6 @@ "null" ] }, - "enableJIT": { - "type": [ - "boolean", - "null" - ] - }, "globalResponseTimeout": { "description": "`globalResponseTimeout` sets the maximum query duration before termination, acting as a safeguard against long-running queries.", "type": [ diff --git a/src/core/blueprint/server.rs b/src/core/blueprint/server.rs index 75d5b3e186..e856154dbc 100644 --- a/src/core/blueprint/server.rs +++ b/src/core/blueprint/server.rs @@ -14,7 +14,6 @@ use crate::core::config::{self, ConfigModule, HttpVersion, PrivateKey, Routes}; #[derive(Clone, Debug, Setters)] pub struct Server { - pub enable_jit: bool, pub enable_apollo_tracing: bool, pub enable_cache_control_header: bool, pub enable_set_cookie_header: bool, @@ -124,7 +123,6 @@ impl TryFrom for Server { )) .map( |(hostname, http, response_headers, script, experimental_headers, cors)| Server { - enable_jit: (config_server).enable_jit(), enable_apollo_tracing: (config_server).enable_apollo_tracing(), enable_cache_control_header: (config_server).enable_cache_control(), enable_set_cookie_header: (config_server).enable_set_cookies(), diff --git a/src/core/config/directives/server.rs b/src/core/config/directives/server.rs index ec006e2023..b9e1897769 100644 --- a/src/core/config/directives/server.rs +++ b/src/core/config/directives/server.rs @@ -29,10 +29,9 @@ use crate::core::macros::MergeRight; /// comprehensive set of server configurations. It dictates how the server /// behaves and helps tune tailcall for various use-cases. pub struct Server { - // The `enableJIT` option activates Just-In-Time (JIT) compilation. When set to true, it - // optimizes execution of each incoming request independently, resulting in significantly - // better performance in most cases, it's enabled by default. - #[serde(default, skip_serializing_if = "is_default", rename = "enableJIT")] + #[deprecated(note = "No longer used, TODO: drop it")] + #[serde(default, skip_serializing, rename = "enableJIT")] + #[schemars(skip)] pub enable_jit: Option, #[serde(default, skip_serializing_if = "is_default")] @@ -262,10 +261,6 @@ impl Server { self.pipeline_flush.unwrap_or(true) } - pub fn enable_jit(&self) -> bool { - self.enable_jit.unwrap_or(true) - } - pub fn get_routes(&self) -> Routes { self.routes.clone().unwrap_or_default() } diff --git a/src/core/http/request_handler.rs b/src/core/http/request_handler.rs index 712929ee0d..e7ab5efae2 100644 --- a/src/core/http/request_handler.rs +++ b/src/core/http/request_handler.rs @@ -119,30 +119,17 @@ async fn execute_query( request: T, req: Parts, ) -> anyhow::Result> { - let mut response = if app_ctx.blueprint.server.enable_jit { - let operation_id = request.operation_id(&req.headers); - let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id); - request - .execute_with_jit(exec) - .await - .set_cache_control( - app_ctx.blueprint.server.enable_cache_control_header, - req_ctx.get_min_max_age().unwrap_or(0), - req_ctx.is_cache_public().unwrap_or(true), - ) - .into_response()? - } else { - request - .data(req_ctx.clone()) - .execute(&app_ctx.schema) - .await - .set_cache_control( - app_ctx.blueprint.server.enable_cache_control_header, - req_ctx.get_min_max_age().unwrap_or(0), - req_ctx.is_cache_public().unwrap_or(true), - ) - .into_response()? - }; + let operation_id = request.operation_id(&req.headers); + let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id); + let mut response = request + .execute_with_jit(exec) + .await + .set_cache_control( + app_ctx.blueprint.server.enable_cache_control_header, + req_ctx.get_min_max_age().unwrap_or(0), + req_ctx.is_cache_public().unwrap_or(true), + ) + .into_response()?; update_response_headers(&mut response, req_ctx, app_ctx); Ok(response) diff --git a/src/core/jit/builder.rs b/src/core/jit/builder.rs index 26ab6d1f66..04f99c45cb 100644 --- a/src/core/jit/builder.rs +++ b/src/core/jit/builder.rs @@ -64,16 +64,16 @@ impl Conditions { } } -pub struct Builder { +pub struct Builder<'a> { pub index: Arc, pub arg_id: Counter, pub field_id: Counter, - pub document: ExecutableDocument, + pub document: &'a ExecutableDocument, } // TODO: make generic over Value (Input) type -impl Builder { - pub fn new(blueprint: &Blueprint, document: ExecutableDocument) -> Self { +impl<'a> Builder<'a> { + pub fn new(blueprint: &Blueprint, document: &'a ExecutableDocument) -> Self { let index = Arc::new(blueprint.index()); Self { document, @@ -372,7 +372,7 @@ mod tests { let config = Config::from_sdl(CONFIG).to_result().unwrap(); let blueprint = Blueprint::try_from(&config.into()).unwrap(); let document = async_graphql::parser::parse_query(query).unwrap(); - Builder::new(&blueprint, document).build(None).unwrap() + Builder::new(&blueprint, &document).build(None).unwrap() } #[tokio::test] @@ -640,25 +640,23 @@ mod tests { let config = Config::from_sdl(CONFIG).to_result().unwrap(); let blueprint = Blueprint::try_from(&config.into()).unwrap(); let document = async_graphql::parser::parse_query(query).unwrap(); - let error = Builder::new(&blueprint, document.clone()) - .build(None) - .unwrap_err(); + let error = Builder::new(&blueprint, &document).build(None).unwrap_err(); assert_eq!(error, BuildError::OperationNameRequired); - let error = Builder::new(&blueprint, document.clone()) + let error = Builder::new(&blueprint, &document) .build(Some("unknown")) .unwrap_err(); assert_eq!(error, BuildError::OperationNotFound("unknown".to_string())); - let plan = Builder::new(&blueprint, document.clone()) + let plan = Builder::new(&blueprint, &document) .build(Some("GetPosts")) .unwrap(); assert!(plan.is_query()); insta::assert_debug_snapshot!(plan.selection); - let plan = Builder::new(&blueprint, document.clone()) + let plan = Builder::new(&blueprint, &document) .build(Some("CreateNewPost")) .unwrap(); assert!(!plan.is_query()); diff --git a/src/core/jit/fixtures/jp.rs b/src/core/jit/fixtures/jp.rs index 2d0c99b45c..f743fdc281 100644 --- a/src/core/jit/fixtures/jp.rs +++ b/src/core/jit/fixtures/jp.rs @@ -88,10 +88,8 @@ impl<'a, Value: Deserialize<'a> + Clone + 'a + JsonLike<'a> + std::fmt::Debug> J fn plan(query: &str, variables: &Variables) -> OperationPlan { let config = ConfigModule::from(Config::from_sdl(Self::CONFIG).to_result().unwrap()); - let builder = Builder::new( - &Blueprint::try_from(&config).unwrap(), - async_graphql::parser::parse_query(query).unwrap(), - ); + let doc = async_graphql::parser::parse_query(query).unwrap(); + let builder = Builder::new(&Blueprint::try_from(&config).unwrap(), &doc); let plan = builder.build(None).unwrap(); let plan = transform::Skip::new(variables) diff --git a/src/core/jit/graphql_executor.rs b/src/core/jit/graphql_executor.rs index 31afb8af0f..e1c131c374 100644 --- a/src/core/jit/graphql_executor.rs +++ b/src/core/jit/graphql_executor.rs @@ -75,6 +75,8 @@ impl JITExecutor { &self, request: async_graphql::Request, ) -> impl Future>> + Send + '_ { + // TODO: hash considering only the query itself ignoring specified operation and + // variables that could differ for the same query let hash = Self::req_hash(&request); async move { @@ -135,6 +137,7 @@ impl JITExecutor { } } +// TODO: used only for introspection, simplify somehow? impl From> for async_graphql::Request { fn from(value: jit::Request) -> Self { let mut request = async_graphql::Request::new(value.query); diff --git a/src/core/jit/request.rs b/src/core/jit/request.rs index cbd3bc0faf..f37c4a721e 100644 --- a/src/core/jit/request.rs +++ b/src/core/jit/request.rs @@ -42,7 +42,7 @@ impl Request { blueprint: &Blueprint, ) -> Result> { let doc = async_graphql::parser::parse_query(&self.query)?; - let builder = Builder::new(blueprint, doc); + let builder = Builder::new(blueprint, &doc); let plan = builder.build(self.operation_name.as_deref())?; transform::CheckConst::new() diff --git a/src/core/jit/synth/synth.rs b/src/core/jit/synth/synth.rs index e24b6c50a0..221fa95da2 100644 --- a/src/core/jit/synth/synth.rs +++ b/src/core/jit/synth/synth.rs @@ -345,7 +345,7 @@ mod tests { let config = Config::from_sdl(CONFIG).to_result().unwrap(); let config = ConfigModule::from(config); - let builder = Builder::new(&Blueprint::try_from(&config).unwrap(), doc); + let builder = Builder::new(&Blueprint::try_from(&config).unwrap(), &doc); let plan = builder.build(None).unwrap(); let plan = plan .try_map(|v| { diff --git a/tests/core/parse.rs b/tests/core/parse.rs index a438a31e89..6246c40008 100644 --- a/tests/core/parse.rs +++ b/tests/core/parse.rs @@ -273,11 +273,7 @@ impl ExecutionSpec { env: HashMap, http: Arc, ) -> Arc { - let mut blueprint = Blueprint::try_from(config).unwrap(); - - if cfg!(feature = "force_jit") { - blueprint.server.enable_jit = true; - } + let blueprint = Blueprint::try_from(config).unwrap(); let script = blueprint.server.script.clone(); diff --git a/tests/core/snapshots/test-enable-jit.md_merged.snap b/tests/core/snapshots/test-enable-jit.md_merged.snap index 3d61f7941a..99a8a80fcb 100644 --- a/tests/core/snapshots/test-enable-jit.md_merged.snap +++ b/tests/core/snapshots/test-enable-jit.md_merged.snap @@ -3,7 +3,7 @@ source: tests/core/spec.rs expression: formatter snapshot_kind: text --- -schema @server(enableJIT: true, hostname: "0.0.0.0", port: 8000) @upstream { +schema @server(hostname: "0.0.0.0", port: 8000) @upstream { query: Query } diff --git a/tests/core/snapshots/test-required-fields.md_merged.snap b/tests/core/snapshots/test-required-fields.md_merged.snap index 32c7a45a63..5cb6c7f4f3 100644 --- a/tests/core/snapshots/test-required-fields.md_merged.snap +++ b/tests/core/snapshots/test-required-fields.md_merged.snap @@ -3,7 +3,7 @@ source: tests/core/spec.rs expression: formatter snapshot_kind: text --- -schema @server(enableJIT: true) @upstream { +schema @server @upstream { query: Query } From a18158195c0050e23b6a05571b907e6270e265a6 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:29:37 +0000 Subject: [PATCH 2/3] refactor: use jit executor to execute REST api --- src/core/async_graphql_hyper.rs | 11 ++++++++ src/core/http/request_handler.rs | 12 ++++++--- src/core/jit/error.rs | 6 +++++ src/core/jit/exec_const.rs | 45 +++++++++++++++++++++++++++++--- src/core/jit/graphql_error.rs | 4 +++ src/core/jit/graphql_executor.rs | 37 ++++++++++++++++---------- src/core/jit/model.rs | 2 +- src/core/jit/request.rs | 34 +++++++++++++++--------- src/core/rest/partial_request.rs | 7 ++--- 9 files changed, 119 insertions(+), 39 deletions(-) diff --git a/src/core/async_graphql_hyper.rs b/src/core/async_graphql_hyper.rs index 5518ea5da7..dd7be32c34 100644 --- a/src/core/async_graphql_hyper.rs +++ b/src/core/async_graphql_hyper.rs @@ -408,6 +408,17 @@ impl GraphQLArcResponse { pub fn into_response(self) -> Result> { self.build_response(StatusCode::OK, self.default_body()?) } + + /// Transforms a plain `GraphQLResponse` into a `Response`. + /// Differs as `to_response` by flattening the response's data + /// `{"data": {"user": {"name": "John"}}}` becomes `{"name": "John"}`. + pub fn into_rest_response(self) -> Result> { + if !self.response.is_ok() { + return self.build_response(StatusCode::INTERNAL_SERVER_ERROR, self.default_body()?); + } + + self.into_response() + } } #[cfg(test)] diff --git a/src/core/http/request_handler.rs b/src/core/http/request_handler.rs index e7ab5efae2..17a25233b4 100644 --- a/src/core/http/request_handler.rs +++ b/src/core/http/request_handler.rs @@ -241,20 +241,24 @@ async fn handle_rest_apis( *request.uri_mut() = request.uri().path().replace(API_URL_PREFIX, "").parse()?; let req_ctx = Arc::new(create_request_context(&request, app_ctx.as_ref())); if let Some(p_request) = app_ctx.endpoints.matches(&request) { + let (req, body) = request.into_parts(); let http_route = format!("{API_URL_PREFIX}{}", p_request.path.as_str()); req_counter.set_http_route(&http_route); let span = tracing::info_span!( "REST", - otel.name = format!("REST {} {}", request.method(), p_request.path.as_str()), + otel.name = format!("REST {} {}", req.method, p_request.path.as_str()), otel.kind = ?SpanKind::Server, - { HTTP_REQUEST_METHOD } = %request.method(), + { HTTP_REQUEST_METHOD } = %req.method, { HTTP_ROUTE } = http_route ); return async { - let graphql_request = p_request.into_request(request).await?; + let mut graphql_request = p_request.into_request(body).await?; + let operation_id = graphql_request.operation_id(&req.headers); + let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id) + .flatten_response(true); let mut response = graphql_request .data(req_ctx.clone()) - .execute(&app_ctx.schema) + .execute_with_jit(exec) .await .set_cache_control( app_ctx.blueprint.server.enable_cache_control_header, diff --git a/src/core/jit/error.rs b/src/core/jit/error.rs index 086b604502..25a6c388ee 100644 --- a/src/core/jit/error.rs +++ b/src/core/jit/error.rs @@ -56,6 +56,12 @@ pub enum Error { Unknown, } +impl From for Error { + fn from(value: async_graphql::ServerError) -> Self { + Self::ServerError(value) + } +} + impl ErrorExtensions for Error { fn extend(&self) -> super::graphql_error::Error { match self { diff --git a/src/core/jit/exec_const.rs b/src/core/jit/exec_const.rs index 1606f05631..b16b7fde5f 100644 --- a/src/core/jit/exec_const.rs +++ b/src/core/jit/exec_const.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_graphql_value::{ConstValue, Value}; +use derive_setters::Setters; use futures_util::future::join_all; use tailcall_valid::Validator; @@ -8,7 +9,6 @@ use super::context::Context; use super::exec::{Executor, IRExecutor}; use super::graphql_error::GraphQLError; use super::{transform, AnyResponse, BuildError, Error, OperationPlan, Request, Response, Result}; -use crate::core::app_context::AppContext; use crate::core::http::RequestContext; use crate::core::ir::model::IR; use crate::core::ir::{self, EmptyResolverContext, EvalContext}; @@ -16,15 +16,19 @@ use crate::core::jit::synth::Synth; use crate::core::jit::transform::InputResolver; use crate::core::json::{JsonLike, JsonLikeList}; use crate::core::Transform; +use crate::core::{app_context::AppContext, json::JsonObjectLike}; /// A specialized executor that executes with async_graphql::Value +#[derive(Setters)] pub struct ConstValueExecutor { pub plan: OperationPlan, + + flatten_response: bool, } impl From> for ConstValueExecutor { fn from(plan: OperationPlan) -> Self { - Self { plan } + Self { plan, flatten_response: false } } } @@ -56,6 +60,7 @@ impl ConstValueExecutor { let is_introspection_query = req_ctx.server.get_enable_introspection() && self.plan.is_introspection_query; + let flatten_response = self.flatten_response; let variables = &request.variables; // Attempt to skip unnecessary fields @@ -102,13 +107,45 @@ impl ConstValueExecutor { let async_req = async_graphql::Request::from(request).only_introspection(); let async_resp = app_ctx.execute(async_req).await; - resp.merge_with(&async_resp).into() + to_any_response(resp.merge_with(&async_resp), flatten_response) } else { - resp.into() + to_any_response(resp, flatten_response) } } } +fn to_any_response( + resp: Response, + flatten: bool, +) -> AnyResponse> { + if flatten { + if resp.errors.is_empty() { + AnyResponse { + body: Arc::new( + serde_json::to_vec(flatten_response(&resp.data)).unwrap_or_default(), + ), + is_ok: true, + cache_control: resp.cache_control, + } + } else { + AnyResponse { + body: Arc::new(serde_json::to_vec(&resp).unwrap_or_default()), + is_ok: false, + cache_control: resp.cache_control, + } + } + } else { + resp.into() + } +} + +fn flatten_response<'a, T: JsonLike<'a>>(data: &'a T) -> &'a T { + match data.as_object() { + Some(obj) if obj.len() == 1 => flatten_response(obj.iter().next().unwrap().1), + _ => data, + } +} + struct ConstValueExec<'a> { plan: &'a OperationPlan, req_context: &'a RequestContext, diff --git a/src/core/jit/graphql_error.rs b/src/core/jit/graphql_error.rs index 6d2e0a7132..ade6b4af61 100644 --- a/src/core/jit/graphql_error.rs +++ b/src/core/jit/graphql_error.rs @@ -53,6 +53,10 @@ impl From> for GraphQLError { return e.into(); } + if let super::Error::ServerError(e) = inner_value { + return e.into(); + } + let ext = inner_value.extend().extensions; let mut server_error = GraphQLError::new(inner_value.to_string(), Some(position)); server_error.extensions = ext; diff --git a/src/core/jit/graphql_executor.rs b/src/core/jit/graphql_executor.rs index e1c131c374..bf2c7546be 100644 --- a/src/core/jit/graphql_executor.rs +++ b/src/core/jit/graphql_executor.rs @@ -5,21 +5,23 @@ use std::sync::Arc; use async_graphql::{BatchRequest, Value}; use async_graphql_value::{ConstValue, Extensions}; +use derive_setters::Setters; use futures_util::stream::FuturesOrdered; use futures_util::StreamExt; use tailcall_hasher::TailcallHasher; use super::{AnyResponse, BatchResponse, Response}; -use crate::core::app_context::AppContext; +use crate::core::{app_context::AppContext, async_graphql_hyper::GraphQLRequest}; use crate::core::async_graphql_hyper::OperationId; use crate::core::http::RequestContext; use crate::core::jit::{self, ConstValueExecutor, OPHash, Pos, Positioned}; -#[derive(Clone)] +#[derive(Clone, Setters)] pub struct JITExecutor { app_ctx: Arc, req_ctx: Arc, operation_id: OperationId, + flatten_response: bool, } impl JITExecutor { @@ -28,7 +30,7 @@ impl JITExecutor { req_ctx: Arc, operation_id: OperationId, ) -> Self { - Self { app_ctx, req_ctx, operation_id } + Self { app_ctx, req_ctx, operation_id, flatten_response: false } } #[inline(always)] @@ -62,21 +64,20 @@ impl JITExecutor { } #[inline(always)] - fn req_hash(request: &async_graphql::Request) -> OPHash { + fn req_hash(request: &impl Hash) -> OPHash { let mut hasher = TailcallHasher::default(); - request.query.hash(&mut hasher); + request.hash(&mut hasher); OPHash::new(hasher.finish()) } } impl JITExecutor { - pub fn execute( - &self, - request: async_graphql::Request, - ) -> impl Future>> + Send + '_ { - // TODO: hash considering only the query itself ignoring specified operation and - // variables that could differ for the same query + pub fn execute(&self, request: T) -> impl Future>> + Send + '_ + where + jit::Request: TryFrom, + T: Hash + Send + 'static, + { let hash = Self::req_hash(&request); async move { @@ -84,7 +85,14 @@ impl JITExecutor { return response.clone(); } - let jit_request = jit::Request::from(request); + let jit_request = match jit::Request::try_from(request) { + Ok(request) => request, + Err(error) => { + return Response::::default() + .with_errors(vec![Positioned::new(error, Pos::default())]) + .into() + } + }; let exec = if let Some(op) = self.app_ctx.operation_plans.get(&hash) { ConstValueExecutor::from(op.value().clone()) } else { @@ -102,6 +110,7 @@ impl JITExecutor { exec }; + let exec = exec.flatten_response(self.flatten_response); let is_const = exec.plan.is_const; let is_protected = exec.plan.is_protected; @@ -125,10 +134,10 @@ impl JITExecutor { /// Execute a GraphQL batch query. pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse> { match batch_request { - BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await), + BatchRequest::Single(request) => BatchResponse::Single(self.execute(GraphQLRequest(request)).await), BatchRequest::Batch(requests) => { let futs = FuturesOrdered::from_iter( - requests.into_iter().map(|request| self.execute(request)), + requests.into_iter().map(|request| self.execute(GraphQLRequest(request))), ); let responses = futs.collect::>().await; BatchResponse::Batch(responses) diff --git a/src/core/jit/model.rs b/src/core/jit/model.rs index 9b2950a22a..d69e9d571f 100644 --- a/src/core/jit/model.rs +++ b/src/core/jit/model.rs @@ -590,7 +590,7 @@ mod test { let bp = Blueprint::try_from(&module).unwrap(); let request = Request::new(query); - let jit_request = jit::Request::from(request); + let jit_request = jit::Request::try_from(request).unwrap(); jit_request.create_plan(&bp).unwrap() } diff --git a/src/core/jit/request.rs b/src/core/jit/request.rs index f37c4a721e..4f0a81da42 100644 --- a/src/core/jit/request.rs +++ b/src/core/jit/request.rs @@ -1,38 +1,46 @@ use std::collections::HashMap; use std::ops::DerefMut; +use async_graphql::parser::types::ExecutableDocument; use async_graphql_value::ConstValue; -use serde::Deserialize; use tailcall_valid::Validator; use super::{transform, Builder, OperationPlan, Result, Variables}; -use crate::core::blueprint::Blueprint; use crate::core::transform::TransformerOps; use crate::core::Transform; +use crate::core::{async_graphql_hyper::GraphQLRequest, blueprint::Blueprint}; -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Clone)] pub struct Request { - #[serde(default)] pub query: String, - #[serde(default, rename = "operationName")] pub operation_name: Option, - #[serde(default)] pub variables: Variables, - #[serde(default)] pub extensions: HashMap, + pub parsed_query: ExecutableDocument, } // NOTE: This is hot code and should allocate minimal memory -impl From for Request { - fn from(mut value: async_graphql::Request) -> Self { +impl TryFrom for Request { + type Error = super::Error; + + fn try_from(mut value: async_graphql::Request) -> Result { let variables = std::mem::take(value.variables.deref_mut()); - Self { + Ok(Self { + parsed_query: value.parsed_query()?.clone(), query: value.query, operation_name: value.operation_name, variables: Variables::from_iter(variables.into_iter().map(|(k, v)| (k.to_string(), v))), extensions: value.extensions.0, - } + }) + } +} + +impl TryFrom for Request { + type Error = super::Error; + + fn try_from(value: GraphQLRequest) -> Result { + Self::try_from(value.0) } } @@ -41,8 +49,7 @@ impl Request { &self, blueprint: &Blueprint, ) -> Result> { - let doc = async_graphql::parser::parse_query(&self.query)?; - let builder = Builder::new(blueprint, &doc); + let builder = Builder::new(blueprint, &self.parsed_query); let plan = builder.build(self.operation_name.as_deref())?; transform::CheckConst::new() @@ -67,6 +74,7 @@ impl Request { operation_name: None, variables: Variables::new(), extensions: HashMap::new(), + parsed_query: async_graphql::parser::parse_query(query).unwrap(), } } diff --git a/src/core/rest/partial_request.rs b/src/core/rest/partial_request.rs index cfd8053483..aaa2a31658 100644 --- a/src/core/rest/partial_request.rs +++ b/src/core/rest/partial_request.rs @@ -1,9 +1,10 @@ use async_graphql::parser::types::ExecutableDocument; use async_graphql::{Name, Variables}; use async_graphql_value::ConstValue; +use hyper::Body; use super::path::Path; -use super::{Request, Result}; +use super::Result; use crate::core::async_graphql_hyper::GraphQLRequest; /// A partial GraphQLRequest that contains a parsed executable GraphQL document. @@ -16,10 +17,10 @@ pub struct PartialRequest<'a> { } impl PartialRequest<'_> { - pub async fn into_request(self, request: Request) -> Result { + pub async fn into_request(self, body: Body) -> Result { let mut variables = self.variables; if let Some(key) = self.body { - let bytes = hyper::body::to_bytes(request.into_body()).await?; + let bytes = hyper::body::to_bytes(body).await?; let body: ConstValue = serde_json::from_slice(&bytes)?; variables.insert(Name::new(key), body); } From d16e2e073734e62a3283877372c4162bfb322b39 Mon Sep 17 00:00:00 2001 From: meskill <8974488+meskill@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:35:40 +0000 Subject: [PATCH 3/3] remove async_graphql::Request usage --- src/core/async_graphql_hyper.rs | 197 +++++++++++++------------------ src/core/http/request_handler.rs | 3 +- src/core/jit/exec_const.rs | 4 +- src/core/jit/graphql_executor.rs | 129 ++++++++++---------- src/core/jit/model.rs | 14 ++- src/core/jit/request.rs | 32 +++-- src/core/jit/response.rs | 8 +- src/core/rest/endpoint.rs | 5 +- src/core/rest/operation.rs | 10 +- src/core/rest/partial_request.rs | 30 +++-- 10 files changed, 214 insertions(+), 218 deletions(-) diff --git a/src/core/async_graphql_hyper.rs b/src/core/async_graphql_hyper.rs index dd7be32c34..4b9df5e488 100644 --- a/src/core/async_graphql_hyper.rs +++ b/src/core/async_graphql_hyper.rs @@ -1,9 +1,10 @@ -use std::any::Any; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; use anyhow::Result; -use async_graphql::parser::types::{ExecutableDocument, OperationType}; -use async_graphql::{BatchResponse, Executor, Value}; +use async_graphql::parser::types::ExecutableDocument; +use async_graphql::{BatchResponse, Value}; +use async_graphql_value::ConstValue; use http::header::{HeaderMap, HeaderValue, CACHE_CONTROL, CONTENT_TYPE}; use http::{Response, StatusCode}; use hyper::Body; @@ -13,32 +14,17 @@ use tailcall_hasher::TailcallHasher; use super::jit::{BatchResponse as JITBatchResponse, JITExecutor}; +// TODO: replace usage with some other implementation. +// This one is used to calculate hash and use the value later +// as a key in the HashMap. But such use could lead to potential +// issues in case of hash collisions #[derive(PartialEq, Eq, Clone, Hash, Debug)] pub struct OperationId(u64); #[async_trait::async_trait] pub trait GraphQLRequestLike: Hash + Send { - fn data(self, data: D) -> Self; - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor; - async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse; - fn parse_query(&mut self) -> Option<&ExecutableDocument>; - - fn is_query(&mut self) -> bool { - self.parse_query() - .map(|a| { - let mut is_query = false; - for (_, operation) in a.operations.iter() { - is_query = operation.node.ty == OperationType::Query; - } - is_query - }) - .unwrap_or(false) - } - fn operation_id(&self, headers: &HeaderMap) -> OperationId { let mut hasher = TailcallHasher::default(); let state = &mut hasher; @@ -51,86 +37,101 @@ pub trait GraphQLRequestLike: Hash + Send { } } -#[derive(Debug, Deserialize)] -pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest); -impl GraphQLBatchRequest {} -impl Hash for GraphQLBatchRequest { - //TODO: Fix Hash implementation for BatchRequest, which should ideally batch - // execution of individual requests instead of the whole chunk of requests as - // one. - fn hash(&self, state: &mut H) { - for request in self.0.iter() { - request.query.hash(state); - request.operation_name.hash(state); - for (name, value) in request.variables.iter() { - name.hash(state); - value.to_string().hash(state); - } - } - } +#[derive(Debug, Hash, Serialize, Deserialize)] +#[serde(untagged)] +pub enum BatchWrapper { + Single(T), + Batch(Vec), } -#[async_trait::async_trait] -impl GraphQLRequestLike for GraphQLBatchRequest { - fn data(mut self, data: D) -> Self { - for request in self.0.iter_mut() { - request.data.insert(data.clone()); - } - self - } - async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { - GraphQLArcResponse::new(executor.execute_batch(self.0).await) - } - - /// Shortcut method to execute the request on the executor. - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - GraphQLResponse(executor.execute_batch(self.0).await) - } +pub type GraphQLBatchRequest = BatchWrapper; - fn parse_query(&mut self) -> Option<&ExecutableDocument> { - None +#[async_trait::async_trait] +impl GraphQLRequestLike for BatchWrapper { + async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { + GraphQLArcResponse::new(executor.execute_batch(self).await) } } -#[derive(Debug, Deserialize)] -pub struct GraphQLRequest(pub async_graphql::Request); +#[derive(Debug, Default, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GraphQLRequest { + #[serde(default)] + pub query: String, + #[serde(default)] + pub operation_name: Option, + #[serde(default)] + pub variables: HashMap, + #[serde(default)] + pub extensions: HashMap, +} -impl GraphQLRequest {} impl Hash for GraphQLRequest { fn hash(&self, state: &mut H) { - self.0.query.hash(state); - self.0.operation_name.hash(state); - for (name, value) in self.0.variables.iter() { + self.query.hash(state); + self.operation_name.hash(state); + for (name, value) in self.variables.iter() { name.hash(state); value.to_string().hash(state); } } } + +impl GraphQLRequest { + pub fn new(query: impl Into) -> Self { + Self { query: query.into(), ..Default::default() } + } +} + #[async_trait::async_trait] impl GraphQLRequestLike for GraphQLRequest { - #[must_use] - fn data(mut self, data: D) -> Self { - self.0.data.insert(data); - self - } async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { - let response = executor.execute(self.0).await; + let response = executor.execute(self).await; GraphQLArcResponse::new(JITBatchResponse::Single(response)) } +} + +#[derive(Debug)] +pub struct ParsedGraphQLRequest { + pub query: String, + pub operation_name: Option, + pub variables: HashMap, + pub extensions: HashMap, + pub parsed_query: ExecutableDocument, +} + +impl Hash for ParsedGraphQLRequest { + fn hash(&self, state: &mut H) { + self.query.hash(state); + self.operation_name.hash(state); + for (name, value) in self.variables.iter() { + name.hash(state); + value.to_string().hash(state); + } + } +} + +impl TryFrom for ParsedGraphQLRequest { + type Error = async_graphql::parser::Error; - /// Shortcut method to execute the request on the schema. - async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - GraphQLResponse(executor.execute(self.0).await.into()) + fn try_from(req: GraphQLRequest) -> std::result::Result { + let parsed_query = async_graphql::parser::parse_query(&req.query)?; + + Ok(Self { + query: req.query, + operation_name: req.operation_name, + variables: req.variables, + extensions: req.extensions, + parsed_query, + }) } +} - fn parse_query(&mut self) -> Option<&ExecutableDocument> { - self.0.parsed_query().ok() +#[async_trait::async_trait] +impl GraphQLRequestLike for ParsedGraphQLRequest { + async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse { + let response = executor.execute(self).await; + GraphQLArcResponse::new(JITBatchResponse::Single(response)) } } @@ -148,42 +149,6 @@ impl From for GraphQLResponse { } } -impl From for GraphQLRequest { - fn from(query: GraphQLQuery) -> Self { - let mut request = async_graphql::Request::new(query.query); - - if let Some(operation_name) = query.operation_name { - request = request.operation_name(operation_name); - } - - if let Some(variables) = query.variables { - let value = serde_json::from_str(&variables).unwrap_or_default(); - let variables = async_graphql::Variables::from_json(value); - request = request.variables(variables); - } - - GraphQLRequest(request) - } -} - -#[derive(Debug)] -pub struct GraphQLQuery { - query: String, - operation_name: Option, - variables: Option, -} - -impl GraphQLQuery { - /// Shortcut method to execute the request on the schema. - pub async fn execute(self, executor: &E) -> GraphQLResponse - where - E: Executor, - { - let request: GraphQLRequest = self.into(); - request.execute(executor).await - } -} - static APPLICATION_JSON: Lazy = Lazy::new(|| HeaderValue::from_static("application/json")); diff --git a/src/core/http/request_handler.rs b/src/core/http/request_handler.rs index 17a25233b4..aa6886d129 100644 --- a/src/core/http/request_handler.rs +++ b/src/core/http/request_handler.rs @@ -252,12 +252,11 @@ async fn handle_rest_apis( { HTTP_ROUTE } = http_route ); return async { - let mut graphql_request = p_request.into_request(body).await?; + let graphql_request = p_request.into_request(body).await?; let operation_id = graphql_request.operation_id(&req.headers); let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id) .flatten_response(true); let mut response = graphql_request - .data(req_ctx.clone()) .execute_with_jit(exec) .await .set_cache_control( diff --git a/src/core/jit/exec_const.rs b/src/core/jit/exec_const.rs index b16b7fde5f..29990f2f88 100644 --- a/src/core/jit/exec_const.rs +++ b/src/core/jit/exec_const.rs @@ -9,14 +9,14 @@ use super::context::Context; use super::exec::{Executor, IRExecutor}; use super::graphql_error::GraphQLError; use super::{transform, AnyResponse, BuildError, Error, OperationPlan, Request, Response, Result}; +use crate::core::app_context::AppContext; use crate::core::http::RequestContext; use crate::core::ir::model::IR; use crate::core::ir::{self, EmptyResolverContext, EvalContext}; use crate::core::jit::synth::Synth; use crate::core::jit::transform::InputResolver; -use crate::core::json::{JsonLike, JsonLikeList}; +use crate::core::json::{JsonLike, JsonLikeList, JsonObjectLike}; use crate::core::Transform; -use crate::core::{app_context::AppContext, json::JsonObjectLike}; /// A specialized executor that executes with async_graphql::Value #[derive(Setters)] diff --git a/src/core/jit/graphql_executor.rs b/src/core/jit/graphql_executor.rs index bf2c7546be..4450de81a4 100644 --- a/src/core/jit/graphql_executor.rs +++ b/src/core/jit/graphql_executor.rs @@ -1,9 +1,8 @@ use std::collections::BTreeMap; -use std::future::Future; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use async_graphql::{BatchRequest, Value}; +use async_graphql::Value; use async_graphql_value::{ConstValue, Extensions}; use derive_setters::Setters; use futures_util::stream::FuturesOrdered; @@ -11,10 +10,10 @@ use futures_util::StreamExt; use tailcall_hasher::TailcallHasher; use super::{AnyResponse, BatchResponse, Response}; -use crate::core::{app_context::AppContext, async_graphql_hyper::GraphQLRequest}; -use crate::core::async_graphql_hyper::OperationId; +use crate::core::app_context::AppContext; +use crate::core::async_graphql_hyper::{BatchWrapper, GraphQLRequest, OperationId}; use crate::core::http::RequestContext; -use crate::core::jit::{self, ConstValueExecutor, OPHash, Pos, Positioned}; +use crate::core::jit::{self, ConstValueExecutor, OPHash}; #[derive(Clone, Setters)] pub struct JITExecutor { @@ -63,81 +62,87 @@ impl JITExecutor { out.unwrap_or_default() } + /// Calculates hash for the request considering + /// the request is const, i.e. doesn't depend on input. + /// That's basically use only the query itself to calculating the hash #[inline(always)] - fn req_hash(request: &impl Hash) -> OPHash { - let mut hasher = TailcallHasher::default(); - request.hash(&mut hasher); + fn const_execution_hash(request: &jit::Request) -> OPHash { + let hasher = &mut TailcallHasher::default(); + + request.query.hash(hasher); OPHash::new(hasher.finish()) } } impl JITExecutor { - pub fn execute(&self, request: T) -> impl Future>> + Send + '_ + pub async fn execute(&self, request: T) -> AnyResponse> where jit::Request: TryFrom, T: Hash + Send + 'static, { - let hash = Self::req_hash(&request); - - async move { - if let Some(response) = self.app_ctx.const_execution_cache.get(&hash) { - return response.clone(); - } - - let jit_request = match jit::Request::try_from(request) { - Ok(request) => request, - Err(error) => { - return Response::::default() - .with_errors(vec![Positioned::new(error, Pos::default())]) - .into() - } - }; - let exec = if let Some(op) = self.app_ctx.operation_plans.get(&hash) { - ConstValueExecutor::from(op.value().clone()) - } else { - let exec = match ConstValueExecutor::try_new(&jit_request, &self.app_ctx) { - Ok(exec) => exec, - Err(error) => { - return Response::::default() - .with_errors(vec![Positioned::new(error, Pos::default())]) - .into() - } - }; - self.app_ctx - .operation_plans - .insert(hash.clone(), exec.plan.clone()); - exec - }; - - let exec = exec.flatten_response(self.flatten_response); - let is_const = exec.plan.is_const; - let is_protected = exec.plan.is_protected; - - let response = if exec.plan.can_dedupe() { - self.dedupe_and_exec(exec, jit_request).await - } else { - self.exec(exec, jit_request).await + let jit_request = match jit::Request::try_from(request) { + Ok(request) => request, + Err(error) => return Response::::from(error).into(), + }; + + let const_execution_hash = Self::const_execution_hash(&jit_request); + + // check if the request is has been set to const_execution_cache + // and if yes serve the response from the cache since + // the query doesn't depend on input and could be calculated once + // WARN: make sure the value is set to cache only if the plan is actually + // is_const + if let Some(response) = self + .app_ctx + .const_execution_cache + .get(&const_execution_hash) + { + return response.clone(); + } + let exec = if let Some(op) = self.app_ctx.operation_plans.get(&const_execution_hash) { + ConstValueExecutor::from(op.value().clone()) + } else { + let exec = match ConstValueExecutor::try_new(&jit_request, &self.app_ctx) { + Ok(exec) => exec, + Err(error) => return Response::::from(error).into(), }; - - // Cache the response if it's constant and not wrapped with protected. - if is_const && !is_protected { - self.app_ctx - .const_execution_cache - .insert(hash, response.clone()); - } - - response + self.app_ctx + .operation_plans + .insert(const_execution_hash.clone(), exec.plan.clone()); + exec + }; + + let exec = exec.flatten_response(self.flatten_response); + let is_const = exec.plan.is_const; + let is_protected = exec.plan.is_protected; + + let response = if exec.plan.can_dedupe() { + self.dedupe_and_exec(exec, jit_request).await + } else { + self.exec(exec, jit_request).await + }; + + // Cache the response if it's constant and not wrapped with protected. + if is_const && !is_protected { + self.app_ctx + .const_execution_cache + .insert(const_execution_hash, response.clone()); } + + response } /// Execute a GraphQL batch query. - pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse> { + pub async fn execute_batch( + &self, + batch_request: BatchWrapper, + ) -> BatchResponse> { match batch_request { - BatchRequest::Single(request) => BatchResponse::Single(self.execute(GraphQLRequest(request)).await), - BatchRequest::Batch(requests) => { + BatchWrapper::Single(request) => BatchResponse::Single(self.execute(request).await), + BatchWrapper::Batch(requests) => { let futs = FuturesOrdered::from_iter( - requests.into_iter().map(|request| self.execute(GraphQLRequest(request))), + requests.into_iter().map(|request| self.execute(request)), ); let responses = futs.collect::>().await; BatchResponse::Batch(responses) diff --git a/src/core/jit/model.rs b/src/core/jit/model.rs index d69e9d571f..65fe69dff5 100644 --- a/src/core/jit/model.rs +++ b/src/core/jit/model.rs @@ -20,6 +20,12 @@ use crate::core::scalar::Scalar; #[derive(Debug, Deserialize, Clone)] pub struct Variables(HashMap); +impl From> for Variables { + fn from(value: HashMap) -> Self { + Self(value) + } +} + impl PathString for Variables { fn path_string<'a, T: AsRef>(&'a self, path: &'a [T]) -> Option> { self.get(path[0].as_ref()) @@ -293,6 +299,10 @@ impl Debug for Field { } } +// TODO: replace usage with some other implementation. +// This one is used to calculate hash and use the value later +// as a key in the HashMap. But such use could lead to potential +// issues in case of hash collisions #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct OPHash(u64); @@ -575,10 +585,10 @@ impl From for Positioned { #[cfg(test)] mod test { use async_graphql::parser::types::ConstDirective; - use async_graphql::Request; use async_graphql_value::ConstValue; use super::{Directive, OperationPlan}; + use crate::core::async_graphql_hyper::GraphQLRequest; use crate::core::blueprint::Blueprint; use crate::core::config::ConfigModule; use crate::core::jit; @@ -589,7 +599,7 @@ mod test { let module = ConfigModule::from(config); let bp = Blueprint::try_from(&module).unwrap(); - let request = Request::new(query); + let request = GraphQLRequest::new(query); let jit_request = jit::Request::try_from(request).unwrap(); jit_request.create_plan(&bp).unwrap() } diff --git a/src/core/jit/request.rs b/src/core/jit/request.rs index 4f0a81da42..261d98e53c 100644 --- a/src/core/jit/request.rs +++ b/src/core/jit/request.rs @@ -1,14 +1,14 @@ use std::collections::HashMap; -use std::ops::DerefMut; use async_graphql::parser::types::ExecutableDocument; use async_graphql_value::ConstValue; use tailcall_valid::Validator; use super::{transform, Builder, OperationPlan, Result, Variables}; +use crate::core::async_graphql_hyper::{GraphQLRequest, ParsedGraphQLRequest}; +use crate::core::blueprint::Blueprint; use crate::core::transform::TransformerOps; use crate::core::Transform; -use crate::core::{async_graphql_hyper::GraphQLRequest, blueprint::Blueprint}; #[derive(Debug, Clone)] pub struct Request { @@ -19,28 +19,26 @@ pub struct Request { pub parsed_query: ExecutableDocument, } -// NOTE: This is hot code and should allocate minimal memory -impl TryFrom for Request { +impl TryFrom for Request { type Error = super::Error; - fn try_from(mut value: async_graphql::Request) -> Result { - let variables = std::mem::take(value.variables.deref_mut()); + fn try_from(value: GraphQLRequest) -> Result { + let value = ParsedGraphQLRequest::try_from(value)?; - Ok(Self { - parsed_query: value.parsed_query()?.clone(), - query: value.query, - operation_name: value.operation_name, - variables: Variables::from_iter(variables.into_iter().map(|(k, v)| (k.to_string(), v))), - extensions: value.extensions.0, - }) + Self::try_from(value) } } -impl TryFrom for Request { +impl TryFrom for Request { type Error = super::Error; - - fn try_from(value: GraphQLRequest) -> Result { - Self::try_from(value.0) + fn try_from(value: ParsedGraphQLRequest) -> Result { + Ok(Self { + parsed_query: value.parsed_query, + query: value.query, + operation_name: value.operation_name, + variables: Variables::from(value.variables), + extensions: value.extensions, + }) } } diff --git a/src/core/jit/response.rs b/src/core/jit/response.rs index aabe67dd65..d12cb5bfda 100644 --- a/src/core/jit/response.rs +++ b/src/core/jit/response.rs @@ -4,7 +4,7 @@ use derive_setters::Setters; use serde::Serialize; use super::graphql_error::GraphQLError; -use super::Positioned; +use super::{Pos, Positioned}; use crate::core::async_graphql_hyper::CacheControl; use crate::core::jit; use crate::core::json::{JsonLike, JsonObjectLike}; @@ -33,6 +33,12 @@ impl Default for Response { } } +impl From for Response { + fn from(value: jit::Error) -> Self { + Response::default().with_errors(vec![Positioned::new(value, Pos::default())]) + } +} + impl Response { pub fn new(result: Result>) -> Self { match result { diff --git a/src/core/rest/endpoint.rs b/src/core/rest/endpoint.rs index 2234a9f885..8daf9ca31b 100644 --- a/src/core/rest/endpoint.rs +++ b/src/core/rest/endpoint.rs @@ -11,7 +11,6 @@ use super::path::{Path, Segment}; use super::query_params::QueryParams; use super::type_map::TypeMap; use super::{Request, Result}; -use crate::core::async_graphql_hyper::GraphQLRequest; use crate::core::directive::DirectiveCodec; use crate::core::http::Method; use crate::core::rest::typed_variables::{UrlParamType, N}; @@ -83,11 +82,11 @@ impl Endpoint { Ok(endpoints) } - pub fn into_request(self) -> GraphQLRequest { + pub fn into_request(self) -> async_graphql::Request { let variables = Self::get_default_variables(&self); let mut req = async_graphql::Request::new("").variables(variables); req.set_parsed_query(Self::remove_rest_directives(self.doc)); - GraphQLRequest(req) + req } fn get_default_variables(endpoint: &Endpoint) -> Variables { diff --git a/src/core/rest/operation.rs b/src/core/rest/operation.rs index ea65a9f89a..31598de4db 100644 --- a/src/core/rest/operation.rs +++ b/src/core/rest/operation.rs @@ -4,24 +4,26 @@ use async_graphql::dynamic::Schema; use tailcall_valid::{Cause, Valid, Validator}; use super::{Error, Result}; -use crate::core::async_graphql_hyper::{GraphQLRequest, GraphQLRequestLike}; use crate::core::blueprint::{Blueprint, SchemaModifiers}; use crate::core::http::RequestContext; #[derive(Debug)] pub struct OperationQuery { - query: GraphQLRequest, + query: async_graphql::Request, } impl OperationQuery { - pub fn new(query: GraphQLRequest, request_context: Arc) -> Result { + pub fn new( + query: async_graphql::Request, + request_context: Arc, + ) -> Result { let query = query.data(request_context); Ok(Self { query }) } async fn validate(self, schema: &Schema) -> Vec { schema - .execute(self.query.0) + .execute(self.query) .await .errors .iter() diff --git a/src/core/rest/partial_request.rs b/src/core/rest/partial_request.rs index aaa2a31658..d6838e9af0 100644 --- a/src/core/rest/partial_request.rs +++ b/src/core/rest/partial_request.rs @@ -1,11 +1,14 @@ +use std::collections::HashMap; +use std::ops::DerefMut; + use async_graphql::parser::types::ExecutableDocument; -use async_graphql::{Name, Variables}; +use async_graphql::Variables; use async_graphql_value::ConstValue; use hyper::Body; use super::path::Path; use super::Result; -use crate::core::async_graphql_hyper::GraphQLRequest; +use crate::core::async_graphql_hyper::ParsedGraphQLRequest; /// A partial GraphQLRequest that contains a parsed executable GraphQL document. #[derive(Debug)] @@ -17,17 +20,26 @@ pub struct PartialRequest<'a> { } impl PartialRequest<'_> { - pub async fn into_request(self, body: Body) -> Result { - let mut variables = self.variables; + pub async fn into_request(mut self, body: Body) -> Result { + let variables = std::mem::take(self.variables.deref_mut()); + let mut variables = + HashMap::from_iter(variables.into_iter().map(|(k, v)| (k.to_string(), v))); + if let Some(key) = self.body { let bytes = hyper::body::to_bytes(body).await?; let body: ConstValue = serde_json::from_slice(&bytes)?; - variables.insert(Name::new(key), body); + variables.insert(key.to_string(), body); } - let mut req = async_graphql::Request::new("").variables(variables); - req.set_parsed_query(self.doc.clone()); - - Ok(GraphQLRequest(req)) + Ok(ParsedGraphQLRequest { + // use path as query because query is used as part of the hashing + // and we need to have different hashed for different operations + // TODO: is there any way to make it more explicit here? + query: self.path.as_str().to_string(), + operation_name: None, + variables, + extensions: Default::default(), + parsed_query: self.doc.clone(), + }) } }