Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use jit engine to handle REST API #3225

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 92 additions & 116 deletions src/core/async_graphql_hyper.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::any::Any;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};

use anyhow::Result;
use async_graphql::parser::types::{ExecutableDocument, OperationType};
use async_graphql::{BatchResponse, Executor, Value};
use async_graphql::parser::types::ExecutableDocument;
use async_graphql::{BatchResponse, Value};
use async_graphql_value::ConstValue;
use http::header::{HeaderMap, HeaderValue, CACHE_CONTROL, CONTENT_TYPE};
use http::{Response, StatusCode};
use hyper::Body;
Expand All @@ -13,32 +14,17 @@ use tailcall_hasher::TailcallHasher;

use super::jit::{BatchResponse as JITBatchResponse, JITExecutor};

// TODO: replace usage with some other implementation.
// This one is used to calculate hash and use the value later
// as a key in the HashMap. But such use could lead to potential
// issues in case of hash collisions
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
pub struct OperationId(u64);

#[async_trait::async_trait]
pub trait GraphQLRequestLike: Hash + Send {
fn data<D: Any + Clone + Send + Sync>(self, data: D) -> Self;
async fn execute<E>(self, executor: &E) -> GraphQLResponse
where
E: Executor;

async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse;

fn parse_query(&mut self) -> Option<&ExecutableDocument>;

fn is_query(&mut self) -> bool {
self.parse_query()
.map(|a| {
let mut is_query = false;
for (_, operation) in a.operations.iter() {
is_query = operation.node.ty == OperationType::Query;
}
is_query
})
.unwrap_or(false)
}

fn operation_id(&self, headers: &HeaderMap) -> OperationId {
let mut hasher = TailcallHasher::default();
let state = &mut hasher;
Expand All @@ -51,86 +37,101 @@ pub trait GraphQLRequestLike: Hash + Send {
}
}

#[derive(Debug, Deserialize)]
pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
impl GraphQLBatchRequest {}
impl Hash for GraphQLBatchRequest {
//TODO: Fix Hash implementation for BatchRequest, which should ideally batch
// execution of individual requests instead of the whole chunk of requests as
// one.
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
for request in self.0.iter() {
request.query.hash(state);
request.operation_name.hash(state);
for (name, value) in request.variables.iter() {
name.hash(state);
value.to_string().hash(state);
}
}
}
#[derive(Debug, Hash, Serialize, Deserialize)]
#[serde(untagged)]
pub enum BatchWrapper<T> {
Single(T),
Batch(Vec<T>),
}
#[async_trait::async_trait]
impl GraphQLRequestLike for GraphQLBatchRequest {
fn data<D: Any + Clone + Send + Sync>(mut self, data: D) -> Self {
for request in self.0.iter_mut() {
request.data.insert(data.clone());
}
self
}

async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse {
GraphQLArcResponse::new(executor.execute_batch(self.0).await)
}

/// Shortcut method to execute the request on the executor.
async fn execute<E>(self, executor: &E) -> GraphQLResponse
where
E: Executor,
{
GraphQLResponse(executor.execute_batch(self.0).await)
}
pub type GraphQLBatchRequest = BatchWrapper<GraphQLRequest>;

fn parse_query(&mut self) -> Option<&ExecutableDocument> {
None
#[async_trait::async_trait]
impl GraphQLRequestLike for BatchWrapper<GraphQLRequest> {
async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse {
GraphQLArcResponse::new(executor.execute_batch(self).await)
}
}

#[derive(Debug, Deserialize)]
pub struct GraphQLRequest(pub async_graphql::Request);
#[derive(Debug, Default, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GraphQLRequest {
#[serde(default)]
pub query: String,
#[serde(default)]
pub operation_name: Option<String>,
#[serde(default)]
pub variables: HashMap<String, ConstValue>,
#[serde(default)]
pub extensions: HashMap<String, ConstValue>,
}

impl GraphQLRequest {}
impl Hash for GraphQLRequest {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.query.hash(state);
self.0.operation_name.hash(state);
for (name, value) in self.0.variables.iter() {
self.query.hash(state);
self.operation_name.hash(state);
for (name, value) in self.variables.iter() {
name.hash(state);
value.to_string().hash(state);
}
}
}

impl GraphQLRequest {
pub fn new(query: impl Into<String>) -> Self {
Self { query: query.into(), ..Default::default() }
}
}

#[async_trait::async_trait]
impl GraphQLRequestLike for GraphQLRequest {
#[must_use]
fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
self.0.data.insert(data);
self
}
async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse {
let response = executor.execute(self.0).await;
let response = executor.execute(self).await;
GraphQLArcResponse::new(JITBatchResponse::Single(response))
}
}

#[derive(Debug)]
pub struct ParsedGraphQLRequest {
pub query: String,
pub operation_name: Option<String>,
pub variables: HashMap<String, ConstValue>,
pub extensions: HashMap<String, ConstValue>,
pub parsed_query: ExecutableDocument,
}

/// Shortcut method to execute the request on the schema.
async fn execute<E>(self, executor: &E) -> GraphQLResponse
where
E: Executor,
{
GraphQLResponse(executor.execute(self.0).await.into())
impl Hash for ParsedGraphQLRequest {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.query.hash(state);
self.operation_name.hash(state);
for (name, value) in self.variables.iter() {
name.hash(state);
value.to_string().hash(state);
}
}
}

impl TryFrom<GraphQLRequest> for ParsedGraphQLRequest {
type Error = async_graphql::parser::Error;

fn parse_query(&mut self) -> Option<&ExecutableDocument> {
self.0.parsed_query().ok()
fn try_from(req: GraphQLRequest) -> std::result::Result<Self, Self::Error> {
let parsed_query = async_graphql::parser::parse_query(&req.query)?;

Ok(Self {
query: req.query,
operation_name: req.operation_name,
variables: req.variables,
extensions: req.extensions,
parsed_query,
})
}
}

#[async_trait::async_trait]
impl GraphQLRequestLike for ParsedGraphQLRequest {
async fn execute_with_jit(self, executor: JITExecutor) -> GraphQLArcResponse {
let response = executor.execute(self).await;
GraphQLArcResponse::new(JITBatchResponse::Single(response))
}
}

Expand All @@ -148,42 +149,6 @@ impl From<async_graphql::Response> for GraphQLResponse {
}
}

impl From<GraphQLQuery> for GraphQLRequest {
fn from(query: GraphQLQuery) -> Self {
let mut request = async_graphql::Request::new(query.query);

if let Some(operation_name) = query.operation_name {
request = request.operation_name(operation_name);
}

if let Some(variables) = query.variables {
let value = serde_json::from_str(&variables).unwrap_or_default();
let variables = async_graphql::Variables::from_json(value);
request = request.variables(variables);
}

GraphQLRequest(request)
}
}

#[derive(Debug)]
pub struct GraphQLQuery {
query: String,
operation_name: Option<String>,
variables: Option<String>,
}

impl GraphQLQuery {
/// Shortcut method to execute the request on the schema.
pub async fn execute<E>(self, executor: &E) -> GraphQLResponse
where
E: Executor,
{
let request: GraphQLRequest = self.into();
request.execute(executor).await
}
}

static APPLICATION_JSON: Lazy<HeaderValue> =
Lazy::new(|| HeaderValue::from_static("application/json"));

Expand Down Expand Up @@ -408,6 +373,17 @@ impl GraphQLArcResponse {
pub fn into_response(self) -> Result<Response<hyper::Body>> {
self.build_response(StatusCode::OK, self.default_body()?)
}

/// Transforms a plain `GraphQLResponse` into a `Response<Body>`.
/// Differs as `to_response` by flattening the response's data
/// `{"data": {"user": {"name": "John"}}}` becomes `{"name": "John"}`.
pub fn into_rest_response(self) -> Result<Response<hyper::Body>> {
if !self.response.is_ok() {
return self.build_response(StatusCode::INTERNAL_SERVER_ERROR, self.default_body()?);
}

self.into_response()
}
}

#[cfg(test)]
Expand Down
13 changes: 8 additions & 5 deletions src/core/http/request_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,20 +241,23 @@ async fn handle_rest_apis(
*request.uri_mut() = request.uri().path().replace(API_URL_PREFIX, "").parse()?;
let req_ctx = Arc::new(create_request_context(&request, app_ctx.as_ref()));
if let Some(p_request) = app_ctx.endpoints.matches(&request) {
let (req, body) = request.into_parts();
let http_route = format!("{API_URL_PREFIX}{}", p_request.path.as_str());
req_counter.set_http_route(&http_route);
let span = tracing::info_span!(
"REST",
otel.name = format!("REST {} {}", request.method(), p_request.path.as_str()),
otel.name = format!("REST {} {}", req.method, p_request.path.as_str()),
otel.kind = ?SpanKind::Server,
{ HTTP_REQUEST_METHOD } = %request.method(),
{ HTTP_REQUEST_METHOD } = %req.method,
{ HTTP_ROUTE } = http_route
);
return async {
let graphql_request = p_request.into_request(request).await?;
let graphql_request = p_request.into_request(body).await?;
let operation_id = graphql_request.operation_id(&req.headers);
let exec = JITExecutor::new(app_ctx.clone(), req_ctx.clone(), operation_id)
.flatten_response(true);
let mut response = graphql_request
.data(req_ctx.clone())
.execute(&app_ctx.schema)
.execute_with_jit(exec)
.await
.set_cache_control(
app_ctx.blueprint.server.enable_cache_control_header,
Expand Down
6 changes: 6 additions & 0 deletions src/core/jit/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
Unknown,
}

impl From<async_graphql::ServerError> for Error {
fn from(value: async_graphql::ServerError) -> Self {
Self::ServerError(value)
}

Check warning on line 62 in src/core/jit/error.rs

View check run for this annotation

Codecov / codecov/patch

src/core/jit/error.rs#L60-L62

Added lines #L60 - L62 were not covered by tests
}

impl ErrorExtensions for Error {
fn extend(&self) -> super::graphql_error::Error {
match self {
Expand Down
Loading
Loading