diff --git a/Cargo.toml b/Cargo.toml index 8f5d6f19..6512284c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,9 @@ sessions = ["async-session", "cookies"] unstable = [] [dependencies] -async-h1 = { version = "2.3.0", optional = true } +# async-h1 = { version = "2.3.0", optional = true } +# FIXME: for proposal purpose only +async-h1 = { git = "https://github.com/pbzweihander/async-h1.git", branch = "cancellation", optional = true } async-session = { version = "3.0", optional = true } async-sse = "4.0.1" async-std = { version = "1.6.5", features = ["unstable"] } @@ -48,6 +50,8 @@ pin-project-lite = "0.2.0" route-recognizer = "0.2.0" serde = "1.0.117" serde_json = "1.0.59" +stopper = "0.2.0" +waitgroup = "0.1.2" [dev-dependencies] async-std = { version = "1.6.5", features = ["unstable", "attributes"] } diff --git a/src/lib.rs b/src/lib.rs index 3245f23f..6acd12e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,8 @@ pub use server::Server; pub use http_types::{self as http, Body, Error, Status, StatusCode}; +pub use stopper; + /// Create a new Tide server. /// /// # Examples diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 7b86a013..5b444ff3 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -9,6 +9,10 @@ use async_std::net::{self, SocketAddr, TcpStream}; use async_std::prelude::*; use async_std::{io, task}; +use futures_util::future::Either; + +use waitgroup::{WaitGroup, Worker}; + /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::net::TcpListener]. It is implemented as an /// enum in order to allow creation of a tide::listener::TcpListener @@ -44,16 +48,30 @@ impl TcpListener { } } -fn handle_tcp(app: Server, stream: TcpStream) { +fn handle_tcp( + app: Server, + stream: TcpStream, + wait_group_worker: Worker, +) { task::spawn(async move { + let _wait_group_worker = wait_group_worker; + let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); - let fut = async_h1::accept(stream, |mut req| async { - req.set_local_addr(local_addr); - req.set_peer_addr(peer_addr); - app.respond(req).await - }); + let opts = async_h1::ServerOptions { + stopper: app.stopper.clone(), + ..Default::default() + }; + let fut = async_h1::accept_with_opts( + stream, + |mut req| async { + req.set_local_addr(local_addr); + req.set_peer_addr(peer_addr); + app.respond(req).await + }, + opts, + ); if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); @@ -98,7 +116,13 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = listener.incoming(); + let incoming = listener.incoming(); + let mut incoming = if let Some(stopper) = server.stopper.clone() { + Either::Left(stopper.stop_stream(incoming)) + } else { + Either::Right(incoming) + }; + let wait_group = WaitGroup::new(); while let Some(stream) = incoming.next().await { match stream { @@ -111,10 +135,13 @@ where } Ok(stream) => { - handle_tcp(server.clone(), stream); + handle_tcp(server.clone(), stream, wait_group.worker()); } }; } + + wait_group.wait().await; + Ok(()) } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index d99a21d3..9b6c6e4d 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -10,6 +10,10 @@ use async_std::path::PathBuf; use async_std::prelude::*; use async_std::{io, task}; +use futures_util::future::Either; + +use waitgroup::{WaitGroup, Worker}; + /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an /// enum in order to allow creation of a tide::listener::UnixListener @@ -45,16 +49,30 @@ impl UnixListener { } } -fn handle_unix(app: Server, stream: UnixStream) { +fn handle_unix( + app: Server, + stream: UnixStream, + wait_group_worker: Worker, +) { task::spawn(async move { + let _wait_group_worker = wait_group_worker; + let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); - let fut = async_h1::accept(stream, |mut req| async { - req.set_local_addr(local_addr.as_ref()); - req.set_peer_addr(peer_addr.as_ref()); - app.respond(req).await - }); + let opts = async_h1::ServerOptions { + stopper: app.stopper.clone(), + ..Default::default() + }; + let fut = async_h1::accept_with_opts( + stream, + |mut req| async { + req.set_local_addr(local_addr.as_ref()); + req.set_peer_addr(peer_addr.as_ref()); + app.respond(req).await + }, + opts, + ); if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); @@ -96,7 +114,13 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = listener.incoming(); + let incoming = listener.incoming(); + let mut incoming = if let Some(stopper) = server.stopper.clone() { + Either::Left(stopper.stop_stream(incoming)) + } else { + Either::Right(incoming) + }; + let wait_group = WaitGroup::new(); while let Some(stream) = incoming.next().await { match stream { @@ -109,10 +133,13 @@ where } Ok(stream) => { - handle_unix(server.clone(), stream); + handle_unix(server.clone(), stream, wait_group.worker()); } }; } + + wait_group.wait().await; + Ok(()) } diff --git a/src/server.rs b/src/server.rs index 1e6f8c1a..0dc96ea1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,8 @@ use async_std::io; use async_std::sync::Arc; +use stopper::Stopper; + #[cfg(feature = "cookies")] use crate::cookies; use crate::listener::{Listener, ToListener}; @@ -38,6 +40,7 @@ pub struct Server { /// We don't use a Mutex around the Vec here because adding a middleware during execution should be an error. #[allow(clippy::rc_buffer)] middleware: Arc>>>, + pub(crate) stopper: Option, } impl Server<()> { @@ -113,6 +116,7 @@ where Arc::new(log::LogMiddleware::new()), ]), state, + stopper: None, } } @@ -286,6 +290,7 @@ where router, state, middleware, + stopper: _, } = self.clone(); let method = req.method().to_owned(); @@ -317,6 +322,29 @@ where pub fn state(&self) -> &State { &self.state } + + /// Stops the server when given `stopper` stops. + /// + /// # Example + /// + /// ```rust + /// use tide::stopper::Stopper; + /// + /// let mut app = tide::new(); + /// + /// let stopper = Stopper::new(); + /// + /// app.with_stopper(stopper.clone()); + /// + /// // Runs server... + /// + /// // When something happens + /// stopper.stop(); + /// ``` + pub fn with_stopper(&mut self, stopper: Stopper) -> &mut Self { + self.stopper = Some(stopper); + self + } } impl std::fmt::Debug for Server { @@ -331,6 +359,7 @@ impl Clone for Server { router: self.router.clone(), state: self.state.clone(), middleware: self.middleware.clone(), + stopper: self.stopper.clone(), } } } diff --git a/tests/cancellation.rs b/tests/cancellation.rs new file mode 100644 index 00000000..2996f61e --- /dev/null +++ b/tests/cancellation.rs @@ -0,0 +1,57 @@ +mod test_utils; +use async_std::prelude::*; +use async_std::task; +use std::time::Duration; + +use tide::stopper::Stopper; +use tide::Response; + +#[async_std::test] +async fn cancellation() -> Result<(), http_types::Error> { + let port = test_utils::find_port().await; + let stopper = Stopper::new(); + let stopper_ = stopper.clone(); + + let server = task::spawn(async move { + let mut app = tide::new(); + app.with_stopper(stopper_); + app.at("/").get(|_| async { + task::sleep(Duration::from_secs(1)).await; + Ok(Response::new(200)) + }); + app.listen(("localhost", port)).await?; + tide::Result::Ok(()) + }); + + let client1 = task::spawn(async move { + task::sleep(Duration::from_millis(100)).await; + let res = surf::get(format!("http://localhost:{}", port)) + .await + .unwrap(); + assert_eq!(res.status(), 200); + async_std::future::pending().await + }); + + let client2 = task::spawn(async move { + task::sleep(Duration::from_millis(200)).await; + let res = surf::get(format!("http://localhost:{}", port)) + .await + .unwrap(); + assert_eq!(res.status(), 200); + async_std::future::pending().await + }); + + let stop = task::spawn(async move { + task::sleep(Duration::from_millis(300)).await; + stopper.stop(); + Ok(()) + }); + + server + .try_join(stop) + .race(client1.try_join(client2)) + .timeout(Duration::from_secs(2)) + .await??; + + Ok(()) +}