diff --git a/quic/s2n-quic-platform/src/features.rs b/quic/s2n-quic-platform/src/features.rs index 55eda2f61f..23bfc81b1c 100644 --- a/quic/s2n-quic-platform/src/features.rs +++ b/quic/s2n-quic-platform/src/features.rs @@ -6,6 +6,11 @@ type c_int = std::os::raw::c_int; pub mod gro; pub mod gso; +pub mod pktinfo; +pub mod pktinfo_v4; +pub mod pktinfo_v6; +pub mod tos; pub mod tos_v4; pub mod tos_v6; + pub use gso::Gso; diff --git a/quic/s2n-quic-platform/src/features/pktinfo.rs b/quic/s2n-quic-platform/src/features/pktinfo.rs new file mode 100644 index 0000000000..e14daae82c --- /dev/null +++ b/quic/s2n-quic-platform/src/features/pktinfo.rs @@ -0,0 +1,4 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub const IS_SUPPORTED: bool = super::pktinfo_v4::IS_SUPPORTED || super::pktinfo_v6::IS_SUPPORTED; diff --git a/quic/s2n-quic-platform/src/features/pktinfo_v4.rs b/quic/s2n-quic-platform/src/features/pktinfo_v4.rs new file mode 100644 index 0000000000..887f41dc4d --- /dev/null +++ b/quic/s2n-quic-platform/src/features/pktinfo_v4.rs @@ -0,0 +1,105 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::c_int; +use s2n_quic_core::inet::IpV4Address; + +#[cfg(s2n_quic_platform_pktinfo)] +mod pktinfo_enabled { + use super::*; + use crate::message::cmsg; + use libc::{IPPROTO_IP, IP_PKTINFO}; + + pub const LEVEL: Option = Some(IPPROTO_IP as _); + pub const TYPE: Option = Some(IP_PKTINFO as _); + pub const SOCKOPT: Option<(c_int, c_int)> = Some((IPPROTO_IP as _, IP_PKTINFO as _)); + pub const CMSG_SPACE: usize = crate::message::cmsg::size_of_cmsg::(); + + pub type Cmsg = libc::in_pktinfo; + + #[inline] + pub const fn is_match(level: c_int, ty: c_int) -> bool { + level == IPPROTO_IP as c_int && ty == IP_PKTINFO as c_int + } + + /// # Safety + /// + /// * The provided bytes must be aligned to `cmsghdr` + #[inline] + pub unsafe fn decode(bytes: &[u8]) -> Option<(IpV4Address, u32)> { + let pkt_info = cmsg::decode::value_from_bytes::(bytes)?; + + // read from both fields in case only one is set and not the other + // + // from https://man7.org/linux/man-pages/man7/ip.7.html: + // + // > ipi_spec_dst is the local address + // > of the packet and ipi_addr is the destination address in + // > the packet header. + let local_address = match (pkt_info.ipi_addr.s_addr, pkt_info.ipi_spec_dst.s_addr) { + (0, v) => v.to_ne_bytes(), + (v, _) => v.to_ne_bytes(), + }; + + let address = IpV4Address::new(local_address); + let interface = pkt_info.ipi_ifindex as _; + + Some((address, interface)) + } + + #[inline] + pub fn encode(addr: &IpV4Address, local_interface: Option) -> Cmsg { + let mut pkt_info = unsafe { core::mem::zeroed::() }; + pkt_info.ipi_spec_dst.s_addr = u32::from_ne_bytes((*addr).into()); + if let Some(interface) = local_interface { + pkt_info.ipi_ifindex = interface as _; + } + pkt_info + } +} + +#[cfg(any(not(s2n_quic_platform_pktinfo), test))] +mod pktinfo_disabled { + #![cfg_attr(test, allow(dead_code))] + use super::*; + + pub const LEVEL: Option = None; + pub const TYPE: Option = None; + pub const SOCKOPT: Option<(c_int, c_int)> = None; + pub const CMSG_SPACE: usize = 0; + + pub type Cmsg = c_int; + + #[inline] + pub const fn is_match(level: c_int, ty: c_int) -> bool { + let _ = level; + let _ = ty; + false + } + + /// # Safety + /// + /// * The provided bytes must be aligned to `cmsghdr` + pub unsafe fn decode(bytes: &[u8]) -> Option<(IpV4Address, u32)> { + let _ = bytes; + None + } + + #[inline] + pub fn encode(addr: &IpV4Address, local_interface: Option) -> Cmsg { + let _ = addr; + let _ = local_interface; + unimplemented!("this platform does not support pktinfo") + } +} + +mod pktinfo_impl { + #[cfg(not(s2n_quic_platform_pktinfo))] + pub use super::pktinfo_disabled::*; + #[cfg(s2n_quic_platform_pktinfo)] + pub use super::pktinfo_enabled::*; +} + +pub use pktinfo_impl::*; + +pub const IS_SUPPORTED: bool = cfg!(s2n_quic_platform_pktinfo); diff --git a/quic/s2n-quic-platform/src/features/pktinfo_v6.rs b/quic/s2n-quic-platform/src/features/pktinfo_v6.rs new file mode 100644 index 0000000000..c20df14fdf --- /dev/null +++ b/quic/s2n-quic-platform/src/features/pktinfo_v6.rs @@ -0,0 +1,94 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::c_int; +use s2n_quic_core::inet::IpV6Address; + +#[cfg(s2n_quic_platform_pktinfo)] +mod pktinfo_enabled { + use super::*; + use crate::message::cmsg; + use libc::{IPPROTO_IPV6, IPV6_PKTINFO, IPV6_RECVPKTINFO}; + + pub const LEVEL: Option = Some(IPPROTO_IPV6 as _); + pub const TYPE: Option = Some(IPV6_PKTINFO as _); + pub const SOCKOPT: Option<(c_int, c_int)> = Some((IPPROTO_IPV6 as _, IPV6_RECVPKTINFO)); + pub const CMSG_SPACE: usize = crate::message::cmsg::size_of_cmsg::(); + + pub type Cmsg = libc::in6_pktinfo; + + #[inline] + pub const fn is_match(level: c_int, ty: c_int) -> bool { + level == IPPROTO_IPV6 as c_int && ty == IPV6_PKTINFO as c_int + } + + /// # Safety + /// + /// * The provided bytes must be aligned to `cmsghdr` + pub unsafe fn decode(bytes: &[u8]) -> Option<(IpV6Address, u32)> { + let pkt_info = cmsg::decode::value_from_bytes::(bytes)?; + + let local_address = pkt_info.ipi6_addr.s6_addr; + + let address = IpV6Address::new(local_address); + let interface = pkt_info.ipi6_ifindex as _; + + Some((address, interface)) + } + + #[inline] + pub fn encode(addr: &IpV6Address, local_interface: Option) -> Cmsg { + let mut pkt_info = unsafe { core::mem::zeroed::() }; + pkt_info.ipi6_addr.s6_addr = (*addr).into(); + if let Some(interface) = local_interface { + pkt_info.ipi6_ifindex = interface as _; + } + pkt_info + } +} + +#[cfg(any(not(s2n_quic_platform_pktinfo), test))] +mod pktinfo_disabled { + #![cfg_attr(test, allow(dead_code))] + use super::*; + + pub const LEVEL: Option = None; + pub const TYPE: Option = None; + pub const SOCKOPT: Option<(c_int, c_int)> = None; + pub const CMSG_SPACE: usize = 0; + + pub type Cmsg = c_int; + + #[inline] + pub const fn is_match(level: c_int, ty: c_int) -> bool { + let _ = level; + let _ = ty; + false + } + + /// # Safety + /// + /// * The provided bytes must be aligned to `cmsghdr` + pub unsafe fn decode(bytes: &[u8]) -> Option<(IpV6Address, u32)> { + let _ = bytes; + None + } + + #[inline] + pub fn encode(addr: &IpV6Address, local_interface: Option) -> Cmsg { + let _ = addr; + let _ = local_interface; + unimplemented!("this platform does not support pktinfo") + } +} + +mod pktinfo_impl { + #[cfg(not(s2n_quic_platform_pktinfo))] + pub use super::pktinfo_disabled::*; + #[cfg(s2n_quic_platform_pktinfo)] + pub use super::pktinfo_enabled::*; +} + +pub use pktinfo_impl::*; + +pub const IS_SUPPORTED: bool = cfg!(s2n_quic_platform_pktinfo); diff --git a/quic/s2n-quic-platform/src/features/tos.rs b/quic/s2n-quic-platform/src/features/tos.rs new file mode 100644 index 0000000000..3bd10e54b7 --- /dev/null +++ b/quic/s2n-quic-platform/src/features/tos.rs @@ -0,0 +1,23 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::c_int; +use s2n_quic_core::inet::ExplicitCongestionNotification; + +pub const IS_SUPPORTED: bool = super::tos_v4::IS_SUPPORTED || super::tos_v6::IS_SUPPORTED; + +#[inline] +pub const fn is_match(level: c_int, ty: c_int) -> bool { + super::tos_v4::is_match(level, ty) || super::tos_v6::is_match(level, ty) +} + +#[inline] +pub fn decode(bytes: &[u8]) -> Option { + let value = match bytes.len() { + 1 => bytes[0], + 4 => u32::from_ne_bytes(bytes.try_into().unwrap()) as u8, + _ => return None, + }; + + Some(ExplicitCongestionNotification::new(value)) +} diff --git a/quic/s2n-quic-platform/src/message/cmsg.rs b/quic/s2n-quic-platform/src/message/cmsg.rs index 85f6807f4a..5a77c7c85b 100644 --- a/quic/s2n-quic-platform/src/message/cmsg.rs +++ b/quic/s2n-quic-platform/src/message/cmsg.rs @@ -4,9 +4,18 @@ #![allow(clippy::unnecessary_cast)] // some platforms encode lengths as `u32` so we cast everything to be safe use crate::features; -use core::mem::{align_of, size_of}; +use core::mem::size_of; use libc::cmsghdr; -use s2n_quic_core::inet::{AncillaryData, ExplicitCongestionNotification}; + +pub mod decode; +pub mod encode; +pub mod storage; + +#[cfg(test)] +mod tests; + +pub use encode::Encoder; +pub use storage::Storage; pub const fn size_of_cmsg() -> usize { unsafe { libc::CMSG_SPACE(size_of::() as _) as _ } @@ -36,10 +45,7 @@ pub const MAX_LEN: usize = { let segment_offload_size = const_max(gso_size, gro_size); // rather than taking the max, we add these in case the OS gives us both - #[cfg(s2n_quic_platform_pktinfo)] - let pktinfo_size = size_of_cmsg::() + size_of_cmsg::(); - #[cfg(not(s2n_quic_platform_pktinfo))] - let pktinfo_size = 0; + let pktinfo_size = features::pktinfo_v4::CMSG_SPACE + features::pktinfo_v6::CMSG_SPACE; // This is currently needed due to how we detect if CMSG data has been written or not. // @@ -49,472 +55,5 @@ pub const MAX_LEN: usize = { tos_size + segment_offload_size + pktinfo_size + padding }; -#[repr(align(8))] // the storage needs to be aligned to the same as `cmsghdr` -#[derive(Clone, Debug)] -pub struct Storage([u8; L]); - -impl Default for Storage { - #[inline] - fn default() -> Self { - Self([0; L]) - } -} - -impl Storage { - #[inline] - pub fn len(&self) -> usize { - self.0.len() - } - - #[inline] - #[allow(dead_code)] // clippy wants this to exist but we don't use it - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - #[inline] - pub fn as_mut_ptr(&mut self) -> *mut u8 { - self.0.as_mut_ptr() - } -} - -#[derive(Clone, Copy, Debug)] -pub struct OutOfSpace; - -pub struct SliceEncoder<'a> { - storage: &'a mut [u8], - cursor: usize, -} - -impl<'a> SliceEncoder<'a> { - #[inline] - pub fn new(storage: &'a mut [u8]) -> Self { - Self { storage, cursor: 0 } - } - - #[inline] - pub fn len(&self) -> usize { - self.cursor - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.cursor == 0 - } -} - -impl<'a> Encoder for SliceEncoder<'a> { - #[inline] - fn encode_cmsg( - &mut self, - level: libc::c_int, - ty: libc::c_int, - value: T, - ) -> Result { - unsafe { - debug_assert!( - align_of::() <= align_of::(), - "alignment of T should be less than or equal to cmsghdr" - ); - - // CMSG_SPACE() returns the number of bytes an ancillary element - // with payload of the passed data length occupies. - let element_len = size_of_cmsg::(); - debug_assert_ne!(element_len, 0); - debug_assert_eq!(libc::CMSG_SPACE(size_of::() as _) as usize, element_len); - - let new_cursor = self.cursor.checked_add(element_len).ok_or(OutOfSpace)?; - - self.storage - .len() - .checked_sub(new_cursor) - .ok_or(OutOfSpace)?; - - let cmsg_ptr = { - // Safety: the msg_control buffer should always be allocated to MAX_LEN - let msg_controllen = self.cursor; - let msg_control = self.storage.as_mut_ptr().add(msg_controllen as _); - msg_control as *mut cmsghdr - }; - - { - let cmsg = &mut *cmsg_ptr; - - // interpret the start of cmsg as a cmsghdr - // Safety: the cmsg slice should already be zero-initialized and aligned - - // Indicate the type of cmsg - cmsg.cmsg_level = level; - cmsg.cmsg_type = ty; - - // CMSG_LEN() returns the value to store in the cmsg_len member - // of the cmsghdr structure, taking into account any necessary - // alignment. It takes the data length as an argument. - cmsg.cmsg_len = libc::CMSG_LEN(size_of::() as _) as _; - } - - { - // Write the actual value in the data space of the cmsg - // Safety: we asserted we had enough space in the cmsg buffer above - // CMSG_DATA() returns a pointer to the data portion of a - // cmsghdr. The pointer returned cannot be assumed to be - // suitably aligned for accessing arbitrary payload data types. - // Applications should not cast it to a pointer type matching the - // payload, but should instead use memcpy(3) to copy data to or - // from a suitably declared object. - let data_ptr = cmsg_ptr.add(1); - - debug_assert_eq!(data_ptr as *mut u8, libc::CMSG_DATA(cmsg_ptr) as *mut u8); - - core::ptr::copy_nonoverlapping( - &value as *const T as *const u8, - data_ptr as *mut u8, - size_of::(), - ); - } - - // add the values as a usize to make sure we work cross-platform - self.cursor = new_cursor; - debug_assert!( - self.cursor <= self.storage.len(), - "msg should not exceed max allocated" - ); - - Ok(self.cursor) - } - } -} - -pub trait Encoder { - /// Encodes the given value as a control message in the cmsg buffer. - /// - /// The msghdr.msg_control should be zero-initialized and aligned and contain enough - /// room for the value to be written. - fn encode_cmsg( - &mut self, - level: libc::c_int, - ty: libc::c_int, - value: T, - ) -> Result; -} - -impl Encoder for libc::msghdr { - #[inline] - fn encode_cmsg( - &mut self, - level: libc::c_int, - ty: libc::c_int, - value: T, - ) -> Result { - let storage = unsafe { &mut *(self.msg_control as *mut Storage) }; - - let mut encoder = SliceEncoder { - storage: &mut storage.0, - cursor: self.msg_controllen as _, - }; - - let cursor = encoder.encode_cmsg(level, ty, value)?; - - self.msg_controllen = cursor as _; - - Ok(cursor) - } -} - -/// Decodes all recognized control messages in the given `msghdr` into `AncillaryData` -#[inline] -pub fn decode(msghdr: &libc::msghdr) -> AncillaryData { - let mut result = AncillaryData::default(); - - let iter = unsafe { Iter::from_msghdr(msghdr) }; - - for (cmsg, value) in iter { - unsafe { - match (cmsg.cmsg_level, cmsg.cmsg_type) { - (level, ty) - if features::tos_v4::is_match(level, ty) - || features::tos_v6::is_match(level, ty) => - { - // IP_TOS cmsgs should be 1 byte, but occasionally are reported as 4 bytes - let value = match value.len() { - 1 => decode_value::(value), - 4 => decode_value::(value) as u8, - len => { - if cfg!(test) { - panic!( - "invalid size for ECN marking. len: {len}, value: {value:?}" - ); - } - continue; - } - }; - - result.ecn = ExplicitCongestionNotification::new(value); - } - #[cfg(s2n_quic_platform_pktinfo)] - (libc::IPPROTO_IP, libc::IP_PKTINFO) => { - let pkt_info = decode_value::(value); - - // read from both fields in case only one is set and not the other - // - // from https://man7.org/linux/man-pages/man7/ip.7.html: - // - // > ipi_spec_dst is the local address - // > of the packet and ipi_addr is the destination address in - // > the packet header. - let local_address = - match (pkt_info.ipi_addr.s_addr, pkt_info.ipi_spec_dst.s_addr) { - (0, v) => v.to_ne_bytes(), - (v, _) => v.to_ne_bytes(), - }; - - // The port should be specified by a different layer that has that information - let port = 0; - let local_address = - s2n_quic_core::inet::SocketAddressV4::new(local_address, port); - result.local_address = local_address.into(); - result.local_interface = Some(pkt_info.ipi_ifindex as _); - } - #[cfg(s2n_quic_platform_pktinfo)] - (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { - let pkt_info = decode_value::(value); - let local_address = pkt_info.ipi6_addr.s6_addr; - // The port should be specified by a different layer that has that information - let port = 0; - let local_address = - s2n_quic_core::inet::SocketAddressV6::new(local_address, port); - result.local_address = local_address.into(); - result.local_interface = Some(pkt_info.ipi6_ifindex as _); - } - (level, ty) if features::gso::is_match(level, ty) => { - // ignore GSO settings when reading - continue; - } - (level, ty) if features::gro::is_match(level, ty) => { - let segment_size = decode_value::(value); - result.segment_size = segment_size as _; - } - (level, ty) if cfg!(test) => { - // if we're getting an unexpected cmsg we should know about it in testing - panic!("unexpected cmsghdr {{ level: {level}, type: {ty}, value: {value:?} }}"); - } - _ => {} - } - } - } - - result -} - -/// Decodes a value of type `T` from the given `cmsghdr` -/// # Safety -/// -/// `cmsghdr` must refer to a cmsg containing a payload of type `T` -#[inline] -pub unsafe fn decode_value(value: &[u8]) -> T { - use core::mem; - - debug_assert!(mem::align_of::() <= mem::align_of::()); - debug_assert!(value.len() >= size_of::()); - - let mut v = mem::zeroed::(); - - core::ptr::copy_nonoverlapping(value.as_ptr(), &mut v as *mut T as *mut u8, size_of::()); - - v -} - -pub struct Iter<'a> { - cursor: *const u8, - len: usize, - contents: core::marker::PhantomData<&'a [u8]>, -} - -impl<'a> Iter<'a> { - /// Creates a new cmsg::Iter used for iterating over control message headers in the given - /// slice of bytes. - /// - /// # Safety - /// - /// * `contents` must be aligned to cmsghdr - #[inline] - pub unsafe fn new(contents: &'a [u8]) -> Iter<'a> { - let cursor = contents.as_ptr(); - let len = contents.len(); - - debug_assert_eq!( - cursor.align_offset(align_of::()), - 0, - "contents must be aligned to cmsghdr" - ); - - Self { - cursor, - len, - contents: Default::default(), - } - } - - /// Creates a new cmsg::Iter used for iterating over control message headers in the given - /// msghdr. - /// - /// # Safety - /// - /// * `contents` must be aligned to cmsghdr - /// * `msghdr` must point to a valid control buffer - #[inline] - pub unsafe fn from_msghdr(msghdr: &'a libc::msghdr) -> Self { - let ptr = msghdr.msg_control as *const u8; - let len = msghdr.msg_controllen as usize; - let slice = core::slice::from_raw_parts(ptr, len); - Self::new(slice) - } -} - -impl<'a> Iterator for Iter<'a> { - type Item = (&'a cmsghdr, &'a [u8]); - - #[inline] - fn next(&mut self) -> Option { - unsafe { - let cursor = self.cursor; - - // make sure we can decode a cmsghdr - self.len.checked_sub(size_of::())?; - let cmsg = &*(cursor as *const cmsghdr); - let data_ptr = cursor.add(size_of::()); - - let cmsg_len = cmsg.cmsg_len as usize; - - // make sure we have capacity to decode the provided cmsg_len - self.len.checked_sub(cmsg_len)?; - - // the cmsg_len includes the header itself so it needs to be subtracted off - let data_len = cmsg_len.checked_sub(size_of::())?; - // construct a slice with the provided data len - let data = core::slice::from_raw_parts(data_ptr, data_len); - - // empty messages are invalid - if data.is_empty() { - return None; - } - - // calculate the next message and update the cursor/len - { - let space = libc::CMSG_SPACE(data_len as _) as usize; - debug_assert!( - space >= data_len, - "space ({space}) should be at least of size len ({data_len})" - ); - self.len = self.len.saturating_sub(space); - self.cursor = cursor.add(space); - } - - Some((cmsg, data)) - } - } -} - #[cfg(test)] -mod tests { - use super::*; - use bolero::{check, TypeGenerator}; - use libc::c_int; - - /// Ensures the cmsg iterator doesn't crash or segfault - #[test] - #[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(17))] - fn iter_test() { - check!().for_each(|cmsg| { - // the bytes needs to be aligned to a cmsghdr - let offset = cmsg.as_ptr().align_offset(align_of::()); - - if let Some(cmsg) = cmsg.get(offset..) { - for (cmsghdr, value) in unsafe { Iter::new(cmsg) } { - let _ = cmsghdr; - let _ = value; - } - } - }); - } - - #[derive(Clone, Copy, Debug, TypeGenerator)] - struct Op { - level: c_int, - ty: c_int, - value: Value, - } - - #[derive(Clone, Copy, Debug, TypeGenerator)] - enum Value { - U8(u8), - U16(u16), - U32(u32), - // alignment can't exceed that of cmsghdr - U64([u32; 2]), - U128([u32; 4]), - } - - impl Value { - fn check_value(&self, bytes: &[u8]) { - let expected_len = match self { - Self::U8(_) => 1, - Self::U16(_) => 2, - Self::U32(_) => 4, - Self::U64(_) => 8, - Self::U128(_) => 16, - }; - assert_eq!(expected_len, bytes.len()); - } - } - - fn round_trip(ops: &[Op]) { - let mut storage = Storage::<{ MAX_LEN }>::default(); - let mut encoder = SliceEncoder { - storage: &mut storage.0, - cursor: 0, - }; - - let mut expected_encoded_count = 0; - - for op in ops { - let res = match op.value { - Value::U8(value) => encoder.encode_cmsg(op.level, op.ty, value), - Value::U16(value) => encoder.encode_cmsg(op.level, op.ty, value), - Value::U32(value) => encoder.encode_cmsg(op.level, op.ty, value), - Value::U64(value) => encoder.encode_cmsg(op.level, op.ty, value), - Value::U128(value) => encoder.encode_cmsg(op.level, op.ty, value), - }; - - match res { - Ok(_) => expected_encoded_count += 1, - Err(_) => break, - } - } - - let cursor = encoder.cursor; - let mut actual_decoded_count = 0; - let mut iter = unsafe { Iter::new(&storage.0[..cursor]) }; - - for (op, (cmsghdr, value)) in ops.iter().zip(&mut iter) { - assert_eq!(op.level, cmsghdr.cmsg_level); - assert_eq!(op.ty, cmsghdr.cmsg_type); - op.value.check_value(value); - actual_decoded_count += 1; - } - - assert_eq!(expected_encoded_count, actual_decoded_count); - assert!(iter.next().is_none()); - } - - #[cfg(not(kani))] - type Ops = Vec; - #[cfg(kani)] - type Ops = s2n_quic_core::testing::InlineVec; - - #[test] - #[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(9))] - fn round_trip_test() { - check!().with_type::().for_each(|ops| round_trip(ops)); - } -} +mod tests_ {} diff --git a/quic/s2n-quic-platform/src/message/cmsg/decode.rs b/quic/s2n-quic-platform/src/message/cmsg/decode.rs new file mode 100644 index 0000000000..d380903bc1 --- /dev/null +++ b/quic/s2n-quic-platform/src/message/cmsg/decode.rs @@ -0,0 +1,222 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::Storage; +use crate::features; +use core::mem::{align_of, size_of}; +use libc::cmsghdr; +use s2n_quic_core::{ensure, inet::AncillaryData}; + +/// Decodes a value of type `T` from the given `cmsghdr` +/// # Safety +/// +/// `cmsghdr` must refer to a cmsg containing a payload of type `T` +#[inline] +pub unsafe fn value_from_bytes(value: &[u8]) -> Option { + use core::mem; + + ensure!(value.len() == size_of::(), None); + + debug_assert!(mem::align_of::() <= mem::align_of::()); + + let mut v = mem::zeroed::(); + + core::ptr::copy_nonoverlapping(value.as_ptr(), &mut v as *mut T as *mut u8, size_of::()); + + Some(v) +} + +/// Decodes all recognized control messages in the given `iter` into `AncillaryData` +#[inline] +pub fn collect(iter: Iter) -> AncillaryData { + let mut data = AncillaryData::default(); + + for (cmsg, value) in iter { + unsafe { + // SAFETY: `Iter` ensures values are aligned + collect_item(&mut data, cmsg, value); + } + } + + data +} + +#[inline] +unsafe fn collect_item(data: &mut AncillaryData, cmsg: &cmsghdr, value: &[u8]) { + macro_rules! decode_error { + ($error:expr) => { + #[cfg(all(test, feature = "tracing", not(any(kani, miri, fuzz))))] + tracing::debug!( + error = $error, + level = cmsg.cmsg_level, + r#type = cmsg.cmsg_type, + value = ?value, + ); + } + } + + match (cmsg.cmsg_level, cmsg.cmsg_type) { + (level, ty) if features::tos::is_match(level, ty) => { + if let Some(ecn) = features::tos::decode(value) { + data.ecn = ecn; + } else { + decode_error!("invalid TOS value"); + } + } + (level, ty) if features::pktinfo_v4::is_match(level, ty) => { + if let Some((local_address, local_interface)) = features::pktinfo_v4::decode(value) { + // The port should be specified by a different layer that has that information + let port = 0; + let local_address = s2n_quic_core::inet::SocketAddressV4::new(local_address, port); + data.local_address = local_address.into(); + data.local_interface = Some(local_interface); + } else { + decode_error!("invalid pktinfo_v4 value"); + } + } + (level, ty) if features::pktinfo_v6::is_match(level, ty) => { + if let Some((local_address, local_interface)) = features::pktinfo_v6::decode(value) { + // The port should be specified by a different layer that has that information + let port = 0; + let local_address = s2n_quic_core::inet::SocketAddressV6::new(local_address, port); + data.local_address = local_address.into(); + data.local_interface = Some(local_interface); + } else { + decode_error!("invalid pktinfo_v6 value"); + } + } + (level, ty) if features::gso::is_match(level, ty) => { + // ignore GSO settings when reading + } + (level, ty) if features::gro::is_match(level, ty) => { + if let Some(segment_size) = value_from_bytes::(value) { + data.segment_size = segment_size as _; + } else { + decode_error!("invalid gro value"); + } + } + _ => { + decode_error!("unexpected cmsghdr"); + } + } +} + +pub struct Iter<'a> { + cursor: *const u8, + len: usize, + contents: core::marker::PhantomData<&'a [u8]>, +} + +impl<'a> Iter<'a> { + /// Creates a new cmsg::Iter used for iterating over control message headers in the given + /// [`Storage`]. + #[inline] + pub fn new(contents: &'a Storage) -> Iter<'a> { + let cursor = contents.as_ptr(); + let len = contents.len(); + + Self { + cursor, + len, + contents: Default::default(), + } + } + + /// Creates a new cmsg::Iter used for iterating over control message headers in the given slice + /// of bytes. + /// + /// # Safety + /// + /// * `contents` must be aligned to cmsghdr + #[inline] + pub unsafe fn from_bytes(contents: &'a [u8]) -> Self { + let cursor = contents.as_ptr(); + let len = contents.len(); + + debug_assert_eq!( + cursor.align_offset(align_of::()), + 0, + "contents must be aligned to cmsghdr" + ); + + Self { + cursor, + len, + contents: Default::default(), + } + } + + /// Creates a new cmsg::Iter used for iterating over control message headers in the given + /// msghdr. + /// + /// # Safety + /// + /// * `contents` must be aligned to cmsghdr + /// * `msghdr` must point to a valid control buffer + #[inline] + pub unsafe fn from_msghdr(msghdr: &'a libc::msghdr) -> Self { + let cursor = msghdr.msg_control as *const u8; + let len = msghdr.msg_controllen as usize; + + debug_assert_eq!( + cursor.align_offset(align_of::()), + 0, + "contents must be aligned to cmsghdr" + ); + + Self { + cursor, + len, + contents: Default::default(), + } + } + + #[inline] + pub fn collect(self) -> AncillaryData { + collect(self) + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = (&'a cmsghdr, &'a [u8]); + + #[inline] + fn next(&mut self) -> Option { + unsafe { + let cursor = self.cursor; + + // make sure we can decode a cmsghdr + self.len.checked_sub(size_of::())?; + let cmsg = &*(cursor as *const cmsghdr); + let data_ptr = cursor.add(size_of::()); + + let cmsg_len = cmsg.cmsg_len as usize; + + // make sure we have capacity to decode the provided cmsg_len + self.len.checked_sub(cmsg_len)?; + + // the cmsg_len includes the header itself so it needs to be subtracted off + let data_len = cmsg_len.checked_sub(size_of::())?; + // construct a slice with the provided data len + let data = core::slice::from_raw_parts(data_ptr, data_len); + + // empty messages are invalid + if data.is_empty() { + return None; + } + + // calculate the next message and update the cursor/len + { + let space = libc::CMSG_SPACE(data_len as _) as usize; + debug_assert!( + space >= data_len, + "space ({space}) should be at least of size len ({data_len})" + ); + self.len = self.len.saturating_sub(space); + self.cursor = cursor.add(space); + } + + Some((cmsg, data)) + } + } +} diff --git a/quic/s2n-quic-platform/src/message/cmsg/encode.rs b/quic/s2n-quic-platform/src/message/cmsg/encode.rs new file mode 100644 index 0000000000..da3dd8ba42 --- /dev/null +++ b/quic/s2n-quic-platform/src/message/cmsg/encode.rs @@ -0,0 +1,97 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::features; +use s2n_quic_core::inet::{ExplicitCongestionNotification, SocketAddress}; + +#[derive(Clone, Copy, Debug)] +pub struct Error; + +pub trait Encoder { + /// Encodes the given value as a control message in the cmsg buffer. + /// + /// The msghdr.msg_control should be zero-initialized and aligned and contain enough + /// room for the value to be written. + fn encode_cmsg( + &mut self, + level: libc::c_int, + ty: libc::c_int, + value: T, + ) -> Result; + + /// Encodes ECN markings into the cmsg encoder + #[inline] + fn encode_ecn( + &mut self, + ecn: ExplicitCongestionNotification, + remote_address: &SocketAddress, + ) -> Result { + // no need to encode for the default case + if ecn == ExplicitCongestionNotification::NotEct { + return Ok(0); + } + + // the remote address needs to be unmapped in order to set the appropriate cmsg + match remote_address.unmap() { + SocketAddress::IpV4(_) => { + if let (Some(level), Some(ty)) = (features::tos_v4::LEVEL, features::tos_v4::TYPE) { + return self.encode_cmsg(level, ty, ecn as u8 as features::tos_v4::Cmsg); + } + } + SocketAddress::IpV6(_) => { + if let (Some(level), Some(ty)) = (features::tos_v6::LEVEL, features::tos_v6::TYPE) { + return self.encode_cmsg(level, ty, ecn as u8 as features::tos_v6::Cmsg); + } + } + } + + Ok(0) + } + + /// Encodes GSO segment_size into the cmsg encoder + #[inline] + fn encode_gso(&mut self, segment_size: u16) -> Result { + if let (Some(level), Some(ty)) = (features::gso::LEVEL, features::gso::TYPE) { + let segment_size = segment_size as features::gso::Cmsg; + self.encode_cmsg(level, ty, segment_size) + } else { + panic!("platform does not support GSO"); + } + } + + #[inline] + fn encode_local_address(&mut self, address: &SocketAddress) -> Result { + use s2n_quic_core::inet::Unspecified; + + match address { + SocketAddress::IpV4(addr) => { + use features::pktinfo_v4 as pktinfo; + if let (Some(level), Some(ty)) = (pktinfo::LEVEL, pktinfo::TYPE) { + let ip = addr.ip(); + + if ip.is_unspecified() { + return Ok(0); + } + + let value = pktinfo::encode(ip, None); + return self.encode_cmsg(level, ty, value); + } + } + SocketAddress::IpV6(addr) => { + use features::pktinfo_v6 as pktinfo; + if let (Some(level), Some(ty)) = (pktinfo::LEVEL, pktinfo::TYPE) { + let ip = addr.ip(); + + if ip.is_unspecified() { + return Ok(0); + } + + let value = pktinfo::encode(ip, None); + return self.encode_cmsg(level, ty, value); + } + } + } + + Ok(0) + } +} diff --git a/quic/s2n-quic-platform/src/message/cmsg/storage.rs b/quic/s2n-quic-platform/src/message/cmsg/storage.rs new file mode 100644 index 0000000000..59803d9e59 --- /dev/null +++ b/quic/s2n-quic-platform/src/message/cmsg/storage.rs @@ -0,0 +1,185 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::{encode, size_of_cmsg}; +use core::{ + mem::{align_of, size_of}, + ops::{Deref, DerefMut}, +}; +use libc::cmsghdr; + +#[repr(align(8))] // the storage needs to be aligned to the same as `cmsghdr` +#[derive(Clone, Debug)] +pub struct Storage([u8; L]); + +impl Storage { + #[inline] + pub fn encoder(&mut self) -> Encoder { + Encoder { + storage: self, + cursor: 0, + } + } + + #[inline] + pub fn iter(&self) -> super::decode::Iter { + super::decode::Iter::new(self) + } +} + +impl Default for Storage { + #[inline] + fn default() -> Self { + Self([0; L]) + } +} + +impl Deref for Storage { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + &self.0 + } +} + +impl DerefMut for Storage { + #[inline] + fn deref_mut(&mut self) -> &mut [u8] { + &mut self.0 + } +} + +pub struct Encoder<'a, const L: usize> { + storage: &'a mut Storage, + cursor: usize, +} + +impl<'a, const L: usize> Encoder<'a, L> { + #[inline] + pub fn new(storage: &'a mut Storage) -> Self { + Self { storage, cursor: 0 } + } + + #[inline] + pub fn len(&self) -> usize { + self.cursor + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.cursor == 0 + } + + #[inline] + pub fn seek(&mut self, len: usize) { + self.cursor += len; + debug_assert!(self.cursor <= L); + } + + #[inline] + pub fn iter(&self) -> super::decode::Iter { + unsafe { + // SAFETY: bytes are aligned with Storage type + super::decode::Iter::from_bytes(self) + } + } +} + +impl<'a, const L: usize> Deref for Encoder<'a, L> { + type Target = [u8]; + + #[inline] + fn deref(&self) -> &[u8] { + &self.storage[..self.cursor] + } +} + +impl<'a, const L: usize> DerefMut for Encoder<'a, L> { + #[inline] + fn deref_mut(&mut self) -> &mut [u8] { + &mut self.storage[..self.cursor] + } +} + +impl<'a, const L: usize> super::Encoder for Encoder<'a, L> { + #[inline] + fn encode_cmsg( + &mut self, + level: libc::c_int, + ty: libc::c_int, + value: T, + ) -> Result { + unsafe { + debug_assert!( + align_of::() <= align_of::(), + "alignment of T should be less than or equal to cmsghdr" + ); + + // CMSG_SPACE() returns the number of bytes an ancillary element + // with payload of the passed data length occupies. + let element_len = size_of_cmsg::(); + debug_assert_ne!(element_len, 0); + debug_assert_eq!(libc::CMSG_SPACE(size_of::() as _) as usize, element_len); + + let new_cursor = self.cursor.checked_add(element_len).ok_or(encode::Error)?; + + self.storage + .len() + .checked_sub(new_cursor) + .ok_or(encode::Error)?; + + let cmsg_ptr = { + // Safety: the msg_control buffer should always be allocated to MAX_LEN + let msg_controllen = self.cursor; + let msg_control = self.storage.as_mut_ptr().add(msg_controllen as _); + msg_control as *mut cmsghdr + }; + + { + let cmsg = &mut *cmsg_ptr; + + // interpret the start of cmsg as a cmsghdr + // Safety: the cmsg slice should already be zero-initialized and aligned + + // Indicate the type of cmsg + cmsg.cmsg_level = level; + cmsg.cmsg_type = ty; + + // CMSG_LEN() returns the value to store in the cmsg_len member + // of the cmsghdr structure, taking into account any necessary + // alignment. It takes the data length as an argument. + cmsg.cmsg_len = libc::CMSG_LEN(size_of::() as _) as _; + } + + { + // Write the actual value in the data space of the cmsg + // Safety: we asserted we had enough space in the cmsg buffer above + // CMSG_DATA() returns a pointer to the data portion of a + // cmsghdr. The pointer returned cannot be assumed to be + // suitably aligned for accessing arbitrary payload data types. + // Applications should not cast it to a pointer type matching the + // payload, but should instead use memcpy(3) to copy data to or + // from a suitably declared object. + let data_ptr = cmsg_ptr.add(1); + + debug_assert_eq!(data_ptr as *mut u8, libc::CMSG_DATA(cmsg_ptr) as *mut u8); + + core::ptr::copy_nonoverlapping( + &value as *const T as *const u8, + data_ptr as *mut u8, + size_of::(), + ); + } + + // add the values as a usize to make sure we work cross-platform + self.cursor = new_cursor; + debug_assert!( + self.cursor <= self.storage.len(), + "msg should not exceed max allocated" + ); + + Ok(self.cursor) + } + } +} diff --git a/quic/s2n-quic-platform/src/message/cmsg/tests.rs b/quic/s2n-quic-platform/src/message/cmsg/tests.rs new file mode 100644 index 0000000000..3e920a3e25 --- /dev/null +++ b/quic/s2n-quic-platform/src/message/cmsg/tests.rs @@ -0,0 +1,123 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use bolero::{check, TypeGenerator}; +use core::mem::align_of; +use libc::c_int; + +#[inline] +fn aligned_iter(bytes: &[u8], f: impl FnOnce(decode::Iter)) { + // the bytes needs to be aligned to a cmsghdr + let offset = bytes.as_ptr().align_offset(align_of::()); + + if let Some(bytes) = bytes.get(offset..) { + let iter = unsafe { + // SAFETY: bytes are aligned above + decode::Iter::from_bytes(bytes) + }; + + f(iter) + } +} + +/// Ensures the cmsg iterator doesn't crash or segfault +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(17))] +fn iter_test() { + check!().for_each(|bytes| { + aligned_iter(bytes, |iter| { + for (cmsghdr, value) in iter { + let _ = cmsghdr; + let _ = value; + } + }) + }); +} + +/// Ensures the `decode::Iter::collect` doesn't crash or segfault +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(17))] +fn collect_test() { + check!().for_each(|bytes| { + aligned_iter(bytes, |iter| { + let _ = iter.collect(); + }) + }); +} + +#[derive(Clone, Copy, Debug, TypeGenerator)] +struct Op { + level: c_int, + ty: c_int, + value: Value, +} + +#[derive(Clone, Copy, Debug, TypeGenerator)] +enum Value { + U8(u8), + U16(u16), + U32(u32), + // alignment can't exceed that of cmsghdr + U64([u32; 2]), + U128([u32; 4]), +} + +impl Value { + fn check_value(&self, bytes: &[u8]) { + let expected_len = match self { + Self::U8(_) => 1, + Self::U16(_) => 2, + Self::U32(_) => 4, + Self::U64(_) => 8, + Self::U128(_) => 16, + }; + assert_eq!(expected_len, bytes.len()); + } +} + +fn round_trip(ops: &[Op]) { + let mut storage = Storage::<32>::default(); + let mut encoder = storage.encoder(); + + let mut expected_encoded_count = 0; + + for op in ops { + let res = match op.value { + Value::U8(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U16(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U32(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U64(value) => encoder.encode_cmsg(op.level, op.ty, value), + Value::U128(value) => encoder.encode_cmsg(op.level, op.ty, value), + }; + + match res { + Ok(_) => expected_encoded_count += 1, + Err(_) => break, + } + } + + let mut actual_decoded_count = 0; + let mut iter = encoder.iter(); + + for (op, (cmsghdr, value)) in ops.iter().zip(&mut iter) { + assert_eq!(op.level, cmsghdr.cmsg_level); + assert_eq!(op.ty, cmsghdr.cmsg_type); + op.value.check_value(value); + actual_decoded_count += 1; + } + + assert_eq!(expected_encoded_count, actual_decoded_count); + assert!(iter.next().is_none()); +} + +#[cfg(not(kani))] +type Ops = Vec; +#[cfg(kani)] +type Ops = s2n_quic_core::testing::InlineVec; + +#[test] +#[cfg_attr(kani, kani::proof, kani::solver(cadical), kani::unwind(9))] +fn round_trip_test() { + check!().with_type::().for_each(|ops| round_trip(ops)); +} diff --git a/quic/s2n-quic-platform/src/message/msg.rs b/quic/s2n-quic-platform/src/message/msg.rs index fb94cb291f..ba19feda10 100644 --- a/quic/s2n-quic-platform/src/message/msg.rs +++ b/quic/s2n-quic-platform/src/message/msg.rs @@ -12,8 +12,8 @@ use core::{ use libc::{iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6}; use s2n_quic_core::{ inet::{ - datagram, ExplicitCongestionNotification, IpV4Address, IpV6Address, SocketAddress, - SocketAddressV4, SocketAddressV6, + datagram, IpV4Address, IpV6Address, SocketAddress, SocketAddressV4, SocketAddressV6, + Unspecified, }, io::tx, path::{self, Handle as _}, @@ -35,7 +35,7 @@ impl MessageTrait for msghdr { type Handle = Handle; const SUPPORTS_GSO: bool = features::gso::IS_SUPPORTED; - const SUPPORTS_ECN: bool = cfg!(s2n_quic_platform_tos); + const SUPPORTS_ECN: bool = features::tos::IS_SUPPORTED; const SUPPORTS_FLOW_LABELS: bool = true; #[inline] @@ -60,10 +60,8 @@ impl MessageTrait for msghdr { #[inline] fn set_segment_size(&mut self, size: usize) { - let level = features::gso::LEVEL.expect("gso is unsupported"); - let ty = features::gso::TYPE.expect("gso is unsupported"); - self.encode_cmsg(level, ty, size as features::gso::Cmsg) - .unwrap(); + debug_assert!(size <= u16::MAX as usize); + self.cmsg_encoder().encode_gso(size as _).unwrap(); } #[inline] @@ -145,7 +143,7 @@ impl MessageTrait for msghdr { let (mut header, cmsg) = self.header()?; // only copy the port if we are told the IP address - if cfg!(s2n_quic_platform_pktinfo) { + if !header.path.local_address.ip().is_unspecified() { header.path.local_address.set_port(local_address.port()); } else { header.path.local_address = *local_address; @@ -195,7 +193,9 @@ impl MessageTrait for msghdr { let handle = *message.path_handle(); handle.update_msg_hdr(self); - self.set_ecn(message.ecn(), &handle.remote_address.0); + self.cmsg_encoder() + .encode_ecn(message.ecn(), &handle.remote_address.0) + .unwrap(); Ok(len) } @@ -286,7 +286,7 @@ fn layout( struct Header { pub iovec: iovec, pub msg_name: sockaddr_in6, - pub cmsg: cmsg::Storage, + pub cmsg: cmsg::Storage<{ cmsg::MAX_LEN }>, } impl Header { @@ -310,7 +310,7 @@ impl Header { debug_assert_eq!( entry .msg_control - .align_offset(core::mem::align_of::()), + .align_offset(core::mem::align_of::>()), 0 ); } diff --git a/quic/s2n-quic-platform/src/message/msg/ext.rs b/quic/s2n-quic-platform/src/message/msg/ext.rs index 009bd57c8e..ffc54061a5 100644 --- a/quic/s2n-quic-platform/src/message/msg/ext.rs +++ b/quic/s2n-quic-platform/src/message/msg/ext.rs @@ -3,20 +3,26 @@ use super::*; -pub trait Ext: cmsg::Encoder { +pub trait Ext { + type Encoder<'a>: cmsg::Encoder + where + Self: 'a; + fn header(&self) -> Option<(datagram::Header, datagram::AncillaryData)>; - fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress); + fn cmsg_encoder(&mut self) -> Self::Encoder<'_>; fn remote_address(&self) -> Option; fn set_remote_address(&mut self, remote_address: &SocketAddress); } impl Ext for msghdr { + type Encoder<'a> = MsghdrEncoder<'a>; + #[inline] fn header(&self) -> Option<(datagram::Header, datagram::AncillaryData)> { let addr = self.remote_address()?; let mut path = Handle::from_remote_address(addr.into()); - let ancillary_data = cmsg::decode(self); + let ancillary_data = unsafe { cmsg::decode::Iter::from_msghdr(self) }.collect(); let ecn = ancillary_data.ecn; path.with_ancillary_data(ancillary_data); @@ -27,26 +33,8 @@ impl Ext for msghdr { } #[inline] - fn set_ecn(&mut self, ecn: ExplicitCongestionNotification, remote_address: &SocketAddress) { - if ecn == ExplicitCongestionNotification::NotEct { - return; - } - - // the remote address needs to be unmapped in order to set the appropriate cmsg - match remote_address.unmap() { - SocketAddress::IpV4(_) => { - use features::tos_v4 as tos; - if let (Some(level), Some(ty)) = (tos::LEVEL, tos::TYPE) { - self.encode_cmsg(level, ty, ecn as tos::Cmsg).unwrap(); - } - } - SocketAddress::IpV6(_) => { - use features::tos_v6 as tos; - if let (Some(level), Some(ty)) = (tos::LEVEL, tos::TYPE) { - self.encode_cmsg(level, ty, ecn as tos::Cmsg).unwrap(); - } - } - }; + fn cmsg_encoder(&mut self) -> Self::Encoder<'_> { + MsghdrEncoder { msghdr: self } } #[inline] @@ -91,3 +79,30 @@ impl Ext for msghdr { } } } + +pub struct MsghdrEncoder<'a> { + msghdr: &'a mut msghdr, +} + +impl<'a> Encoder for MsghdrEncoder<'a> { + #[inline] + fn encode_cmsg( + &mut self, + level: libc::c_int, + ty: libc::c_int, + value: T, + ) -> Result { + let storage = + unsafe { &mut *(self.msghdr.msg_control as *mut cmsg::Storage<{ cmsg::MAX_LEN }>) }; + + let mut encoder = storage.encoder(); + encoder.seek(self.msghdr.msg_controllen as _); + + let msg_len = encoder.encode_cmsg(level, ty, value)?; + + // update the cursor + self.msghdr.msg_controllen = encoder.len() as _; + + Ok(msg_len) + } +} diff --git a/quic/s2n-quic-platform/src/message/msg/handle.rs b/quic/s2n-quic-platform/src/message/msg/handle.rs index 26851eebc2..66a0ce7922 100644 --- a/quic/s2n-quic-platform/src/message/msg/handle.rs +++ b/quic/s2n-quic-platform/src/message/msg/handle.rs @@ -2,10 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 use super::ext::Ext as _; -use crate::message::cmsg::Encoder; +use crate::{features, message::cmsg::Encoder}; use libc::msghdr; use s2n_quic_core::{ - inet::{AncillaryData, SocketAddress, SocketAddressV4}, + inet::{AncillaryData, SocketAddressV4}, path::{self, LocalAddress, RemoteAddress}, }; @@ -32,42 +32,10 @@ impl Handle { msghdr.set_remote_address(&self.remote_address.0); - #[cfg(s2n_quic_platform_pktinfo)] - match self.local_address.0 { - SocketAddress::IpV4(addr) => { - use s2n_quic_core::inet::Unspecified; - - let ip = addr.ip(); - - if ip.is_unspecified() { - return; - } - - let mut pkt_info = unsafe { core::mem::zeroed::() }; - pkt_info.ipi_spec_dst.s_addr = u32::from_ne_bytes((*ip).into()); - - msghdr - .encode_cmsg(libc::IPPROTO_IP, libc::IP_PKTINFO, pkt_info) - .unwrap(); - } - SocketAddress::IpV6(addr) => { - use s2n_quic_core::inet::Unspecified; - - let ip = addr.ip(); - - if ip.is_unspecified() { - return; - } - - let mut pkt_info = unsafe { core::mem::zeroed::() }; - - pkt_info.ipi6_addr.s6_addr = (*ip).into(); - - msghdr - .encode_cmsg(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pkt_info) - .unwrap(); - } - } + msghdr + .cmsg_encoder() + .encode_local_address(&self.local_address.0) + .unwrap(); } } @@ -100,7 +68,7 @@ impl path::Handle for Handle { let mut eq = true; // only compare local addresses if the OS returns them - if cfg!(s2n_quic_platform_pktinfo) { + if features::pktinfo::IS_SUPPORTED { eq &= self.local_address.eq(&other.local_address); } diff --git a/quic/s2n-quic-platform/src/message/msg/tests.rs b/quic/s2n-quic-platform/src/message/msg/tests.rs index 6d7de4f9f8..06b48149ab 100644 --- a/quic/s2n-quic-platform/src/message/msg/tests.rs +++ b/quic/s2n-quic-platform/src/message/msg/tests.rs @@ -23,7 +23,7 @@ fn test_msghdr(f: F) { msghdr.msg_iov = &mut iovec; - let mut msg_control = cmsg::Storage::default(); + let mut msg_control = >::default(); msghdr.msg_controllen = msg_control.len() as _; msghdr.msg_control = msg_control.as_mut_ptr() as *mut _; @@ -39,7 +39,7 @@ fn test_msghdr(f: F) { mod stubs { use s2n_quic_core::inet::AncillaryData; - pub fn decode(_msghdr: &libc::msghdr) -> AncillaryData { + pub fn collect(_iter: crate::message::cmsg::decode::Iter) -> AncillaryData { let ancillary_data = kani::any(); ancillary_data @@ -68,7 +68,7 @@ fn address_inverse_pair_test() { kani::solver(cadical), kani::unwind(65), // it's safe to stub out cmsg::decode since the cmsg result isn't actually checked in this particular test - kani::stub(cmsg::decode, stubs::decode) + kani::stub(cmsg::decode::collect, stubs::collect) )] fn handle_get_set_test() { check!() @@ -92,7 +92,7 @@ fn handle_get_set_test() { // no need to check this on kani since we abstract the decode() function to avoid performance issues #[cfg(not(kani))] { - if cfg!(s2n_quic_platform_pktinfo) + if features::pktinfo::IS_SUPPORTED && !handle.local_address.ip().is_unspecified() { assert_eq!(header.path.local_address.ip(), handle.local_address.ip()); diff --git a/quic/s2n-quic-platform/src/syscall.rs b/quic/s2n-quic-platform/src/syscall.rs index 678534b9d0..2e1cededc3 100644 --- a/quic/s2n-quic-platform/src/syscall.rs +++ b/quic/s2n-quic-platform/src/syscall.rs @@ -207,28 +207,32 @@ pub fn configure_pktinfo(rx_socket: &Socket) -> bool { let mut success = false; // Set up the RX socket to pass information about the local address and interface - #[cfg(s2n_quic_platform_pktinfo)] + #[cfg(unix)] { use std::os::unix::io::AsRawFd; let enabled: libc::c_int = 1; - success |= libc!(setsockopt( - rx_socket.as_raw_fd(), - libc::IPPROTO_IP, - libc::IP_PKTINFO, - &enabled as *const _ as _, - core::mem::size_of_val(&enabled) as _, - )) - .is_ok(); + if let Some((level, ty)) = crate::features::pktinfo_v4::SOCKOPT { + success |= libc!(setsockopt( + rx_socket.as_raw_fd(), + level, + ty, + &enabled as *const _ as _, + core::mem::size_of_val(&enabled) as _, + )) + .is_ok(); + } - success |= libc!(setsockopt( - rx_socket.as_raw_fd(), - libc::IPPROTO_IPV6, - libc::IPV6_RECVPKTINFO, - &enabled as *const _ as _, - core::mem::size_of_val(&enabled) as _, - )) - .is_ok(); + if let Some((level, ty)) = crate::features::pktinfo_v6::SOCKOPT { + success |= libc!(setsockopt( + rx_socket.as_raw_fd(), + level, + ty, + &enabled as *const _ as _, + core::mem::size_of_val(&enabled) as _, + )) + .is_ok(); + } } success