Skip to content

Commit

Permalink
Make ConnectionIdRef Copy (#1561)
Browse files Browse the repository at this point in the history
This is just a reference, so Copy is cheap.
This reduces the number of unnecessary references that are taken.
  • Loading branch information
martinthomson authored Jan 17, 2024
1 parent 2fcf7a7 commit 9d58e64
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 74 deletions.
37 changes: 18 additions & 19 deletions neqo-transport/src/cid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@

// Representation and management of connection IDs.

use crate::frame::FRAME_TYPE_NEW_CONNECTION_ID;
use crate::packet::PacketBuilder;
use crate::recovery::RecoveryToken;
use crate::stats::FrameStats;
use crate::{Error, Res};
use crate::{
frame::FRAME_TYPE_NEW_CONNECTION_ID, packet::PacketBuilder, recovery::RecoveryToken,
stats::FrameStats, Error, Res,
};

use neqo_common::{hex, hex_with_len, qinfo, Decoder, Encoder};
use neqo_crypto::random;

use smallvec::SmallVec;
use std::borrow::Borrow;
use std::cell::{Ref, RefCell};
use std::cmp::max;
use std::cmp::min;
use std::convert::AsRef;
use std::convert::TryFrom;
use std::ops::Deref;
use std::rc::Rc;
use std::{
borrow::Borrow,
cell::{Ref, RefCell},
cmp::{max, min},
convert::{AsRef, TryFrom},
ops::Deref,
rc::Rc,
};

pub const MAX_CONNECTION_ID_LEN: usize = 20;
pub const LOCAL_ACTIVE_CID_LIMIT: usize = 8;
Expand Down Expand Up @@ -88,8 +87,8 @@ impl<T: AsRef<[u8]> + ?Sized> From<&T> for ConnectionId {
}
}

impl<'a> From<&ConnectionIdRef<'a>> for ConnectionId {
fn from(cidref: &ConnectionIdRef<'a>) -> Self {
impl<'a> From<ConnectionIdRef<'a>> for ConnectionId {
fn from(cidref: ConnectionIdRef<'a>) -> Self {
Self::from(SmallVec::from(cidref.cid))
}
}
Expand Down Expand Up @@ -120,7 +119,7 @@ impl<'a> PartialEq<ConnectionIdRef<'a>> for ConnectionId {
}
}

#[derive(Hash, Eq, PartialEq)]
#[derive(Hash, Eq, PartialEq, Clone, Copy)]
pub struct ConnectionIdRef<'a> {
cid: &'a [u8],
}
Expand Down Expand Up @@ -340,8 +339,8 @@ impl<SRT: Clone + PartialEq> ConnectionIdStore<SRT> {
self.cids.retain(|c| c.seqno != seqno);
}

pub fn contains(&self, cid: &ConnectionIdRef) -> bool {
self.cids.iter().any(|c| &c.cid == cid)
pub fn contains(&self, cid: ConnectionIdRef) -> bool {
self.cids.iter().any(|c| c.cid == cid)
}

pub fn next(&mut self) -> Option<ConnectionIdEntry<SRT>> {
Expand Down Expand Up @@ -479,7 +478,7 @@ impl ConnectionIdManager {
}
}

pub fn is_valid(&self, cid: &ConnectionIdRef) -> bool {
pub fn is_valid(&self, cid: ConnectionIdRef) -> bool {
self.connection_ids.contains(cid)
}

Expand Down
6 changes: 3 additions & 3 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ impl Connection {
dcid: Option<&ConnectionId>,
now: Instant,
) -> Res<PreprocessResult> {
if dcid.map_or(false, |d| d != packet.dcid()) {
if dcid.map_or(false, |d| d != &packet.dcid()) {
self.stats
.borrow_mut()
.pkt_dropped("Coalesced packet has different DCID");
Expand Down Expand Up @@ -1266,7 +1266,7 @@ impl Connection {
if versions.is_empty()
|| versions.contains(&self.version().wire_version())
|| versions.contains(&0)
|| packet.scid() != self.odcid().unwrap()
|| &packet.scid() != self.odcid().unwrap()
|| matches!(
self.address_validation,
AddressValidationInfo::Retry { .. }
Expand Down Expand Up @@ -1373,7 +1373,7 @@ impl Connection {
self.handle_migration(path, d, migrate, now);
} else if self.role != Role::Client
&& (packet.packet_type() == PacketType::Handshake
|| (packet.dcid().len() >= 8 && packet.dcid() == &self.local_initial_source_cid))
|| (packet.dcid().len() >= 8 && packet.dcid() == self.local_initial_source_cid))
{
// We only allow one path during setup, so apply handshake
// path validation to this path.
Expand Down
12 changes: 6 additions & 6 deletions neqo-transport/src/connection/tests/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ fn migration(mut client: Connection) {
let probe = client.process_output(now).dgram().unwrap();
assert_v4_path(&probe, true); // Contains PATH_CHALLENGE.
assert_eq!(client.stats().frame_tx.path_challenge, 1);
let probe_cid = ConnectionId::from(&get_cid(&probe));
let probe_cid = ConnectionId::from(get_cid(&probe));

let resp = server.process(Some(&probe), now).dgram().unwrap();
assert_v4_path(&resp, true);
Expand Down Expand Up @@ -814,7 +814,7 @@ fn retire_all() {
.unwrap();
connect_force_idle(&mut client, &mut server);

let original_cid = ConnectionId::from(&get_cid(&send_something(&mut client, now())));
let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now())));

server.test_frame_writer = Some(Box::new(RetireAll { cid_gen }));
let ncid = send_something(&mut server, now());
Expand Down Expand Up @@ -852,7 +852,7 @@ fn retire_prior_to_migration_failure() {
.unwrap();
connect_force_idle(&mut client, &mut server);

let original_cid = ConnectionId::from(&get_cid(&send_something(&mut client, now())));
let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now())));

client
.migrate(Some(addr_v4()), Some(addr_v4()), false, now())
Expand All @@ -862,7 +862,7 @@ fn retire_prior_to_migration_failure() {
let probe = client.process_output(now()).dgram().unwrap();
assert_v4_path(&probe, true);
assert_eq!(client.stats().frame_tx.path_challenge, 1);
let probe_cid = ConnectionId::from(&get_cid(&probe));
let probe_cid = ConnectionId::from(get_cid(&probe));
assert_ne!(original_cid, probe_cid);

// Have the server receive the probe, but separately have it decide to
Expand Down Expand Up @@ -907,7 +907,7 @@ fn retire_prior_to_migration_success() {
.unwrap();
connect_force_idle(&mut client, &mut server);

let original_cid = ConnectionId::from(&get_cid(&send_something(&mut client, now())));
let original_cid = ConnectionId::from(get_cid(&send_something(&mut client, now())));

client
.migrate(Some(addr_v4()), Some(addr_v4()), false, now())
Expand All @@ -917,7 +917,7 @@ fn retire_prior_to_migration_success() {
let probe = client.process_output(now()).dgram().unwrap();
assert_v4_path(&probe, true);
assert_eq!(client.stats().frame_tx.path_challenge, 1);
let probe_cid = ConnectionId::from(&get_cid(&probe));
let probe_cid = ConnectionId::from(get_cid(&probe));
assert_ne!(original_cid, probe_cid);

// Have the server receive the probe, but separately have it decide to
Expand Down
7 changes: 3 additions & 4 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,13 +673,12 @@ impl<'a> PublicPacket<'a> {
self.packet_type
}

pub fn dcid(&self) -> &ConnectionIdRef<'a> {
&self.dcid
pub fn dcid(&self) -> ConnectionIdRef<'a> {
self.dcid
}

pub fn scid(&self) -> &ConnectionIdRef<'a> {
pub fn scid(&self) -> ConnectionIdRef<'a> {
self.scid
.as_ref()
.expect("should only be called for long header packets")
}

Expand Down
44 changes: 23 additions & 21 deletions neqo-transport/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,29 @@
#![deny(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]

use std::cell::RefCell;
use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::rc::Rc;
use std::time::{Duration, Instant};

use crate::ackrate::{AckRate, PeerAckDelay};
use crate::cc::CongestionControlAlgorithm;
use crate::cid::{ConnectionId, ConnectionIdRef, ConnectionIdStore, RemoteConnectionIdEntry};
use crate::frame::{
FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID,
use std::{
cell::RefCell,
convert::TryFrom,
fmt::{self, Display},
mem,
net::{IpAddr, SocketAddr},
rc::Rc,
time::{Duration, Instant},
};

use crate::{
ackrate::{AckRate, PeerAckDelay},
cc::CongestionControlAlgorithm,
cid::{ConnectionId, ConnectionIdRef, ConnectionIdStore, RemoteConnectionIdEntry},
frame::{FRAME_TYPE_PATH_CHALLENGE, FRAME_TYPE_PATH_RESPONSE, FRAME_TYPE_RETIRE_CONNECTION_ID},
packet::PacketBuilder,
recovery::RecoveryToken,
rtt::RttEstimate,
sender::PacketSender,
stats::FrameStats,
tracking::{PacketNumberSpace, SentPacket},
Error, Res,
};
use crate::packet::PacketBuilder;
use crate::recovery::RecoveryToken;
use crate::rtt::RttEstimate;
use crate::sender::PacketSender;
use crate::stats::FrameStats;
use crate::tracking::{PacketNumberSpace, SentPacket};
use crate::{Error, Res};

use neqo_common::{hex, qdebug, qinfo, qlog::NeqoQlog, qtrace, Datagram, Encoder};
use neqo_crypto::random;
Expand Down Expand Up @@ -664,7 +666,7 @@ impl Path {

/// Set the remote connection ID based on the peer's choice.
/// This is only valid during the handshake.
pub fn set_remote_cid(&mut self, cid: &ConnectionIdRef) {
pub fn set_remote_cid(&mut self, cid: ConnectionIdRef) {
self.remote_cid
.as_mut()
.unwrap()
Expand Down
46 changes: 25 additions & 21 deletions neqo-transport/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,25 @@ use neqo_crypto::{
use qlog::streamer::QlogStreamer;

pub use crate::addr_valid::ValidateAddress;
use crate::addr_valid::{AddressValidation, AddressValidationResult};
use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef};
use crate::connection::{Connection, Output, State};
use crate::packet::{PacketBuilder, PacketType, PublicPacket};
use crate::{ConnectionParameters, Res, Version};

use std::cell::RefCell;
use std::collections::{HashMap, HashSet, VecDeque};
use std::fs::OpenOptions;
use std::mem;
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::rc::{Rc, Weak};
use std::time::{Duration, Instant};
use crate::{
addr_valid::{AddressValidation, AddressValidationResult},
cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionIdRef},
connection::{Connection, Output, State},
packet::{PacketBuilder, PacketType, PublicPacket},
ConnectionParameters, Res, Version,
};

use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
fs::OpenOptions,
mem,
net::SocketAddr,
ops::{Deref, DerefMut},
path::PathBuf,
rc::{Rc, Weak},
time::{Duration, Instant},
};

pub enum InitialResult {
Accept,
Expand Down Expand Up @@ -303,7 +307,7 @@ impl Server {
out.dgram()
}

fn connection(&self, cid: &ConnectionIdRef) -> Option<StateRef> {
fn connection(&self, cid: ConnectionIdRef) -> Option<StateRef> {
self.connections.borrow().get(&cid[..]).map(Rc::clone)
}

Expand Down Expand Up @@ -383,7 +387,7 @@ impl Server {
}
}

fn create_qlog_trace(&self, odcid: &ConnectionIdRef<'_>) -> NeqoQlog {
fn create_qlog_trace(&self, odcid: ConnectionIdRef<'_>) -> NeqoQlog {
if let Some(qlog_dir) = &self.qlog_dir {
let mut qlog_path = qlog_dir.to_path_buf();

Expand Down Expand Up @@ -449,7 +453,7 @@ impl Server {
c.set_retry_cids(odcid, initial.src_cid, initial.dst_cid);
}
c.set_validation(Rc::clone(&self.address_validation));
c.set_qlog(self.create_qlog_trace(&attempt_key.odcid.as_cid_ref()));
c.set_qlog(self.create_qlog_trace(attempt_key.odcid.as_cid_ref()));
if let Some(cfg) = &self.ech_config {
if c.server_enable_ech(cfg.config, &cfg.public_name, &cfg.sk, &cfg.pk)
.is_err()
Expand Down Expand Up @@ -504,7 +508,7 @@ impl Server {
qwarn!([self], "Unable to create connection");
if e == crate::Error::VersionNegotiation {
crate::qlog::server_version_information_failed(
&mut self.create_qlog_trace(&attempt_key.odcid.as_cid_ref()),
&mut self.create_qlog_trace(attempt_key.odcid.as_cid_ref()),
self.conn_params.get_versions().all(),
initial.version.wire_version(),
)
Expand Down Expand Up @@ -578,8 +582,8 @@ impl Server {

qdebug!([self], "Unsupported version: {:x}", packet.wire_version());
let vn = PacketBuilder::version_negotiation(
packet.scid(),
packet.dcid(),
&packet.scid()[..],
&packet.dcid()[..],
packet.wire_version(),
self.conn_params.get_versions().all(),
);
Expand Down

0 comments on commit 9d58e64

Please sign in to comment.