diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 8522507a69..f4e2847609 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -46,7 +46,7 @@ use crate::{ quic_datagrams::{DatagramTracking, QuicDatagrams}, recovery::{LossRecovery, RecoveryToken, SendProfile}, recv_stream::RecvStreamStats, - rtt::GRANULARITY, + rtt::{RttEstimate, GRANULARITY}, send_stream::SendStream, stats::{Stats, StatsCell}, stream_id::StreamType, @@ -610,11 +610,10 @@ impl Connection { /// a value of this approximate order. Don't use this for loss recovery, /// only use it where a more precise value is not important. fn pto(&self) -> Duration { - self.paths - .primary() - .borrow() - .rtt() - .pto(PacketNumberSpace::ApplicationData) + self.paths.primary_fallible().map_or_else( + || RttEstimate::default().pto(PacketNumberSpace::ApplicationData), + |p| p.borrow().rtt().pto(PacketNumberSpace::ApplicationData), + ) } fn create_resumption_token(&mut self, now: Instant) { @@ -962,9 +961,11 @@ impl Connection { let res = self.crypto.states.check_key_update(now); self.absorb_error(now, res); - let lost = self.loss_recovery.timeout(&self.paths.primary(), now); - self.handle_lost_packets(&lost); - qlog::packets_lost(&mut self.qlog, &lost); + if let Some(path) = self.paths.primary_fallible() { + let lost = self.loss_recovery.timeout(&path, now); + self.handle_lost_packets(&lost); + qlog::packets_lost(&mut self.qlog, &lost); + } if self.release_resumption_token_timer.is_some() { self.create_resumption_token(now); @@ -2861,8 +2862,11 @@ impl Connection { { qdebug!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges); + let Some(path) = self.paths.primary_fallible() else { + return; + }; let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received( - &self.paths.primary(), + &path, space, largest_acknowledged, ack_ranges, diff --git a/neqo-transport/tests/server.rs b/neqo-transport/tests/server.rs index 7388e0fee7..745ee79520 100644 --- a/neqo-transport/tests/server.rs +++ b/neqo-transport/tests/server.rs @@ -477,6 +477,65 @@ fn bad_client_initial() { assert_eq!(res, Output::None); } +#[test] +fn bad_client_initial_connection_close() { + let mut client = default_client(); + let mut server = default_server(); + + let dgram = client.process(None, now()).dgram().expect("a datagram"); + let (header, d_cid, s_cid, payload) = decode_initial_header(&dgram, Role::Client); + let (aead, hp) = initial_aead_and_hp(d_cid, Role::Client); + let (_, pn) = remove_header_protection(&hp, header, payload); + + let mut payload_enc = Encoder::with_capacity(1200); + payload_enc.encode(&[0x1c, 0x01, 0x00, 0x00]); // Add a CONNECTION_CLOSE frame. + + // Make a new header with a 1 byte packet number length. + let mut header_enc = Encoder::new(); + header_enc + .encode_byte(0xc0) // Initial with 1 byte packet number. + .encode_uint(4, Version::default().wire_version()) + .encode_vec(1, d_cid) + .encode_vec(1, s_cid) + .encode_vvec(&[]) + .encode_varint(u64::try_from(payload_enc.len() + aead.expansion() + 1).unwrap()) + .encode_byte(u8::try_from(pn).unwrap()); + + let mut ciphertext = header_enc.as_ref().to_vec(); + ciphertext.resize(header_enc.len() + payload_enc.len() + aead.expansion(), 0); + let v = aead + .encrypt( + pn, + header_enc.as_ref(), + payload_enc.as_ref(), + &mut ciphertext[header_enc.len()..], + ) + .unwrap(); + assert_eq!(header_enc.len() + v.len(), ciphertext.len()); + // Pad with zero to get up to 1200. + ciphertext.resize(1200, 0); + + apply_header_protection( + &hp, + &mut ciphertext, + (header_enc.len() - 1)..header_enc.len(), + ); + let bad_dgram = Datagram::new( + dgram.source(), + dgram.destination(), + dgram.tos(), + dgram.ttl(), + ciphertext, + ); + + // The server should ignore this and go to Draining. + let mut now = now(); + let response = server.process(Some(&bad_dgram), now); + now += response.callback(); + let response = server.process(None, now); + assert_eq!(response, Output::None); +} + #[test] fn version_negotiation_ignored() { let mut server = default_server();