diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index f1a8a110..cd3d565f 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -191,44 +191,32 @@ where async fn run(mut self) -> Result<(), Error> { loop { - let shutdown = match self._run().await { - Ok(shutdown) => shutdown, + match self._run().await { + Ok(_) => {} Err(Error::KeepAliveExpire) => { trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down"); return Ok(()); } - Err(Error::RequestTimeout) => { - self.request_error(|| status_only(StatusCode::REQUEST_TIMEOUT)); - - false - } + Err(Error::RequestTimeout) => self.request_error(|| status_only(StatusCode::REQUEST_TIMEOUT)), Err(Error::Proto(ProtoError::HeaderTooLarge)) => { - self.request_error(|| status_only(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)); - - false - } - Err(Error::Proto(_)) => { - self.request_error(|| status_only(StatusCode::BAD_REQUEST)); - - false + self.request_error(|| status_only(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)) } + Err(Error::Proto(_)) => self.request_error(|| status_only(StatusCode::BAD_REQUEST)), Err(e) => return Err(e), }; // TODO: add timeout for drain write? self.io.drain_write().await?; - if shutdown || self.ctx.is_connection_closed() { + if self.ctx.is_connection_closed() { return self.io.shutdown().await.map_err(Into::into); } } } - async fn _run(&mut self) -> Result> { + async fn _run(&mut self) -> Result<(), Error> { self.timer.update(self.ctx.date().now()); - let mut cancelled = false; - match self .io .read() @@ -239,9 +227,7 @@ where Err(_) => return Err(self.timer.map_to_err()), Ok(SelectOutput::A(Ok(_))) => {} Ok(SelectOutput::A(Err(_))) => return Err(Error::KeepAliveExpire), - Ok(SelectOutput::B(())) => { - cancelled = true; - } + Ok(SelectOutput::B(())) => self.ctx.set_close(), } while let Some((req, decoder)) = self.ctx.decode_head::(&mut self.io.read_buf)? { @@ -297,7 +283,7 @@ where } } - Ok(cancelled) + Ok(()) } fn encode_head(&mut self, parts: Parts, body: &impl Stream) -> Result { diff --git a/http/src/h1/dispatcher_uring.rs b/http/src/h1/dispatcher_uring.rs index 2a49b42e..5dc54e6e 100644 --- a/http/src/h1/dispatcher_uring.rs +++ b/http/src/h1/dispatcher_uring.rs @@ -14,13 +14,14 @@ use std::{io, net::Shutdown, rc::Rc}; use futures_core::stream::Stream; use pin_project_lite::pin_project; +use tokio_util::sync::CancellationToken; use tracing::trace; use xitca_io::{ bytes::BytesMut, io_uring::{write_all, AsyncBufRead, AsyncBufWrite, BoundedBuf}, }; use xitca_service::Service; -use xitca_unsafe_collection::futures::SelectOutput; +use xitca_unsafe_collection::futures::{Select, SelectOutput}; use crate::{ body::NoneBody, @@ -54,6 +55,7 @@ pub(super) struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_L write_buf: BufOwned, notify: Notify, _phantom: PhantomData, + cancellation_token: CancellationToken, } #[derive(Default)] @@ -118,6 +120,7 @@ where config: HttpServiceConfig, service: &'a S, date: &'a D, + cancellation_token: CancellationToken, ) -> Self { Self { io: Rc::new(io), @@ -128,13 +131,14 @@ where write_buf: BufOwned::new(), notify: Notify::new(), _phantom: PhantomData, + cancellation_token, } } pub(super) async fn run(mut self) -> Result<(), Error> { loop { match self._run().await { - Ok(_) => {} + Ok(shutdown) => shutdown, Err(Error::KeepAliveExpire) => { trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down"); return Ok(()); @@ -145,7 +149,7 @@ where } Err(Error::Proto(_)) => self.request_error(|| status_only(StatusCode::BAD_REQUEST)), Err(e) => return Err(e), - } + }; self.write_buf.write_io(&*self.io).await?; @@ -158,12 +162,22 @@ where async fn _run(&mut self) -> Result<(), Error> { self.timer.update(self.ctx.date().now()); - let read = self + let read = match self .read_buf .read_io(&*self.io) + .select(self.cancellation_token.cancelled()) .timeout(self.timer.get()) .await - .map_err(|_| self.timer.map_to_err())??; + { + Err(_) => return Err(self.timer.map_to_err()), + Ok(SelectOutput::A(Ok(read))) => read, + Ok(SelectOutput::A(Err(_))) => return Err(Error::KeepAliveExpire), + Ok(SelectOutput::B(())) => { + self.ctx.set_close(); + + return Ok(()); + } + }; if read == 0 { self.ctx.set_close(); diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index 80d041d6..701b23a1 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -117,7 +117,7 @@ where type Error = HttpServiceError; async fn call( &self, - ((io, addr), _cancellation_token): ((TcpStream, SocketAddr), CancellationToken), + ((io, addr), cancellation_token): ((TcpStream, SocketAddr), CancellationToken), ) -> Result { let accept_dur = self.config.tls_accept_timeout; let deadline = self.date.get().now() + accept_dur; @@ -130,10 +130,18 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher_uring::Dispatcher::new(io, addr, timer, self.config, &self.service, self.date.get()) - .run() - .await - .map_err(Into::into) + super::dispatcher_uring::Dispatcher::new( + io, + addr, + timer, + self.config, + &self.service, + self.date.get(), + cancellation_token, + ) + .run() + .await + .map_err(Into::into) } } diff --git a/test/tests/h1.rs b/test/tests/h1.rs index 41fcf879..d88aa8ab 100644 --- a/test/tests/h1.rs +++ b/test/tests/h1.rs @@ -243,6 +243,43 @@ async fn h1_keepalive() -> Result<(), Error> { Ok(()) } +#[tokio::test] +async fn h1_shutdown_during_request() -> Result<(), Error> { + let mut handle = test_h1_server(fn_service(handle))?; + + let mut stream = TcpStream::connect(handle.addr())?; + + let mut buf = [0; 128]; + + // simulate writing req during shutdown. + let (req_buf_1, req_buf2) = SIMPLE_GET_REQ.split_at(SIMPLE_GET_REQ.len() / 2); + + // write first half of the request. + stream.write_all(req_buf_1)?; + + // shutdown server. + handle.try_handle()?.stop(true); + + // sleep a bit to make sure event are processed. + tokio::time::sleep(Duration::from_millis(100)).await; + + // write second half of the request. + stream.write_all(req_buf2)?; + + // read response. + loop { + let n = stream.read(&mut buf)?; + + if buf[..n].ends_with(b"GET Response") { + break; + } + } + + handle.await?; + + Ok(()) +} + async fn handle(req: Request>) -> Result, Error> { match (req.method(), req.uri().path()) { (&Method::GET, "/") | (&Method::HEAD, "/") => Ok(Response::new(Bytes::from("GET Response").into())),