From dfb67f3c8de9236733f3c86460517b6a01b49992 Mon Sep 17 00:00:00 2001 From: Oscar Beaumont Date: Tue, 31 Oct 2023 23:15:45 +0800 Subject: [PATCH] wip --- crates/core/src/body2/cursed.rs | 55 +++++++++++ crates/core/src/body2/mod.rs | 19 +++- crates/core/src/lib.rs | 3 +- crates/core/src/middleware/mw_ctx.rs | 8 +- crates/core/src/middleware/mw_result.rs | 23 ++++- examples/axum/src/main.rs | 14 ++- src/internal/middleware/middleware_layer.rs | 104 +++++++------------- src/internal/middleware/mw.rs | 11 +-- src/lib.rs | 2 + 9 files changed, 151 insertions(+), 88 deletions(-) create mode 100644 crates/core/src/body2/cursed.rs diff --git a/crates/core/src/body2/cursed.rs b/crates/core/src/body2/cursed.rs new file mode 100644 index 00000000..6c9649f4 --- /dev/null +++ b/crates/core/src/body2/cursed.rs @@ -0,0 +1,55 @@ +use std::{ + cell::Cell, + future::poll_fn, + task::{Poll, Waker}, +}; + +use crate::Body; + +// TODO: Make this private +pub enum YieldMsg { + YieldBody, + YieldBodyResult(serde_json::Value), +} + +thread_local! { + // TODO: Make this private + pub static CURSED_OP: Cell> = const { Cell::new(None) }; +} + +// TODO: Make private +pub async fn inner() -> Body { + let mut state = false; + poll_fn(|_| match state { + false => { + CURSED_OP.set(Some(YieldMsg::YieldBody)); + state = true; + return Poll::Pending; + } + true => { + let y = CURSED_OP + .take() + .expect("Expected response from outer future!"); + return Poll::Ready(match y { + YieldMsg::YieldBody => unreachable!(), + YieldMsg::YieldBodyResult(body) => Body::Value(body), + }); + } + }) + .await +} + +// TODO: Use this instead +// // Called on `Poll::Pending` from inner +// pub fn outer(waker: &Waker) { +// if let Some(op) = CURSED_OP.take() { +// match op { +// YieldMsg::YieldBody => { +// // TODO: Get proper value +// CURSED_OP.set(Some(YieldMsg::YieldBodyResult(serde_json::Value::Null))); +// waker.wake_by_ref(); +// } +// YieldMsg::YieldBodyResult(_) => unreachable!(), +// } +// } +// } diff --git a/crates/core/src/body2/mod.rs b/crates/core/src/body2/mod.rs index 519f2aa9..14108db4 100644 --- a/crates/core/src/body2/mod.rs +++ b/crates/core/src/body2/mod.rs @@ -4,6 +4,8 @@ use futures::Stream; use crate::error::ExecError; +pub mod cursed; + // It is expected that the type remains the same for all items of a single stream! It's ok for panic's if this is violated. // // TODO: Can this be `pub(crate)`??? -> Right now `Layer` is the problem @@ -15,8 +17,6 @@ pub enum ValueOrBytes { pub(crate) type StreamItem = Result; pub(crate) type ErasedBody = Pin + Send>>; -// pub(crate) type ErasedBody = BodyInternal> + Send>>>; - #[derive(Debug)] #[non_exhaustive] pub enum Body { @@ -31,5 +31,20 @@ pub enum Body { #[derive(Debug)] // TODO: Better debug impl pub struct StreamBody {} +impl Stream for StreamBody { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + todo!() + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } +} + #[derive(Debug)] // TODO: Better debug impl pub struct BytesBody {} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 2c1a3705..0989840f 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -21,6 +21,7 @@ mod router; mod router_builder; mod util; +pub use body2::cursed; // TODO: Hide `cursed` under `internal` pub use body2::{Body, BytesBody, StreamBody, ValueOrBytes}; // TODO: Reduce API surface in this?? @@ -50,7 +51,7 @@ pub mod internal { pub use super::procedure_store::{build, ProcedureDef, ProcedureTodo, ProceduresDef}; pub use super::middleware::{ - new_mw_ctx, MiddlewareContext, MwV2Result, ProcedureKind, RequestContext, + new_mw_ctx, IntoMiddlewareResult, MiddlewareContext, ProcedureKind, RequestContext, }; pub use super::router_builder::{ diff --git a/crates/core/src/middleware/mw_ctx.rs b/crates/core/src/middleware/mw_ctx.rs index 82e1c35e..04ff93fb 100644 --- a/crates/core/src/middleware/mw_ctx.rs +++ b/crates/core/src/middleware/mw_ctx.rs @@ -5,7 +5,7 @@ use std::{ use serde_json::Value; -use crate::Body; +use crate::{cursed, Body}; pub fn new_mw_ctx( input: serde_json::Value, @@ -27,6 +27,7 @@ pub struct MiddlewareContext { // For response new_ctx: Arc>>, + // chan: futures::mpsc::Sender, // new_span: Option } @@ -38,14 +39,13 @@ impl MiddlewareContext { } // TODO: Refactor return type - pub fn next(self, ctx: TNewCtx) -> Body { + pub async fn next(self, ctx: TNewCtx) -> Body { self.new_ctx .lock() .unwrap_or_else(PoisonError::into_inner) .replace(ctx); - // TODO - Body::Value(serde_json::Value::Null) + cursed::inner().await } } diff --git a/crates/core/src/middleware/mw_result.rs b/crates/core/src/middleware/mw_result.rs index 7132ad6f..437d8014 100644 --- a/crates/core/src/middleware/mw_result.rs +++ b/crates/core/src/middleware/mw_result.rs @@ -1,4 +1,21 @@ -// TODO: Rename -pub trait MwV2Result {} +use serde_json::Value; -impl MwV2Result for () {} +use crate::error::ExecError; + +/// TODO +pub trait IntoMiddlewareResult { + // TODO: Support streams and bytes + fn into_result(self) -> Result; +} + +impl IntoMiddlewareResult for () { + fn into_result(self) -> Result { + Ok(Value::Null) + } +} + +impl IntoMiddlewareResult for Value { + fn into_result(self) -> Result { + Ok(self) + } +} diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs index 281fcc2b..21489c13 100644 --- a/examples/axum/src/main.rs +++ b/examples/axum/src/main.rs @@ -10,7 +10,7 @@ use std::{ use async_stream::stream; use axum::routing::get; use futures::{Stream, StreamExt}; -use rspc::{ExportConfig, Rspc}; +use rspc::{Body, ExportConfig, Rspc}; use serde::Serialize; use specta::Type; use tokio::{sync::broadcast, time::sleep}; @@ -45,9 +45,17 @@ async fn main() { R // TODO: Old cringe syntax .with(|mw, ctx| async move { - let y = mw.next(((), ctx)); + // Some processing - () + let y = mw.next(((), ctx)).await; + + println!("{:?}", y); + + match y { + Body::Value(v) => v, + Body::Stream(v) => todo!(), + _ => todo!(), + } }) // Passthrough // .with(|mw, ctx| async move { mw.next::(ctx)? }) diff --git a/src/internal/middleware/middleware_layer.rs b/src/internal/middleware/middleware_layer.rs index af246d6e..2dbe2853 100644 --- a/src/internal/middleware/middleware_layer.rs +++ b/src/internal/middleware/middleware_layer.rs @@ -13,8 +13,11 @@ mod private { use specta::{ts, TypeMap}; use rspc_core::{ + cursed::{self, YieldMsg, CURSED_OP}, error::ExecError, - internal::{new_mw_ctx, Layer, ProcedureDef, RequestContext}, + internal::{ + new_mw_ctx, IntoMiddlewareResult, Layer, PinnedOption, ProcedureDef, RequestContext, + }, ValueOrBytes, }; @@ -53,7 +56,7 @@ mod private { req: RequestContext, ) -> Result, ExecError> { let new_ctx = Arc::new(Mutex::new(None)); - let fut = self.mw.run_me( + let fut = self.mw.execute( ctx, new_mw_ctx( input.clone(), // TODO: This probs won't fly if we accept file upload @@ -68,6 +71,8 @@ mod private { new_ctx, input: Some(input), req: Some(req), + stream: PinnedOption::None, + is_stream_done: false, }) } } @@ -103,22 +108,12 @@ mod private { // TODO: Avoid `Option` and instead encode into enum input: Option, req: Option, - }, - // We are in this state where we are executing the current middleware on the stream - Execute { + // The actual data stream from the resolver function or next middleware #[pin] - stream: TNextLayer::Stream<'a>, + stream: PinnedOption>, // We use this so we can keep polling `resp_fut` for the final message and once it is done and this bool is set, shutdown. is_stream_done: bool, - - // The currently executing future returned by the `resp_fn` (publicly `.map`) function - // Be aware this will go `None` -> `Some` -> `None`, etc for a subscription - // #[pin] - // resp_fut: PinnedOption<<::Resp as Executable2>::Fut>, - // The `.map` function returned by the user from the execution of the current middleware - // This allows a middleware to map the values being returned from the stream - // resp_fn: ::Resp, }, // The stream is internally done but it returned `Poll::Ready` for the shutdown message so the caller thinks it's still active // This will yield `Poll::Ready(None)` and transition into the `Self::Done` phase. @@ -147,65 +142,36 @@ mod private { new_ctx, input, req, + stream, + is_stream_done, } => { - let result = ready!(fut.poll(cx)); - - let ctx = new_ctx - .lock() - .unwrap_or_else(PoisonError::into_inner) - .take() - .unwrap(); - - match next.call(ctx, input.take().unwrap(), req.take().unwrap()) { - Ok(stream) => { - self.as_mut().set(Self::Execute { - stream, - is_stream_done: false, - // resp_fut: PinnedOption::None, - // resp_fn: None, // TODO: Fully remove this - }); - } - - Err(err) => { + let result = match fut.poll(cx) { + Poll::Ready(result) => { self.as_mut().set(Self::PendingDone); - return Poll::Ready(Some(Err(err))); + return Poll::Ready(Some( + result.into_result().map(ValueOrBytes::Value), + )); } - } - } - MiddlewareLayerFutureProj::Execute { - mut stream, - is_stream_done, - } => { - // if let PinnedOptionProj::Some { v } = resp_fut.as_mut().project() { - // let result = ready!(v.poll(cx)); - // cx.waker().wake_by_ref(); // No wakers set so we set one - // resp_fut.set(PinnedOption::None); - // return Poll::Ready(Some(Ok(result))); - // } - - if *is_stream_done { - self.as_mut().set(Self::Done); - return Poll::Ready(None); - } - - match ready!(stream.as_mut().poll_next(cx)) { - Some(result) => match result { - Ok(result) => { - return Poll::Ready(Some(Ok(result))); + Poll::Pending => { + // cursed::outer(cx.waker()); + + if let Some(op) = CURSED_OP.take() { + match op { + YieldMsg::YieldBody => { + // TODO: Get proper value + CURSED_OP.set(Some(YieldMsg::YieldBodyResult( + serde_json::Value::Null, + ))); + + cx.waker().wake_by_ref(); + } + YieldMsg::YieldBodyResult(_) => unreachable!(), + } } - // TODO: The `.map` function is skipped for errors. Maybe it should be possible to map them when desired? - // TODO: We also shut down the whole stream on a single error. Is this desired? - Err(err) => { - self.as_mut().set(Self::PendingDone); - return Poll::Ready(Some(Err(err))); - } - }, - // The underlying stream has shutdown so we will resolve `resp_fut` and then terminate ourselves - None => { - *is_stream_done = true; - continue; + + return Poll::Pending; } - } + }; } MiddlewareLayerFutureProj::PendingDone => { self.as_mut().set(Self::Done); @@ -225,7 +191,7 @@ mod private { fn size_hint(&self) -> (usize, Option) { match &self { - Self::Execute { stream: c, .. } => c.size_hint(), + // Self::Execute { stream: c, .. } => c.size_hint(), // TODO: Bring this back _ => (0, None), } } diff --git a/src/internal/middleware/mw.rs b/src/internal/middleware/mw.rs index 72573ea4..56a61765 100644 --- a/src/internal/middleware/mw.rs +++ b/src/internal/middleware/mw.rs @@ -1,6 +1,6 @@ use std::future::Future; -use rspc_core::internal::{MiddlewareContext, MwV2Result}; +use rspc_core::internal::{IntoMiddlewareResult, MiddlewareContext}; // `TNewCtx` is sadly require to constain the impl at the bottom of this file. If you can remove it your a god. pub trait MiddlewareFn: @@ -9,10 +9,9 @@ where TLCtx: Send + Sync + 'static, { type Fut: Future + Send + 'static; - type Result: MwV2Result; + type Result: IntoMiddlewareResult; - // TODO: Rename - fn run_me(&self, ctx: TLCtx, mw: MiddlewareContext) -> Self::Fut; + fn execute(&self, ctx: TLCtx, mw: MiddlewareContext) -> Self::Fut; } impl MiddlewareFn for F @@ -20,12 +19,12 @@ where TLCtx: Send + Sync + 'static, F: Fn(MiddlewareContext, TLCtx) -> Fu + Send + Sync + 'static, Fu: Future + Send + 'static, - Fu::Output: MwV2Result + Send + 'static, + Fu::Output: IntoMiddlewareResult + Send + 'static, { type Fut = Fu; type Result = Fu::Output; - fn run_me(&self, ctx: TLCtx, mw: MiddlewareContext) -> Self::Fut { + fn execute(&self, ctx: TLCtx, mw: MiddlewareContext) -> Self::Fut { self(mw, ctx) } } diff --git a/src/lib.rs b/src/lib.rs index 30effc72..82c19ee9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,3 +27,5 @@ pub mod internal; // TODO: Only reexport certain types pub use rspc_core::error::*; + +pub use rspc_core::Body;