From ce06bfb09261750b720a8a98ab8b419981ffd25c Mon Sep 17 00:00:00 2001 From: Zhang Jingqiang Date: Sat, 16 Nov 2024 23:30:45 +0800 Subject: [PATCH] g3proxy: send small http body along with header --- Cargo.lock | 4 +- .../direct_fixed/http_forward/writer.rs | 8 +- .../direct_float/http_forward/writer.rs | 8 +- .../peer/http/http_forward/writer.rs | 15 +- .../peer/https/http_forward/writer.rs | 15 +- .../peer/socks5/http_forward/writer.rs | 8 +- .../peer/socks5s/http_forward/writer.rs | 8 +- .../escape/proxy_http/http_forward/writer.rs | 15 +- .../escape/proxy_https/http_forward/writer.rs | 15 +- .../src/module/http_forward/connection/mod.rs | 8 +- .../module/http_forward/connection/writer.rs | 31 ++- g3proxy/src/module/http_forward/mod.rs | 4 +- .../src/serve/http_proxy/task/forward/task.rs | 152 +++++++++++--- .../serve/http_rproxy/task/forward/task.rs | 193 +++++++++++++----- 14 files changed, 370 insertions(+), 114 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 282f5af94..58fc87b39 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2667,9 +2667,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libloading" diff --git a/g3proxy/src/escape/direct_fixed/http_forward/writer.rs b/g3proxy/src/escape/direct_fixed/http_forward/writer.rs index c622c1e3e..dbd84d365 100644 --- a/g3proxy/src/escape/direct_fixed/http_forward/writer.rs +++ b/g3proxy/src/escape/direct_fixed/http_forward/writer.rs @@ -104,7 +104,11 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { - send_req_header_to_origin(&mut self.inner, req).await + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/escape/direct_float/http_forward/writer.rs b/g3proxy/src/escape/direct_float/http_forward/writer.rs index 965757772..088dc0a96 100644 --- a/g3proxy/src/escape/direct_float/http_forward/writer.rs +++ b/g3proxy/src/escape/direct_float/http_forward/writer.rs @@ -109,11 +109,15 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if self.bind.is_expired() { Err(io::Error::other("connection has expired")) } else { - send_req_header_to_origin(&mut self.inner, req).await + send_req_header_to_origin(&mut self.inner, req, body).await } } } diff --git a/g3proxy/src/escape/proxy_float/peer/http/http_forward/writer.rs b/g3proxy/src/escape/proxy_float/peer/http/http_forward/writer.rs index db9bf0be6..6fa41e4da 100644 --- a/g3proxy/src/escape/proxy_float/peer/http/http_forward/writer.rs +++ b/g3proxy/src/escape/proxy_float/peer/http/http_forward/writer.rs @@ -114,7 +114,11 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if let Some(expire) = &self.config.expire_instant { let now = Instant::now(); if expire.checked_duration_since(now).is_none() { @@ -124,6 +128,7 @@ where send_req_header_via_proxy( &mut self.inner, req, + body, &self.upstream, &self.config.append_http_headers, None, @@ -205,13 +210,17 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if let Some(expire) = &self.config.expire_instant { let now = Instant::now(); if expire.checked_duration_since(now).is_none() { return Err(io::Error::other("connection has expired")); } } - send_req_header_to_origin(&mut self.inner, req).await + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/escape/proxy_float/peer/https/http_forward/writer.rs b/g3proxy/src/escape/proxy_float/peer/https/http_forward/writer.rs index 4c00f8c7c..b8ee63a67 100644 --- a/g3proxy/src/escape/proxy_float/peer/https/http_forward/writer.rs +++ b/g3proxy/src/escape/proxy_float/peer/https/http_forward/writer.rs @@ -105,7 +105,11 @@ where self.inner.reset_stats(Arc::new(wrapper_stats)); } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if let Some(expire) = &self.config.expire_instant { let now = Instant::now(); if expire.checked_duration_since(now).is_none() { @@ -115,6 +119,7 @@ where send_req_header_via_proxy( &mut self.inner, req, + body, &self.upstream, &self.config.append_http_headers, None, @@ -184,13 +189,17 @@ where self.inner.reset_stats(Arc::new(wrapper_stats)); } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if let Some(expire) = &self.config.expire_instant { let now = Instant::now(); if expire.checked_duration_since(now).is_none() { return Err(io::Error::other("connection has expired")); } } - send_req_header_to_origin(&mut self.inner, req).await + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/escape/proxy_float/peer/socks5/http_forward/writer.rs b/g3proxy/src/escape/proxy_float/peer/socks5/http_forward/writer.rs index df73f5d16..9daf6348a 100644 --- a/g3proxy/src/escape/proxy_float/peer/socks5/http_forward/writer.rs +++ b/g3proxy/src/escape/proxy_float/peer/socks5/http_forward/writer.rs @@ -109,13 +109,17 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if let Some(expire) = &self.config.expire_instant { let now = Instant::now(); if expire.checked_duration_since(now).is_none() { return Err(io::Error::other("connection has expired")); } } - send_req_header_to_origin(&mut self.inner, req).await + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/escape/proxy_float/peer/socks5s/http_forward/writer.rs b/g3proxy/src/escape/proxy_float/peer/socks5s/http_forward/writer.rs index 14944f537..8a5214e64 100644 --- a/g3proxy/src/escape/proxy_float/peer/socks5s/http_forward/writer.rs +++ b/g3proxy/src/escape/proxy_float/peer/socks5s/http_forward/writer.rs @@ -97,13 +97,17 @@ where self.inner.reset_stats(Arc::new(wrapper_stats)); } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { if let Some(expire) = &self.config.expire_instant { let now = Instant::now(); if expire.checked_duration_since(now).is_none() { return Err(io::Error::other("connection has expired")); } } - send_req_header_to_origin(&mut self.inner, req).await + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/escape/proxy_http/http_forward/writer.rs b/g3proxy/src/escape/proxy_http/http_forward/writer.rs index 118a57fdd..77a7a6e31 100644 --- a/g3proxy/src/escape/proxy_http/http_forward/writer.rs +++ b/g3proxy/src/escape/proxy_http/http_forward/writer.rs @@ -116,11 +116,16 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { let userid = self.pass_userid.as_deref(); send_req_header_via_proxy( &mut self.inner, req, + body, &self.upstream, &self.config.append_http_headers, userid, @@ -202,7 +207,11 @@ where } } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { - send_req_header_to_origin(&mut self.inner, req).await + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/escape/proxy_https/http_forward/writer.rs b/g3proxy/src/escape/proxy_https/http_forward/writer.rs index 1b621bda6..9d6178a2c 100644 --- a/g3proxy/src/escape/proxy_https/http_forward/writer.rs +++ b/g3proxy/src/escape/proxy_https/http_forward/writer.rs @@ -107,11 +107,16 @@ where self.inner.reset_stats(Arc::new(wrapper_stats)); } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { let userid = self.pass_userid.as_deref(); send_req_header_via_proxy( &mut self.inner, req, + body, &self.upstream, &self.config.append_http_headers, userid, @@ -181,7 +186,11 @@ where self.inner.reset_stats(Arc::new(wrapper_stats)); } - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { - send_req_header_to_origin(&mut self.inner, req).await + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()> { + send_req_header_to_origin(&mut self.inner, req, body).await } } diff --git a/g3proxy/src/module/http_forward/connection/mod.rs b/g3proxy/src/module/http_forward/connection/mod.rs index b12f88c72..9e2f3cd71 100644 --- a/g3proxy/src/module/http_forward/connection/mod.rs +++ b/g3proxy/src/module/http_forward/connection/mod.rs @@ -51,7 +51,11 @@ pub(crate) trait HttpForwardWrite: AsyncWrite { user_stats: Vec>, ); - async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()>; + async fn send_request_header( + &mut self, + req: &HttpProxyClientRequest, + body: Option<&[u8]>, + ) -> io::Result<()>; } #[async_trait] @@ -107,6 +111,6 @@ impl AsyncWrite for HttpForwardWriterForAdaptation<'_> { impl HttpRequestUpstreamWriter for HttpForwardWriterForAdaptation<'_> { async fn send_request_header(&mut self, req: &HttpProxyClientRequest) -> io::Result<()> { - self.inner.send_request_header(req).await + self.inner.send_request_header(req, None).await } } diff --git a/g3proxy/src/module/http_forward/connection/writer.rs b/g3proxy/src/module/http_forward/connection/writer.rs index bc9b55f7f..379fb112b 100644 --- a/g3proxy/src/module/http_forward/connection/writer.rs +++ b/g3proxy/src/module/http_forward/connection/writer.rs @@ -14,19 +14,21 @@ * limitations under the License. */ -use std::io; +use std::io::{self, IoSlice}; + +use g3_io_ext::LimitedWriteExt; +use g3_types::net::UpstreamAddr; use bytes::BufMut; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use g3_types::net::UpstreamAddr; - use super::HttpProxyClientRequest; use crate::module::http_header; pub(crate) async fn send_req_header_via_proxy( writer: &mut W, req: &HttpProxyClientRequest, + body: Option<&[u8]>, upstream: &UpstreamAddr, append_header_lines: &[String], pass_userid: Option<&str>, @@ -45,16 +47,35 @@ where } buf.put_slice(b"\r\n"); - writer.write_all(buf.as_ref()).await + send_request_header(writer, buf.as_slice(), body).await } pub(crate) async fn send_req_header_to_origin( writer: &mut W, req: &HttpProxyClientRequest, + body: Option<&[u8]>, ) -> io::Result<()> where W: AsyncWrite + Unpin, { let buf = req.serialize_for_origin(); - writer.write_all(buf.as_ref()).await + send_request_header(writer, buf.as_slice(), body).await +} + +async fn send_request_header( + writer: &mut W, + header: &[u8], + body: Option<&[u8]>, +) -> io::Result<()> +where + W: AsyncWrite + Unpin, +{ + if let Some(body) = body { + writer + .write_all_vectored([IoSlice::new(header), IoSlice::new(body)]) + .await?; + Ok(()) + } else { + writer.write_all(header).await + } } diff --git a/g3proxy/src/module/http_forward/mod.rs b/g3proxy/src/module/http_forward/mod.rs index c0bb8113c..f5a52dfde 100644 --- a/g3proxy/src/module/http_forward/mod.rs +++ b/g3proxy/src/module/http_forward/mod.rs @@ -22,8 +22,8 @@ mod task; pub(crate) use connection::{ send_req_header_to_origin, send_req_header_via_proxy, BoxHttpForwardConnection, - BoxHttpForwardReader, BoxHttpForwardWriter, HttpConnectionEofPoller, HttpForwardRead, - HttpForwardWrite, HttpForwardWriterForAdaptation, + BoxHttpForwardReader, HttpConnectionEofPoller, HttpForwardRead, HttpForwardWrite, + HttpForwardWriterForAdaptation, }; pub(crate) use context::{ BoxHttpForwardContext, DirectHttpForwardContext, FailoverHttpForwardContext, diff --git a/g3proxy/src/serve/http_proxy/task/forward/task.rs b/g3proxy/src/serve/http_proxy/task/forward/task.rs index 42a4c1ea0..adf19a57a 100644 --- a/g3proxy/src/serve/http_proxy/task/forward/task.rs +++ b/g3proxy/src/serve/http_proxy/task/forward/task.rs @@ -19,8 +19,9 @@ use std::time::Duration; use anyhow::anyhow; use futures_util::FutureExt; +use http::header; use log::debug; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::Instant; use g3_http::client::HttpForwardRemoteResponse; @@ -48,8 +49,8 @@ use crate::audit::AuditContext; use crate::config::server::ServerConfig; use crate::log::task::http_forward::TaskLogForHttpForward; use crate::module::http_forward::{ - BoxHttpForwardConnection, BoxHttpForwardContext, BoxHttpForwardReader, BoxHttpForwardWriter, - HttpForwardTaskNotes, HttpProxyClientResponse, + BoxHttpForwardConnection, BoxHttpForwardContext, BoxHttpForwardReader, HttpForwardTaskNotes, + HttpProxyClientResponse, }; use crate::module::http_header; use crate::module::tcp_connect::{ @@ -206,7 +207,7 @@ impl<'a> HttpProxyForwardTask<'a> { let http_user_agent = self .req .end_to_end_headers - .get(http::header::USER_AGENT) + .get(header::USER_AGENT) .map(|v| v.to_str()); TaskLogForHttpForward { upstream: &self.upstream, @@ -1008,7 +1009,14 @@ impl<'a> HttpProxyForwardTask<'a> { let ups_w = &mut ups_c.0; let ups_r = &mut ups_c.1; - self.send_request_header(ups_w).await?; + ups_w + .send_request_header(self.req, None) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; self.http_notes.mark_req_send_hdr(); self.http_notes.mark_req_no_body(); self.http_notes.retry_new_connection = false; @@ -1045,6 +1053,62 @@ impl<'a> HttpProxyForwardTask<'a> { } } + async fn run_with_small_body( + &mut self, + body: Vec, + clt_w: &mut W, + mut ups_c: BoxHttpForwardConnection, + ) -> ServerTaskResult> + where + W: AsyncWrite + Send + Unpin, + { + let ups_w = &mut ups_c.0; + let ups_r = &mut ups_c.1; + + self.http_notes.retry_new_connection = false; + ups_w + .send_request_header(self.req, Some(body.as_slice())) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + self.http_notes.mark_req_send_hdr(); + self.http_notes.mark_req_send_all(); + + let mut rsp_header = match tokio::time::timeout( + self.rsp_hdr_recv_timeout(), + self.recv_response_header(ups_r), + ) + .await + { + Ok(Ok(rsp_header)) => rsp_header, + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(ServerTaskError::UpstreamAppTimeout( + "timeout to receive response header", + )) + } + }; + self.http_notes.mark_rsp_recv_hdr(); + + self.send_response(clt_w, ups_r, &mut rsp_header, false, None) + .await?; + + self.task_notes.stage = ServerTaskStage::Finished; + if self.should_close { + if self.is_https { + // make sure we correctly shutdown tls connection, or the ticket won't be reused + // FIXME use async drop at escaper side when supported + let _ = ups_w.shutdown().await; + } + Ok(None) + } else { + Ok(Some(ups_c)) + } + } + async fn run_with_body( &mut self, clt_r: &mut R, @@ -1058,22 +1122,68 @@ impl<'a> HttpProxyForwardTask<'a> { let ups_w = &mut ups_c.0; let ups_r = &mut ups_c.1; - self.send_request_header(ups_w).await?; - self.http_notes.mark_req_send_hdr(); - self.http_notes.retry_new_connection = false; - let mut clt_body_reader = HttpBodyReader::new( clt_r, self.req.body_type().unwrap(), self.ctx.server_config.body_line_max_len, ); - let mut rsp_header: Option = None; - let mut clt_to_ups = LimitedCopy::new( - &mut clt_body_reader, - ups_w, - &self.ctx.server_config.tcp_copy, - ); + let mut clt_to_ups = if self.req.end_to_end_headers.contains_key(header::EXPECT) { + ups_w + .send_request_header(self.req, None) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + self.http_notes.mark_req_send_hdr(); + self.http_notes.retry_new_connection = false; + + LimitedCopy::new( + &mut clt_body_reader, + ups_w, + &self.ctx.server_config.tcp_copy, + ) + } else { + let mut pre_read_buf = vec![0u8; self.ctx.server_config.tcp_copy.buffer_size()]; + match clt_body_reader + .read_to_end(&mut pre_read_buf) + .now_or_never() + { + Some(Ok(0)) => return Err(ServerTaskError::ClosedByClient), + Some(Ok(n)) => { + pre_read_buf.truncate(n); + if clt_body_reader.finished() { + return self.run_with_small_body(pre_read_buf, clt_w, ups_c).await; + } + } + Some(Err(e)) => return Err(ServerTaskError::ClientTcpReadFailed(e)), + None => { + pre_read_buf.truncate(0); + } + } + + ups_w + .send_request_header(self.req, None) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + self.http_notes.mark_req_send_hdr(); + self.http_notes.retry_new_connection = false; + + LimitedCopy::with_data( + &mut clt_body_reader, + ups_w, + &self.ctx.server_config.tcp_copy, + pre_read_buf, + ) + }; + + let mut rsp_header: Option = None; let idle_duration = self.ctx.server_config.task_idle_check_duration; let mut idle_interval = @@ -1229,18 +1339,6 @@ impl<'a> HttpProxyForwardTask<'a> { } } - async fn send_request_header(&self, ups_w: &mut BoxHttpForwardWriter) -> ServerTaskResult<()> { - ups_w - .send_request_header(self.req) - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - ups_w - .flush() - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - Ok(()) - } - async fn recv_response_header( &mut self, ups_r: &mut BoxHttpForwardReader, diff --git a/g3proxy/src/serve/http_rproxy/task/forward/task.rs b/g3proxy/src/serve/http_rproxy/task/forward/task.rs index 0370e0861..eeb655de9 100644 --- a/g3proxy/src/serve/http_rproxy/task/forward/task.rs +++ b/g3proxy/src/serve/http_rproxy/task/forward/task.rs @@ -17,10 +17,19 @@ use std::sync::Arc; use futures_util::FutureExt; +use http::header; use log::debug; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::time::Instant; +use g3_http::client::HttpForwardRemoteResponse; +use g3_http::server::HttpProxyClientRequest; +use g3_http::{HttpBodyReader, HttpBodyType}; +use g3_io_ext::{ + GlobalLimitGroup, LimitedBufReadExt, LimitedCopy, LimitedCopyError, LimitedWriteExt, +}; +use g3_types::acl::AclAction; + use super::protocol::{HttpClientReader, HttpClientWriter, HttpRProxyRequest}; use super::{ CommonTaskContext, HttpForwardTaskCltWrapperStats, HttpForwardTaskStats, @@ -29,8 +38,8 @@ use super::{ use crate::config::server::ServerConfig; use crate::log::task::http_forward::TaskLogForHttpForward; use crate::module::http_forward::{ - BoxHttpForwardConnection, BoxHttpForwardContext, BoxHttpForwardReader, BoxHttpForwardWriter, - HttpForwardTaskNotes, HttpProxyClientResponse, + BoxHttpForwardConnection, BoxHttpForwardContext, BoxHttpForwardReader, HttpForwardTaskNotes, + HttpProxyClientResponse, }; use crate::module::tcp_connect::{ TcpConnectError, TcpConnectTaskConf, TcpConnectTaskNotes, TlsConnectTaskConf, @@ -40,13 +49,6 @@ use crate::serve::{ ServerStats, ServerTaskError, ServerTaskForbiddenError, ServerTaskNotes, ServerTaskResult, ServerTaskStage, }; -use g3_http::client::HttpForwardRemoteResponse; -use g3_http::server::HttpProxyClientRequest; -use g3_http::{HttpBodyReader, HttpBodyType}; -use g3_io_ext::{ - GlobalLimitGroup, LimitedBufReadExt, LimitedCopy, LimitedCopyError, LimitedWriteExt, -}; -use g3_types::acl::AclAction; pub(crate) struct HttpRProxyForwardTask<'a> { ctx: Arc, @@ -186,7 +188,7 @@ impl<'a> HttpRProxyForwardTask<'a> { let http_user_agent = self .req .end_to_end_headers - .get(http::header::USER_AGENT) + .get(header::USER_AGENT) .map(|v| v.to_str()); TaskLogForHttpForward { upstream: self.host.config.upstream(), @@ -609,17 +611,24 @@ impl<'a> HttpRProxyForwardTask<'a> { CDR: AsyncRead + Unpin, CDW: AsyncWrite + Unpin, { + if reused_connection { + if let Some(r) = ups_c.1.fill_wait_eof().now_or_never() { + return match r { + Ok(_) => Err(ServerTaskError::ClosedByUpstream), + Err(e) => Err(ServerTaskError::UpstreamReadFailed(e)), + }; + } + } ups_c .0 .prepare_new(&self.task_notes, self.host.config.upstream()); if self.req.body_type().is_none() { self.mark_relaying(); - self.run_without_body(clt_w, ups_c, reused_connection).await + self.run_without_body(clt_w, ups_c).await } else if let Some(br) = clt_r { self.mark_relaying(); - self.run_with_body(br, clt_w, ups_c, reused_connection) - .await + self.run_with_body(br, clt_w, ups_c).await } else { // there should be a body reader Err(ServerTaskError::InternalServerError( @@ -632,7 +641,6 @@ impl<'a> HttpRProxyForwardTask<'a> { &'f mut self, clt_w: &'f mut W, mut ups_c: BoxHttpForwardConnection, - reused_connection: bool, ) -> ServerTaskResult> where W: AsyncWrite + Unpin, @@ -640,16 +648,14 @@ impl<'a> HttpRProxyForwardTask<'a> { let ups_w = &mut ups_c.0; let ups_r = &mut ups_c.1; - if reused_connection { - if let Some(r) = ups_r.fill_wait_eof().now_or_never() { - return match r { - Ok(_) => Err(ServerTaskError::ClosedByUpstream), - Err(e) => Err(ServerTaskError::UpstreamReadFailed(e)), - }; - } - } - - self.send_request_header(ups_w).await?; + ups_w + .send_request_header(self.req, None) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; self.http_notes.mark_req_send_hdr(); self.http_notes.mark_req_no_body(); self.retry_new_connection = false; @@ -681,45 +687,132 @@ impl<'a> HttpRProxyForwardTask<'a> { } } - async fn run_with_body<'f, R, W>( + async fn run_with_small_body<'f, W>( &'f mut self, - clt_r: &'f mut R, + body: Vec, clt_w: &'f mut W, mut ups_c: BoxHttpForwardConnection, - reused_connection: bool, ) -> ServerTaskResult> where - R: AsyncBufRead + Unpin, W: AsyncWrite + Unpin, { let ups_w = &mut ups_c.0; let ups_r = &mut ups_c.1; - if reused_connection { - if let Some(r) = ups_r.fill_wait_eof().now_or_never() { - return match r { - Ok(_) => Err(ServerTaskError::ClosedByUpstream), - Err(e) => Err(ServerTaskError::UpstreamReadFailed(e)), - }; + self.retry_new_connection = false; + ups_w + .send_request_header(self.req, Some(body.as_slice())) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + self.http_notes.mark_req_send_hdr(); + self.http_notes.mark_req_send_all(); + + let mut rsp_header = match tokio::time::timeout( + self.ctx.server_config.timeout.recv_rsp_header, + self.recv_response_header(ups_r), + ) + .await + { + Ok(Ok(rsp_header)) => rsp_header, + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(ServerTaskError::UpstreamAppTimeout( + "timeout to receive response header", + )) } + }; + self.http_notes.mark_rsp_recv_hdr(); + + self.update_response_header(&mut rsp_header); + self.send_response(clt_w, ups_r, &rsp_header).await?; + + self.task_notes.stage = ServerTaskStage::Finished; + if self.should_close { + Ok(None) + } else { + Ok(Some(ups_c)) } + } - self.send_request_header(ups_w).await?; - self.http_notes.mark_req_send_hdr(); - self.retry_new_connection = false; + async fn run_with_body<'f, R, W>( + &'f mut self, + clt_r: &'f mut R, + clt_w: &'f mut W, + mut ups_c: BoxHttpForwardConnection, + ) -> ServerTaskResult> + where + R: AsyncBufRead + Unpin, + W: AsyncWrite + Unpin, + { + let ups_w = &mut ups_c.0; + let ups_r = &mut ups_c.1; let mut clt_body_reader = HttpBodyReader::new( clt_r, self.req.body_type().unwrap(), self.ctx.server_config.body_line_max_len, ); - let mut rsp_header: Option = None; - let mut clt_to_ups = LimitedCopy::new( - &mut clt_body_reader, - ups_w, - &self.ctx.server_config.tcp_copy, - ); + let mut clt_to_ups = if self.req.end_to_end_headers.contains_key(header::EXPECT) { + ups_w + .send_request_header(self.req, None) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + self.http_notes.mark_req_send_hdr(); + self.retry_new_connection = false; + + LimitedCopy::new( + &mut clt_body_reader, + ups_w, + &self.ctx.server_config.tcp_copy, + ) + } else { + let mut pre_read_buf = vec![0u8; self.ctx.server_config.tcp_copy.buffer_size()]; + match clt_body_reader + .read_to_end(&mut pre_read_buf) + .now_or_never() + { + Some(Ok(0)) => return Err(ServerTaskError::ClosedByClient), + Some(Ok(n)) => { + pre_read_buf.truncate(n); + if clt_body_reader.finished() { + return self.run_with_small_body(pre_read_buf, clt_w, ups_c).await; + } + } + Some(Err(e)) => return Err(ServerTaskError::ClientTcpReadFailed(e)), + None => { + pre_read_buf.truncate(0); + } + } + + ups_w + .send_request_header(self.req, None) + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + ups_w + .flush() + .await + .map_err(ServerTaskError::UpstreamWriteFailed)?; + self.http_notes.mark_req_send_hdr(); + self.retry_new_connection = false; + + LimitedCopy::with_data( + &mut clt_body_reader, + ups_w, + &self.ctx.server_config.tcp_copy, + pre_read_buf.to_vec(), + ) + }; + + let mut rsp_header: Option = None; let idle_duration = self.ctx.server_config.task_idle_check_duration; let mut idle_interval = @@ -872,18 +965,6 @@ impl<'a> HttpRProxyForwardTask<'a> { } } - async fn send_request_header(&self, ups_w: &mut BoxHttpForwardWriter) -> ServerTaskResult<()> { - ups_w - .send_request_header(self.req) - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - ups_w - .flush() - .await - .map_err(ServerTaskError::UpstreamWriteFailed)?; - Ok(()) - } - async fn recv_response_header( &mut self, ups_r: &mut BoxHttpForwardReader,