diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 5b9fd8aea..b65951c21 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -254,6 +254,11 @@ impl Headers { &mut self.header_block.pseudo } + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.header_block.pseudo.is_informational() + } + pub fn fields(&self) -> &HeaderMap { &self.header_block.fields } @@ -599,6 +604,12 @@ impl Pseudo { pub fn set_authority(&mut self, authority: BytesStr) { self.authority = Some(authority); } + + /// Whether it has status 1xx + pub(crate) fn is_informational(&self) -> bool { + self.status + .map_or(false, |status| status.is_informational()) + } } // ===== impl EncodingHeaderBlock ===== diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 682200d45..7a6ff8ad2 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -161,7 +161,7 @@ impl Recv { counts: &mut Counts, ) -> Result<(), RecvHeaderBlockError>> { tracing::trace!("opening stream; init_window={}", self.init_window_sz); - let is_initial = stream.state.recv_open(frame.is_end_stream())?; + let is_initial = stream.state.recv_open(&frame)?; if is_initial { // TODO: be smarter about this logic @@ -226,15 +226,17 @@ impl Recv { let stream_id = frame.stream_id(); let (pseudo, fields) = frame.into_parts(); - let message = counts - .peer() - .convert_poll_message(pseudo, fields, stream_id)?; - - // Push the frame onto the stream's recv buffer - stream - .pending_recv - .push_back(&mut self.buffer, Event::Headers(message)); - stream.notify_recv(); + if !pseudo.is_informational() { + let message = counts + .peer() + .convert_poll_message(pseudo, fields, stream_id)?; + + // Push the frame onto the stream's recv buffer + stream + .pending_recv + .push_back(&mut self.buffer, Event::Headers(message)); + stream.notify_recv(); + } // Only servers can receive a headers frame that initiates the stream. // This is verified in `Streams` before calling this function. diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 45ec82f90..08d4dba00 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -2,7 +2,7 @@ use std::io; use crate::codec::UserError::*; use crate::codec::{RecvError, UserError}; -use crate::frame::Reason; +use crate::frame::{self, Reason}; use crate::proto::{self, PollReset}; use self::Inner::*; @@ -132,10 +132,13 @@ impl State { /// Opens the receive-half of the stream when a HEADERS frame is received. /// + /// is_informational: whether received a 1xx status code + /// /// Returns true if this transitions the state to Open. - pub fn recv_open(&mut self, eos: bool) -> Result { + pub fn recv_open(&mut self, frame: &frame::Headers) -> Result { let remote = Streaming; let mut initial = false; + let eos = frame.is_end_stream(); self.inner = match self.inner { Idle => { @@ -172,6 +175,9 @@ impl State { HalfClosedLocal(AwaitingHeaders) => { if eos { Closed(Cause::EndStream) + } else if frame.is_informational() { + tracing::trace!("skipping 1xx response headers"); + HalfClosedLocal(AwaitingHeaders) } else { HalfClosedLocal(remote) } diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 35b4beacf..23b5dbb50 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1215,6 +1215,48 @@ async fn allow_empty_data_for_head() { join(srv, h2).await; } +#[tokio::test] +async fn early_hints() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(103)).await; + srv.send_frame(frames::headers(1).response(200).field("content-length", 2)) + .await; + srv.send_frame(frames::data(1, "ok").eos()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .unwrap(); + tokio::spawn(async { + h2.await.expect("connection failed"); + }); + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + let (ha, mut body) = response.await.unwrap().into_parts(); + eprintln!("{:?}", ha); + assert_eq!(body.data().await.unwrap().unwrap(), "ok"); + }; + + join(srv, h2).await; +} + const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];