diff --git a/src/config/mod.rs b/src/config/mod.rs index 2edd0134..b93c1306 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,7 +1,10 @@ -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use std::{fs::read_to_string, path::Path}; -use crate::plugins::{cors::CorsPluginConfig, http_get_plugin::HttpGetPluginConfig}; +use crate::plugins::{ + cors::CorsPluginConfig, http_get_plugin::HttpGetPluginConfig, + persisted_documents::config::PersistedOperationsPluginConfig, +}; #[derive(Deserialize, Debug, Clone)] pub struct ConductorConfig { @@ -21,7 +24,7 @@ pub struct EndpointDefinition { pub plugins: Option>, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] #[serde(tag = "type")] pub enum PluginDefinition { #[serde(rename = "cors")] @@ -32,6 +35,9 @@ pub enum PluginDefinition { #[serde(rename = "http_get")] HttpGetPlugin(HttpGetPluginConfig), + + #[serde(rename = "persisted_operations")] + PersistedOperationsPlugin(PersistedOperationsPluginConfig), } #[derive(Debug, Clone, Copy)] @@ -69,7 +75,7 @@ pub struct LoggerConfig { pub level: Level, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] pub struct ServerConfig { #[serde(default = "default_server_port")] pub port: u16, @@ -97,7 +103,7 @@ fn default_server_host() -> String { "127.0.0.1".to_string() } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] #[serde(tag = "type")] pub enum SourceDefinition { #[serde(rename = "graphql")] @@ -107,7 +113,7 @@ pub enum SourceDefinition { }, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] pub struct GraphQLSourceConfig { pub endpoint: String, } diff --git a/src/endpoint/endpoint_runtime.rs b/src/endpoint/endpoint_runtime.rs index ceae9c38..f4c45cdb 100644 --- a/src/endpoint/endpoint_runtime.rs +++ b/src/endpoint/endpoint_runtime.rs @@ -25,7 +25,7 @@ impl IntoResponse for EndpointError { let (status_code, error_message) = match self { EndpointError::UpstreamError(e) => ( StatusCode::BAD_GATEWAY, - format!("Invalid GraphQL variables JSON format: {:?}", e), + format!("Invalid response from the GraphQL upstream: {:}", e), ), }; diff --git a/src/graphql_utils.rs b/src/graphql_utils.rs index cef355c6..52211df2 100644 --- a/src/graphql_utils.rs +++ b/src/graphql_utils.rs @@ -65,10 +65,22 @@ impl ParsedGraphQLRequest { .map_err(ExtractGraphQLOperationError::GraphQLParserError) } - pub fn is_mutation(&self) -> bool { - for definition in &self.parsed_operation.definitions { - if let Definition::Operation(OperationDefinition::Mutation(_)) = definition { - return true; + pub fn is_running_mutation(&self) -> bool { + if let Some(operation_name) = &self.request.operation_name { + for definition in &self.parsed_operation.definitions { + if let Definition::Operation(OperationDefinition::Mutation(mutation)) = definition { + if let Some(mutation_name) = &mutation.name { + if *mutation_name == *operation_name { + return true; + } + } + } + } + } else { + for definition in &self.parsed_operation.definitions { + if let Definition::Operation(OperationDefinition::Mutation(_)) = definition { + return true; + } } } diff --git a/src/http_utils.rs b/src/http_utils.rs index a8cc976a..6263e7a9 100644 --- a/src/http_utils.rs +++ b/src/http_utils.rs @@ -3,11 +3,10 @@ use http::{ header::{ACCEPT, CONTENT_TYPE}, HeaderMap, StatusCode, }; -use hyper::body::to_bytes; use mime::Mime; use mime::{APPLICATION_JSON, APPLICATION_WWW_FORM_URLENCODED}; use serde::de::Error as DeError; -use serde_json::{from_slice, from_str, Error as SerdeError, Map, Value}; +use serde_json::{from_str, Error as SerdeError, Map, Value}; use std::collections::HashMap; use crate::{ @@ -41,8 +40,9 @@ pub enum ExtractGraphQLOperationError { InvalidVariablesJsonFormat(SerdeError), InvalidExtensionsJsonFormat(SerdeError), EmptyExtraction, - FailedToReadRequestBody(hyper::Error), + FailedToReadRequestBody, GraphQLParserError(ParseError), + PersistedOperationNotFound, } impl ExtractGraphQLOperationError { @@ -66,8 +66,11 @@ impl ExtractGraphQLOperationError { ExtractGraphQLOperationError::EmptyExtraction => { "Failed to location a GraphQL query in request".to_string() } - ExtractGraphQLOperationError::FailedToReadRequestBody(e) => { - format!("Failed to read response body: {}", e) + ExtractGraphQLOperationError::FailedToReadRequestBody => { + "Failed to read request body".to_string() + } + ExtractGraphQLOperationError::PersistedOperationNotFound => { + "persisted operation not found in store".to_string() } ExtractGraphQLOperationError::GraphQLParserError(e) => e.to_string(), }); @@ -87,9 +90,7 @@ impl ExtractGraphQLOperationError { } } -pub async fn extract_graphql_from_post_request<'a>( - flow_ctx: &mut FlowContext<'a>, -) -> ExtractionResult { +pub async fn extract_graphql_from_post_request(flow_ctx: &mut FlowContext<'_>) -> ExtractionResult { // Extract the content-type and default to application/json when it's not set // see https://graphql.github.io/graphql-over-http/draft/#sec-POST let headers = flow_ctx.downstream_http_request.headers(); @@ -104,20 +105,9 @@ pub async fn extract_graphql_from_post_request<'a>( ); } - let body_bytes = to_bytes(flow_ctx.downstream_http_request.body_mut()).await; + let body = flow_ctx.json_body::().await; - match body_bytes { - Ok(bytes) => ( - Some(content_type), - accept, - from_slice(bytes.as_ref()).map_err(ExtractGraphQLOperationError::InvalidBodyJsonFormat), - ), - Err(e) => ( - Some(content_type), - accept, - Err(ExtractGraphQLOperationError::FailedToReadRequestBody(e)), - ), - } + (Some(content_type), accept, body) } pub fn parse_and_extract_json_map_value(value: &str) -> Result, SerdeError> { @@ -154,7 +144,7 @@ pub fn extract_graphql_from_get_request(flow_ctx: &mut FlowContext) -> Extractio .into_owned() .collect() }) - .unwrap_or_else(HashMap::new); + .unwrap_or_default(); match params.get("query") { Some(operation) => { diff --git a/src/lib.rs b/src/lib.rs index b6b5e897..7095aefc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod http_utils; pub mod plugins; pub mod source; pub mod test; +pub mod utils; use std::sync::Arc; @@ -172,14 +173,15 @@ pub(crate) fn create_router_from_config(config_object: ConductorConfig) -> IntoM plugin_manager.clone(), ); - debug!("creating router route"); - - http_router = http_router - .route(endpoint_config.path.as_str(), any(http_request_handler)) - .route_layer(Extension(endpoint_runtime)); - + debug!("creating router child router"); + let mut nested_router = Router::new() + .route("/*catch_all", any(http_request_handler)) + .route("/", any(http_request_handler)) + .layer(Extension(endpoint_runtime)); debug!("calling on_endpoint_creation on route"); - http_router = plugin_manager.on_endpoint_creation(http_router); + (http_router, nested_router) = + plugin_manager.on_endpoint_creation(http_router, nested_router); + http_router = http_router.nest(endpoint_config.path.as_str(), nested_router) } http_router.into_make_service() diff --git a/src/plugins/core.rs b/src/plugins/core.rs index e1e8a7f5..f967b993 100644 --- a/src/plugins/core.rs +++ b/src/plugins/core.rs @@ -9,9 +9,14 @@ use super::flow_context::FlowContext; #[async_trait::async_trait] pub trait Plugin: Sync + Send { - fn on_endpoint_creation(&self, _router: Router<()>) -> axum::Router<()> { - _router + fn on_endpoint_creation( + &self, + _root_router: Router<()>, + _endpoint_router: Router<()>, + ) -> (axum::Router<()>, axum::Router<()>) { + (_root_router, _endpoint_router) } + // An HTTP request send from the client to Conductor async fn on_downstream_http_request(&self, _ctx: &mut FlowContext) {} // A final HTTP response send from Conductor to the client diff --git a/src/plugins/cors.rs b/src/plugins/cors.rs index b28249e6..53d35495 100644 --- a/src/plugins/cors.rs +++ b/src/plugins/cors.rs @@ -1,7 +1,7 @@ use std::time::Duration; use http::{HeaderValue, Method}; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Deserializer}; use tower_http::cors::{Any, CorsLayer}; use tracing::{debug, info}; @@ -9,7 +9,7 @@ use super::core::Plugin; pub struct CorsPlugin(pub CorsPluginConfig); -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] #[serde(untagged)] pub enum CorsListStringConfig { #[serde(deserialize_with = "deserialize_wildcard")] @@ -17,7 +17,7 @@ pub enum CorsListStringConfig { List(Vec), } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] #[serde(untagged)] pub enum CorsStringConfig { #[serde(deserialize_with = "deserialize_wildcard")] @@ -38,7 +38,7 @@ where Helper::deserialize(deserializer).map(|_| ()) } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] pub struct CorsPluginConfig { allow_credentials: Option, allowed_methods: Option, @@ -61,7 +61,11 @@ impl CorsPluginConfig { #[async_trait::async_trait] impl Plugin for CorsPlugin { - fn on_endpoint_creation(&self, router: axum::Router<()>) -> axum::Router<()> { + fn on_endpoint_creation( + &self, + root_router: axum::Router<()>, + endpoint_router: axum::Router<()>, + ) -> (axum::Router<()>, axum::Router<()>) { info!("CORS plugin registered, modifying route..."); debug!("using object config for CORS plugin, config: {:?}", self.0); @@ -77,7 +81,7 @@ impl Plugin for CorsPlugin { debug!("CORS layer configuration: {:?}", layer); - router.route_layer(layer) + (root_router, endpoint_router.route_layer(layer)) } false => { let mut layer = CorsLayer::new(); @@ -130,7 +134,7 @@ impl Plugin for CorsPlugin { debug!("CORS layer configuration: {:?}", layer); - router.route_layer(layer) + (root_router, endpoint_router.route_layer(layer)) } } } diff --git a/src/plugins/flow_context.rs b/src/plugins/flow_context.rs index 4229afc8..063fbe14 100644 --- a/src/plugins/flow_context.rs +++ b/src/plugins/flow_context.rs @@ -1,8 +1,13 @@ use axum::{body::BoxBody, response::IntoResponse}; use http::{Request, Response}; +use serde::de::DeserializeOwned; +use serde_json::from_slice; -use crate::{endpoint::endpoint_runtime::EndpointRuntime, graphql_utils::ParsedGraphQLRequest}; -use hyper::Body; +use crate::{ + endpoint::endpoint_runtime::EndpointRuntime, graphql_utils::ParsedGraphQLRequest, + http_utils::ExtractGraphQLOperationError, +}; +use hyper::{body::to_bytes, Body}; #[derive(Debug)] pub struct FlowContext<'a> { @@ -10,6 +15,7 @@ pub struct FlowContext<'a> { pub downstream_graphql_request: Option, pub downstream_http_request: &'a mut Request, pub short_circuit_response: Option>, + pub downstream_request_body_bytes: Option>, } impl<'a> FlowContext<'a> { @@ -19,6 +25,33 @@ impl<'a> FlowContext<'a> { downstream_http_request: request, short_circuit_response: None, endpoint: Some(endpoint), + downstream_request_body_bytes: None, + } + } + + pub async fn consume_body(&mut self) -> &Result { + if self.downstream_request_body_bytes.is_none() { + self.downstream_request_body_bytes = + Some(to_bytes(self.downstream_http_request.body_mut()).await); + } + + return self.downstream_request_body_bytes.as_ref().unwrap(); + } + + pub async fn json_body(&mut self) -> Result + where + T: DeserializeOwned, + { + let body_bytes = self.consume_body().await; + + match body_bytes { + Ok(bytes) => { + let json = from_slice::(bytes) + .map_err(ExtractGraphQLOperationError::InvalidBodyJsonFormat)?; + + Ok(json) + } + Err(_e) => Err(ExtractGraphQLOperationError::FailedToReadRequestBody), } } @@ -29,6 +62,7 @@ impl<'a> FlowContext<'a> { downstream_http_request: request, short_circuit_response: None, endpoint: None, + downstream_request_body_bytes: None, } } diff --git a/src/plugins/http_get_plugin.rs b/src/plugins/http_get_plugin.rs index e3112b61..9f22f6b4 100644 --- a/src/plugins/http_get_plugin.rs +++ b/src/plugins/http_get_plugin.rs @@ -1,4 +1,4 @@ -use http::StatusCode; +use http::{Method, StatusCode}; use crate::{ graphql_utils::{GraphQLResponse, ParsedGraphQLRequest}, @@ -6,11 +6,10 @@ use crate::{ }; use super::{core::Plugin, flow_context::FlowContext}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone)] pub struct HttpGetPluginConfig { - allow: bool, mutations: Option, } @@ -42,9 +41,11 @@ impl Plugin for HttpGetPlugin { } async fn on_downstream_graphql_request(&self, ctx: &mut FlowContext) { - if self.0.mutations.is_none() || self.0.mutations == Some(false) { + if ctx.downstream_http_request.method() == Method::GET + && (self.0.mutations.is_none() || self.0.mutations == Some(false)) + { if let Some(gql_req) = &ctx.downstream_graphql_request { - if gql_req.is_mutation() { + if gql_req.is_running_mutation() { ctx.short_circuit( GraphQLResponse::new_error("mutations are not allowed over GET") .into_response(StatusCode::METHOD_NOT_ALLOWED), diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index 28a1fd22..79cda00f 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -4,4 +4,5 @@ pub mod flow_context; pub mod graphiql_plugin; pub mod http_get_plugin; pub mod match_content_type; +pub mod persisted_documents; pub mod plugin_manager; diff --git a/src/plugins/persisted_documents/config.rs b/src/plugins/persisted_documents/config.rs new file mode 100644 index 00000000..2527636a --- /dev/null +++ b/src/plugins/persisted_documents/config.rs @@ -0,0 +1,96 @@ +use serde::Deserialize; + +use crate::utils::serde_utils::LocalFileReference; + +use super::store::fs::PersistedDocumentsFileFormat; + +#[derive(Deserialize, Debug, Clone)] +pub struct PersistedOperationsPluginConfig { + pub store: PersistedOperationsPluginStoreConfig, + pub allow_non_persisted: Option, + pub protocols: Vec, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "source")] +pub enum PersistedOperationsPluginStoreConfig { + #[serde(rename = "file")] + File { + #[serde(rename = "path")] + file: LocalFileReference, + format: PersistedDocumentsFileFormat, + }, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct ApolloPersistedQueryManifest { + pub format: String, + pub version: i32, + pub operations: Vec, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct ApolloPersistedQueryManifestRecord { + pub id: String, + pub body: String, + pub name: String, + #[serde(rename = "type")] + pub operation_type: String, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "type")] +pub enum PersistedOperationsProtocolConfig { + #[serde(rename = "apollo_manifest_extensions")] + ApolloManifestExtensions, + #[serde(rename = "document_id")] + DocumentId { + #[serde(default = "document_id_default_field_name")] + field_name: String, + }, + #[serde(rename = "http_get")] + HttpGet { + #[serde(default = "PersistedOperationHttpGetParameterLocation::document_id_default")] + document_id_from: PersistedOperationHttpGetParameterLocation, + #[serde(default = "PersistedOperationHttpGetParameterLocation::variables_default")] + variables_from: PersistedOperationHttpGetParameterLocation, + #[serde(default = "PersistedOperationHttpGetParameterLocation::operation_name_default")] + operation_name_from: PersistedOperationHttpGetParameterLocation, + }, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "source")] +pub enum PersistedOperationHttpGetParameterLocation { + // TODO: This doesn't work when parsed from config + #[serde(rename = "search_query")] + Query { name: String }, + #[serde(rename = "path")] + Path { position: usize }, + #[serde(rename = "header")] + Header { name: String }, +} + +impl PersistedOperationHttpGetParameterLocation { + pub fn document_id_default() -> Self { + PersistedOperationHttpGetParameterLocation::Query { + name: document_id_default_field_name(), + } + } + + pub fn variables_default() -> Self { + PersistedOperationHttpGetParameterLocation::Query { + name: "variables".to_string(), + } + } + + pub fn operation_name_default() -> Self { + PersistedOperationHttpGetParameterLocation::Query { + name: "operationName".to_string(), + } + } +} + +fn document_id_default_field_name() -> String { + "documentId".to_string() +} diff --git a/src/plugins/persisted_documents/mod.rs b/src/plugins/persisted_documents/mod.rs new file mode 100644 index 00000000..bdcc1921 --- /dev/null +++ b/src/plugins/persisted_documents/mod.rs @@ -0,0 +1,4 @@ +pub mod config; +pub mod plugin; +pub mod protocols; +pub mod store; diff --git a/src/plugins/persisted_documents/plugin.rs b/src/plugins/persisted_documents/plugin.rs new file mode 100644 index 00000000..96501cc2 --- /dev/null +++ b/src/plugins/persisted_documents/plugin.rs @@ -0,0 +1,300 @@ +use crate::{ + graphql_utils::{GraphQLRequest, ParsedGraphQLRequest}, + http_utils::ExtractGraphQLOperationError, + plugins::{ + core::Plugin, + flow_context::FlowContext, + persisted_documents::{ + config::PersistedOperationsPluginStoreConfig, + protocols::{ + apollo_manifest::ApolloManifestPersistedDocumentsProtocol, + document_id::DocumentIdPersistedDocumentsProtocol, + get_handler::PersistedDocumentsGetHandler, + }, + store::fs::PersistedDocumentsFilesystemStore, + }, + }, +}; + +use super::{ + config::{PersistedOperationsPluginConfig, PersistedOperationsProtocolConfig}, + protocols::PersistedDocumentsProtocol, + store::PersistedDocumentsStore, +}; +use async_trait::async_trait; +use tracing::{debug, error, info, warn}; + +pub struct PersistedOperationsPlugin { + config: PersistedOperationsPluginConfig, + incoming_message_handlers: Vec>, + store: Box, +} + +type ErrorMessage = String; + +#[derive(Debug)] +pub enum PersistedOperationsPluginError { + StoreCreationError(ErrorMessage), +} + +impl PersistedOperationsPlugin { + pub fn new_from_config( + config: PersistedOperationsPluginConfig, + ) -> Result { + debug!("creating persisted operations plugin"); + + let store: Box = match &config.store { + PersistedOperationsPluginStoreConfig::File { file, format } => { + let fs_store = PersistedDocumentsFilesystemStore::new_from_file_contents( + &file.contents, + format, + ) + .map_err(|pe| PersistedOperationsPluginError::StoreCreationError(pe.to_string()))?; + + Box::new(fs_store) + } + }; + + let incoming_message_handlers: Vec> = config + .protocols + .iter() + .map(|protocol| match protocol { + PersistedOperationsProtocolConfig::DocumentId { field_name } => { + debug!("adding persisted documents protocol of type document_id with field_name: {}", field_name); + + Box::new(DocumentIdPersistedDocumentsProtocol { + field_name: field_name.to_string(), + }) as Box + } + PersistedOperationsProtocolConfig::ApolloManifestExtensions => { + debug!("adding persisted documents protocol of type apollo_manifest (extensions) with field_name"); + + Box::new(ApolloManifestPersistedDocumentsProtocol {}) + as Box + } + PersistedOperationsProtocolConfig::HttpGet { + document_id_from, + variables_from, + operation_name_from, + } => { + debug!( + "adding persisted documents protocol of type get HTTP with the following sources: {:?}, {:?}, {:?}", + document_id_from, variables_from, operation_name_from + ); + + Box::new(PersistedDocumentsGetHandler { + document_id_from: document_id_from.clone(), + variables_from: variables_from.clone(), + operation_name_from: operation_name_from.clone(), + }) as Box + } + }) + .collect(); + + Ok(Self { + config, + store, + incoming_message_handlers, + }) + } +} + +#[async_trait] +impl Plugin for PersistedOperationsPlugin { + async fn on_downstream_http_request(&self, ctx: &mut FlowContext) { + if ctx.downstream_graphql_request.is_some() { + return; + } + + for extractor in &self.incoming_message_handlers { + debug!( + "trying to extract persisted operation from incoming request, extractor: {:?}", + extractor + ); + if let Some(extracted) = extractor.as_ref().try_extraction(ctx).await { + info!( + "extracted persisted operation from incoming request: {:?}", + extracted + ); + + if let Some(op) = self.store.get_document(&extracted.hash).await { + debug!("found persisted operation with id {:?}", extracted.hash); + + match ParsedGraphQLRequest::create_and_parse(GraphQLRequest { + operation: op.clone(), + operation_name: extracted.operation_name, + variables: extracted.variables, + extensions: extracted.extensions, + }) { + Ok(parsed) => { + debug!( + "extracted persisted operation is valid, updating request context: {:?}", parsed + ); + + ctx.downstream_graphql_request = Some(parsed); + return; + } + Err(e) => { + warn!("failed to parse GraphQL request from a store object with key {:?}, error: {:?}", e, extracted.hash); + + ctx.short_circuit(e.into_response(None)); + return; + } + } + } else { + warn!("persisted operation with id {:?} not found", extracted.hash); + } + } + } + + if self.config.allow_non_persisted != Some(true) { + error!("non-persisted operations are not allowed, short-circute with an error"); + + ctx.short_circuit( + ExtractGraphQLOperationError::PersistedOperationNotFound.into_response(None), + ); + return; + } + } + + async fn on_downstream_graphql_request(&self, ctx: &mut FlowContext) { + for item in self.incoming_message_handlers.iter() { + if let Some(response) = item.as_ref().should_prevent_execution(ctx) { + warn!( + "persisted operation execution was prevented, due to falsy value returned from should_prevent_execution from extractor {:?}",item + ); + ctx.short_circuit(response); + } + } + } +} + +#[tokio::test] +async fn persisted_documents_plugin() { + use crate::endpoint::endpoint_runtime::EndpointRuntime; + use http::Request; + use hyper::Body; + use serde_json::json; + + let config = PersistedOperationsPluginConfig { + store: PersistedOperationsPluginStoreConfig::File { file: crate::utils::serde_utils::LocalFileReference { + path: "dummy.json".to_string(), + contents: json!({ + "key1": "query { hello }", + "key2": "query { hello2 }", + }).to_string(), + }, format: crate::plugins::persisted_documents::store::fs::PersistedDocumentsFileFormat::JsonKeyValue }, + allow_non_persisted: Some(false), + protocols: vec![ + PersistedOperationsProtocolConfig::DocumentId { + field_name: "documentId".to_string(), + }, + PersistedOperationsProtocolConfig::ApolloManifestExtensions, + PersistedOperationsProtocolConfig::HttpGet { + document_id_from: crate::plugins::persisted_documents::config::PersistedOperationHttpGetParameterLocation::Query { + name: "documentId".to_string(), + }, + variables_from: crate::plugins::persisted_documents::config::PersistedOperationHttpGetParameterLocation::Query { + name: "variables".to_string(), + }, + operation_name_from: crate::plugins::persisted_documents::config::PersistedOperationHttpGetParameterLocation::Query { + name: "operationName".to_string(), + }, + }, + ], + }; + let plugin = + PersistedOperationsPlugin::new_from_config(config).expect("failed to create plugin"); + let endpoint = EndpointRuntime::mocked_endpoint(); + + // Try to use POST with "documentId" in the body + let mut req = Request::builder() + .method("POST") + .body(Body::from( + json!({ + "documentId": "key1" + }) + .to_string(), + )) + .unwrap(); + let mut ctx = FlowContext::new(&endpoint, &mut req); + plugin.on_downstream_http_request(&mut ctx).await; + assert_eq!(ctx.is_short_circuit(), false); + assert_eq!(ctx.downstream_graphql_request.is_some(), true); + assert_eq!( + ctx.downstream_graphql_request + .unwrap() + .parsed_operation + .to_string(), + "query {\n hello\n}\n" + ); + + // Try to use POST with "extensions" in the body + let mut req = Request::builder() + .method("POST") + .body(Body::from( + json!({ + "extensions": { + "persistedQuery": { + "sha256Hash": "key2" + } + } + }) + .to_string(), + )) + .unwrap(); + let mut ctx = FlowContext::new(&endpoint, &mut req); + plugin.on_downstream_http_request(&mut ctx).await; + assert_eq!(ctx.is_short_circuit(), false); + assert_eq!(ctx.downstream_graphql_request.is_some(), true); + assert_eq!( + ctx.downstream_graphql_request + .unwrap() + .parsed_operation + .to_string(), + "query {\n hello2\n}\n" + ); + + // Try to use GET with query params + let mut req = Request::builder() + .method("GET") + .uri("http://localhost:8080/graphql?documentId=key2") + .body(Body::empty()) + .unwrap(); + let mut ctx = FlowContext::new(&endpoint, &mut req); + plugin.on_downstream_http_request(&mut ctx).await; + assert_eq!(ctx.is_short_circuit(), false); + assert_eq!(ctx.downstream_graphql_request.is_some(), true); + assert_eq!( + ctx.downstream_graphql_request + .unwrap() + .parsed_operation + .to_string(), + "query {\n hello2\n}\n" + ); + + // Try to use a non-existing key in store + let mut req = Request::builder() + .method("GET") + .uri("http://localhost:8080/graphql?documentId=does_not_exists") + .body(Body::empty()) + .unwrap(); + let mut ctx = FlowContext::new(&endpoint, &mut req); + plugin.on_downstream_http_request(&mut ctx).await; + assert_eq!(ctx.is_short_circuit(), true); + assert_eq!(ctx.downstream_graphql_request.is_none(), true); + + // Try run a POST query with regular GraphQL (not allowed) + let mut req = Request::builder() + .method("POST") + .body(Body::from( + json!({ + "query": "query { __typename }" + }) + .to_string(), + )) + .unwrap(); + let mut ctx = FlowContext::new(&endpoint, &mut req); + plugin.on_downstream_http_request(&mut ctx).await; + assert_eq!(ctx.is_short_circuit(), true); +} diff --git a/src/plugins/persisted_documents/protocols/apollo_manifest.rs b/src/plugins/persisted_documents/protocols/apollo_manifest.rs new file mode 100644 index 00000000..3f6148fb --- /dev/null +++ b/src/plugins/persisted_documents/protocols/apollo_manifest.rs @@ -0,0 +1,61 @@ +use crate::plugins::flow_context::FlowContext; +use http::Method; +use serde::Deserialize; +use serde_json::{Map, Value}; +use tracing::{debug, info}; + +use super::{ExtractedPersistedDocument, PersistedDocumentsProtocol}; + +#[derive(Debug)] +pub struct ApolloManifestPersistedDocumentsProtocol; + +#[derive(Deserialize, Debug)] + +struct ApolloPersistedOperationsIncomingMessage { + variables: Option>, + #[serde(rename = "operationName")] + operation_name: Option, + extensions: Extensions, +} + +#[derive(Deserialize, Debug)] +struct Extensions { + #[serde(rename = "persistedQuery")] + persisted_query: PersistedQuery, + #[serde(flatten)] + other: Map, +} + +#[derive(Deserialize, Debug)] +struct PersistedQuery { + #[serde(rename = "sha256Hash")] + hash: String, +} + +#[async_trait::async_trait] +impl PersistedDocumentsProtocol for ApolloManifestPersistedDocumentsProtocol { + async fn try_extraction(&self, ctx: &mut FlowContext) -> Option { + if ctx.downstream_http_request.method() == Method::POST { + debug!("request http method is post, trying to extract from body..."); + + if let Ok(message) = ctx + .json_body::() + .await + { + info!( + "succuessfully extracted incoming persisted operation from request: {:?}", + message + ); + + return Some(ExtractedPersistedDocument { + hash: message.extensions.persisted_query.hash, + variables: message.variables, + operation_name: message.operation_name, + extensions: Some(message.extensions.other), + }); + } + } + + None + } +} diff --git a/src/plugins/persisted_documents/protocols/document_id.rs b/src/plugins/persisted_documents/protocols/document_id.rs new file mode 100644 index 00000000..281c0f20 --- /dev/null +++ b/src/plugins/persisted_documents/protocols/document_id.rs @@ -0,0 +1,53 @@ +use crate::plugins::flow_context::FlowContext; +use http::Method; +use serde_json::Value; +use tracing::{debug, info}; + +use super::{ExtractedPersistedDocument, PersistedDocumentsProtocol}; + +#[derive(Debug)] +pub struct DocumentIdPersistedDocumentsProtocol { + pub field_name: String, +} + +#[async_trait::async_trait] +impl PersistedDocumentsProtocol for DocumentIdPersistedDocumentsProtocol { + async fn try_extraction(&self, ctx: &mut FlowContext) -> Option { + if ctx.downstream_http_request.method() == Method::POST { + debug!("request http method is post, trying to extract from body..."); + + if let Ok(root_object) = ctx.json_body::().await { + debug!( + "found valid JSON body in request, trying to extract the document id using field_name: {}", + self.field_name + ); + + if let Some(op_id) = root_object + .get(self.field_name.as_str()) + .and_then(|v| v.as_str()) + .map(|v| v.to_string()) + { + info!("succuessfully extracted incoming persisted operation from request",); + + return Some(ExtractedPersistedDocument { + hash: op_id, + variables: root_object + .get("variables") + .and_then(|v| v.as_object()) + .cloned(), + operation_name: root_object + .get("operationName") + .and_then(|v| v.as_str()) + .map(|v| v.to_string()), + extensions: root_object + .get("extensions") + .and_then(|v| v.as_object()) + .cloned(), + }); + } + } + } + + None + } +} diff --git a/src/plugins/persisted_documents/protocols/get_handler.rs b/src/plugins/persisted_documents/protocols/get_handler.rs new file mode 100644 index 00000000..acea3d9e --- /dev/null +++ b/src/plugins/persisted_documents/protocols/get_handler.rs @@ -0,0 +1,152 @@ +use std::collections::HashMap; + +use axum::body::BoxBody; +use http::{HeaderMap, Method, Response, StatusCode, Uri}; +use tracing::{debug, info}; + +use crate::{ + graphql_utils::GraphQLResponse, + plugins::{ + flow_context::FlowContext, + persisted_documents::config::PersistedOperationHttpGetParameterLocation, + }, +}; + +use super::{ExtractedPersistedDocument, PersistedDocumentsProtocol}; + +#[derive(Debug)] +pub struct PersistedDocumentsGetHandler { + pub document_id_from: PersistedOperationHttpGetParameterLocation, + pub variables_from: PersistedOperationHttpGetParameterLocation, + pub operation_name_from: PersistedOperationHttpGetParameterLocation, +} + +fn extract_header(header_map: &HeaderMap, header_name: &String) -> Option { + header_map + .get(header_name) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()) +} + +fn extract_query_param(uri: &Uri, param_name: &String) -> Option { + let params: HashMap = uri + .query() + .map(|v| { + url::form_urlencoded::parse(v.as_bytes()) + .into_owned() + .collect() + }) + .unwrap_or_default(); + + params.get(param_name).cloned() +} + +fn extract_path_position(uri: &Uri, position: usize) -> Option { + uri.path() + .split('/') + .collect::>() + .get(position) + .map(|v| v.to_string()) +} + +impl PersistedDocumentsGetHandler { + fn maybe_document_id(&self, ctx: &FlowContext) -> Option { + debug!( + "trying to extract document id hash from source {:?}", + self.operation_name_from + ); + + match &self.document_id_from { + PersistedOperationHttpGetParameterLocation::Header { name } => { + extract_header(ctx.downstream_http_request.headers(), name) + } + PersistedOperationHttpGetParameterLocation::Query { name } => { + extract_query_param(ctx.downstream_http_request.uri(), name) + } + PersistedOperationHttpGetParameterLocation::Path { position } => { + extract_path_position(ctx.downstream_http_request.uri(), *position) + } + } + } + + fn maybe_variables(&self, ctx: &FlowContext) -> Option { + debug!( + "trying to extract variables from source {:?}", + self.operation_name_from + ); + + match &self.variables_from { + PersistedOperationHttpGetParameterLocation::Header { name } => { + extract_header(ctx.downstream_http_request.headers(), name) + } + PersistedOperationHttpGetParameterLocation::Query { name } => { + extract_query_param(ctx.downstream_http_request.uri(), name) + } + PersistedOperationHttpGetParameterLocation::Path { position } => { + extract_path_position(ctx.downstream_http_request.uri(), *position) + } + } + } + + fn maybe_operation_name(&self, ctx: &FlowContext) -> Option { + debug!( + "trying to extract operationName from source {:?}", + self.operation_name_from + ); + + match &self.operation_name_from { + PersistedOperationHttpGetParameterLocation::Header { name } => { + extract_header(ctx.downstream_http_request.headers(), name) + } + PersistedOperationHttpGetParameterLocation::Query { name } => { + extract_query_param(ctx.downstream_http_request.uri(), name) + } + PersistedOperationHttpGetParameterLocation::Path { position } => { + extract_path_position(ctx.downstream_http_request.uri(), *position) + } + } + } +} + +#[async_trait::async_trait] +impl PersistedDocumentsProtocol for PersistedDocumentsGetHandler { + async fn try_extraction(&self, ctx: &mut FlowContext) -> Option { + if ctx.downstream_http_request.method() == http::Method::GET { + debug!("request http method is get, trying to extract from body..."); + + if let Some(op_id) = self.maybe_document_id(ctx) { + info!("succuessfully extracted incoming persisted operation from request",); + + return Some(ExtractedPersistedDocument { + hash: op_id, + variables: self + .maybe_variables(ctx) + .and_then(|v| serde_json::from_str(&v).ok()), + operation_name: self.maybe_operation_name(ctx), + extensions: None, + }); + } + } + + None + } + + fn should_prevent_execution(&self, ctx: &mut FlowContext) -> Option> { + if ctx.downstream_http_request.method() == Method::GET { + if let Some(gql_req) = &ctx.downstream_graphql_request { + if gql_req.is_running_mutation() { + debug!( + "trying to execute mutation from the persisted document, preventing because of GET request", + ); + + return Some( + GraphQLResponse::new_error("mutations are not allowed over GET") + .into_response(StatusCode::METHOD_NOT_ALLOWED), + ); + } + } + } + + None + } +} diff --git a/src/plugins/persisted_documents/protocols/mod.rs b/src/plugins/persisted_documents/protocols/mod.rs new file mode 100644 index 00000000..5363b2c5 --- /dev/null +++ b/src/plugins/persisted_documents/protocols/mod.rs @@ -0,0 +1,27 @@ +pub mod apollo_manifest; +pub mod document_id; +pub mod get_handler; + +use std::fmt::Debug; + +use axum::body::BoxBody; +use http::Response; +use serde_json::{Map, Value}; + +use crate::plugins::flow_context::FlowContext; + +#[derive(Debug)] +pub struct ExtractedPersistedDocument { + pub hash: String, + pub variables: Option>, + pub operation_name: Option, + pub extensions: Option>, +} + +#[async_trait::async_trait] +pub trait PersistedDocumentsProtocol: Sync + Send + Debug { + async fn try_extraction(&self, ctx: &mut FlowContext) -> Option; + fn should_prevent_execution(&self, _ctx: &mut FlowContext) -> Option> { + None + } +} diff --git a/src/plugins/persisted_documents/store/fs.rs b/src/plugins/persisted_documents/store/fs.rs new file mode 100644 index 00000000..a3679039 --- /dev/null +++ b/src/plugins/persisted_documents/store/fs.rs @@ -0,0 +1,189 @@ +use serde::Deserialize; +use std::collections::HashMap; +use tracing::{debug, info}; + +use crate::plugins::persisted_documents::config::ApolloPersistedQueryManifest; + +use super::PersistedDocumentsStore; + +#[derive(Debug)] +pub struct PersistedDocumentsFilesystemStore { + known_documents: HashMap, +} + +#[derive(Deserialize, Debug, Clone)] +pub enum PersistedDocumentsFileFormat { + #[serde(rename = "apollo_persisted_query_manifest")] + ApolloPersistedQueryManifest, + #[serde(rename = "json_key_value")] + JsonKeyValue, +} + +#[async_trait::async_trait] +impl PersistedDocumentsStore for PersistedDocumentsFilesystemStore { + async fn has_document(&self, hash: &str) -> bool { + self.known_documents.contains_key(hash) + } + + async fn get_document(&self, hash: &str) -> Option<&String> { + self.known_documents.get(hash) + } +} + +impl PersistedDocumentsFilesystemStore { + pub fn new_from_file_contents( + contents: &String, + file_format: &PersistedDocumentsFileFormat, + ) -> Result { + debug!( + "creating persisted operations store from a local FS file, the expected file format is: {:?}", + file_format + ); + + let result = match file_format { + PersistedDocumentsFileFormat::ApolloPersistedQueryManifest => { + let parsed = + serde_json::from_str::(contents.as_str())?; + + Self { + known_documents: parsed.operations.into_iter().fold( + HashMap::new(), + |mut acc, record| { + acc.insert(record.id, record.body); + acc + }, + ), + } + } + PersistedDocumentsFileFormat::JsonKeyValue => Self { + known_documents: serde_json::from_str(contents.as_str())?, + }, + }; + + info!( + "loaded persisted documents store from file, total records: {:?}", + result.known_documents.len() + ); + + Ok(result) + } +} + +#[tokio::test] +async fn fs_store_apollo_manifest_value() { + use serde_json::json; + + // valid JSON structure with empty array + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &json!({ + "format": "apollo", + "version": 1, + "operations": [] + }) + .to_string(), + &PersistedDocumentsFileFormat::ApolloPersistedQueryManifest, + ) + .expect("expected valid apollo manifest store") + .known_documents + .len(), + 0 + ); + + // valid store mapping + let store = PersistedDocumentsFilesystemStore::new_from_file_contents( + &json!({ + "format": "apollo", + "version": 1, + "operations": [ + { + "id": "key1", + "body": "query test { __typename }", + "name": "test", + "type": "query" + } + ] + }) + .to_string(), + &PersistedDocumentsFileFormat::ApolloPersistedQueryManifest, + ) + .expect("expected valid apollo manifest store"); + assert_eq!(store.known_documents.len(), 1); + assert_eq!(store.has_document("key1").await, true); + assert_eq!( + store.get_document("key1").await.cloned(), + Some("query test { __typename }".to_string()) + ); + + // Invalid JSON + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &"{".to_string(), + &PersistedDocumentsFileFormat::ApolloPersistedQueryManifest, + ) + .is_err(), + true + ); + + // invalid JSON structure + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &json!({}).to_string(), + &PersistedDocumentsFileFormat::ApolloPersistedQueryManifest, + ) + .is_err(), + true + ); +} + +#[tokio::test] +async fn fs_store_json_key_value() { + use serde_json::json; + + // Valid empty JSON map + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &json!({}).to_string(), + &PersistedDocumentsFileFormat::JsonKeyValue, + ) + .expect("failed to create store from json key value") + .known_documents + .len(), + 0 + ); + + // Valid JSON map + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &json!({ + "key1": "query { __typename }" + }) + .to_string(), + &PersistedDocumentsFileFormat::JsonKeyValue, + ) + .expect("failed to create store from json key value") + .known_documents + .len(), + 1 + ); + + // Invalid object structure + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &json!([]).to_string(), + &PersistedDocumentsFileFormat::JsonKeyValue, + ) + .is_err(), + true + ); + + // Invalid JSON + assert_eq!( + PersistedDocumentsFilesystemStore::new_from_file_contents( + &"{".to_string(), + &PersistedDocumentsFileFormat::JsonKeyValue, + ) + .is_err(), + true + ); +} diff --git a/src/plugins/persisted_documents/store/mod.rs b/src/plugins/persisted_documents/store/mod.rs new file mode 100644 index 00000000..16fc562c --- /dev/null +++ b/src/plugins/persisted_documents/store/mod.rs @@ -0,0 +1,7 @@ +pub mod fs; + +#[async_trait::async_trait] +pub trait PersistedDocumentsStore: Sync + Send { + async fn has_document(&self, hash: &str) -> bool; + async fn get_document(&self, hash: &str) -> Option<&String>; +} diff --git a/src/plugins/plugin_manager.rs b/src/plugins/plugin_manager.rs index 296015b2..3d459b61 100644 --- a/src/plugins/plugin_manager.rs +++ b/src/plugins/plugin_manager.rs @@ -9,6 +9,7 @@ use crate::{ use super::{ cors::CorsPlugin, flow_context::FlowContext, graphiql_plugin::GraphiQLPlugin, http_get_plugin::HttpGetPlugin, match_content_type::MatchContentTypePlugin, + persisted_documents::plugin::PersistedOperationsPlugin, }; #[derive(Debug, Default)] @@ -29,6 +30,10 @@ impl PluginManager { PluginDefinition::HttpGetPlugin(config) => { instance.register_plugin(HttpGetPlugin(config.clone())) } + PluginDefinition::PersistedOperationsPlugin(config) => instance.register_plugin( + PersistedOperationsPlugin::new_from_config(config.clone()) + .expect("failed to initalize persisted operations plugin"), + ), }); } @@ -107,14 +112,20 @@ impl PluginManager { } #[tracing::instrument(level = "trace")] - pub fn on_endpoint_creation<'a>(&self, router: Router<()>) -> Router<()> { - let p = &self.plugins; - let mut modified_router = router; + pub fn on_endpoint_creation<'a>( + &self, + root_router: Router<()>, + endpoint_router: Router<()>, + ) -> (Router<()>, Router<()>) { + let p: &Vec> = &self.plugins; + let mut modified_root_router = root_router; + let mut modified_endpoint_router = endpoint_router; for plugin in p.iter() { - modified_router = plugin.on_endpoint_creation(modified_router); + (modified_root_router, modified_endpoint_router) = + plugin.on_endpoint_creation(modified_root_router, modified_endpoint_router); } - modified_router + (modified_root_router, modified_endpoint_router) } } diff --git a/src/source/base_source.rs b/src/source/base_source.rs index d1d3f152..bbacc9b8 100644 --- a/src/source/base_source.rs +++ b/src/source/base_source.rs @@ -1,3 +1,4 @@ +use std::fmt::Display; use std::pin::Pin; use axum::response::Result; @@ -29,6 +30,18 @@ pub enum SourceError { InvalidPlannedRequest(hyper::http::Error), } +impl Display for SourceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SourceError::UnexpectedHTTPStatusError(status) => { + write!(f, "Unexpected HTTP status: {}", status) + } + SourceError::NetworkError(e) => write!(f, "Network error: {}", e), + SourceError::InvalidPlannedRequest(e) => write!(f, "Invalid planned request: {}", e), + } + } +} + impl GraphQLRequest { pub async fn into_hyper_request( &self, diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 00000000..4fe9e752 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod serde_utils; diff --git a/src/utils/serde_utils.rs b/src/utils/serde_utils.rs new file mode 100644 index 00000000..1b3b3e7d --- /dev/null +++ b/src/utils/serde_utils.rs @@ -0,0 +1,43 @@ +use std::{fmt, path::Path}; + +use serde::{de::Visitor, Deserialize}; +use std::fs::read_to_string; +use tracing::debug; + +struct LocalFileReferenceVisitor {} + +impl<'de> Visitor<'de> for LocalFileReferenceVisitor { + type Value = LocalFileReference; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("expected a valid local file path") + } + + fn visit_str(self, file_path: &str) -> Result + where + E: serde::de::Error, + { + debug!("loading local file reference from path {:?}", file_path); + let contents = read_to_string(Path::new(file_path)).expect("Failed to read file"); + + Ok(LocalFileReference { + path: file_path.to_string(), + contents, + }) + } +} + +#[derive(Debug, Clone)] +pub struct LocalFileReference { + pub path: String, + pub contents: String, +} + +impl<'de> Deserialize<'de> for LocalFileReference { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(LocalFileReferenceVisitor {}) + } +} diff --git a/temp/config.yaml b/temp/config.yaml index a8143cec..ed9d6220 100644 --- a/temp/config.yaml +++ b/temp/config.yaml @@ -15,6 +15,26 @@ endpoints: - path: /graphql from: countries + - path: /persisted + from: countries + plugins: + - type: persisted_operations + store: + source: file + path: ./temp/persisted_operations_store.json + format: json_key_value + protocols: + - type: apollo_manifest_extensions + - type: document_id + - type: http_get + document_id_from: + source: search_query + name: docId + variables_from: + source: header + name: "X-GraphQL-Variables" + allow_non_persisted: true + - path: /test from: countries plugins: