Skip to content

Commit

Permalink
feat: persisted documents (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
dotansimha authored Oct 19, 2023
1 parent 0ea4455 commit 464f030
Show file tree
Hide file tree
Showing 24 changed files with 1,094 additions and 62 deletions.
18 changes: 12 additions & 6 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use std::{fs::read_to_string, path::Path};

use crate::plugins::{cors::CorsPluginConfig, http_get_plugin::HttpGetPluginConfig};
use crate::plugins::{
cors::CorsPluginConfig, http_get_plugin::HttpGetPluginConfig,
persisted_documents::config::PersistedOperationsPluginConfig,
};

#[derive(Deserialize, Debug, Clone)]
pub struct ConductorConfig {
Expand All @@ -21,7 +24,7 @@ pub struct EndpointDefinition {
pub plugins: Option<Vec<PluginDefinition>>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum PluginDefinition {
#[serde(rename = "cors")]
Expand All @@ -32,6 +35,9 @@ pub enum PluginDefinition {

#[serde(rename = "http_get")]
HttpGetPlugin(HttpGetPluginConfig),

#[serde(rename = "persisted_operations")]
PersistedOperationsPlugin(PersistedOperationsPluginConfig),
}

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -69,7 +75,7 @@ pub struct LoggerConfig {
pub level: Level,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct ServerConfig {
#[serde(default = "default_server_port")]
pub port: u16,
Expand Down Expand Up @@ -97,7 +103,7 @@ fn default_server_host() -> String {
"127.0.0.1".to_string()
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
#[serde(tag = "type")]
pub enum SourceDefinition {
#[serde(rename = "graphql")]
Expand All @@ -107,7 +113,7 @@ pub enum SourceDefinition {
},
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct GraphQLSourceConfig {
pub endpoint: String,
}
Expand Down
2 changes: 1 addition & 1 deletion src/endpoint/endpoint_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl IntoResponse for EndpointError {
let (status_code, error_message) = match self {
EndpointError::UpstreamError(e) => (
StatusCode::BAD_GATEWAY,
format!("Invalid GraphQL variables JSON format: {:?}", e),
format!("Invalid response from the GraphQL upstream: {:}", e),
),
};

Expand Down
20 changes: 16 additions & 4 deletions src/graphql_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,22 @@ impl ParsedGraphQLRequest {
.map_err(ExtractGraphQLOperationError::GraphQLParserError)
}

pub fn is_mutation(&self) -> bool {
for definition in &self.parsed_operation.definitions {
if let Definition::Operation(OperationDefinition::Mutation(_)) = definition {
return true;
pub fn is_running_mutation(&self) -> bool {
if let Some(operation_name) = &self.request.operation_name {
for definition in &self.parsed_operation.definitions {
if let Definition::Operation(OperationDefinition::Mutation(mutation)) = definition {
if let Some(mutation_name) = &mutation.name {
if *mutation_name == *operation_name {
return true;
}
}
}
}
} else {
for definition in &self.parsed_operation.definitions {
if let Definition::Operation(OperationDefinition::Mutation(_)) = definition {
return true;
}
}
}

Expand Down
34 changes: 12 additions & 22 deletions src/http_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use http::{
header::{ACCEPT, CONTENT_TYPE},
HeaderMap, StatusCode,
};
use hyper::body::to_bytes;
use mime::Mime;
use mime::{APPLICATION_JSON, APPLICATION_WWW_FORM_URLENCODED};
use serde::de::Error as DeError;
use serde_json::{from_slice, from_str, Error as SerdeError, Map, Value};
use serde_json::{from_str, Error as SerdeError, Map, Value};
use std::collections::HashMap;

use crate::{
Expand Down Expand Up @@ -41,8 +40,9 @@ pub enum ExtractGraphQLOperationError {
InvalidVariablesJsonFormat(SerdeError),
InvalidExtensionsJsonFormat(SerdeError),
EmptyExtraction,
FailedToReadRequestBody(hyper::Error),
FailedToReadRequestBody,
GraphQLParserError(ParseError),
PersistedOperationNotFound,
}

impl ExtractGraphQLOperationError {
Expand All @@ -66,8 +66,11 @@ impl ExtractGraphQLOperationError {
ExtractGraphQLOperationError::EmptyExtraction => {
"Failed to location a GraphQL query in request".to_string()
}
ExtractGraphQLOperationError::FailedToReadRequestBody(e) => {
format!("Failed to read response body: {}", e)
ExtractGraphQLOperationError::FailedToReadRequestBody => {
"Failed to read request body".to_string()
}
ExtractGraphQLOperationError::PersistedOperationNotFound => {
"persisted operation not found in store".to_string()
}
ExtractGraphQLOperationError::GraphQLParserError(e) => e.to_string(),
});
Expand All @@ -87,9 +90,7 @@ impl ExtractGraphQLOperationError {
}
}

pub async fn extract_graphql_from_post_request<'a>(
flow_ctx: &mut FlowContext<'a>,
) -> ExtractionResult {
pub async fn extract_graphql_from_post_request(flow_ctx: &mut FlowContext<'_>) -> ExtractionResult {
// Extract the content-type and default to application/json when it's not set
// see https://graphql.github.io/graphql-over-http/draft/#sec-POST
let headers = flow_ctx.downstream_http_request.headers();
Expand All @@ -104,20 +105,9 @@ pub async fn extract_graphql_from_post_request<'a>(
);
}

let body_bytes = to_bytes(flow_ctx.downstream_http_request.body_mut()).await;
let body = flow_ctx.json_body::<GraphQLRequest>().await;

match body_bytes {
Ok(bytes) => (
Some(content_type),
accept,
from_slice(bytes.as_ref()).map_err(ExtractGraphQLOperationError::InvalidBodyJsonFormat),
),
Err(e) => (
Some(content_type),
accept,
Err(ExtractGraphQLOperationError::FailedToReadRequestBody(e)),
),
}
(Some(content_type), accept, body)
}

pub fn parse_and_extract_json_map_value(value: &str) -> Result<Map<String, Value>, SerdeError> {
Expand Down Expand Up @@ -154,7 +144,7 @@ pub fn extract_graphql_from_get_request(flow_ctx: &mut FlowContext) -> Extractio
.into_owned()
.collect()
})
.unwrap_or_else(HashMap::new);
.unwrap_or_default();

match params.get("query") {
Some(operation) => {
Expand Down
16 changes: 9 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod http_utils;
pub mod plugins;
pub mod source;
pub mod test;
pub mod utils;

use std::sync::Arc;

Expand Down Expand Up @@ -172,14 +173,15 @@ pub(crate) fn create_router_from_config(config_object: ConductorConfig) -> IntoM
plugin_manager.clone(),
);

debug!("creating router route");

http_router = http_router
.route(endpoint_config.path.as_str(), any(http_request_handler))
.route_layer(Extension(endpoint_runtime));

debug!("creating router child router");
let mut nested_router = Router::new()
.route("/*catch_all", any(http_request_handler))
.route("/", any(http_request_handler))
.layer(Extension(endpoint_runtime));
debug!("calling on_endpoint_creation on route");
http_router = plugin_manager.on_endpoint_creation(http_router);
(http_router, nested_router) =
plugin_manager.on_endpoint_creation(http_router, nested_router);
http_router = http_router.nest(endpoint_config.path.as_str(), nested_router)
}

http_router.into_make_service()
Expand Down
9 changes: 7 additions & 2 deletions src/plugins/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ use super::flow_context::FlowContext;

#[async_trait::async_trait]
pub trait Plugin: Sync + Send {
fn on_endpoint_creation(&self, _router: Router<()>) -> axum::Router<()> {
_router
fn on_endpoint_creation(
&self,
_root_router: Router<()>,
_endpoint_router: Router<()>,
) -> (axum::Router<()>, axum::Router<()>) {
(_root_router, _endpoint_router)
}

// An HTTP request send from the client to Conductor
async fn on_downstream_http_request(&self, _ctx: &mut FlowContext) {}
// A final HTTP response send from Conductor to the client
Expand Down
18 changes: 11 additions & 7 deletions src/plugins/cors.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
use std::time::Duration;

use http::{HeaderValue, Method};
use serde::{Deserialize, Deserializer, Serialize};
use serde::{Deserialize, Deserializer};
use tower_http::cors::{Any, CorsLayer};
use tracing::{debug, info};

use super::core::Plugin;

pub struct CorsPlugin(pub CorsPluginConfig);

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum CorsListStringConfig {
#[serde(deserialize_with = "deserialize_wildcard")]
Wildcard,
List(Vec<String>),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum CorsStringConfig {
#[serde(deserialize_with = "deserialize_wildcard")]
Expand All @@ -38,7 +38,7 @@ where
Helper::deserialize(deserializer).map(|_| ())
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct CorsPluginConfig {
allow_credentials: Option<bool>,
allowed_methods: Option<CorsListStringConfig>,
Expand All @@ -61,7 +61,11 @@ impl CorsPluginConfig {

#[async_trait::async_trait]
impl Plugin for CorsPlugin {
fn on_endpoint_creation(&self, router: axum::Router<()>) -> axum::Router<()> {
fn on_endpoint_creation(
&self,
root_router: axum::Router<()>,
endpoint_router: axum::Router<()>,
) -> (axum::Router<()>, axum::Router<()>) {
info!("CORS plugin registered, modifying route...");
debug!("using object config for CORS plugin, config: {:?}", self.0);

Expand All @@ -77,7 +81,7 @@ impl Plugin for CorsPlugin {

debug!("CORS layer configuration: {:?}", layer);

router.route_layer(layer)
(root_router, endpoint_router.route_layer(layer))
}
false => {
let mut layer = CorsLayer::new();
Expand Down Expand Up @@ -130,7 +134,7 @@ impl Plugin for CorsPlugin {

debug!("CORS layer configuration: {:?}", layer);

router.route_layer(layer)
(root_router, endpoint_router.route_layer(layer))
}
}
}
Expand Down
38 changes: 36 additions & 2 deletions src/plugins/flow_context.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
use axum::{body::BoxBody, response::IntoResponse};
use http::{Request, Response};
use serde::de::DeserializeOwned;
use serde_json::from_slice;

use crate::{endpoint::endpoint_runtime::EndpointRuntime, graphql_utils::ParsedGraphQLRequest};
use hyper::Body;
use crate::{
endpoint::endpoint_runtime::EndpointRuntime, graphql_utils::ParsedGraphQLRequest,
http_utils::ExtractGraphQLOperationError,
};
use hyper::{body::to_bytes, Body};

#[derive(Debug)]
pub struct FlowContext<'a> {
pub endpoint: Option<&'a EndpointRuntime>,
pub downstream_graphql_request: Option<ParsedGraphQLRequest>,
pub downstream_http_request: &'a mut Request<Body>,
pub short_circuit_response: Option<Response<BoxBody>>,
pub downstream_request_body_bytes: Option<Result<tokio_util::bytes::Bytes, hyper::Error>>,
}

impl<'a> FlowContext<'a> {
Expand All @@ -19,6 +25,33 @@ impl<'a> FlowContext<'a> {
downstream_http_request: request,
short_circuit_response: None,
endpoint: Some(endpoint),
downstream_request_body_bytes: None,
}
}

pub async fn consume_body(&mut self) -> &Result<tokio_util::bytes::Bytes, hyper::Error> {
if self.downstream_request_body_bytes.is_none() {
self.downstream_request_body_bytes =
Some(to_bytes(self.downstream_http_request.body_mut()).await);
}

return self.downstream_request_body_bytes.as_ref().unwrap();
}

pub async fn json_body<T>(&mut self) -> Result<T, ExtractGraphQLOperationError>
where
T: DeserializeOwned,
{
let body_bytes = self.consume_body().await;

match body_bytes {
Ok(bytes) => {
let json = from_slice::<T>(bytes)
.map_err(ExtractGraphQLOperationError::InvalidBodyJsonFormat)?;

Ok(json)
}
Err(_e) => Err(ExtractGraphQLOperationError::FailedToReadRequestBody),
}
}

Expand All @@ -29,6 +62,7 @@ impl<'a> FlowContext<'a> {
downstream_http_request: request,
short_circuit_response: None,
endpoint: None,
downstream_request_body_bytes: None,
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/plugins/http_get_plugin.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use http::StatusCode;
use http::{Method, StatusCode};

use crate::{
graphql_utils::{GraphQLResponse, ParsedGraphQLRequest},
http_utils::{extract_graphql_from_get_request, ExtractGraphQLOperationError},
};

use super::{core::Plugin, flow_context::FlowContext};
use serde::{Deserialize, Serialize};
use serde::Deserialize;

#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone)]
pub struct HttpGetPluginConfig {
allow: bool,
mutations: Option<bool>,
}

Expand Down Expand Up @@ -42,9 +41,11 @@ impl Plugin for HttpGetPlugin {
}

async fn on_downstream_graphql_request(&self, ctx: &mut FlowContext) {
if self.0.mutations.is_none() || self.0.mutations == Some(false) {
if ctx.downstream_http_request.method() == Method::GET
&& (self.0.mutations.is_none() || self.0.mutations == Some(false))
{
if let Some(gql_req) = &ctx.downstream_graphql_request {
if gql_req.is_mutation() {
if gql_req.is_running_mutation() {
ctx.short_circuit(
GraphQLResponse::new_error("mutations are not allowed over GET")
.into_response(StatusCode::METHOD_NOT_ALLOWED),
Expand Down
1 change: 1 addition & 0 deletions src/plugins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pub mod flow_context;
pub mod graphiql_plugin;
pub mod http_get_plugin;
pub mod match_content_type;
pub mod persisted_documents;
pub mod plugin_manager;
Loading

0 comments on commit 464f030

Please sign in to comment.