Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
oscartbeaumont committed Oct 31, 2023
1 parent 0961e29 commit dfb67f3
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 88 deletions.
55 changes: 55 additions & 0 deletions crates/core/src/body2/cursed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::{
cell::Cell,
future::poll_fn,
task::{Poll, Waker},
};

use crate::Body;

// TODO: Make this private
pub enum YieldMsg {
YieldBody,
YieldBodyResult(serde_json::Value),
}

thread_local! {
// TODO: Make this private
pub static CURSED_OP: Cell<Option<YieldMsg>> = const { Cell::new(None) };
}

// TODO: Make private
pub async fn inner() -> Body {
let mut state = false;
poll_fn(|_| match state {
false => {
CURSED_OP.set(Some(YieldMsg::YieldBody));
state = true;
return Poll::Pending;
}
true => {
let y = CURSED_OP
.take()
.expect("Expected response from outer future!");
return Poll::Ready(match y {
YieldMsg::YieldBody => unreachable!(),
YieldMsg::YieldBodyResult(body) => Body::Value(body),
});
}
})
.await
}

// TODO: Use this instead
// // Called on `Poll::Pending` from inner
// pub fn outer(waker: &Waker) {
// if let Some(op) = CURSED_OP.take() {
// match op {
// YieldMsg::YieldBody => {
// // TODO: Get proper value
// CURSED_OP.set(Some(YieldMsg::YieldBodyResult(serde_json::Value::Null)));
// waker.wake_by_ref();
// }
// YieldMsg::YieldBodyResult(_) => unreachable!(),
// }
// }
// }
19 changes: 17 additions & 2 deletions crates/core/src/body2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use futures::Stream;

use crate::error::ExecError;

pub mod cursed;

// It is expected that the type remains the same for all items of a single stream! It's ok for panic's if this is violated.
//
// TODO: Can this be `pub(crate)`??? -> Right now `Layer` is the problem
Expand All @@ -15,8 +17,6 @@ pub enum ValueOrBytes {
pub(crate) type StreamItem = Result<ValueOrBytes, ExecError>;
pub(crate) type ErasedBody = Pin<Box<dyn Stream<Item = StreamItem> + Send>>;

// pub(crate) type ErasedBody = BodyInternal<Pin<Box<dyn Stream<Item = Result<Value, ExecError>> + Send>>>;

#[derive(Debug)]
#[non_exhaustive]
pub enum Body {
Expand All @@ -31,5 +31,20 @@ pub enum Body {
#[derive(Debug)] // TODO: Better debug impl
pub struct StreamBody {}

impl Stream for StreamBody {
type Item = Result<Value, ExecError>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
todo!()
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, None)
}
}

#[derive(Debug)] // TODO: Better debug impl
pub struct BytesBody {}
3 changes: 2 additions & 1 deletion crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod router;
mod router_builder;
mod util;

pub use body2::cursed; // TODO: Hide `cursed` under `internal`
pub use body2::{Body, BytesBody, StreamBody, ValueOrBytes};

// TODO: Reduce API surface in this??
Expand Down Expand Up @@ -50,7 +51,7 @@ pub mod internal {
pub use super::procedure_store::{build, ProcedureDef, ProcedureTodo, ProceduresDef};

pub use super::middleware::{
new_mw_ctx, MiddlewareContext, MwV2Result, ProcedureKind, RequestContext,
new_mw_ctx, IntoMiddlewareResult, MiddlewareContext, ProcedureKind, RequestContext,
};

pub use super::router_builder::{
Expand Down
8 changes: 4 additions & 4 deletions crates/core/src/middleware/mw_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use serde_json::Value;

use crate::Body;
use crate::{cursed, Body};

pub fn new_mw_ctx<TNCtx>(
input: serde_json::Value,
Expand All @@ -27,6 +27,7 @@ pub struct MiddlewareContext<TNewCtx> {

// For response
new_ctx: Arc<Mutex<Option<TNewCtx>>>,
// chan: futures::mpsc::Sender<Body>,
// new_span: Option<tracing::Span>
}

Expand All @@ -38,14 +39,13 @@ impl<TNewCtx> MiddlewareContext<TNewCtx> {
}

// TODO: Refactor return type
pub fn next(self, ctx: TNewCtx) -> Body {
pub async fn next(self, ctx: TNewCtx) -> Body {
self.new_ctx
.lock()
.unwrap_or_else(PoisonError::into_inner)
.replace(ctx);

// TODO
Body::Value(serde_json::Value::Null)
cursed::inner().await
}
}

Expand Down
23 changes: 20 additions & 3 deletions crates/core/src/middleware/mw_result.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
// TODO: Rename
pub trait MwV2Result {}
use serde_json::Value;

impl MwV2Result for () {}
use crate::error::ExecError;

/// TODO
pub trait IntoMiddlewareResult {
// TODO: Support streams and bytes
fn into_result(self) -> Result<Value, ExecError>;
}

impl IntoMiddlewareResult for () {
fn into_result(self) -> Result<Value, ExecError> {
Ok(Value::Null)
}
}

impl IntoMiddlewareResult for Value {
fn into_result(self) -> Result<Value, ExecError> {
Ok(self)
}
}
14 changes: 11 additions & 3 deletions examples/axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
use async_stream::stream;
use axum::routing::get;
use futures::{Stream, StreamExt};
use rspc::{ExportConfig, Rspc};
use rspc::{Body, ExportConfig, Rspc};
use serde::Serialize;
use specta::Type;
use tokio::{sync::broadcast, time::sleep};
Expand Down Expand Up @@ -45,9 +45,17 @@ async fn main() {
R
// TODO: Old cringe syntax
.with(|mw, ctx| async move {
let y = mw.next(((), ctx));
// Some processing

()
let y = mw.next(((), ctx)).await;

println!("{:?}", y);

match y {
Body::Value(v) => v,
Body::Stream(v) => todo!(),
_ => todo!(),
}
})
// Passthrough
// .with(|mw, ctx| async move { mw.next::<middleware::Any, _>(ctx)? })
Expand Down
104 changes: 35 additions & 69 deletions src/internal/middleware/middleware_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ mod private {
use specta::{ts, TypeMap};

use rspc_core::{
cursed::{self, YieldMsg, CURSED_OP},
error::ExecError,
internal::{new_mw_ctx, Layer, ProcedureDef, RequestContext},
internal::{
new_mw_ctx, IntoMiddlewareResult, Layer, PinnedOption, ProcedureDef, RequestContext,
},
ValueOrBytes,
};

Expand Down Expand Up @@ -53,7 +56,7 @@ mod private {
req: RequestContext,
) -> Result<Self::Stream<'_>, ExecError> {
let new_ctx = Arc::new(Mutex::new(None));
let fut = self.mw.run_me(
let fut = self.mw.execute(
ctx,
new_mw_ctx(
input.clone(), // TODO: This probs won't fly if we accept file upload
Expand All @@ -68,6 +71,8 @@ mod private {
new_ctx,
input: Some(input),
req: Some(req),
stream: PinnedOption::None,
is_stream_done: false,
})
}
}
Expand Down Expand Up @@ -103,22 +108,12 @@ mod private {
// TODO: Avoid `Option` and instead encode into enum
input: Option<Value>,
req: Option<RequestContext>,
},
// We are in this state where we are executing the current middleware on the stream
Execute {

// The actual data stream from the resolver function or next middleware
#[pin]
stream: TNextLayer::Stream<'a>,
stream: PinnedOption<TNextLayer::Stream<'a>>,
// We use this so we can keep polling `resp_fut` for the final message and once it is done and this bool is set, shutdown.
is_stream_done: bool,

// The currently executing future returned by the `resp_fn` (publicly `.map`) function
// Be aware this will go `None` -> `Some` -> `None`, etc for a subscription
// #[pin]
// resp_fut: PinnedOption<<<TMiddleware::Result as MwV2Result>::Resp as Executable2>::Fut>,
// The `.map` function returned by the user from the execution of the current middleware
// This allows a middleware to map the values being returned from the stream
// resp_fn: <TMiddleware::Result as MwV2Result>::Resp,
},
// The stream is internally done but it returned `Poll::Ready` for the shutdown message so the caller thinks it's still active
// This will yield `Poll::Ready(None)` and transition into the `Self::Done` phase.
Expand Down Expand Up @@ -147,65 +142,36 @@ mod private {
new_ctx,
input,
req,
stream,
is_stream_done,
} => {
let result = ready!(fut.poll(cx));

let ctx = new_ctx
.lock()
.unwrap_or_else(PoisonError::into_inner)
.take()
.unwrap();

match next.call(ctx, input.take().unwrap(), req.take().unwrap()) {
Ok(stream) => {
self.as_mut().set(Self::Execute {
stream,
is_stream_done: false,
// resp_fut: PinnedOption::None,
// resp_fn: None, // TODO: Fully remove this
});
}

Err(err) => {
let result = match fut.poll(cx) {
Poll::Ready(result) => {
self.as_mut().set(Self::PendingDone);
return Poll::Ready(Some(Err(err)));
return Poll::Ready(Some(
result.into_result().map(ValueOrBytes::Value),
));
}
}
}
MiddlewareLayerFutureProj::Execute {
mut stream,
is_stream_done,
} => {
// if let PinnedOptionProj::Some { v } = resp_fut.as_mut().project() {
// let result = ready!(v.poll(cx));
// cx.waker().wake_by_ref(); // No wakers set so we set one
// resp_fut.set(PinnedOption::None);
// return Poll::Ready(Some(Ok(result)));
// }

if *is_stream_done {
self.as_mut().set(Self::Done);
return Poll::Ready(None);
}

match ready!(stream.as_mut().poll_next(cx)) {
Some(result) => match result {
Ok(result) => {
return Poll::Ready(Some(Ok(result)));
Poll::Pending => {
// cursed::outer(cx.waker());

if let Some(op) = CURSED_OP.take() {
match op {
YieldMsg::YieldBody => {
// TODO: Get proper value
CURSED_OP.set(Some(YieldMsg::YieldBodyResult(
serde_json::Value::Null,
)));

cx.waker().wake_by_ref();
}
YieldMsg::YieldBodyResult(_) => unreachable!(),
}
}
// TODO: The `.map` function is skipped for errors. Maybe it should be possible to map them when desired?
// TODO: We also shut down the whole stream on a single error. Is this desired?
Err(err) => {
self.as_mut().set(Self::PendingDone);
return Poll::Ready(Some(Err(err)));
}
},
// The underlying stream has shutdown so we will resolve `resp_fut` and then terminate ourselves
None => {
*is_stream_done = true;
continue;

return Poll::Pending;
}
}
};
}
MiddlewareLayerFutureProj::PendingDone => {
self.as_mut().set(Self::Done);
Expand All @@ -225,7 +191,7 @@ mod private {

fn size_hint(&self) -> (usize, Option<usize>) {
match &self {
Self::Execute { stream: c, .. } => c.size_hint(),
// Self::Execute { stream: c, .. } => c.size_hint(), // TODO: Bring this back
_ => (0, None),
}
}
Expand Down
11 changes: 5 additions & 6 deletions src/internal/middleware/mw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::future::Future;

use rspc_core::internal::{MiddlewareContext, MwV2Result};
use rspc_core::internal::{IntoMiddlewareResult, MiddlewareContext};

// `TNewCtx` is sadly require to constain the impl at the bottom of this file. If you can remove it your a god.
pub trait MiddlewareFn<TLCtx, TNewCtx>:
Expand All @@ -9,23 +9,22 @@ where
TLCtx: Send + Sync + 'static,
{
type Fut: Future<Output = Self::Result> + Send + 'static;
type Result: MwV2Result;
type Result: IntoMiddlewareResult;

// TODO: Rename
fn run_me(&self, ctx: TLCtx, mw: MiddlewareContext<TNewCtx>) -> Self::Fut;
fn execute(&self, ctx: TLCtx, mw: MiddlewareContext<TNewCtx>) -> Self::Fut;
}

impl<TLCtx, TNewCtx, F, Fu> MiddlewareFn<TLCtx, TNewCtx> for F
where
TLCtx: Send + Sync + 'static,
F: Fn(MiddlewareContext<TNewCtx>, TLCtx) -> Fu + Send + Sync + 'static,
Fu: Future + Send + 'static,
Fu::Output: MwV2Result + Send + 'static,
Fu::Output: IntoMiddlewareResult + Send + 'static,
{
type Fut = Fu;
type Result = Fu::Output;

fn run_me(&self, ctx: TLCtx, mw: MiddlewareContext<TNewCtx>) -> Self::Fut {
fn execute(&self, ctx: TLCtx, mw: MiddlewareContext<TNewCtx>) -> Self::Fut {
self(mw, ctx)
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ pub mod internal;

// TODO: Only reexport certain types
pub use rspc_core::error::*;

pub use rspc_core::Body;

0 comments on commit dfb67f3

Please sign in to comment.