Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(client): allow to set a specific sni hostname per request #1171

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ impl Client {

let (conn, version) = self
.connector
.call((connect.hostname(), conn))
.call((connect.sni_hostname(), conn))
.timeout(timer.as_mut())
.await
.map_err(|_| TimeoutError::TlsHandshake)??;
Expand Down
13 changes: 11 additions & 2 deletions client/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::{fmt, iter, net::SocketAddr};

use std::collections::vec_deque::{self, VecDeque};

use crate::uri::Uri;
use crate::{request::SniHostname, uri::Uri};

pub trait Address {
/// Get hostname part.
Expand Down Expand Up @@ -80,17 +80,19 @@ pub struct Connect<'a> {
pub(crate) uri: Uri<'a>,
pub(crate) port: u16,
pub(crate) addr: Addrs,
pub(crate) sni_hostname: Option<&'a SniHostname>,
}

impl<'a> Connect<'a> {
/// Create `Connect` instance by splitting the string by ':' and convert the second part to u16
pub fn new(uri: Uri<'a>) -> Self {
pub fn new(uri: Uri<'a>, sni_hostname: Option<&'a SniHostname>) -> Self {
let (_, port) = parse_host(uri.hostname());

Self {
uri,
port: port.unwrap_or(0),
addr: Addrs::None,
sni_hostname,
}
}

Expand All @@ -112,6 +114,13 @@ impl<'a> Connect<'a> {
self.uri.hostname()
}

/// Get sni hostname.
pub fn sni_hostname(&self) -> &str {
self.sni_hostname
.map(|s| s.0.as_str())
.unwrap_or_else(|| self.hostname())
}

/// Get request port.
pub fn port(&self) -> u16 {
Address::port(&self.uri).unwrap_or(self.port)
Expand Down
22 changes: 16 additions & 6 deletions client/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::hash::{Hash, Hasher};

use xitca_http::http::uri::{Authority, PathAndQuery};

use super::{tls::TlsStream, uri::Uri};
use super::{connect::Connect, request::SniHostname, tls::TlsStream, uri::Uri};

/// exclusive connection for http1 and in certain case they can be upgraded to [ConnectionShared]
pub type ConnectionExclusive = TlsStream;
Expand Down Expand Up @@ -34,10 +34,17 @@ impl From<crate::h3::Connection> for ConnectionShared {
#[doc(hidden)]
#[derive(PartialEq, Eq, Debug, Clone, Hash)]
pub enum ConnectionKey {
Regular(Authority),
Regular(AuthorityWithSni),
Unix(AuthorityWithPath),
}

#[doc(hidden)]
#[derive(PartialEq, Eq, Debug, Clone, Hash)]
pub struct AuthorityWithSni {
authority: Authority,
sni: Option<SniHostname>,
}

#[doc(hidden)]
#[derive(Eq, Debug, Clone)]
pub struct AuthorityWithPath {
Expand All @@ -58,10 +65,13 @@ impl Hash for AuthorityWithPath {
}
}

impl From<&Uri<'_>> for ConnectionKey {
fn from(uri: &Uri<'_>) -> Self {
match *uri {
Uri::Tcp(uri) | Uri::Tls(uri) => ConnectionKey::Regular(uri.authority().unwrap().clone()),
impl From<&Connect<'_>> for ConnectionKey {
fn from(connect: &Connect<'_>) -> Self {
match connect.uri {
Uri::Tcp(uri) | Uri::Tls(uri) => ConnectionKey::Regular(AuthorityWithSni {
authority: uri.authority().unwrap().clone(),
sni: connect.sni_hostname.cloned(),
}),
Uri::Unix(uri) => ConnectionKey::Unix(AuthorityWithPath {
authority: uri.authority().unwrap().clone(),
path_and_query: uri.path_and_query().unwrap().clone(),
Expand Down
13 changes: 13 additions & 0 deletions client/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use core::{marker::PhantomData, time::Duration};

use futures_core::Stream;
use xitca_unsafe_collection::bytes::BytesStr;

use crate::{
body::{BodyError, BoxBody, Once},
Expand Down Expand Up @@ -210,6 +211,15 @@ impl<'a, M> RequestBuilder<'a, M> {
self
}

/// Set SNI hostname of this request.
#[inline]
pub fn sni_hostname(mut self, sni_hostname: &str) -> Self {
self.req
.extensions_mut()
.insert(SniHostname(BytesStr::from(sni_hostname)));
self
}

fn map_body<B, E>(mut self, b: B) -> RequestBuilder<'a, M>
where
B: Stream<Item = Result<Bytes, E>> + Send + 'static,
Expand All @@ -219,3 +229,6 @@ impl<'a, M> RequestBuilder<'a, M> {
self
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SniHostname(pub(crate) BytesStr);
11 changes: 6 additions & 5 deletions client/src/service/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ pub(crate) fn base_service() -> HttpService {
#[allow(unused_mut)]
let mut version = req.version();

let mut connect = Connect::new(uri);
let sni_hostname = req.extensions().get();
let mut connect = Connect::new(uri, sni_hostname);

let _date = client.date_service.handle();

loop {
match version {
Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await {
Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect).await {
shared::AcquireOutput::Conn(mut _conn) => {
let mut _timer = Box::pin(tokio::time::sleep(timeout));
*req.version_mut() = version;
Expand Down Expand Up @@ -94,7 +95,7 @@ pub(crate) fn base_service() -> HttpService {
if let Ok(Ok(conn)) = crate::h3::proto::connect(
&client.h3_client,
connect.addrs(),
connect.hostname(),
connect.sni_hostname(),
)
.timeout(timer.as_mut())
.await
Expand Down Expand Up @@ -136,7 +137,7 @@ pub(crate) fn base_service() -> HttpService {

#[cfg(feature = "http1")]
{
client.exclusive_pool.try_add(&connect.uri, conn);
client.exclusive_pool.try_add(&connect, conn);
// downgrade request version to what alpn protocol suggested from make_exclusive.
version = alpn_version;
}
Expand All @@ -151,7 +152,7 @@ pub(crate) fn base_service() -> HttpService {
_ => unreachable!("outer match didn't handle version correctly."),
},
},
version => match client.exclusive_pool.acquire(&connect.uri).await {
version => match client.exclusive_pool.acquire(&connect).await {
exclusive::AcquireOutput::Conn(mut _conn) => {
*req.version_mut() = version;

Expand Down
Loading