From 0327d9f77b54e80e19223d36d7355a4e1382a60a Mon Sep 17 00:00:00 2001 From: Yuchen Wu Date: Mon, 4 Nov 2024 08:37:01 -0800 Subject: [PATCH] Fix RecvStream::is_end_stream(): return true only when END_STREAM is received Before this change, it returned true on other types of disconnection as well. Fixes #806 --- src/proto/streams/recv.rs | 2 +- src/proto/streams/state.rs | 7 ++----- tests/h2-support/src/util.rs | 4 +++- tests/h2-tests/tests/stream_states.rs | 5 +++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index a70527e2..2b8f3f0c 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -557,7 +557,7 @@ impl Recv { } pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { - if !stream.state.is_recv_closed() { + if !stream.state.is_end_stream() { return false; } diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 5256f09c..c47c94e1 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -413,11 +413,8 @@ impl State { matches!(self.inner, Closed(_)) } - pub fn is_recv_closed(&self) -> bool { - matches!( - self.inner, - Closed(..) | HalfClosedRemote(..) | ReservedLocal - ) + pub fn is_end_stream(&self) -> bool { + matches!(self.inner, Closed(Cause::EndStream)) } pub fn is_send_closed(&self) -> bool { diff --git a/tests/h2-support/src/util.rs b/tests/h2-support/src/util.rs index 02b6450d..8c88637b 100644 --- a/tests/h2-support/src/util.rs +++ b/tests/h2-support/src/util.rs @@ -1,5 +1,6 @@ use bytes::{BufMut, Bytes}; use futures::ready; +use std::borrow::BorrowMut; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -8,7 +9,8 @@ pub fn byte_str(s: &str) -> h2::frame::BytesStr { h2::frame::BytesStr::try_from(Bytes::copy_from_slice(s.as_bytes())).unwrap() } -pub async fn concat(mut body: h2::RecvStream) -> Result { +pub async fn concat>(mut body: B) -> Result { + let body = body.borrow_mut(); let mut vec = Vec::new(); while let Some(chunk) = body.data().await { vec.put(chunk?); diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index 9a377d79..9d85ff35 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -338,13 +338,14 @@ async fn errors_if_recv_frame_exceeds_max_frame_size() { let req = async move { let resp = client.get("https://example.com/").await.expect("response"); assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - let res = util::concat(body).await; + let mut body = resp.into_parts().1; + let res = util::concat(&mut body).await; let err = res.unwrap_err(); assert_eq!( err.to_string(), "connection error detected: frame with invalid size" ); + assert!(!body.is_end_stream()); }; // client should see a conn error