Skip to content

Commit

Permalink
feat(client): allow to specify a specific remote adress on client req…
Browse files Browse the repository at this point in the history
…uest
  • Loading branch information
joelwurtz committed Feb 13, 2025
1 parent cec4ef4 commit 4c8fc50
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 13 deletions.
12 changes: 7 additions & 5 deletions client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,13 @@ impl Client {
connect: &mut Connect<'_>,
timer: &mut Pin<Box<Sleep>>,
) -> Result<ConnectionExclusive, Error> {
self.resolver
.call(connect)
.timeout(timer.as_mut())
.await
.map_err(|_| TimeoutError::Resolve)??;
if !connect.is_resolved() {
self.resolver
.call(connect)
.timeout(timer.as_mut())
.await
.map_err(|_| TimeoutError::Resolve)??;
}

timer
.as_mut()
Expand Down
16 changes: 14 additions & 2 deletions client/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,16 @@ pub struct Connect<'a> {

impl<'a> Connect<'a> {
/// Create `Connect` instance by splitting the string by ':' and convert the second part to u16
pub fn new(uri: Uri<'a>) -> Self {
pub fn new(uri: Uri<'a>, address: Option<SocketAddr>) -> Self {
let (_, port) = parse_host(uri.hostname());

Self {
uri,
port: port.unwrap_or(0),
addr: Addrs::None,
addr: match address {
Some(address) => Addrs::One(address),
None => Addrs::None,
},
}
}

Expand Down Expand Up @@ -125,6 +128,15 @@ impl<'a> Connect<'a> {
Addrs::Multi(ref addrs) => AddrsIter::Multi(addrs.iter()),
}
}

/// Check if address is resolved.
pub fn is_resolved(&self) -> bool {
match self.addr {
Addrs::None => false,
Addrs::One(_) => true,
Addrs::Multi(ref addrs) => !addrs.is_empty(),
}
}
}

impl fmt::Display for Connect<'_> {
Expand Down
17 changes: 15 additions & 2 deletions client/src/middleware/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,28 @@ where
type Error = Error;

async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result<Self::Response, Self::Error> {
let ServiceRequest { req, client, timeout } = req;
let ServiceRequest {
req,
address,
client,
timeout,
} = req;
let mut headers = req.headers().clone();
let mut method = req.method().clone();
let mut uri = req.uri().clone();
let ext = req.extensions().clone();
let mut count = 0;

loop {
let mut res = self.service.call(ServiceRequest { req, client, timeout }).await?;
let mut res = self
.service
.call(ServiceRequest {
req,
address,
client,
timeout,
})
.await?;

if count == MAX {
return Ok(res);
Expand Down
14 changes: 13 additions & 1 deletion client/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::{marker::PhantomData, time::Duration};
use core::{marker::PhantomData, net::SocketAddr, time::Duration};

use futures_core::Stream;

Expand All @@ -21,6 +21,7 @@ pub struct RequestBuilder<'a, M = marker::Http> {
pub(crate) req: http::Request<BoxBody>,
pub(crate) err: Vec<Error>,
client: &'a Client,
address: Option<SocketAddr>,
timeout: Duration,
_marker: PhantomData<M>,
}
Expand Down Expand Up @@ -102,6 +103,7 @@ impl<'a, M> RequestBuilder<'a, M> {
Self {
req: req.map(BoxBody::new),
err: Vec::new(),
address: None,
client,
timeout: client.timeout_config.request_timeout,
_marker: PhantomData,
Expand All @@ -112,6 +114,7 @@ impl<'a, M> RequestBuilder<'a, M> {
RequestBuilder {
req: self.req,
err: self.err,
address: self.address,
client: self.client,
timeout: self.timeout,
_marker: PhantomData,
Expand All @@ -123,6 +126,7 @@ impl<'a, M> RequestBuilder<'a, M> {
let Self {
mut req,
err,
address,
client,
timeout,
..
Expand All @@ -136,6 +140,7 @@ impl<'a, M> RequestBuilder<'a, M> {
.service
.call(ServiceRequest {
req: &mut req,
address,
client,
timeout,
})
Expand Down Expand Up @@ -210,6 +215,13 @@ impl<'a, M> RequestBuilder<'a, M> {
self
}

/// Set specific address for this request.
#[inline]
pub fn address(mut self, addr: SocketAddr) -> Self {
self.address = Some(addr);
self
}

fn map_body<B, E>(mut self, b: B) -> RequestBuilder<'a, M>
where
B: Stream<Item = Result<Bytes, E>> + Send + 'static,
Expand Down
13 changes: 10 additions & 3 deletions client/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::{future::Future, pin::Pin, time::Duration};
use core::{future::Future, net::SocketAddr, pin::Pin, time::Duration};

use crate::{
body::BoxBody,
Expand Down Expand Up @@ -66,6 +66,7 @@ where
/// [RequestBuilder]: crate::request::RequestBuilder
pub struct ServiceRequest<'r, 'c> {
pub req: &'r mut Request<BoxBody>,
pub address: Option<SocketAddr>,
pub client: &'c Client,
pub timeout: Duration,
}
Expand All @@ -85,7 +86,12 @@ pub(crate) fn base_service() -> HttpService {
#[cfg(any(feature = "http1", feature = "http2", feature = "http3"))]
use crate::{error::TimeoutError, timeout::Timeout};

let ServiceRequest { req, client, timeout } = req;
let ServiceRequest {
req,
address,
client,
timeout,
} = req;

let uri = Uri::try_parse(req.uri())?;

Expand All @@ -94,7 +100,7 @@ pub(crate) fn base_service() -> HttpService {
#[allow(unused_mut)]
let mut version = req.version();

let mut connect = Connect::new(uri);
let mut connect = Connect::new(uri, address);

let _date = client.date_service.handle();

Expand Down Expand Up @@ -307,6 +313,7 @@ mod test {
req,
client: &self.0,
timeout: self.0.timeout_config.request_timeout,
address: None,
}
}
}
Expand Down

0 comments on commit 4c8fc50

Please sign in to comment.