diff --git a/neqo-transport/Cargo.toml b/neqo-transport/Cargo.toml index 24a0926db7..3db4053d7d 100644 --- a/neqo-transport/Cargo.toml +++ b/neqo-transport/Cargo.toml @@ -22,6 +22,7 @@ indexmap = { version = "2.2", default-features = false } # See https://github.co log = { workspace = true } neqo-common = { path = "../neqo-common" } neqo-crypto = { path = "../neqo-crypto" } +mtu = { version = "0.1.3", default-features = false } qlog = { workspace = true } smallvec = { version = "1.11", default-features = false } static_assertions = { version = "1.1", default-features = false } diff --git a/neqo-transport/src/cc/classic_cc.rs b/neqo-transport/src/cc/classic_cc.rs index 1130178bc0..c255aad0fb 100644 --- a/neqo-transport/src/cc/classic_cc.rs +++ b/neqo-transport/src/cc/classic_cc.rs @@ -614,6 +614,7 @@ mod tests { }; const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); + const MTU: usize = 1_500; const PTO: Duration = Duration::from_millis(100); const RTT: Duration = Duration::from_millis(98); const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(Duration::from_millis(98)); @@ -652,11 +653,11 @@ mod tests { match cc { CongestionControlAlgorithm::NewReno => Box::new(ClassicCongestionControl::new( NewReno::default(), - Pmtud::new(IP_ADDR), + Pmtud::new(IP_ADDR, MTU), )), CongestionControlAlgorithm::Cubic => Box::new(ClassicCongestionControl::new( Cubic::default(), - Pmtud::new(IP_ADDR), + Pmtud::new(IP_ADDR, MTU), )), } } @@ -894,13 +895,13 @@ mod tests { fn persistent_congestion_no_lost() { let lost = make_lost(&[]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost @@ -912,13 +913,13 @@ mod tests { fn persistent_congestion_one_lost() { let lost = make_lost(&[1]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost @@ -932,37 +933,37 @@ mod tests { // sample are not considered. So 0 is ignored. let lost = make_lost(&[0, PERSISTENT_CONG_THRESH + 1, PERSISTENT_CONG_THRESH + 2]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 1, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 0, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 1, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 1, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 0, 1, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 1, 0, &lost @@ -983,13 +984,13 @@ mod tests { lost[0].len(), ); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost )); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost @@ -1003,13 +1004,13 @@ mod tests { fn persistent_congestion_min() { let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); assert!(persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost )); assert!(persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost @@ -1022,7 +1023,7 @@ mod tests { #[test] fn persistent_congestion_no_prev_ack_newreno() { let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); - let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)); cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, lost.iter()); assert_eq!(cc.cwnd(), cc.cwnd_min()); } @@ -1030,7 +1031,7 @@ mod tests { #[test] fn persistent_congestion_no_prev_ack_cubic() { let lost = make_lost(&[1, PERSISTENT_CONG_THRESH + 2]); - let mut cc = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); cc.detect_persistent_congestion(Some(by_pto(0)), None, PTO, lost.iter()); assert_eq!(cc.cwnd(), cc.cwnd_min()); } @@ -1041,7 +1042,7 @@ mod tests { fn persistent_congestion_unsorted_newreno() { let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost @@ -1054,7 +1055,7 @@ mod tests { fn persistent_congestion_unsorted_cubic() { let lost = make_lost(&[PERSISTENT_CONG_THRESH + 2, 1]); assert!(!persistent_congestion_by_pto( - ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)), + ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)), 0, 0, &lost @@ -1065,7 +1066,7 @@ mod tests { fn app_limited_slow_start() { const BELOW_APP_LIMIT_PKTS: usize = 5; const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; - let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)); let cwnd = cc.congestion_window; let mut now = now(); let mut next_pn = 0; @@ -1149,7 +1150,7 @@ mod tests { const BELOW_APP_LIMIT_PKTS: usize = CWND_PKTS_CA - 2; const ABOVE_APP_LIMIT_PKTS: usize = BELOW_APP_LIMIT_PKTS + 1; - let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)); let mut now = now(); // Change state to congestion avoidance by introducing loss. @@ -1264,7 +1265,7 @@ mod tests { #[test] fn ecn_ce() { - let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)); let p_ce = SentPacket::new( PacketType::Short, 1, diff --git a/neqo-transport/src/cc/tests/cubic.rs b/neqo-transport/src/cc/tests/cubic.rs index 54c8c2c3c8..045a9b202d 100644 --- a/neqo-transport/src/cc/tests/cubic.rs +++ b/neqo-transport/src/cc/tests/cubic.rs @@ -32,6 +32,7 @@ use crate::{ }; const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); +const MTU: usize = 1_500; const RTT: Duration = Duration::from_millis(100); const fn cwnd_after_loss(cwnd: usize) -> usize { @@ -95,7 +96,7 @@ fn expected_tcp_acks(cwnd_rtt_start: usize, mtu: usize) -> u64 { #[test] fn tcp_phase() { - let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); // change to congestion avoidance state. cubic.set_ssthresh(1); @@ -202,7 +203,7 @@ fn tcp_phase() { #[test] fn cubic_phase() { - let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); let cwnd_initial_f64: f64 = convert_to_f64(cubic.cwnd_initial()); // Set last_max_cwnd to a higher number make sure that cc is the cubic phase (cwnd is calculated // by the cubic equation). @@ -257,7 +258,7 @@ fn assert_within + PartialOrd + Copy>(value: T, expected: T, #[test] fn congestion_event_slow_start() { - let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); _ = fill_cwnd(&mut cubic, 0, now()); ack_packet(&mut cubic, 0, now()); @@ -288,7 +289,7 @@ fn congestion_event_slow_start() { #[test] fn congestion_event_congestion_avoidance() { - let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); @@ -312,7 +313,7 @@ fn congestion_event_congestion_avoidance() { #[test] fn congestion_event_congestion_avoidance_2() { - let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); @@ -341,7 +342,7 @@ fn congestion_event_congestion_avoidance_2() { #[test] fn congestion_event_congestion_avoidance_no_overflow() { const PTO: Duration = Duration::from_millis(120); - let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR)); + let mut cubic = ClassicCongestionControl::new(Cubic::default(), Pmtud::new(IP_ADDR, MTU)); // Set ssthresh to something small to make sure that cc is in the congection avoidance phase. cubic.set_ssthresh(1); diff --git a/neqo-transport/src/cc/tests/new_reno.rs b/neqo-transport/src/cc/tests/new_reno.rs index 1ee8c74f67..4dca9427b8 100644 --- a/neqo-transport/src/cc/tests/new_reno.rs +++ b/neqo-transport/src/cc/tests/new_reno.rs @@ -23,6 +23,7 @@ use crate::{ }; const IP_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); +const MTU: usize = 1_500; const PTO: Duration = Duration::from_millis(100); const RTT: Duration = Duration::from_millis(98); const RTT_ESTIMATE: RttEstimate = RttEstimate::from_duration(RTT); @@ -39,7 +40,7 @@ fn cwnd_is_halved(cc: &ClassicCongestionControl) { #[test] fn issue_876() { - let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)); let time_now = now(); let time_before = time_now.checked_sub(Duration::from_millis(100)).unwrap(); let time_after = time_now + Duration::from_millis(150); @@ -150,7 +151,7 @@ fn issue_876() { #[test] // https://github.com/mozilla/neqo/pull/1465 fn issue_1465() { - let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR)); + let mut cc = ClassicCongestionControl::new(NewReno::default(), Pmtud::new(IP_ADDR, MTU)); let mut pn = 0; let mut now = now(); let max_datagram_size = cc.max_datagram_size(); diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index 49c289f60b..f53c03bd87 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -569,7 +569,9 @@ impl Path { qlog: NeqoQlog, now: Instant, ) -> Self { - let mut sender = PacketSender::new(cc, pacing, Pmtud::new(remote.ip()), now); + let iface_mtu = + mtu::interface_and_mtu(&(local, remote)).map_or_else(|_| usize::MAX, |(_, m)| m); + let mut sender = PacketSender::new(cc, pacing, Pmtud::new(remote.ip(), iface_mtu), now); sender.set_qlog(qlog.clone()); Self { local, diff --git a/neqo-transport/src/pmtud.rs b/neqo-transport/src/pmtud.rs index 8a2179e41d..928bd44fde 100644 --- a/neqo-transport/src/pmtud.rs +++ b/neqo-transport/src/pmtud.rs @@ -46,6 +46,7 @@ pub struct Pmtud { search_table: &'static [usize], header_size: usize, mtu: usize, + iface_mtu: usize, probe_index: usize, probe_count: usize, probe_state: Probe, @@ -71,13 +72,14 @@ impl Pmtud { } #[must_use] - pub const fn new(remote_ip: IpAddr) -> Self { + pub const fn new(remote_ip: IpAddr, iface_mtu: usize) -> Self { let search_table = Self::search_table(remote_ip); let probe_index = 0; Self { search_table, header_size: Self::header_size(remote_ip), mtu: search_table[probe_index], + iface_mtu, probe_index, probe_count: 0, probe_state: Probe::NotNeeded, @@ -303,7 +305,10 @@ impl Pmtud { /// Starts the next upward PMTUD probe. pub fn start(&mut self) { - if self.probe_index < SEARCH_TABLE_LEN - 1 { + if self.probe_index < SEARCH_TABLE_LEN - 1 // Not at the end of the search table + // Next size is <= iface MTU + && self.search_table[self.probe_index + 1] <= self.iface_mtu + { self.probe_state = Probe::Needed; // We need to send a probe self.probe_count = 0; // For the first time self.probe_index += 1; // At this size @@ -407,7 +412,7 @@ mod tests { fn find_pmtu(addr: IpAddr, mtu: usize) { fixture_init(); let now = now(); - let mut pmtud = Pmtud::new(addr); + let mut pmtud = Pmtud::new(addr, mtu); let mut stats = Stats::default(); let mut prot = CryptoDxState::test_default(); @@ -445,7 +450,7 @@ mod tests { fixture_init(); let now = now(); - let mut pmtud = Pmtud::new(addr); + let mut pmtud = Pmtud::new(addr, mtu); let mut stats = Stats::default(); let mut prot = CryptoDxState::test_default(); @@ -498,7 +503,7 @@ mod tests { fixture_init(); let now = now(); - let mut pmtud = Pmtud::new(addr); + let mut pmtud = Pmtud::new(addr, larger_mtu); let mut stats = Stats::default(); let mut prot = CryptoDxState::test_default(); @@ -570,7 +575,7 @@ mod tests { #[test] fn pmtud_on_packets_lost() { let now = now(); - let mut pmtud = Pmtud::new(V4); + let mut pmtud = Pmtud::new(V4, 1500); let mut stats = Stats::default(); // No packets lost, nothing should change. @@ -638,7 +643,7 @@ mod tests { #[test] fn pmtud_on_packets_lost_and_acked() { let now = now(); - let mut pmtud = Pmtud::new(V4); + let mut pmtud = Pmtud::new(V4, 1500); let mut stats = Stats::default(); // A packet of size 100 was ACKed, which is smaller than all probe sizes.