diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index a214892b..85c4e2b4 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -229,7 +229,7 @@ impl Counts { } } - if stream.is_counted { + if !stream.state.is_scheduled_reset() && stream.is_counted { tracing::trace!("dec_num_streams; stream={:?}", stream.id); // Decrement the number of active streams. self.dec_num_streams(&mut stream); diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 74c34542..b67f39fe 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -685,8 +685,11 @@ impl Prioritize { } pub fn clear_pending_send(&mut self, store: &mut Store, counts: &mut Counts) { - while let Some(stream) = self.pending_send.pop(store) { + while let Some(mut stream) = self.pending_send.pop(store) { let is_pending_reset = stream.is_pending_reset_expiration(); + if let Some(reason) = stream.state.get_scheduled_reset() { + stream.set_reset(reason, Initiator::Library); + } counts.transition_after(stream, is_pending_reset); } } diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index facd367e..8ec2cf31 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1218,3 +1218,140 @@ async fn reset_new_stream_before_send() { join(srv, client).await; } + +#[tokio::test] +async fn explicit_reset_with_max_concurrent_stream() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let mock = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().max_concurrent_streams(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + + srv.recv_frame(frames::reset(1).cancel()).await; + + srv.recv_frame( + frames::headers(3) + .request("POST", "https://www.example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, mut stream) = client.send_request(request, false).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + stream.send_reset(Reason::CANCEL); + }; + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, _) = client.send_request(request, true).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + }; + + h2.await.unwrap(); + }; + + join(mock, h2).await; +} + +#[tokio::test] +async fn implicit_cancel_with_max_concurrent_stream() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let mock = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().max_concurrent_streams(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + + srv.recv_frame(frames::reset(1).cancel()).await; + + srv.recv_frame( + frames::headers(3) + .request("POST", "https://www.example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, stream) = client.send_request(request, false).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + // This implicitly resets the stream with CANCEL. + drop(stream); + }; + + { + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (resp, _) = client.send_request(request, true).unwrap(); + + { + let resp = h2.drive(resp).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + }; + + h2.await.unwrap(); + }; + + join(mock, h2).await; +}