From 2fc08186fe7c960f34138daa281df9f740e6eaf5 Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Tue, 7 Jan 2025 17:11:09 +0100 Subject: [PATCH] feat(client): add support for upgrade request --- client/src/client.rs | 59 ++++++++++++++++++++++ client/src/error.rs | 8 +++ client/src/http_tunnel.rs | 6 +-- client/src/lib.rs | 1 + client/src/request.rs | 2 +- client/src/upgrade.rs | 102 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 174 insertions(+), 4 deletions(-) create mode 100644 client/src/upgrade.rs diff --git a/client/src/client.rs b/client/src/client.rs index 14e1ce040..a1847baba 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -20,6 +20,7 @@ use crate::{ service::HttpService, timeout::{Timeout, TimeoutConfig}, tls::connector::Connector, + upgrade::UpgradeRequest, uri::Uri, }; @@ -184,6 +185,64 @@ impl Client { self.request_builder(url, Method::CONNECT).mutate_marker() } + #[cfg(feature = "http1")] + /// Start a new upgrade request. + /// + /// # Example + /// ```rust + /// use xitca_client::{Client, bytes::Bytes, http::Method}; + /// + /// async fn _main() -> Result<(), xitca_client::error::Error> { + /// // construct a new client and initialize connect request. + /// let client = Client::new(); + /// let mut upgrade_response = client + /// .upgrade("http://localhost:8080", Method::GET) + /// .protocol("protocol1, protocol2") + /// .send().await? + /// ; + /// + /// if let Some(upgrade) = upgrade_response.headers.get(xitca_client::http::header::UPGRADE) { + /// // check which protocol it was upgraded to + /// } + /// + /// // upgrade_response is a response that contains the http request head and tunnel connection. + /// + /// // import Stream trait and call it's method on tunnel to receive bytes. + /// use futures::StreamExt; + /// if let Some(Ok(_)) = upgrade_response.tunnel().next().await { + /// // received bytes data. + /// } + /// + /// // import Sink trait and call it's method on tunnel to send bytes data. + /// use futures::SinkExt; + /// // send bytes data. + /// upgrade_response.tunnel().send(b"996").await?; + /// + /// // tunnel support split sending/receiving task into different parts to enable concurrent bytes data handling. + /// let (_head, mut tunnel) = upgrade_response.into_parts(); + /// let (mut write, mut read) = tunnel.split(); + /// + /// // read part can operate with Stream trait implement. + /// if let Some(Ok(_)) = read.next().await { + /// // received bytes data. + /// } + /// + /// // write part can operate with Sink trait implement. + /// write.send(b"996").await?; + /// + /// Ok(()) + /// # } + /// ``` + pub fn upgrade(&self, url: U, method: Method) -> UpgradeRequest<'_> + where + uri::Uri: TryFrom, + Error: From<>::Error>, + { + self.request_builder(url, method) + .version(Version::HTTP_11) + .mutate_marker() + } + #[cfg(all(feature = "websocket", feature = "http1"))] /// Start a new websocket request. /// diff --git a/client/src/error.rs b/client/src/error.rs index 93c94994b..78f7603f0 100644 --- a/client/src/error.rs +++ b/client/src/error.rs @@ -23,6 +23,7 @@ pub enum Error { #[cfg(any(feature = "rustls", feature = "rustls-ring-crypto"))] Rustls(_rustls::RustlsError), Parse(ParseError), + InvalidHeaderValue(http::header::InvalidHeaderValue), } impl fmt::Display for Error { @@ -64,6 +65,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: http::header::InvalidHeaderValue) -> Self { + Self::InvalidHeaderValue(e) + } +} + /// a collection of multiple errors chained together. #[derive(Debug)] pub struct ErrorMultiple(Vec); @@ -210,6 +217,7 @@ mod _openssl { #[cfg(any(feature = "rustls", feature = "rustls-ring-crypto"))] pub(crate) use _rustls::*; +use xitca_http::http; #[cfg(any(feature = "rustls", feature = "rustls-ring-crypto"))] mod _rustls { diff --git a/client/src/http_tunnel.rs b/client/src/http_tunnel.rs index ca081ee14..8b2ec3c1d 100644 --- a/client/src/http_tunnel.rs +++ b/client/src/http_tunnel.rs @@ -51,13 +51,13 @@ impl HttpTunnelRequest<'_> { } pub struct HttpTunnel { - body: ResponseBody, - io: TunnelIo, + pub(crate) body: ResponseBody, + pub(crate) io: TunnelIo, } // io type to bridge AsyncIo trait and h2 body's poll based read/write apis. #[derive(Default)] -struct TunnelIo { +pub(crate) struct TunnelIo { write_buf: BytesMut, #[cfg(feature = "http2")] adaptor: TunnelIoAdaptor, diff --git a/client/src/lib.rs b/client/src/lib.rs index f38794ef3..e6ae010b2 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -43,6 +43,7 @@ mod service; mod timeout; mod tls; mod tunnel; +mod upgrade; mod uri; #[cfg(feature = "http1")] diff --git a/client/src/request.rs b/client/src/request.rs index c45df7d6c..9490e56d9 100644 --- a/client/src/request.rs +++ b/client/src/request.rs @@ -19,7 +19,7 @@ use crate::{ /// builder type for [http::Request] with extended functionalities. pub struct RequestBuilder<'a, M = marker::Http> { pub(crate) req: http::Request, - err: Vec, + pub(crate) err: Vec, client: &'a Client, timeout: Duration, _marker: PhantomData, diff --git a/client/src/upgrade.rs b/client/src/upgrade.rs new file mode 100644 index 000000000..acb1a19a0 --- /dev/null +++ b/client/src/upgrade.rs @@ -0,0 +1,102 @@ +//! http upgrade handling. + +use super::{ + error::{Error, ErrorResponse}, + http::{ + header::{self, HeaderValue}, + response::Parts, + StatusCode, + }, + http_tunnel::HttpTunnel, + request::RequestBuilder, + tunnel::Tunnel, +}; +use std::ops::{Deref, DerefMut}; + +pub type UpgradeRequest<'a> = RequestBuilder<'a, marker::Upgrade>; +pub type UpgradeRequestWithProtocol<'a> = RequestBuilder<'a, marker::UpgradeWithProtocol>; + +mod marker { + pub struct Upgrade; + pub struct UpgradeWithProtocol; +} + +pub struct UpgradeResponse { + pub parts: Parts, + pub tunnel: Tunnel, +} + +impl<'a> UpgradeRequest<'a> { + pub fn protocol(mut self, proto: V) -> UpgradeRequestWithProtocol<'a> + where + V: TryInto, + >::Error: Into, + { + match proto.try_into() { + Ok(v) => { + self.req.headers_mut().insert(header::UPGRADE, v); + } + Err(e) => { + self.push_error(e.into()); + } + }; + + self.mutate_marker() + } +} + +impl UpgradeRequestWithProtocol<'_> { + /// Send the request and wait for response asynchronously. + pub async fn send(mut self) -> Result { + self.headers_mut() + .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); + + let res = self._send().await?; + + let status = res.status(); + let expect_status = StatusCode::SWITCHING_PROTOCOLS; + if status != expect_status { + return Err(Error::from(ErrorResponse { + expect_status, + status, + description: "upgrade tunnel can't be established", + })); + } + + let (parts, body) = res.into_inner().into_parts(); + + Ok(UpgradeResponse { + parts, + tunnel: Tunnel::new(HttpTunnel { + body, + io: Default::default(), + }), + }) + } +} + +impl UpgradeResponse { + #[inline] + pub fn into_parts(self) -> (Parts, Tunnel) { + (self.parts, self.tunnel) + } + + #[inline] + pub fn tunnel(&mut self) -> &mut Tunnel { + &mut self.tunnel + } +} + +impl Deref for UpgradeResponse { + type Target = Parts; + + fn deref(&self) -> &Self::Target { + &self.parts + } +} + +impl DerefMut for UpgradeResponse { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.parts + } +}