From 343adcc667ede44bd4c0dfceaea86822c5a9176e Mon Sep 17 00:00:00 2001 From: Boquan Fang Date: Thu, 6 Feb 2025 22:41:57 +0000 Subject: [PATCH] fix: allow configurable IPv4/IPv6 socket binding --- quic/s2n-quic-platform/src/io/tokio.rs | 5 ++- .../s2n-quic-platform/src/io/tokio/builder.rs | 6 +++ quic/s2n-quic-platform/src/io/tokio/tests.rs | 40 +++++++++++++++---- quic/s2n-quic-platform/src/io/xdp.rs | 3 +- quic/s2n-quic-platform/src/socket/options.rs | 4 +- quic/s2n-quic-platform/src/syscall.rs | 8 ++-- 6 files changed, 51 insertions(+), 15 deletions(-) diff --git a/quic/s2n-quic-platform/src/io/tokio.rs b/quic/s2n-quic-platform/src/io/tokio.rs index a1fbf8de67..c74162c81b 100644 --- a/quic/s2n-quic-platform/src/io/tokio.rs +++ b/quic/s2n-quic-platform/src/io/tokio.rs @@ -59,6 +59,7 @@ impl Io { gro_enabled, reuse_address, reuse_port, + only_v6, } = self.builder; let clock = Clock::default(); @@ -91,7 +92,7 @@ impl Io { let rx_socket = if let Some(rx_socket) = rx_socket { rx_socket } else if let Some(recv_addr) = recv_addr { - syscall::bind_udp(recv_addr, reuse_address, reuse_port)? + syscall::bind_udp(recv_addr, reuse_address, reuse_port, only_v6)? } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -104,7 +105,7 @@ impl Io { let tx_socket = if let Some(tx_socket) = tx_socket { tx_socket } else if let Some(send_addr) = send_addr { - syscall::bind_udp(send_addr, reuse_address, reuse_port)? + syscall::bind_udp(send_addr, reuse_address, reuse_port, only_v6)? } else { // No tx_socket or send address was specified, so the tx socket // will be a handle to the rx socket. diff --git a/quic/s2n-quic-platform/src/io/tokio/builder.rs b/quic/s2n-quic-platform/src/io/tokio/builder.rs index 4e6621bb4b..260ec38ad9 100644 --- a/quic/s2n-quic-platform/src/io/tokio/builder.rs +++ b/quic/s2n-quic-platform/src/io/tokio/builder.rs @@ -19,6 +19,7 @@ pub struct Builder { pub(super) gro_enabled: Option, pub(super) reuse_address: bool, pub(super) reuse_port: bool, + pub(super) only_v6: bool, } impl Builder { @@ -236,6 +237,11 @@ impl Builder { Ok(self) } + pub fn with_only_v6(mut self, only_v6: bool) -> io::Result { + self.only_v6 = only_v6; + Ok(self) + } + pub fn build(self) -> io::Result { Ok(Io { builder: self }) } diff --git a/quic/s2n-quic-platform/src/io/tokio/tests.rs b/quic/s2n-quic-platform/src/io/tokio/tests.rs index 21b0a6ab26..326c2ed887 100644 --- a/quic/s2n-quic-platform/src/io/tokio/tests.rs +++ b/quic/s2n-quic-platform/src/io/tokio/tests.rs @@ -141,16 +141,19 @@ impl Endpoint for TestEndpoint { async fn runtime( receive_addr: A, send_addr: Option, + only_v6: bool, ) -> io::Result<(super::Io, SocketAddress)> { - let rx_socket = syscall::bind_udp(receive_addr, false, false)?; + let mut io_builder = Io::builder().with_only_v6(only_v6)?; + + let rx_socket = syscall::bind_udp(receive_addr, false, false, only_v6)?; rx_socket.set_nonblocking(true)?; let rx_socket: std::net::UdpSocket = rx_socket.into(); let rx_addr = rx_socket.local_addr()?; - let mut io_builder = Io::builder().with_rx_socket(rx_socket)?; + io_builder = io_builder.with_rx_socket(rx_socket)?; if let Some(tx_addr) = send_addr { - let tx_socket = syscall::bind_udp(tx_addr, false, false)?; + let tx_socket = syscall::bind_udp(tx_addr, false, false, only_v6)?; tx_socket.set_nonblocking(true)?; let tx_socket: std::net::UdpSocket = tx_socket.into(); io_builder = io_builder.with_tx_socket(tx_socket)? @@ -177,9 +180,10 @@ async fn test( server_tx_addr: Option, client_rx_addr: A, client_tx_addr: Option, + only_v6: bool, ) -> io::Result<()> { - let (server_io, server_addr) = runtime(server_rx_addr, server_tx_addr).await?; - let (client_io, client_addr) = runtime(client_rx_addr, client_tx_addr).await?; + let (server_io, server_addr) = runtime(server_rx_addr, server_tx_addr, only_v6).await?; + let (client_io, client_addr) = runtime(client_rx_addr, client_tx_addr, only_v6).await?; let server_endpoint = { let mut handle = PathHandle::from_remote_address(client_addr.into()); @@ -212,17 +216,20 @@ static IPV6_LOCALHOST: &str = "[::1]:0"; #[tokio::test] #[cfg_attr(miri, ignore)] async fn ipv4_test() -> io::Result<()> { - test(IPV4_LOCALHOST, None, IPV4_LOCALHOST, None).await + let only_v6: bool = false; + test(IPV4_LOCALHOST, None, IPV4_LOCALHOST, None, only_v6).await } #[tokio::test] #[cfg_attr(miri, ignore)] async fn ipv4_two_socket_test() -> io::Result<()> { + let only_v6: bool = false; test( IPV4_LOCALHOST, Some(IPV4_LOCALHOST), IPV4_LOCALHOST, Some(IPV4_LOCALHOST), + only_v6, ) .await } @@ -230,7 +237,8 @@ async fn ipv4_two_socket_test() -> io::Result<()> { #[tokio::test] #[cfg_attr(miri, ignore)] async fn ipv6_test() -> io::Result<()> { - let result = test(IPV6_LOCALHOST, None, IPV6_LOCALHOST, None).await; + let only_v6: bool = false; + let result = test(IPV6_LOCALHOST, None, IPV6_LOCALHOST, None, only_v6).await; match result { Err(err) if err.kind() == io::ErrorKind::AddrNotAvailable => { @@ -244,11 +252,13 @@ async fn ipv6_test() -> io::Result<()> { #[tokio::test] #[cfg_attr(miri, ignore)] async fn ipv6_two_socket_test() -> io::Result<()> { + let only_v6: bool = false; let result = test( IPV6_LOCALHOST, Some(IPV6_LOCALHOST), IPV6_LOCALHOST, Some(IPV6_LOCALHOST), + only_v6, ) .await; @@ -260,3 +270,19 @@ async fn ipv6_two_socket_test() -> io::Result<()> { other => other, } } + +#[tokio::test] +#[cfg_attr(miri, ignore)] +async fn only_v6_test() -> io::Result<()> { + let mut only_v6 = true; + + let socket = syscall::bind_udp(IPV6_LOCALHOST, false, false, only_v6)?; + assert_eq!(socket.only_v6()?, only_v6); + + only_v6 = false; + + let socket = syscall::bind_udp(IPV6_LOCALHOST, false, false, only_v6)?; + assert_eq!(socket.only_v6()?, only_v6); + + Ok(()) +} diff --git a/quic/s2n-quic-platform/src/io/xdp.rs b/quic/s2n-quic-platform/src/io/xdp.rs index 7394b5eadf..69935e4a06 100644 --- a/quic/s2n-quic-platform/src/io/xdp.rs +++ b/quic/s2n-quic-platform/src/io/xdp.rs @@ -28,7 +28,8 @@ pub mod socket { interface: &::std::ffi::CStr, addr: ::std::net::SocketAddr, ) -> ::std::io::Result<::std::net::UdpSocket> { - let socket = crate::syscall::udp_socket(addr)?; + let only_v6 = false; + let socket = crate::syscall::udp_socket(addr, only_v6)?; // associate the socket with a single interface crate::syscall::bind_to_interface(&socket, interface)?; diff --git a/quic/s2n-quic-platform/src/socket/options.rs b/quic/s2n-quic-platform/src/socket/options.rs index b06e2a390b..6b54c7dba3 100644 --- a/quic/s2n-quic-platform/src/socket/options.rs +++ b/quic/s2n-quic-platform/src/socket/options.rs @@ -32,6 +32,7 @@ pub struct Options { pub send_buffer: Option, pub recv_buffer: Option, pub backlog: usize, + pub only_v6: bool, } impl Default for Options { @@ -47,6 +48,7 @@ impl Default for Options { recv_buffer: None, delay: false, backlog: 4096, + only_v6: false, } } } @@ -62,7 +64,7 @@ impl Options { #[inline] pub fn build_udp(&self) -> io::Result { - let socket = syscall::udp_socket(self.addr)?; + let socket = syscall::udp_socket(self.addr, self.only_v6)?; if self.gro { let _ = syscall::configure_gro(&socket); diff --git a/quic/s2n-quic-platform/src/syscall.rs b/quic/s2n-quic-platform/src/syscall.rs index 346d736539..3ea01c6cbe 100644 --- a/quic/s2n-quic-platform/src/syscall.rs +++ b/quic/s2n-quic-platform/src/syscall.rs @@ -66,15 +66,14 @@ pub trait UnixMessage: crate::message::Message { ); } -pub fn udp_socket(addr: std::net::SocketAddr) -> io::Result { +pub fn udp_socket(addr: std::net::SocketAddr, only_v6: bool) -> io::Result { let domain = Domain::for_address(addr); let socket_type = Type::DGRAM; let protocol = Some(Protocol::UDP); let socket = Socket::new(domain, socket_type, protocol)?; - // allow ipv4 to also connect - ignore the error if it fails - let _ = socket.set_only_v6(false); + let _ = socket.set_only_v6(only_v6); Ok(socket) } @@ -84,6 +83,7 @@ pub fn bind_udp( addr: A, reuse_address: bool, reuse_port: bool, + only_v6: bool, ) -> io::Result { let addr = addr.to_socket_addrs()?.next().ok_or_else(|| { std::io::Error::new( @@ -91,7 +91,7 @@ pub fn bind_udp( "the provided bind address was empty", ) })?; - let socket = udp_socket(addr)?; + let socket = udp_socket(addr, only_v6)?; socket.set_reuse_address(reuse_address)?;