Skip to content

Commit

Permalink
feat(http): set request authority and scheme for h1
Browse files Browse the repository at this point in the history
  • Loading branch information
joelwurtz committed Feb 5, 2025
1 parent 7645226 commit d05e3b5
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 29 deletions.
4 changes: 2 additions & 2 deletions client/src/h1/proto/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<const HEADER_LIMIT: usize> DerefMut for Context<'_, '_, HEADER_LIMIT> {
}

impl<'c, 'd, const HEADER_LIMIT: usize> Context<'c, 'd, HEADER_LIMIT> {
pub(crate) fn new(date: &'c DateTimeHandle<'d>) -> Self {
Self(context::Context::new(date))
pub(crate) fn new(date: &'c DateTimeHandle<'d>, is_tls: bool) -> Self {
Self(context::Context::new(date, is_tls))
}
}
6 changes: 5 additions & 1 deletion client/src/h1/proto/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ where
}
}

let is_tls = req
.uri()
.scheme()
.is_some_and(|scheme| scheme == "https" || scheme == "wss");
// TODO: make const generic params configurable.
let mut ctx = Context::<128>::new(&date);
let mut ctx = Context::<128>::new(&date, is_tls);

// encode request head and return transfer encoding for request body
let encoder = ctx.encode_head(&mut buf, req)?;
Expand Down
2 changes: 1 addition & 1 deletion http/benches/h1_decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl DateTime for DT {
fn decode(c: &mut Criterion) {
let dt = DT::dummy_date_time();

let mut ctx = Context::<_, 8>::new(&dt);
let mut ctx = Context::<_, 8>::new(&dt, false);

let req = b"\
GET /HFQR/xitca-web HTTP/1.1\r\n\
Expand Down
7 changes: 5 additions & 2 deletions http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub(crate) async fn run<
config: HttpServiceConfig<HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
) -> Result<(), Error<S::Error, BE>>
where
S: Service<ExtRequest<ReqB>, Response = Response<ResB>>,
Expand All @@ -77,7 +78,7 @@ where
EitherBuf::Right(WriteBuf::<WRITE_BUF_LIMIT>::default())
};

Dispatcher::new(io, addr, timer, config, service, date, write_buf)
Dispatcher::new(io, addr, timer, config, service, date, write_buf, is_tls)
.run()
.await
}
Expand Down Expand Up @@ -166,6 +167,7 @@ where
W: H1BufWrite,
D: DateTime,
{
#[allow(clippy::too_many_arguments)]
fn new<const WRITE_BUF_LIMIT: usize>(
io: &'a mut St,
addr: SocketAddr,
Expand All @@ -174,11 +176,12 @@ where
service: &'a S,
date: &'a D,
write_buf: W,
is_tls: bool,
) -> Self {
Self {
io: BufferedIo::new(io, write_buf),
timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
ctx: Context::with_addr(addr, date),
ctx: Context::with_addr(addr, date, is_tls),
service,
_phantom: PhantomData,
}
Expand Down
3 changes: 2 additions & 1 deletion http/src/h1/dispatcher_uring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ where
config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
) -> Self {
Self {
io: Rc::new(io),
timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
ctx: Context::<_, H_LIMIT>::with_addr(addr, date),
ctx: Context::<_, H_LIMIT>::with_addr(addr, date, is_tls),
service,
read_buf: BufOwned::new(),
write_buf: BufOwned::new(),
Expand Down
8 changes: 5 additions & 3 deletions http/src/h1/proto/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct Context<'a, D, const HEADER_LIMIT: usize> {
// http extensions reused by next request.
exts: Extensions,
date: &'a D,
pub(crate) is_tls: bool,
}

// A set of state for current request that are used after request's ownership is passed
Expand Down Expand Up @@ -49,21 +50,22 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> {
///
/// [DateTime]: crate::date::DateTime
#[inline]
pub fn new(date: &'a D) -> Self {
Self::with_addr(crate::unspecified_socket_addr(), date)
pub fn new(date: &'a D, is_tls: bool) -> Self {
Self::with_addr(crate::unspecified_socket_addr(), date, is_tls)
}

/// Context is constructed with [SocketAddr] and reference of certain type that impl [DateTime] trait.
///
/// [DateTime]: crate::date::DateTime
#[inline]
pub fn with_addr(addr: SocketAddr, date: &'a D) -> Self {
pub fn with_addr(addr: SocketAddr, date: &'a D, is_tls: bool) -> Self {
Self {
addr,
state: ContextState::new(),
header: None,
exts: Extensions::new(),
date,
is_tls,
}
}

Expand Down
55 changes: 52 additions & 3 deletions http/src/h1/proto/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::mem::MaybeUninit;

use http::uri::{Authority, Scheme};
use httparse::Status;

use crate::{
Expand Down Expand Up @@ -71,7 +72,7 @@ impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
// split the headers from buffer.
let slice = buf.split_to(len).freeze();

let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?;
let mut uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?.into_parts();

// pop a cached headermap or construct a new one.
let mut headers = self.take_headers();
Expand All @@ -87,6 +88,25 @@ impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {

let extensions = self.take_extensions();

// Try to set authority from host header if not present in request path
if uri.authority.is_none() {
// @TODO if it's a tls connection we could set the sni server name as authority instead
if let Some(host) = headers.get(http::header::HOST) {
uri.authority = Some(Authority::try_from(host.as_bytes())?);
}
}

// If authority is set, this will set the correct scheme depending on the tls acceptor used in the service.
if uri.authority.is_some() && uri.scheme.is_none() {
uri.scheme = if self.is_tls {
Some(Scheme::HTTPS)
} else {
Some(Scheme::HTTP)
};
}

let uri = Uri::from_parts(uri)?;

*req.method_mut() = method;
*req.version_mut() = version;
*req.uri_mut() = uri;
Expand Down Expand Up @@ -173,7 +193,7 @@ mod test {

#[test]
fn connection_multiple_value() {
let mut ctx = Context::<_, 4>::new(&());
let mut ctx = Context::<_, 4>::new(&(), false);

let head = b"\
GET / HTTP/1.1\r\n\
Expand Down Expand Up @@ -211,7 +231,7 @@ mod test {

#[test]
fn transfer_encoding() {
let mut ctx = Context::<_, 4>::new(&());
let mut ctx = Context::<_, 4>::new(&(), false);

let head = b"\
GET / HTTP/1.1\r\n\
Expand Down Expand Up @@ -311,4 +331,33 @@ mod test {
"transfer coding is not decoded to chunked"
);
}

#[test]
fn test_host_with_scheme() {
let mut ctx = Context::<_, 4>::new(&(), true);

let head = b"\
GET / HTTP/1.1\r\n\
Host: example.com\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);

let (req, _) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();

assert_eq!(req.uri().scheme(), Some(&Scheme::HTTPS));
assert_eq!(req.uri().authority(), Some(&Authority::from_static("example.com")));
assert_eq!(req.headers().get(http::header::HOST).unwrap(), "example.com");

let head = b"\
GET / HTTP/1.1\r\n\
\r\n\
";
let mut buf = BytesMut::from(&head[..]);

let (req, _) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap();

assert_eq!(req.uri().scheme(), None);
assert_eq!(req.uri().authority(), None);
}
}
4 changes: 2 additions & 2 deletions http/src/h1/proto/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ mod test {

#[test]
fn append_header() {
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler);
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false);

let mut res = Response::new(BoxBody::new(Once::new(Bytes::new())));

Expand Down Expand Up @@ -287,7 +287,7 @@ mod test {

#[test]
fn multi_set_cookie() {
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler);
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false);

let mut res = Response::new(BoxBody::new(Once::new(Bytes::new())));

Expand Down
6 changes: 6 additions & 0 deletions http/src/h1/proto/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ impl From<http::uri::InvalidUri> for ProtoError {
}
}

impl From<http::uri::InvalidUriParts> for ProtoError {
fn from(_: http::uri::InvalidUriParts) -> Self {
Self::Uri
}
}

impl From<http::status::InvalidStatusCode> for ProtoError {
fn from(_: http::status::InvalidStatusCode) -> Self {
Self::Status
Expand Down
35 changes: 26 additions & 9 deletions http/src/h1/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
error::{HttpServiceError, TimeoutError},
http::{Request, RequestExt, Response},
service::HttpService,
tls::IsTls,
util::timer::Timeout,
};

Expand All @@ -21,7 +22,7 @@ impl<St, S, B, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, co
Service<(St, SocketAddr)> for H1Service<St, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<St>,
A: Service<St> + IsTls,
St: AsyncIo,
A::Response: AsyncIo,
B: Stream<Item = Result<Bytes, BE>>,
Expand All @@ -41,9 +42,17 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get())
.await
.map_err(Into::into)
super::dispatcher::run(
&mut io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(Into::into)
}
}

Expand Down Expand Up @@ -94,7 +103,7 @@ impl<S, B, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const
Service<(TcpStream, SocketAddr)> for H1UringService<S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<TcpStream>,
A: Service<TcpStream> + IsTls,
A::Response: AsyncBufRead + AsyncBufWrite + 'static,
B: Stream<Item = Result<Bytes, BE>>,
HttpServiceError<S::Error, BE>: From<A::Error>,
Expand All @@ -113,10 +122,18 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher_uring::Dispatcher::new(io, addr, timer, self.config, &self.service, self.date.get())
.run()
.await
.map_err(Into::into)
super::dispatcher_uring::Dispatcher::new(
io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.run()
.await
.map_err(Into::into)
}
}

Expand Down
5 changes: 4 additions & 1 deletion http/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use super::{
date::{DateTime, DateTimeService},
error::{HttpServiceError, TimeoutError},
http::{Request, RequestExt, Response},
tls::IsTls,
util::timer::{KeepAlive, Timeout},
version::AsVersion,
};
Expand Down Expand Up @@ -73,7 +74,7 @@ impl<S, ResB, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, con
for HttpService<ServerStream, S, RequestBody, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<ResB>>,
A: Service<TcpStream>,
A: Service<TcpStream> + IsTls,
A::Response: AsyncIo + AsVersion,
HttpServiceError<S::Error, BE>: From<A::Error>,
S::Error: fmt::Debug,
Expand Down Expand Up @@ -120,6 +121,7 @@ where
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(From::from),
Expand Down Expand Up @@ -168,6 +170,7 @@ where
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(From::from)
Expand Down
13 changes: 13 additions & 0 deletions http/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ pub use error::TlsError;

use xitca_service::Service;

/// A trait to check if an acceptor will create a Tls stream.
pub trait IsTls {
fn is_tls(&self) -> bool {
true
}
}

/// A NoOp Tls Acceptor pass through input Stream type.
#[derive(Copy, Clone)]
pub struct NoOpTlsAcceptorBuilder;
Expand All @@ -42,3 +49,9 @@ impl<St> Service<St> for NoOpTlsAcceptorService {
Ok(io)
}
}

impl IsTls for NoOpTlsAcceptorService {
fn is_tls(&self) -> bool {
false
}
}
4 changes: 3 additions & 1 deletion http/src/tls/native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use xitca_service::Service;

use crate::{http::Version, version::AsVersion};

use super::error::TlsError;
use super::{error::TlsError, IsTls};

/// A wrapper type for [TlsStream](native_tls::TlsStream).
///
Expand Down Expand Up @@ -93,6 +93,8 @@ impl<St: AsyncIo> Service<St> for TlsAcceptorService {
}
}

impl IsTls for TlsAcceptorService {}

impl<S: AsyncIo> AsyncIo for TlsStream<S> {
#[inline]
fn ready(&mut self, interest: Interest) -> impl Future<Output = io::Result<Ready>> + Send {
Expand Down
Loading

0 comments on commit d05e3b5

Please sign in to comment.