diff --git a/neqo-crypto/src/agent.rs b/neqo-crypto/src/agent.rs index 3d5a8b9f35..c04accd775 100644 --- a/neqo-crypto/src/agent.rs +++ b/neqo-crypto/src/agent.rs @@ -875,14 +875,13 @@ impl Client { arg: *mut c_void, ) -> ssl::SECStatus { let mut info: MaybeUninit = MaybeUninit::uninit(); - if ssl::SSL_GetResumptionTokenInfo( + let info_res = &ssl::SSL_GetResumptionTokenInfo( token, len, info.as_mut_ptr(), c_uint::try_from(mem::size_of::()).unwrap(), - ) - .is_err() - { + ); + if info_res.is_err() { // Ignore the token. return ssl::SECSuccess; } diff --git a/neqo-transport/src/cc/classic_cc.rs b/neqo-transport/src/cc/classic_cc.rs index 6914e91f67..23f0d04bd2 100644 --- a/neqo-transport/src/cc/classic_cc.rs +++ b/neqo-transport/src/cc/classic_cc.rs @@ -17,9 +17,9 @@ use crate::{ cc::MAX_DATAGRAM_SIZE, packet::PacketNumber, qlog::{self, QlogMetric}, + recovery::SentPacket, rtt::RttEstimate, sender::PACING_BURST_SIZE, - tracking::SentPacket, }; #[rustfmt::skip] // to keep `::` and thus prevent conflict with `crate::qlog` use ::qlog::events::{quic::CongestionStateUpdated, EventData}; @@ -167,8 +167,8 @@ impl CongestionControl for ClassicCongestionControl { qdebug!( "packet_acked this={:p}, pn={}, ps={}, ignored={}, lost={}, rtt_est={:?}", self, - pkt.pn, - pkt.size, + pkt.pn(), + pkt.len(), i32::from(!pkt.cc_outstanding()), i32::from(pkt.lost()), rtt_est, @@ -176,12 +176,12 @@ impl CongestionControl for ClassicCongestionControl { if !pkt.cc_outstanding() { continue; } - if pkt.pn < self.first_app_limited { + if pkt.pn() < self.first_app_limited { is_app_limited = false; } // BIF is set to 0 on a path change, but in case that was because of a simple rebinding // event, we may still get ACKs for packets sent before the rebinding. - self.bytes_in_flight = self.bytes_in_flight.saturating_sub(pkt.size); + self.bytes_in_flight = self.bytes_in_flight.saturating_sub(pkt.len()); if !self.after_recovery_start(pkt) { // Do not increase congestion window for packets sent before @@ -194,7 +194,7 @@ impl CongestionControl for ClassicCongestionControl { qlog::metrics_updated(&mut self.qlog, &[QlogMetric::InRecovery(false)]); } - new_acked += pkt.size; + new_acked += pkt.len(); } if is_app_limited { @@ -269,12 +269,12 @@ impl CongestionControl for ClassicCongestionControl { qdebug!( "packet_lost this={:p}, pn={}, ps={}", self, - pkt.pn, - pkt.size + pkt.pn(), + pkt.len() ); // BIF is set to 0 on a path change, but in case that was because of a simple rebinding // event, we may still declare packets lost that were sent before the rebinding. - self.bytes_in_flight = self.bytes_in_flight.saturating_sub(pkt.size); + self.bytes_in_flight = self.bytes_in_flight.saturating_sub(pkt.len()); } qlog::metrics_updated( &mut self.qlog, @@ -308,13 +308,13 @@ impl CongestionControl for ClassicCongestionControl { fn discard(&mut self, pkt: &SentPacket) { if pkt.cc_outstanding() { - assert!(self.bytes_in_flight >= pkt.size); - self.bytes_in_flight -= pkt.size; + assert!(self.bytes_in_flight >= pkt.len()); + self.bytes_in_flight -= pkt.len(); qlog::metrics_updated( &mut self.qlog, &[QlogMetric::BytesInFlight(self.bytes_in_flight)], ); - qtrace!([self], "Ignore pkt with size {}", pkt.size); + qtrace!([self], "Ignore pkt with size {}", pkt.len()); } } @@ -329,7 +329,7 @@ impl CongestionControl for ClassicCongestionControl { fn on_packet_sent(&mut self, pkt: &SentPacket) { // Record the recovery time and exit any transient state. if self.state.transient() { - self.recovery_start = Some(pkt.pn); + self.recovery_start = Some(pkt.pn()); self.state.update(); } @@ -341,15 +341,15 @@ impl CongestionControl for ClassicCongestionControl { // window. Assume that all in-flight packets up to this one are NOT app-limited. // However, subsequent packets might be app-limited. Set `first_app_limited` to the // next packet number. - self.first_app_limited = pkt.pn + 1; + self.first_app_limited = pkt.pn() + 1; } - self.bytes_in_flight += pkt.size; + self.bytes_in_flight += pkt.len(); qdebug!( "packet_sent this={:p}, pn={}, ps={}", self, - pkt.pn, - pkt.size + pkt.pn(), + pkt.len() ); qlog::metrics_updated( &mut self.qlog, @@ -448,20 +448,20 @@ impl ClassicCongestionControl { let cutoff = max(first_rtt_sample_time, prev_largest_acked_sent); for p in lost_packets .iter() - .skip_while(|p| Some(p.time_sent) < cutoff) + .skip_while(|p| Some(p.time_sent()) < cutoff) { - if p.pn != last_pn + 1 { + if p.pn() != last_pn + 1 { // Not a contiguous range of lost packets, start over. start = None; } - last_pn = p.pn; + last_pn = p.pn(); if !p.cc_in_flight() { // Not interesting, keep looking. continue; } if let Some(t) = start { let elapsed = p - .time_sent + .time_sent() .checked_duration_since(t) .expect("time is monotonic"); if elapsed > pc_period { @@ -476,7 +476,7 @@ impl ClassicCongestionControl { return true; } } else { - start = Some(p.time_sent); + start = Some(p.time_sent()); } } false @@ -490,7 +490,7 @@ impl ClassicCongestionControl { // state and update the variable `self.recovery_start`. Before the // first recovery, all packets were sent after the recovery event, // allowing to reduce the cwnd on congestion events. - !self.state.transient() && self.recovery_start.map_or(true, |pn| packet.pn >= pn) + !self.state.transient() && self.recovery_start.map_or(true, |pn| packet.pn() >= pn) } /// Handle a congestion event. @@ -560,8 +560,8 @@ mod tests { CongestionControl, CongestionControlAlgorithm, CWND_INITIAL_PKTS, MAX_DATAGRAM_SIZE, }, packet::{PacketNumber, PacketType}, + recovery::SentPacket, rtt::RttEstimate, - tracking::SentPacket, }; const PTO: Duration = Duration::from_millis(100); @@ -923,13 +923,13 @@ mod tests { fn persistent_congestion_ack_eliciting() { let mut lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); lost[0] = SentPacket::new( - lost[0].pt, - lost[0].pn, - lost[0].ecn_mark, - lost[0].time_sent, + lost[0].packet_type(), + lost[0].pn(), + lost[0].ecn_mark(), + lost[0].time_sent(), false, Vec::new(), - lost[0].size, + lost[0].len(), ); assert!(!persistent_congestion_by_pto( ClassicCongestionControl::new(NewReno::default()), diff --git a/neqo-transport/src/cc/mod.rs b/neqo-transport/src/cc/mod.rs index 2adffbc0c4..e85413b491 100644 --- a/neqo-transport/src/cc/mod.rs +++ b/neqo-transport/src/cc/mod.rs @@ -14,7 +14,7 @@ use std::{ use neqo_common::qlog::NeqoQlog; -use crate::{path::PATH_MTU_V6, rtt::RttEstimate, tracking::SentPacket, Error}; +use crate::{path::PATH_MTU_V6, recovery::SentPacket, rtt::RttEstimate, Error}; mod classic_cc; mod cubic; diff --git a/neqo-transport/src/cc/tests/cubic.rs b/neqo-transport/src/cc/tests/cubic.rs index 8ff591cb47..4d8c436cc4 100644 --- a/neqo-transport/src/cc/tests/cubic.rs +++ b/neqo-transport/src/cc/tests/cubic.rs @@ -25,8 +25,8 @@ use crate::{ CongestionControl, MAX_DATAGRAM_SIZE, MAX_DATAGRAM_SIZE_F64, }, packet::PacketType, + recovery::SentPacket, rtt::RttEstimate, - tracking::SentPacket, }; const RTT: Duration = Duration::from_millis(100); diff --git a/neqo-transport/src/cc/tests/new_reno.rs b/neqo-transport/src/cc/tests/new_reno.rs index 0cc560bf2b..a82e4995f4 100644 --- a/neqo-transport/src/cc/tests/new_reno.rs +++ b/neqo-transport/src/cc/tests/new_reno.rs @@ -17,8 +17,8 @@ use crate::{ MAX_DATAGRAM_SIZE, }, packet::PacketType, + recovery::SentPacket, rtt::RttEstimate, - tracking::SentPacket, }; const PTO: Duration = Duration::from_millis(100); @@ -133,14 +133,14 @@ fn issue_876() { // and ack it. cwnd increases slightly cc.on_packets_acked(&sent_packets[6..], &RTT_ESTIMATE, time_now); - assert_eq!(cc.acked_bytes(), sent_packets[6].size); + assert_eq!(cc.acked_bytes(), sent_packets[6].len()); cwnd_is_halved(&cc); assert_eq!(cc.bytes_in_flight(), 5 * MAX_DATAGRAM_SIZE - 2); // Packet from before is lost. Should not hurt cwnd. cc.on_packets_lost(Some(time_now), None, PTO, &sent_packets[1..2]); assert!(!cc.recovery_packet()); - assert_eq!(cc.acked_bytes(), sent_packets[6].size); + assert_eq!(cc.acked_bytes(), sent_packets[6].len()); cwnd_is_halved(&cc); assert_eq!(cc.bytes_in_flight(), 4 * MAX_DATAGRAM_SIZE); } diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 732cd31cf4..fcc41dab83 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -46,7 +46,7 @@ use crate::{ path::{Path, PathRef, Paths}, qlog, quic_datagrams::{DatagramTracking, QuicDatagrams}, - recovery::{LossRecovery, RecoveryToken, SendProfile}, + recovery::{LossRecovery, RecoveryToken, SendProfile, SentPacket}, recv_stream::RecvStreamStats, rtt::{RttEstimate, GRANULARITY}, send_stream::SendStream, @@ -57,7 +57,7 @@ use crate::{ self, TransportParameter, TransportParameterId, TransportParameters, TransportParametersHandler, }, - tracking::{AckTracker, PacketNumberSpace, RecvdPackets, SentPacket}, + tracking::{AckTracker, PacketNumberSpace, RecvdPackets}, version::{Version, WireVersion}, AppError, CloseReason, Error, Res, StreamId, }; @@ -2370,7 +2370,7 @@ impl Connection { packets.len(), mtu ); - initial.size += mtu - packets.len(); + initial.track_padding(mtu - packets.len()); // These zeros aren't padding frames, they are an invalid all-zero coalesced // packet, which is why we don't increase `frame_tx.padding` count here. packets.resize(mtu, 0); @@ -2894,7 +2894,7 @@ impl Connection { /// to retransmit the frame as needed. fn handle_lost_packets(&mut self, lost_packets: &[SentPacket]) { for lost in lost_packets { - for token in &lost.tokens { + for token in lost.tokens() { qdebug!([self], "Lost: {:?}", token); match token { RecoveryToken::Ack(_) => {} @@ -2930,13 +2930,13 @@ impl Connection { fn handle_ack( &mut self, space: PacketNumberSpace, - largest_acknowledged: u64, + largest_acknowledged: PacketNumber, ack_ranges: R, ack_ecn: Option, ack_delay: u64, now: Instant, ) where - R: IntoIterator> + Debug, + R: IntoIterator> + Debug, R::IntoIter: ExactSizeIterator, { qdebug!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges); @@ -2954,7 +2954,7 @@ impl Connection { now, ); for acked in acked_packets { - for token in &acked.tokens { + for token in acked.tokens() { match token { RecoveryToken::Stream(stream_token) => self.streams.acked(stream_token), RecoveryToken::Ack(at) => self.acks.acked(at), diff --git a/neqo-transport/src/ecn.rs b/neqo-transport/src/ecn.rs index 20eb4da003..cd550e589c 100644 --- a/neqo-transport/src/ecn.rs +++ b/neqo-transport/src/ecn.rs @@ -9,7 +9,7 @@ use std::ops::{AddAssign, Deref, DerefMut, Sub}; use enum_map::EnumMap; use neqo_common::{qdebug, qinfo, qwarn, IpTosEcn}; -use crate::{packet::PacketNumber, tracking::SentPacket}; +use crate::{packet::PacketNumber, recovery::SentPacket}; /// The number of packets to use for testing a path for ECN capability. pub const ECN_TEST_COUNT: usize = 10; @@ -159,7 +159,7 @@ impl EcnInfo { // > Validating ECN counts from reordered ACK frames can result in failure. An endpoint MUST // > NOT fail ECN validation as a result of processing an ACK frame that does not increase // > the largest acknowledged packet number. - let largest_acked = acked_packets.first().expect("must be there").pn; + let largest_acked = acked_packets.first().expect("must be there").pn(); if largest_acked <= self.largest_acked { return; } @@ -186,7 +186,7 @@ impl EcnInfo { // > ECT(0) marking. let newly_acked_sent_with_ect0: u64 = acked_packets .iter() - .filter(|p| p.ecn_mark == IpTosEcn::Ect0) + .filter(|p| p.ecn_mark() == IpTosEcn::Ect0) .count() .try_into() .unwrap(); diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index 0e4c82b1ca..037a04408a 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -25,11 +25,11 @@ use crate::{ ecn::{EcnCount, EcnInfo}, frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID}, packet::PacketBuilder, - recovery::RecoveryToken, + recovery::{RecoveryToken, SentPacket}, rtt::RttEstimate, sender::PacketSender, stats::FrameStats, - tracking::{PacketNumberSpace, SentPacket}, + tracking::PacketNumberSpace, Stats, }; @@ -954,12 +954,12 @@ impl Path { qinfo!( [self], "discarding a packet without an RTT estimate; guessing RTT={:?}", - now - sent.time_sent + now - sent.time_sent() ); stats.rtt_init_guess = true; self.rtt.update( &mut self.qlog, - now - sent.time_sent, + now - sent.time_sent(), Duration::new(0, 0), false, now, diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index 715ba85e81..7d28d94d96 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -27,9 +27,9 @@ use crate::{ frame::{CloseError, Frame}, packet::{DecryptedPacket, PacketNumber, PacketType, PublicPacket}, path::PathRef, + recovery::SentPacket, stream_id::StreamType as NeqoStreamType, tparams::{self, TransportParametersHandler}, - tracking::SentPacket, version::{Version, VersionConfig, WireVersion}, }; @@ -254,7 +254,8 @@ pub fn packet_dropped(qlog: &mut NeqoQlog, public_packet: &PublicPacket) { pub fn packets_lost(qlog: &mut NeqoQlog, pkts: &[SentPacket]) { qlog.add_event_with_stream(|stream| { for pkt in pkts { - let header = PacketHeader::with_type(pkt.pt.into(), Some(pkt.pn), None, None, None); + let header = + PacketHeader::with_type(pkt.packet_type().into(), Some(pkt.pn()), None, None, None); let ev_data = EventData::PacketLost(PacketLost { header: Some(header), diff --git a/neqo-transport/src/recovery.rs b/neqo-transport/src/recovery/mod.rs similarity index 90% rename from neqo-transport/src/recovery.rs rename to neqo-transport/src/recovery/mod.rs index 22a635d9f3..b181bd88a0 100644 --- a/neqo-transport/src/recovery.rs +++ b/neqo-transport/src/recovery/mod.rs @@ -6,31 +6,30 @@ // Tracking of sent packets and detecting their loss. +mod sent; +mod token; + use std::{ cmp::{max, min}, - collections::BTreeMap, - mem, + convert::TryFrom, ops::RangeInclusive, time::{Duration, Instant}, }; use neqo_common::{qdebug, qinfo, qlog::NeqoQlog, qtrace, qwarn}; +pub use sent::SentPacket; +use sent::SentPackets; use smallvec::{smallvec, SmallVec}; +pub use token::{RecoveryToken, StreamRecoveryToken}; use crate::{ - ackrate::AckRate, - cid::ConnectionIdEntry, - crypto::CryptoRecoveryToken, ecn::EcnCount, packet::PacketNumber, path::{Path, PathRef}, qlog::{self, QlogMetric}, - quic_datagrams::DatagramTracking, rtt::RttEstimate, - send_stream::SendStreamRecoveryToken, stats::{Stats, StatsCell}, - stream_id::{StreamId, StreamType}, - tracking::{AckToken, PacketNumberSpace, PacketNumberSpaceSet, SentPacket}, + tracking::{PacketNumberSpace, PacketNumberSpaceSet}, }; pub(crate) const PACKET_THRESHOLD: u64 = 3; @@ -49,54 +48,6 @@ pub(crate) const MIN_OUTSTANDING_UNACK: usize = 16; /// The scale we use for the fast PTO feature. pub const FAST_PTO_SCALE: u8 = 100; -#[derive(Debug, Clone)] -#[allow(clippy::module_name_repetitions)] -pub enum StreamRecoveryToken { - Stream(SendStreamRecoveryToken), - ResetStream { - stream_id: StreamId, - }, - StopSending { - stream_id: StreamId, - }, - - MaxData(u64), - DataBlocked(u64), - - MaxStreamData { - stream_id: StreamId, - max_data: u64, - }, - StreamDataBlocked { - stream_id: StreamId, - limit: u64, - }, - - MaxStreams { - stream_type: StreamType, - max_streams: u64, - }, - StreamsBlocked { - stream_type: StreamType, - limit: u64, - }, -} - -#[derive(Debug, Clone)] -#[allow(clippy::module_name_repetitions)] -pub enum RecoveryToken { - Stream(StreamRecoveryToken), - Ack(AckToken), - Crypto(CryptoRecoveryToken), - HandshakeDone, - KeepAlive, // Special PING. - NewToken(usize), - NewConnectionId(ConnectionIdEntry<[u8; 16]>), - RetireConnectionId(u64), - AckFrequency(AckRate), - Datagram(DatagramTracking), -} - /// `SendProfile` tells a sender how to send packets. #[derive(Debug)] pub struct SendProfile { @@ -181,7 +132,8 @@ pub(crate) struct LossRecoverySpace { /// This might be less than the number of ACK-eliciting packets, /// because PTO packets don't count. in_flight_outstanding: usize, - sent_packets: BTreeMap, + /// The packets that we have sent and are tracking. + sent_packets: SentPackets, /// The time that the first out-of-order packet was sent. /// This is `None` if there were no out-of-order packets detected. /// When set to `Some(T)`, time-based loss detection should be enabled. @@ -196,7 +148,7 @@ impl LossRecoverySpace { largest_acked_sent_time: None, last_ack_eliciting: None, in_flight_outstanding: 0, - sent_packets: BTreeMap::default(), + sent_packets: SentPackets::default(), first_ooo_time: None, } } @@ -221,9 +173,9 @@ impl LossRecoverySpace { pub fn pto_packets(&mut self, count: usize) -> impl Iterator { self.sent_packets .iter_mut() - .filter_map(|(pn, sent)| { + .filter_map(|sent| { if sent.pto() { - qtrace!("PTO: marking packet {} lost ", pn); + qtrace!("PTO: marking packet {} lost ", sent.pn()); Some(&*sent) } else { None @@ -256,16 +208,16 @@ impl LossRecoverySpace { pub fn on_packet_sent(&mut self, sent_packet: SentPacket) { if sent_packet.ack_eliciting() { - self.last_ack_eliciting = Some(sent_packet.time_sent); + self.last_ack_eliciting = Some(sent_packet.time_sent()); self.in_flight_outstanding += 1; } else if self.space != PacketNumberSpace::ApplicationData && self.last_ack_eliciting.is_none() { // For Initial and Handshake spaces, make sure that we have a PTO baseline // always. See `LossRecoverySpace::pto_base_time()` for details. - self.last_ack_eliciting = Some(sent_packet.time_sent); + self.last_ack_eliciting = Some(sent_packet.time_sent()); } - self.sent_packets.insert(sent_packet.pn, sent_packet); + self.sent_packets.track(sent_packet); } /// If we are only sending ACK frames, send a PING frame after 2 PTOs so that @@ -285,56 +237,42 @@ impl LossRecoverySpace { .map_or(false, |t| now > t + (pto * n_pto)) } + fn remove_outstanding(&mut self, count: usize) { + debug_assert!(self.in_flight_outstanding >= count); + self.in_flight_outstanding -= count; + if self.in_flight_outstanding == 0 { + qtrace!("remove_packet outstanding == 0 for space {}", self.space); + } + } + fn remove_packet(&mut self, p: &SentPacket) { if p.ack_eliciting() { - debug_assert!(self.in_flight_outstanding > 0); - self.in_flight_outstanding -= 1; - if self.in_flight_outstanding == 0 { - qtrace!("remove_packet outstanding == 0 for space {}", self.space); - } + self.remove_outstanding(1); } } - /// Remove all acknowledged packets. + /// Remove all newly acknowledged packets. /// Returns all the acknowledged packets, with the largest packet number first. /// ...and a boolean indicating if any of those packets were ack-eliciting. /// This operates more efficiently because it assumes that the input is sorted /// in the order that an ACK frame is (from the top). fn remove_acked(&mut self, acked_ranges: R, stats: &mut Stats) -> (Vec, bool) where - R: IntoIterator>, + R: IntoIterator>, R::IntoIter: ExactSizeIterator, { - let acked_ranges = acked_ranges.into_iter(); - let mut keep = Vec::with_capacity(acked_ranges.len()); - - let mut acked = Vec::new(); + let acked = self.sent_packets.take_ranges(acked_ranges); let mut eliciting = false; - for range in acked_ranges { - let first_keep = *range.end() + 1; - if let Some((&first, _)) = self.sent_packets.range(range).next() { - let mut tail = self.sent_packets.split_off(&first); - if let Some((&next, _)) = tail.range(first_keep..).next() { - keep.push(tail.split_off(&next)); - } - for (_, p) in tail.into_iter().rev() { - self.remove_packet(&p); - eliciting |= p.ack_eliciting(); - if p.lost() { - stats.late_ack += 1; - } - if p.pto_fired() { - stats.pto_ack += 1; - } - acked.push(p); - } + for p in &acked { + self.remove_packet(p); + eliciting |= p.ack_eliciting(); + if p.lost() { + stats.late_ack += 1; + } + if p.pto_fired() { + stats.pto_ack += 1; } } - - for mut k in keep.into_iter().rev() { - self.sent_packets.append(&mut k); - } - (acked, eliciting) } @@ -343,12 +281,12 @@ impl LossRecoverySpace { /// and when keys are dropped. fn remove_ignored(&mut self) -> impl Iterator { self.in_flight_outstanding = 0; - mem::take(&mut self.sent_packets).into_values() + std::mem::take(&mut self.sent_packets).drain_all() } /// Remove the primary path marking on any packets this is tracking. fn migrate(&mut self) { - for pkt in self.sent_packets.values_mut() { + for pkt in self.sent_packets.iter_mut() { pkt.clear_primary_path(); } } @@ -357,26 +295,9 @@ impl LossRecoverySpace { /// We try to keep these around until a probe is sent for them, so it is /// important that `cd` is set to at least the current PTO time; otherwise we /// might remove all in-flight packets and stop sending probes. - #[allow(clippy::option_if_let_else)] // Hard enough to read as-is. fn remove_old_lost(&mut self, now: Instant, cd: Duration) { - let mut it = self.sent_packets.iter(); - // If the first item is not expired, do nothing. - if it.next().map_or(false, |(_, p)| p.expired(now, cd)) { - // Find the index of the first unexpired packet. - let to_remove = if let Some(first_keep) = - it.find_map(|(i, p)| if p.expired(now, cd) { None } else { Some(*i) }) - { - // Some packets haven't expired, so keep those. - let keep = self.sent_packets.split_off(&first_keep); - mem::replace(&mut self.sent_packets, keep) - } else { - // All packets are expired. - mem::take(&mut self.sent_packets) - }; - for (_, p) in to_remove { - self.remove_packet(&p); - } - } + let removed = self.sent_packets.remove_expired(now, cd); + self.remove_outstanding(removed); } /// Detect lost packets. @@ -402,44 +323,39 @@ impl LossRecoverySpace { let largest_acked = self.largest_acked; - // Lost for retrans/CC purposes - let mut lost_pns = SmallVec::<[_; 8]>::new(); - - for (pn, packet) in self + for packet in self .sent_packets .iter_mut() // BTreeMap iterates in order of ascending PN - .take_while(|(&k, _)| k < largest_acked.unwrap_or(PacketNumber::MAX)) + .take_while(|p| p.pn() < largest_acked.unwrap_or(PacketNumber::MAX)) { // Packets sent before now - loss_delay are deemed lost. - if packet.time_sent + loss_delay <= now { + if packet.time_sent() + loss_delay <= now { qtrace!( "lost={}, time sent {:?} is before lost_delay {:?}", - pn, - packet.time_sent, + packet.pn(), + packet.time_sent(), loss_delay ); - } else if largest_acked >= Some(*pn + PACKET_THRESHOLD) { + } else if largest_acked >= Some(packet.pn() + PACKET_THRESHOLD) { qtrace!( "lost={}, is >= {} from largest acked {:?}", - pn, + packet.pn(), PACKET_THRESHOLD, largest_acked ); } else { if largest_acked.is_some() { - self.first_ooo_time = Some(packet.time_sent); + self.first_ooo_time = Some(packet.time_sent()); } // No more packets can be declared lost after this one. break; }; if packet.declare_lost(now) { - lost_pns.push(*pn); + lost_packets.push(packet.clone()); } } - - lost_packets.extend(lost_pns.iter().map(|pn| self.sent_packets[pn].clone())); } } @@ -629,8 +545,8 @@ impl LossRecovery { } pub fn on_packet_sent(&mut self, path: &PathRef, mut sent_packet: SentPacket) { - let pn_space = PacketNumberSpace::from(sent_packet.pt); - qdebug!([self], "packet {}-{} sent", pn_space, sent_packet.pn); + let pn_space = PacketNumberSpace::from(sent_packet.packet_type()); + qdebug!([self], "packet {}-{} sent", pn_space, sent_packet.pn()); if let Some(space) = self.spaces.get_mut(pn_space) { path.borrow_mut().packet_sent(&mut sent_packet); space.on_packet_sent(sent_packet); @@ -639,7 +555,7 @@ impl LossRecovery { [self], "ignoring {}-{} from dropped space", pn_space, - sent_packet.pn + sent_packet.pn() ); } } @@ -671,14 +587,14 @@ impl LossRecovery { &mut self, primary_path: &PathRef, pn_space: PacketNumberSpace, - largest_acked: u64, + largest_acked: PacketNumber, acked_ranges: R, ack_ecn: Option, ack_delay: Duration, now: Instant, ) -> (Vec, Vec) where - R: IntoIterator>, + R: IntoIterator>, R::IntoIter: ExactSizeIterator, { qdebug!( @@ -707,11 +623,11 @@ impl LossRecovery { // If the largest acknowledged is newly acked and any newly acked // packet was ack-eliciting, update the RTT. (-recovery 5.1) - space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent); + space.largest_acked_sent_time = Some(largest_acked_pkt.time_sent()); if any_ack_eliciting && largest_acked_pkt.on_primary_path() { self.rtt_sample( primary_path.borrow_mut().rtt_mut(), - largest_acked_pkt.time_sent, + largest_acked_pkt.time_sent(), now, ack_delay, ); @@ -1019,6 +935,7 @@ impl ::std::fmt::Display for LossRecovery { mod tests { use std::{ cell::RefCell, + convert::TryInto, ops::{Deref, DerefMut, RangeInclusive}, rc::Rc, time::{Duration, Instant}, @@ -1034,7 +951,7 @@ mod tests { cc::CongestionControlAlgorithm, cid::{ConnectionId, ConnectionIdEntry}, ecn::EcnCount, - packet::PacketType, + packet::{PacketNumber, PacketType}, path::{Path, PathRef}, rtt::RttEstimate, stats::{Stats, StatsCell}, @@ -1061,8 +978,8 @@ mod tests { pub fn on_ack_received( &mut self, pn_space: PacketNumberSpace, - largest_acked: u64, - acked_ranges: Vec>, + largest_acked: PacketNumber, + acked_ranges: Vec>, ack_ecn: Option, ack_delay: Duration, now: Instant, @@ -1235,8 +1152,8 @@ mod tests { ); } - fn add_sent(lrs: &mut LossRecoverySpace, packet_numbers: &[u64]) { - for &pn in packet_numbers { + fn add_sent(lrs: &mut LossRecoverySpace, max_pn: PacketNumber) { + for pn in 0..=max_pn { lrs.on_packet_sent(SentPacket::new( PacketType::Short, pn, @@ -1249,15 +1166,18 @@ mod tests { } } - fn match_acked(acked: &[SentPacket], expected: &[u64]) { - assert!(acked.iter().map(|p| &p.pn).eq(expected)); + fn match_acked(acked: &[SentPacket], expected: &[PacketNumber]) { + assert_eq!( + acked.iter().map(SentPacket::pn).collect::>(), + expected + ); } #[test] fn remove_acked() { let mut lrs = LossRecoverySpace::new(PacketNumberSpace::ApplicationData); let mut stats = Stats::default(); - add_sent(&mut lrs, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + add_sent(&mut lrs, 10); let (acked, _) = lrs.remove_acked(vec![], &mut stats); assert!(acked.is_empty()); let (acked, _) = lrs.remove_acked(vec![7..=8, 2..=4], &mut stats); @@ -1265,7 +1185,7 @@ mod tests { let (acked, _) = lrs.remove_acked(vec![8..=11], &mut stats); match_acked(&acked, &[10, 9]); let (acked, _) = lrs.remove_acked(vec![0..=2], &mut stats); - match_acked(&acked, &[1]); + match_acked(&acked, &[1, 0]); let (acked, _) = lrs.remove_acked(vec![5..=6], &mut stats); match_acked(&acked, &[6, 5]); } @@ -1517,7 +1437,7 @@ mod tests { Vec::new(), ON_SENT_SIZE, ); - let pn_space = PacketNumberSpace::from(sent_pkt.pt); + let pn_space = PacketNumberSpace::from(sent_pkt.packet_type()); lr.on_packet_sent(sent_pkt); lr.on_ack_received( pn_space, @@ -1630,7 +1550,7 @@ mod tests { lr.on_packet_sent(SentPacket::new( PacketType::Initial, - 1, + 0, IpTosEcn::default(), now(), true, diff --git a/neqo-transport/src/recovery/sent.rs b/neqo-transport/src/recovery/sent.rs new file mode 100644 index 0000000000..1a7be9dd4a --- /dev/null +++ b/neqo-transport/src/recovery/sent.rs @@ -0,0 +1,379 @@ +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// A collection for sent packets. + +use std::{ + collections::BTreeMap, + ops::RangeInclusive, + time::{Duration, Instant}, +}; + +use neqo_common::IpTosEcn; + +use crate::{ + packet::{PacketNumber, PacketType}, + recovery::RecoveryToken, +}; + +#[derive(Debug, Clone)] +pub struct SentPacket { + pt: PacketType, + pn: PacketNumber, + ecn_mark: IpTosEcn, + ack_eliciting: bool, + time_sent: Instant, + primary_path: bool, + tokens: Vec, + + time_declared_lost: Option, + /// After a PTO, this is true when the packet has been released. + pto: bool, + + len: usize, +} + +impl SentPacket { + pub fn new( + pt: PacketType, + pn: PacketNumber, + ecn_mark: IpTosEcn, + time_sent: Instant, + ack_eliciting: bool, + tokens: Vec, + len: usize, + ) -> Self { + Self { + pt, + pn, + ecn_mark, + time_sent, + ack_eliciting, + primary_path: true, + tokens, + time_declared_lost: None, + pto: false, + len, + } + } + + /// The type of this packet. + pub fn packet_type(&self) -> PacketType { + self.pt + } + + /// The number of the packet. + pub fn pn(&self) -> PacketNumber { + self.pn + } + + /// The ECN mark of the packet. + pub fn ecn_mark(&self) -> IpTosEcn { + self.ecn_mark + } + + /// The time that this packet was sent. + pub fn time_sent(&self) -> Instant { + self.time_sent + } + + /// Returns `true` if the packet will elicit an ACK. + pub fn ack_eliciting(&self) -> bool { + self.ack_eliciting + } + + /// Returns `true` if the packet was sent on the primary path. + pub fn on_primary_path(&self) -> bool { + self.primary_path + } + + /// The length of the packet that was sent. + pub fn len(&self) -> usize { + self.len + } + + /// Access the recovery tokens that this holds. + pub fn tokens(&self) -> &[RecoveryToken] { + &self.tokens + } + + /// Clears the flag that had this packet on the primary path. + /// Used when migrating to clear out state. + pub fn clear_primary_path(&mut self) { + self.primary_path = false; + } + + /// For Initial packets, it is possible that the packet builder needs to amend the length. + pub fn track_padding(&mut self, padding: usize) { + debug_assert_eq!(self.pt, PacketType::Initial); + self.len += padding; + } + + /// Whether the packet has been declared lost. + pub fn lost(&self) -> bool { + self.time_declared_lost.is_some() + } + + /// Whether accounting for the loss or acknowledgement in the + /// congestion controller is pending. + /// Returns `true` if the packet counts as being "in flight", + /// and has not previously been declared lost. + /// Note that this should count packets that contain only ACK and PADDING, + /// but we don't send PADDING, so we don't track that. + pub fn cc_outstanding(&self) -> bool { + self.ack_eliciting() && self.on_primary_path() && !self.lost() + } + + /// Whether the packet should be tracked as in-flight. + pub fn cc_in_flight(&self) -> bool { + self.ack_eliciting() && self.on_primary_path() + } + + /// Declare the packet as lost. Returns `true` if this is the first time. + pub fn declare_lost(&mut self, now: Instant) -> bool { + if self.lost() { + false + } else { + self.time_declared_lost = Some(now); + true + } + } + + /// Ask whether this tracked packet has been declared lost for long enough + /// that it can be expired and no longer tracked. + pub fn expired(&self, now: Instant, expiration_period: Duration) -> bool { + self.time_declared_lost + .map_or(false, |loss_time| (loss_time + expiration_period) <= now) + } + + /// Whether the packet contents were cleared out after a PTO. + pub fn pto_fired(&self) -> bool { + self.pto + } + + /// On PTO, we need to get the recovery tokens so that we can ensure that + /// the frames we sent can be sent again in the PTO packet(s). Do that just once. + pub fn pto(&mut self) -> bool { + if self.pto || self.lost() { + false + } else { + self.pto = true; + true + } + } +} + +/// A collection for packets that we have sent that haven't been acknowledged. +#[derive(Debug, Default)] +pub struct SentPackets { + /// The collection. + packets: BTreeMap, +} + +impl SentPackets { + pub fn len(&self) -> usize { + self.packets.len() + } + + pub fn track(&mut self, packet: SentPacket) { + self.packets.insert(packet.pn, packet); + } + + pub fn iter_mut(&mut self) -> impl Iterator { + self.packets.values_mut() + } + + /// Take values from a specified ranges of packet numbers. + /// The values returned will be reversed, so that the most recent packet appears first. + /// This is because ACK frames arrive with ranges starting from the largest acknowledged + /// and we want to match that. + pub fn take_ranges(&mut self, acked_ranges: R) -> Vec + where + R: IntoIterator>, + R::IntoIter: ExactSizeIterator, + { + let mut result = Vec::new(); + // Remove all packets. We will add them back as we don't need them. + let mut packets = std::mem::take(&mut self.packets); + for range in acked_ranges { + // For each acked range, split off the acknowledged part, + // then split off the part that hasn't been acknowledged. + // This order works better when processing ranges that + // have already been processed, which is common. + let mut acked = packets.split_off(range.start()); + let keep = acked.split_off(&(*range.end() + 1)); + self.packets.extend(keep); + result.extend(acked.into_values().rev()); + } + self.packets.extend(packets); + result + } + + /// Empty out the packets, but keep the offset. + pub fn drain_all(&mut self) -> impl Iterator { + std::mem::take(&mut self.packets).into_values() + } + + /// See `LossRecoverySpace::remove_old_lost` for details on `now` and `cd`. + /// Returns the number of ack-eliciting packets removed. + pub fn remove_expired(&mut self, now: Instant, cd: Duration) -> usize { + let mut it = self.packets.iter(); + // If the first item is not expired, do nothing (the most common case). + if it.next().map_or(false, |(_, p)| p.expired(now, cd)) { + // Find the index of the first unexpired packet. + let to_remove = if let Some(first_keep) = + it.find_map(|(i, p)| if p.expired(now, cd) { None } else { Some(*i) }) + { + // Some packets haven't expired, so keep those. + let keep = self.packets.split_off(&first_keep); + std::mem::replace(&mut self.packets, keep) + } else { + // All packets are expired. + std::mem::take(&mut self.packets) + }; + to_remove + .into_values() + .filter(SentPacket::ack_eliciting) + .count() + } else { + 0 + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + cell::OnceCell, + convert::TryFrom, + time::{Duration, Instant}, + }; + + use neqo_common::IpTosEcn; + + use super::{SentPacket, SentPackets}; + use crate::packet::{PacketNumber, PacketType}; + + const PACKET_GAP: Duration = Duration::from_secs(1); + fn start_time() -> Instant { + thread_local!(static STARTING_TIME: OnceCell = const { OnceCell::new() }); + STARTING_TIME.with(|t| *t.get_or_init(Instant::now)) + } + + fn pkt(n: u32) -> SentPacket { + SentPacket::new( + PacketType::Short, + PacketNumber::from(n), + IpTosEcn::default(), + start_time() + (PACKET_GAP * n), + true, + Vec::new(), + 100, + ) + } + + fn pkts() -> SentPackets { + let mut pkts = SentPackets::default(); + pkts.track(pkt(0)); + pkts.track(pkt(1)); + pkts.track(pkt(2)); + assert_eq!(pkts.len(), 3); + pkts + } + + trait HasPacketNumber { + fn pn(&self) -> PacketNumber; + } + impl HasPacketNumber for SentPacket { + fn pn(&self) -> PacketNumber { + self.pn + } + } + impl HasPacketNumber for &'_ SentPacket { + fn pn(&self) -> PacketNumber { + self.pn + } + } + impl HasPacketNumber for &'_ mut SentPacket { + fn pn(&self) -> PacketNumber { + self.pn + } + } + + fn remove_one(pkts: &mut SentPackets, idx: PacketNumber) { + assert_eq!(pkts.len(), 3); + let store = pkts.take_ranges([idx..=idx]); + let mut it = store.into_iter(); + assert_eq!(idx, it.next().unwrap().pn()); + assert!(it.next().is_none()); + std::mem::drop(it); + assert_eq!(pkts.len(), 2); + } + + fn assert_zero_and_two<'a, 'b: 'a>( + mut it: impl Iterator + 'a, + ) { + assert_eq!(it.next().unwrap().pn(), 0); + assert_eq!(it.next().unwrap().pn(), 2); + assert!(it.next().is_none()); + } + + #[test] + fn iterate_skipped() { + let mut pkts = pkts(); + for (i, p) in pkts.packets.values().enumerate() { + assert_eq!(i, usize::try_from(p.pn).unwrap()); + } + remove_one(&mut pkts, 1); + + // Validate the merged result multiple ways. + assert_zero_and_two(pkts.iter_mut()); + + { + // Reverse the expectations here as this iterator reverses its output. + let store = pkts.take_ranges([0..=2]); + let mut it = store.into_iter(); + assert_eq!(it.next().unwrap().pn(), 2); + assert_eq!(it.next().unwrap().pn(), 0); + assert!(it.next().is_none()); + }; + + // The None values are still there in this case, so offset is 0. + assert_eq!(pkts.packets.len(), 0); + assert_eq!(pkts.len(), 0); + } + + #[test] + fn drain() { + let mut pkts = pkts(); + remove_one(&mut pkts, 1); + + assert_zero_and_two(pkts.drain_all()); + assert_eq!(pkts.len(), 0); + } + + #[test] + fn remove_expired() { + let mut pkts = pkts(); + remove_one(&mut pkts, 0); + + for p in pkts.iter_mut() { + p.declare_lost(p.time_sent); // just to keep things simple. + } + + // Expire up to pkt(1). + let count = pkts.remove_expired(start_time() + PACKET_GAP, Duration::new(0, 0)); + assert_eq!(count, 1); + assert_eq!(pkts.len(), 1); + } + + #[test] + fn first_skipped_ok() { + let mut pkts = SentPackets::default(); + pkts.track(pkt(4)); // This is fine. + assert_eq!(pkts.len(), 1); + } +} diff --git a/neqo-transport/src/recovery/token.rs b/neqo-transport/src/recovery/token.rs new file mode 100644 index 0000000000..93f84268cd --- /dev/null +++ b/neqo-transport/src/recovery/token.rs @@ -0,0 +1,63 @@ +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::{ + ackrate::AckRate, + cid::ConnectionIdEntry, + crypto::CryptoRecoveryToken, + quic_datagrams::DatagramTracking, + send_stream::SendStreamRecoveryToken, + stream_id::{StreamId, StreamType}, + tracking::AckToken, +}; + +#[derive(Debug, Clone)] +#[allow(clippy::module_name_repetitions)] +pub enum StreamRecoveryToken { + Stream(SendStreamRecoveryToken), + ResetStream { + stream_id: StreamId, + }, + StopSending { + stream_id: StreamId, + }, + + MaxData(u64), + DataBlocked(u64), + + MaxStreamData { + stream_id: StreamId, + max_data: u64, + }, + StreamDataBlocked { + stream_id: StreamId, + limit: u64, + }, + + MaxStreams { + stream_type: StreamType, + max_streams: u64, + }, + StreamsBlocked { + stream_type: StreamType, + limit: u64, + }, +} + +#[derive(Debug, Clone)] +#[allow(clippy::module_name_repetitions)] +pub enum RecoveryToken { + Stream(StreamRecoveryToken), + Ack(AckToken), + Crypto(CryptoRecoveryToken), + HandshakeDone, + KeepAlive, // Special PING. + NewToken(usize), + NewConnectionId(ConnectionIdEntry<[u8; 16]>), + RetireConnectionId(u64), + AckFrequency(AckRate), + Datagram(DatagramTracking), +} diff --git a/neqo-transport/src/sender.rs b/neqo-transport/src/sender.rs index abb14d0a25..22abef4dc7 100644 --- a/neqo-transport/src/sender.rs +++ b/neqo-transport/src/sender.rs @@ -18,8 +18,8 @@ use neqo_common::qlog::NeqoQlog; use crate::{ cc::{ClassicCongestionControl, CongestionControl, CongestionControlAlgorithm, Cubic, NewReno}, pace::Pacer, + recovery::SentPacket, rtt::RttEstimate, - tracking::SentPacket, }; /// The number of packets we allow to burst from the pacer. @@ -114,7 +114,7 @@ impl PacketSender { pub fn on_packet_sent(&mut self, pkt: &SentPacket, rtt: Duration) { self.pacer - .spend(pkt.time_sent, rtt, self.cc.cwnd(), pkt.size); + .spend(pkt.time_sent(), rtt, self.cc.cwnd(), pkt.len()); self.cc.on_packet_sent(pkt); } diff --git a/neqo-transport/src/tracking.rs b/neqo-transport/src/tracking.rs index 6643d516e3..7c97b55f27 100644 --- a/neqo-transport/src/tracking.rs +++ b/neqo-transport/src/tracking.rs @@ -133,117 +133,6 @@ impl std::fmt::Debug for PacketNumberSpaceSet { } } -#[derive(Debug, Clone)] -pub struct SentPacket { - pub pt: PacketType, - pub pn: PacketNumber, - pub ecn_mark: IpTosEcn, - ack_eliciting: bool, - pub time_sent: Instant, - primary_path: bool, - pub tokens: Vec, - - time_declared_lost: Option, - /// After a PTO, this is true when the packet has been released. - pto: bool, - - pub size: usize, -} - -impl SentPacket { - pub fn new( - pt: PacketType, - pn: PacketNumber, - ecn_mark: IpTosEcn, - time_sent: Instant, - ack_eliciting: bool, - tokens: Vec, - size: usize, - ) -> Self { - Self { - pt, - pn, - ecn_mark, - time_sent, - ack_eliciting, - primary_path: true, - tokens, - time_declared_lost: None, - pto: false, - size, - } - } - - /// Returns `true` if the packet will elicit an ACK. - pub fn ack_eliciting(&self) -> bool { - self.ack_eliciting - } - - /// Returns `true` if the packet was sent on the primary path. - pub fn on_primary_path(&self) -> bool { - self.primary_path - } - - /// Clears the flag that had this packet on the primary path. - /// Used when migrating to clear out state. - pub fn clear_primary_path(&mut self) { - self.primary_path = false; - } - - /// Whether the packet has been declared lost. - pub fn lost(&self) -> bool { - self.time_declared_lost.is_some() - } - - /// Whether accounting for the loss or acknowledgement in the - /// congestion controller is pending. - /// Returns `true` if the packet counts as being "in flight", - /// and has not previously been declared lost. - /// Note that this should count packets that contain only ACK and PADDING, - /// but we don't send PADDING, so we don't track that. - pub fn cc_outstanding(&self) -> bool { - self.ack_eliciting() && self.on_primary_path() && !self.lost() - } - - /// Whether the packet should be tracked as in-flight. - pub fn cc_in_flight(&self) -> bool { - self.ack_eliciting() && self.on_primary_path() - } - - /// Declare the packet as lost. Returns `true` if this is the first time. - pub fn declare_lost(&mut self, now: Instant) -> bool { - if self.lost() { - false - } else { - self.time_declared_lost = Some(now); - true - } - } - - /// Ask whether this tracked packet has been declared lost for long enough - /// that it can be expired and no longer tracked. - pub fn expired(&self, now: Instant, expiration_period: Duration) -> bool { - self.time_declared_lost - .map_or(false, |loss_time| (loss_time + expiration_period) <= now) - } - - /// Whether the packet contents were cleared out after a PTO. - pub fn pto_fired(&self) -> bool { - self.pto - } - - /// On PTO, we need to get the recovery tokens so that we can ensure that - /// the frames we sent can be sent again in the PTO packet(s). Do that just once. - pub fn pto(&mut self) -> bool { - if self.pto || self.lost() { - false - } else { - self.pto = true; - true - } - } -} - impl std::fmt::Display for PacketNumberSpace { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.write_str(match self {