Skip to content

Commit

Permalink
feat(shutdown): add test, simplify on how to set close
Browse files Browse the repository at this point in the history
  • Loading branch information
joelwurtz committed Jan 31, 2025
1 parent 573b5cb commit f0bd80e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 33 deletions.
32 changes: 9 additions & 23 deletions http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,44 +191,32 @@ where

async fn run(mut self) -> Result<(), Error<S::Error, BE>> {
loop {
let shutdown = match self._run().await {
Ok(shutdown) => shutdown,
match self._run().await {
Ok(_) => {}
Err(Error::KeepAliveExpire) => {
trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down");
return Ok(());
}
Err(Error::RequestTimeout) => {
self.request_error(|| status_only(StatusCode::REQUEST_TIMEOUT));

false
}
Err(Error::RequestTimeout) => self.request_error(|| status_only(StatusCode::REQUEST_TIMEOUT)),
Err(Error::Proto(ProtoError::HeaderTooLarge)) => {
self.request_error(|| status_only(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE));

false
}
Err(Error::Proto(_)) => {
self.request_error(|| status_only(StatusCode::BAD_REQUEST));

false
self.request_error(|| status_only(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE))
}
Err(Error::Proto(_)) => self.request_error(|| status_only(StatusCode::BAD_REQUEST)),
Err(e) => return Err(e),
};

// TODO: add timeout for drain write?
self.io.drain_write().await?;

if shutdown || self.ctx.is_connection_closed() {
if self.ctx.is_connection_closed() {
return self.io.shutdown().await.map_err(Into::into);
}
}
}

async fn _run(&mut self) -> Result<bool, Error<S::Error, BE>> {
async fn _run(&mut self) -> Result<(), Error<S::Error, BE>> {
self.timer.update(self.ctx.date().now());

let mut cancelled = false;

match self
.io
.read()
Expand All @@ -239,9 +227,7 @@ where
Err(_) => return Err(self.timer.map_to_err()),
Ok(SelectOutput::A(Ok(_))) => {}
Ok(SelectOutput::A(Err(_))) => return Err(Error::KeepAliveExpire),
Ok(SelectOutput::B(())) => {
cancelled = true;
}
Ok(SelectOutput::B(())) => self.ctx.set_close(),
}

while let Some((req, decoder)) = self.ctx.decode_head::<READ_BUF_LIMIT>(&mut self.io.read_buf)? {
Expand Down Expand Up @@ -297,7 +283,7 @@ where
}
}

Ok(cancelled)
Ok(())
}

fn encode_head(&mut self, parts: Parts, body: &impl Stream) -> Result<TransferCoding, ProtoError> {
Expand Down
24 changes: 19 additions & 5 deletions http/src/h1/dispatcher_uring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ use std::{io, net::Shutdown, rc::Rc};

use futures_core::stream::Stream;
use pin_project_lite::pin_project;
use tokio_util::sync::CancellationToken;
use tracing::trace;
use xitca_io::{
bytes::BytesMut,
io_uring::{write_all, AsyncBufRead, AsyncBufWrite, BoundedBuf},
};
use xitca_service::Service;
use xitca_unsafe_collection::futures::SelectOutput;
use xitca_unsafe_collection::futures::{Select, SelectOutput};

use crate::{
body::NoneBody,
Expand Down Expand Up @@ -54,6 +55,7 @@ pub(super) struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_L
write_buf: BufOwned,
notify: Notify<BufOwned>,
_phantom: PhantomData<ReqB>,
cancellation_token: CancellationToken,
}

#[derive(Default)]
Expand Down Expand Up @@ -118,6 +120,7 @@ where
config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
service: &'a S,
date: &'a D,
cancellation_token: CancellationToken,
) -> Self {
Self {
io: Rc::new(io),
Expand All @@ -128,13 +131,14 @@ where
write_buf: BufOwned::new(),
notify: Notify::new(),
_phantom: PhantomData,
cancellation_token,
}
}

pub(super) async fn run(mut self) -> Result<(), Error<S::Error, BE>> {
loop {
match self._run().await {
Ok(_) => {}
Ok(shutdown) => shutdown,
Err(Error::KeepAliveExpire) => {
trace!(target: "h1_dispatcher", "Connection keep-alive expired. Shutting down");
return Ok(());
Expand All @@ -145,7 +149,7 @@ where
}
Err(Error::Proto(_)) => self.request_error(|| status_only(StatusCode::BAD_REQUEST)),
Err(e) => return Err(e),
}
};

self.write_buf.write_io(&*self.io).await?;

Expand All @@ -158,12 +162,22 @@ where
async fn _run(&mut self) -> Result<(), Error<S::Error, BE>> {
self.timer.update(self.ctx.date().now());

let read = self
let read = match self
.read_buf
.read_io(&*self.io)
.select(self.cancellation_token.cancelled())
.timeout(self.timer.get())
.await
.map_err(|_| self.timer.map_to_err())??;
{
Err(_) => return Err(self.timer.map_to_err()),
Ok(SelectOutput::A(Ok(read))) => read,
Ok(SelectOutput::A(Err(_))) => return Err(Error::KeepAliveExpire),
Ok(SelectOutput::B(())) => {
self.ctx.set_close();

return Ok(());
}
};

if read == 0 {
self.ctx.set_close();
Expand Down
18 changes: 13 additions & 5 deletions http/src/h1/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ where
type Error = HttpServiceError<S::Error, BE>;
async fn call(
&self,
((io, addr), _cancellation_token): ((TcpStream, SocketAddr), CancellationToken),
((io, addr), cancellation_token): ((TcpStream, SocketAddr), CancellationToken),
) -> Result<Self::Response, Self::Error> {
let accept_dur = self.config.tls_accept_timeout;
let deadline = self.date.get().now() + accept_dur;
Expand All @@ -130,10 +130,18 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher_uring::Dispatcher::new(io, addr, timer, self.config, &self.service, self.date.get())
.run()
.await
.map_err(Into::into)
super::dispatcher_uring::Dispatcher::new(
io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
cancellation_token,
)
.run()
.await
.map_err(Into::into)
}
}

Expand Down
37 changes: 37 additions & 0 deletions test/tests/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,43 @@ async fn h1_keepalive() -> Result<(), Error> {
Ok(())
}

#[tokio::test]
async fn h1_shutdown_during_request() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;

let mut stream = TcpStream::connect(handle.addr())?;

let mut buf = [0; 128];

// simulate writing req during shutdown.
let (req_buf_1, req_buf2) = SIMPLE_GET_REQ.split_at(SIMPLE_GET_REQ.len() / 2);

// write first half of the request.
stream.write_all(req_buf_1)?;

// shutdown server.
handle.try_handle()?.stop(true);

// sleep a bit to make sure event are processed.
tokio::time::sleep(Duration::from_millis(100)).await;

// write second half of the request.
stream.write_all(req_buf2)?;

// read response.
loop {
let n = stream.read(&mut buf)?;

if buf[..n].ends_with(b"GET Response") {
break;
}
}

handle.await?;

Ok(())
}

async fn handle(req: Request<RequestExt<h1::RequestBody>>) -> Result<Response<ResponseBody>, Error> {
match (req.method(), req.uri().path()) {
(&Method::GET, "/") | (&Method::HEAD, "/") => Ok(Response::new(Bytes::from("GET Response").into())),
Expand Down

0 comments on commit f0bd80e

Please sign in to comment.