diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index cd3d565f..5c6566b0 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -8,7 +8,7 @@ use core::{ }; use futures_core::stream::Stream; -use std::io; +use std::{io, rc::Rc}; use tokio_util::sync::CancellationToken; use tracing::trace; use xitca_io::io::{AsyncIo, Interest, Ready}; @@ -41,6 +41,7 @@ use super::proto::{ encode::CONTINUE, error::ProtoError, }; +use crate::util::futures::WaitOrPending; type ExtRequest = crate::http::Request>; @@ -91,6 +92,7 @@ struct Dispatcher<'a, St, S, ReqB, W, D, const HEADER_LIMIT: usize, const READ_B service: &'a S, _phantom: PhantomData, cancellation_token: CancellationToken, + request_guard: Rc<()>, } // timer state is transformed in following order: @@ -186,6 +188,7 @@ where service, _phantom: PhantomData, cancellation_token, + request_guard: Rc::new(()), } } @@ -208,9 +211,18 @@ where // TODO: add timeout for drain write? self.io.drain_write().await?; + // shutdown io if connection is closed. if self.ctx.is_connection_closed() { return self.io.shutdown().await.map_err(Into::into); } + + // shutdown io if there is no more read buf + if self.io.read_buf.is_empty() + && self.cancellation_token.is_cancelled() + && Rc::strong_count(&self.request_guard) == 1 + { + return self.io.shutdown().await.map_err(Into::into); + } } } @@ -220,14 +232,17 @@ where match self .io .read() - .select(self.cancellation_token.cancelled()) + .select(WaitOrPending::new( + self.cancellation_token.cancelled(), + self.cancellation_token.is_cancelled(), + )) .timeout(self.timer.get()) .await { Err(_) => return Err(self.timer.map_to_err()), Ok(SelectOutput::A(Ok(_))) => {} Ok(SelectOutput::A(Err(_))) => return Err(Error::KeepAliveExpire), - Ok(SelectOutput::B(())) => self.ctx.set_close(), + Ok(SelectOutput::B(())) => {} } while let Some((req, decoder)) = self.ctx.decode_head::(&mut self.io.read_buf)? { @@ -235,6 +250,7 @@ where let (mut body_reader, body) = BodyReader::from_coding(decoder); let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); + let _guard = self.request_guard.clone(); let (parts, body) = match self .service diff --git a/http/src/h1/dispatcher_uring.rs b/http/src/h1/dispatcher_uring.rs index 5dc54e6e..911e5ed3 100644 --- a/http/src/h1/dispatcher_uring.rs +++ b/http/src/h1/dispatcher_uring.rs @@ -23,16 +23,6 @@ use xitca_io::{ use xitca_service::Service; use xitca_unsafe_collection::futures::{Select, SelectOutput}; -use crate::{ - body::NoneBody, - bytes::Bytes, - config::HttpServiceConfig, - date::DateTime, - h1::{body::RequestBody, error::Error}, - http::{response::Response, StatusCode}, - util::timer::{KeepAlive, Timeout}, -}; - use super::{ dispatcher::{status_only, Timer}, proto::{ @@ -42,6 +32,16 @@ use super::{ error::ProtoError, }, }; +use crate::util::futures::WaitOrPending; +use crate::{ + body::NoneBody, + bytes::Bytes, + config::HttpServiceConfig, + date::DateTime, + h1::{body::RequestBody, error::Error}, + http::{response::Response, StatusCode}, + util::timer::{KeepAlive, Timeout}, +}; type ExtRequest = crate::http::Request>; @@ -56,6 +56,7 @@ pub(super) struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_L notify: Notify, _phantom: PhantomData, cancellation_token: CancellationToken, + request_guard: Rc<()>, } #[derive(Default)] @@ -132,13 +133,14 @@ where notify: Notify::new(), _phantom: PhantomData, cancellation_token, + request_guard: Rc::new(()), } } pub(super) async fn run(mut self) -> Result<(), Error> { loop { match self._run().await { - Ok(shutdown) => shutdown, + Ok(_) => {} Err(Error::KeepAliveExpire) => { trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down"); return Ok(()); @@ -156,6 +158,14 @@ where if self.ctx.is_connection_closed() { return self.io.shutdown(Shutdown::Both).map_err(Into::into); } + + // shutdown io if there is no more read buf + if self.read_buf.is_empty() + && self.cancellation_token.is_cancelled() + && Rc::strong_count(&self.request_guard) == 1 + { + return self.io.shutdown(Shutdown::Both).map_err(Into::into); + } } } @@ -165,26 +175,32 @@ where let read = match self .read_buf .read_io(&*self.io) - .select(self.cancellation_token.cancelled()) + .select(WaitOrPending::new( + self.cancellation_token.cancelled(), + self.cancellation_token.is_cancelled(), + )) .timeout(self.timer.get()) .await { Err(_) => return Err(self.timer.map_to_err()), Ok(SelectOutput::A(Ok(read))) => read, Ok(SelectOutput::A(Err(_))) => return Err(Error::KeepAliveExpire), - Ok(SelectOutput::B(())) => { - self.ctx.set_close(); - - return Ok(()); - } + Ok(SelectOutput::B(())) => 0, }; if read == 0 { - self.ctx.set_close(); + if !self.cancellation_token.is_cancelled() { + println!("set close"); + self.ctx.set_close(); + } else { + println!("cancelled"); + } + return Ok(()); } while let Some((req, decoder)) = self.ctx.decode_head::(&mut self.read_buf)? { + println!("decode head"); self.timer.reset_state(); let (waiter, body) = if decoder.is_eof() { @@ -204,6 +220,7 @@ where let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); + let _guard = self.request_guard.clone(); let (parts, body) = self.service.call(req).await.map_err(Error::Service)?.into_parts(); let mut encoder = self.ctx.encode_head(parts, &body, &mut *self.write_buf)?; diff --git a/http/src/h2/proto/dispatcher.rs b/http/src/h2/proto/dispatcher.rs index c7775d51..212ac72f 100644 --- a/http/src/h2/proto/dispatcher.rs +++ b/http/src/h2/proto/dispatcher.rs @@ -7,12 +7,13 @@ use core::{ task::{ready, Context, Poll}, time::Duration, }; - +use std::rc::Rc; use ::h2::{ server::{Connection, SendResponse}, Ping, PingPong, }; use futures_core::stream::Stream; +use tokio_util::sync::CancellationToken; use tracing::trace; use xitca_io::io::{AsyncRead, AsyncWrite}; use xitca_service::Service; @@ -30,6 +31,7 @@ use crate::{ }, util::{futures::Queue, timer::KeepAlive}, }; +use crate::util::futures::WaitOrPending; /// Http/2 dispatcher pub(crate) struct Dispatcher<'a, TlsSt, S, ReqB> { @@ -39,6 +41,7 @@ pub(crate) struct Dispatcher<'a, TlsSt, S, ReqB> { ka_dur: Duration, service: &'a S, date: &'a DateTimeHandle, + cancellation_token: CancellationToken, _req_body: PhantomData, } @@ -60,6 +63,7 @@ where ka_dur: Duration, service: &'a S, date: &'a DateTimeHandle, + cancellation_token: CancellationToken, ) -> Self { Self { io, @@ -69,6 +73,7 @@ where service, date, _req_body: PhantomData, + cancellation_token, } } @@ -80,6 +85,7 @@ where ka_dur, service, date, + cancellation_token, .. } = self; @@ -101,7 +107,11 @@ where let mut queue = Queue::new(); loop { - match io.accept().select(try_poll_queue(&mut queue, &mut ping_pong)).await { + if queue.is_empty() && cancellation_token.is_cancelled() { + break; + } + + match io.accept().select(try_poll_queue(&mut queue, &mut ping_pong, cancellation_token.clone())).await { SelectOutput::A(Some(Ok((req, tx)))) => { // Convert http::Request body type to crate::h2::Body // and reconstruct as HttpRequest. @@ -137,6 +147,7 @@ where async fn try_poll_queue( queue: &mut Queue, ping_ping: &mut H2PingPong<'_>, + cancellation_token: CancellationToken, ) -> SelectOutput<(), Result<(), ::h2::Error>> where F: Future>, @@ -146,7 +157,14 @@ where { loop { if queue.is_empty() { - return SelectOutput::B(ping_ping.await); + return match ping_ping.select(WaitOrPending::new(cancellation_token.cancelled(), cancellation_token.is_cancelled())).await { + SelectOutput::A(res) => SelectOutput::B(res), + SelectOutput::B(_) => { + println!("cancelled"); + + SelectOutput::A(()) + }, + } } match queue.next2().await { diff --git a/http/src/h2/service.rs b/http/src/h2/service.rs index 20dbb44a..1fec5319 100644 --- a/http/src/h2/service.rs +++ b/http/src/h2/service.rs @@ -48,7 +48,7 @@ where async fn call( &self, - ((io, addr), _): ((St, SocketAddr), CancellationToken), + ((io, addr), cancellation_token): ((St, SocketAddr), CancellationToken), ) -> Result { // tls accept timer. let timer = self.keep_alive(); @@ -78,6 +78,7 @@ where self.config.keep_alive_timeout, &self.service, self.date.get(), + cancellation_token ); dispatcher.run().await?; diff --git a/http/src/service.rs b/http/src/service.rs index 896e2a33..8ef922f5 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -147,6 +147,7 @@ where self.config.keep_alive_timeout, &self.service, self.date.get(), + cancellation_token, ) .run() .await diff --git a/http/src/util/futures.rs b/http/src/util/futures.rs index 5bfe864c..7ab38347 100644 --- a/http/src/util/futures.rs +++ b/http/src/util/futures.rs @@ -1,5 +1,7 @@ +use pin_project_lite::pin_project; #[cfg(any(feature = "http2", feature = "http3"))] pub(crate) use queue::*; +use std::future::Future; #[cfg(any(feature = "http2", feature = "http3"))] mod queue { @@ -43,3 +45,39 @@ mod queue { } } } + +// A future that resolve only one time when the future is ready +pin_project! { + pub(crate) struct WaitOrPending { + #[pin] + future: F, + is_pending: bool, + } +} + +impl WaitOrPending { + pub fn new(future: F, is_pending: bool) -> Self { + Self { future, is_pending } + } +} + +impl Future for WaitOrPending { + type Output = F::Output; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + if self.is_pending { + return std::task::Poll::Pending; + } + + let this = self.as_mut().project(); + + match this.future.poll(cx) { + std::task::Poll::Ready(f) => { + *this.is_pending = true; + + std::task::Poll::Ready(f) + } + poll => poll, + } + } +} diff --git a/server/src/server/future.rs b/server/src/server/future.rs index c3769c1f..2c60ef4e 100644 --- a/server/src/server/future.rs +++ b/server/src/server/future.rs @@ -47,9 +47,11 @@ impl ServerFuture { match *self { Self::Init { ref server, .. } => Ok(ServerHandle { tx: server.tx_cmd.clone(), + cancellation_token: server.cancellation_token.clone(), }), Self::Running(ref inner) => Ok(ServerHandle { tx: inner.server.tx_cmd.clone(), + cancellation_token: inner.server.cancellation_token.clone(), }), Self::Error(_) => match mem::take(self) { Self::Error(e) => Err(e), diff --git a/server/src/server/handle.rs b/server/src/server/handle.rs index 2fc90d35..cadfb585 100644 --- a/server/src/server/handle.rs +++ b/server/src/server/handle.rs @@ -1,10 +1,11 @@ -use tokio::sync::mpsc::UnboundedSender; - use super::Command; +use tokio::sync::mpsc::UnboundedSender; +use tokio_util::sync::CancellationToken; #[derive(Clone)] pub struct ServerHandle { pub(super) tx: UnboundedSender, + pub(super) cancellation_token: CancellationToken, } impl ServerHandle { @@ -17,5 +18,6 @@ impl ServerHandle { }; let _ = self.tx.send(cmd); + self.cancellation_token.cancel(); } } diff --git a/test/tests/h1.rs b/test/tests/h1.rs index d88aa8ab..95d92fc9 100644 --- a/test/tests/h1.rs +++ b/test/tests/h1.rs @@ -251,26 +251,35 @@ async fn h1_shutdown_during_request() -> Result<(), Error> { let mut buf = [0; 128]; - // simulate writing req during shutdown. - let (req_buf_1, req_buf2) = SIMPLE_GET_REQ.split_at(SIMPLE_GET_REQ.len() / 2); - // write first half of the request. - stream.write_all(req_buf_1)?; + stream.write_all(SLEEP_GET_REQ_PART_1)?; + tokio::time::sleep(Duration::from_millis(100)).await; // shutdown server. handle.try_handle()?.stop(true); + tokio::time::sleep(Duration::from_millis(100)).await; - // sleep a bit to make sure event are processed. + // write second half of the request. + stream.write_all(SLEEP_GET_REQ_PART_2)?; tokio::time::sleep(Duration::from_millis(100)).await; // write second half of the request. - stream.write_all(req_buf2)?; + stream.write_all(SLEEP_GET_REQ_PART_3)?; + tokio::time::sleep(Duration::from_millis(100)).await; // read response. loop { let n = stream.read(&mut buf)?; - if buf[..n].ends_with(b"GET Response") { + if n == 0 { + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected eof", + )))?; + } + + // it should be chunked so it ends with 0\r\n\r\n + if buf[..n].ends_with(b"0\r\n\r\n") { break; } } @@ -319,8 +328,17 @@ async fn handle(req: Request>) -> Result { + tokio::time::sleep(Duration::from_millis(200)).await; + + Ok(Response::new(ResponseBody::stream(BoxBody::new(req.into_body())))) + } _ => todo!(), } } const SIMPLE_GET_REQ: &[u8] = b"GET / HTTP/1.1\r\ncontent-length: 0\r\n\r\n"; + +const SLEEP_GET_REQ_PART_1: &[u8] = b"GET /sleep HTTP/1.1\r\n"; +const SLEEP_GET_REQ_PART_2: &[u8] = b"content-length: 10\r\n\r\n01"; +const SLEEP_GET_REQ_PART_3: &[u8] = b"23456789"; diff --git a/test/tests/h2.rs b/test/tests/h2.rs index bcf33262..d7bd89c7 100644 --- a/test/tests/h2.rs +++ b/test/tests/h2.rs @@ -27,7 +27,17 @@ async fn h2_get() -> Result<(), Error> { assert_eq!("GET Response", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); + + // let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?; + // assert_eq!(res.status().as_u16(), 200); + // assert!(!res.can_close_connection()); + // let body = res.string().await?; + // assert_eq!("GET Response", body); + + let mut res = c.get(&server_url).version(Version::HTTP_2).send().await; + + println!("{:?}", res.err().unwrap()); handle.await?;