Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Paths::primary() is now what used to be Paths::primary_fallible() #1833

Merged
merged 10 commits into from
Apr 17, 2024
104 changes: 66 additions & 38 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl Debug for Connection {
"{:?} Connection: {:?} {:?}",
self.role,
self.state,
self.paths.primary_fallible()
self.paths.primary()
)
}
}
Expand Down Expand Up @@ -592,7 +592,11 @@ impl Connection {
fn make_resumption_token(&mut self) -> ResumptionToken {
debug_assert_eq!(self.role, Role::Client);
debug_assert!(self.crypto.has_resumption_token());
let rtt = self.paths.primary().borrow().rtt().estimate();
let rtt = self.paths.primary().map_or_else(
|| RttEstimate::default().estimate(),
|p| p.borrow().rtt().estimate(),
);

self.crypto
.create_resumption_token(
self.new_token.take_token(),
Expand All @@ -611,7 +615,7 @@ impl Connection {
/// a value of this approximate order. Don't use this for loss recovery,
/// only use it where a more precise value is not important.
fn pto(&self) -> Duration {
self.paths.primary_fallible().map_or_else(
self.paths.primary().map_or_else(
|| RttEstimate::default().pto(PacketNumberSpace::ApplicationData),
|p| p.borrow().rtt().pto(PacketNumberSpace::ApplicationData),
)
Expand Down Expand Up @@ -746,7 +750,12 @@ impl Connection {
if !init_token.is_empty() {
self.address_validation = AddressValidationInfo::NewToken(init_token.to_vec());
}
self.paths.primary().borrow_mut().rtt_mut().set_initial(rtt);
self.paths
.primary()
.ok_or(Error::InternalError)?
.borrow_mut()
.rtt_mut()
.set_initial(rtt);
self.set_initial_limits();
// Start up TLS, which has the effect of setting up all the necessary
// state for 0-RTT. This only stages the CRYPTO frames.
Expand Down Expand Up @@ -786,7 +795,7 @@ impl Connection {
// If we are able, also send a NEW_TOKEN frame.
// This should be recording all remote addresses that are valid,
// but there are just 0 or 1 in the current implementation.
if let Some(path) = self.paths.primary_fallible() {
if let Some(path) = self.paths.primary() {
if let Some(token) = self
.address_validation
.generate_new_token(path.borrow().remote_address(), now)
Expand Down Expand Up @@ -858,7 +867,7 @@ impl Connection {
#[must_use]
pub fn stats(&self) -> Stats {
let mut v = self.stats.borrow().clone();
if let Some(p) = self.paths.primary_fallible() {
if let Some(p) = self.paths.primary() {
let p = p.borrow();
v.rtt = p.rtt().estimate();
v.rttvar = p.rtt().rttvar();
Expand Down Expand Up @@ -895,14 +904,14 @@ impl Connection {
State::WaitInitial => {
// We don't have any state yet, so don't bother with
// the closing state, just send one CONNECTION_CLOSE.
if let Some(path) = path.or_else(|| self.paths.primary_fallible()) {
if let Some(path) = path.or_else(|| self.paths.primary()) {
self.state_signaling
.close(path, error.clone(), frame_type, msg);
}
self.set_state(State::Closed(error));
}
_ => {
if let Some(path) = path.or_else(|| self.paths.primary_fallible()) {
if let Some(path) = path.or_else(|| self.paths.primary()) {
self.state_signaling
.close(path, error.clone(), frame_type, msg);
if matches!(v, Error::KeysExhausted) {
Expand Down Expand Up @@ -962,7 +971,7 @@ impl Connection {
let res = self.crypto.states.check_key_update(now);
self.absorb_error(now, res);

if let Some(path) = self.paths.primary_fallible() {
if let Some(path) = self.paths.primary() {
let lost = self.loss_recovery.timeout(&path, now);
self.handle_lost_packets(&lost);
qlog::packets_lost(&mut self.qlog, &lost);
Expand Down Expand Up @@ -1016,7 +1025,7 @@ impl Connection {
delays.push(ack_time);
}

if let Some(p) = self.paths.primary_fallible() {
if let Some(p) = self.paths.primary() {
let path = p.borrow();
let rtt = path.rtt();
let pto = rtt.pto(PacketNumberSpace::ApplicationData);
Expand Down Expand Up @@ -1125,7 +1134,13 @@ impl Connection {
}
// At this point, we should only have the connection ID that we generated.
// Update to the one that the server prefers.
let path = self.paths.primary();
let Some(path) = self.paths.primary() else {
self.stats
.borrow_mut()
.pkt_dropped("Retry without an existing path");
return;
};

path.borrow_mut().set_remote_cid(packet.scid());

let retry_scid = ConnectionId::from(packet.scid());
Expand Down Expand Up @@ -1153,8 +1168,9 @@ impl Connection {
fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) {
if self.crypto.discard(space) {
qdebug!([self], "Drop packet number space {}", space);
let primary = self.paths.primary();
self.loss_recovery.discard(&primary, space, now);
if let Some(path) = self.paths.primary() {
self.loss_recovery.discard(&path, space, now);
}
self.acks.drop_space(space);
}
}
Expand Down Expand Up @@ -1229,8 +1245,9 @@ impl Connection {
assert_ne!(self.version, version);

qinfo!([self], "Version negotiation: trying {:?}", version);
let local_addr = self.paths.primary().borrow().local_address();
let remote_addr = self.paths.primary().borrow().remote_address();
let path = self.paths.primary().ok_or(Error::NoAvailablePath)?;
let local_addr = path.borrow().local_address();
let remote_addr = path.borrow().remote_address();
let conn_params = self
.conn_params
.clone()
Expand Down Expand Up @@ -1632,10 +1649,15 @@ impl Connection {
if let Some(cid) = self.connection_ids.next() {
self.paths.make_permanent(path, None, cid);
Ok(())
} else if self.paths.primary().borrow().remote_cid().is_empty() {
self.paths
.make_permanent(path, None, ConnectionIdEntry::empty_remote());
Ok(())
} else if let Some(primary) = self.paths.primary() {
if primary.borrow().remote_cid().is_empty() {
self.paths
.make_permanent(path, None, ConnectionIdEntry::empty_remote());
Ok(())
} else {
qtrace!([self], "Unable to make path permanent: {}", path.borrow());
Err(Error::InvalidMigration)
}
} else {
qtrace!([self], "Unable to make path permanent: {}", path.borrow());
Err(Error::InvalidMigration)
Expand Down Expand Up @@ -1728,8 +1750,10 @@ impl Connection {
// Pointless migration is pointless.
return Err(Error::InvalidMigration);
}
let local = local.unwrap_or_else(|| self.paths.primary().borrow().local_address());
let remote = remote.unwrap_or_else(|| self.paths.primary().borrow().remote_address());

let path = self.paths.primary().ok_or(Error::InvalidMigration)?;
let local = local.unwrap_or_else(|| path.borrow().local_address());
let remote = remote.unwrap_or_else(|| path.borrow().remote_address());

if mem::discriminant(&local.ip()) != mem::discriminant(&remote.ip()) {
// Can't mix address families.
Expand Down Expand Up @@ -1782,7 +1806,12 @@ impl Connection {
// has to use the existing address. So only pay attention to a preferred
// address from the same family as is currently in use. More thought will
// be needed to work out how to get addresses from a different family.
let prev = self.paths.primary().borrow().remote_address();
let prev = self
.paths
.primary()
.ok_or(Error::NoAvailablePath)?
.borrow()
.remote_address();
let remote = match prev.ip() {
IpAddr::V4(_) => addr.ipv4().map(SocketAddr::V4),
IpAddr::V6(_) => addr.ipv6().map(SocketAddr::V6),
Expand Down Expand Up @@ -2331,7 +2360,9 @@ impl Connection {
fn client_start(&mut self, now: Instant) -> Res<()> {
qdebug!([self], "client_start");
debug_assert_eq!(self.role, Role::Client);
qlog::client_connection_started(&mut self.qlog, &self.paths.primary());
if let Some(path) = self.paths.primary() {
qlog::client_connection_started(&mut self.qlog, &path);
}
qlog::client_version_information_initiated(&mut self.qlog, self.conn_params.get_versions());

self.handshake(now, self.version, PacketNumberSpace::Initial, None)?;
Expand All @@ -2354,7 +2385,7 @@ impl Connection {
pub fn close(&mut self, now: Instant, app_error: AppError, msg: impl AsRef<str>) {
let error = ConnectionError::Application(app_error);
let timeout = self.get_closing_period_time(now);
if let Some(path) = self.paths.primary_fallible() {
if let Some(path) = self.paths.primary() {
self.state_signaling.close(path, error.clone(), 0, msg);
self.set_state(State::Closing { error, timeout });
} else {
Expand Down Expand Up @@ -2412,10 +2443,8 @@ impl Connection {
// That's OK, they can try guessing this.
ConnectionIdEntry::random_srt()
};
self.paths
.primary()
.borrow_mut()
.set_reset_token(reset_token);
let path = self.paths.primary().ok_or(Error::NoAvailablePath)?;
path.borrow_mut().set_reset_token(reset_token);

let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY));
let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) {
Expand All @@ -2427,11 +2456,8 @@ impl Connection {
} else {
None
};
self.paths.primary().borrow_mut().set_ack_delay(
max_ad,
min_ad,
self.conn_params.get_ack_ratio(),
);
path.borrow_mut()
.set_ack_delay(max_ad, min_ad, self.conn_params.get_ack_ratio());

let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT);
self.cid_manager.set_limit(max_active_cids);
Expand Down Expand Up @@ -2871,7 +2897,7 @@ impl Connection {
{
qdebug!([self], "Rx ACK space={}, ranges={:?}", space, ack_ranges);

let Some(path) = self.paths.primary_fallible() else {
let Some(path) = self.paths.primary() else {
return;
};
let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received(
Expand Down Expand Up @@ -2917,8 +2943,10 @@ impl Connection {
qdebug!([self], "0-RTT rejected");

// Tell 0-RTT packets that they were "lost".
let dropped = self.loss_recovery.drop_0rtt(&self.paths.primary(), now);
self.handle_lost_packets(&dropped);
if let Some(path) = self.paths.primary() {
let dropped = self.loss_recovery.drop_0rtt(&path, now);
self.handle_lost_packets(&dropped);
}

self.streams.zero_rtt_rejected();

Expand All @@ -2937,7 +2965,7 @@ impl Connection {
// Remove the randomized client CID from the list of acceptable CIDs.
self.cid_manager.remove_odcid();
// Mark the path as validated, if it isn't already.
let path = self.paths.primary();
let path = self.paths.primary().ok_or(Error::NoAvailablePath)?;
path.borrow_mut().set_valid(now);
// Generate a qlog event that the server connection started.
qlog::server_connection_started(&mut self.qlog, &path);
Expand Down Expand Up @@ -3205,7 +3233,7 @@ impl Connection {
else {
return Err(Error::NotAvailable);
};
let path = self.paths.primary_fallible().ok_or(Error::NotAvailable)?;
let path = self.paths.primary().ok_or(Error::NotAvailable)?;
let mtu = path.borrow().mtu();
let encoder = Encoder::with_capacity(mtu);

Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/connection/tests/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ fn datagram_fill() {

// Work out how much space we have for a datagram.
let space = {
let p = client.paths.primary();
let p = client.paths.primary().unwrap();
let path = p.borrow();
// Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number,
// 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD.
Expand Down
6 changes: 3 additions & 3 deletions neqo-transport/src/connection/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ fn fill_stream(c: &mut Connection, stream: StreamId) {
fn fill_cwnd(c: &mut Connection, stream: StreamId, mut now: Instant) -> (Vec<Datagram>, Instant) {
// Train wreck function to get the remaining congestion window on the primary path.
fn cwnd(c: &Connection) -> usize {
c.paths.primary().borrow().sender().cwnd_avail()
c.paths.primary().unwrap().borrow().sender().cwnd_avail()
}

qtrace!("fill_cwnd starting cwnd: {}", cwnd(c));
Expand Down Expand Up @@ -475,10 +475,10 @@ where

// Get the current congestion window for the connection.
fn cwnd(c: &Connection) -> usize {
c.paths.primary().borrow().sender().cwnd()
c.paths.primary().unwrap().borrow().sender().cwnd()
}
fn cwnd_avail(c: &Connection) -> usize {
c.paths.primary().borrow().sender().cwnd_avail()
c.paths.primary().unwrap().borrow().sender().cwnd_avail()
}

fn induce_persistent_congestion(
Expand Down
35 changes: 17 additions & 18 deletions neqo-transport/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,8 @@ impl Paths {
})
}

/// Get a reference to the primary path. This will assert if there is no primary
/// path, which happens at a server prior to receiving a valid Initial packet
/// from a client. So be careful using this method.
pub fn primary(&self) -> PathRef {
self.primary_fallible().unwrap()
}

/// Get a reference to the primary path. Use this prior to handshake completion.
pub fn primary_fallible(&self) -> Option<PathRef> {
/// Get a reference to the primary path, if one exists.
pub fn primary(&self) -> Option<PathRef> {
self.primary.clone()
}

Expand Down Expand Up @@ -243,7 +236,10 @@ impl Paths {
/// Returns `true` if the path was migrated.
pub fn migrate(&mut self, path: &PathRef, force: bool, now: Instant) -> bool {
debug_assert!(!self.is_temporary(path));
let baseline = self.primary().borrow().ecn_info.baseline();
let baseline = self.primary().map_or_else(
|| EcnInfo::default().baseline(),
|p| p.borrow().ecn_info.baseline(),
);
path.borrow_mut().set_ecn_baseline(baseline);
if force || path.borrow().is_valid() {
path.borrow_mut().set_valid(now);
Expand Down Expand Up @@ -310,7 +306,6 @@ impl Paths {
/// Set the identified path to be primary.
/// This panics if `make_permanent` hasn't been called.
pub fn handle_migration(&mut self, path: &PathRef, remote: SocketAddr, now: Instant) {
qtrace!([self.primary().borrow()], "handle_migration");
// The update here needs to match the checks in `Path::received_on`.
// Here, we update the remote port number to match the source port on the
// datagram that was received. This ensures that we send subsequent
Expand Down Expand Up @@ -428,10 +423,10 @@ impl Paths {
stats.retire_connection_id += 1;
}

// Write out any ACK_FREQUENCY frames.
self.primary()
.borrow_mut()
.write_cc_frames(builder, tokens, stats);
if let Some(path) = self.primary() {
// Write out any ACK_FREQUENCY frames.
path.borrow_mut().write_cc_frames(builder, tokens, stats);
}
}

pub fn lost_retire_cid(&mut self, lost: u64) {
Expand All @@ -443,11 +438,15 @@ impl Paths {
}

pub fn lost_ack_frequency(&mut self, lost: &AckRate) {
self.primary().borrow_mut().lost_ack_frequency(lost);
if let Some(path) = self.primary() {
path.borrow_mut().lost_ack_frequency(lost);
}
}

pub fn acked_ack_frequency(&mut self, acked: &AckRate) {
self.primary().borrow_mut().acked_ack_frequency(acked);
if let Some(path) = self.primary() {
path.borrow_mut().acked_ack_frequency(acked);
}
}

/// Get an estimate of the RTT on the primary path.
Expand All @@ -457,7 +456,7 @@ impl Paths {
// make a new RTT esimate and interrogate that.
// That is more expensive, but it should be rare and breaking encapsulation
// is worse, especially as this is only used in tests.
self.primary_fallible()
self.primary()
.map_or(RttEstimate::default().estimate(), |p| {
p.borrow().rtt().estimate()
})
Expand Down