Skip to content

Commit

Permalink
Handle implicit resets at the right time
Browse files Browse the repository at this point in the history
A stream whose ref count reaches zero while open should
not immediately decrease the number of active streams,
otherwise MAX_CONCURRENT_STREAMS isn't respected anymore.
  • Loading branch information
nox committed Jan 20, 2025
1 parent d7c56f4 commit 38ebb1b
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 24 deletions.
4 changes: 3 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ where
///
/// [module]: index.html
pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), crate::Error>> {
ready!(self.inner.poll_pending_open(cx, self.pending.as_ref()))?;
ready!(self
.inner
.poll_pending_open(cx, self.pending.as_ref()))?;
self.pending = None;
Poll::Ready(Ok(()))
}
Expand Down
1 change: 1 addition & 0 deletions src/proto/streams/prioritize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ impl Prioritize {
}),
None => {
if let Some(reason) = stream.state.get_scheduled_reset() {
stream.state.did_schedule_reset();
stream.set_reset(reason, Initiator::Library);

let frame = frame::Reset::new(stream.id, reason);
Expand Down
2 changes: 1 addition & 1 deletion src/proto/streams/recv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ impl Recv {

tracing::trace!("enqueue_reset_expiration; {:?}", stream.id);

if counts.can_inc_num_reset_streams() {
if dbg!(counts.can_inc_num_reset_streams()) {
counts.inc_num_reset_streams();
self.pending_reset_expired.push(stream);
}
Expand Down
49 changes: 27 additions & 22 deletions src/proto/streams/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,20 @@ enum Inner {
// TODO: these states shouldn't count against concurrency limits:
ReservedLocal,
ReservedRemote,
Open { local: Peer, remote: Peer },
Open {
local: Peer,
remote: Peer,
},
HalfClosedLocal(Peer), // TODO: explicitly name this value
HalfClosedRemote(Peer),
/// This indicates to the connection that a reset frame must be sent out
/// once the send queue has been flushed.
///
/// Examples of when this could happen:
/// - User drops all references to a stream, so we want to CANCEL the it.
/// - Header block size was too large, so we want to REFUSE, possibly
/// after sending a 431 response frame.
ScheduledReset(Reason),
Closed(Cause),
}

Expand All @@ -75,15 +86,6 @@ enum Peer {
enum Cause {
EndStream,
Error(Error),

/// This indicates to the connection that a reset frame must be sent out
/// once the send queue has been flushed.
///
/// Examples of when this could happen:
/// - User drops all references to a stream, so we want to CANCEL the it.
/// - Header block size was too large, so we want to REFUSE, possibly
/// after sending a 431 response frame.
ScheduledLibraryReset(Reason),
}

impl State {
Expand Down Expand Up @@ -339,24 +341,29 @@ impl State {
/// Set the stream state to a scheduled reset.
pub fn set_scheduled_reset(&mut self, reason: Reason) {
debug_assert!(!self.is_closed());
self.inner = Closed(Cause::ScheduledLibraryReset(reason));
self.inner = ScheduledReset(reason)
}

pub fn get_scheduled_reset(&self) -> Option<Reason> {
match self.inner {
Closed(Cause::ScheduledLibraryReset(reason)) => Some(reason),
ScheduledReset(reason) => Some(reason),
_ => None,
}
}

pub fn is_scheduled_reset(&self) -> bool {
matches!(self.inner, Closed(Cause::ScheduledLibraryReset(..)))
matches!(self.inner, ScheduledReset(_))
}

pub fn did_schedule_reset(&mut self) {
debug_assert!(self.is_scheduled_reset());
self.inner = Closed(Cause::EndStream);
}

pub fn is_local_error(&self) -> bool {
match self.inner {
ScheduledReset(_) => true,
Closed(Cause::Error(ref e)) => e.is_local(),
Closed(Cause::ScheduledLibraryReset(..)) => true,
_ => false,
}
}
Expand Down Expand Up @@ -416,14 +423,14 @@ impl State {
pub fn is_recv_closed(&self) -> bool {
matches!(
self.inner,
Closed(..) | HalfClosedRemote(..) | ReservedLocal
ScheduledReset(_) | Closed(..) | HalfClosedRemote(..) | ReservedLocal
)
}

pub fn is_send_closed(&self) -> bool {
matches!(
self.inner,
Closed(..) | HalfClosedLocal(..) | ReservedRemote
ScheduledReset(_) | Closed(..) | HalfClosedLocal(..) | ReservedRemote
)
}

Expand All @@ -434,10 +441,8 @@ impl State {
pub fn ensure_recv_open(&self) -> Result<bool, proto::Error> {
// TODO: Is this correct?
match self.inner {
ScheduledReset(reason) => Err(proto::Error::library_go_away(reason)),
Closed(Cause::Error(ref e)) => Err(e.clone()),
Closed(Cause::ScheduledLibraryReset(reason)) => {
Err(proto::Error::library_go_away(reason))
}
Closed(Cause::EndStream) | HalfClosedRemote(..) | ReservedLocal => Ok(false),
_ => Ok(true),
}
Expand All @@ -446,9 +451,9 @@ impl State {
/// Returns a reason if the stream has been reset.
pub(super) fn ensure_reason(&self, mode: PollReset) -> Result<Option<Reason>, crate::Error> {
match self.inner {
Closed(Cause::Error(Error::Reset(_, reason, _)))
| Closed(Cause::Error(Error::GoAway(_, reason, _)))
| Closed(Cause::ScheduledLibraryReset(reason)) => Ok(Some(reason)),
ScheduledReset(reason)
| Closed(Cause::Error(Error::Reset(_, reason, _)))
| Closed(Cause::Error(Error::GoAway(_, reason, _))) => Ok(Some(reason)),
Closed(Cause::Error(ref e)) => Err(e.clone().into()),
Open {
local: Streaming, ..
Expand Down
137 changes: 137 additions & 0 deletions tests/h2-tests/tests/stream_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,3 +1218,140 @@ async fn reset_new_stream_before_send() {

join(srv, client).await;
}

#[tokio::test]
async fn explicit_reset_with_max_concurrent_stream() {
h2_support::trace_init!();

let (io, mut srv) = mock::new();

let mock = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().max_concurrent_streams(1))
.await;
assert_default_settings!(settings);

srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/"))
.await;
srv.send_frame(frames::headers(1).response(200)).await;

srv.recv_frame(frames::reset(1).cancel()).await;

srv.recv_frame(
frames::headers(3)
.request("POST", "https://www.example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(3).response(200)).await;
};

let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();

{
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();

let (resp, mut stream) = client.send_request(request, false).unwrap();

{
let resp = h2.drive(resp).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}

stream.send_reset(Reason::CANCEL);
};

{
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();

let (resp, _) = client.send_request(request, true).unwrap();

{
let resp = h2.drive(resp).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
};

h2.await.unwrap();
};

join(mock, h2).await;
}

#[tokio::test]
async fn implicit_cancel_with_max_concurrent_stream() {
h2_support::trace_init!();

let (io, mut srv) = mock::new();

let mock = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().max_concurrent_streams(1))
.await;
assert_default_settings!(settings);

srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/"))
.await;
srv.send_frame(frames::headers(1).response(200)).await;

srv.recv_frame(frames::reset(1).cancel()).await;

srv.recv_frame(
frames::headers(3)
.request("POST", "https://www.example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(3).response(200)).await;
};

let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();

{
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();

let (resp, stream) = client.send_request(request, false).unwrap();

{
let resp = h2.drive(resp).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}

// This implicitly resets the stream with CANCEL.
drop(stream);
};

{
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();

let (resp, _) = client.send_request(request, true).unwrap();

{
let resp = h2.drive(resp).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
};

h2.await.unwrap();
};

join(mock, h2).await;
}

0 comments on commit 38ebb1b

Please sign in to comment.