diff --git a/crates/core/src/exec/connection.rs b/crates/core/src/exec/connection.rs index 0e963e63..96015d0b 100644 --- a/crates/core/src/exec/connection.rs +++ b/crates/core/src/exec/connection.rs @@ -43,10 +43,10 @@ where }; match res { - ExecutorResult::Task(task) => { + ExecutorResult::Task(task, shutdown_tx) => { let task_id = task.id; self.streams.insert(task); - self.subscriptions.shutdown(task_id); + self.subscriptions.insert(task_id, shutdown_tx); } ExecutorResult::Future(fut) => { self.streams.insert(fut.into()); @@ -92,8 +92,7 @@ pub async fn run_connection< loop { if !batch.is_empty() { - let batch = batch.drain(..batch.len()).collect::>(); - if let Err(_err) = socket.send(batch).await { + if let Err(_err) = socket.send(std::mem::take(&mut batch)).await { #[cfg(feature = "tracing")] tracing::error!("Error sending message to websocket: {}", _err); } @@ -159,7 +158,7 @@ pub async fn run_connection< StreamYield::Item(resp) => batch.push(resp), StreamYield::Finished(f) => { if let Some(stream) = f.take(Pin::new(&mut conn.streams)) { - conn.subscriptions._internal_remove(stream.id); + conn.subscriptions.remove(stream.id); } } } @@ -171,7 +170,7 @@ pub async fn run_connection< } StreamYield::Finished(f) => { if let Some(stream) = f.take(Pin::new(&mut conn.streams)) { - conn.subscriptions._internal_remove(stream.id); + conn.subscriptions.remove(stream.id); } } } diff --git a/crates/core/src/exec/execute.rs b/crates/core/src/exec/execute.rs index 0f9005cb..681ffe6b 100644 --- a/crates/core/src/exec/execute.rs +++ b/crates/core/src/exec/execute.rs @@ -1,5 +1,7 @@ use std::{pin::Pin, sync::Arc}; +use futures::channel::oneshot; + use crate::{ body::Body, error::ExecError, @@ -11,7 +13,7 @@ use crate::{ Router, }; -use super::{task, SubscriptionMap}; +use super::SubscriptionMap; /// TODO /// @@ -23,7 +25,7 @@ pub enum ExecutorResult { /// A future that will resolve to a response. Future(RequestFuture), /// A task that should be queued onto an async runtime. - Task(Task), + Task(Task, oneshot::Sender<()>), } // TODO: Move this into `build_router.rs` and turn it into a module with all the other `exec::*` types @@ -62,11 +64,19 @@ impl Router { } Some(_) => match get_subscription(self, ctx, data) { None => Err(ExecError::OperationNotFound), - Some(stream) => Ok(ExecutorResult::Task(Task { - id, - stream, - status: task::Status::ShouldBePolled { done: false }, - })), + Some(stream) => { + let (tx, rx) = oneshot::channel(); + + Ok(ExecutorResult::Task( + Task { + id, + stream, + done: false, + shutdown_rx: Some(rx), + }, + tx, + )) + } }, } .unwrap_or_else(|e| { diff --git a/crates/core/src/exec/subscription_map.rs b/crates/core/src/exec/subscription_map.rs index 9398bcfa..0279f873 100644 --- a/crates/core/src/exec/subscription_map.rs +++ b/crates/core/src/exec/subscription_map.rs @@ -27,8 +27,18 @@ impl SubscriptionMap { } } + pub fn insert(&mut self, id: u32, tx: oneshot::Sender<()>) { + self.map.insert(id, tx); + } + // We remove but don't shutdown. This should be used when we know the subscription is shutdown. - pub(crate) fn _internal_remove(&mut self, id: u32) { - self.map.remove(&id); + pub(crate) fn remove(&mut self, id: u32) { + if let Some(tx) = self.map.remove(&id) { + #[cfg(debug_assertions)] + #[allow(clippy::panic)] + if !tx.is_canceled() { + panic!("Subscription was not shutdown before being removed!"); + } + }; } } diff --git a/crates/core/src/exec/task.rs b/crates/core/src/exec/task.rs index cc81e1f8..4fc634cf 100644 --- a/crates/core/src/exec/task.rs +++ b/crates/core/src/exec/task.rs @@ -1,6 +1,6 @@ use std::{fmt, pin::Pin, task::Poll}; -use futures::{ready, Stream}; +use futures::{channel::oneshot, ready, stream::FusedStream, FutureExt, Stream}; use crate::body::Body; use crate::exec; @@ -9,12 +9,6 @@ use super::{arc_ref::ArcRef, request_future::RequestFuture}; // TODO: Should this be called `Task` or `StreamWrapper`? Will depend on it's final form. -// TODO: Replace with FusedStream in dev if possible??? -pub enum Status { - ShouldBePolled { done: bool }, - DoNotPoll, -} - // TODO: docs pub struct Task { pub(crate) id: u32, @@ -22,7 +16,8 @@ pub struct Task { // rspc's whole middleware system only uses `Stream`'s cause it makes life easier so we change to & from a `Future` at the start/end. pub(crate) stream: ArcRef>>, // Mark when the stream is done. This means `self.reference` returned `None` but we still had to yield the complete message so we haven't returned `None` yet. - pub(crate) status: Status, + pub(crate) done: bool, + pub(crate) shutdown_rx: Option>, } impl fmt::Debug for Task { @@ -40,17 +35,13 @@ impl Stream for Task { mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - match &self.status { - #[allow(clippy::panic)] - Status::DoNotPoll => { - #[cfg(debug_assertions)] - panic!("`StreamWrapper` polled after completion") - } - Status::ShouldBePolled { done } => { - if *done { - self.status = Status::DoNotPoll; - return Poll::Ready(None); - } + if self.done { + return Poll::Ready(None); + } + + if let Some(shutdown_rx) = self.shutdown_rx.as_mut() { + if shutdown_rx.poll_unpin(cx).is_ready() { + self.done = true; } } @@ -64,8 +55,8 @@ impl Stream for Task { }, None => { let id = self.id; - cx.waker().wake_by_ref(); // We want the stream to be called again so we can return `None` and close it - self.status = Status::ShouldBePolled { done: true }; + self.done = true; + cx.waker().wake_by_ref(); exec::Response { id, inner: exec::ResponseInner::Complete, @@ -80,12 +71,19 @@ impl Stream for Task { } } +impl FusedStream for Task { + fn is_terminated(&self) -> bool { + self.done + } +} + impl From for Task { fn from(value: RequestFuture) -> Self { Self { id: value.id, stream: value.stream, - status: Status::ShouldBePolled { done: false }, + done: false, + shutdown_rx: None, } } } diff --git a/crates/core/src/router.rs b/crates/core/src/router.rs index b2dc3262..024a8195 100644 --- a/crates/core/src/router.rs +++ b/crates/core/src/router.rs @@ -13,13 +13,7 @@ use specta::{ TypeMap, }; -use crate::{ - error::ExportError, - internal::ProcedureDef, - middleware::ProcedureKind, - procedure_store::{ProcedureTodo, ProceduresDef}, - router_builder::ProcedureMap, -}; +use crate::{error::ExportError, procedure_store::ProceduresDef, router_builder::ProcedureMap}; // TODO: Break this out into it's own file /// ExportConfig is used to configure how rspc will export your types. diff --git a/crates/httpz/src/httpz_endpoint.rs b/crates/httpz/src/httpz_endpoint.rs index 68eb55f8..f58070a1 100644 --- a/crates/httpz/src/httpz_endpoint.rs +++ b/crates/httpz/src/httpz_endpoint.rs @@ -128,7 +128,7 @@ where ExecutorResult::Future(fut) => fut.await, ExecutorResult::Response(response) => response, #[allow(clippy::panic)] - ExecutorResult::Task(_) => { + ExecutorResult::Task(_, _) => { #[cfg(debug_assertions)] panic!("rspc: unexpected HTTP endpoint returned 'Task'"); } @@ -220,7 +220,7 @@ where responses.push(resp); } #[allow(clippy::panic)] - ExecutorResult::Task(_) => { + ExecutorResult::Task(_, _) => { #[cfg(debug_assertions)] panic!("rspc: unexpected HTTP endpoint returned 'Task'"); }