From 5abc0e821c9c3653f92c88b920aa947e2f45f65c Mon Sep 17 00:00:00 2001 From: Oscar Beaumont Date: Wed, 17 Jul 2024 15:45:31 +0800 Subject: [PATCH] OpenAPI working --- examples/axum/Cargo.toml | 3 +- examples/axum/src/api.rs | 6 +- examples/axum/src/api/chat.rs | 6 +- examples/axum/src/api/store.rs | 25 +++ examples/axum/src/main.rs | 17 +- middleware/openapi/Cargo.toml | 3 + middleware/openapi/src/lib.rs | 245 +++++++++++++++++++++++++--- middleware/openapi/src/swagger.html | 34 ++++ middleware/tracing/src/lib.rs | 4 +- rspc/src/procedure/procedure.rs | 3 +- rspc/src/router.rs | 6 +- 11 files changed, 306 insertions(+), 46 deletions(-) create mode 100644 examples/axum/src/api/store.rs create mode 100644 middleware/openapi/src/swagger.html diff --git a/examples/axum/Cargo.toml b/examples/axum/Cargo.toml index 33e8673e..71a3d58d 100644 --- a/examples/axum/Cargo.toml +++ b/examples/axum/Cargo.toml @@ -13,7 +13,8 @@ thiserror = "1.0.62" async-stream = "0.3.5" tracing = "0.1.40" tracing-subscriber = "0.3.18" -rspc-tracing = { version = "0.0.0", path = "../../middleware/tracing" } # TODO: Remove? +rspc-tracing = { version = "0.0.0", path = "../../middleware/tracing" } +rspc-openapi = { version = "0.0.0", path = "../../middleware/openapi" } serde = { version = "1", features = ["derive"] } specta = { version = "=2.0.0-rc.15", features = ["derive"] } # TODO: Drop requirement on `derive` specta-util = "0.0.2" # TODO: We need this for `TypeCollection` which is cringe diff --git a/examples/axum/src/api.rs b/examples/axum/src/api.rs index 8d13dbaa..9d84bb26 100644 --- a/examples/axum/src/api.rs +++ b/examples/axum/src/api.rs @@ -8,13 +8,14 @@ use specta_typescript::Typescript; use specta_util::TypeCollection; use thiserror::Error; -mod chat; +pub(crate) mod chat; +pub(crate) mod store; #[derive(Debug, Error)] pub enum Error {} // `Clone` is only required for usage with Websockets -#[derive(Default, Clone)] +#[derive(Clone)] pub struct Context { pub chat: chat::Ctx, } @@ -40,6 +41,7 @@ pub fn mount() -> Router { ::builder().query(|_, _: ()| async { Ok(env!("CARGO_PKG_VERSION")) }) }) .merge("chat", chat::mount()) + .merge("store", store::mount()) // TODO: I dislike this API .ext({ let mut types = TypeCollection::default(); diff --git a/examples/axum/src/api/chat.rs b/examples/axum/src/api/chat.rs index 860bc9f5..8b0b3cc8 100644 --- a/examples/axum/src/api/chat.rs +++ b/examples/axum/src/api/chat.rs @@ -15,11 +15,11 @@ pub struct Ctx { chat: broadcast::Sender, } -impl Default for Ctx { - fn default() -> Self { +impl Ctx { + pub fn new(chat: broadcast::Sender) -> Self { Self { author: Arc::new(Mutex::new("Anonymous".into())), - chat: broadcast::channel(100).0, + chat, } } } diff --git a/examples/axum/src/api/store.rs b/examples/axum/src/api/store.rs new file mode 100644 index 00000000..3a8584de --- /dev/null +++ b/examples/axum/src/api/store.rs @@ -0,0 +1,25 @@ +use rspc_openapi::OpenAPI; + +use super::{BaseProcedure, Router}; + +pub fn mount() -> Router { + Router::new() + .procedure("get", { + ::builder() + .with(OpenAPI::get("/api/get").build()) + .mutation(|ctx, _: ()| async move { + // TODO + + Ok("Hello From rspc!") + }) + }) + .procedure("set", { + ::builder() + .with(OpenAPI::post("/api/set").build()) + .mutation(|ctx, value: String| async move { + // TODO + + Ok(()) + }) + }) +} diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs index 50d009dc..85c5ac6c 100644 --- a/examples/axum/src/main.rs +++ b/examples/axum/src/main.rs @@ -1,6 +1,7 @@ use std::net::Ipv6Addr; use axum::{routing::get, Router}; +use tokio::sync::broadcast; use tracing::info; mod api; @@ -11,18 +12,18 @@ async fn main() { let router = api::mount().build().unwrap(); + let chat_tx = broadcast::channel(100).0; + let ctx_fn = move || api::Context { + chat: api::chat::Ctx::new(chat_tx.clone()), + }; + let app = Router::new() .route("/", get(|| async { "Hello, World!" })) .nest( "/rspc", - rspc_axum::Endpoint::new(router, || api::Context { - chat: Default::default(), - }) - .with_endpoints() - .with_websocket() - .with_batching() - .build(), - ); + rspc_axum::Endpoint::new(router.clone(), ctx_fn.clone()), + ) + .nest("/", rspc_openapi::mount(router, ctx_fn)); info!("Listening on http://[::1]:3000"); let listener = tokio::net::TcpListener::bind((Ipv6Addr::UNSPECIFIED, 3000)) diff --git a/middleware/openapi/Cargo.toml b/middleware/openapi/Cargo.toml index faa338fc..bfd03069 100644 --- a/middleware/openapi/Cargo.toml +++ b/middleware/openapi/Cargo.toml @@ -6,6 +6,9 @@ publish = false # TODO: Crate metadata & publish [dependencies] rspc = { path = "../../rspc" } +axum = { version = "0.7.5", default-features = false } +serde_json = "1.0.120" +futures = "0.3.30" # /bin/sh RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features [package.metadata."docs.rs"] diff --git a/middleware/openapi/src/lib.rs b/middleware/openapi/src/lib.rs index 8f3173ac..d6ab5334 100644 --- a/middleware/openapi/src/lib.rs +++ b/middleware/openapi/src/lib.rs @@ -5,37 +5,230 @@ html_favicon_url = "https://github.com/oscartbeaumont/rspc/raw/main/docs/public/logo.png" )] -use std::{borrow::Cow, collections::HashMap}; +use std::{borrow::Cow, collections::HashMap, hash::Hash, sync::Arc}; -use rspc::middleware::Middleware; +use axum::{ + body::Bytes, + extract::Query, + http::StatusCode, + response::Html, + routing::{get, post}, + Json, +}; +use futures::StreamExt; +use rspc::{ + middleware::Middleware, + procedure::{Procedure, ProcedureInput}, + BuiltRouter, +}; +use serde_json::json; -#[derive(Default)] -pub struct OpenAPIState(HashMap, ()>); +// TODO: Properly handle inputs from query params +// TODO: Properly handle responses from query params +// TODO: Support input's coming from URL. Eg. `/todos/{id}` like tRPC-OpenAPI +// TODO: Support `application/x-www-form-urlencoded` bodies like tRPC-OpenAPI +// TODO: Probs put SwaggerUI behind a feature flag + +pub struct OpenAPI { + method: &'static str, + path: Cow<'static, str>, +} + +impl OpenAPI { + // TODO + // pub fn new(method: Method, path: impl Into>) {} + + pub fn get(path: impl Into>) -> Self { + Self { + method: "GET", + path: path.into(), + } + } + + pub fn post(path: impl Into>) -> Self { + Self { + method: "GET", + path: path.into(), + } + } + + pub fn put(path: impl Into>) -> Self { + Self { + method: "GET", + path: path.into(), + } + } + + pub fn patch(path: impl Into>) -> Self { + Self { + method: "GET", + path: path.into(), + } + } + + pub fn delete(path: impl Into>) -> Self { + Self { + method: "GET", + path: path.into(), + } + } + + // TODO: Configure other OpenAPI stuff like auth??? + + pub fn build( + self, + ) -> Middleware + where + TError: 'static, + TThisCtx: Send + 'static, + TThisInput: Send + 'static, + TThisResult: Send + 'static, + { + // TODO: Can we have a middleware with only a `setup` function to avoid the extra future boxing??? + Middleware::new(|ctx, input, next| async move { next.exec(ctx, input).await }).setup( + move |state, meta| { + state + .get_mut_or_init::(Default::default) + .0 + .insert((self.method, self.path), meta.name().to_string()); + }, + ) + } +} -// TODO: Configure other OpenAPI stuff like auth +// The state that is stored into rspc. +// A map of (method, path) to procedure name. +#[derive(Default)] +struct OpenAPIState(HashMap<(&'static str, Cow<'static, str>), String>); -// TODO: Make convert this into a builder like: Endpoint::get("/todo").some_other_stuff().build() -pub fn openapi( - // method: Method, - path: impl Into>, -) -> Middleware +// TODO: Axum should be behind feature flag +// TODO: Can we decouple webserver from OpenAPI while keeping something maintainable???? +pub fn mount( + router: BuiltRouter, + ctx_fn: impl Fn() -> TCtx + Clone + Send + Sync + 'static, +) -> axum::Router where - TError: 'static, - TThisCtx: Send + 'static, - TThisInput: Send + 'static, - TThisResult: Send + 'static, + S: Clone + Send + Sync + 'static, + TCtx: Send + 'static, { - let path = path.into(); - Middleware::new(|ctx, input, next| async move { - let _result = next.exec(ctx, input).await; - _result - }) - .setup(|state, meta| { - state - .get_mut_or_init::(Default::default) - .0 - .insert(path, ()); - }) + let mut r = axum::Router::new(); + + let mut paths: HashMap<_, HashMap<_, _>> = HashMap::new(); + if let Some(endpoints) = router.state.get::() { + for ((method, path), procedure_name) in endpoints.0.iter() { + let procedure = router + .procedures + .get(&Cow::Owned(procedure_name.clone())) + .expect("unreachable: a procedure was registered that doesn't exist") + .clone(); + let ctx_fn = ctx_fn.clone(); + + paths + .entry(path.clone()) + .or_default() + .insert(method.to_lowercase(), procedure.clone()); + + r = r.route( + path, + match *method { + "GET" => { + // TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable. + get(move |query: Query>| async move { + let ctx = (ctx_fn)(); + + handle_procedure( + ctx, + &mut serde_json::Deserializer::from_str( + query.get("input").map(|v| &**v).unwrap_or("null"), + ), + procedure, + ) + .await + }) + } + "POST" => { + // TODO: By moving `procedure` into the closure we hang onto the types for the duration of the program which is probs undesirable. + post(move |body: Bytes| async move { + let ctx = (ctx_fn)(); + + handle_procedure( + ctx, + &mut serde_json::Deserializer::from_slice(&body), + procedure, + ) + .await + }) + } + // "PUT" => axum::routing::put, + // "PATCH" => axum::routing::patch, + // "DELETE" => axum::routing::delete, + _ => panic!("Unsupported method"), + }, + ); + } + } + + let schema = Arc::new(json!({ + "openapi": "3.0.3", + "info": { + "title": "rspc OpenAPI", + "description": "This is a demo of rspc OpenAPI", + "version": "0.0.0" + }, + "paths": paths.into_iter() + .map(|(path, procedures)| { + let mut methods = HashMap::new(); + for (method, procedure) in procedures { + methods.insert(method.to_string(), json!({ + "operationId": procedure.ty().key.to_string(), + "responses": { + "200": { + "description": "Successful operation" + } + } + })); + } + + (path, methods) + }) + .collect::>() + })); // TODO: Maybe convert to string now cause it will be more efficient to clone + + r.route( + // TODO: Allow the user to configure this URL & turn it off + "/api/docs", + get(|| async { Html(include_str!("swagger.html")) }), + ) + .route( + // TODO: Allow the user to configure this URL & turn it off + "/api/openapi.json", + get(move || async move { Json((*schema).clone()) }), + ) } -// TODO: Convert into API endpoint +// Used for `GET` and `POST` endpoints +async fn handle_procedure<'de, TCtx>( + ctx: TCtx, + input: impl ProcedureInput<'de>, + procedure: Procedure, +) -> Result, (StatusCode, Json)> { + let mut stream = procedure.exec(ctx, input).map_err(|err| { + // TODO: Error code by matching off `InternalError` + (StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string())) + })?; + + // TODO: Support for streaming + while let Some(value) = stream.next().await { + // TODO: We should probs deserialize into buffer instead of value??? + return match value.map(|v| v.serialize(serde_json::value::Serializer)) { + Ok(Ok(value)) => Ok(Json(value)), + Ok(Err(err)) => { + // TODO: Error code by matching off `InternalError` + Err((StatusCode::INTERNAL_SERVER_ERROR, Json(err.to_string()))) + } + Err(err) => panic!("{err:?}"), // TODO: Error handling -> How to serialize `TError`??? -> Should this be done in procedure? + }; + } + + Ok(Json(serde_json::Value::Null)) +} diff --git a/middleware/openapi/src/swagger.html b/middleware/openapi/src/swagger.html new file mode 100644 index 00000000..113b1cc5 --- /dev/null +++ b/middleware/openapi/src/swagger.html @@ -0,0 +1,34 @@ + + + + + + + SwaggerUI + + + +
+ + + + + diff --git a/middleware/tracing/src/lib.rs b/middleware/tracing/src/lib.rs index 7b315f1a..47181764 100644 --- a/middleware/tracing/src/lib.rs +++ b/middleware/tracing/src/lib.rs @@ -38,9 +38,7 @@ where let start = std::time::Instant::now(); let result = next.exec(ctx, input).await; info!( - "{} {} took {:?} with input {input_str:?} and returned {:?}", - next.meta().kind().to_string().to_uppercase(), // TODO: Maybe adding color? - next.meta().name(), + "took {:?} with input {input_str:?} and returned {:?}", start.elapsed(), DebugWrapper(&result, PhantomData::) ); diff --git a/rspc/src/procedure/procedure.rs b/rspc/src/procedure/procedure.rs index b99549e4..2e9ae8d1 100644 --- a/rspc/src/procedure/procedure.rs +++ b/rspc/src/procedure/procedure.rs @@ -107,7 +107,7 @@ impl Procedure { /// ```rust /// todo!(); # TODO: Example /// ``` - pub fn types(&self) -> &ProcedureTypeDefinition { + pub fn ty(&self) -> &ProcedureTypeDefinition { &self.ty } @@ -147,6 +147,7 @@ impl Procedure { #[derive(Debug, Clone, PartialEq)] pub struct ProcedureTypeDefinition { + // TODO: Should `key` move onto `Procedure` instead?s pub key: Cow<'static, str>, pub kind: ProcedureKind, pub input: DataType, diff --git a/rspc/src/router.rs b/rspc/src/router.rs index 79eb5fc8..1803392a 100644 --- a/rspc/src/router.rs +++ b/rspc/src/router.rs @@ -3,6 +3,7 @@ use std::{ collections::BTreeMap, fmt, path::{Path, PathBuf}, + sync::Arc, }; use specta::{Language, TypeMap}; @@ -127,15 +128,16 @@ impl Router { } Ok(BuiltRouter { - state, + state: Arc::new(state), types, procedures, }) } } +#[derive(Clone)] pub struct BuiltRouter { - pub state: State, + pub state: Arc, pub types: TypeMap, pub procedures: BTreeMap, Procedure>, }