From 69fd6bbf35389121aaf2b1db6e540919631dfbe5 Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Thu, 26 Dec 2024 13:54:25 +0100 Subject: [PATCH] feat(client): allow to set a specific sni hostname per request --- client/src/client.rs | 2 +- client/src/connect.rs | 13 +++++++++++-- client/src/connection.rs | 22 ++++++++++++++++------ client/src/request.rs | 13 +++++++++++++ client/src/service/http.rs | 11 ++++++----- 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/client/src/client.rs b/client/src/client.rs index bf1724cf..aab82640 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -328,7 +328,7 @@ impl Client { let (conn, version) = self .connector - .call((connect.hostname(), conn)) + .call((connect.sni_hostname(), conn)) .timeout(timer.as_mut()) .await .map_err(|_| TimeoutError::TlsHandshake)??; diff --git a/client/src/connect.rs b/client/src/connect.rs index fc0bda86..10123e28 100644 --- a/client/src/connect.rs +++ b/client/src/connect.rs @@ -2,7 +2,7 @@ use core::{fmt, iter, net::SocketAddr}; use std::collections::vec_deque::{self, VecDeque}; -use crate::uri::Uri; +use crate::{request::SniHostname, uri::Uri}; pub trait Address { /// Get hostname part. @@ -80,17 +80,19 @@ pub struct Connect<'a> { pub(crate) uri: Uri<'a>, pub(crate) port: u16, pub(crate) addr: Addrs, + pub(crate) sni_hostname: Option<&'a SniHostname>, } impl<'a> Connect<'a> { /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 - pub fn new(uri: Uri<'a>) -> Self { + pub fn new(uri: Uri<'a>, sni_hostname: Option<&'a SniHostname>) -> Self { let (_, port) = parse_host(uri.hostname()); Self { uri, port: port.unwrap_or(0), addr: Addrs::None, + sni_hostname, } } @@ -112,6 +114,13 @@ impl<'a> Connect<'a> { self.uri.hostname() } + /// Get sni hostname. + pub fn sni_hostname(&self) -> &str { + self.sni_hostname + .map(|s| s.0.as_str()) + .unwrap_or_else(|| self.hostname()) + } + /// Get request port. pub fn port(&self) -> u16 { Address::port(&self.uri).unwrap_or(self.port) diff --git a/client/src/connection.rs b/client/src/connection.rs index 321dfa14..41735b3a 100644 --- a/client/src/connection.rs +++ b/client/src/connection.rs @@ -2,7 +2,7 @@ use core::hash::{Hash, Hasher}; use xitca_http::http::uri::{Authority, PathAndQuery}; -use super::{tls::TlsStream, uri::Uri}; +use super::{connect::Connect, request::SniHostname, tls::TlsStream, uri::Uri}; /// exclusive connection for http1 and in certain case they can be upgraded to [ConnectionShared] pub type ConnectionExclusive = TlsStream; @@ -34,10 +34,17 @@ impl From for ConnectionShared { #[doc(hidden)] #[derive(PartialEq, Eq, Debug, Clone, Hash)] pub enum ConnectionKey { - Regular(Authority), + Regular(AuthorityWithSni), Unix(AuthorityWithPath), } +#[doc(hidden)] +#[derive(PartialEq, Eq, Debug, Clone, Hash)] +pub struct AuthorityWithSni { + authority: Authority, + sni: Option, +} + #[doc(hidden)] #[derive(Eq, Debug, Clone)] pub struct AuthorityWithPath { @@ -58,10 +65,13 @@ impl Hash for AuthorityWithPath { } } -impl From<&Uri<'_>> for ConnectionKey { - fn from(uri: &Uri<'_>) -> Self { - match *uri { - Uri::Tcp(uri) | Uri::Tls(uri) => ConnectionKey::Regular(uri.authority().unwrap().clone()), +impl From<&Connect<'_>> for ConnectionKey { + fn from(connect: &Connect<'_>) -> Self { + match connect.uri { + Uri::Tcp(uri) | Uri::Tls(uri) => ConnectionKey::Regular(AuthorityWithSni { + authority: uri.authority().unwrap().clone(), + sni: connect.sni_hostname.cloned(), + }), Uri::Unix(uri) => ConnectionKey::Unix(AuthorityWithPath { authority: uri.authority().unwrap().clone(), path_and_query: uri.path_and_query().unwrap().clone(), diff --git a/client/src/request.rs b/client/src/request.rs index 9490e56d..b74833cc 100644 --- a/client/src/request.rs +++ b/client/src/request.rs @@ -1,6 +1,7 @@ use core::{marker::PhantomData, time::Duration}; use futures_core::Stream; +use xitca_unsafe_collection::bytes::BytesStr; use crate::{ body::{BodyError, BoxBody, Once}, @@ -210,6 +211,15 @@ impl<'a, M> RequestBuilder<'a, M> { self } + /// Set SNI hostname of this request. + #[inline] + pub fn sni_hostname(mut self, sni_hostname: &str) -> Self { + self.req + .extensions_mut() + .insert(SniHostname(BytesStr::from(sni_hostname))); + self + } + fn map_body(mut self, b: B) -> RequestBuilder<'a, M> where B: Stream> + Send + 'static, @@ -219,3 +229,6 @@ impl<'a, M> RequestBuilder<'a, M> { self } } + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SniHostname(pub(crate) BytesStr); diff --git a/client/src/service/http.rs b/client/src/service/http.rs index 07653b97..6c468913 100644 --- a/client/src/service/http.rs +++ b/client/src/service/http.rs @@ -33,13 +33,14 @@ pub(crate) fn base_service() -> HttpService { #[allow(unused_mut)] let mut version = req.version(); - let mut connect = Connect::new(uri); + let sni_hostname = req.extensions().get(); + let mut connect = Connect::new(uri, sni_hostname); let _date = client.date_service.handle(); loop { match version { - Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await { + Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect).await { shared::AcquireOutput::Conn(mut _conn) => { let mut _timer = Box::pin(tokio::time::sleep(timeout)); *req.version_mut() = version; @@ -94,7 +95,7 @@ pub(crate) fn base_service() -> HttpService { if let Ok(Ok(conn)) = crate::h3::proto::connect( &client.h3_client, connect.addrs(), - connect.hostname(), + connect.sni_hostname(), ) .timeout(timer.as_mut()) .await @@ -136,7 +137,7 @@ pub(crate) fn base_service() -> HttpService { #[cfg(feature = "http1")] { - client.exclusive_pool.try_add(&connect.uri, conn); + client.exclusive_pool.try_add(&connect, conn); // downgrade request version to what alpn protocol suggested from make_exclusive. version = alpn_version; } @@ -151,7 +152,7 @@ pub(crate) fn base_service() -> HttpService { _ => unreachable!("outer match didn't handle version correctly."), }, }, - version => match client.exclusive_pool.acquire(&connect.uri).await { + version => match client.exclusive_pool.acquire(&connect).await { exclusive::AcquireOutput::Conn(mut _conn) => { *req.version_mut() = version;