From 1488a4eec6000e0e4cbaf953b587a333149b4139 Mon Sep 17 00:00:00 2001 From: Max Leonard Inden Date: Sun, 26 Jan 2025 18:58:43 +0100 Subject: [PATCH] fix(transport/packet): don't (mutably) borrow data multiple times --- neqo-transport/src/packet/mod.rs | 48 +++++++++++++++++--------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 73b47bccd8..a476b1a655 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -563,12 +563,12 @@ pub struct PublicPacket<'a> { /// The packet type. packet_type: PacketType, /// The recovered destination connection ID. - dcid: ConnectionIdRef<'a>, + dcid: ConnectionId, /// The source connection ID, if this is a long header packet. - scid: Option>, + scid: Option, /// Any token that is included in the packet (Retry always has a token; Initial sometimes /// does). This is empty when there is no token. - token: &'a [u8], + token: Vec, /// The size of the header, not including the packet number. header_len: usize, /// Protocol version, if present in header. @@ -624,9 +624,9 @@ impl<'a> PublicPacket<'a> { pub fn decode( data: &'a mut [u8], dcid_decoder: &dyn ConnectionIdDecoder, - ) -> Res<(Self, &'a mut [u8])> { + ) -> Res<(PublicPacket<'a>, &'a mut [u8])> { let mut decoder = Decoder::new(data); - let first = Self::opt(decoder.decode_uint::())?; + let first = PublicPacket::opt(decoder.decode_uint::())?; if first & 0x80 == PACKET_BIT_SHORT { // Conveniently, this also guarantees that there is enough space @@ -634,17 +634,18 @@ impl<'a> PublicPacket<'a> { if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { return Err(Error::InvalidPacket); } - let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?; + let dcid = PublicPacket::opt(dcid_decoder.decode_cid(&mut decoder))?.into(); if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { return Err(Error::InvalidPacket); } let header_len = decoder.offset(); + return Ok(( - Self { + PublicPacket { packet_type: PacketType::Short, dcid, scid: None, - token: &[], + token: vec![], header_len, version: None, data, @@ -654,18 +655,18 @@ impl<'a> PublicPacket<'a> { } // Generic long header. - let version = Self::opt(decoder.decode_uint())?; - let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); - let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + let version = PublicPacket::opt(decoder.decode_uint())?; + let dcid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into(); + let scid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into(); // Version negotiation. if version == 0 { return Ok(( - Self { + PublicPacket { packet_type: PacketType::VersionNegotiation, dcid, scid: Some(scid), - token: &[], + token: vec![], header_len: decoder.offset(), version: None, data, @@ -677,11 +678,11 @@ impl<'a> PublicPacket<'a> { // Check that this is a long header from a supported version. let Ok(version) = Version::try_from(version) else { return Ok(( - Self { + PublicPacket { packet_type: PacketType::OtherVersion, dcid, scid: Some(scid), - token: &[], + token: vec![], header_len: decoder.offset(), version: Some(version), data, @@ -696,11 +697,12 @@ impl<'a> PublicPacket<'a> { let packet_type = PacketType::from_byte((first >> 4) & 3, version); // The type-specific code includes a token. This consumes the remainder of the packet. - let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?; + let (token, header_len) = PublicPacket::decode_long(&mut decoder, packet_type, version)?; + let token = token.to_vec(); let end = data.len() - decoder.remaining(); let (data, remainder) = data.split_at_mut(end); Ok(( - Self { + PublicPacket { packet_type, dcid, scid: Some(scid), @@ -751,22 +753,24 @@ impl<'a> PublicPacket<'a> { } #[must_use] - pub const fn dcid(&self) -> ConnectionIdRef<'a> { - self.dcid + pub fn dcid(&self) -> ConnectionIdRef { + self.dcid.as_cid_ref() } /// # Panics /// /// This will panic if called for a short header packet. #[must_use] - pub fn scid(&self) -> ConnectionIdRef<'a> { + pub fn scid(&self) -> ConnectionIdRef { self.scid + .as_ref() .expect("should only be called for long header packets") + .as_cid_ref() } #[must_use] - pub const fn token(&self) -> &'a [u8] { - self.token + pub fn token(&self) -> &[u8] { + &self.token } #[must_use]