From b3f049e0c0935d61e10e0ae3ca2da15dc92170c0 Mon Sep 17 00:00:00 2001 From: jesup Date: Mon, 19 Jun 2023 15:40:56 -0400 Subject: [PATCH 1/4] Implement stream fairness when sending data (#1443) * Add support for WebTransport SendOrder to neqo. No attempt at fairness within a sendorder (or for unordered streams) is attempted * Add support for WebTransport SendOrder to neqo. No attempt at fairness within a sendorder (or for unordered streams) is attempted * Change comment for stream_sendorder errors * Add unit test for sendorder * Responses to comments * Fix other test failures -- one relies on unordered streams being processed in stream_id order; the other was a bug in remove_terminal() * Switch to a sorted Vec for the not-sendordered streams for efficiency. * Fix lints * Lint fix * resolve comments * move sendorder HashSets to Vecs * use let vec = if let ... * sendorder_fixes * fix rebase problem * Implement fairness for streams * Only use fairness for WebTransport streams. Clean up error reporting for sendorder Also remove tabs from these files * resolve most comments * resolve Kershaw's comments * Responses to martin * Final comment resolution * remove tabs * more lint fixes * Implement fairness for streams * Only use fairness for WebTransport streams. Clean up error reporting for sendorder Also remove tabs from these files * lint fix * lint fixes * more lint fixes * lint fixes for tests * lint fixes * Implement fairness for streams * resolve most comments * resolve Kershaw's comments * Responses to martin --- neqo-http3/src/connection.rs | 47 +- neqo-http3/src/connection_client.rs | 28 +- .../extended_connect/webtransport_session.rs | 12 +- .../extended_connect/webtransport_streams.rs | 10 + neqo-http3/src/lib.rs | 8 +- neqo-http3/src/send_message.rs | 12 +- neqo-transport/src/connection/mod.rs | 24 +- neqo-transport/src/connection/tests/stream.rs | 200 +++++++- neqo-transport/src/send_stream.rs | 450 +++++++++++++++--- neqo-transport/src/stream_id.rs | 6 + neqo-transport/src/streams.rs | 63 ++- 11 files changed, 781 insertions(+), 79 deletions(-) diff --git a/neqo-http3/src/connection.rs b/neqo-http3/src/connection.rs index bb5b6451c4..f2d0f28806 100644 --- a/neqo-http3/src/connection.rs +++ b/neqo-http3/src/connection.rs @@ -31,8 +31,8 @@ use neqo_common::{qdebug, qerror, qinfo, qtrace, qwarn, Decoder, Header, Message use neqo_qpack::decoder::QPackDecoder; use neqo_qpack::encoder::QPackEncoder; use neqo_transport::{ - AppError, Connection, ConnectionError, DatagramTracking, State, StreamId, StreamType, - ZeroRttState, + streams::SendOrder, AppError, Connection, ConnectionError, DatagramTracking, State, StreamId, + StreamType, ZeroRttState, }; use std::cell::RefCell; use std::collections::{BTreeSet, HashMap}; @@ -232,7 +232,7 @@ possible if there is no buffered data. If a stream has buffered data it will be registered in the `streams_with_pending_data` queue and actual sending will be performed in the `process_sending` function call. (This is done in this way, i.e. data is buffered first and then sent, for 2 reasons: in this way, sending will happen in a -single function, therefore error handling and clean up is easier and the QUIIC layer may not be +single function, therefore error handling and clean up is easier and the QUIC layer may not be able to accept all data and being able to buffer data is required in any case.) The `send` and `send_data` functions may detect that the stream is closed and all outstanding data @@ -626,7 +626,7 @@ impl Http3Connection { } } - /// This is called when 0RTT has been reseted to clear `send_streams`, `recv_streams` and settings. + /// This is called when 0RTT has been reset to clear `send_streams`, `recv_streams` and settings. pub fn handle_zero_rtt_rejected(&mut self) -> Res<()> { if self.state == Http3State::ZeroRtt { self.state = Http3State::Initializing; @@ -735,6 +735,14 @@ impl Http3Connection { conn.stream_stop_sending(stream_id, Error::HttpStreamCreation.code())?; return Ok(ReceiveOutput::NoOutput); } + // set incoming WebTransport streams to be fair (share bandwidth) + conn.stream_fairness(stream_id, true).ok(); + qinfo!( + [self], + "A new WebTransport stream {} for session {}.", + stream_id, + session_id + ); } NewStreamType::Unknown => { conn.stream_stop_sending(stream_id, Error::HttpStreamCreation.code())?; @@ -920,7 +928,7 @@ impl Http3Connection { ); // Call immediately send so that at least headers get sent. This will make Firefox faster, since - // it can send request body immediatly in most cases and does not need to do a complete process loop. + // it can send request body immediately in most cases and does not need to do a complete process loop. self.send_streams .get_mut(&stream_id) .ok_or(Error::InvalidStreamId)? @@ -995,6 +1003,32 @@ impl Http3Connection { Ok(()) } + /// Set the stream `SendOrder`. + /// # Errors + /// Returns `InvalidStreamId` if the stream id doesn't exist + pub fn stream_set_sendorder( + conn: &mut Connection, + stream_id: StreamId, + sendorder: Option, + ) -> Res<()> { + conn.stream_sendorder(stream_id, sendorder) + .map_err(|_| Error::InvalidStreamId) + } + + /// Set the stream Fairness. Fair streams will share bandwidth with other + /// streams of the same sendOrder group (or the unordered group). Unfair streams + /// will give bandwidth preferentially to the lowest streamId with data to send. + /// # Errors + /// Returns `InvalidStreamId` if the stream id doesn't exist + pub fn stream_set_fairness( + conn: &mut Connection, + stream_id: StreamId, + fairness: bool, + ) -> Res<()> { + conn.stream_fairness(stream_id, fairness) + .map_err(|_| Error::InvalidStreamId) + } + pub fn cancel_fetch( &mut self, stream_id: StreamId, @@ -1238,6 +1272,9 @@ impl Http3Connection { let stream_id = conn .stream_create(stream_type) .map_err(|e| Error::map_stream_create_errors(&e))?; + // Set outgoing WebTransport streams to be fair (share bandwidth) + // This really can't fail, panics if it does + conn.stream_fairness(stream_id, true).unwrap(); self.webtransport_create_stream_internal( wt, diff --git a/neqo-http3/src/connection_client.rs b/neqo-http3/src/connection_client.rs index d2ed1f526d..48c9e1baf7 100644 --- a/neqo-http3/src/connection_client.rs +++ b/neqo-http3/src/connection_client.rs @@ -21,9 +21,9 @@ use neqo_common::{ use neqo_crypto::{agent::CertificateInfo, AuthenticationStatus, ResumptionToken, SecretAgentInfo}; use neqo_qpack::Stats as QpackStats; use neqo_transport::{ - send_stream::SendStreamStats, AppError, Connection, ConnectionEvent, ConnectionId, - ConnectionIdGenerator, DatagramTracking, Output, Stats as TransportStats, StreamId, StreamType, - Version, ZeroRttState, + send_stream::SendStreamStats, streams::SendOrder, AppError, Connection, ConnectionEvent, + ConnectionId, ConnectionIdGenerator, DatagramTracking, Output, Stats as TransportStats, + StreamId, StreamType, Version, ZeroRttState, }; use std::{ cell::RefCell, @@ -755,6 +755,28 @@ impl Http3Client { - u64::try_from(Encoder::varint_len(session_id.as_u64())).unwrap()) } + /// Sets the `SendOrder` for a given stream + /// # Errors + /// It may return `InvalidStreamId` if a stream does not exist anymore. + /// # Panics + /// This cannot panic. + pub fn webtransport_set_sendorder( + &mut self, + stream_id: StreamId, + sendorder: SendOrder, + ) -> Res<()> { + Http3Connection::stream_set_sendorder(&mut self.conn, stream_id, Some(sendorder)) + } + + /// Sets the `Fairness` for a given stream + /// # Errors + /// It may return `InvalidStreamId` if a stream does not exist anymore. + /// # Panics + /// This cannot panic. + pub fn webtransport_set_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> { + Http3Connection::stream_set_fairness(&mut self.conn, stream_id, fairness) + } + /// Returns the current `SendStreamStats` of a `WebTransportSendStream`. /// # Errors /// `InvalidStreamId` if the stream does not exist. diff --git a/neqo-http3/src/features/extended_connect/webtransport_session.rs b/neqo-http3/src/features/extended_connect/webtransport_session.rs index 4a412dd27e..c446fd3843 100644 --- a/neqo-http3/src/features/extended_connect/webtransport_session.rs +++ b/neqo-http3/src/features/extended_connect/webtransport_session.rs @@ -17,7 +17,7 @@ use crate::{ }; use neqo_common::{qtrace, Encoder, Header, MessageType, Role}; use neqo_qpack::{QPackDecoder, QPackEncoder}; -use neqo_transport::{Connection, DatagramTracking, StreamId}; +use neqo_transport::{streams::SendOrder, Connection, DatagramTracking, StreamId}; use std::any::Any; use std::cell::RefCell; use std::collections::BTreeSet; @@ -486,6 +486,16 @@ impl SendStream for Rc> { self.borrow_mut().has_data_to_send() } + fn set_sendorder(&mut self, _conn: &mut Connection, _sendorder: Option) -> Res<()> { + // Not relevant on session + Ok(()) + } + + fn set_fairness(&mut self, _conn: &mut Connection, _fairness: bool) -> Res<()> { + // Not relevant on session + Ok(()) + } + fn stream_writable(&self) {} fn done(&self) -> bool { diff --git a/neqo-http3/src/features/extended_connect/webtransport_streams.rs b/neqo-http3/src/features/extended_connect/webtransport_streams.rs index 2c8da3f8d3..f463fd8b2f 100644 --- a/neqo-http3/src/features/extended_connect/webtransport_streams.rs +++ b/neqo-http3/src/features/extended_connect/webtransport_streams.rs @@ -185,6 +185,16 @@ impl SendStream for WebTransportSendStream { } } + fn set_sendorder(&mut self, conn: &mut Connection, sendorder: Option) -> Res<()> { + conn.stream_sendorder(self.stream_id, sendorder) + .map_err(|_| crate::Error::InvalidStreamId) + } + + fn set_fairness(&mut self, conn: &mut Connection, fairness: bool) -> Res<()> { + conn.stream_fairness(self.stream_id, fairness) + .map_err(|_| crate::Error::InvalidStreamId) + } + fn handle_stop_sending(&mut self, close_type: CloseType) { self.set_done(close_type); } diff --git a/neqo-http3/src/lib.rs b/neqo-http3/src/lib.rs index a7827058a1..d89b86af13 100644 --- a/neqo-http3/src/lib.rs +++ b/neqo-http3/src/lib.rs @@ -162,7 +162,7 @@ mod stream_type_reader; use neqo_qpack::Error as QpackError; use neqo_transport::{send_stream::SendStreamStats, AppError, Connection, Error as TransportError}; -pub use neqo_transport::{Output, StreamId}; +pub use neqo_transport::{streams::SendOrder, Output, StreamId}; use std::fmt::Debug; use crate::priority::PriorityHandler; @@ -545,13 +545,15 @@ trait HttpRecvStreamEvents: RecvStreamEvents { trait SendStream: Stream { /// # Errors - /// Error my occure during sending data, e.g. protocol error, etc. + /// Error my occur during sending data, e.g. protocol error, etc. fn send(&mut self, conn: &mut Connection) -> Res<()>; fn has_data_to_send(&self) -> bool; fn stream_writable(&self); fn done(&self) -> bool; + fn set_sendorder(&mut self, conn: &mut Connection, sendorder: Option) -> Res<()>; + fn set_fairness(&mut self, conn: &mut Connection, fairness: bool) -> Res<()>; /// # Errors - /// Error my occure during sending data, e.g. protocol error, etc. + /// Error my occur during sending data, e.g. protocol error, etc. fn send_data(&mut self, _conn: &mut Connection, _buf: &[u8]) -> Res; /// # Errors diff --git a/neqo-http3/src/send_message.rs b/neqo-http3/src/send_message.rs index 3fecae4708..deb0cf3c34 100644 --- a/neqo-http3/src/send_message.rs +++ b/neqo-http3/src/send_message.rs @@ -13,7 +13,7 @@ use crate::{ use neqo_common::{qdebug, qinfo, qtrace, Encoder, Header, MessageType}; use neqo_qpack::encoder::QPackEncoder; -use neqo_transport::{Connection, StreamId}; +use neqo_transport::{streams::SendOrder, Connection, StreamId}; use std::any::Any; use std::cell::RefCell; use std::cmp::min; @@ -271,6 +271,16 @@ impl SendStream for SendMessage { self.stream.has_buffered_data() } + fn set_sendorder(&mut self, _conn: &mut Connection, _sendorder: Option) -> Res<()> { + // Not relevant for SendMessage + Ok(()) + } + + fn set_fairness(&mut self, _conn: &mut Connection, _fairness: bool) -> Res<()> { + // Not relevant for SendMessage + Ok(()) + } + fn close(&mut self, conn: &mut Connection) -> Res<()> { self.state.fin()?; if !self.stream.has_buffered_data() { diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 5494804f75..e07ca0fc4b 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -55,7 +55,7 @@ use crate::{ rtt::GRANULARITY, stats::{Stats, StatsCell}, stream_id::StreamType, - streams::Streams, + streams::{SendOrder, Streams}, tparams::{ self, TransportParameter, TransportParameterId, TransportParameters, TransportParametersHandler, @@ -1911,9 +1911,13 @@ impl Connection { } } + // datagrams are best-effort and unreliable. Let streams starve them for now // Check if there is a Datagram to be written self.quic_datagrams .write_frames(builder, tokens, &mut self.stats.borrow_mut()); + if builder.is_full() { + return Ok(()); + } let stats = &mut self.stats.borrow_mut().frame_tx; @@ -2941,6 +2945,24 @@ impl Connection { Ok(()) } + /// Set the SendOrder of a stream. Re-enqueues to keep the ordering correct + /// # Errors + /// Returns InvalidStreamId if the stream id doesn't exist + pub fn stream_sendorder( + &mut self, + stream_id: StreamId, + sendorder: Option, + ) -> Res<()> { + self.streams.set_sendorder(stream_id, sendorder) + } + + /// Set the Fairness of a stream + /// # Errors + /// Returns InvalidStreamId if the stream id doesn't exist + pub fn stream_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> { + self.streams.set_fairness(stream_id, fairness) + } + pub fn stream_stats(&self, stream_id: StreamId) -> Res { self.streams.get_send_stream(stream_id).map(|s| s.stats()) } diff --git a/neqo-transport/src/connection/tests/stream.rs b/neqo-transport/src/connection/tests/stream.rs index ec2f171555..036a3adff9 100644 --- a/neqo-transport/src/connection/tests/stream.rs +++ b/neqo-transport/src/connection/tests/stream.rs @@ -11,11 +11,14 @@ use super::{ use crate::{ events::ConnectionEvent, recv_stream::RECV_BUFFER_SIZE, + send_stream::OrderGroup, send_stream::{SendStreamState, SEND_BUFFER_SIZE}, + streams::{SendOrder, StreamOrder}, tparams::{self, TransportParameter}, tracking::DEFAULT_ACK_PACKET_TOLERANCE, - Connection, ConnectionError, ConnectionParameters, Error, StreamType, + Connection, ConnectionError, ConnectionParameters, Error, StreamId, StreamType, }; +use std::collections::HashMap; use neqo_common::{event::Provider, qdebug}; use std::{cmp::max, convert::TryFrom, mem}; @@ -111,6 +114,201 @@ fn transfer() { assert!(fin3); } +#[derive(PartialEq, Eq, PartialOrd, Ord)] +struct IdEntry { + sendorder: StreamOrder, + stream_id: StreamId, +} + +// tests stream sendorder priorization +fn sendorder_test(order_of_sendorder: &[Option]) { + let mut client = default_client(); + let mut server = default_server(); + connect_force_idle(&mut client, &mut server); + + qdebug!("---- client sends"); + // open all streams and set the sendorders + let mut ordered = Vec::new(); + let mut streams = Vec::::new(); + for sendorder in order_of_sendorder { + let id = client.stream_create(StreamType::UniDi).unwrap(); + streams.push(id); + ordered.push((id, *sendorder)); + // must be set before sendorder + client.streams.set_fairness(id, true).ok(); + client.streams.set_sendorder(id, *sendorder).ok(); + } + // Write some data to all the streams + for stream_id in streams { + client.stream_send(stream_id, &[6; 100]).unwrap(); + } + + // Sending this much takes a few datagrams. + // Note: this test uses an RTT of 0 which simplifies things (no pacing) + let mut datagrams = Vec::new(); + let mut out = client.process_output(now()); + while let Some(d) = out.dgram() { + datagrams.push(d); + out = client.process_output(now()); + } + assert_eq!(*client.state(), State::Confirmed); + + qdebug!("---- server receives"); + for (_, d) in datagrams.into_iter().enumerate() { + let out = server.process(Some(d), now()); + qdebug!("Output={:0x?}", out.as_dgram_ref()); + } + assert_eq!(*server.state(), State::Confirmed); + + let stream_ids = server + .events() + .filter_map(|evt| match evt { + ConnectionEvent::RecvStreamReadable { stream_id, .. } => Some(stream_id), + _ => None, + }) + .enumerate() + .map(|(a, b)| (b, a)) + .collect::>(); + + // streams should arrive in priority order, not order of creation, if sendorder prioritization + // is working correctly + + // 'ordered' has the send order currently. Re-sort it by sendorder, but + // if two items from the same sendorder exist, secondarily sort by the ordering in + // the stream_ids vector (HashMap) + ordered.sort_unstable_by_key(|(stream_id, sendorder)| { + ( + StreamOrder { + sendorder: *sendorder, + }, + stream_ids[stream_id], + ) + }); + // make sure everything now is in the same order, since we modified the order of + // same-sendorder items to match the ordering of those we saw in reception + for (i, (stream_id, _sendorder)) in ordered.iter().enumerate() { + assert_eq!(i, stream_ids[stream_id]); + } +} + +#[test] +fn sendorder_0() { + sendorder_test(&[None, Some(1), Some(2), Some(3)]); +} +#[test] +fn sendorder_1() { + sendorder_test(&[Some(3), Some(2), Some(1), None]); +} +#[test] +fn sendorder_2() { + sendorder_test(&[Some(3), None, Some(2), Some(1)]); +} +#[test] +fn sendorder_3() { + sendorder_test(&[Some(1), Some(2), None, Some(3)]); +} +#[test] +fn sendorder_4() { + sendorder_test(&[ + Some(1), + Some(2), + Some(1), + None, + Some(3), + Some(1), + Some(3), + None, + ]); +} + +// Tests stream sendorder priorization +// Converts Vecs of u64's into StreamIds +fn fairness_test(source: S, number_iterates: usize, truncate_to: usize, result_array: &R) +where + S: IntoIterator, + S::Item: Into, + R: IntoIterator + std::fmt::Debug, + R::Item: Into, + Vec: PartialEq, +{ + // test the OrderGroup code used for fairness + let mut group: OrderGroup = OrderGroup::default(); + for stream_id in source { + group.insert(stream_id.into()); + } + { + let mut iterator1 = group.iter(); + // advance_by() would help here + let mut n = number_iterates; + while n > 0 { + iterator1.next(); + n -= 1; + } + // let iterator1 go out of scope + } + group.truncate(truncate_to); + + let iterator2 = group.iter(); + let result: Vec = iterator2.map(StreamId::as_u64).collect(); + assert_eq!(result, *result_array); +} + +#[test] +fn ordergroup_0() { + let source: [u64; 0] = []; + let result: [u64; 0] = []; + fairness_test(source, 1, usize::MAX, &result); +} + +#[test] +fn ordergroup_1() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [1, 2, 3, 4, 5, 0]; + fairness_test(source, 1, usize::MAX, &result); +} + +#[test] +fn ordergroup_2() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [2, 3, 4, 5, 0, 1]; + fairness_test(source, 2, usize::MAX, &result); +} + +#[test] +fn ordergroup_3() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [0, 1, 2, 3, 4, 5]; + fairness_test(source, 10, usize::MAX, &result); +} + +#[test] +fn ordergroup_4() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [0, 1, 2, 3, 4, 5]; + fairness_test(source, 0, usize::MAX, &result); +} + +#[test] +fn ordergroup_5() { + let source: [u64; 1] = [0]; + let result: [u64; 1] = [0]; + fairness_test(source, 1, usize::MAX, &result); +} + +#[test] +fn ordergroup_6() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 6] = [5, 0, 1, 2, 3, 4]; + fairness_test(source, 5, usize::MAX, &result); +} + +#[test] +fn ordergroup_7() { + let source: [u64; 6] = [0, 1, 2, 3, 4, 5]; + let result: [u64; 3] = [0, 1, 2]; + fairness_test(source, 5, 3, &result); +} + #[test] // Send fin even if a peer closes a reomte bidi send stream before sending any data. fn report_fin_when_stream_closed_wo_data() { diff --git a/neqo-transport/src/send_stream.rs b/neqo-transport/src/send_stream.rs index b51ee92ca6..4a2bf08002 100644 --- a/neqo-transport/src/send_stream.rs +++ b/neqo-transport/src/send_stream.rs @@ -18,6 +18,7 @@ use std::{ use indexmap::IndexMap; use smallvec::SmallVec; +use std::hash::{Hash, Hasher}; use neqo_common::{qdebug, qerror, qinfo, qtrace, Encoder, Role}; @@ -29,6 +30,7 @@ use crate::{ recovery::{RecoveryToken, StreamRecoveryToken}, stats::FrameStats, stream_id::StreamId, + streams::SendOrder, tparams::{self, TransportParameters}, AppError, Error, Res, }; @@ -613,9 +615,24 @@ pub struct SendStream { priority: TransmissionPriority, retransmission_priority: RetransmissionPriority, retransmission_offset: u64, + sendorder: Option, bytes_sent: u64, + fair: bool, } +impl Hash for SendStream { + fn hash(&self, state: &mut H) { + self.stream_id.hash(state) + } +} + +impl PartialEq for SendStream { + fn eq(&self, other: &Self) -> bool { + self.stream_id == other.stream_id + } +} +impl Eq for SendStream {} + impl SendStream { pub fn new( stream_id: StreamId, @@ -633,7 +650,9 @@ impl SendStream { priority: TransmissionPriority::default(), retransmission_priority: RetransmissionPriority::default(), retransmission_offset: 0, + sendorder: None, bytes_sent: 0, + fair: false, }; if ss.avail() > 0 { ss.conn_events.send_stream_writable(stream_id); @@ -641,6 +660,49 @@ impl SendStream { ss } + pub fn write_frames( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec, + stats: &mut FrameStats, + ) { + qtrace!("write STREAM frames at priority {:?}", priority); + if !self.write_reset_frame(priority, builder, tokens, stats) { + self.write_blocked_frame(priority, builder, tokens, stats); + self.write_stream_frame(priority, builder, tokens, stats); + } + } + + // return false if the builder is full and the caller should stop iterating + pub fn write_frames_with_early_return( + &mut self, + priority: TransmissionPriority, + builder: &mut PacketBuilder, + tokens: &mut Vec, + stats: &mut FrameStats, + ) -> bool { + if !self.write_reset_frame(priority, builder, tokens, stats) { + self.write_blocked_frame(priority, builder, tokens, stats); + if builder.is_full() { + return false; + } + self.write_stream_frame(priority, builder, tokens, stats); + if builder.is_full() { + return false; + } + } + true + } + + pub fn set_fairness(&mut self, make_fair: bool) { + self.fair = make_fair; + } + + pub fn is_fair(&self) -> bool { + self.fair + } + pub fn set_priority( &mut self, transmission: TransmissionPriority, @@ -650,6 +712,14 @@ impl SendStream { self.retransmission_priority = retransmission; } + pub fn sendorder(&self) -> Option { + self.sendorder + } + + pub fn set_sendorder(&mut self, sendorder: Option) { + self.sendorder = sendorder; + } + /// If all data has been buffered or written, how much was sent. pub fn final_size(&self) -> Option { match &self.state { @@ -769,7 +839,7 @@ impl SendStream { } /// Maybe write a `STREAM` frame. - fn write_stream_frame( + pub fn write_stream_frame( &mut self, priority: TransmissionPriority, builder: &mut PacketBuilder, @@ -1193,61 +1263,278 @@ impl ::std::fmt::Display for SendStream { } #[derive(Debug, Default)] -pub(crate) struct SendStreams(IndexMap); +pub struct OrderGroup { + // This vector is sorted by StreamId + vec: Vec, + + // Since we need to remember where we were, we'll store the iterator next + // position in the object. This means there can only be a single iterator active + // at a time! + next: usize, + // This is used when an iterator is created to set the start/stop point for the + // iteration. The iterator must iterate from this entry to the end, and then + // wrap and iterate from 0 until before the initial value of next. + // This value may need to be updated after insertion and removal; in theory we should + // track the target entry across modifications, but in practice it should be good + // enough to simply leave it alone unless it points past the end of the + // Vec, and re-initialize to 0 in that case. +} + +pub struct OrderGroupIter<'a> { + group: &'a mut OrderGroup, + // We store the next position in the OrderGroup. + // Otherwise we'd need an explicit "done iterating" call to be made, or implement Drop to + // copy the value back. + // This is where next was when we iterated for the first time; when we get back to that we stop. + started_at: Option, +} + +impl OrderGroup { + pub fn iter(&mut self) -> OrderGroupIter { + // Ids may have been deleted since we last iterated + if self.next >= self.vec.len() { + self.next = 0; + } + OrderGroupIter { + started_at: None, + group: self, + } + } + + pub fn stream_ids(&self) -> &Vec { + &self.vec + } + + pub fn clear(&mut self) { + self.vec.clear(); + } + + pub fn push(&mut self, stream_id: StreamId) { + self.vec.push(stream_id); + } + + #[cfg(test)] + pub fn truncate(&mut self, position: usize) { + self.vec.truncate(position); + } + + fn update_next(&mut self) -> usize { + let next = self.next; + self.next = (self.next + 1) % self.vec.len(); + next + } + + pub fn insert(&mut self, stream_id: StreamId) { + match self.vec.binary_search(&stream_id) { + Ok(_) => panic!("Duplicate stream_id {}", stream_id), // element already in vector @ `pos` + Err(pos) => self.vec.insert(pos, stream_id), + } + } + + pub fn remove(&mut self, stream_id: StreamId) { + match self.vec.binary_search(&stream_id) { + Ok(pos) => { + self.vec.remove(pos); + } + Err(_) => panic!("Missing stream_id {}", stream_id), // element already in vector @ `pos` + } + } +} + +impl<'a> Iterator for OrderGroupIter<'a> { + type Item = StreamId; + fn next(&mut self) -> Option { + // Stop when we would return the started_at element on the next + // call. Note that this must take into account wrapping. + if self.started_at == Some(self.group.next) || self.group.vec.is_empty() { + return None; + } + self.started_at = self.started_at.or(Some(self.group.next)); + let orig = self.group.update_next(); + Some(self.group.vec[orig]) + } +} + +#[derive(Debug, Default)] +pub(crate) struct SendStreams { + map: IndexMap, + + // What we really want is a Priority Queue that we can do arbitrary + // removes from (so we can reprioritize). BinaryHeap doesn't work, + // because there's no remove(). BTreeMap doesn't work, since you can't + // duplicate keys. PriorityQueue does have what we need, except for an + // ordered iterator that doesn't consume the queue. So we roll our own. + + // Added complication: We want to have Fairness for streams of the same + // 'group' (for WebTransport), but for H3 (and other non-WT streams) we + // tend to get better pageload performance by prioritizing by creation order. + // + // Two options are to walk the 'map' first, ignoring WebTransport + // streams, then process the unordered and ordered WebTransport + // streams. The second is to have a sorted Vec for unfair streams (and + // use a normal iterator for that), and then chain the iterators for + // the unordered and ordered WebTranport streams. The first works very + // well for H3, and for WebTransport nodes are visited twice on every + // processing loop. The second adds insertion and removal costs, but + // avoids a CPU penalty for WebTransport streams. For now we'll do #1. + // + // So we use a sorted Vec<> for the regular streams (that's usually all of + // them), and then a BTreeMap of an entry for each SendOrder value, and + // for each of those entries a Vec of the stream_ids at that + // sendorder. In most cases (such as stream-per-frame), there will be + // a single stream at a given sendorder. + + // These both store stream_ids, which need to be looked up in 'map'. + // This avoids the complexity of trying to hold references to the + // Streams which are owned by the IndexMap. + sendordered: BTreeMap, + regular: OrderGroup, // streams with no SendOrder set, sorted in stream_id order +} impl SendStreams { pub fn get(&self, id: StreamId) -> Res<&SendStream> { - self.0.get(&id).ok_or(Error::InvalidStreamId) + self.map.get(&id).ok_or(Error::InvalidStreamId) } pub fn get_mut(&mut self, id: StreamId) -> Res<&mut SendStream> { - self.0.get_mut(&id).ok_or(Error::InvalidStreamId) + self.map.get_mut(&id).ok_or(Error::InvalidStreamId) } pub fn exists(&self, id: StreamId) -> bool { - self.0.contains_key(&id) + self.map.contains_key(&id) } pub fn insert(&mut self, id: StreamId, stream: SendStream) { - self.0.insert(id, stream); + self.map.insert(id, stream); + } + + fn group_mut(&mut self, sendorder: Option) -> &mut OrderGroup { + if let Some(order) = sendorder { + self.sendordered.entry(order).or_default() + } else { + &mut self.regular + } + } + + pub fn set_sendorder(&mut self, stream_id: StreamId, sendorder: Option) -> Res<()> { + self.set_fairness(stream_id, true)?; + if let Some(stream) = self.map.get_mut(&stream_id) { + // don't grab stream here; causes borrow errors + let old_sendorder = stream.sendorder(); + if old_sendorder != sendorder { + // we have to remove it from the list it was in, and reinsert it with the new + // sendorder key + let mut group = self.group_mut(old_sendorder); + group.remove(stream_id); + self.get_mut(stream_id).unwrap().set_sendorder(sendorder); + group = self.group_mut(sendorder); + group.insert(stream_id); + qtrace!( + "ordering of stream_ids: {:?}", + self.sendordered.values().collect::>() + ); + } + Ok(()) + } else { + Err(Error::InvalidStreamId) + } + } + + pub fn set_fairness(&mut self, stream_id: StreamId, make_fair: bool) -> Res<()> { + let stream: &mut SendStream = self.map.get_mut(&stream_id).ok_or(Error::InvalidStreamId)?; + let was_fair = stream.fair; + stream.set_fairness(make_fair); + if !was_fair && make_fair { + // Move to the regular OrderGroup. + + // We know sendorder can't have been set, since + // set_sendorder() will call this routine if it's not + // already set as fair. + + // This normally is only called when a new stream is created. If + // so, because of how we allocate StreamIds, it should always have + // the largest value. This means we can just append it to the + // regular vector. However, if we were ever to change this + // invariant, things would break subtly. + + // To be safe we can try to insert at the end and if not + // fall back to binary-search insertion + if matches!(self.regular.stream_ids().last(), Some(last) if stream_id > *last) { + self.regular.push(stream_id); + } else { + self.regular.insert(stream_id); + } + } else if was_fair && !make_fair { + // remove from the OrderGroup + let group = if let Some(sendorder) = stream.sendorder { + self.sendordered.get_mut(&sendorder).unwrap() + } else { + &mut self.regular + }; + group.remove(stream_id); + } + Ok(()) } pub fn acked(&mut self, token: &SendStreamRecoveryToken) { - if let Some(ss) = self.0.get_mut(&token.id) { + if let Some(ss) = self.map.get_mut(&token.id) { ss.mark_as_acked(token.offset, token.length, token.fin); } } pub fn reset_acked(&mut self, id: StreamId) { - if let Some(ss) = self.0.get_mut(&id) { + if let Some(ss) = self.map.get_mut(&id) { ss.reset_acked() } } pub fn lost(&mut self, token: &SendStreamRecoveryToken) { - if let Some(ss) = self.0.get_mut(&token.id) { + if let Some(ss) = self.map.get_mut(&token.id) { ss.mark_as_lost(token.offset, token.length, token.fin); } } pub fn reset_lost(&mut self, stream_id: StreamId) { - if let Some(ss) = self.0.get_mut(&stream_id) { + if let Some(ss) = self.map.get_mut(&stream_id) { ss.reset_lost(); } } pub fn blocked_lost(&mut self, stream_id: StreamId, limit: u64) { - if let Some(ss) = self.0.get_mut(&stream_id) { + if let Some(ss) = self.map.get_mut(&stream_id) { ss.blocked_lost(limit); } } pub fn clear(&mut self) { - self.0.clear() - } - - pub fn clear_terminal(&mut self) { - self.0.retain(|_, stream| !stream.is_terminal()) + self.map.clear(); + self.sendordered.clear(); + self.regular.clear(); + } + + pub fn remove_terminal(&mut self) { + let map: &mut IndexMap = &mut self.map; + let regular: &mut OrderGroup = &mut self.regular; + let sendordered: &mut BTreeMap = &mut self.sendordered; + + // Take refs to all the items we need to modify instead of &mut + // self to keep the compiler happy (if we use self.map.retain it + // gets upset due to borrows) + map.retain(|stream_id, stream| { + if stream.is_terminal() { + if stream.is_fair() { + match stream.sendorder() { + None => regular.remove(*stream_id), + Some(sendorder) => { + sendordered.get_mut(&sendorder).unwrap().remove(*stream_id) + } + }; + } + // if unfair, we're done + return false; + } + true + }); } pub(crate) fn write_frames( @@ -1258,16 +1545,73 @@ impl SendStreams { stats: &mut FrameStats, ) { qtrace!("write STREAM frames at priority {:?}", priority); - for stream in self.0.values_mut() { - if !stream.write_reset_frame(priority, builder, tokens, stats) { - stream.write_blocked_frame(priority, builder, tokens, stats); - stream.write_stream_frame(priority, builder, tokens, stats); + // WebTransport data (which is Normal) may have a SendOrder + // priority attached. The spec states (6.3 write-chunk 6.1): + + // First, we send any streams without Fairness defined, with + // ordering defined by StreamId. (Http3 streams used for + // e.g. pageload benefit from being processed in order of creation + // so the far side can start acting on a datum/request sooner. All + // WebTransport streams MUST have fairness set.) Then we send + // streams with fairness set (including all WebTransport streams) + // as follows: + + // If stream.[[SendOrder]] is null then this sending MUST NOT + // starve except for flow control reasons or error. If + // stream.[[SendOrder]] is not null then this sending MUST starve + // until all bytes queued for sending on WebTransportSendStreams + // with a non-null and higher [[SendOrder]], that are neither + // errored nor blocked by flow control, have been sent. + + // So data without SendOrder goes first. Then the highest priority + // SendOrdered streams. + // + // Fairness is implemented by a round-robining or "statefully + // iterating" within a single sendorder/unordered vector. We do + // this by recording where we stopped in the previous pass, and + // starting there the next pass. If we store an index into the + // vec, this means we can't use a chained iterator, since we want + // to retain our place-in-the-vector. If we rotate the vector, + // that would let us use the chained iterator, but would require + // more expensive searches for insertion and removal (since the + // sorted order would be lost). + + // Iterate the map, but only those without fairness, then iterate + // OrderGroups, then iterate each group + qdebug!("processing streams... unfair:"); + for stream in self.map.values_mut() { + if !stream.is_fair() { + qdebug!(" {}", stream); + if !stream.write_frames_with_early_return(priority, builder, tokens, stats) { + break; + } + } + } + qdebug!("fair streams:"); + let stream_ids = self.regular.iter().chain( + self.sendordered + .values_mut() + .rev() + .flat_map(|group| group.iter()), + ); + for stream_id in stream_ids { + match self.map.get_mut(&stream_id).unwrap().sendorder() { + Some(order) => qdebug!(" {} ({})", stream_id, order), + None => qdebug!(" None"), + } + if !self + .map + .get_mut(&stream_id) + .unwrap() + .write_frames_with_early_return(priority, builder, tokens, stats) + { + break; } } } pub fn update_initial_limit(&mut self, remote: &TransportParameters) { - for (id, ss) in self.0.iter_mut() { + for (id, ss) in self.map.iter_mut() { let limit = if id.is_bidi() { assert!(!id.is_remote_initiated(Role::Client)); remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE) @@ -1284,7 +1628,7 @@ impl<'a> IntoIterator for &'a mut SendStreams { type IntoIter = indexmap::map::IterMut<'a, StreamId, SendStream>; fn into_iter(self) -> indexmap::map::IterMut<'a, StreamId, SendStream> { - self.0.iter_mut() + self.map.iter_mut() } } @@ -1450,16 +1794,16 @@ mod tests { // Fill the buffer assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE); assert!(matches!(txb.next_bytes(), - Some((0, x)) if x.len()==SEND_BUFFER_SIZE - && x.iter().all(|ch| *ch == 1))); + Some((0, x)) if x.len()==SEND_BUFFER_SIZE + && x.iter().all(|ch| *ch == 1))); // Mark almost all as sent. Get what's left let one_byte_from_end = SEND_BUFFER_SIZE as u64 - 1; txb.mark_as_sent(0, one_byte_from_end as usize); assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 1 - && start == one_byte_from_end - && x.iter().all(|ch| *ch == 1))); + Some((start, x)) if x.len() == 1 + && start == one_byte_from_end + && x.iter().all(|ch| *ch == 1))); // Mark all as sent. Get nothing txb.mark_as_sent(0, SEND_BUFFER_SIZE); @@ -1468,18 +1812,18 @@ mod tests { // Mark as lost. Get it again txb.mark_as_lost(one_byte_from_end, 1); assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 1 - && start == one_byte_from_end - && x.iter().all(|ch| *ch == 1))); + Some((start, x)) if x.len() == 1 + && start == one_byte_from_end + && x.iter().all(|ch| *ch == 1))); // Mark a larger range lost, including beyond what's in the buffer even. // Get a little more let five_bytes_from_end = SEND_BUFFER_SIZE as u64 - 5; txb.mark_as_lost(five_bytes_from_end, 100); assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 5 - && start == five_bytes_from_end - && x.iter().all(|ch| *ch == 1))); + Some((start, x)) if x.len() == 5 + && start == five_bytes_from_end + && x.iter().all(|ch| *ch == 1))); // Contig acked range at start means it can be removed from buffer // Impl of vecdeque should now result in a split buffer when more data @@ -1488,9 +1832,9 @@ mod tests { assert_eq!(txb.send(&[2; 30]), 30); // Just get 5 even though there is more assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 5 - && start == five_bytes_from_end - && x.iter().all(|ch| *ch == 1))); + Some((start, x)) if x.len() == 5 + && start == five_bytes_from_end + && x.iter().all(|ch| *ch == 1))); assert_eq!(txb.retired, five_bytes_from_end); assert_eq!(txb.buffered(), 35); @@ -1498,9 +1842,9 @@ mod tests { // when called again txb.mark_as_sent(five_bytes_from_end, 5); assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 30 - && start == SEND_BUFFER_SIZE as u64 - && x.iter().all(|ch| *ch == 2))); + Some((start, x)) if x.len() == 30 + && start == SEND_BUFFER_SIZE as u64 + && x.iter().all(|ch| *ch == 2))); } #[test] @@ -1512,8 +1856,8 @@ mod tests { // Fill the buffer assert_eq!(txb.send(&[1; SEND_BUFFER_SIZE * 2]), SEND_BUFFER_SIZE); assert!(matches!(txb.next_bytes(), - Some((0, x)) if x.len()==SEND_BUFFER_SIZE - && x.iter().all(|ch| *ch == 1))); + Some((0, x)) if x.len()==SEND_BUFFER_SIZE + && x.iter().all(|ch| *ch == 1))); // As above let forty_bytes_from_end = SEND_BUFFER_SIZE as u64 - 40; @@ -1531,18 +1875,18 @@ mod tests { txb.mark_as_sent(forty_bytes_from_end, 10); let thirty_bytes_from_end = forty_bytes_from_end + 10; assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 30 - && start == thirty_bytes_from_end - && x.iter().all(|ch| *ch == 1))); + Some((start, x)) if x.len() == 30 + && start == thirty_bytes_from_end + && x.iter().all(|ch| *ch == 1))); // Mark a range 'A' in second slice as sent. Should still return the same let range_a_start = SEND_BUFFER_SIZE as u64 + 30; let range_a_end = range_a_start + 10; txb.mark_as_sent(range_a_start, 10); assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 30 - && start == thirty_bytes_from_end - && x.iter().all(|ch| *ch == 1))); + Some((start, x)) if x.len() == 30 + && start == thirty_bytes_from_end + && x.iter().all(|ch| *ch == 1))); // Ack entire first slice and into second slice let ten_bytes_past_end = SEND_BUFFER_SIZE as u64 + 10; @@ -1550,17 +1894,17 @@ mod tests { // Get up to marked range A assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 20 - && start == ten_bytes_past_end - && x.iter().all(|ch| *ch == 2))); + Some((start, x)) if x.len() == 20 + && start == ten_bytes_past_end + && x.iter().all(|ch| *ch == 2))); txb.mark_as_sent(ten_bytes_past_end, 20); // Get bit after earlier marked range A assert!(matches!(txb.next_bytes(), - Some((start, x)) if x.len() == 60 - && start == range_a_end - && x.iter().all(|ch| *ch == 2))); + Some((start, x)) if x.len() == 60 + && start == range_a_end + && x.iter().all(|ch| *ch == 2))); // No more bytes. txb.mark_as_sent(range_a_end, 60); diff --git a/neqo-transport/src/stream_id.rs b/neqo-transport/src/stream_id.rs index c82b09d8c4..51df2ca9fb 100644 --- a/neqo-transport/src/stream_id.rs +++ b/neqo-transport/src/stream_id.rs @@ -107,6 +107,12 @@ impl From for StreamId { } } +impl From<&u64> for StreamId { + fn from(val: &u64) -> Self { + Self::new(*val) + } +} + impl PartialEq for StreamId { fn eq(&self, other: &u64) -> bool { self.as_u64() == *other diff --git a/neqo-transport/src/streams.rs b/neqo-transport/src/streams.rs index 398104bf3d..735e602feb 100644 --- a/neqo-transport/src/streams.rs +++ b/neqo-transport/src/streams.rs @@ -5,7 +5,6 @@ // except according to those terms. // Stream management for a connection. - use crate::{ fc::{LocalStreamLimits, ReceiverFlowControl, RemoteStreamLimits, SenderFlowControl}, frame::Frame, @@ -19,8 +18,42 @@ use crate::{ ConnectionEvents, Error, Res, }; use neqo_common::{qtrace, qwarn, Role}; +use std::cmp::Ordering; use std::{cell::RefCell, rc::Rc}; +pub type SendOrder = i64; + +#[derive(Copy, Clone)] +pub struct StreamOrder { + pub sendorder: Option, +} + +// We want highest to lowest, with None being higher than any value +impl Ord for StreamOrder { + fn cmp(&self, other: &Self) -> Ordering { + if self.sendorder.is_some() && other.sendorder.is_some() { + // We want reverse order (high to low) when both values are specified. + other.sendorder.cmp(&self.sendorder) + } else { + self.sendorder.cmp(&other.sendorder) + } + } +} + +impl PartialOrd for StreamOrder { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for StreamOrder { + fn eq(&self, other: &Self) -> bool { + self.sendorder == other.sendorder + } +} + +impl Eq for StreamOrder {} + pub struct Streams { role: Role, tps: Rc>, @@ -66,8 +99,7 @@ impl Streams { } pub fn zero_rtt_rejected(&mut self) { - self.send.clear(); - self.recv.clear(); + self.clear_streams(); debug_assert_eq!( self.remote_stream_limits[StreamType::BiDi].max_active(), self.tps @@ -295,7 +327,9 @@ impl Streams { } pub fn cleanup_closed_streams(&mut self) { - self.send.clear_terminal(); + // filter the list, removing closed streams + self.send.remove_terminal(); + let send = &self.send; let (removed_bidi, removed_uni) = self.recv.clear_terminal(send, self.role); @@ -377,6 +411,14 @@ impl Streams { )) } + pub fn set_sendorder(&mut self, stream_id: StreamId, sendorder: Option) -> Res<()> { + self.send.set_sendorder(stream_id, sendorder) + } + + pub fn set_fairness(&mut self, stream_id: StreamId, fairness: bool) -> Res<()> { + self.send.set_fairness(stream_id, fairness) + } + pub fn stream_create(&mut self, st: StreamType) -> Res { match self.local_stream_limits.take_stream_id(st) { None => Err(Error::StreamLimitError), @@ -386,15 +428,14 @@ impl Streams { StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, }; let send_limit = self.tps.borrow().remote().get_integer(send_limit_tp); - self.send.insert( + let stream = SendStream::new( new_id, - SendStream::new( - new_id, - send_limit, - Rc::clone(&self.sender_fc), - self.events.clone(), - ), + send_limit, + Rc::clone(&self.sender_fc), + self.events.clone(), ); + self.send.insert(new_id, stream); + if st == StreamType::BiDi { // From the local perspective, this is a local- originated BiDi stream. From the // remote perspective, this is a remote-originated BiDi stream. Therefore, look at From a0158f624129e570a392e7b1b530dbaebf0ab367 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Thu, 22 Jun 2023 08:00:33 +0100 Subject: [PATCH 2/4] Cast the f64 version from the integer version (#1445) --- neqo-transport/src/cc/mod.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/neqo-transport/src/cc/mod.rs b/neqo-transport/src/cc/mod.rs index 50a3e73c82..5cd5676747 100644 --- a/neqo-transport/src/cc/mod.rs +++ b/neqo-transport/src/cc/mod.rs @@ -7,26 +7,26 @@ // Congestion control #![deny(clippy::pedantic)] -use crate::path::PATH_MTU_V6; -use crate::tracking::SentPacket; -use crate::Error; +use crate::{path::PATH_MTU_V6, tracking::SentPacket, Error}; use neqo_common::qlog::NeqoQlog; -use std::fmt::{Debug, Display}; -use std::str::FromStr; -use std::time::{Duration, Instant}; +use std::{ + fmt::{Debug, Display}, + str::FromStr, + time::{Duration, Instant}, +}; mod classic_cc; mod cubic; mod new_reno; -pub use classic_cc::ClassicCongestionControl; -pub use classic_cc::{CWND_INITIAL, CWND_INITIAL_PKTS, CWND_MIN}; +pub use classic_cc::{ClassicCongestionControl, CWND_INITIAL, CWND_INITIAL_PKTS, CWND_MIN}; pub use cubic::Cubic; pub use new_reno::NewReno; pub const MAX_DATAGRAM_SIZE: usize = PATH_MTU_V6; -pub const MAX_DATAGRAM_SIZE_F64: f64 = 1337.0; +#[allow(clippy::cast_precision_loss)] +pub const MAX_DATAGRAM_SIZE_F64: f64 = MAX_DATAGRAM_SIZE as f64; pub trait CongestionControl: Display + Debug { fn set_qlog(&mut self, qlog: NeqoQlog); From c5b41dac5c35130a5abe20d2828690f02c3de980 Mon Sep 17 00:00:00 2001 From: Kershaw Date: Fri, 23 Jun 2023 09:59:24 +0200 Subject: [PATCH 3/4] Implement RecvStreamStats (#1444) * Implement RecvStreamStats * address comments --- neqo-transport/src/recv_stream.rs | 228 ++++++++++++++++++++++++++---- 1 file changed, 201 insertions(+), 27 deletions(-) diff --git a/neqo-transport/src/recv_stream.rs b/neqo-transport/src/recv_stream.rs index 19e233f8ca..fbd2fad7bb 100644 --- a/neqo-transport/src/recv_stream.rs +++ b/neqo-transport/src/recv_stream.rs @@ -122,6 +122,7 @@ impl RecvStreams { pub struct RxStreamOrderer { data_ranges: BTreeMap>, // (start_offset, data) retired: u64, // Number of bytes the application has read + received: u64, // The number of bytes has stored in `data_ranges` } impl RxStreamOrderer { @@ -231,6 +232,7 @@ impl RxStreamOrderer { } if !to_add.is_empty() { + self.received += u64::try_from(to_add.len()).unwrap(); if extend { let (_, buf) = self .data_ranges @@ -280,6 +282,10 @@ impl RxStreamOrderer { self.retired } + fn received(&self) -> u64 { + self.received + } + /// Data bytes buffered. Could be more than bytes_readable if there are /// ranges missing. fn buffered(&self) -> u64 { @@ -359,19 +365,29 @@ enum RecvStreamState { session_fc: Rc>>, recv_buf: RxStreamOrderer, }, - DataRead, + DataRead { + final_received: u64, + final_read: u64, + }, AbortReading { fc: ReceiverFlowControl, session_fc: Rc>>, final_size_reached: bool, frame_needed: bool, err: AppError, + final_received: u64, + final_read: u64, }, WaitForReset { fc: ReceiverFlowControl, session_fc: Rc>>, + final_received: u64, + final_read: u64, + }, + ResetRecvd { + final_received: u64, + final_read: u64, }, - ResetRecvd, // Defined by spec but we don't use it: ResetRead } @@ -393,10 +409,10 @@ impl RecvStreamState { Self::Recv { .. } => "Recv", Self::SizeKnown { .. } => "SizeKnown", Self::DataRecvd { .. } => "DataRecvd", - Self::DataRead => "DataRead", + Self::DataRead { .. } => "DataRead", Self::AbortReading { .. } => "AbortReading", Self::WaitForReset { .. } => "WaitForReset", - Self::ResetRecvd => "ResetRecvd", + Self::ResetRecvd { .. } => "ResetRecvd", } } @@ -405,10 +421,10 @@ impl RecvStreamState { Self::Recv { recv_buf, .. } | Self::SizeKnown { recv_buf, .. } | Self::DataRecvd { recv_buf, .. } => Some(recv_buf), - Self::DataRead + Self::DataRead { .. } | Self::AbortReading { .. } | Self::WaitForReset { .. } - | Self::ResetRecvd => None, + | Self::ResetRecvd { .. } => None, } } @@ -429,7 +445,7 @@ impl RecvStreamState { *final_size_reached |= fin; (fc, session_fc, old_final_size_reached, true) } - Self::DataRead | Self::ResetRecvd => { + Self::DataRead { .. } | Self::ResetRecvd { .. } => { return Ok(()); } }; @@ -456,6 +472,40 @@ impl RecvStreamState { } } +// See https://www.w3.org/TR/webtransport/#receive-stream-stats +#[derive(Debug, Clone, Copy)] +pub struct RecvStreamStats { + // An indicator of progress on how many of the server application’s bytes + // intended for this stream have been received so far. + // Only sequential bytes up to, but not including, the first missing byte, + // are counted. This number can only increase. + pub bytes_received: u64, + // The total number of bytes the application has successfully read from this + // stream. This number can only increase, and is always less than or equal + // to bytes_received. + pub bytes_read: u64, +} + +impl RecvStreamStats { + #[must_use] + pub fn new(bytes_received: u64, bytes_read: u64) -> Self { + Self { + bytes_received, + bytes_read, + } + } + + #[must_use] + pub fn bytes_received(&self) -> u64 { + self.bytes_received + } + + #[must_use] + pub fn bytes_read(&self) -> u64 { + self.bytes_read + } +} + /// Implement a QUIC receive stream. #[derive(Debug)] pub struct RecvStream { @@ -497,11 +547,11 @@ impl RecvStream { // is cause to stop keep-alives. RecvStreamState::DataRecvd { .. } | RecvStreamState::AbortReading { .. } - | RecvStreamState::ResetRecvd => { + | RecvStreamState::ResetRecvd { .. } => { self.keep_alive = None; } // Once all the data is read, generate an event. - RecvStreamState::DataRead => { + RecvStreamState::DataRead { .. } => { self.conn_events.recv_stream_complete(self.stream_id); } _ => {} @@ -510,6 +560,40 @@ impl RecvStream { self.state = new_state; } + pub fn stats(&self) -> RecvStreamStats { + match &self.state { + RecvStreamState::Recv { recv_buf, .. } + | RecvStreamState::SizeKnown { recv_buf, .. } + | RecvStreamState::DataRecvd { recv_buf, .. } => { + let received = recv_buf.received(); + let read = recv_buf.retired(); + RecvStreamStats::new(received, read) + } + RecvStreamState::AbortReading { + final_received, + final_read, + .. + } + | RecvStreamState::WaitForReset { + final_received, + final_read, + .. + } + | RecvStreamState::DataRead { + final_received, + final_read, + } + | RecvStreamState::ResetRecvd { + final_received, + final_read, + } => { + let received = *final_received; + let read = *final_read; + RecvStreamStats::new(received, read) + } + } + } + pub fn inbound_stream_frame(&mut self, fin: bool, offset: u64, data: &[u8]) -> Res<()> { // We should post a DataReadable event only once when we change from no-data-ready to // data-ready. Therefore remember the state before processing a new frame. @@ -564,10 +648,10 @@ impl RecvStream { } } RecvStreamState::DataRecvd { .. } - | RecvStreamState::DataRead + | RecvStreamState::DataRead { .. } | RecvStreamState::AbortReading { .. } | RecvStreamState::WaitForReset { .. } - | RecvStreamState::ResetRecvd => { + | RecvStreamState::ResetRecvd { .. } => { qtrace!("data received when we are in state {}", self.state.name()) } } @@ -582,15 +666,50 @@ impl RecvStream { pub fn reset(&mut self, application_error_code: AppError, final_size: u64) -> Res<()> { self.state.flow_control_consume_data(final_size, true)?; match &mut self.state { - RecvStreamState::Recv { fc, session_fc, .. } - | RecvStreamState::SizeKnown { fc, session_fc, .. } - | RecvStreamState::AbortReading { fc, session_fc, .. } - | RecvStreamState::WaitForReset { fc, session_fc } => { + RecvStreamState::Recv { + fc, + session_fc, + recv_buf, + } + | RecvStreamState::SizeKnown { + fc, + session_fc, + recv_buf, + } => { // make flow control consumes new data that not really exist. Self::flow_control_retire_data(final_size - fc.retired(), fc, session_fc); self.conn_events .recv_stream_reset(self.stream_id, application_error_code); - self.set_state(RecvStreamState::ResetRecvd); + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::ResetRecvd { + final_received: received, + final_read: read, + }); + } + RecvStreamState::AbortReading { + fc, + session_fc, + final_received, + final_read, + .. + } + | RecvStreamState::WaitForReset { + fc, + session_fc, + final_received, + final_read, + } => { + // make flow control consumes new data that not really exist. + Self::flow_control_retire_data(final_size - fc.retired(), fc, session_fc); + self.conn_events + .recv_stream_reset(self.stream_id, application_error_code); + let received = *final_received; + let read = *final_read; + self.set_state(RecvStreamState::ResetRecvd { + final_received: received, + final_read: read, + }); } _ => { // Ignore reset if in DataRecvd, DataRead, or ResetRecvd @@ -629,7 +748,7 @@ impl RecvStream { pub fn is_terminal(&self) -> bool { matches!( self.state, - RecvStreamState::ResetRecvd | RecvStreamState::DataRead + RecvStreamState::ResetRecvd { .. } | RecvStreamState::DataRead { .. } ) } @@ -669,7 +788,12 @@ impl RecvStream { Self::flow_control_retire_data(u64::try_from(bytes_read).unwrap(), fc, session_fc); let fin_read = if data_recvd_state { if recv_buf.buffered() == 0 { - self.set_state(RecvStreamState::DataRead); + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::DataRead { + final_received: received, + final_read: read, + }); true } else { false @@ -679,38 +803,59 @@ impl RecvStream { }; Ok((bytes_read, fin_read)) } - RecvStreamState::DataRead + RecvStreamState::DataRead { .. } | RecvStreamState::AbortReading { .. } | RecvStreamState::WaitForReset { .. } - | RecvStreamState::ResetRecvd => Err(Error::NoMoreData), + | RecvStreamState::ResetRecvd { .. } => Err(Error::NoMoreData), } } pub fn stop_sending(&mut self, err: AppError) { qtrace!("stop_sending called when in state {}", self.state.name()); match &mut self.state { - RecvStreamState::Recv { fc, session_fc, .. } - | RecvStreamState::SizeKnown { fc, session_fc, .. } => { + RecvStreamState::Recv { + fc, + session_fc, + recv_buf, + } + | RecvStreamState::SizeKnown { + fc, + session_fc, + recv_buf, + } => { // Retire data Self::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc); let fc_copy = mem::take(fc); let session_fc_copy = mem::take(session_fc); + let received = recv_buf.received(); + let read = recv_buf.retired(); self.set_state(RecvStreamState::AbortReading { fc: fc_copy, session_fc: session_fc_copy, final_size_reached: matches!(self.state, RecvStreamState::SizeKnown { .. }), frame_needed: true, err, + final_received: received, + final_read: read, }) } - RecvStreamState::DataRecvd { fc, session_fc, .. } => { + RecvStreamState::DataRecvd { + fc, + session_fc, + recv_buf, + } => { Self::flow_control_retire_data(fc.consumed() - fc.retired(), fc, session_fc); - self.set_state(RecvStreamState::DataRead); + let received = recv_buf.received(); + let read = recv_buf.retired(); + self.set_state(RecvStreamState::DataRead { + final_received: received, + final_read: read, + }); } - RecvStreamState::DataRead + RecvStreamState::DataRead { .. } | RecvStreamState::AbortReading { .. } | RecvStreamState::WaitForReset { .. } - | RecvStreamState::ResetRecvd => { + | RecvStreamState::ResetRecvd { .. } => { // Already in terminal state } } @@ -765,19 +910,28 @@ impl RecvStream { fc, session_fc, final_size_reached, + final_received, + final_read, .. } = &mut self.state { + let received = *final_received; + let read = *final_read; if *final_size_reached { // We already know the final_size of the stream therefore we // do not need to wait for RESET. - self.set_state(RecvStreamState::ResetRecvd); + self.set_state(RecvStreamState::ResetRecvd { + final_received: received, + final_read: read, + }); } else { let fc_copy = mem::take(fc); let session_fc_copy = mem::take(session_fc); self.set_state(RecvStreamState::WaitForReset { fc: fc_copy, session_fc: session_fc_copy, + final_received: received, + final_read: read, }); } } @@ -1050,6 +1204,12 @@ mod tests { assert_eq!(0, s.read(&mut buf[..])); } + fn check_stats(stream: &RecvStream, expected_received: u64, expected_read: u64) { + let stream_stats = stream.stats(); + assert_eq!(expected_received, stream_stats.bytes_received()); + assert_eq!(expected_read, stream_stats.bytes_read()); + } + #[test] fn stream_rx() { let conn_events = ConnectionEvents::default(); @@ -1064,11 +1224,15 @@ mod tests { // test receiving a contig frame and reading it works s.inbound_stream_frame(false, 0, &[1; 10]).unwrap(); assert!(s.data_ready()); + check_stats(&s, 10, 0); + let mut buf = vec![0u8; 100]; assert_eq!(s.read(&mut buf).unwrap(), (10, false)); assert_eq!(s.state.recv_buf().unwrap().retired(), 10); assert_eq!(s.state.recv_buf().unwrap().buffered(), 0); + check_stats(&s, 10, 10); + // test receiving a noncontig frame s.inbound_stream_frame(false, 12, &[2; 12]).unwrap(); assert!(!s.data_ready()); @@ -1076,12 +1240,16 @@ mod tests { assert_eq!(s.state.recv_buf().unwrap().retired(), 10); assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + check_stats(&s, 22, 10); + // another frame that overlaps the first s.inbound_stream_frame(false, 14, &[3; 8]).unwrap(); assert!(!s.data_ready()); assert_eq!(s.state.recv_buf().unwrap().retired(), 10); assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + check_stats(&s, 22, 10); + // fill in the gap, but with a FIN s.inbound_stream_frame(true, 10, &[4; 6]).unwrap_err(); assert!(!s.data_ready()); @@ -1089,12 +1257,16 @@ mod tests { assert_eq!(s.state.recv_buf().unwrap().retired(), 10); assert_eq!(s.state.recv_buf().unwrap().buffered(), 12); + check_stats(&s, 22, 10); + // fill in the gap s.inbound_stream_frame(false, 10, &[5; 10]).unwrap(); assert!(s.data_ready()); assert_eq!(s.state.recv_buf().unwrap().retired(), 10); assert_eq!(s.state.recv_buf().unwrap().buffered(), 14); + check_stats(&s, 24, 10); + // a legit FIN s.inbound_stream_frame(true, 24, &[6; 18]).unwrap(); assert_eq!(s.state.recv_buf().unwrap().retired(), 10); @@ -1102,6 +1274,8 @@ mod tests { assert!(s.data_ready()); assert_eq!(s.read(&mut buf).unwrap(), (32, true)); + check_stats(&s, 42, 42); + // Stream now no longer readable (is in DataRead state) s.read(&mut buf).unwrap_err(); } From 98bbe59a95b2b1d2b8eed4a8d3d568fc7165552d Mon Sep 17 00:00:00 2001 From: Kershaw Date: Thu, 6 Jul 2023 17:56:21 +0200 Subject: [PATCH 4/4] WebTransportReceiveStreamStats Impl (#1446) * WebTransportReceiveStreamStats Impl * address comments --- neqo-http3/src/connection_client.rs | 17 +++++-- .../tests/webtransport/mod.rs | 8 ++- .../tests/webtransport/streams.rs | 51 ++++++++++++------- .../extended_connect/webtransport_streams.rs | 37 +++++++++++++- neqo-http3/src/lib.rs | 9 +++- neqo-transport/src/connection/mod.rs | 9 +++- neqo-transport/src/lib.rs | 6 +-- 7 files changed, 107 insertions(+), 30 deletions(-) diff --git a/neqo-http3/src/connection_client.rs b/neqo-http3/src/connection_client.rs index 48c9e1baf7..8d0d78922a 100644 --- a/neqo-http3/src/connection_client.rs +++ b/neqo-http3/src/connection_client.rs @@ -21,9 +21,9 @@ use neqo_common::{ use neqo_crypto::{agent::CertificateInfo, AuthenticationStatus, ResumptionToken, SecretAgentInfo}; use neqo_qpack::Stats as QpackStats; use neqo_transport::{ - send_stream::SendStreamStats, streams::SendOrder, AppError, Connection, ConnectionEvent, - ConnectionId, ConnectionIdGenerator, DatagramTracking, Output, Stats as TransportStats, - StreamId, StreamType, Version, ZeroRttState, + streams::SendOrder, AppError, Connection, ConnectionEvent, ConnectionId, ConnectionIdGenerator, + DatagramTracking, Output, RecvStreamStats, SendStreamStats, Stats as TransportStats, StreamId, + StreamType, Version, ZeroRttState, }; use std::{ cell::RefCell, @@ -788,6 +788,17 @@ impl Http3Client { .stats(&mut self.conn) } + /// Returns the current `RecvStreamStats` of a `WebTransportRecvStream`. + /// # Errors + /// `InvalidStreamId` if the stream does not exist. + pub fn webtransport_recv_stream_stats(&mut self, stream_id: StreamId) -> Res { + self.base_handler + .recv_streams + .get_mut(&stream_id) + .ok_or(Error::InvalidStreamId)? + .stats(&mut self.conn) + } + /// This function combines `process_input` and `process_output` function. pub fn process(&mut self, dgram: Option, now: Instant) -> Output { qtrace!([self], "Process."); diff --git a/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs b/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs index f06fb0953f..4ac5f72b0f 100644 --- a/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs +++ b/neqo-http3/src/features/extended_connect/tests/webtransport/mod.rs @@ -13,8 +13,8 @@ use neqo_common::event::Provider; use crate::{ features::extended_connect::SessionCloseReason, Error, Header, Http3Client, Http3ClientEvent, Http3OrWebTransportStream, Http3Parameters, Http3Server, Http3ServerEvent, Http3State, - SendStreamStats, WebTransportEvent, WebTransportRequest, WebTransportServerEvent, - WebTransportSessionAcceptAction, + RecvStreamStats, SendStreamStats, WebTransportEvent, WebTransportRequest, + WebTransportServerEvent, WebTransportSessionAcceptAction, }; use neqo_crypto::AuthenticationStatus; use neqo_transport::{ConnectionParameters, StreamId, StreamType}; @@ -315,6 +315,10 @@ impl WtTest { self.client.webtransport_send_stream_stats(wt_stream_id) } + fn recv_stream_stats(&mut self, wt_stream_id: StreamId) -> Result { + self.client.webtransport_recv_stream_stats(wt_stream_id) + } + fn receive_data_client( &mut self, expected_stream_id: StreamId, diff --git a/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs b/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs index bd99c3ef0c..a50c45d518 100644 --- a/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs +++ b/neqo-http3/src/features/extended_connect/tests/webtransport/streams.rs @@ -16,25 +16,28 @@ fn wt_client_stream_uni() { let mut wt = WtTest::new(); let wt_session = wt.create_wt_session(); let wt_stream = wt.create_wt_stream_client(wt_session.stream_id(), StreamType::UniDi); - let stats = wt.send_stream_stats(wt_stream).unwrap(); - assert_eq!(stats.bytes_written(), 0); - assert_eq!(stats.bytes_sent(), 0); - assert_eq!(stats.bytes_acked(), 0); + let send_stats = wt.send_stream_stats(wt_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), 0); + assert_eq!(send_stats.bytes_sent(), 0); + assert_eq!(send_stats.bytes_acked(), 0); wt.send_data_client(wt_stream, BUF_CLIENT); wt.receive_data_server(wt_stream, true, BUF_CLIENT, false); - let stats = wt.send_stream_stats(wt_stream).unwrap(); - assert_eq!(stats.bytes_written(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_sent(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_acked(), BUF_CLIENT.len() as u64); + let send_stats = wt.send_stream_stats(wt_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_sent(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_acked(), BUF_CLIENT.len() as u64); // Send data again to test if the stats has the expected values. wt.send_data_client(wt_stream, BUF_CLIENT); wt.receive_data_server(wt_stream, false, BUF_CLIENT, false); - let stats = wt.send_stream_stats(wt_stream).unwrap(); - assert_eq!(stats.bytes_written(), (BUF_CLIENT.len() * 2) as u64); - assert_eq!(stats.bytes_sent(), (BUF_CLIENT.len() * 2) as u64); - assert_eq!(stats.bytes_acked(), (BUF_CLIENT.len() * 2) as u64); + let send_stats = wt.send_stream_stats(wt_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), (BUF_CLIENT.len() * 2) as u64); + assert_eq!(send_stats.bytes_sent(), (BUF_CLIENT.len() * 2) as u64); + assert_eq!(send_stats.bytes_acked(), (BUF_CLIENT.len() * 2) as u64); + + let recv_stats = wt.recv_stream_stats(wt_stream); + assert_eq!(recv_stats.unwrap_err(), Error::InvalidStreamId); } #[test] @@ -49,10 +52,14 @@ fn wt_client_stream_bidi() { let mut wt_server_stream = wt.receive_data_server(wt_client_stream, true, BUF_CLIENT, false); wt.send_data_server(&mut wt_server_stream, BUF_SERVER); wt.receive_data_client(wt_client_stream, false, BUF_SERVER, false); - let stats = wt.send_stream_stats(wt_client_stream).unwrap(); - assert_eq!(stats.bytes_written(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_sent(), BUF_CLIENT.len() as u64); - assert_eq!(stats.bytes_acked(), BUF_CLIENT.len() as u64); + let send_stats = wt.send_stream_stats(wt_client_stream).unwrap(); + assert_eq!(send_stats.bytes_written(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_sent(), BUF_CLIENT.len() as u64); + assert_eq!(send_stats.bytes_acked(), BUF_CLIENT.len() as u64); + + let recv_stats = wt.recv_stream_stats(wt_client_stream).unwrap(); + assert_eq!(recv_stats.bytes_received(), BUF_SERVER.len() as u64); + assert_eq!(recv_stats.bytes_read(), BUF_SERVER.len() as u64); } #[test] @@ -64,8 +71,12 @@ fn wt_server_stream_uni() { let mut wt_server_stream = WtTest::create_wt_stream_server(&mut wt_session, StreamType::UniDi); wt.send_data_server(&mut wt_server_stream, BUF_SERVER); wt.receive_data_client(wt_server_stream.stream_id(), true, BUF_SERVER, false); - let stats = wt.send_stream_stats(wt_server_stream.stream_id()); - assert_eq!(stats.unwrap_err(), Error::InvalidStreamId); + let send_stats = wt.send_stream_stats(wt_server_stream.stream_id()); + assert_eq!(send_stats.unwrap_err(), Error::InvalidStreamId); + + let recv_stats = wt.recv_stream_stats(wt_server_stream.stream_id()).unwrap(); + assert_eq!(recv_stats.bytes_received(), BUF_SERVER.len() as u64); + assert_eq!(recv_stats.bytes_read(), BUF_SERVER.len() as u64); } #[test] @@ -84,6 +95,10 @@ fn wt_server_stream_bidi() { assert_eq!(stats.bytes_written(), BUF_CLIENT.len() as u64); assert_eq!(stats.bytes_sent(), BUF_CLIENT.len() as u64); assert_eq!(stats.bytes_acked(), BUF_CLIENT.len() as u64); + + let recv_stats = wt.recv_stream_stats(wt_server_stream.stream_id()).unwrap(); + assert_eq!(recv_stats.bytes_received(), BUF_SERVER.len() as u64); + assert_eq!(recv_stats.bytes_read(), BUF_SERVER.len() as u64); } #[test] diff --git a/neqo-http3/src/features/extended_connect/webtransport_streams.rs b/neqo-http3/src/features/extended_connect/webtransport_streams.rs index f463fd8b2f..ca918dce9e 100644 --- a/neqo-http3/src/features/extended_connect/webtransport_streams.rs +++ b/neqo-http3/src/features/extended_connect/webtransport_streams.rs @@ -10,7 +10,7 @@ use crate::{ SendStream, SendStreamEvents, Stream, }; use neqo_common::Encoder; -use neqo_transport::{send_stream::SendStreamStats, Connection, StreamId}; +use neqo_transport::{Connection, RecvStreamStats, SendStreamStats, StreamId}; use std::cell::RefCell; use std::rc::Rc; @@ -75,6 +75,35 @@ impl RecvStream for WebTransportRecvStream { } Ok((amount, fin)) } + + fn stats(&mut self, conn: &mut Connection) -> Res { + const TYPE_LEN_UNI: usize = Encoder::varint_len(WEBTRANSPORT_UNI_STREAM); + const TYPE_LEN_BIDI: usize = Encoder::varint_len(WEBTRANSPORT_STREAM); + + let stream_header_size = if self.stream_id.is_server_initiated() { + let id_len = if self.stream_id.is_uni() { + TYPE_LEN_UNI + } else { + TYPE_LEN_BIDI + }; + (id_len + Encoder::varint_len(self.session_id.as_u64())) as u64 + } else { + 0 + }; + + let stats = conn.recv_stream_stats(self.stream_id)?; + if stream_header_size == 0 { + return Ok(stats); + } + + let subtract_non_app_bytes = + |count: u64| -> u64 { count.saturating_sub(stream_header_size) }; + + let bytes_received = subtract_non_app_bytes(stats.bytes_received()); + let bytes_read = subtract_non_app_bytes(stats.bytes_read()); + + Ok(RecvStreamStats::new(bytes_received, bytes_read)) + } } #[derive(Debug, PartialEq)] @@ -225,10 +254,14 @@ impl SendStream for WebTransportSendStream { 0 }; + let stats = conn.send_stream_stats(self.stream_id)?; + if stream_header_size == 0 { + return Ok(stats); + } + let subtract_non_app_bytes = |count: u64| -> u64 { count.saturating_sub(stream_header_size) }; - let stats = conn.stream_stats(self.stream_id)?; let bytes_written = subtract_non_app_bytes(stats.bytes_written()); let bytes_sent = subtract_non_app_bytes(stats.bytes_sent()); let bytes_acked = subtract_non_app_bytes(stats.bytes_acked()); diff --git a/neqo-http3/src/lib.rs b/neqo-http3/src/lib.rs index d89b86af13..76be301a8e 100644 --- a/neqo-http3/src/lib.rs +++ b/neqo-http3/src/lib.rs @@ -161,8 +161,10 @@ mod settings; mod stream_type_reader; use neqo_qpack::Error as QpackError; -use neqo_transport::{send_stream::SendStreamStats, AppError, Connection, Error as TransportError}; pub use neqo_transport::{streams::SendOrder, Output, StreamId}; +use neqo_transport::{ + AppError, Connection, Error as TransportError, RecvStreamStats, SendStreamStats, +}; use std::fmt::Debug; use crate::priority::PriorityHandler; @@ -470,6 +472,11 @@ trait RecvStream: Stream { fn webtransport(&self) -> Option>> { None } + + /// This function is only implemented by `WebTransportRecvStream`. + fn stats(&mut self, _conn: &mut Connection) -> Res { + Err(Error::Unavailable) + } } trait HttpRecvStream: RecvStream { diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index e07ca0fc4b..0a388ea70a 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -38,6 +38,7 @@ use crate::{ }, }; +use crate::recv_stream::RecvStreamStats; pub use crate::send_stream::{RetransmissionPriority, SendStreamStats, TransmissionPriority}; use crate::{ crypto::{Crypto, CryptoDxState, CryptoSpace}, @@ -2963,10 +2964,16 @@ impl Connection { self.streams.set_fairness(stream_id, fairness) } - pub fn stream_stats(&self, stream_id: StreamId) -> Res { + pub fn send_stream_stats(&self, stream_id: StreamId) -> Res { self.streams.get_send_stream(stream_id).map(|s| s.stats()) } + pub fn recv_stream_stats(&mut self, stream_id: StreamId) -> Res { + let stream = self.streams.get_recv_stream_mut(stream_id)?; + + Ok(stream.stats()) + } + /// Send data on a stream. /// Returns how many bytes were successfully sent. Could be less /// than total, based on receiver credit space available, etc. diff --git a/neqo-transport/src/lib.rs b/neqo-transport/src/lib.rs index 643ec218ef..daff7e73c2 100644 --- a/neqo-transport/src/lib.rs +++ b/neqo-transport/src/lib.rs @@ -28,7 +28,7 @@ mod quic_datagrams; mod recovery; mod recv_stream; mod rtt; -pub mod send_stream; +mod send_stream; mod sender; pub mod server; mod stats; @@ -53,8 +53,8 @@ pub use self::stats::Stats; pub use self::stream_id::{StreamId, StreamType}; pub use self::version::Version; -pub use self::recv_stream::RECV_BUFFER_SIZE; -pub use self::send_stream::SEND_BUFFER_SIZE; +pub use self::recv_stream::{RecvStreamStats, RECV_BUFFER_SIZE}; +pub use self::send_stream::{SendStreamStats, SEND_BUFFER_SIZE}; pub type TransportError = u64; const ERROR_APPLICATION_CLOSE: TransportError = 12;