Skip to content

Commit

Permalink
feat(client): add middleware to retry http connection if closed by se…
Browse files Browse the repository at this point in the history
…rver
  • Loading branch information
joelwurtz committed Feb 2, 2025
1 parent 4d339df commit 14ba245
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 28 deletions.
6 changes: 3 additions & 3 deletions client/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ impl ResponseBody {
}
}

pub(crate) fn can_destroy_on_drop(&mut self) -> bool {
pub(crate) fn can_destroy_on_drop(&self) -> bool {
#[cfg(feature = "http1")]
if let Self::H1(ref mut body) = *self {
return body.conn_mut().is_destroy_on_drop();
if let Self::H1(ref body) = *self {
return body.conn().is_destroy_on_drop();
}

false
Expand Down
2 changes: 2 additions & 0 deletions client/src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! middleware offer extended functionality to http client.
mod redirect;
mod retry_closed_connection;

#[cfg(feature = "compress")]
mod decompress;
Expand All @@ -9,3 +10,4 @@ mod decompress;
pub use decompress::Decompress;

pub use redirect::FollowRedirect;
pub use retry_closed_connection::RetryClosedConnection;
80 changes: 80 additions & 0 deletions client/src/middleware/retry_closed_connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::io;

use crate::{
error::Error,
response::Response,
service::{Service, ServiceRequest},
};

/// middleware for retrying closed connection
pub struct RetryClosedConnection<S> {
service: S,
}

impl<S> RetryClosedConnection<S> {
pub fn new(service: S) -> Self {
Self { service }
}
}

impl<'r, 'c, S> Service<ServiceRequest<'r, 'c>> for RetryClosedConnection<S>
where
S: for<'r2, 'c2> Service<ServiceRequest<'r2, 'c2>, Response = Response, Error = Error> + Send + Sync,
{
type Response = Response;
type Error = Error;

async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result<Self::Response, Self::Error> {
let ServiceRequest { req, client, timeout } = req;
let headers = req.headers().clone();
let method = req.method().clone();
let uri = req.uri().clone();

loop {
let res = self.service.call(ServiceRequest { req, client, timeout }).await;

match res {
Err(Error::Io(err)) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::Io(err));
}
}
#[cfg(feature = "http1")]
Err(Error::H1(crate::h1::Error::Io(err))) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::H1(crate::h1::Error::Io(err)));
}
}
#[cfg(feature = "http2")]
Err(Error::H2(crate::h2::Error::H2(err))) => {
if !err.is_go_away() {
return Err(Error::H2(crate::h2::Error::H2(err)));
}

let reason = err.reason().unwrap();

if reason != h2::Reason::NO_ERROR {
return Err(Error::H2(crate::h2::Error::H2(err)));
}
}
#[cfg(feature = "http2")]
Err(Error::H2(crate::h2::Error::Io(err))) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::H2(crate::h2::Error::Io(err)));
}
}
#[cfg(feature = "http3")]
Err(Error::H3(crate::h3::Error::Io(err))) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::H3(crate::h3::Error::Io(err)));
}
}
res => return res,
}

*req.uri_mut() = uri.clone();
*req.method_mut() = method.clone();
*req.headers_mut() = headers.clone();
}
}
}
4 changes: 2 additions & 2 deletions client/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ impl<const PAYLOAD_LIMIT: usize> Response<PAYLOAD_LIMIT> {
/// Public API for test purpose.
///
/// Used for testing server implementation to make sure it follows spec.
pub fn can_close_connection(&mut self) -> bool {
self.res.body_mut().can_destroy_on_drop()
pub fn can_close_connection(&self) -> bool {
self.res.body().can_destroy_on_drop()
}
}

Expand Down
46 changes: 41 additions & 5 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::{
error, fmt, fs,
future::Future,
io,
net::SocketAddr,
net::TcpListener,
net::{SocketAddr, TcpListener, ToSocketAddrs},
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -36,9 +35,18 @@ where
T::Response: ReadyService + Service<Req>,
Req: TryFrom<NetStream> + 'static,
{
let lst = TcpListener::bind("127.0.0.1:0")?;
test_server_with_addr(service, "127.0.0.1:0")
}

let addr = lst.local_addr()?;
pub fn test_server_with_addr<T, Req, A>(service: T, addr: A) -> Result<TestServerHandle, Error>
where
T: Service + Send + Sync + 'static,
T::Response: ReadyService + Service<Req>,
Req: TryFrom<NetStream> + 'static,
A: ToSocketAddrs,
{
let lst = TcpListener::bind(addr)?;
let local_addr = lst.local_addr()?;

let handle = Builder::new()
.worker_threads(1)
Expand All @@ -47,7 +55,35 @@ where
.listen::<_, _, _, Req>("test_server", lst, service)
.build();

Ok(TestServerHandle { addr, handle })
Ok(TestServerHandle {
addr: local_addr,
handle,
})
}

/// A specialized http/1 server on top of [test_server]
pub fn test_h1_server_with_addr<T, B, E, A>(service: T, addr: A) -> Result<TestServerHandle, Error>
where
T: Service + Send + Sync + 'static,
T::Response: ReadyService + Service<Request<RequestExt<h1::RequestBody>>, Response = HResponse<B>> + 'static,
<T::Response as Service<Request<RequestExt<h1::RequestBody>>>>::Error: fmt::Debug,
T::Error: error::Error + 'static,
B: Stream<Item = Result<Bytes, E>> + 'static,
E: fmt::Debug + 'static,
A: ToSocketAddrs,
{
#[cfg(not(feature = "io-uring"))]
{
test_server_with_addr::<_, (TcpStream, SocketAddr), A>(service.enclosed(HttpServiceBuilder::h1()), addr)
}

#[cfg(feature = "io-uring")]
{
test_server_with_addr::<_, (xitca_io::net::io_uring::TcpStream, SocketAddr), A>(
service.enclosed(HttpServiceBuilder::h1().io_uring()),
addr,
)
}
}

/// A specialized http/1 server on top of [test_server]
Expand Down
91 changes: 80 additions & 11 deletions test/tests/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
time::Duration,
};

use xitca_client::Client;
use xitca_client::{middleware::RetryClosedConnection, Client};
use xitca_http::{
body::{BoxBody, ResponseBody},
bytes::{Bytes, BytesMut},
Expand All @@ -16,7 +16,7 @@ use xitca_http::{
},
};
use xitca_service::fn_service;
use xitca_test::{test_h1_server, Error};
use xitca_test::{test_h1_server, test_h1_server_with_addr, Error};

#[tokio::test]
async fn h1_get() -> Result<(), Error> {
Expand All @@ -27,7 +27,7 @@ async fn h1_get() -> Result<(), Error> {
let c = Client::new();

for _ in 0..3 {
let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -49,14 +49,14 @@ async fn h1_get_without_body_reading() -> Result<(), Error> {

let c = Client::builder().set_pool_capacity(1).finish();

let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());

// drop the response body without reading it.
drop(res);

let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -68,6 +68,75 @@ async fn h1_get_without_body_reading() -> Result<(), Error> {
Ok(())
}

#[tokio::test]
async fn h1_get_connection_closed_by_server() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;
let ip_port = handle.ip_port_string();

let server_url = format!("http://{}/", ip_port);

let c = Client::builder()
.middleware(RetryClosedConnection::new)
.set_pool_capacity(1)
.finish();

let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(false);
handle.await?;

let mut handle = test_h1_server_with_addr(fn_service(crate::handle), ip_port)?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;

assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(true);
handle.await?;

Ok(())
}

#[tokio::test]
async fn h1_get_connection_closed_by_server_then_refused() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;
let ip_port = handle.ip_port_string();

let server_url = format!("http://{}/", ip_port);

let c = Client::builder()
.middleware(RetryClosedConnection::new)
.set_pool_capacity(1)
.finish();

let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(false);
handle.await?;

let res = c.get(&server_url).version(Version::HTTP_11).send().await;

assert!(res.is_err());
assert!(matches!(res, Err(xitca_client::error::Error::Io(_))));

let Err(xitca_client::error::Error::Io(err)) = res else {
panic!("not expected");
};
assert_eq!(err.kind(), std::io::ErrorKind::ConnectionRefused);

Ok(())
}

#[tokio::test]
async fn h1_head() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;
Expand All @@ -77,7 +146,7 @@ async fn h1_head() -> Result<(), Error> {
let c = Client::new();

for _ in 0..3 {
let mut res = c.head(&server_url).version(Version::HTTP_11).send().await?;
let res = c.head(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand Down Expand Up @@ -106,7 +175,7 @@ async fn h1_post() -> Result<(), Error> {
}
let body_len = body.len();

let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.limit::<{ 12 * 1024 }>().string().await?;
Expand Down Expand Up @@ -134,7 +203,7 @@ async fn h1_drop_body_read() -> Result<(), Error> {
body.extend_from_slice(b"Hello,World!");
}

let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(res.can_close_connection());
}
Expand All @@ -160,7 +229,7 @@ async fn h1_partial_body_read() -> Result<(), Error> {
body.extend_from_slice(b"Hello,World!");
}

let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(res.can_close_connection());
}
Expand All @@ -180,7 +249,7 @@ async fn h1_close_connection() -> Result<(), Error> {

let c = Client::new();

let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(res.can_close_connection());

Expand Down Expand Up @@ -217,7 +286,7 @@ async fn h1_request_too_large() -> Result<(), Error> {
req.headers_mut()
.insert("large-header", HeaderValue::try_from(body).unwrap());

let mut res = req.send().await?;
let res = req.send().await?;
assert_eq!(res.status().as_u16(), 431);
assert!(res.can_close_connection());

Expand Down
8 changes: 4 additions & 4 deletions test/tests/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn h2_get() -> Result<(), Error> {
let c = Client::new();

for _ in 0..3 {
let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?;
let res = c.get(&server_url).version(Version::HTTP_2).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -46,7 +46,7 @@ async fn h2_no_host_header() -> Result<(), Error> {
let mut req = c.get(&server_url).version(Version::HTTP_2);
req.headers_mut().insert(header::HOST, "localhost".parse().unwrap());

let mut res = req.send().await?;
let res = req.send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -73,7 +73,7 @@ async fn h2_post() -> Result<(), Error> {
for _ in 0..1024 * 1024 {
body.extend_from_slice(b"Hello,World!");
}
let mut res = c.post(&server_url).version(Version::HTTP_2).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_2).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let _ = res.body().await;
Expand Down Expand Up @@ -142,7 +142,7 @@ async fn h2_keepalive() -> Result<(), Error> {
.block_on(async move {
let c = Client::new();

let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?;
let res = c.get(&server_url).version(Version::HTTP_2).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand Down
Loading

0 comments on commit 14ba245

Please sign in to comment.