From 3d0dd274589cf7f12af3c9c04a2f8fcb8594503e Mon Sep 17 00:00:00 2001 From: Ilion Beyst Date: Sat, 12 Aug 2023 16:47:42 +0100 Subject: [PATCH] Client-side support for Expect: 100-continue This patch modifies hyper client to behave correctly when the `Expect: 100-continue` header is set on a request. The request body will now only be sent once a 100 Continue response is receved from the server. --- src/proto/h1/conn.rs | 18 +++++++++++++++ src/proto/h1/dispatch.rs | 9 ++++++++ src/proto/h1/io.rs | 4 ++++ src/proto/h1/mod.rs | 2 ++ src/proto/h1/role.rs | 36 +++++++++++++++++++++++++++++ tests/client.rs | 49 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 118 insertions(+) diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 3807adb46a..800f1ebecc 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -75,6 +75,8 @@ where // We assume a modern world where the remote speaks HTTP/1.1. // If they tell us otherwise, we'll downgrade in `read_head`. version: Version::HTTP_11, + #[cfg(feature = "client")] + awaiting_100_continue: false, }, _marker: PhantomData, } @@ -103,6 +105,18 @@ where self.io.set_read_buf_exact_size(sz); } + #[cfg(feature = "client")] + pub(crate) fn set_awaiting_100_continue(&mut self, awaiting: bool) { + self.state.awaiting_100_continue = awaiting; + } + + pub(crate) fn is_awaiting_100_continue(&self) -> bool { + #[cfg(feature = "client")] + return self.state.awaiting_100_continue; + #[cfg(not(feature = "client"))] + return false; + } + pub(crate) fn set_write_strategy_flatten(&mut self) { self.io.set_write_strategy_flatten(); } @@ -219,6 +233,8 @@ where h09_responses: self.state.h09_responses, #[cfg(feature = "ffi")] on_informational: &mut self.state.on_informational, + #[cfg(feature = "client")] + awaiting_100_continue: &mut self.state.awaiting_100_continue, } )) { Ok(msg) => msg, @@ -842,6 +858,8 @@ struct State { upgrade: Option, /// Either HTTP/1.0 or 1.1 connection version: Version, + #[cfg(feature = "client")] + awaiting_100_continue: bool, } #[derive(Debug)] diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index c29c15dcae..6f779ff3f2 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -310,6 +310,11 @@ where { if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) { let (head, body) = msg.map_err(crate::Error::new_user_service)?; + let expect_100_continue = T::is_client() + && head.headers.get(http::header::EXPECT).map_or(false, |h| { + h.as_bytes().eq_ignore_ascii_case(b"100-continue") + }); + let body_type = if body.is_end_stream() { self.body_rx.set(None); @@ -324,10 +329,14 @@ where btype }; self.conn.write_head(head, body_type); + #[cfg(feature = "client")] + self.conn.set_awaiting_100_continue(expect_100_continue); } else { self.close(); return Poll::Ready(Ok(())); } + } else if self.conn.is_awaiting_100_continue() { + ready!(self.poll_read(cx))?; } else if !self.conn.can_buffer_body() { ready!(self.poll_flush(cx))?; } else { diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index d0c0cba3bf..887346ec68 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -197,6 +197,8 @@ where h09_responses: parse_ctx.h09_responses, #[cfg(feature = "ffi")] on_informational: parse_ctx.on_informational, + #[cfg(feature = "client")] + awaiting_100_continue: parse_ctx.awaiting_100_continue, }, )? { Some(msg) => { @@ -734,6 +736,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; assert!(buffered .parse::(cx, parse_ctx) diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 86561c3764..f9e3e1dba8 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -92,6 +92,8 @@ pub(crate) struct ParseContext<'a> { h09_responses: bool, #[cfg(feature = "ffi")] on_informational: &'a mut Option, + #[cfg(feature = "client")] + awaiting_100_continue: &'a mut bool, } /// Passed to Http1Transaction::encode diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index c30a4948f9..a5381ad08a 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -1073,6 +1073,10 @@ impl Http1Transaction for Client { })); } + if head.subject == StatusCode::CONTINUE { + *ctx.awaiting_100_continue = false; + } + #[cfg(feature = "ffi")] if head.subject.is_informational() { if let Some(callback) = ctx.on_informational { @@ -1574,6 +1578,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .unwrap() @@ -1605,6 +1611,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); assert_eq!(raw.len(), 0); @@ -1631,6 +1639,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; Server::parse(&mut raw, ctx).unwrap_err(); } @@ -1655,6 +1665,8 @@ mod tests { h09_responses: true, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); assert_eq!(raw, H09_RESPONSE); @@ -1681,6 +1693,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; Client::parse(&mut raw, ctx).unwrap_err(); assert_eq!(raw, H09_RESPONSE); @@ -1711,6 +1725,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); assert_eq!(raw.len(), 0); @@ -1738,6 +1754,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; Client::parse(&mut raw, ctx).unwrap_err(); } @@ -1760,6 +1778,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }; let parsed_message = Server::parse(&mut raw, ctx).unwrap().unwrap(); let orig_headers = parsed_message @@ -1803,6 +1823,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .expect("parse ok") @@ -1827,6 +1849,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .expect_err(comment) @@ -2060,6 +2084,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, } ) .expect("parse ok") @@ -2084,6 +2110,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .expect("parse ok") @@ -2108,6 +2136,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .expect_err("parse should err") @@ -2627,6 +2657,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .expect("parse ok") @@ -2715,6 +2747,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .unwrap() @@ -2759,6 +2793,8 @@ mod tests { h09_responses: false, #[cfg(feature = "ffi")] on_informational: &mut None, + #[cfg(feature = "client")] + awaiting_100_continue: &mut false, }, ) .unwrap() diff --git a/tests/client.rs b/tests/client.rs index 92955af360..f747411664 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -2318,6 +2318,55 @@ mod conn { assert!(error.is_user()); } + #[tokio::test] + async fn test_await_100_continue() { + let (listener, addr) = setup_tk_test_server().await; + + let server = async move { + let mut sock = listener.accept().await.unwrap().0; + let mut buf = [0; 4096]; + let n = sock.read(&mut buf).await.expect("read 1"); + + // we should have received just the headers + let expected = "PUT /a HTTP/1.1\r\nexpect: 100-continue\r\ncontent-length: 8\r\n\r\n"; + assert_eq!(s(&buf[..n]), expected); + + sock.write_all(b"HTTP/1.1 100 Continue\r\n\r\n") + .await + .unwrap(); + + let n = sock.read(&mut buf).await.expect("read 2"); + + // the next read should hold the body + let expected = "baguette"; + assert_eq!(s(&buf[..n]), expected); + + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") + .await + .unwrap(); + }; + + let client = async move { + let tcp = tcp_connect(&addr).await.expect("connect"); + let (mut client, conn) = conn::http1::handshake(tcp).await.expect("handshake"); + + tokio::task::spawn(async move { + conn.await.expect("http conn"); + }); + + let req = Request::builder() + .method(Method::PUT) + .header(http::header::EXPECT, "100-continue") + .uri("/a") + .body(String::from("baguette")) + .unwrap(); + let res = client.send_request(req).await.expect("send_request"); + assert_eq!(res.status(), hyper::StatusCode::OK); + }; + + future::join(server, client).await; + } + async fn drain_til_eof(mut sock: T) -> io::Result<()> { let mut buf = [0u8; 1024]; loop {