From 98bbe59a95b2b1d2b8eed4a8d3d568fc7165552d Mon Sep 17 00:00:00 2001 From: Kershaw Date: Thu, 6 Jul 2023 17:56:21 +0200 Subject: [PATCH] WebTransportReceiveStreamStats Impl (#1446) * WebTransportReceiveStreamStats Impl * address comments --- neqo-http3/src/connection_client.rs | 17 +++++-- .../tests/webtransport/mod.rs | 8 ++- .../tests/webtransport/streams.rs | 51 ++++++++++++------- .../extended_connect/webtransport_streams.rs | 37 +++++++++++++- neqo-http3/src/lib.rs | 9 +++- neqo-transport/src/connection/mod.rs | 9 +++- neqo-transport/src/lib.rs | 6 +-- 7 files changed, 107 insertions(+), 30 deletions(-) diff --git a/neqo-http3/src/connection_client.rs b/neqo-http3/src/connection_client.rs index 48c9e1baf7..8d0d78922a 100644 --- a/neqo-http3/src/connection_client.rs +++ b/neqo-http3/src/connection_client.rs @@ -21,9 +21,9 @@ use neqo_common::{ use neqo_crypto::{agent::CertificateInfo, AuthenticationStatus, ResumptionToken, SecretAgentInfo}; use neqo_qpack::Stats as QpackStats; use neqo_transport::{ - send_stream::SendStreamStats, streams::SendOrder, AppError, Connection, ConnectionEvent, - ConnectionId, ConnectionIdGenerator, DatagramTracking, Output, Stats as TransportStats, - StreamId, StreamType, Version, ZeroRttState, + streams::SendOrder, AppError, Connection, ConnectionEvent, ConnectionId, ConnectionIdGenerator, + DatagramTracking, Output, RecvStreamStats, SendStreamStats, Stats as TransportStats, StreamId, + StreamType, Version, ZeroRttState, }; use std::{ cell::RefCell, @@ -788,6 +788,17 @@ impl Http3Client { .stats(&mut self.conn) } + /// Returns the current `RecvStreamStats` of a `WebTransportRecvStream`. + /// # Errors + /// `InvalidStreamId` if the stream does not exist. + pub fn webtransport_recv_stream_stats(&mut self, stream_id: StreamId) -> Res { + self.base_handler + .recv_streams + .get_mut(&stream_id) + .ok_or(Error::InvalidStreamId)? + .stats(&mut self.conn) + } + /// This function combines `process_input` and `process_output` function. pub fn process(&mut self, dgram: Option, now: Instant) -> Output { qtrace!([self], "Process."); diff --git a/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs b/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs index f06fb0953f..4ac5f72b0f 100644 --- a/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs +++ b/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs @@ -13,8 +13,8 @@ use neqo_common::event::Provider; use crate::{ features::extended_connect::SessionCloseReason, Error, Header, Http3Client, Http3ClientEvent, Http3OrWebTransportStream, Http3Parameters, Http3Server, Http3ServerEvent, Http3State, - SendStreamStats, WebTransportEvent, WebTransportRequest, WebTransportServerEvent, - WebTransportSessionAcceptAction, + RecvStreamStats, SendStreamStats, WebTransportEvent, WebTransportRequest, + WebTransportServerEvent, WebTransportSessionAcceptAction, }; use neqo_crypto::AuthenticationStatus; use neqo_transport::{ConnectionParameters, StreamId, StreamType}; @@ -315,6 +315,10 @@ impl WtTest { self.client.webtransport_send_stream_stats(wt_stream_id) } + fn recv_stream_stats(&mut self, wt_stream_id: StreamId) -> Result { + self.client.webtransport_recv_stream_stats(wt_stream_id) + } + fn receive_data_client( &mut self, expected_stream_id: StreamId, diff --git a/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs b/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs index bd99c3ef0c..a50c45d518 100644 --- a/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs +++ b/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs @@ -16,25 +16,28 @@ fn wt_client_stream_uni() { let mut wt = WtTest::new(); let wt_session = wt.create_wt_session(); let wt_stream = wt.create_wt_stream_client(wt_session.stream_id(), StreamType::UniDi); - let stats = wt.send_stream_stats(wt_stream).unwrap(); - assert_eq!(stats.bytes_written(), 0); - assert_eq!(stats.bytes_sent(), 0); - assert_eq!(stats.bytes_acked(), 0); + let send_stats = wt.send_stream_stats(wt_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), 0); + assert_eq!(send_stats.bytes_sent(), 0); + assert_eq!(send_stats.bytes_acked(), 0); wt.send_data_client(wt_stream, BUF_CLIENT); wt.receive_data_server(wt_stream, true, BUF_CLIENT, false); - let stats = wt.send_stream_stats(wt_stream).unwrap(); - assert_eq!(stats.bytes_written(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_sent(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_acked(), BUF_CLIENT.len() as u64); + let send_stats = wt.send_stream_stats(wt_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_sent(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_acked(), BUF_CLIENT.len() as u64); // Send data again to test if the stats has the expected values. wt.send_data_client(wt_stream, BUF_CLIENT); wt.receive_data_server(wt_stream, false, BUF_CLIENT, false); - let stats = wt.send_stream_stats(wt_stream).unwrap(); - assert_eq!(stats.bytes_written(), (BUF_CLIENT.len() * 2) as u64); - assert_eq!(stats.bytes_sent(), (BUF_CLIENT.len() * 2) as u64); - assert_eq!(stats.bytes_acked(), (BUF_CLIENT.len() * 2) as u64); + let send_stats = wt.send_stream_stats(wt_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), (BUF_CLIENT.len() * 2) as u64); + assert_eq!(send_stats.bytes_sent(), (BUF_CLIENT.len() * 2) as u64); + assert_eq!(send_stats.bytes_acked(), (BUF_CLIENT.len() * 2) as u64); + + let recv_stats = wt.recv_stream_stats(wt_stream); + assert_eq!(recv_stats.unwrap_err(), Error::InvalidStreamId); } #[test] @@ -49,10 +52,14 @@ fn wt_client_stream_bidi() { let mut wt_server_stream = wt.receive_data_server(wt_client_stream, true, BUF_CLIENT, false); wt.send_data_server(&mut wt_server_stream, BUF_SERVER); wt.receive_data_client(wt_client_stream, false, BUF_SERVER, false); - let stats = wt.send_stream_stats(wt_client_stream).unwrap(); - assert_eq!(stats.bytes_written(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_sent(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_acked(), BUF_CLIENT.len() as u64); + let send_stats = wt.send_stream_stats(wt_client_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_sent(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_acked(), BUF_CLIENT.len() as u64); + + let recv_stats = wt.recv_stream_stats(wt_client_stream).unwrap(); + assert_eq!(recv_stats.bytes_received(), BUF_SERVER.len() as u64); + assert_eq!(recv_stats.bytes_read(), BUF_SERVER.len() as u64); } #[test] @@ -64,8 +71,12 @@ fn wt_server_stream_uni() { let mut wt_server_stream = WtTest::create_wt_stream_server(&mut wt_session, StreamType::UniDi); wt.send_data_server(&mut wt_server_stream, BUF_SERVER); wt.receive_data_client(wt_server_stream.stream_id(), true, BUF_SERVER, false); - let stats = wt.send_stream_stats(wt_server_stream.stream_id()); - assert_eq!(stats.unwrap_err(), Error::InvalidStreamId); + let send_stats = wt.send_stream_stats(wt_server_stream.stream_id()); + assert_eq!(send_stats.unwrap_err(), Error::InvalidStreamId); + + let recv_stats = wt.recv_stream_stats(wt_server_stream.stream_id()).unwrap(); + assert_eq!(recv_stats.bytes_received(), BUF_SERVER.len() as u64); + assert_eq!(recv_stats.bytes_read(), BUF_SERVER.len() as u64); } #[test] @@ -84,6 +95,10 @@ fn wt_server_stream_bidi() { assert_eq!(stats.bytes_written(), BUF_CLIENT.len() as u64); assert_eq!(stats.bytes_sent(), BUF_CLIENT.len() as u64); assert_eq!(stats.bytes_acked(), BUF_CLIENT.len() as u64); + + let recv_stats = wt.recv_stream_stats(wt_server_stream.stream_id()).unwrap(); + assert_eq!(recv_stats.bytes_received(), BUF_SERVER.len() as u64); + assert_eq!(recv_stats.bytes_read(), BUF_SERVER.len() as u64); } #[test] diff --git a/neqo-http3/src/features/extended_connect/webtransport_streams.rs b/neqo-http3/src/features/extended_connect/webtransport_streams.rs index f463fd8b2f..ca918dce9e 100644 --- a/neqo-http3/src/features/extended_connect/webtransport_streams.rs +++ b/neqo-http3/src/features/extended_connect/webtransport_streams.rs @@ -10,7 +10,7 @@ use crate::{ SendStream, SendStreamEvents, Stream, }; use neqo_common::Encoder; -use neqo_transport::{send_stream::SendStreamStats, Connection, StreamId}; +use neqo_transport::{Connection, RecvStreamStats, SendStreamStats, StreamId}; use std::cell::RefCell; use std::rc::Rc; @@ -75,6 +75,35 @@ impl RecvStream for WebTransportRecvStream { } Ok((amount, fin)) } + + fn stats(&mut self, conn: &mut Connection) -> Res { + const TYPE_LEN_UNI: usize = Encoder::varint_len(WEBTRANSPORT_UNI_STREAM); + const TYPE_LEN_BIDI: usize = Encoder::varint_len(WEBTRANSPORT_STREAM); + + let stream_header_size = if self.stream_id.is_server_initiated() { + let id_len = if self.stream_id.is_uni() { + TYPE_LEN_UNI + } else { + TYPE_LEN_BIDI + }; + (id_len + Encoder::varint_len(self.session_id.as_u64())) as u64 + } else { + 0 + }; + + let stats = conn.recv_stream_stats(self.stream_id)?; + if stream_header_size == 0 { + return Ok(stats); + } + + let subtract_non_app_bytes = + |count: u64| -> u64 { count.saturating_sub(stream_header_size) }; + + let bytes_received = subtract_non_app_bytes(stats.bytes_received()); + let bytes_read = subtract_non_app_bytes(stats.bytes_read()); + + Ok(RecvStreamStats::new(bytes_received, bytes_read)) + } } #[derive(Debug, PartialEq)] @@ -225,10 +254,14 @@ impl SendStream for WebTransportSendStream { 0 }; + let stats = conn.send_stream_stats(self.stream_id)?; + if stream_header_size == 0 { + return Ok(stats); + } + let subtract_non_app_bytes = |count: u64| -> u64 { count.saturating_sub(stream_header_size) }; - let stats = conn.stream_stats(self.stream_id)?; let bytes_written = subtract_non_app_bytes(stats.bytes_written()); let bytes_sent = subtract_non_app_bytes(stats.bytes_sent()); let bytes_acked = subtract_non_app_bytes(stats.bytes_acked()); diff --git a/neqo-http3/src/lib.rs b/neqo-http3/src/lib.rs index d89b86af13..76be301a8e 100644 --- a/neqo-http3/src/lib.rs +++ b/neqo-http3/src/lib.rs @@ -161,8 +161,10 @@ mod settings; mod stream_type_reader; use neqo_qpack::Error as QpackError; -use neqo_transport::{send_stream::SendStreamStats, AppError, Connection, Error as TransportError}; pub use neqo_transport::{streams::SendOrder, Output, StreamId}; +use neqo_transport::{ + AppError, Connection, Error as TransportError, RecvStreamStats, SendStreamStats, +}; use std::fmt::Debug; use crate::priority::PriorityHandler; @@ -470,6 +472,11 @@ trait RecvStream: Stream { fn webtransport(&self) -> Option>> { None } + + /// This function is only implemented by `WebTransportRecvStream`. + fn stats(&mut self, _conn: &mut Connection) -> Res { + Err(Error::Unavailable) + } } trait HttpRecvStream: RecvStream { diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index e07ca0fc4b..0a388ea70a 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -38,6 +38,7 @@ use crate::{ }, }; +use crate::recv_stream::RecvStreamStats; pub use crate::send_stream::{RetransmissionPriority, SendStreamStats, TransmissionPriority}; use crate::{ crypto::{Crypto, CryptoDxState, CryptoSpace}, @@ -2963,10 +2964,16 @@ impl Connection { self.streams.set_fairness(stream_id, fairness) } - pub fn stream_stats(&self, stream_id: StreamId) -> Res { + pub fn send_stream_stats(&self, stream_id: StreamId) -> Res { self.streams.get_send_stream(stream_id).map(|s| s.stats()) } + pub fn recv_stream_stats(&mut self, stream_id: StreamId) -> Res { + let stream = self.streams.get_recv_stream_mut(stream_id)?; + + Ok(stream.stats()) + } + /// Send data on a stream. /// Returns how many bytes were successfully sent. Could be less /// than total, based on receiver credit space available, etc. diff --git a/neqo-transport/src/lib.rs b/neqo-transport/src/lib.rs index 643ec218ef..daff7e73c2 100644 --- a/neqo-transport/src/lib.rs +++ b/neqo-transport/src/lib.rs @@ -28,7 +28,7 @@ mod quic_datagrams; mod recovery; mod recv_stream; mod rtt; -pub mod send_stream; +mod send_stream; mod sender; pub mod server; mod stats; @@ -53,8 +53,8 @@ pub use self::stats::Stats; pub use self::stream_id::{StreamId, StreamType}; pub use self::version::Version; -pub use self::recv_stream::RECV_BUFFER_SIZE; -pub use self::send_stream::SEND_BUFFER_SIZE; +pub use self::recv_stream::{RecvStreamStats, RECV_BUFFER_SIZE}; +pub use self::send_stream::{SendStreamStats, SEND_BUFFER_SIZE}; pub type TransportError = u64; const ERROR_APPLICATION_CLOSE: TransportError = 12;