From 0a9313cfd1cf3a1e3df15feea87ac724d774f113 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Wed, 31 Jul 2024 20:44:59 +0800 Subject: [PATCH] feat(quic): basic endpoint & connection --- Cargo.toml | 2 + compio-quic/Cargo.toml | 51 +++ compio-quic/src/connection.rs | 340 +++++++++++++++ compio-quic/src/endpoint.rs | 349 +++++++++++++++ compio-quic/src/lib.rs | 13 + compio-quic/src/socket.rs | 775 ++++++++++++++++++++++++++++++++++ compio-tls/Cargo.toml | 4 +- 7 files changed, 1532 insertions(+), 2 deletions(-) create mode 100644 compio-quic/Cargo.toml create mode 100644 compio-quic/src/connection.rs create mode 100644 compio-quic/src/endpoint.rs create mode 100644 compio-quic/src/lib.rs create mode 100644 compio-quic/src/socket.rs diff --git a/Cargo.toml b/Cargo.toml index 02a44d46..0a00f90f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "compio-tls", "compio-log", "compio-process", + "compio-quic", ] resolver = "2" @@ -49,6 +50,7 @@ nix = "0.29.0" once_cell = "1.18.0" os_pipe = "1.1.4" paste = "1.0.14" +rustls = { version = "0.23.1", default-features = false } slab = "0.4.9" socket2 = "0.5.6" tempfile = "3.8.1" diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml new file mode 100644 index 00000000..b8ad45a0 --- /dev/null +++ b/compio-quic/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "compio-quic" +version = "0.1.0" +description = "QUIC for compio" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "net", "quic"] +edition = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +repository = { workspace = true } + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +# Workspace dependencies +compio-buf = { workspace = true, features = ["bytes"] } +compio-net = { workspace = true, features = ["io-uring"] } +compio-runtime = { workspace = true, features = ["event", "time"] } + +rustls = { workspace = true, default-features = false, features = [ + "std", + "logging", + "ring", +] } +quinn-proto = "0.11.3" + +# Utils +flume = { workspace = true } +futures-util = { workspace = true } + +# Windows specific dependencies +[target.'cfg(windows)'.dependencies] +windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } + +# Linux specific dependencies +[target.'cfg(target_os = "linux")'.dependencies] + +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } + +[dev-dependencies] +compio-driver = { workspace = true } +compio-macros = { workspace = true } +socket2 = { workspace = true, features = ["all"] } +tracing-subscriber = "0.3.18" + +[features] +default = [] diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs new file mode 100644 index 00000000..3f92c83f --- /dev/null +++ b/compio-quic/src/connection.rs @@ -0,0 +1,340 @@ +use std::{ + cell::RefCell, + io, + net::{IpAddr, SocketAddr}, + pin::{pin, Pin}, + rc::Rc, + task::{Context, Poll}, + time::Instant, +}; + +use compio_buf::BufResult; +use compio_runtime::JoinHandle; +use flume::{Receiver, Sender}; +use futures_util::{ + future::{poll_fn, Fuse, FusedFuture, Future, LocalBoxFuture}, + select, + task::AtomicWaker, + FutureExt, +}; +use quinn_proto::{ + crypto::rustls::HandshakeData, ConnectionError, ConnectionHandle, EndpointEvent, VarInt, +}; + +use crate::socket::Socket; + +#[derive(Debug)] +pub enum ConnectionEvent { + Close(VarInt, String), + Proto(quinn_proto::ConnectionEvent), +} + +#[derive(Debug)] +struct ConnectionState { + conn: quinn_proto::Connection, + connected: bool, + error: Option, + worker: Option>, +} + +impl ConnectionState { + fn new(conn: quinn_proto::Connection) -> Self { + Self { + conn, + connected: false, + error: None, + worker: None, + } + } + + fn terminate(&mut self, reason: ConnectionError) { + self.error = Some(reason); + self.connected = false; + } +} + +#[derive(Debug)] +struct ConnectionInner { + state: RefCell, + handle: ConnectionHandle, + socket: Socket, + on_connected: AtomicWaker, + on_handshake_data: AtomicWaker, + endpoint_tx: Sender<(ConnectionHandle, EndpointEvent)>, + conn_rx: Receiver, +} + +impl ConnectionInner { + fn new( + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + endpoint_tx: Sender<(ConnectionHandle, EndpointEvent)>, + conn_rx: Receiver, + ) -> Self { + Self { + state: RefCell::new(ConnectionState::new(conn)), + handle, + socket, + on_connected: AtomicWaker::new(), + on_handshake_data: AtomicWaker::new(), + endpoint_tx, + conn_rx, + } + } + + fn on_close(&self) { + self.on_handshake_data.wake(); + self.on_connected.wake(); + } + + fn close(&self, error_code: VarInt, reason: String) { + let mut state = self.state.borrow_mut(); + state.conn.close(Instant::now(), error_code, reason.into()); + state.terminate(ConnectionError::LocallyClosed); + self.on_close(); + } + + async fn run(&self) -> io::Result<()> { + let mut send_buf = Some(Vec::with_capacity( + self.state.borrow().conn.current_mtu() as usize + )); + let mut transmit_fut = pin!(Fuse::terminated()); + + let mut timer = Timer::new(); + + loop { + { + let now = Instant::now(); + let mut state = self.state.borrow_mut(); + + if let Some(mut buf) = send_buf.take() { + if let Some(transmit) = + state + .conn + .poll_transmit(now, self.socket.max_gso_segments(), &mut buf) + { + transmit_fut + .set(async move { self.socket.send(buf, &transmit).await }.fuse()) + } else { + send_buf = Some(buf); + } + } + + timer.reset(state.conn.poll_timeout()); + + while let Some(event) = state.conn.poll_endpoint_events() { + let _ = self.endpoint_tx.send((self.handle, event)); + } + + while let Some(event) = state.conn.poll() { + use quinn_proto::Event::*; + match event { + HandshakeDataReady => { + self.on_handshake_data.wake(); + } + Connected => { + state.connected = true; + self.on_connected.wake(); + } + ConnectionLost { reason } => { + state.terminate(reason); + self.on_close(); + } + _ => {} + } + } + + if state.conn.is_drained() { + break Ok(()); + } + } + + select! { + _ = timer => { + self.state.borrow_mut().conn.handle_timeout(Instant::now()); + timer.reset(None); + }, + ev = self.conn_rx.recv_async() =>match ev { + Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason), + Ok(ConnectionEvent::Proto(ev)) => self.state.borrow_mut().conn.handle_event(ev), + Err(_) => unreachable!("endpoint dropped connection"), + }, + BufResult(res, mut buf) = transmit_fut => match res { + Ok(()) => { + buf.clear(); + send_buf = Some(buf); + }, + Err(e) => break Err(e), + }, + } + } + } +} + +macro_rules! conn_fn { + () => { + /// The local IP address which was used when the peer established + /// the connection. + /// + /// This can be different from the address the endpoint is bound to, in case + /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`. + /// + /// This will return `None` for clients, or when the platform does not + /// expose this information. + pub fn local_ip(&self) -> Option { + self.0.state.borrow().conn.local_ip() + } + + /// The peer's UDP address. + /// + /// Will panic if called after `poll` has returned `Ready`. + pub fn remote_address(&self) -> SocketAddr { + self.0.state.borrow().conn.remote_address() + } + + /// Parameters negotiated during the handshake. + pub async fn handshake_data(&mut self) -> Result, ConnectionError> { + poll_fn(|cx| { + let state = self.0.state.borrow(); + if let Some(handshake_data) = state.conn.crypto_session().handshake_data() { + Poll::Ready(Ok(handshake_data.downcast::().unwrap())) + } else if let Some(ref error) = state.error { + Poll::Ready(Err(error.clone())) + } else { + self.0.on_handshake_data.register(cx.waker()); + Poll::Pending + } + }) + .await + } + }; +} + +/// In-progress connection attempt future +#[derive(Debug)] +#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] +pub struct Connecting(Rc); + +impl Connecting { + conn_fn!(); + + pub(crate) fn new( + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + endpoint_tx: Sender<(ConnectionHandle, EndpointEvent)>, + conn_rx: Receiver, + ) -> Self { + let inner = Rc::new(ConnectionInner::new( + handle, + conn, + socket, + endpoint_tx, + conn_rx, + )); + let worker = compio_runtime::spawn({ + let inner = inner.clone(); + async move { inner.run().await.unwrap() } + }); + inner.state.borrow_mut().worker = Some(worker); + Self(inner) + } +} + +impl Future for Connecting { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let state = self.0.state.borrow(); + if state.connected { + Poll::Ready(Ok(Connection(self.0.clone()))) + } else if let Some(error) = &state.error { + Poll::Ready(Err(error.clone())) + } else { + self.0.on_connected.register(cx.waker()); + Poll::Pending + } + } +} + +/// A QUIC connection. +#[derive(Debug)] +pub struct Connection(Rc); + +impl Connection { + conn_fn!(); + + /// Close the connection immediately. + /// + /// Pending operations will fail immediately with + /// [`ConnectionError::LocallyClosed`]. Delivery of data on unfinished + /// streams is not guaranteed, so the application must call this only when + /// all important communications have been completed, e.g. by calling + /// [`finish`] on outstanding [`SendStream`]s and waiting for the resulting + /// futures to complete. + /// + /// `error_code` and `reason` are not interpreted, and are provided directly + /// to the peer. + /// + /// `reason` will be truncated to fit in a single packet with overhead; to + /// improve odds that it is preserved in full, it should be kept under 1KiB. + /// + /// [`ConnectionError::LocallyClosed`]: quinn_proto::ConnectionError::LocallyClosed + /// [`finish`]: crate::SendStream::finish + /// [`SendStream`]: crate::SendStream + pub fn close(&self, error_code: VarInt, reason: &str) { + self.0.close(error_code, reason.to_string()); + } + + /// Wait for the connection to be closed for any reason + pub async fn closed(&self) -> ConnectionError { + let worker = self.0.state.borrow_mut().worker.take(); + if let Some(worker) = worker { + let _ = worker.await; + } + + self.0.state.borrow().error.clone().unwrap() + } +} + +struct Timer { + deadline: Option, + fut: Fuse>, +} + +impl Timer { + fn new() -> Self { + Self { + deadline: None, + fut: Fuse::terminated(), + } + } + + fn reset(&mut self, deadline: Option) { + if let Some(deadline) = deadline { + if self.deadline.is_none() || self.deadline != Some(deadline) { + self.fut = compio_runtime::time::sleep_until(deadline) + .boxed_local() + .fuse(); + } + } else { + self.fut = Fuse::terminated(); + } + self.deadline = deadline; + } +} + +impl Future for Timer { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.fut.poll_unpin(cx) + } +} + +impl FusedFuture for Timer { + fn is_terminated(&self) -> bool { + self.fut.is_terminated() + } +} diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs new file mode 100644 index 00000000..26a95aaa --- /dev/null +++ b/compio-quic/src/endpoint.rs @@ -0,0 +1,349 @@ +use std::{ + cell::RefCell, + collections::HashMap, + io, + mem::ManuallyDrop, + net::{SocketAddr, SocketAddrV6}, + pin::pin, + rc::Rc, + sync::Arc, + task::Poll, + time::Instant, +}; + +use compio_buf::BufResult; +use compio_net::UdpSocket; +use compio_runtime::JoinHandle; +use flume::{unbounded, Receiver, Sender}; +use futures_util::{ + future::{join_all, poll_fn}, + select, + task::AtomicWaker, + FutureExt, +}; +use quinn_proto::{ + ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, EndpointConfig, EndpointEvent, + ServerConfig, Transmit, VarInt, +}; + +use crate::{ + connection::{Connecting, ConnectionEvent}, + socket::{RecvMeta, Socket}, +}; + +#[derive(Debug)] +struct EndpointState { + endpoint: quinn_proto::Endpoint, + worker: Option>, + connections: HashMap>, + close: Option<(VarInt, String)>, +} + +impl EndpointState { + fn new(config: EndpointConfig, server_config: Option, allow_mtud: bool) -> Self { + Self { + endpoint: quinn_proto::Endpoint::new( + Arc::new(config), + server_config.map(Arc::new), + allow_mtud, + None, + ), + worker: None, + connections: HashMap::new(), + close: None, + } + } + + fn handle_data(&mut self, meta: RecvMeta, buf: &[u8]) -> Vec<(Vec, Transmit)> { + let mut outgoing = vec![]; + let now = Instant::now(); + for data in buf[..meta.len] + .chunks(meta.stride.min(meta.len)) + .map(Into::into) + { + let mut resp_buf = Vec::new(); + match self.endpoint.handle( + now, + meta.remote, + meta.local_ip, + meta.ecn, + data, + &mut resp_buf, + ) { + Some(DatagramEvent::NewConnection(_incoming)) => todo!("server"), + Some(DatagramEvent::ConnectionEvent(ch, event)) => { + let _ = self + .connections + .get(&ch) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + Some(DatagramEvent::Response(transmit)) => outgoing.push((resp_buf, transmit)), + None => {} + } + } + outgoing + } + + fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) -> bool { + let mut is_idle = false; + if event.is_drained() { + self.connections.remove(&ch); + if self.connections.is_empty() { + is_idle = true; + } + } + if let Some(event) = self.endpoint.handle_event(ch, event) { + let _ = self + .connections + .get(&ch) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + is_idle + } +} + +#[derive(Debug)] +struct EndpointInner { + state: RefCell, + socket: Socket, + ipv6: bool, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver<(ConnectionHandle, EndpointEvent)>, + done: AtomicWaker, +} + +impl EndpointInner { + fn new( + socket: UdpSocket, + config: EndpointConfig, + server_config: Option, + ) -> io::Result { + let socket = Socket::new(socket)?; + let ipv6 = socket.local_addr()?.is_ipv6(); + let allow_mtud = !socket.may_fragment(); + + let (events_tx, events_rx) = unbounded(); + + Ok(Self { + state: RefCell::new(EndpointState::new(config, server_config, allow_mtud)), + socket, + ipv6, + events_tx, + events_rx, + done: AtomicWaker::new(), + }) + } + + fn connect( + &self, + remote: SocketAddr, + server_name: &str, + config: ClientConfig, + ) -> Result { + let (handle, conn) = { + let mut state = self.state.borrow_mut(); + + if state.worker.is_none() { + return Err(ConnectError::EndpointStopping); + } + if remote.is_ipv6() && !self.ipv6 { + return Err(ConnectError::InvalidRemoteAddress(remote)); + } + let remote = if self.ipv6 { + SocketAddr::V6(match remote { + SocketAddr::V4(addr) => { + SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0) + } + SocketAddr::V6(addr) => addr, + }) + } else { + remote + }; + + state + .endpoint + .connect(Instant::now(), config, remote, server_name)? + }; + + Ok(self.new_connection(handle, conn)) + } + + fn new_connection( + &self, + handle: ConnectionHandle, + conn: quinn_proto::Connection, + ) -> Connecting { + let (tx, rx) = unbounded(); + + let mut state = self.state.borrow_mut(); + if let Some((error_code, reason)) = &state.close { + tx.send(ConnectionEvent::Close(*error_code, reason.clone())) + .unwrap(); + } + state.connections.insert(handle, tx); + Connecting::new( + handle, + conn, + self.socket.clone(), + self.events_tx.clone(), + rx, + ) + } + + async fn run(&self) -> io::Result<()> { + let mut recv_fut = pin!( + self.socket + .recv(Vec::with_capacity( + self.state + .borrow() + .endpoint + .config() + .get_max_udp_payload_size() + .min(64 * 1024) as usize + * self.socket.max_gro_segments(), + )) + .fuse() + ); + + let idle = AtomicWaker::new(); + let mut close_fut = poll_fn(|cx| { + let state = self.state.borrow(); + if state.close.is_some() && state.connections.is_empty() { + Poll::Ready(()) + } else { + idle.register(cx.waker()); + Poll::Pending + } + }) + .fuse(); + + loop { + select! { + BufResult(res, recv_buf) = recv_fut => { + match res { + Ok(meta) => { + let outgoing = self.state.borrow_mut().handle_data(meta, &recv_buf); + join_all( + outgoing.into_iter().map(|(buf, transmit)| async move { + self.socket.send(buf, &transmit).await + }), + ) + .await; + } + Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} + Err(e) => break Err(e), + } + recv_fut.set(self.socket.recv(recv_buf).fuse()); + }, + (ch, event) = self.events_rx.recv_async().map(Result::unwrap) => { + if self.state.borrow_mut().handle_event(ch, event) { + idle.wake(); + } + }, + _ = close_fut => break Ok(()), + } + } + } +} + +/// A QUIC endpoint. +#[derive(Debug, Clone)] +pub struct Endpoint { + inner: Rc, + /// The client configuration used by `connect` + pub default_client_config: Option, +} + +impl Endpoint { + /// Create a QUIC endpoint. + pub fn new( + socket: UdpSocket, + config: EndpointConfig, + server_config: Option, + default_client_config: Option, + ) -> io::Result { + let inner = Rc::new(EndpointInner::new(socket, config, server_config)?); + let worker = compio_runtime::spawn({ + let inner = inner.clone(); + async move { inner.run().await.unwrap() } + }); + inner.state.borrow_mut().worker = Some(worker); + Ok(Self { + inner, + default_client_config, + }) + } + + /// Connect to a remote endpoint. + pub fn connect( + &self, + remote: SocketAddr, + server_name: &str, + config: Option, + ) -> Result { + let config = if let Some(config) = config { + config + } else if let Some(config) = &self.default_client_config { + config.clone() + } else { + return Err(ConnectError::NoDefaultClientConfig); + }; + + self.inner.connect(remote, server_name, config) + } + + // Modified from [`SharedFd::try_unwrap_inner`], see notes there. + unsafe fn try_unwrap_inner(this: &ManuallyDrop) -> Option { + let ptr = ManuallyDrop::new(std::ptr::read(&this.inner)); + match Rc::try_unwrap(ManuallyDrop::into_inner(ptr)) { + Ok(inner) => Some(inner), + Err(ptr) => { + std::mem::forget(ptr); + None + } + } + } + + /// Shutdown the endpoint. + /// + /// This will close all connections and the underlying socket. Note that it + /// will wait for all connections and all clones of the endpoint (and any + /// clone of the underlying socket) to be dropped before closing the socket. + /// + /// See [`Connection::close()`](crate::Connection::close) for details. + pub async fn close(self, error_code: VarInt, reason: &str) -> io::Result<()> { + let reason = reason.to_string(); + self.inner.state.borrow_mut().close = Some((error_code, reason.clone())); + for conn in self.inner.state.borrow().connections.values() { + let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone())); + } + + let worker = self.inner.state.borrow_mut().worker.take(); + if let Some(worker) = worker { + let _ = worker.await; + } + + let this = ManuallyDrop::new(self); + let inner = poll_fn(move |cx| { + if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } { + Poll::Ready(inner) + } else { + this.inner.done.register(cx.waker()); + Poll::Pending + } + }) + .await; + + inner.socket.close().await + } +} + +impl Drop for Endpoint { + fn drop(&mut self) { + if Rc::strong_count(&self.inner) == 2 { + self.inner.done.wake(); + } + } +} diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs new file mode 100644 index 00000000..83a7f1db --- /dev/null +++ b/compio-quic/src/lib.rs @@ -0,0 +1,13 @@ +//! QUIC implementation for compio + +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![warn(missing_docs)] + +mod connection; +mod endpoint; +mod socket; + +pub use crate::{ + connection::{Connecting, Connection}, + endpoint::Endpoint, +}; diff --git a/compio-quic/src/socket.rs b/compio-quic/src/socket.rs new file mode 100644 index 00000000..15a3cd74 --- /dev/null +++ b/compio-quic/src/socket.rs @@ -0,0 +1,775 @@ +//! Simple wrapper around UDP socket with advanced features useful for QUIC, +//! ported from [`quinn-udp`] +//! +//! Differences from [`quinn-udp`]: +//! - [quinn-rs/quinn#1516] is not implemented +//! - `recvmmsg` is not available +//! +//! [`quinn-udp`]: https://docs.rs/quinn-udp +//! [quinn-rs/quinn#1516]: https://github.com/quinn-rs/quinn/pull/1516 + +use std::{ + future::Future, + io, + net::{IpAddr, SocketAddr}, + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, Ordering}, +}; + +use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; +use compio_net::{CMsgBuilder, CMsgIter, UdpSocket}; +use quinn_proto::{EcnCodepoint, Transmit}; +#[cfg(windows)] +use windows_sys::Win32::Networking::WinSock; + +trait IoResultExt { + fn map_noprotoopt(self) -> io::Result; +} + +impl IoResultExt for io::Result<()> { + fn map_noprotoopt(self) -> io::Result { + match self { + Ok(()) => Ok(false), + Err(e) => match e.raw_os_error() { + #[cfg(unix)] + Some(libc::ENOPROTOOPT) => Ok(true), + #[cfg(windows)] + Some(WinSock::WSAENOPROTOOPT) => Ok(true), + _ => Err(e), + }, + } + } +} + +/// Metadata for a single buffer filled with bytes received from the network +/// +/// This associated buffer can contain one or more datagrams, see [`stride`]. +/// +/// [`stride`]: RecvMeta::stride +#[derive(Debug)] +pub(crate) struct RecvMeta { + /// The source address of the datagram(s) contained in the buffer + pub remote: SocketAddr, + /// The number of bytes the associated buffer has + pub len: usize, + /// The size of a single datagram in the associated buffer + /// + /// When GRO (Generic Receive Offload) is used this indicates the size of a + /// single datagram inside the buffer. If the buffer is larger, that is + /// if [`len`] is greater then this value, then the individual datagrams + /// contained have their boundaries at `stride` increments from the + /// start. The last datagram could be smaller than `stride`. + /// + /// [`len`]: RecvMeta::len + pub stride: usize, + /// The Explicit Congestion Notification bits for the datagram(s) in the + /// buffer + pub ecn: Option, + /// The destination IP address which was encoded in this datagram + /// + /// Populated on platforms: Windows, Linux, Android, FreeBSD, OpenBSD, + /// NetBSD, macOS, and iOS. + pub local_ip: Option, +} + +const CMSG_LEN: usize = 128; + +struct Ancillary { + inner: [u8; N], + len: usize, + #[cfg(unix)] + _align: [libc::cmsghdr; 0], + #[cfg(windows)] + _align: [WinSock::CMSGHDR; 0], +} + +impl Ancillary { + fn new() -> Self { + Self { + inner: [0u8; N], + len: N, + _align: [], + } + } +} + +unsafe impl IoBuf for Ancillary { + fn as_buf_ptr(&self) -> *const u8 { + self.inner.as_buf_ptr() + } + + fn buf_len(&self) -> usize { + self.len + } + + fn buf_capacity(&self) -> usize { + N + } +} + +impl SetBufInit for Ancillary { + unsafe fn set_buf_init(&mut self, len: usize) { + debug_assert!(len <= N); + self.len = len; + } +} + +unsafe impl IoBufMut for Ancillary { + fn as_buf_mut_ptr(&mut self) -> *mut u8 { + self.inner.as_buf_mut_ptr() + } +} + +impl Deref for Ancillary { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.inner[0..self.len] + } +} + +impl DerefMut for Ancillary { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner[0..self.len] + } +} + +#[cfg(unix)] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + socket.get_socket_option::(libc::SOL_UDP, libc::UDP_SEGMENT)?; + Ok(64) +} +#[cfg(windows)] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + socket.get_socket_option::(WinSock::IPPROTO_UDP, WinSock::UDP_SEND_MSG_SIZE)?; + Ok(512) +} +#[cfg(not(any(target_os = "linux", windows)))] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) +} + +#[derive(Debug)] +pub(crate) struct Socket { + inner: UdpSocket, + max_gro_segments: usize, + max_gso_segments: usize, + may_fragment: bool, + has_gso_error: AtomicBool, + #[cfg(target_os = "freebsd")] + encode_src_ip_v4: bool, +} + +impl Socket { + pub fn new(socket: UdpSocket) -> io::Result { + let is_ipv6 = socket.local_addr()?.is_ipv6(); + #[cfg(unix)] + let only_v6 = is_ipv6 + && socket.get_socket_option::(libc::IPPROTO_IPV6, libc::IPV6_V6ONLY)? != 0; + #[cfg(windows)] + let only_v6 = is_ipv6 + && socket.get_socket_option::(WinSock::IPPROTO_IPV6, WinSock::IPV6_V6ONLY)? != 0; + let is_ipv4 = socket.local_addr()?.is_ipv4() || !only_v6; + + // ECN + if is_ipv4 { + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + socket.set_socket_option(libc::IPPROTO_IP, libc::IP_RECVTOS, &1)?; + #[cfg(windows)] + socket.set_socket_option(WinSock::IPPROTO_IP, WinSock::IP_ECN, &1)?; + } + if is_ipv6 { + #[cfg(unix)] + socket.set_socket_option(libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, &1)?; + #[cfg(windows)] + socket.set_socket_option(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, &1)?; + } + + // pktinfo / destination address + if is_ipv4 { + #[cfg(any(target_os = "linux", target_os = "android"))] + socket.set_socket_option(libc::IPPROTO_IP, libc::IP_PKTINFO, &1)?; + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios" + ))] + socket.set_socket_option(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, &1)?; + #[cfg(windows)] + socket.set_socket_option(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, &1)?; + } + if is_ipv6 { + #[cfg(unix)] + socket.set_socket_option(libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, &1)?; + #[cfg(windows)] + socket.set_socket_option(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, &1)?; + } + + // disable fragmentation + let mut may_fragment = false; + if is_ipv4 { + #[cfg(any(target_os = "linux", target_os = "android"))] + { + may_fragment |= socket + .set_socket_option( + libc::IPPROTO_IP, + libc::IP_MTU_DISCOVER, + &libc::IP_PMTUDISC_PROBE, + ) + .map_noprotoopt()?; + } + #[cfg(any( + target_os = "aix", + target_os = "freebsd", + target_os = "macos", + target_os = "ios" + ))] + { + may_fragment |= socket + .set_socket_option(libc::IPPROTO_IP, libc::IP_DONTFRAG, &1) + .map_noprotoopt()?; + } + #[cfg(windows)] + { + may_fragment |= socket + .set_socket_option(WinSock::IPPROTO_IP, WinSock::IP_DONTFRAGMENT, &1) + .map_noprotoopt()?; + } + } + if is_ipv6 { + #[cfg(any(target_os = "linux", target_os = "android"))] + { + may_fragment |= socket + .set_socket_option( + libc::IPPROTO_IPV6, + libc::IPV6_MTU_DISCOVER, + &libc::IPV6_PMTUDISC_PROBE, + ) + .map_noprotoopt()?; + } + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + { + may_fragment |= socket + .set_socket_option(libc::IPPROTO_IPV6, libc::IPV6_DONTFRAG, &1) + .map_noprotoopt()?; + } + #[cfg(any(target_os = "openbsd", target_os = "netbsd"))] + { + // FIXME: workaround until https://github.com/rust-lang/libc/pull/3716 is released (at least in 0.2.155) + may_fragment |= socket + .set_socket_option(libc::IPPROTO_IPV6, 62, &1) + .map_noprotoopt()?; + } + #[cfg(windows)] + { + may_fragment |= socket + .set_socket_option(WinSock::IPPROTO_IPV6, WinSock::IPV6_DONTFRAG, &1) + .map_noprotoopt()?; + } + } + + // GRO + #[allow(unused_mut)] // only mutable on Linux and Windows + let mut max_gro_segments = 1; + #[cfg(target_os = "linux")] + if socket + .set_socket_option(libc::SOL_UDP, libc::UDP_GRO, &1) + .is_ok() + { + max_gro_segments = 64; + } + #[cfg(windows)] + if socket + .set_socket_option( + WinSock::IPPROTO_UDP, + WinSock::UDP_RECV_MAX_COALESCED_SIZE, + &(u16::MAX as u32), + ) + .is_ok() + { + max_gro_segments = 64; + } + + // GSO + let max_gso_segments = max_gso_segments(&socket).unwrap_or(1); + + #[cfg(target_os = "freebsd")] + let encode_src_ip_v4 = + socket.local_addr().unwrap().ip() == IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); + + Ok(Self { + inner: socket, + max_gro_segments, + max_gso_segments, + may_fragment, + has_gso_error: AtomicBool::new(false), + #[cfg(target_os = "freebsd")] + encode_src_ip_v4, + }) + } + + #[inline] + pub fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + #[inline] + pub fn may_fragment(&self) -> bool { + self.may_fragment + } + + #[inline] + pub fn max_gro_segments(&self) -> usize { + self.max_gro_segments + } + + #[inline] + pub fn max_gso_segments(&self) -> usize { + if self.has_gso_error.load(Ordering::Relaxed) { + 1 + } else { + self.max_gso_segments + } + } + + pub async fn recv(&self, buffer: T) -> BufResult { + let control = Ancillary::::new(); + + let BufResult(res, (buffer, control)) = self.inner.recv_msg(buffer, control).await; + let ((len, _, remote), buffer) = buf_try!(res, buffer); + + let mut ecn_bits = 0u8; + let mut local_ip = None; + #[allow(unused_mut)] // only mutable on Linux + let mut stride = len; + + // SAFETY: `control` contains valid data + unsafe { + for cmsg in CMsgIter::new(&control) { + #[cfg(windows)] + const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32; + + match (cmsg.level(), cmsg.ty()) { + // ECN + #[cfg(unix)] + (libc::IPPROTO_IP, libc::IP_TOS) => ecn_bits = *cmsg.data::(), + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + (libc::IPPROTO_IP, libc::IP_RECVTOS) => ecn_bits = *cmsg.data::(), + #[cfg(unix)] + (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => { + // NOTE: It's OK to use `c_int` instead of `u8` on Apple systems + ecn_bits = *cmsg.data::() as u8 + } + #[cfg(windows)] + (WinSock::IPPROTO_IP, WinSock::IP_ECN) + | (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { + ecn_bits = *cmsg.data::() as u8 + } + + // pktinfo / destination address + #[cfg(any(target_os = "linux", target_os = "android"))] + (libc::IPPROTO_IP, libc::IP_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes())); + } + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { + let in_addr = cmsg.data::(); + local_ip = Some(IpAddr::from(in_addr.s_addr.to_ne_bytes())); + } + #[cfg(windows)] + (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi_addr.S_un.S_addr.to_ne_bytes())); + } + #[cfg(unix)] + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.s6_addr)); + } + #[cfg(windows)] + (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.u.Byte)); + } + + // GRO + #[cfg(target_os = "linux")] + (libc::SOL_UDP, libc::UDP_GRO) => stride = *cmsg.data::() as usize, + #[cfg(windows)] + (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => { + stride = *cmsg.data::() as usize + } + + _ => {} + } + } + } + + let meta = RecvMeta { + remote, + len, + stride, + ecn: EcnCodepoint::from_bits(ecn_bits), + local_ip, + }; + BufResult(Ok(meta), buffer) + } + + pub async fn send(&self, buffer: T, transmit: &Transmit) -> BufResult<(), T> { + let is_ipv4 = transmit.destination.ip().to_canonical().is_ipv4(); + let ecn = transmit.ecn.map_or(0, |x| x as u8); + + let mut control = Ancillary::::new(); + let mut builder = CMsgBuilder::new(&mut control); + + // ECN + if is_ipv4 { + #[cfg(all(unix, not(any(target_os = "freebsd", target_os = "netbsd"))))] + builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_int); + #[cfg(target_os = "freebsd")] + builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_uchar); + #[cfg(windows)] + builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn as i32); + } else { + #[cfg(unix)] + builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn as libc::c_int); + #[cfg(windows)] + builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn as i32); + } + + // pktinfo / destination address + match transmit.src_ip { + Some(IpAddr::V4(ip)) => { + let addr = u32::from_ne_bytes(ip.octets()); + #[cfg(any(target_os = "linux", target_os = "android"))] + { + let pktinfo = libc::in_pktinfo { + ipi_ifindex: 0, + ipi_spec_dst: libc::in_addr { s_addr: addr }, + ipi_addr: libc::in_addr { s_addr: 0 }, + }; + builder.try_push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo); + } + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + { + #[cfg(target_os = "freebsd")] + let encode_src_ip_v4 = self.encode_src_ip_v4; + #[cfg(any( + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + let encode_src_ip_v4 = true; + + if encode_src_ip_v4 { + let addr = libc::in_addr { s_addr: addr }; + builder.try_push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr); + } + } + #[cfg(windows)] + { + let pktinfo = WinSock::IN_PKTINFO { + ipi_addr: WinSock::IN_ADDR { + S_un: WinSock::IN_ADDR_0 { S_addr: addr }, + }, + ipi_ifindex: 0, + }; + builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo); + } + } + Some(IpAddr::V6(ip)) => { + #[cfg(unix)] + { + let pktinfo = libc::in6_pktinfo { + ipi6_ifindex: 0, + ipi6_addr: libc::in6_addr { + s6_addr: ip.octets(), + }, + }; + builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo); + } + #[cfg(windows)] + { + let pktinfo = WinSock::IN6_PKTINFO { + ipi6_addr: WinSock::IN6_ADDR { + u: WinSock::IN6_ADDR_0 { Byte: ip.octets() }, + }, + ipi6_ifindex: 0, + }; + builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo); + } + } + None => {} + } + + // GSO + if let Some(segment_size) = transmit.segment_size { + #[cfg(target_os = "linux")] + builder.try_push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size as u16); + #[cfg(windows)] + builder.try_push( + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + segment_size as u32, + ); + } + + let len = builder.finish(); + control.len = len; + + let buffer = buffer.slice(0..transmit.size); + let BufResult(res, (buffer, _)) = self + .inner + .send_msg(buffer, control, transmit.destination) + .await; + let buffer = buffer.into_inner(); + match res { + Ok(_) => BufResult(Ok(()), buffer), + Err(e) => { + #[cfg(target_os = "linux")] + if let Some(libc::EIO) | Some(libc::EINVAL) = e.raw_os_error() { + if self.max_gso_segments() > 1 { + self.has_gso_error.store(true, Ordering::Relaxed); + } + } + BufResult(Err(e), buffer) + } + } + } + + pub fn close(self) -> impl Future> { + self.inner.close() + } +} + +impl Clone for Socket { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + may_fragment: self.may_fragment, + max_gro_segments: self.max_gro_segments, + max_gso_segments: self.max_gso_segments, + has_gso_error: AtomicBool::new(self.has_gso_error.load(Ordering::Relaxed)), + #[cfg(target_os = "freebsd")] + encode_src_ip_v4: self.encode_src_ip_v4.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, Ipv6Addr}; + + use compio_driver::AsRawFd; + use socket2::{Domain, Protocol, Socket as Socket2, Type}; + + use super::*; + + async fn test_send_recv( + passive: Socket, + active: Socket, + content: T, + transmit: Transmit, + ) { + let passive_addr = passive.local_addr().unwrap(); + let active_addr = active.local_addr().unwrap(); + + let (_, content) = active.send(content, &transmit).await.unwrap(); + + let segment_size = transmit.segment_size.unwrap_or(transmit.size); + let expected_datagrams = transmit.size / segment_size; + let mut datagrams = 0; + while datagrams < expected_datagrams { + let (meta, buf) = passive + .recv(Vec::with_capacity(u16::MAX as usize)) + .await + .unwrap(); + let segments = meta.len / meta.stride; + for i in 0..segments { + assert_eq!( + &content.as_slice() + [(datagrams + i) * segment_size..(datagrams + i + 1) * segment_size], + &buf[(i * meta.stride)..((i + 1) * meta.stride)] + ); + } + datagrams += segments; + + assert_eq!(meta.ecn, transmit.ecn); + + assert_eq!(meta.remote.port(), active_addr.port()); + for addr in [meta.remote.ip(), meta.local_ip.unwrap()] { + match (active_addr.is_ipv6(), passive_addr.is_ipv6()) { + (_, false) => assert_eq!(addr, Ipv4Addr::LOCALHOST), + (false, true) => assert!( + addr == Ipv4Addr::LOCALHOST || addr == Ipv4Addr::LOCALHOST.to_ipv6_mapped() + ), + (true, true) => assert!( + addr == Ipv6Addr::LOCALHOST || addr == Ipv4Addr::LOCALHOST.to_ipv6_mapped() + ), + } + } + } + assert_eq!(datagrams, expected_datagrams); + } + + /// Helper function to create dualstack udp socket. + /// This is only used for testing. + fn bind_udp_dualstack() -> io::Result { + #[cfg(unix)] + use std::os::fd::{FromRawFd, IntoRawFd}; + #[cfg(windows)] + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let socket = Socket2::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(false)?; + socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).into())?; + + compio_runtime::Runtime::with_current(|r| r.attach(socket.as_raw_fd()))?; + #[cfg(unix)] + unsafe { + Ok(UdpSocket::from_raw_fd(socket.into_raw_fd())) + } + #[cfg(windows)] + unsafe { + Ok(UdpSocket::from_raw_socket(socket.into_raw_socket())) + } + } + + #[compio_macros::test] + async fn basic_v4() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: None, + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive, active, content, transmit).await; + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_v4() { + let passive = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + async fn ecn_v6() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_dualstack() { + let passive = Socket::new(bind_udp_dualstack().unwrap()).unwrap(); + + let mut dst_v4 = passive.local_addr().unwrap(); + dst_v4.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); + let mut dst_v6 = dst_v4; + dst_v6.set_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)); + + for (src, dst) in [("[::1]:0", dst_v6), ("127.0.0.1:0", dst_v4)] { + let active = Socket::new(UdpSocket::bind(src).await.unwrap()).unwrap(); + + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: dst, + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_v4_mapped_v6() { + let passive = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + let active = Socket::new(bind_udp_dualstack().unwrap()).unwrap(); + + let mut dst_addr = passive.local_addr().unwrap(); + dst_addr.set_ip(IpAddr::V6(Ipv4Addr::LOCALHOST.to_ipv6_mapped())); + + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: dst_addr, + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + #[cfg_attr(not(any(target_os = "linux", target_os = "windows")), ignore)] + async fn gso() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + + let max_segments = active.max_gso_segments(); + const SEGMENT_SIZE: usize = 128; + let content = vec![0xAB; SEGMENT_SIZE * max_segments]; + + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: None, + size: content.len(), + segment_size: Some(SEGMENT_SIZE), + src_ip: None, + }; + test_send_recv(passive, active, content, transmit).await; + } +} diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml index 34bb7d90..ef70c561 100644 --- a/compio-tls/Cargo.toml +++ b/compio-tls/Cargo.toml @@ -19,7 +19,7 @@ compio-buf = { workspace = true } compio-io = { workspace = true, features = ["compat"] } native-tls = { version = "0.2.11", optional = true } -rustls = { version = "0.23.1", default-features = false, optional = true, features = [ +rustls = { workspace = true, default-features = false, optional = true, features = [ "logging", "std", "tls12", @@ -30,7 +30,7 @@ compio-net = { workspace = true } compio-runtime = { workspace = true } compio-macros = { workspace = true } -rustls = { version = "0.23.1", default-features = false, features = ["ring"] } +rustls = { workspace = true, default-features = false, features = ["ring"] } rustls-native-certs = "0.7.0" [features]