diff --git a/Cargo.toml b/Cargo.toml index dc5d3347..94ddf612 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "io", "postgres", "postgres-codegen", + "reverse-proxy", "router", "server", "service", diff --git a/client/src/body.rs b/client/src/body.rs index ce9850e5..107edd0d 100644 --- a/client/src/body.rs +++ b/client/src/body.rs @@ -49,7 +49,7 @@ impl fmt::Debug for ResponseBody<'_> { } impl ResponseBody<'_> { - pub(crate) fn into_owned(self) -> ResponseBody<'static> { + pub fn into_owned(self) -> ResponseBody<'static> { match self { #[cfg(feature = "http1")] Self::H1(body) => ResponseBody::H1Owned(body.map_conn(Into::into)), @@ -101,7 +101,7 @@ impl Stream for ResponseBody<'_> { } /// type erased stream body. -pub struct BoxBody(Pin> + Send + 'static>>); +pub struct BoxBody(Pin> + 'static>>); impl Default for BoxBody { fn default() -> Self { @@ -113,7 +113,7 @@ impl BoxBody { #[inline] pub fn new(body: B) -> Self where - B: Stream> + Send + 'static, + B: Stream> + 'static, E: Into, { Self(Box::pin(BoxStreamMapErr { body })) diff --git a/client/src/client.rs b/client/src/client.rs index 14e1ce04..a3b1b606 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -77,7 +77,7 @@ impl Client { #[inline] pub fn request(&self, req: http::Request) -> RequestBuilder<'_> where - B: Stream> + Send + 'static, + B: Stream> + 'static, BodyError: From, { RequestBuilder::new(req, self) @@ -391,8 +391,6 @@ mod test { #[tokio::test] async fn connect_google() { let res = Client::builder() - .middleware(crate::middleware::FollowRedirect::new) - .middleware(crate::middleware::Decompress::new) .openssl() .finish() .get("https://www.google.com/") diff --git a/client/src/request.rs b/client/src/request.rs index d73ef6a6..e2cda070 100644 --- a/client/src/request.rs +++ b/client/src/request.rs @@ -81,7 +81,7 @@ impl<'a> RequestBuilder<'a, marker::Http> { #[inline] pub fn stream(self, body: B) -> Self where - B: Stream> + Send + 'static, + B: Stream> + 'static, E: Into, { self.map_body(body) @@ -96,7 +96,7 @@ impl<'a> RequestBuilder<'a, marker::Http> { impl<'a, M> RequestBuilder<'a, M> { pub(crate) fn new(req: http::Request, client: &'a Client) -> Self where - B: Stream> + Send + 'static, + B: Stream> + 'static, E: Into, { Self { @@ -212,7 +212,7 @@ impl<'a, M> RequestBuilder<'a, M> { fn map_body(mut self, b: B) -> RequestBuilder<'a, M> where - B: Stream> + Send + 'static, + B: Stream> + 'static, E: Into, { self.req = self.req.map(|_| BoxBody::new(b)); diff --git a/client/src/response.rs b/client/src/response.rs index be22e079..90f05c9f 100644 --- a/client/src/response.rs +++ b/client/src/response.rs @@ -9,7 +9,7 @@ use futures_core::stream::Stream; use tokio::time::{Instant, Sleep}; use tracing::debug; use xitca_http::{bytes::BytesMut, http}; - +use xitca_http::http::response::Parts; use crate::{ body::ResponseBody, error::{Error, TimeoutError}, @@ -85,6 +85,11 @@ impl<'a, const PAYLOAD_LIMIT: usize> Response<'a, PAYLOAD_LIMIT> { timeout: dur, } } + /// Collect response body as String. Response is consumed. + #[inline] + pub fn into_parts(self) -> (Parts, ResponseBody<'a>) { + self.res.into_parts() + } /// Collect response body as String. Response is consumed. #[inline] diff --git a/client/src/service.rs b/client/src/service.rs index 70465bc4..47094aae 100644 --- a/client/src/service.rs +++ b/client/src/service.rs @@ -11,14 +11,14 @@ use crate::{ uri::Uri, }; -type BoxFuture<'f, T, E> = Pin> + Send + 'f>>; +type BoxFuture<'f, T, E> = Pin> + 'f>>; /// trait for composable http services. Used for middleware,resolver and tls connector. pub trait Service { type Response; type Error; - fn call(&self, req: Req) -> impl Future> + Send; + fn call(&self, req: Req) -> impl Future>; } pub trait ServiceDyn { @@ -48,8 +48,7 @@ where impl Service for Box where - Req: Send, - I: ServiceDyn + ?Sized + Send + Sync, + I: ServiceDyn + ?Sized, { type Response = I::Response; type Error = I::Error; @@ -72,7 +71,7 @@ pub struct ServiceRequest<'r, 'c> { /// type alias for object safe wrapper of type implement [Service] trait. pub type HttpService = - Box ServiceDyn, Response = Response<'c>, Error = Error> + Send + Sync>; + Box ServiceDyn, Response = Response<'c>, Error = Error>>; pub(crate) fn base_service() -> HttpService { struct HttpService; diff --git a/reverse-proxy/Cargo.toml b/reverse-proxy/Cargo.toml new file mode 100644 index 00000000..aa7f0403 --- /dev/null +++ b/reverse-proxy/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "redirectionio-xitca-proxy" +version = "0.1.0" +authors = ["Joel Wurtz "] +edition = "2021" +description = "Xitca web reverse HTTP and Websocket proxy" + +[dependencies] +bytes = "1.9.0" +lazy_static = "1.5.0" +xitca-client = { version = "0.1.0", features = ["dangerous", "openssl"] } +xitca-http = "0.7.0" +xitca-web = "0.7.0" + +[dev-dependencies] +tokio = { version = "1.42.0", features = ["full"] } + +[[example]] +name = "http_proxy" diff --git a/reverse-proxy/examples/http_proxy.rs b/reverse-proxy/examples/http_proxy.rs new file mode 100644 index 00000000..46ac0794 --- /dev/null +++ b/reverse-proxy/examples/http_proxy.rs @@ -0,0 +1,21 @@ +use redirectionio_xitca_proxy::{HttpPeer, Proxy}; +use std::net::ToSocketAddrs; +use xitca_web::App; + +#[tokio::main] +async fn main() -> std::io::Result<()> { + let address = "github.com:443" + .to_socket_addrs() + .expect("error getting addresses") + .next() + .expect("cannot get address"); + + App::new() + .at("", Proxy::new(HttpPeer::new(address, "github.com:443").tls(true))) + .serve() + .bind("127.0.0.1:8080")? + .run() + .await?; + + Ok(()) +} diff --git a/reverse-proxy/src/error.rs b/reverse-proxy/src/error.rs new file mode 100644 index 00000000..516f32b1 --- /dev/null +++ b/reverse-proxy/src/error.rs @@ -0,0 +1,31 @@ +use crate::forwarder::ForwardError; +use std::error::Error; +use std::fmt; +use xitca_web::error::Error as XitcaError; + +#[derive(Debug)] +pub enum ProxyError { + CannotReadRequestBody(XitcaError), + ForwardError(ForwardError), + NoPeer, +} + +impl fmt::Display for ProxyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::CannotReadRequestBody(e) => write!(f, "error when reading request body: {}", e), + Self::ForwardError(e) => write!(f, "error when forwarding request: {}", e), + Self::NoPeer => f.write_str("no peer found"), + } + } +} + +impl Error for ProxyError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::CannotReadRequestBody(err) => Some(err), + Self::ForwardError(err) => Some(err), + Self::NoPeer => None, + } + } +} diff --git a/reverse-proxy/src/forwarder/error.rs b/reverse-proxy/src/forwarder/error.rs new file mode 100644 index 00000000..9856573f --- /dev/null +++ b/reverse-proxy/src/forwarder/error.rs @@ -0,0 +1,32 @@ +use std::{error::Error, fmt}; + +/// Errors that can result from using a connector service. +#[derive(Debug)] +pub enum ForwardError { + /// Failed to build a request from origin + UriError(xitca_web::http::Error), +} + +impl fmt::Display for ForwardError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UriError(_) => f.write_str("could not build request from origin"), + } + } +} + +impl Error for ForwardError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::UriError(err) => Some(err), + } + } +} + +impl ForwardError { + pub fn into_error_status(self) -> xitca_web::error::ErrorStatus { + match self { + Self::UriError(_) => xitca_web::error::ErrorStatus::bad_request(), + } + } +} diff --git a/reverse-proxy/src/forwarder/forward_header.rs b/reverse-proxy/src/forwarder/forward_header.rs new file mode 100644 index 00000000..25265710 --- /dev/null +++ b/reverse-proxy/src/forwarder/forward_header.rs @@ -0,0 +1,314 @@ +use xitca_web::http::uri::Scheme; +use xitca_web::http::{header, HeaderMap, HeaderName}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ForwardedFor { + by: String, + strategy: ForwardedHeaderStrategy, + override_proto: Option, +} + +pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); +pub const X_FORWARDED_PROTO: HeaderName = HeaderName::from_static("x-forwarded-proto"); +pub const X_FORWARDED_HOST: HeaderName = HeaderName::from_static("x-forwarded-host"); +pub const X_FORWARDED_BY: HeaderName = HeaderName::from_static("x-forwarded-by"); + +impl Default for ForwardedFor { + fn default() -> Self { + Self { + by: "actix-proxy".to_string(), + strategy: ForwardedHeaderStrategy::Auto, + override_proto: None, + } + } +} + +impl ForwardedFor { + pub fn new_none() -> Self { + Self { + by: "".to_string(), + strategy: ForwardedHeaderStrategy::None, + override_proto: None, + } + } + + pub fn new_auto(by: &str, override_proto: Option) -> Self { + Self { + by: by.to_string(), + strategy: ForwardedHeaderStrategy::Auto, + override_proto, + } + } + + pub fn new_legacy(by: &str, override_proto: Option) -> Self { + Self { + by: by.to_string(), + strategy: ForwardedHeaderStrategy::Legacy, + override_proto, + } + } + + pub fn new_rfc7239(by: &str, override_proto: Option) -> Self { + Self { + by: by.to_string(), + strategy: ForwardedHeaderStrategy::RFC7239, + override_proto, + } + } + + pub(crate) fn apply(&self, headers: &mut HeaderMap, client_ip: &str, host: String, proto: Scheme) { + let proto = self.override_proto.clone().unwrap_or(proto); + + self.strategy.apply(headers, self.by.as_str(), client_ip, host, proto); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ForwardedHeaderStrategy { + None, + Auto, + Legacy, + RFC7239, +} + +impl ForwardedHeaderStrategy { + pub fn apply(&self, headers: &mut HeaderMap, by: &str, client_ip: &str, host: String, proto: Scheme) { + match self { + Self::None => (), + Self::Legacy => Self::apply_legacy(headers, by, client_ip, host, proto), + Self::RFC7239 => Self::apply_forwarded(headers, by, client_ip, host, proto), + Self::Auto => { + if headers.contains_key(X_FORWARDED_FOR) { + if headers.contains_key(X_FORWARDED_PROTO) + || headers.contains_key(X_FORWARDED_HOST) + || headers.contains_key("x-forwarded-by") + { + // Cannot transition use legacy headers + Self::apply_legacy(headers, by, client_ip, host, proto); + + return; + } + + // Transition to RFC7239 + // Transform x forwarded for to Forwarded + let mut forwarded_for_value = "".to_string(); + + for value in headers.get_all(X_FORWARDED_FOR).iter() { + if forwarded_for_value.is_empty() { + forwarded_for_value = format!("for={}", value.to_str().unwrap()); + } else { + forwarded_for_value = format!("{}, for={}", forwarded_for_value, value.to_str().unwrap()); + } + } + + headers.insert(header::FORWARDED, forwarded_for_value.parse().unwrap()); + headers.remove(X_FORWARDED_FOR); + } + + // Apply forwarded + Self::apply_forwarded(headers, by, client_ip, host, proto); + } + } + } + + fn apply_legacy(headers: &mut HeaderMap, by: &str, client_ip: &str, host: String, proto: Scheme) { + let by_value = match headers.get("x-forwarded-by") { + Some(value) => format!("{}, {}", value.to_str().unwrap(), by), + None => by.to_string(), + }; + + let mut forwarded_for_existing_value = headers + .get_all(X_FORWARDED_FOR) + .iter() + .map(|value| value.to_str().unwrap().to_string()) + .collect::>() + .join(", "); + + if !forwarded_for_existing_value.is_empty() { + forwarded_for_existing_value.push_str(", "); + forwarded_for_existing_value.push_str(client_ip); + + headers.insert(X_FORWARDED_FOR, forwarded_for_existing_value.parse().unwrap()); + } else { + headers.insert(X_FORWARDED_FOR, client_ip.parse().unwrap()); + } + + headers.insert(X_FORWARDED_BY, by_value.parse().unwrap()); + + if !headers.contains_key(X_FORWARDED_PROTO) { + let proto_value = match headers.get(X_FORWARDED_PROTO) { + Some(value) => format!("{}, {}", value.to_str().unwrap(), proto.as_str()), + None => proto.to_string(), + }; + + headers.insert(X_FORWARDED_PROTO, proto_value.parse().unwrap()); + } + + if !headers.contains_key(X_FORWARDED_HOST) { + headers.insert(X_FORWARDED_HOST, host.parse().unwrap()); + } + } + + fn apply_forwarded(headers: &mut HeaderMap, by: &str, client_ip: &str, host: String, proto: Scheme) { + let forwarded_for_value = format!("for={client_ip};by={by};host={host};proto={proto}"); + + if headers.contains_key(header::FORWARDED) { + let forwarded_existing_value = headers + .get_all(header::FORWARDED) + .iter() + .map(|value| value.to_str().unwrap()) + .collect::>() + .join(", "); + let forwarded_value = format!("{forwarded_existing_value}, {forwarded_for_value}"); + + headers.insert(header::FORWARDED, forwarded_value.parse().unwrap()); + } else { + headers.insert(header::FORWARDED, forwarded_for_value.parse().unwrap()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_forward_legacy() { + let mut headers = HeaderMap::new(); + + headers.insert(X_FORWARDED_FOR, "127.0.0.1".parse().unwrap()); + headers.insert(X_FORWARDED_PROTO, "https".parse().unwrap()); + headers.insert(X_FORWARDED_HOST, "localhost".parse().unwrap()); + + let strategy = ForwardedHeaderStrategy::Legacy; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!(headers.get(X_FORWARDED_FOR).unwrap().to_str().unwrap(), "127.0.0.1, 192.168.0.1"); + assert_eq!(headers.get(X_FORWARDED_PROTO).unwrap().to_str().unwrap(), "https"); + assert_eq!(headers.get(X_FORWARDED_HOST).unwrap().to_str().unwrap(), "localhost"); + } + + #[test] + fn test_forward_none() { + let mut headers = HeaderMap::new(); + + let strategy = ForwardedHeaderStrategy::None; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!(headers.get(header::FORWARDED), None); + assert_eq!(headers.get(X_FORWARDED_FOR), None); + assert_eq!(headers.get(X_FORWARDED_PROTO), None); + assert_eq!(headers.get(X_FORWARDED_HOST), None); + } + + #[test] + fn test_forward_rfc7239() { + let mut headers = HeaderMap::new(); + + let strategy = ForwardedHeaderStrategy::RFC7239; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!( + headers.get(header::FORWARDED).unwrap().to_str().unwrap(), + "for=192.168.0.1;by=by;host=test.com;proto=http" + ); + } + + #[test] + fn test_forward_legacy_multiple() { + let mut headers = HeaderMap::new(); + + headers.append(X_FORWARDED_FOR, "127.0.0.1".parse().unwrap()); + headers.append(X_FORWARDED_FOR, "127.0.0.2".parse().unwrap()); + headers.append(X_FORWARDED_PROTO, "https".parse().unwrap()); + headers.append(X_FORWARDED_HOST, "localhost".parse().unwrap()); + + let strategy = ForwardedHeaderStrategy::Legacy; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!( + headers.get(X_FORWARDED_FOR).unwrap().to_str().unwrap(), + "127.0.0.1, 127.0.0.2, 192.168.0.1" + ); + assert_eq!(headers.get(X_FORWARDED_PROTO).unwrap().to_str().unwrap(), "https"); + assert_eq!(headers.get(X_FORWARDED_HOST).unwrap().to_str().unwrap(), "localhost"); + } + + #[test] + fn test_forward_legacy_no_proto() { + let mut headers = HeaderMap::new(); + + headers.insert(X_FORWARDED_FOR, "127.0.0.1".parse().unwrap()); + headers.insert(X_FORWARDED_HOST, "localhost".parse().unwrap()); + + let strategy = ForwardedHeaderStrategy::Legacy; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!(headers.get(X_FORWARDED_FOR).unwrap().to_str().unwrap(), "127.0.0.1, 192.168.0.1"); + assert_eq!(headers.get(X_FORWARDED_PROTO).unwrap().to_str().unwrap(), "http"); + assert_eq!(headers.get(X_FORWARDED_HOST).unwrap().to_str().unwrap(), "localhost"); + assert!(!headers.contains_key(header::FORWARDED)); + } + + #[test] + fn test_forward_legacy_no_host() { + let mut headers = HeaderMap::new(); + + headers.insert(X_FORWARDED_FOR, "127.0.0.1".parse().unwrap()); + + let strategy = ForwardedHeaderStrategy::Legacy; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!(headers.get(X_FORWARDED_FOR).unwrap().to_str().unwrap(), "127.0.0.1, 192.168.0.1"); + assert_eq!(headers.get(X_FORWARDED_PROTO).unwrap().to_str().unwrap(), "http"); + assert_eq!(headers.get(X_FORWARDED_HOST).unwrap().to_str().unwrap(), "test.com"); + assert!(!headers.contains_key(header::FORWARDED)); + } + + #[test] + fn test_forward_auto_forwarded() { + let mut headers = HeaderMap::new(); + + let strategy = ForwardedHeaderStrategy::Auto; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert!(headers.contains_key(header::FORWARDED)); + assert_eq!( + headers.get(header::FORWARDED).unwrap().to_str().unwrap(), + "for=192.168.0.1;by=by;host=test.com;proto=http" + ); + } + + #[test] + fn test_forward_auto_forwarded_transition() { + let mut headers = HeaderMap::new(); + + headers.insert(X_FORWARDED_FOR, "127.0.0.1".parse().unwrap()); + + let strategy = ForwardedHeaderStrategy::Auto; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert!(headers.contains_key(header::FORWARDED)); + assert_eq!( + headers.get(header::FORWARDED).unwrap().to_str().unwrap(), + "for=127.0.0.1, for=192.168.0.1;by=by;host=test.com;proto=http" + ); + } + + #[test] + fn test_forward_auto_legacy() { + let mut headers = HeaderMap::new(); + + headers.insert(X_FORWARDED_FOR, "127.0.0.1".parse().unwrap()); + headers.insert(X_FORWARDED_PROTO, "https".parse().unwrap()); + headers.insert(X_FORWARDED_HOST, "localhost".parse().unwrap()); + + let strategy = ForwardedHeaderStrategy::Auto; + strategy.apply(&mut headers, "by", "192.168.0.1", "test.com".to_string(), Scheme::HTTP); + + assert_eq!(headers.get(X_FORWARDED_FOR).unwrap().to_str().unwrap(), "127.0.0.1, 192.168.0.1"); + assert_eq!(headers.get(X_FORWARDED_PROTO).unwrap().to_str().unwrap(), "https"); + assert_eq!(headers.get(X_FORWARDED_HOST).unwrap().to_str().unwrap(), "localhost"); + assert!(!headers.contains_key(header::FORWARDED)); + } +} diff --git a/reverse-proxy/src/forwarder/mod.rs b/reverse-proxy/src/forwarder/mod.rs new file mode 100644 index 00000000..1a5632df --- /dev/null +++ b/reverse-proxy/src/forwarder/mod.rs @@ -0,0 +1,5 @@ +// mod client; +mod error; + +pub(crate) mod forward_header; +pub use crate::forwarder::error::ForwardError; diff --git a/reverse-proxy/src/lib.rs b/reverse-proxy/src/lib.rs new file mode 100644 index 00000000..a14d23ab --- /dev/null +++ b/reverse-proxy/src/lib.rs @@ -0,0 +1,15 @@ +#[macro_use] +extern crate lazy_static; + +mod error; +mod forwarder; +mod peer; +mod peer_resolver; +mod proxy; +mod service; + +pub use forwarder::forward_header::ForwardedFor; +pub use peer::HttpPeer; +pub use peer_resolver::HttpPeerResolve; +pub use proxy::Proxy; +pub use error::ProxyError; diff --git a/reverse-proxy/src/peer.rs b/reverse-proxy/src/peer.rs new file mode 100644 index 00000000..c9d2f10a --- /dev/null +++ b/reverse-proxy/src/peer.rs @@ -0,0 +1,120 @@ +use crate::forwarder::forward_header::ForwardedFor; +use std::collections::HashSet; +use std::net::SocketAddr; +use std::time::Duration; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ViaHeader { + pub(crate) add_in_request: bool, + pub(crate) add_in_response: bool, + pub(crate) name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HttpPeer { + pub(crate) forward_for: ForwardedFor, + pub(crate) address: SocketAddr, + pub(crate) sni_host: String, + pub(crate) request_host: String, + pub(crate) via: Option, + pub(crate) allow_invalid_certificates: bool, + pub(crate) supported_encodings: Option>, + pub(crate) force_close: bool, + pub(crate) tls: bool, + pub(crate) request_body_size_limit: usize, + pub(crate) timeout: Option, +} + +impl HttpPeer { + /// Create a new `HttpPeer` with the given address and request host. + pub fn new(address: SocketAddr, request_host: &str) -> Self { + Self { + forward_for: ForwardedFor::default(), + address, + sni_host: request_host.to_string(), + request_host: request_host.to_string(), + via: None, + allow_invalid_certificates: false, + supported_encodings: None, + force_close: false, + tls: address.port() == 443, + request_body_size_limit: 1024 * 1024 * 16, // 16MB + timeout: None, + } + } + + /// Set the `Forwarded-For` or `X-Forwarded-*` headers to the given value. + pub fn forward_for(mut self, forward_for: ForwardedFor) -> Self { + self.forward_for = forward_for; + self + } + + /// Set the SNI host to the given value. This is the host that will be used for the TLS handshake. + pub fn sni_host(mut self, sni_host: String) -> Self { + self.sni_host = sni_host; + self + } + + /// Set the `Host` header to the given value. This is the host that will be used for the HTTP request. + pub fn request_host(mut self, request_host: String) -> Self { + self.request_host = request_host; + self + } + + /// Set the `Via` header to the given value. This header can be added to the request fowarded to the peer and/or to the response returned to the client. + pub fn via(mut self, via: &str, in_request: bool, in_response: bool) -> Self { + self.via = Some(ViaHeader { + add_in_request: in_request, + add_in_response: in_response, + name: via.to_string(), + }); + + self + } + + /// Set whether invalid certificates should be allowed. + pub fn allow_invalid_certificates(mut self, allow_invalid_certificates: bool) -> Self { + self.allow_invalid_certificates = allow_invalid_certificates; + self + } + + /// A set of supported encodings that this server can handle, this may be useful if you need to + /// update the response body and you cannot handle some encodings. + pub fn supported_encodings(mut self, supported_encodings: HashSet) -> Self { + self.supported_encodings = Some(supported_encodings); + self + } + + /// Set whether the connection should be closed after the request has been forwarded to the peer. + pub fn force_close(mut self, force_close: bool) -> Self { + self.force_close = force_close; + self + } + + /// Set whether the connection should be encrypted using TLS when connecting to the peer. + pub fn tls(mut self, tls: bool) -> Self { + self.tls = tls; + self + } + + /// Set the timeout for the request to the peer. + /// If the timeout is reached, the request will be aborted and an error will be returned to the client. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Set the maximum size of the request body that can be stored in memory. + /// + /// This limit should not be used in most cases as the request body is streamed and not stored in memory. + /// However, there is some cases where the request body needs to be stored in memory, think + /// of an old client sending HTTP/1.0 requests with a body. In this case the body will be stored + /// in memory before being forwarded to the peer. + /// + /// This limit prevents the server from storing too much data in memory. + /// The default value is 16MB. + pub fn request_body_size_limit(mut self, request_body_size_limit: usize) -> Self { + self.request_body_size_limit = request_body_size_limit; + self + } +} diff --git a/reverse-proxy/src/peer_resolver.rs b/reverse-proxy/src/peer_resolver.rs new file mode 100644 index 00000000..dcd094e4 --- /dev/null +++ b/reverse-proxy/src/peer_resolver.rs @@ -0,0 +1,20 @@ +use crate::peer::HttpPeer; +use std::convert::Infallible; +use std::rc::Rc; +use xitca_client::Service; +use xitca_web::http::{request::Parts, Request}; + +pub type HttpPeerResolve = dyn Service, Error = Infallible, Response = Option>>; + +#[derive(Clone)] +pub(crate) enum HttpPeerResolver { + Static(Rc), +} + +impl HttpPeerResolver { + pub async fn resolve(&self, _: &Parts) -> Option> { + match self { + Self::Static(peer) => Some(peer.clone()), + } + } +} diff --git a/reverse-proxy/src/proxy.rs b/reverse-proxy/src/proxy.rs new file mode 100644 index 00000000..0ebe4c2e --- /dev/null +++ b/reverse-proxy/src/proxy.rs @@ -0,0 +1,45 @@ +use crate::peer_resolver::HttpPeerResolver; +use crate::service::ProxyService; +use crate::HttpPeer; +use std::convert::Infallible; +use std::rc::Rc; +use xitca_http::util::service::router::{PathGen, RouteGen}; +use xitca_web::service::Service; + +pub struct Proxy { + peer: HttpPeer, +} + +impl PathGen for Proxy { + fn path_gen(&mut self, prefix: &str) -> String { + let mut prefix = String::from(prefix); + prefix.push_str("*p"); + prefix + } +} + +impl RouteGen for Proxy { + type Route = R; + + fn route_gen(route: R) -> Self::Route { + route + } +} + +impl Proxy { + pub fn new(peer: HttpPeer) -> Self { + Self { peer } + } +} + +impl Service for Proxy { + type Response = ProxyService; + type Error = Infallible; + + async fn call(&self, _: ()) -> Result { + Ok(ProxyService { + peer_resolver: Rc::new(HttpPeerResolver::Static(Rc::new(self.peer.clone()))), + client: Rc::new(xitca_client::ClientBuilder::new().openssl().finish()), + }) + } +} diff --git a/reverse-proxy/src/service.rs b/reverse-proxy/src/service.rs new file mode 100644 index 00000000..258d4a92 --- /dev/null +++ b/reverse-proxy/src/service.rs @@ -0,0 +1,250 @@ +use crate::forwarder::{ForwardError}; +use crate::peer_resolver::HttpPeerResolver; +use crate::HttpPeer; +use bytes::Bytes; +use std::collections::HashSet; +use std::rc::Rc; +use std::str::FromStr; +use xitca_client::{Client}; +use xitca_http::body::BoxBody; +use xitca_http::BodyError; +use xitca_http::http::header::AsHeaderName; +use xitca_http::http::{StatusCode, Version}; +use xitca_http::util::service::RouterError; +use xitca_web::error::ErrorStatus; +use xitca_web::http::uri::Scheme; +use xitca_web::http::{header, HeaderMap, HeaderName, Request, Uri, WebResponse}; +use xitca_web::service::Service; +use xitca_web::{BodyStream, WebContext}; + +lazy_static! { + static ref HOP_HEADERS: HashSet = { + let mut hop_headers = HashSet::new(); + + hop_headers.insert(header::CONNECTION); + hop_headers.insert(HeaderName::from_str("proxy-connection").unwrap()); + hop_headers.insert(HeaderName::from_str("keep-alive").unwrap()); + hop_headers.insert(header::PROXY_AUTHENTICATE); + hop_headers.insert(header::PROXY_AUTHORIZATION); + hop_headers.insert(header::TE); + hop_headers.insert(header::TRAILER); + hop_headers.insert(header::TRANSFER_ENCODING); + hop_headers.insert(header::UPGRADE); + + hop_headers + }; +} + +#[derive(Clone)] +pub struct ProxyService { + pub(crate) peer_resolver: Rc, + pub(crate) client: Rc, +} + +impl<'r, C, B> Service> for ProxyService +where + C: 'static, + B: BodyStream + Default + 'static, +{ + type Response = WebResponse; + type Error = RouterError; + + async fn call(&self, mut ctx: WebContext<'r, C, B>) -> Result { + let downstream_request = ctx.take_request(); + let (downstream_request_head, _downstream_body) = downstream_request.into_parts(); + + let peer = self.peer_resolver.resolve(&downstream_request_head).await.unwrap(); + + // @TODO this doesn't work when body is empty, as we don't know the size which induce a chunked encoding, and eof is called on that which makes it panics + // let mut upstream_request = Request::new(_downstream_body); + let mut upstream_request = Request::new(BoxBody::default()); + *upstream_request.method_mut() = downstream_request_head.method; + *upstream_request.uri_mut() = match Uri::builder() + .path_and_query(match downstream_request_head.uri.path_and_query() { + Some(path_and_query) => path_and_query.as_str(), + None => downstream_request_head.uri.path(), + }) + // @TODO only work for http 1.1, need to update lib to be able to separate request host from sni_host + .authority(peer.sni_host.as_str()) + .scheme(if peer.tls { Scheme::HTTPS } else { Scheme::HTTP }) + .build() + { + Err(err) => { + return Err(RouterError::Service(ForwardError::UriError(err).into_error_status())); + } + Ok(url) => url, + }; + + let upstream_request_connection_headers = get_connection_headers(&downstream_request_head.headers); + + for (name, value) in downstream_request_head.headers.iter() { + if HOP_HEADERS.contains(name) { + continue; + } + + if upstream_request_connection_headers.contains(name.as_str()) { + continue; + } + + if name == header::HOST { + continue; + } + + upstream_request.headers_mut().append(name.clone(), value.clone()); + } + + if contain_value(&downstream_request_head.headers, header::TE, "trailers") { + upstream_request.headers_mut().insert(header::TE, "trailers".parse().unwrap()); + } + + if let Some(via) = &peer.via { + if via.add_in_request { + let version = match downstream_request_head.version { + Version::HTTP_09 => Some("0.9"), + Version::HTTP_10 => Some("1.0"), + Version::HTTP_11 => Some("1.1"), + Version::HTTP_2 => Some("2.0"), + Version::HTTP_3 => Some("3.0"), + _ => None, + }; + + if let Some(version_str) = version { + upstream_request + .headers_mut() + .append(header::VIA, format!("HTTP/{} {}", version_str, via.name).parse().unwrap()); + } + } + } + + let current_host = downstream_request_head + .headers + .get(header::HOST) + .map(|v| v.to_str().unwrap_or_default().to_string()) + .or_else(|| downstream_request_head.uri.host().map(|v| v.to_string())) + .unwrap_or_else(|| "localhost".to_string()); + + // @TODO only work for http 1.1, need to update lib to be able to separate request host from sni_host + upstream_request + .headers_mut() + .insert(header::HOST, peer.request_host.parse().unwrap()); + + // @TODO Get them from forwaded headers if available + let addr = ctx.req().body().socket_addr().ip().to_string(); + let scheme = downstream_request_head.uri.scheme().cloned().unwrap_or(Scheme::HTTP); + + peer.forward_for + .apply(upstream_request.headers_mut(), addr.as_str(), current_host, scheme); + + // @TODO Handle upgrade request + + // @TODO Handle invalid certificates + let client = self.client.clone(); + let mut upstream_request = client.request(upstream_request); + + // @TODO Need to set the correct address + // upstream_request = upstream_request.address(peer.address); + + if let Some(timeout) = peer.timeout { + upstream_request = upstream_request.timeout(timeout); + } + + // @TODO check bug with http 1.0 + + let upstream_response = match upstream_request.send().await { + Ok(res) => res, + Err(err) => { + println!("error: {:?}", err); + // @TODO handle better error + return Err(RouterError::Service(ErrorStatus::from(StatusCode::from_u16(503).unwrap()))); + } + }; + + let (parts, body) = upstream_response.into_parts(); + + // @TODO body into owned is a bad thing, since we lost the capability to having a pool of connections for the client (each request will be a new connection) + // However without that since we don't have the client in the response we can't have a pool of connections for the client + let mut response = WebResponse::new(BoxBody::new(body.into_owned())); + *response.status_mut() = parts.status.clone(); + + map_headers(peer, response.headers_mut(), parts.version, parts.status, &parts.headers); + + Ok(response) + } +} + +fn map_headers(peer: Rc, downstream_headers: &mut HeaderMap, version: Version, status: StatusCode, upstream_headers: &HeaderMap) { + let response_connection_headers = get_connection_headers(upstream_headers); + + for (name, value) in upstream_headers { + // Skip headers only when no switching protocols + if status != StatusCode::SWITCHING_PROTOCOLS { + if HOP_HEADERS.contains(name) { + continue; + } + + if name == header::CONTENT_LENGTH { + continue; + } + + if response_connection_headers.contains(name.as_str()) { + continue; + } + } + + downstream_headers.append(name, value.clone()); + } + + if let Some(via) = &peer.via { + if via.add_in_response { + let via_version = match version { + Version::HTTP_09 => Some("0.9"), + Version::HTTP_10 => Some("1.0"), + Version::HTTP_11 => Some("1.1"), + Version::HTTP_2 => Some("2.0"), + Version::HTTP_3 => Some("3.0"), + _ => None, + }; + + if let Some(via_version_str) = via_version { + downstream_headers.append(header::VIA, format!("HTTP/{} {}", via_version_str, via.name).parse().unwrap()); + } + } + } +} + +fn get_connection_headers(header_map: &HeaderMap) -> HashSet { + let mut connection_headers = HashSet::new(); + + for conn_value in header_map.get_all("connection") { + match conn_value.to_str() { + Err(_) => (), + Ok(conn_value_str) => { + for value in conn_value_str.split(',') { + match HeaderName::from_str(value.trim()) { + Err(_) => (), + Ok(header_name) => { + connection_headers.insert(header_name.as_str().to_lowercase()); + } + } + } + } + } + } + + connection_headers +} + +fn contain_value(map: &HeaderMap, key: impl AsHeaderName, value: &str) -> bool { + for val in map.get_all(key) { + match val.to_str() { + Err(_) => (), + Ok(vs) => { + if value.to_lowercase() == vs.to_lowercase() { + return true; + } + } + } + } + + false +}