From 07124ee94a247a441d4e2d4e0dbaa33c1466c5f8 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Wed, 31 Jul 2024 17:36:53 +0800 Subject: [PATCH 01/26] feat(net): add `get_socket_option` on `Socket` --- compio-net/src/socket.rs | 45 +++++++++++++++++++++++++++++++++++++++- compio-net/src/udp.rs | 5 +++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index d7bc5a8d..74d8b7ee 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -1,4 +1,8 @@ -use std::{future::Future, io, mem::ManuallyDrop}; +use std::{ + future::Future, + io, + mem::{ManuallyDrop, MaybeUninit}, +}; use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut}; #[cfg(unix)] @@ -319,6 +323,45 @@ impl Socket { compio_runtime::submit(op).await.into_inner() } + #[cfg(unix)] + pub fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + let mut value: MaybeUninit = MaybeUninit::uninit(); + let mut len = size_of::() as libc::socklen_t; + syscall!(libc::getsockopt( + self.socket.as_raw_fd(), + level, + name, + value.as_mut_ptr() as _, + &mut len + )) + .map(|_| { + debug_assert_eq!(len as usize, size_of::()); + // SAFETY: The value is initialized by `getsockopt`. + unsafe { value.assume_init() } + }) + } + + #[cfg(windows)] + pub fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + let mut value: MaybeUninit = MaybeUninit::uninit(); + let mut len = size_of::() as i32; + syscall!( + SOCKET, + windows_sys::Win32::Networking::WinSock::getsockopt( + self.socket.as_raw_fd() as _, + level, + name, + value.as_mut_ptr() as _, + &mut len + ) + ) + .map(|_| { + debug_assert_eq!(len as usize, size_of::()); + // SAFETY: The value is initialized by `getsockopt`. + unsafe { value.assume_init() } + }) + } + #[cfg(unix)] pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { syscall!(libc::setsockopt( diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index 13e59d73..d0855833 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -316,6 +316,11 @@ impl UdpSocket { .await } + /// Gets a socket option. + pub fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + self.inner.get_socket_option(level, name) + } + /// Sets a socket option. pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { self.inner.set_socket_option(level, name, value) From c4e9db5531320d0a2db74ae58f3ed68ba6db9e03 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 1 Aug 2024 02:42:18 +0800 Subject: [PATCH 02/26] feat(quic): basic endpoint & connection --- Cargo.toml | 2 + compio-quic/Cargo.toml | 57 +++ compio-quic/examples/client.rs | 34 ++ compio-quic/examples/server.rs | 29 ++ compio-quic/src/builder.rs | 512 ++++++++++++++++++++++ compio-quic/src/connection.rs | 465 ++++++++++++++++++++ compio-quic/src/endpoint.rs | 449 +++++++++++++++++++ compio-quic/src/incoming.rs | 141 ++++++ compio-quic/src/lib.rs | 27 ++ compio-quic/src/socket.rs | 774 +++++++++++++++++++++++++++++++++ compio-tls/Cargo.toml | 4 +- 11 files changed, 2492 insertions(+), 2 deletions(-) create mode 100644 compio-quic/Cargo.toml create mode 100644 compio-quic/examples/client.rs create mode 100644 compio-quic/examples/server.rs create mode 100644 compio-quic/src/builder.rs create mode 100644 compio-quic/src/connection.rs create mode 100644 compio-quic/src/endpoint.rs create mode 100644 compio-quic/src/incoming.rs create mode 100644 compio-quic/src/lib.rs create mode 100644 compio-quic/src/socket.rs diff --git a/Cargo.toml b/Cargo.toml index 02a44d46..0a00f90f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "compio-tls", "compio-log", "compio-process", + "compio-quic", ] resolver = "2" @@ -49,6 +50,7 @@ nix = "0.29.0" once_cell = "1.18.0" os_pipe = "1.1.4" paste = "1.0.14" +rustls = { version = "0.23.1", default-features = false } slab = "0.4.9" socket2 = "0.5.6" tempfile = "3.8.1" diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml new file mode 100644 index 00000000..c162ec2e --- /dev/null +++ b/compio-quic/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "compio-quic" +version = "0.1.0" +description = "QUIC for compio" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "net", "quic"] +edition = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +repository = { workspace = true } + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +# Workspace dependencies +compio-buf = { workspace = true, features = ["bytes"] } +compio-log = { workspace = true } +compio-net = { workspace = true } +compio-runtime = { workspace = true, features = ["time"] } + +quinn-proto = "0.11.3" +rustls = { workspace = true } +rustls-platform-verifier = { version = "0.3.3", optional = true } +rustls-native-certs = { version = "0.7.1", optional = true } +webpki-roots = { version = "0.26.3", optional = true } + +# Utils +event-listener = "5.3.1" +flume = { workspace = true } +futures-util = { workspace = true } +thiserror = "1.0.63" + +# Windows specific dependencies +[target.'cfg(windows)'.dependencies] +windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } + +# Linux specific dependencies +[target.'cfg(target_os = "linux")'.dependencies] + +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } + +[dev-dependencies] +compio-driver = { workspace = true } +compio-macros = { workspace = true } +rcgen = "0.13.1" +socket2 = { workspace = true, features = ["all"] } +tracing-subscriber = "0.3.18" + +[features] +default = ["webpki-roots"] +platform-verifier = ["dep:rustls-platform-verifier"] +native-certs = ["dep:rustls-native-certs"] +webpki-roots = ["dep:webpki-roots"] diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/client.rs new file mode 100644 index 00000000..a243570f --- /dev/null +++ b/compio-quic/examples/client.rs @@ -0,0 +1,34 @@ +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; + +use compio_quic::Endpoint; +use tracing_subscriber::filter::LevelFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_max_level(LevelFilter::TRACE) + .init(); + + let endpoint = Endpoint::client() + .with_no_server_verification() + .with_alpn_protocols(&["hq-29"]) + .with_key_log() + .bind("[::1]:0") + .await + .unwrap(); + + { + let conn = endpoint + .connect( + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 4433), + "localhost", + None, + ) + .unwrap() + .await + .unwrap(); + conn.close(1u32.into(), "bye"); + conn.closed().await; + } + endpoint.close(0u32.into(), "").await.unwrap(); +} diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs new file mode 100644 index 00000000..98bb9bdc --- /dev/null +++ b/compio-quic/examples/server.rs @@ -0,0 +1,29 @@ +use compio_quic::Endpoint; +use tracing_subscriber::filter::LevelFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_max_level(LevelFilter::TRACE) + .init(); + + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_chain = vec![cert.cert.into()]; + let key_der = cert.key_pair.serialize_der().try_into().unwrap(); + + let endpoint = Endpoint::server() + .with_single_cert(cert_chain, key_der) + .unwrap() + .with_alpn_protocols(&["hq-29"]) + .with_key_log() + .bind("[::1]:4433") + .await + .unwrap(); + + if let Some(incoming) = endpoint.wait_incoming().await { + let conn = incoming.await.unwrap(); + conn.closed().await; + } + + endpoint.close(0u32.into(), "").await.unwrap(); +} diff --git a/compio-quic/src/builder.rs b/compio-quic/src/builder.rs new file mode 100644 index 00000000..d22a8e44 --- /dev/null +++ b/compio-quic/src/builder.rs @@ -0,0 +1,512 @@ +use std::{ + io, + net::{SocketAddrV4, SocketAddrV6}, + sync::Arc, + time::Duration, +}; + +use compio_net::{ToSocketAddrsAsync, UdpSocket}; +use quinn_proto::{ + crypto::rustls::{QuicClientConfig, QuicServerConfig}, + ClientConfig, EndpointConfig, ServerConfig, TransportConfig, +}; + +use crate::Endpoint; + +/// A [builder] for [`Endpoint`] in client mode. +/// +/// To get one, call [`Endpoint::client()`] or [`ClientBuilder::default()`]. +/// +/// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html +#[derive(Debug)] +pub struct ClientBuilder { + inner: T, + + alpn_protocols: Vec>, + key_log: bool, + enable_early_data: bool, + + transport: Option, + version: Option, + + endpoint_config: EndpointConfig, +} + +impl Default for ClientBuilder<()> { + fn default() -> Self { + Self { + inner: (), + alpn_protocols: Vec::new(), + key_log: false, + enable_early_data: true, + transport: None, + version: None, + endpoint_config: EndpointConfig::default(), + } + } +} + +impl From>> for Result, E> { + fn from(builder: ClientBuilder>) -> Self { + builder.inner.map(|inner| ClientBuilder { + inner, + alpn_protocols: builder.alpn_protocols, + key_log: builder.key_log, + enable_early_data: builder.enable_early_data, + transport: builder.transport, + version: builder.version, + endpoint_config: builder.endpoint_config, + }) + } +} + +impl ClientBuilder { + fn map_inner(self, f: impl FnOnce(T) -> S) -> ClientBuilder { + ClientBuilder { + inner: f(self.inner), + alpn_protocols: self.alpn_protocols, + key_log: self.key_log, + enable_early_data: self.enable_early_data, + transport: self.transport, + version: self.version, + endpoint_config: self.endpoint_config, + } + } + + /// Set the ALPN protocols to use. + pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { + self.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); + self + } + + /// Logging key material to a file for debugging. The file's name is given + /// by the `SSLKEYLOGFILE` environment variable. + /// + /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot + /// be written, this does nothing. + pub fn with_key_log(mut self) -> Self { + self.key_log = true; + self + } + + /// Set a custom [`TransportConfig`]. + pub fn with_transport_config(mut self, transport: TransportConfig) -> Self { + self.transport = Some(transport); + self + } + + /// Set the QUIC version to use. + pub fn with_version(mut self, version: u32) -> Self { + self.version = Some(version); + self + } + + /// Use the provided [`EndpointConfig`]. + pub fn with_endpoint_config(mut self, endpoint_config: EndpointConfig) -> Self { + self.endpoint_config = endpoint_config; + self + } +} + +impl ClientBuilder<()> { + /// Use the provided [`rustls::ClientConfig`]. + pub fn with_rustls_client_config( + self, + client_config: rustls::ClientConfig, + ) -> ClientBuilder { + self.map_inner(|_| client_config) + } + + /// Do not verify the server's certificate. It is vulnerable to MITM + /// attacks, but convenient for testing. + pub fn with_no_server_verification( + self, + ) -> ClientBuilder> { + self.map_inner(|_| Arc::new(verifier::SkipServerVerification::new()) as _) + } + + /// Use [`rustls_platform_verifier`]. + #[cfg(feature = "platform-verifier")] + pub fn with_platform_verifier( + self, + ) -> ClientBuilder> { + self.map_inner(|_| Arc::new(rustls_platform_verifier::Verifier::new()) as _) + } + + /// Use an empty [`rustls::RootCertStore`]. + pub fn with_root_certificates(self) -> ClientBuilder { + self.map_inner(|_| rustls::RootCertStore::empty()) + } +} + +impl ClientBuilder { + /// Create an [`Endpoint`] binding to the addr provided. + pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { + let mut client_config = self.inner; + + client_config.alpn_protocols = self.alpn_protocols; + if self.key_log { + client_config.key_log = Arc::new(rustls::KeyLogFile::new()); + } + client_config.enable_early_data = self.enable_early_data; + + let mut client_config = ClientConfig::new(Arc::new( + QuicClientConfig::try_from(client_config) + .expect("should support TLS13_AES_128_GCM_SHA256"), + )); + + if let Some(transport) = self.transport { + client_config.transport_config(Arc::new(transport)); + } + if let Some(version) = self.version { + client_config.version(version); + } + + let socket = UdpSocket::bind(addr).await?; + Endpoint::new(socket, self.endpoint_config, None, Some(client_config)) + } +} + +impl ClientBuilder> { + /// Create an [`Endpoint`] binding to the addr provided. + pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { + self.map_inner(|verifier| { + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(verifier) + .with_no_client_auth() + }) + .bind(addr) + .await + } +} + +impl ClientBuilder { + /// Use [`rustls_native_certs`]. + #[cfg(feature = "native-certs")] + pub fn with_native_certs(mut self) -> io::Result { + self.inner + .add_parsable_certificates(rustls_native_certs::load_native_certs()?); + Ok(self) + } + + /// Use [`webpki_roots`]. + #[cfg(feature = "webpki-roots")] + pub fn with_webpki_roots(mut self) -> Self { + self.inner + .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + self + } + + /// Add a custom certificate. + pub fn with_custom_certificate( + mut self, + der: rustls::pki_types::CertificateDer, + ) -> Result { + self.inner.add(der)?; + Ok(self) + } + + /// Verify the revocation state of presented client certificates against the + /// provided certificate revocation lists (CRLs). + pub fn with_crls( + self, + crls: impl IntoIterator>, + ) -> Result< + ClientBuilder>, + rustls::client::VerifierBuilderError, + > { + self.map_inner(|roots| { + rustls::client::WebPkiServerVerifier::builder(Arc::new(roots)) + .with_crls(crls) + .build() + .map(|v| v as _) + }) + .into() + } + + /// Create an [`Endpoint`] binding to the addr provided. + pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { + self.map_inner(|roots| { + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_root_certificates(roots) + .with_no_client_auth() + }) + .bind(addr) + .await + } +} + +/// A [builder] for [`Endpoint`] in server mode. +/// +/// To get one, call [`Endpoint::server()`] or [`ServerBuilder::default()`]. +/// +/// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html +#[derive(Debug)] +pub struct ServerBuilder { + inner: T, + + alpn_protocols: Vec>, + key_log: bool, + enable_early_data: bool, + + transport: Option, + retry_token_lifetime: Option, + migration: bool, + preferred_address_v4: Option, + preferred_address_v6: Option, + max_incoming: Option, + incoming_buffer_size: Option, + incoming_buffer_size_total: Option, + + endpoint_config: EndpointConfig, +} + +impl Default for ServerBuilder<()> { + fn default() -> Self { + Self { + inner: (), + alpn_protocols: Vec::new(), + key_log: false, + enable_early_data: true, + transport: None, + retry_token_lifetime: None, + migration: true, + preferred_address_v4: None, + preferred_address_v6: None, + max_incoming: None, + incoming_buffer_size: None, + incoming_buffer_size_total: None, + endpoint_config: EndpointConfig::default(), + } + } +} + +impl ServerBuilder { + fn map_inner(self, f: impl FnOnce(T) -> S) -> ServerBuilder { + ServerBuilder { + inner: f(self.inner), + alpn_protocols: self.alpn_protocols, + key_log: self.key_log, + enable_early_data: self.enable_early_data, + transport: self.transport, + retry_token_lifetime: self.retry_token_lifetime, + migration: self.migration, + preferred_address_v4: self.preferred_address_v4, + preferred_address_v6: self.preferred_address_v6, + max_incoming: self.max_incoming, + incoming_buffer_size: self.incoming_buffer_size, + incoming_buffer_size_total: self.incoming_buffer_size_total, + endpoint_config: self.endpoint_config, + } + } + + /// Set the ALPN protocols to use. + pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { + self.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); + self + } + + /// Logging key material to a file for debugging. The file's name is given + /// by the `SSLKEYLOGFILE` environment variable. + /// + /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot + /// be written, this does nothing. + pub fn with_key_log(mut self) -> Self { + self.key_log = true; + self + } + + /// Set a custom [`TransportConfig`]. + pub fn with_transport_config(mut self, transport: TransportConfig) -> Self { + self.transport = Some(transport); + self + } + + /// Duration after a stateless retry token was issued for which it's + /// considered valid. + pub fn with_retry_token_lifetime(mut self, retry_token_lifetime: Duration) -> Self { + self.retry_token_lifetime = Some(retry_token_lifetime); + self + } + + /// Whether to allow clients to migrate to new addresses. + /// + /// See [`quinn_proto::ServerConfig::migration`]. + pub fn with_migration(mut self, migration: bool) -> Self { + self.migration = migration; + self + } + + /// The preferred IPv4 address during handshaking. + /// + /// See [`quinn_proto::ServerConfig::preferred_address_v4`]. + pub fn with_preferred_address_v4(mut self, addr: SocketAddrV4) -> Self { + self.preferred_address_v4 = Some(addr); + self + } + + /// The preferred IPv6 address during handshaking. + /// + /// See [`quinn_proto::ServerConfig::preferred_address_v6`]. + pub fn with_preferred_address_v6(mut self, addr: SocketAddrV6) -> Self { + self.preferred_address_v6 = Some(addr); + self + } + + /// Maximum number of [`Incoming`][crate::Incoming] to allow to exist at a + /// time. + /// + /// See [`quinn_proto::ServerConfig::max_incoming`]. + pub fn with_max_incoming(mut self, max_incoming: usize) -> Self { + self.max_incoming = Some(max_incoming); + self + } + + /// Maximum number of received bytes to buffer for each + /// [`Incoming`][crate::Incoming]. + /// + /// See [`quinn_proto::ServerConfig::incoming_buffer_size`]. + pub fn with_incoming_buffer_size(mut self, incoming_buffer_size: u64) -> Self { + self.incoming_buffer_size = Some(incoming_buffer_size); + self + } + + /// Maximum number of received bytes to buffer for all + /// [`Incoming`][crate::Incoming] collectively. + /// + /// See [`quinn_proto::ServerConfig::incoming_buffer_size_total`]. + pub fn with_incoming_buffer_size_total(mut self, incoming_buffer_size_total: u64) -> Self { + self.incoming_buffer_size_total = Some(incoming_buffer_size_total); + self + } + + /// Use the provided [`EndpointConfig`]. + pub fn with_endpoint_config(mut self, endpoint_config: EndpointConfig) -> Self { + self.endpoint_config = endpoint_config; + self + } +} + +impl ServerBuilder<()> { + /// Use the provided [`rustls::ServerConfig`]. + pub fn with_rustls_server_config( + self, + server_config: rustls::ServerConfig, + ) -> ServerBuilder { + self.map_inner(|_| server_config) + } + + /// Sets a single certificate chain and matching private key. + pub fn with_single_cert( + self, + cert_chain: Vec>, + key_der: rustls::pki_types::PrivateKeyDer<'static>, + ) -> Result, rustls::Error> { + let server_config = + rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert(cert_chain, key_der)?; + Ok(self.with_rustls_server_config(server_config)) + } +} + +impl ServerBuilder { + /// Create an [`Endpoint`] binding to the addr provided. + pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { + let mut server_config = self.inner; + + server_config.alpn_protocols = self.alpn_protocols; + if self.key_log { + server_config.key_log = Arc::new(rustls::KeyLogFile::new()); + } + if self.enable_early_data { + server_config.max_early_data_size = u32::MAX; + } + + let mut server_config = ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(server_config) + .expect("should support TLS13_AES_128_GCM_SHA256"), + )); + + if let Some(transport) = self.transport { + server_config.transport_config(Arc::new(transport)); + } + if let Some(retry_token_lifetime) = self.retry_token_lifetime { + server_config.retry_token_lifetime(retry_token_lifetime); + } + server_config + .migration(self.migration) + .preferred_address_v4(self.preferred_address_v4) + .preferred_address_v6(self.preferred_address_v6); + if let Some(max_incoming) = self.max_incoming { + server_config.max_incoming(max_incoming); + } + if let Some(incoming_buffer_size) = self.incoming_buffer_size { + server_config.incoming_buffer_size(incoming_buffer_size); + } + if let Some(incoming_buffer_size_total) = self.incoming_buffer_size_total { + server_config.incoming_buffer_size_total(incoming_buffer_size_total); + } + + let socket = UdpSocket::bind(addr).await?; + Endpoint::new(socket, self.endpoint_config, Some(server_config), None) + } +} + +mod verifier { + use rustls::{ + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + crypto::WebPkiSupportedAlgorithms, + pki_types::{CertificateDer, ServerName, UnixTime}, + DigitallySignedStruct, Error, SignatureScheme, + }; + + #[derive(Debug)] + pub struct SkipServerVerification(WebPkiSupportedAlgorithms); + + impl SkipServerVerification { + pub fn new() -> Self { + Self( + rustls::crypto::CryptoProvider::get_default() + .unwrap() + .signature_verification_algorithms, + ) + } + } + + impl ServerCertVerifier for SkipServerVerification { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0) + } + + fn supported_verify_schemes(&self) -> Vec { + self.0.supported_schemes() + } + } +} diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs new file mode 100644 index 00000000..15874116 --- /dev/null +++ b/compio-quic/src/connection.rs @@ -0,0 +1,465 @@ +use std::{ + io, + net::{IpAddr, SocketAddr}, + pin::{pin, Pin}, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, + time::{Duration, Instant}, +}; + +use compio_buf::BufResult; +use compio_runtime::JoinHandle; +use flume::{Receiver, Sender}; +use futures_util::{ + future::{poll_fn, Fuse, FusedFuture, Future, LocalBoxFuture}, + select, FutureExt, +}; +use quinn_proto::{ + congestion::Controller, crypto::rustls::HandshakeData, ConnectionError, ConnectionHandle, + ConnectionStats, EndpointEvent, VarInt, +}; + +use crate::Socket; + +#[derive(Debug)] +pub(crate) enum ConnectionEvent { + Close(VarInt, String), + Proto(quinn_proto::ConnectionEvent), +} + +#[derive(Debug)] +struct ConnectionState { + conn: quinn_proto::Connection, + connected: bool, + error: Option, + worker: Option>, + on_connected: Option, + on_handshake_data: Option, +} + +impl ConnectionState { + fn terminate(&mut self, reason: ConnectionError) { + self.error = Some(reason); + self.connected = false; + + if let Some(waker) = self.on_connected.take() { + waker.wake() + } + if let Some(waker) = self.on_handshake_data.take() { + waker.wake() + } + } + + #[inline] + fn try_map(&self, f: impl Fn(&Self) -> Option) -> Option> { + if let Some(error) = &self.error { + Some(Err(error.clone())) + } else { + f(self).map(Ok) + } + } + + #[inline] + fn try_handshake_data(&self) -> Option, ConnectionError>> { + self.try_map(|state| { + state + .conn + .crypto_session() + .handshake_data() + .map(|data| data.downcast::().unwrap()) + }) + } +} + +#[derive(Debug)] +struct ConnectionInner { + state: Mutex, + handle: ConnectionHandle, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver, +} + +impl ConnectionInner { + fn new( + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver, + ) -> Self { + Self { + state: Mutex::new(ConnectionState { + conn, + connected: false, + error: None, + worker: None, + on_connected: None, + on_handshake_data: None, + }), + handle, + socket, + events_tx, + events_rx, + } + } + + fn close(&self, error_code: VarInt, reason: String) { + let mut state = self.state.lock().unwrap(); + state.conn.close(Instant::now(), error_code, reason.into()); + state.terminate(ConnectionError::LocallyClosed); + } + + async fn run(&self) -> io::Result<()> { + let mut send_buf = Some(Vec::with_capacity( + self.state.lock().unwrap().conn.current_mtu() as usize, + )); + let mut transmit_fut = pin!(Fuse::terminated()); + + let mut timer = Timer::new(); + + loop { + { + let now = Instant::now(); + let mut state = self.state.lock().unwrap(); + + if let Some(mut buf) = send_buf.take() { + if let Some(transmit) = + state + .conn + .poll_transmit(now, self.socket.max_gso_segments(), &mut buf) + { + transmit_fut + .set(async move { self.socket.send(buf, &transmit).await }.fuse()) + } else { + send_buf = Some(buf); + } + } + + timer.reset(state.conn.poll_timeout()); + + while let Some(event) = state.conn.poll_endpoint_events() { + let _ = self.events_tx.send((self.handle, event)); + } + + while let Some(event) = state.conn.poll() { + use quinn_proto::Event::*; + match event { + HandshakeDataReady => { + if let Some(waker) = state.on_handshake_data.take() { + waker.wake() + } + } + Connected => { + state.connected = true; + if let Some(waker) = state.on_connected.take() { + waker.wake() + } + } + ConnectionLost { reason } => state.terminate(reason), + _ => {} + } + } + + if state.conn.is_drained() { + break Ok(()); + } + } + + select! { + _ = timer => { + self.state.lock().unwrap().conn.handle_timeout(Instant::now()); + timer.reset(None); + }, + ev = self.events_rx.recv_async() => match ev { + Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason), + Ok(ConnectionEvent::Proto(ev)) => self.state.lock().unwrap().conn.handle_event(ev), + Err(_) => unreachable!("endpoint dropped connection"), + }, + BufResult(res, mut buf) = transmit_fut => match res { + Ok(()) => { + buf.clear(); + send_buf = Some(buf); + }, + Err(e) => break Err(e), + }, + } + } + } +} + +macro_rules! conn_fn { + () => { + /// The local IP address which was used when the peer established + /// the connection. + /// + /// This can be different from the address the endpoint is bound to, in case + /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`. + /// + /// This will return `None` for clients, or when the platform does not + /// expose this information. + pub fn local_ip(&self) -> Option { + self.0.state.lock().unwrap().conn.local_ip() + } + + /// The peer's UDP address. + /// + /// Will panic if called after `poll` has returned `Ready`. + pub fn remote_address(&self) -> SocketAddr { + self.0.state.lock().unwrap().conn.remote_address() + } + + /// Current best estimate of this connection's latency (round-trip-time). + pub fn rtt(&self) -> Duration { + self.0.state.lock().unwrap().conn.rtt() + } + + /// Connection statistics. + pub fn stats(&self) -> ConnectionStats { + self.0.state.lock().unwrap().conn.stats() + } + + /// Current state of the congestion control algorithm. (For debugging + /// purposes) + pub fn congestion_state(&self) -> Box { + self.0 + .state + .lock() + .unwrap() + .conn + .congestion_state() + .clone_box() + } + + /// Cryptographic identity of the peer. + pub fn peer_identity( + &self, + ) -> Option>>> { + self.0 + .state + .lock() + .unwrap() + .conn + .crypto_session() + .peer_identity() + .map(|v| v.downcast().unwrap()) + } + }; +} + +/// In-progress connection attempt future +#[derive(Debug)] +#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] +pub struct Connecting(Arc); + +impl Connecting { + conn_fn!(); + + pub(crate) fn new( + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + events_rx: Receiver, + ) -> Self { + let inner = Arc::new(ConnectionInner::new( + handle, conn, socket, events_tx, events_rx, + )); + let worker = compio_runtime::spawn({ + let inner = inner.clone(); + async move { inner.run().await.unwrap() } + }); + inner.state.lock().unwrap().worker = Some(worker); + Self(inner) + } + + /// Parameters negotiated during the handshake. + pub async fn handshake_data(&mut self) -> Result, ConnectionError> { + poll_fn(|cx| { + let mut state = self.0.state.lock().unwrap(); + if let Some(res) = state.try_handshake_data() { + return Poll::Ready(res); + } + + match &state.on_handshake_data { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.on_handshake_data = Some(cx.waker().clone()), + } + + if let Some(res) = state.try_handshake_data() { + Poll::Ready(res) + } else { + Poll::Pending + } + }) + .await + } +} + +impl Future for Connecting { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut state = self.0.state.lock().unwrap(); + + if let Some(res) = + state.try_map(|state| state.connected.then(|| Connection(self.0.clone()))) + { + return Poll::Ready(res); + } + + match &state.on_connected { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.on_connected = Some(cx.waker().clone()), + } + + if let Some(res) = + state.try_map(|state| state.connected.then(|| Connection(self.0.clone()))) + { + Poll::Ready(res) + } else { + Poll::Pending + } + } +} + +impl Drop for Connecting { + fn drop(&mut self) { + if Arc::strong_count(&self.0) == 2 { + self.0.close(0u32.into(), String::new()) + } + } +} + +/// A QUIC connection. +#[derive(Debug)] +pub struct Connection(Arc); + +impl Connection { + conn_fn!(); + + /// Parameters negotiated during the handshake. + pub fn handshake_data(&mut self) -> Result, ConnectionError> { + self.0.state.lock().unwrap().try_handshake_data().unwrap() + } + + /// Compute the maximum size of datagrams that may be passed to + /// [`send_datagram()`](Self::send_datagram). + /// + /// Returns `None` if datagrams are unsupported by the peer or disabled + /// locally. + /// + /// This may change over the lifetime of a connection according to variation + /// in the path MTU estimate. The peer can also enforce an arbitrarily small + /// fixed limit, but if the peer's limit is large this is guaranteed to be a + /// little over a kilobyte at minimum. + /// + /// Not necessarily the maximum size of received datagrams. + pub fn max_datagram_size(&self) -> Option { + self.0.state.lock().unwrap().conn.datagrams().max_size() + } + + /// Bytes available in the outgoing datagram buffer. + /// + /// When greater than zero, calling [`send_datagram()`](Self::send_datagram) + /// with a datagram of at most this size is guaranteed not to cause older + /// datagrams to be dropped. + pub fn datagram_send_buffer_space(&self) -> usize { + self.0 + .state + .lock() + .unwrap() + .conn + .datagrams() + .send_buffer_space() + } + + /// Close the connection immediately. + /// + /// Pending operations will fail immediately with + /// [`ConnectionError::LocallyClosed`]. Delivery of data on unfinished + /// streams is not guaranteed, so the application must call this only when + /// all important communications have been completed, e.g. by calling + /// [`finish`] on outstanding [`SendStream`]s and waiting for the resulting + /// futures to complete. + /// + /// `error_code` and `reason` are not interpreted, and are provided directly + /// to the peer. + /// + /// `reason` will be truncated to fit in a single packet with overhead; to + /// improve odds that it is preserved in full, it should be kept under 1KiB. + /// + /// [`ConnectionError::LocallyClosed`]: quinn_proto::ConnectionError::LocallyClosed + /// [`finish`]: crate::SendStream::finish + /// [`SendStream`]: crate::SendStream + pub fn close(&self, error_code: VarInt, reason: &str) { + self.0.close(error_code, reason.to_string()); + } + + /// Wait for the connection to be closed for any reason + pub async fn closed(&self) -> ConnectionError { + let worker = self.0.state.lock().unwrap().worker.take(); + if let Some(worker) = worker { + let _ = worker.await; + } + + self.0.state.lock().unwrap().error.clone().unwrap() + } +} + +impl PartialEq for Connection { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Eq for Connection {} + +impl Drop for Connection { + fn drop(&mut self) { + if Arc::strong_count(&self.0) == 2 { + self.close(0u32.into(), "") + } + } +} + +struct Timer { + deadline: Option, + fut: Fuse>, +} + +impl Timer { + fn new() -> Self { + Self { + deadline: None, + fut: Fuse::terminated(), + } + } + + fn reset(&mut self, deadline: Option) { + if let Some(deadline) = deadline { + if self.deadline.is_none() || self.deadline != Some(deadline) { + self.fut = compio_runtime::time::sleep_until(deadline) + .boxed_local() + .fuse(); + } + } else { + self.fut = Fuse::terminated(); + } + self.deadline = deadline; + } +} + +impl Future for Timer { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.fut.poll_unpin(cx) + } +} + +impl FusedFuture for Timer { + fn is_terminated(&self) -> bool { + self.fut.is_terminated() + } +} diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs new file mode 100644 index 00000000..958bc6d0 --- /dev/null +++ b/compio-quic/src/endpoint.rs @@ -0,0 +1,449 @@ +use std::{ + collections::{HashMap, VecDeque}, + io, + mem::ManuallyDrop, + net::{SocketAddr, SocketAddrV6}, + pin::pin, + sync::{Arc, Mutex}, + task::Poll, + time::Instant, +}; + +use compio_buf::BufResult; +use compio_net::UdpSocket; +use compio_runtime::JoinHandle; +use event_listener::{Event, IntoNotification}; +use flume::{unbounded, Receiver, Sender}; +use futures_util::{ + future::{poll_fn, Fuse, FusedFuture}, + select, + task::AtomicWaker, + FutureExt, +}; +use quinn_proto::{ + ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig, + EndpointEvent, ServerConfig, Transmit, VarInt, +}; + +use crate::{ + ClientBuilder, Connecting, ConnectionEvent, Incoming, RecvMeta, ServerBuilder, Socket, +}; + +#[derive(Debug)] +struct EndpointState { + endpoint: quinn_proto::Endpoint, + worker: Option>, + connections: HashMap>, + close: Option<(VarInt, String)>, + resp: VecDeque<(Vec, Transmit)>, + incoming: VecDeque, +} + +impl EndpointState { + fn handle_data(&mut self, meta: RecvMeta, buf: &[u8]) { + let now = Instant::now(); + for data in buf[..meta.len] + .chunks(meta.stride.min(meta.len)) + .map(Into::into) + { + let mut resp_buf = Vec::new(); + match self.endpoint.handle( + now, + meta.remote, + meta.local_ip, + meta.ecn, + data, + &mut resp_buf, + ) { + Some(DatagramEvent::NewConnection(incoming)) => { + if self.close.is_none() { + self.incoming.push_back(incoming); + } else { + let transmit = self.endpoint.refuse(incoming, &mut resp_buf); + self.resp.push_back((resp_buf, transmit)); + } + } + Some(DatagramEvent::ConnectionEvent(ch, event)) => { + let _ = self + .connections + .get(&ch) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + Some(DatagramEvent::Response(transmit)) => { + self.resp.push_back((resp_buf, transmit)) + } + None => {} + } + } + } + + fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) { + if event.is_drained() { + self.connections.remove(&ch); + } + if let Some(event) = self.endpoint.handle_event(ch, event) { + let _ = self + .connections + .get(&ch) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + } + + fn is_idle(&self) -> bool { + self.connections.is_empty() && self.resp.is_empty() + } + + fn try_get_incoming(&mut self) -> Option> { + if self.close.is_none() { + self.incoming.pop_front().map(Some) + } else { + Some(None) + } + } + + fn new_connection( + &mut self, + handle: ConnectionHandle, + conn: quinn_proto::Connection, + socket: Socket, + events_tx: Sender<(ConnectionHandle, EndpointEvent)>, + ) -> Connecting { + let (tx, rx) = unbounded(); + if let Some((error_code, reason)) = &self.close { + tx.send(ConnectionEvent::Close(*error_code, reason.clone())) + .unwrap(); + } + self.connections.insert(handle, tx); + Connecting::new(handle, conn, socket, events_tx, rx) + } +} + +type ChannelPair = (Sender, Receiver); + +#[derive(Debug)] +pub(crate) struct EndpointInner { + state: Mutex, + socket: Socket, + ipv6: bool, + events: ChannelPair<(ConnectionHandle, EndpointEvent)>, + done: AtomicWaker, + incoming: Event, +} + +impl EndpointInner { + fn new( + socket: UdpSocket, + config: EndpointConfig, + server_config: Option, + ) -> io::Result { + let socket = Socket::new(socket)?; + let ipv6 = socket.local_addr()?.is_ipv6(); + let allow_mtud = !socket.may_fragment(); + + Ok(Self { + state: Mutex::new(EndpointState { + endpoint: quinn_proto::Endpoint::new( + Arc::new(config), + server_config.map(Arc::new), + allow_mtud, + None, + ), + worker: None, + connections: HashMap::new(), + close: None, + resp: VecDeque::new(), + incoming: VecDeque::new(), + }), + socket, + ipv6, + events: unbounded(), + done: AtomicWaker::new(), + incoming: Event::new(), + }) + } + + fn connect( + &self, + remote: SocketAddr, + server_name: &str, + config: ClientConfig, + ) -> Result { + let mut state = self.state.lock().unwrap(); + + if state.worker.is_none() { + return Err(ConnectError::EndpointStopping); + } + if remote.is_ipv6() && !self.ipv6 { + return Err(ConnectError::InvalidRemoteAddress(remote)); + } + let remote = if self.ipv6 { + SocketAddr::V6(match remote { + SocketAddr::V4(addr) => { + SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0) + } + SocketAddr::V6(addr) => addr, + }) + } else { + remote + }; + + let (handle, conn) = state + .endpoint + .connect(Instant::now(), config, remote, server_name)?; + + Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone())) + } + + pub(crate) fn accept( + &self, + incoming: quinn_proto::Incoming, + server_config: Option, + ) -> Result { + let mut state = self.state.lock().unwrap(); + let mut resp_buf = Vec::new(); + let now = Instant::now(); + match state + .endpoint + .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new)) + { + Ok((handle, conn)) => { + Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone())) + } + Err(err) => { + if let Some(transmit) = err.response { + state.resp.push_back((resp_buf, transmit)); + } + Err(err.cause) + } + } + } + + pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) { + let mut state = self.state.lock().unwrap(); + let mut resp_buf = Vec::new(); + let transmit = state.endpoint.refuse(incoming, &mut resp_buf); + state.resp.push_back((resp_buf, transmit)); + } + + pub(crate) fn retry( + &self, + incoming: quinn_proto::Incoming, + ) -> Result<(), quinn_proto::RetryError> { + let mut state = self.state.lock().unwrap(); + let mut resp_buf = Vec::new(); + let transmit = state.endpoint.retry(incoming, &mut resp_buf)?; + state.resp.push_back((resp_buf, transmit)); + Ok(()) + } + + pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) { + let mut state = self.state.lock().unwrap(); + state.endpoint.ignore(incoming); + } + + async fn run(&self) -> io::Result<()> { + let mut recv_fut = pin!( + self.socket + .recv(Vec::with_capacity( + self.state + .lock() + .unwrap() + .endpoint + .config() + .get_max_udp_payload_size() + .min(64 * 1024) as usize + * self.socket.max_gro_segments(), + )) + .fuse() + ); + + let mut resp_fut = pin!(Fuse::terminated()); + + loop { + select! { + BufResult(res, recv_buf) = recv_fut => { + match res { + Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf), + Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} + Err(e) => break Err(e), + } + recv_fut.set(self.socket.recv(recv_buf).fuse()); + }, + (ch, event) = self.events.1.recv_async().map(Result::unwrap) => { + self.state.lock().unwrap().handle_event(ch, event); + }, + _ = resp_fut => {}, + } + + let mut state = self.state.lock().unwrap(); + if resp_fut.is_terminated() && !state.resp.is_empty() { + let (data, transmit) = state.resp.pop_front().unwrap(); + resp_fut.set(async move { self.socket.send(data, &transmit).await }.fuse()); + } + + if state.close.is_some() && state.is_idle() { + break Ok(()); + } + + if !state.incoming.is_empty() { + self.incoming.notify(state.incoming.len().additional()); + } + } + } +} + +/// A QUIC endpoint. +#[derive(Debug, Clone)] +pub struct Endpoint { + inner: Arc, + /// The client configuration used by `connect` + pub default_client_config: Option, +} + +impl Endpoint { + /// Create a QUIC endpoint. + pub fn new( + socket: UdpSocket, + config: EndpointConfig, + server_config: Option, + default_client_config: Option, + ) -> io::Result { + let inner = Arc::new(EndpointInner::new(socket, config, server_config)?); + let worker = compio_runtime::spawn({ + let inner = inner.clone(); + async move { inner.run().await.unwrap() } + }); + inner.state.lock().unwrap().worker = Some(worker); + Ok(Self { + inner, + default_client_config, + }) + } + + /// Create a builder for a QUIC client. + pub fn client() -> ClientBuilder<()> { + ClientBuilder::default() + } + + /// Create a builder for a QUIC server. + pub fn server() -> ServerBuilder<()> { + ServerBuilder::default() + } + + /// Connect to a remote endpoint. + pub fn connect( + &self, + remote: SocketAddr, + server_name: &str, + config: Option, + ) -> Result { + let config = if let Some(config) = config { + config + } else if let Some(config) = &self.default_client_config { + config.clone() + } else { + return Err(ConnectError::NoDefaultClientConfig); + }; + + self.inner.connect(remote, server_name, config) + } + + /// Wait for the next incoming connection attempt from a client. + /// + /// Yields [`Incoming`]s, or `None` if the endpoint is + /// [`close`](Self::close)d. [`Incoming`] can be `await`ed to obtain the + /// final [`Connection`](crate::Connection), or used to e.g. filter + /// connection attempts or force address validation, or converted into an + /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT + /// data. + pub async fn wait_incoming(&self) -> Option { + loop { + if let Some(incoming) = self.inner.state.lock().unwrap().try_get_incoming() { + return incoming.map(|incoming| Incoming::new(incoming, self.inner.clone())); + } + + let listener = self.inner.incoming.listen(); + + if let Some(incoming) = self.inner.state.lock().unwrap().try_get_incoming() { + return incoming.map(|incoming| Incoming::new(incoming, self.inner.clone())); + } + + listener.await; + } + } + + // Modified from [`SharedFd::try_unwrap_inner`], see notes there. + unsafe fn try_unwrap_inner(this: &ManuallyDrop) -> Option { + let ptr = ManuallyDrop::new(std::ptr::read(&this.inner)); + match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) { + Ok(inner) => Some(inner), + Err(ptr) => { + std::mem::forget(ptr); + None + } + } + } + + /// Shutdown the endpoint and close the underlying socket. + /// + /// This will close all connections and the underlying socket. Note that it + /// will wait for all connections and all clones of the endpoint (and any + /// clone of the underlying socket) to be dropped before closing the socket. + /// + /// If the endpoint has already been closed or is closing, this will return + /// immediately with `Ok(())`. + /// + /// See [`Connection::close()`](crate::Connection::close) for details. + pub async fn close(self, error_code: VarInt, reason: &str) -> io::Result<()> { + let reason = reason.to_string(); + + { + let close = &mut self.inner.state.lock().unwrap().close; + if close.is_some() { + return Ok(()); + } + close.replace((error_code, reason.clone())); + } + + for conn in self.inner.state.lock().unwrap().connections.values() { + let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone())); + } + + let worker = self.inner.state.lock().unwrap().worker.take(); + if let Some(worker) = worker { + if self.inner.state.lock().unwrap().is_idle() { + worker.cancel().await; + } else { + let _ = worker.await; + } + } + + let this = ManuallyDrop::new(self); + let inner = poll_fn(move |cx| { + if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } { + return Poll::Ready(inner); + } + + this.inner.done.register(cx.waker()); + + if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } { + Poll::Ready(inner) + } else { + Poll::Pending + } + }) + .await; + + inner.socket.close().await + } +} + +impl Drop for Endpoint { + fn drop(&mut self) { + if Arc::strong_count(&self.inner) == 2 { + self.inner.done.wake(); + } + } +} diff --git a/compio-quic/src/incoming.rs b/compio-quic/src/incoming.rs new file mode 100644 index 00000000..0af41b62 --- /dev/null +++ b/compio-quic/src/incoming.rs @@ -0,0 +1,141 @@ +use std::{ + future::{Future, IntoFuture}, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures_util::FutureExt; +use quinn_proto::{ConnectionError, ServerConfig}; +use thiserror::Error; + +use crate::{Connecting, Connection, EndpointInner}; + +#[derive(Debug)] +pub(crate) struct IncomingInner { + pub(crate) incoming: quinn_proto::Incoming, + pub(crate) endpoint: Arc, +} + +/// An incoming connection for which the server has not yet begun its part +/// of the handshake. +#[derive(Debug)] +pub struct Incoming(Option); + +impl Incoming { + pub(crate) fn new(incoming: quinn_proto::Incoming, endpoint: Arc) -> Self { + Self(Some(IncomingInner { incoming, endpoint })) + } + + /// Attempt to accept this incoming connection (an error may still + /// occur). + pub fn accept(mut self) -> Result { + let inner = self.0.take().unwrap(); + inner.endpoint.accept(inner.incoming, None) + } + + /// Accept this incoming connection using a custom configuration. + /// + /// See [`accept()`] for more details. + /// + /// [`accept()`]: Incoming::accept + pub fn accept_with( + mut self, + server_config: ServerConfig, + ) -> Result { + let inner = self.0.take().unwrap(); + inner.endpoint.accept(inner.incoming, Some(server_config)) + } + + /// Reject this incoming connection attempt. + pub fn refuse(mut self) { + let inner = self.0.take().unwrap(); + inner.endpoint.refuse(inner.incoming); + } + + /// Respond with a retry packet, requiring the client to retry with + /// address validation. + /// + /// Errors if `remote_address_validated()` is true. + pub fn retry(mut self) -> Result<(), RetryError> { + let inner = self.0.take().unwrap(); + inner + .endpoint + .retry(inner.incoming) + .map_err(|e| RetryError(Self::new(e.into_incoming(), inner.endpoint))) + } + + /// Ignore this incoming connection attempt, not sending any packet in + /// response. + pub fn ignore(mut self) { + let inner = self.0.take().unwrap(); + inner.endpoint.ignore(inner.incoming); + } + + /// The local IP address which was used when the peer established + /// the connection. + pub fn local_ip(&self) -> Option { + self.0.as_ref().unwrap().incoming.local_ip() + } + + /// The peer's UDP address. + pub fn remote_address(&self) -> SocketAddr { + self.0.as_ref().unwrap().incoming.remote_address() + } + + /// Whether the socket address that is initiating this connection has + /// been validated. + /// + /// This means that the sender of the initial packet has proved that + /// they can receive traffic sent to `self.remote_address()`. + pub fn remote_address_validated(&self) -> bool { + self.0.as_ref().unwrap().incoming.remote_address_validated() + } +} + +impl Drop for Incoming { + fn drop(&mut self) { + // Implicit reject, similar to Connection's implicit close + if let Some(inner) = self.0.take() { + inner.endpoint.refuse(inner.incoming); + } + } +} + +/// Error for attempting to retry an [`Incoming`] which already bears an +/// address validation token from a previous retry. +#[derive(Debug, Error)] +#[error("retry() with validated Incoming")] +pub struct RetryError(Incoming); + +impl RetryError { + /// Get the [`Incoming`] + pub fn into_incoming(self) -> Incoming { + self.0 + } +} + +/// Basic adapter to let [`Incoming`] be `await`-ed like a [`Connecting`]. +#[derive(Debug)] +pub struct IncomingFuture(Result); + +impl Future for IncomingFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match &mut self.0 { + Ok(connecting) => connecting.poll_unpin(cx), + Err(e) => Poll::Ready(Err(e.clone())), + } + } +} + +impl IntoFuture for Incoming { + type IntoFuture = IncomingFuture; + type Output = Result; + + fn into_future(self) -> Self::IntoFuture { + IncomingFuture(self.accept()) + } +} diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs new file mode 100644 index 00000000..7d23f7d2 --- /dev/null +++ b/compio-quic/src/lib.rs @@ -0,0 +1,27 @@ +//! QUIC implementation for compio +//! +//! Ported from [`quinn`]. +//! +//! [`quinn`]: https://docs.rs/quinn + +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![warn(missing_docs)] + +pub use quinn_proto::{ + congestion, crypto, AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, + ConfigError, ConnectError, ConnectionClose, ConnectionError, ConnectionStats, EndpointConfig, + IdleTimeout, MtuDiscoveryConfig, ServerConfig, StreamId, Transmit, TransportConfig, VarInt, +}; + +mod builder; +mod connection; +mod endpoint; +mod incoming; +mod socket; + +pub use builder::*; +pub(crate) use connection::ConnectionEvent; +pub use connection::*; +pub use endpoint::*; +pub use incoming::*; +pub(crate) use socket::*; diff --git a/compio-quic/src/socket.rs b/compio-quic/src/socket.rs new file mode 100644 index 00000000..c3f66454 --- /dev/null +++ b/compio-quic/src/socket.rs @@ -0,0 +1,774 @@ +//! Simple wrapper around UDP socket with advanced features useful for QUIC, +//! ported from [`quinn-udp`] +//! +//! Differences from [`quinn-udp`]: +//! - [quinn-rs/quinn#1516] is not implemented +//! - `recvmmsg` is not available +//! +//! [`quinn-udp`]: https://docs.rs/quinn-udp +//! [quinn-rs/quinn#1516]: https://github.com/quinn-rs/quinn/pull/1516 + +use std::{ + future::Future, + io, + net::{IpAddr, SocketAddr}, + ops::{Deref, DerefMut}, + sync::atomic::{AtomicBool, Ordering}, +}; + +use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; +use compio_net::{CMsgBuilder, CMsgIter, UdpSocket}; +use quinn_proto::{EcnCodepoint, Transmit}; +#[cfg(windows)] +use windows_sys::Win32::Networking::WinSock; + +/// Metadata for a single buffer filled with bytes received from the network +/// +/// This associated buffer can contain one or more datagrams, see [`stride`]. +/// +/// [`stride`]: RecvMeta::stride +#[derive(Debug)] +pub(crate) struct RecvMeta { + /// The source address of the datagram(s) contained in the buffer + pub remote: SocketAddr, + /// The number of bytes the associated buffer has + pub len: usize, + /// The size of a single datagram in the associated buffer + /// + /// When GRO (Generic Receive Offload) is used this indicates the size of a + /// single datagram inside the buffer. If the buffer is larger, that is + /// if [`len`] is greater then this value, then the individual datagrams + /// contained have their boundaries at `stride` increments from the + /// start. The last datagram could be smaller than `stride`. + /// + /// [`len`]: RecvMeta::len + pub stride: usize, + /// The Explicit Congestion Notification bits for the datagram(s) in the + /// buffer + pub ecn: Option, + /// The destination IP address which was encoded in this datagram + /// + /// Populated on platforms: Windows, Linux, Android, FreeBSD, OpenBSD, + /// NetBSD, macOS, and iOS. + pub local_ip: Option, +} + +const CMSG_LEN: usize = 128; + +struct Ancillary { + inner: [u8; N], + len: usize, + #[cfg(unix)] + _align: [libc::cmsghdr; 0], + #[cfg(windows)] + _align: [WinSock::CMSGHDR; 0], +} + +impl Ancillary { + fn new() -> Self { + Self { + inner: [0u8; N], + len: N, + _align: [], + } + } +} + +unsafe impl IoBuf for Ancillary { + fn as_buf_ptr(&self) -> *const u8 { + self.inner.as_buf_ptr() + } + + fn buf_len(&self) -> usize { + self.len + } + + fn buf_capacity(&self) -> usize { + N + } +} + +impl SetBufInit for Ancillary { + unsafe fn set_buf_init(&mut self, len: usize) { + debug_assert!(len <= N); + self.len = len; + } +} + +unsafe impl IoBufMut for Ancillary { + fn as_buf_mut_ptr(&mut self) -> *mut u8 { + self.inner.as_buf_mut_ptr() + } +} + +impl Deref for Ancillary { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.inner[0..self.len] + } +} + +impl DerefMut for Ancillary { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner[0..self.len] + } +} + +#[cfg(target_os = "linux")] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + socket.get_socket_option::(libc::SOL_UDP, libc::UDP_SEGMENT)?; + Ok(64) +} +#[cfg(windows)] +#[inline] +fn max_gso_segments(socket: &UdpSocket) -> io::Result { + socket.get_socket_option::(WinSock::IPPROTO_UDP, WinSock::UDP_SEND_MSG_SIZE)?; + Ok(512) +} +#[cfg(not(any(target_os = "linux", windows)))] +#[inline] +fn max_gso_segments(_socket: &UdpSocket) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) +} + +macro_rules! set_socket_option { + ($socket:expr, $level:expr, $name:expr, $value:expr $(,)?) => { + match $socket.set_socket_option($level, $name, $value) { + Ok(()) => true, + Err(e) => { + compio_log::warn!( + level = stringify!($level), + name = stringify!($name), + "failed to set socket option: {}", + e + ); + if e.kind() == io::ErrorKind::InvalidInput { + true + } else if e.raw_os_error() + == Some( + #[cfg(unix)] + libc::ENOPROTOOPT, + #[cfg(windows)] + WinSock::WSAENOPROTOOPT, + ) + { + false + } else { + return Err(e); + } + } + } + }; +} + +#[derive(Debug)] +pub(crate) struct Socket { + inner: UdpSocket, + max_gro_segments: usize, + max_gso_segments: usize, + may_fragment: bool, + has_gso_error: AtomicBool, + #[cfg(target_os = "freebsd")] + encode_src_ip_v4: bool, +} + +impl Socket { + pub fn new(socket: UdpSocket) -> io::Result { + let is_ipv6 = socket.local_addr()?.is_ipv6(); + #[cfg(unix)] + let only_v6 = is_ipv6 + && socket.get_socket_option::(libc::IPPROTO_IPV6, libc::IPV6_V6ONLY)? != 0; + #[cfg(windows)] + let only_v6 = is_ipv6 + && socket.get_socket_option::(WinSock::IPPROTO_IPV6, WinSock::IPV6_V6ONLY)? != 0; + let is_ipv4 = socket.local_addr()?.is_ipv4() || !only_v6; + + // ECN + if is_ipv4 { + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_RECVTOS, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IP, WinSock::IP_ECN, &1); + } + if is_ipv6 { + #[cfg(unix)] + set_socket_option!(socket, libc::IPPROTO_IPV6, libc::IPV6_RECVTCLASS, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, &1); + } + + // pktinfo / destination address + if is_ipv4 { + #[cfg(any(target_os = "linux", target_os = "android"))] + set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_PKTINFO, &1); + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios" + ))] + set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_RECVDSTADDR, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, &1); + } + if is_ipv6 { + #[cfg(unix)] + set_socket_option!(socket, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, &1); + #[cfg(windows)] + set_socket_option!(socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, &1); + } + + // disable fragmentation + let mut may_fragment = false; + if is_ipv4 { + #[cfg(any(target_os = "linux", target_os = "android"))] + { + may_fragment |= set_socket_option!( + socket, + libc::IPPROTO_IP, + libc::IP_MTU_DISCOVER, + &libc::IP_PMTUDISC_PROBE, + ); + } + #[cfg(any( + target_os = "aix", + target_os = "freebsd", + target_os = "macos", + target_os = "ios" + ))] + { + may_fragment |= set_socket_option!(socket, libc::IPPROTO_IP, libc::IP_DONTFRAG, &1); + } + #[cfg(windows)] + { + may_fragment |= + set_socket_option!(socket, WinSock::IPPROTO_IP, WinSock::IP_DONTFRAGMENT, &1); + } + } + if is_ipv6 { + #[cfg(any(target_os = "linux", target_os = "android"))] + { + may_fragment |= set_socket_option!( + socket, + libc::IPPROTO_IPV6, + libc::IPV6_MTU_DISCOVER, + &libc::IPV6_PMTUDISC_PROBE, + ); + } + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + { + may_fragment |= + set_socket_option!(socket, libc::IPPROTO_IPV6, libc::IPV6_DONTFRAG, &1); + } + #[cfg(any(target_os = "openbsd", target_os = "netbsd"))] + { + // FIXME: workaround until https://github.com/rust-lang/libc/pull/3716 is released (at least in 0.2.155) + may_fragment |= set_socket_option!(socket, libc::IPPROTO_IPV6, 62, &1); + } + #[cfg(windows)] + { + may_fragment |= + set_socket_option!(socket, WinSock::IPPROTO_IPV6, WinSock::IPV6_DONTFRAG, &1); + } + } + + // GRO + #[allow(unused_mut)] // only mutable on Linux and Windows + let mut max_gro_segments = 1; + #[cfg(target_os = "linux")] + if set_socket_option!(socket, libc::SOL_UDP, libc::UDP_GRO, &1) { + max_gro_segments = 64; + } + #[cfg(windows)] + if set_socket_option!( + socket, + WinSock::IPPROTO_UDP, + WinSock::UDP_RECV_MAX_COALESCED_SIZE, + &(u16::MAX as u32), + ) { + max_gro_segments = 64; + } + + // GSO + let max_gso_segments = max_gso_segments(&socket).unwrap_or(1); + + #[cfg(target_os = "freebsd")] + let encode_src_ip_v4 = + socket.local_addr().unwrap().ip() == IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); + + Ok(Self { + inner: socket, + max_gro_segments, + max_gso_segments, + may_fragment, + has_gso_error: AtomicBool::new(false), + #[cfg(target_os = "freebsd")] + encode_src_ip_v4, + }) + } + + #[inline] + pub fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + #[inline] + pub fn may_fragment(&self) -> bool { + self.may_fragment + } + + #[inline] + pub fn max_gro_segments(&self) -> usize { + self.max_gro_segments + } + + #[inline] + pub fn max_gso_segments(&self) -> usize { + if self.has_gso_error.load(Ordering::Relaxed) { + 1 + } else { + self.max_gso_segments + } + } + + pub async fn recv(&self, buffer: T) -> BufResult { + let control = Ancillary::::new(); + + let BufResult(res, (buffer, control)) = self.inner.recv_msg(buffer, control).await; + let ((len, _, remote), buffer) = buf_try!(res, buffer); + + let mut ecn_bits = 0u8; + let mut local_ip = None; + #[allow(unused_mut)] // only mutable on Linux + let mut stride = len; + + // SAFETY: `control` contains valid data + unsafe { + for cmsg in CMsgIter::new(&control) { + #[cfg(windows)] + const UDP_COALESCED_INFO: i32 = WinSock::UDP_COALESCED_INFO as i32; + + match (cmsg.level(), cmsg.ty()) { + // ECN + #[cfg(unix)] + (libc::IPPROTO_IP, libc::IP_TOS) => ecn_bits = *cmsg.data::(), + #[cfg(all(unix, not(any(target_os = "openbsd", target_os = "netbsd"))))] + (libc::IPPROTO_IP, libc::IP_RECVTOS) => ecn_bits = *cmsg.data::(), + #[cfg(unix)] + (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => { + // NOTE: It's OK to use `c_int` instead of `u8` on Apple systems + ecn_bits = *cmsg.data::() as u8 + } + #[cfg(windows)] + (WinSock::IPPROTO_IP, WinSock::IP_ECN) + | (WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN) => { + ecn_bits = *cmsg.data::() as u8 + } + + // pktinfo / destination address + #[cfg(any(target_os = "linux", target_os = "android"))] + (libc::IPPROTO_IP, libc::IP_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi_addr.s_addr.to_ne_bytes())); + } + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { + let in_addr = cmsg.data::(); + local_ip = Some(IpAddr::from(in_addr.s_addr.to_ne_bytes())); + } + #[cfg(windows)] + (WinSock::IPPROTO_IP, WinSock::IP_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi_addr.S_un.S_addr.to_ne_bytes())); + } + #[cfg(unix)] + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.s6_addr)); + } + #[cfg(windows)] + (WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO) => { + let pktinfo = cmsg.data::(); + local_ip = Some(IpAddr::from(pktinfo.ipi6_addr.u.Byte)); + } + + // GRO + #[cfg(target_os = "linux")] + (libc::SOL_UDP, libc::UDP_GRO) => stride = *cmsg.data::() as usize, + #[cfg(windows)] + (WinSock::IPPROTO_UDP, UDP_COALESCED_INFO) => { + stride = *cmsg.data::() as usize + } + + _ => {} + } + } + } + + let meta = RecvMeta { + remote, + len, + stride, + ecn: EcnCodepoint::from_bits(ecn_bits), + local_ip, + }; + BufResult(Ok(meta), buffer) + } + + pub async fn send(&self, buffer: T, transmit: &Transmit) -> BufResult<(), T> { + let is_ipv4 = transmit.destination.ip().to_canonical().is_ipv4(); + let ecn = transmit.ecn.map_or(0, |x| x as u8); + + let mut control = Ancillary::::new(); + let mut builder = CMsgBuilder::new(&mut control); + + // ECN + if is_ipv4 { + #[cfg(all(unix, not(any(target_os = "freebsd", target_os = "netbsd"))))] + builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_int); + #[cfg(target_os = "freebsd")] + builder.try_push(libc::IPPROTO_IP, libc::IP_TOS, ecn as libc::c_uchar); + #[cfg(windows)] + builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_ECN, ecn as i32); + } else { + #[cfg(unix)] + builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn as libc::c_int); + #[cfg(windows)] + builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_ECN, ecn as i32); + } + + // pktinfo / destination address + match transmit.src_ip { + Some(IpAddr::V4(ip)) => { + let addr = u32::from_ne_bytes(ip.octets()); + #[cfg(any(target_os = "linux", target_os = "android"))] + { + let pktinfo = libc::in_pktinfo { + ipi_ifindex: 0, + ipi_spec_dst: libc::in_addr { s_addr: addr }, + ipi_addr: libc::in_addr { s_addr: 0 }, + }; + builder.try_push(libc::IPPROTO_IP, libc::IP_PKTINFO, pktinfo); + } + #[cfg(any( + target_os = "freebsd", + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + { + #[cfg(target_os = "freebsd")] + let encode_src_ip_v4 = self.encode_src_ip_v4; + #[cfg(any( + target_os = "openbsd", + target_os = "netbsd", + target_os = "macos", + target_os = "ios", + ))] + let encode_src_ip_v4 = true; + + if encode_src_ip_v4 { + let addr = libc::in_addr { s_addr: addr }; + builder.try_push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr); + } + } + #[cfg(windows)] + { + let pktinfo = WinSock::IN_PKTINFO { + ipi_addr: WinSock::IN_ADDR { + S_un: WinSock::IN_ADDR_0 { S_addr: addr }, + }, + ipi_ifindex: 0, + }; + builder.try_push(WinSock::IPPROTO_IP, WinSock::IP_PKTINFO, pktinfo); + } + } + Some(IpAddr::V6(ip)) => { + #[cfg(unix)] + { + let pktinfo = libc::in6_pktinfo { + ipi6_ifindex: 0, + ipi6_addr: libc::in6_addr { + s6_addr: ip.octets(), + }, + }; + builder.try_push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo); + } + #[cfg(windows)] + { + let pktinfo = WinSock::IN6_PKTINFO { + ipi6_addr: WinSock::IN6_ADDR { + u: WinSock::IN6_ADDR_0 { Byte: ip.octets() }, + }, + ipi6_ifindex: 0, + }; + builder.try_push(WinSock::IPPROTO_IPV6, WinSock::IPV6_PKTINFO, pktinfo); + } + } + None => {} + } + + // GSO + if let Some(segment_size) = transmit.segment_size { + #[cfg(target_os = "linux")] + builder.try_push(libc::SOL_UDP, libc::UDP_SEGMENT, segment_size as u16); + #[cfg(windows)] + builder.try_push( + WinSock::IPPROTO_UDP, + WinSock::UDP_SEND_MSG_SIZE, + segment_size as u32, + ); + #[cfg(not(any(target_os = "linux", windows)))] + let _ = segment_size; + } + + let len = builder.finish(); + control.len = len; + + let buffer = buffer.slice(0..transmit.size); + let BufResult(res, (buffer, _)) = self + .inner + .send_msg(buffer, control, transmit.destination) + .await; + let buffer = buffer.into_inner(); + match res { + Ok(_) => BufResult(Ok(()), buffer), + Err(e) => { + #[cfg(target_os = "linux")] + if let Some(libc::EIO) | Some(libc::EINVAL) = e.raw_os_error() { + if self.max_gso_segments() > 1 { + self.has_gso_error.store(true, Ordering::Relaxed); + } + } + BufResult(Err(e), buffer) + } + } + } + + pub fn close(self) -> impl Future> { + self.inner.close() + } +} + +impl Clone for Socket { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + may_fragment: self.may_fragment, + max_gro_segments: self.max_gro_segments, + max_gso_segments: self.max_gso_segments, + has_gso_error: AtomicBool::new(self.has_gso_error.load(Ordering::Relaxed)), + #[cfg(target_os = "freebsd")] + encode_src_ip_v4: self.encode_src_ip_v4.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, Ipv6Addr}; + + use compio_driver::AsRawFd; + use socket2::{Domain, Protocol, Socket as Socket2, Type}; + + use super::*; + + async fn test_send_recv( + passive: Socket, + active: Socket, + content: T, + transmit: Transmit, + ) { + let passive_addr = passive.local_addr().unwrap(); + let active_addr = active.local_addr().unwrap(); + + let (_, content) = active.send(content, &transmit).await.unwrap(); + + let segment_size = transmit.segment_size.unwrap_or(transmit.size); + let expected_datagrams = transmit.size / segment_size; + let mut datagrams = 0; + while datagrams < expected_datagrams { + let (meta, buf) = passive + .recv(Vec::with_capacity(u16::MAX as usize)) + .await + .unwrap(); + let segments = meta.len / meta.stride; + for i in 0..segments { + assert_eq!( + &content.as_slice() + [(datagrams + i) * segment_size..(datagrams + i + 1) * segment_size], + &buf[(i * meta.stride)..((i + 1) * meta.stride)] + ); + } + datagrams += segments; + + assert_eq!(meta.ecn, transmit.ecn); + + assert_eq!(meta.remote.port(), active_addr.port()); + for addr in [meta.remote.ip(), meta.local_ip.unwrap()] { + match (active_addr.is_ipv6(), passive_addr.is_ipv6()) { + (_, false) => assert_eq!(addr, Ipv4Addr::LOCALHOST), + (false, true) => assert!( + addr == Ipv4Addr::LOCALHOST || addr == Ipv4Addr::LOCALHOST.to_ipv6_mapped() + ), + (true, true) => assert!( + addr == Ipv6Addr::LOCALHOST || addr == Ipv4Addr::LOCALHOST.to_ipv6_mapped() + ), + } + } + } + assert_eq!(datagrams, expected_datagrams); + } + + /// Helper function to create dualstack udp socket. + /// This is only used for testing. + fn bind_udp_dualstack() -> io::Result { + #[cfg(unix)] + use std::os::fd::{FromRawFd, IntoRawFd}; + #[cfg(windows)] + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let socket = Socket2::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(false)?; + socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0).into())?; + + compio_runtime::Runtime::with_current(|r| r.attach(socket.as_raw_fd()))?; + #[cfg(unix)] + unsafe { + Ok(UdpSocket::from_raw_fd(socket.into_raw_fd())) + } + #[cfg(windows)] + unsafe { + Ok(UdpSocket::from_raw_socket(socket.into_raw_socket())) + } + } + + #[compio_macros::test] + async fn basic() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: None, + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive, active, content, transmit).await; + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_v4() { + let passive = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + async fn ecn_v6() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_dualstack() { + let passive = Socket::new(bind_udp_dualstack().unwrap()).unwrap(); + + let mut dst_v4 = passive.local_addr().unwrap(); + dst_v4.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); + let mut dst_v6 = dst_v4; + dst_v6.set_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)); + + for (src, dst) in [("[::1]:0", dst_v6), ("127.0.0.1:0", dst_v4)] { + let active = Socket::new(UdpSocket::bind(src).await.unwrap()).unwrap(); + + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: dst, + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + } + + #[compio_macros::test] + #[cfg_attr(any(target_os = "openbsd", target_os = "netbsd"), ignore)] + async fn ecn_v4_mapped_v6() { + let passive = Socket::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()).unwrap(); + let active = Socket::new(bind_udp_dualstack().unwrap()).unwrap(); + + let mut dst_addr = passive.local_addr().unwrap(); + dst_addr.set_ip(IpAddr::V6(Ipv4Addr::LOCALHOST.to_ipv6_mapped())); + + for ecn in [EcnCodepoint::Ect0, EcnCodepoint::Ect1] { + let content = b"hello"; + let transmit = Transmit { + destination: dst_addr, + ecn: Some(ecn), + size: content.len(), + segment_size: None, + src_ip: None, + }; + test_send_recv(passive.clone(), active.clone(), content, transmit).await; + } + } + + #[compio_macros::test] + #[cfg_attr(not(any(target_os = "linux", target_os = "windows")), ignore)] + async fn gso() { + let passive = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + let active = Socket::new(UdpSocket::bind("[::1]:0").await.unwrap()).unwrap(); + + let max_segments = active.max_gso_segments(); + const SEGMENT_SIZE: usize = 128; + let content = vec![0xAB; SEGMENT_SIZE * max_segments]; + + let transmit = Transmit { + destination: passive.local_addr().unwrap(), + ecn: None, + size: content.len(), + segment_size: Some(SEGMENT_SIZE), + src_ip: None, + }; + test_send_recv(passive, active, content, transmit).await; + } +} diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml index 34bb7d90..ef70c561 100644 --- a/compio-tls/Cargo.toml +++ b/compio-tls/Cargo.toml @@ -19,7 +19,7 @@ compio-buf = { workspace = true } compio-io = { workspace = true, features = ["compat"] } native-tls = { version = "0.2.11", optional = true } -rustls = { version = "0.23.1", default-features = false, optional = true, features = [ +rustls = { workspace = true, default-features = false, optional = true, features = [ "logging", "std", "tls12", @@ -30,7 +30,7 @@ compio-net = { workspace = true } compio-runtime = { workspace = true } compio-macros = { workspace = true } -rustls = { version = "0.23.1", default-features = false, features = ["ring"] } +rustls = { workspace = true, default-features = false, features = ["ring"] } rustls-native-certs = "0.7.0" [features] From 110011bd6b45417de097007b87d3dd26e024a0e9 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Mon, 5 Aug 2024 18:13:34 +0800 Subject: [PATCH 03/26] fix(quic): blocking --- compio-quic/Cargo.toml | 4 +- compio-quic/src/builder.rs | 6 +- compio-quic/src/connection.rs | 185 ++++++++++++++++++++++++---------- compio-quic/src/endpoint.rs | 47 ++++----- 4 files changed, 161 insertions(+), 81 deletions(-) diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index c162ec2e..cbb71bb1 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -16,13 +16,13 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] # Workspace dependencies -compio-buf = { workspace = true, features = ["bytes"] } +compio-buf = { workspace = true } compio-log = { workspace = true } compio-net = { workspace = true } compio-runtime = { workspace = true, features = ["time"] } quinn-proto = "0.11.3" -rustls = { workspace = true } +rustls = { workspace = true, features = ["ring"] } rustls-platform-verifier = { version = "0.3.3", optional = true } rustls-native-certs = { version = "0.7.1", optional = true } webpki-roots = { version = "0.26.3", optional = true } diff --git a/compio-quic/src/builder.rs b/compio-quic/src/builder.rs index d22a8e44..c8b20030 100644 --- a/compio-quic/src/builder.rs +++ b/compio-quic/src/builder.rs @@ -469,8 +469,10 @@ mod verifier { pub fn new() -> Self { Self( rustls::crypto::CryptoProvider::get_default() - .unwrap() - .signature_verification_algorithms, + .map(|provider| provider.signature_verification_algorithms) + .unwrap_or_else(|| { + rustls::crypto::ring::default_provider().signature_verification_algorithms + }), ) } } diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index 15874116..08b71078 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -11,12 +11,12 @@ use compio_buf::BufResult; use compio_runtime::JoinHandle; use flume::{Receiver, Sender}; use futures_util::{ - future::{poll_fn, Fuse, FusedFuture, Future, LocalBoxFuture}, - select, FutureExt, + future::{self, Fuse, FusedFuture, LocalBoxFuture}, + select, stream, Future, FutureExt, StreamExt, }; use quinn_proto::{ congestion::Controller, crypto::rustls::HandshakeData, ConnectionError, ConnectionHandle, - ConnectionStats, EndpointEvent, VarInt, + ConnectionStats, Dir, EndpointEvent, VarInt, }; use crate::Socket; @@ -33,6 +33,7 @@ struct ConnectionState { connected: bool, error: Option, worker: Option>, + poll_waker: Option, on_connected: Option, on_handshake_data: Option, } @@ -50,6 +51,12 @@ impl ConnectionState { } } + fn wake(&mut self) { + if let Some(waker) = self.poll_waker.take() { + waker.wake() + } + } + #[inline] fn try_map(&self, f: impl Fn(&Self) -> Option) -> Option> { if let Some(error) = &self.error { @@ -94,6 +101,7 @@ impl ConnectionInner { connected: false, error: None, worker: None, + poll_waker: None, on_connected: None, on_handshake_data: None, }), @@ -108,6 +116,7 @@ impl ConnectionInner { let mut state = self.state.lock().unwrap(); state.conn.close(Instant::now(), error_code, reason.into()); state.terminate(ConnectionError::LocallyClosed); + state.wake(); } async fn run(&self) -> io::Result<()> { @@ -118,65 +127,34 @@ impl ConnectionInner { let mut timer = Timer::new(); - loop { - { - let now = Instant::now(); - let mut state = self.state.lock().unwrap(); - - if let Some(mut buf) = send_buf.take() { - if let Some(transmit) = - state - .conn - .poll_transmit(now, self.socket.max_gso_segments(), &mut buf) - { - transmit_fut - .set(async move { self.socket.send(buf, &transmit).await }.fuse()) - } else { - send_buf = Some(buf); - } - } - - timer.reset(state.conn.poll_timeout()); - - while let Some(event) = state.conn.poll_endpoint_events() { - let _ = self.events_tx.send((self.handle, event)); - } - - while let Some(event) = state.conn.poll() { - use quinn_proto::Event::*; - match event { - HandshakeDataReady => { - if let Some(waker) = state.on_handshake_data.take() { - waker.wake() - } - } - Connected => { - state.connected = true; - if let Some(waker) = state.on_connected.take() { - waker.wake() - } - } - ConnectionLost { reason } => state.terminate(reason), - _ => {} - } - } - - if state.conn.is_drained() { - break Ok(()); - } + let mut poller = stream::poll_fn(|cx| { + let mut state = self.state.lock().unwrap(); + let ready = state.poll_waker.is_none(); + match &state.poll_waker { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.poll_waker = Some(cx.waker().clone()), + }; + if ready { + Poll::Ready(Some(())) + } else { + Poll::Pending } + }) + .fuse(); + loop { select! { + _ = poller.next() => {} _ = timer => { self.state.lock().unwrap().conn.handle_timeout(Instant::now()); timer.reset(None); - }, + } ev = self.events_rx.recv_async() => match ev { Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason), Ok(ConnectionEvent::Proto(ev)) => self.state.lock().unwrap().conn.handle_event(ev), Err(_) => unreachable!("endpoint dropped connection"), }, - BufResult(res, mut buf) = transmit_fut => match res { + BufResult::<(), Vec>(res, mut buf) = transmit_fut => match res { Ok(()) => { buf.clear(); send_buf = Some(buf); @@ -184,6 +162,50 @@ impl ConnectionInner { Err(e) => break Err(e), }, } + + let now = Instant::now(); + let mut state = self.state.lock().unwrap(); + + if let Some(mut buf) = send_buf.take() { + if let Some(transmit) = + state + .conn + .poll_transmit(now, self.socket.max_gso_segments(), &mut buf) + { + transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse()) + } else { + send_buf = Some(buf); + } + } + + timer.reset(state.conn.poll_timeout()); + + while let Some(event) = state.conn.poll_endpoint_events() { + let _ = self.events_tx.send((self.handle, event)); + } + + while let Some(event) = state.conn.poll() { + use quinn_proto::Event::*; + match event { + HandshakeDataReady => { + if let Some(waker) = state.on_handshake_data.take() { + waker.wake() + } + } + Connected => { + state.connected = true; + if let Some(waker) = state.on_connected.take() { + waker.wake() + } + } + ConnectionLost { reason } => state.terminate(reason), + _ => {} + } + } + + if state.conn.is_drained() { + break Ok(()); + } } } } @@ -244,6 +266,32 @@ macro_rules! conn_fn { .peer_identity() .map(|v| v.downcast().unwrap()) } + + /// Derive keying material from this connection's TLS session secrets. + /// + /// When both peers call this method with the same `label` and `context` + /// arguments and `output` buffers of equal length, they will get the + /// same sequence of bytes in `output`. These bytes are cryptographically + /// strong and pseudorandom, and are suitable for use as keying material. + /// + /// This function fails if called with an empty `output` or called prior to + /// the handshake completing. + /// + /// See [RFC5705](https://tools.ietf.org/html/rfc5705) for more information. + pub fn export_keying_material( + &self, + output: &mut [u8], + label: &[u8], + context: &[u8], + ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> { + self.0 + .state + .lock() + .unwrap() + .conn + .crypto_session() + .export_keying_material(output, label, context) + } }; } @@ -275,7 +323,7 @@ impl Connecting { /// Parameters negotiated during the handshake. pub async fn handshake_data(&mut self) -> Result, ConnectionError> { - poll_fn(|cx| { + future::poll_fn(|cx| { let mut state = self.0.state.lock().unwrap(); if let Some(res) = state.try_handshake_data() { return Poll::Ready(res); @@ -374,6 +422,39 @@ impl Connection { .send_buffer_space() } + /// Modify the number of remotely initiated unidirectional streams that may + /// be concurrently open + /// + /// No streams may be opened by the peer unless fewer than `count` are + /// already open. Large `count`s increase both minimum and worst-case + /// memory consumption. + pub fn set_max_concurrent_uni_streams(&self, count: VarInt) { + let mut state = self.0.state.lock().unwrap(); + state.conn.set_max_concurrent_streams(Dir::Uni, count); + // May need to send MAX_STREAMS to make progress + state.wake(); + } + + /// See [`quinn_proto::TransportConfig::receive_window()`] + pub fn set_receive_window(&self, receive_window: VarInt) { + let mut state = self.0.state.lock().unwrap(); + state.conn.set_receive_window(receive_window); + state.wake(); + } + + /// Modify the number of remotely initiated bidirectional streams that may + /// be concurrently open + /// + /// No streams may be opened by the peer unless fewer than `count` are + /// already open. Large `count`s increase both minimum and worst-case + /// memory consumption. + pub fn set_max_concurrent_bi_streams(&self, count: VarInt) { + let mut state = self.0.state.lock().unwrap(); + state.conn.set_max_concurrent_streams(Dir::Bi, count); + // May need to send MAX_STREAMS to make progress + state.wake(); + } + /// Close the connection immediately. /// /// Pending operations will fail immediately with diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 958bc6d0..2ca6b59d 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -12,10 +12,10 @@ use std::{ use compio_buf::BufResult; use compio_net::UdpSocket; use compio_runtime::JoinHandle; -use event_listener::{Event, IntoNotification}; +use event_listener::{listener, Event, IntoNotification}; use flume::{unbounded, Receiver, Sender}; use futures_util::{ - future::{poll_fn, Fuse, FusedFuture}, + future::{self}, select, task::AtomicWaker, FutureExt, @@ -35,12 +35,11 @@ struct EndpointState { worker: Option>, connections: HashMap>, close: Option<(VarInt, String)>, - resp: VecDeque<(Vec, Transmit)>, incoming: VecDeque, } impl EndpointState { - fn handle_data(&mut self, meta: RecvMeta, buf: &[u8]) { + fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec, Transmit)) { let now = Instant::now(); for data in buf[..meta.len] .chunks(meta.stride.min(meta.len)) @@ -60,7 +59,7 @@ impl EndpointState { self.incoming.push_back(incoming); } else { let transmit = self.endpoint.refuse(incoming, &mut resp_buf); - self.resp.push_back((resp_buf, transmit)); + respond_fn(resp_buf, transmit); } } Some(DatagramEvent::ConnectionEvent(ch, event)) => { @@ -70,9 +69,7 @@ impl EndpointState { .unwrap() .send(ConnectionEvent::Proto(event)); } - Some(DatagramEvent::Response(transmit)) => { - self.resp.push_back((resp_buf, transmit)) - } + Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit), None => {} } } @@ -92,7 +89,7 @@ impl EndpointState { } fn is_idle(&self) -> bool { - self.connections.is_empty() && self.resp.is_empty() + self.connections.is_empty() } fn try_get_incoming(&mut self) -> Option> { @@ -153,7 +150,6 @@ impl EndpointInner { worker: None, connections: HashMap::new(), close: None, - resp: VecDeque::new(), incoming: VecDeque::new(), }), socket, @@ -196,6 +192,14 @@ impl EndpointInner { Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone())) } + fn respond(&self, buf: Vec, transmit: Transmit) { + let socket = self.socket.clone(); + compio_runtime::spawn(async move { + let _ = socket.send(buf, &transmit).await; + }) + .detach(); + } + pub(crate) fn accept( &self, incoming: quinn_proto::Incoming, @@ -213,7 +217,7 @@ impl EndpointInner { } Err(err) => { if let Some(transmit) = err.response { - state.resp.push_back((resp_buf, transmit)); + self.respond(resp_buf, transmit); } Err(err.cause) } @@ -224,7 +228,7 @@ impl EndpointInner { let mut state = self.state.lock().unwrap(); let mut resp_buf = Vec::new(); let transmit = state.endpoint.refuse(incoming, &mut resp_buf); - state.resp.push_back((resp_buf, transmit)); + self.respond(resp_buf, transmit); } pub(crate) fn retry( @@ -234,7 +238,7 @@ impl EndpointInner { let mut state = self.state.lock().unwrap(); let mut resp_buf = Vec::new(); let transmit = state.endpoint.retry(incoming, &mut resp_buf)?; - state.resp.push_back((resp_buf, transmit)); + self.respond(resp_buf, transmit); Ok(()) } @@ -259,13 +263,13 @@ impl EndpointInner { .fuse() ); - let mut resp_fut = pin!(Fuse::terminated()); + let respond_fn = |buf: Vec, transmit: Transmit| self.respond(buf, transmit); loop { select! { BufResult(res, recv_buf) = recv_fut => { match res { - Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf), + Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn), Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} Err(e) => break Err(e), } @@ -274,19 +278,12 @@ impl EndpointInner { (ch, event) = self.events.1.recv_async().map(Result::unwrap) => { self.state.lock().unwrap().handle_event(ch, event); }, - _ = resp_fut => {}, - } - - let mut state = self.state.lock().unwrap(); - if resp_fut.is_terminated() && !state.resp.is_empty() { - let (data, transmit) = state.resp.pop_front().unwrap(); - resp_fut.set(async move { self.socket.send(data, &transmit).await }.fuse()); } + let state = self.state.lock().unwrap(); if state.close.is_some() && state.is_idle() { break Ok(()); } - if !state.incoming.is_empty() { self.incoming.notify(state.incoming.len().additional()); } @@ -364,7 +361,7 @@ impl Endpoint { return incoming.map(|incoming| Incoming::new(incoming, self.inner.clone())); } - let listener = self.inner.incoming.listen(); + listener!(self.inner.incoming => listener); if let Some(incoming) = self.inner.state.lock().unwrap().try_get_incoming() { return incoming.map(|incoming| Incoming::new(incoming, self.inner.clone())); @@ -421,7 +418,7 @@ impl Endpoint { } let this = ManuallyDrop::new(self); - let inner = poll_fn(move |cx| { + let inner = future::poll_fn(move |cx| { if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } { return Poll::Ready(inner); } From d84891a387e5988c87afb38a296e3b43eba3d806 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Mon, 5 Aug 2024 21:45:49 +0800 Subject: [PATCH 04/26] feat(quic): datagram --- compio-quic/Cargo.toml | 1 + compio-quic/src/connection.rs | 157 +++++++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index cbb71bb1..1383297a 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -28,6 +28,7 @@ rustls-native-certs = { version = "0.7.1", optional = true } webpki-roots = { version = "0.26.3", optional = true } # Utils +bytes = "1.7.1" event-listener = "5.3.1" flume = { workspace = true } futures-util = { workspace = true } diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index 08b71078..bc2da142 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -7,8 +7,10 @@ use std::{ time::{Duration, Instant}, }; +use bytes::Bytes; use compio_buf::BufResult; use compio_runtime::JoinHandle; +use event_listener::{listener, Event, IntoNotification}; use flume::{Receiver, Sender}; use futures_util::{ future::{self, Fuse, FusedFuture, LocalBoxFuture}, @@ -18,6 +20,7 @@ use quinn_proto::{ congestion::Controller, crypto::rustls::HandshakeData, ConnectionError, ConnectionHandle, ConnectionStats, Dir, EndpointEvent, VarInt, }; +use thiserror::Error; use crate::Socket; @@ -66,6 +69,18 @@ impl ConnectionState { } } + #[inline] + fn try_map_mut( + &mut self, + f: impl Fn(&mut Self) -> Option, + ) -> Option> { + if let Some(error) = &self.error { + Some(Err(error.clone())) + } else { + f(self).map(Ok) + } + } + #[inline] fn try_handshake_data(&self) -> Option, ConnectionError>> { self.try_map(|state| { @@ -85,6 +100,8 @@ struct ConnectionInner { socket: Socket, events_tx: Sender<(ConnectionHandle, EndpointEvent)>, events_rx: Receiver, + datagram_received: Event, + datagrams_unblocked: Event, } impl ConnectionInner { @@ -109,14 +126,22 @@ impl ConnectionInner { socket, events_tx, events_rx, + datagram_received: Event::new(), + datagrams_unblocked: Event::new(), } } + fn notify_events(&self) { + self.datagram_received.notify(usize::MAX.additional()); + self.datagrams_unblocked.notify(usize::MAX.additional()); + } + fn close(&self, error_code: VarInt, reason: String) { let mut state = self.state.lock().unwrap(); state.conn.close(Instant::now(), error_code, reason.into()); state.terminate(ConnectionError::LocallyClosed); state.wake(); + self.notify_events(); } async fn run(&self) -> io::Result<()> { @@ -198,7 +223,16 @@ impl ConnectionInner { waker.wake() } } - ConnectionLost { reason } => state.terminate(reason), + ConnectionLost { reason } => { + state.terminate(reason); + self.notify_events(); + } + DatagramReceived => { + self.datagram_received.notify(usize::MAX.additional()); + } + DatagramsUnblocked => { + self.datagrams_unblocked.notify(usize::MAX.additional()); + } _ => {} } } @@ -486,6 +520,93 @@ impl Connection { self.0.state.lock().unwrap().error.clone().unwrap() } + + /// Receive an application datagram + pub async fn read_datagram(&self) -> Result, ConnectionError> { + loop { + if let Some(res) = self + .0 + .state + .lock() + .unwrap() + .try_map_mut(|state| state.conn.datagrams().recv().map(Into::into)) + { + return res; + } + + listener!(self.0.datagram_received => listener); + + if let Some(res) = self + .0 + .state + .lock() + .unwrap() + .try_map_mut(|state| state.conn.datagrams().recv().map(Into::into)) + { + return res; + } + + listener.await; + } + } + + fn try_send_datagram( + &self, + data: Bytes, + drop: bool, + ) -> Result, Bytes> { + let mut state = self.0.state.lock().unwrap(); + if let Some(err) = &state.error { + return Ok(Err(err.clone().into())); + } + match state.conn.datagrams().send(data, drop) { + Ok(()) => { + state.wake(); + Ok(Ok(())) + } + Err(e) => e.try_into().map(Err), + } + } + + /// Transmit `data` as an unreliable, unordered application datagram + /// + /// Application datagrams are a low-level primitive. They may be lost or + /// delivered out of order, and `data` must both fit inside a single + /// QUIC packet and be smaller than the maximum dictated by the peer. + pub fn send_datagram(&self, data: impl Into) -> Result<(), SendDatagramError> { + self.try_send_datagram(data.into(), true).unwrap() + } + + /// Transmit `data` as an unreliable, unordered application datagram + /// + /// Unlike [`send_datagram()`], this method will wait for buffer space + /// during congestion conditions, which effectively prioritizes old + /// datagrams over new datagrams. + /// + /// See [`send_datagram()`] for details. + /// + /// [`send_datagram()`]: Connection::send_datagram + pub async fn send_datagram_wait( + &self, + data: impl Into, + ) -> Result<(), SendDatagramError> { + let mut data = Some(data.into()); + loop { + match self.try_send_datagram(data.take().unwrap(), false) { + Ok(res) => return res, + Err(b) => data.replace(b), + }; + + listener!(self.0.datagrams_unblocked => listener); + + match self.try_send_datagram(data.take().unwrap(), false) { + Ok(res) => return res, + Err(b) => data.replace(b), + }; + + listener.await; + } + } } impl PartialEq for Connection { @@ -544,3 +665,37 @@ impl FusedFuture for Timer { self.fut.is_terminated() } } + +/// Errors that can arise when sending a datagram +#[derive(Debug, Error, Clone, Eq, PartialEq)] +pub enum SendDatagramError { + /// The peer does not support receiving datagram frames + #[error("datagrams not supported by peer")] + UnsupportedByPeer, + /// Datagram support is disabled locally + #[error("datagram support disabled")] + Disabled, + /// The datagram is larger than the connection can currently accommodate + /// + /// Indicates that the path MTU minus overhead or the limit advertised by + /// the peer has been exceeded. + #[error("datagram too large")] + TooLarge, + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), +} + +impl TryFrom for SendDatagramError { + type Error = Bytes; + + fn try_from(value: quinn_proto::SendDatagramError) -> Result { + use quinn_proto::SendDatagramError::*; + match value { + UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer), + Disabled => Ok(SendDatagramError::Disabled), + TooLarge => Ok(SendDatagramError::TooLarge), + Blocked(data) => Err(data), + } + } +} From e466fc03e6110c996da33a231c1d863e5f54abfb Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 8 Aug 2024 13:06:32 +0800 Subject: [PATCH 05/26] feat(quic): stream --- compio-quic/Cargo.toml | 2 + compio-quic/examples/client.rs | 15 +- compio-quic/examples/server.rs | 10 + compio-quic/src/connection.rs | 411 +++++++++++++++++----------- compio-quic/src/endpoint.rs | 27 +- compio-quic/src/lib.rs | 59 +++- compio-quic/src/recv_stream.rs | 478 +++++++++++++++++++++++++++++++++ compio-quic/src/send_stream.rs | 346 ++++++++++++++++++++++++ 8 files changed, 1165 insertions(+), 183 deletions(-) create mode 100644 compio-quic/src/recv_stream.rs create mode 100644 compio-quic/src/send_stream.rs diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 1383297a..025be868 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -16,6 +16,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] # Workspace dependencies +compio-io = { workspace = true } compio-buf = { workspace = true } compio-log = { workspace = true } compio-net = { workspace = true } @@ -53,6 +54,7 @@ tracing-subscriber = "0.3.18" [features] default = ["webpki-roots"] +futures-io = ["futures-util/io"] platform-verifier = ["dep:rustls-platform-verifier"] native-certs = ["dep:rustls-native-certs"] webpki-roots = ["dep:webpki-roots"] diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/client.rs index a243570f..36e3235c 100644 --- a/compio-quic/examples/client.rs +++ b/compio-quic/examples/client.rs @@ -27,7 +27,20 @@ async fn main() { .unwrap() .await .unwrap(); - conn.close(1u32.into(), "bye"); + + println!("Connected to {:?}", conn.remote_address()); + + let (mut send, mut recv) = conn.open_bi().unwrap(); + send.write(&[1, 2, 3]).await.unwrap(); + send.finish().unwrap(); + + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{:?}", buf); + + let _ = dbg!(send.write(&[1, 2, 3]).await); + + conn.close(1u32.into(), "qaq"); conn.closed().await; } endpoint.close(0u32.into(), "").await.unwrap(); diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs index 98bb9bdc..8a727177 100644 --- a/compio-quic/examples/server.rs +++ b/compio-quic/examples/server.rs @@ -22,6 +22,16 @@ async fn main() { if let Some(incoming) = endpoint.wait_incoming().await { let conn = incoming.await.unwrap(); + + let (mut send, mut recv) = conn.accept_bi().await.unwrap(); + + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{:?}", buf); + + send.write(&[4, 5, 6]).await.unwrap(); + send.finish().unwrap(); + conn.closed().await; } diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index bc2da142..a372c96f 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -1,8 +1,9 @@ use std::{ + collections::HashMap, io, net::{IpAddr, SocketAddr}, pin::{pin, Pin}, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll, Waker}, time::{Duration, Instant}, }; @@ -10,7 +11,7 @@ use std::{ use bytes::Bytes; use compio_buf::BufResult; use compio_runtime::JoinHandle; -use event_listener::{listener, Event, IntoNotification}; +use event_listener::{Event, IntoNotification}; use flume::{Receiver, Sender}; use futures_util::{ future::{self, Fuse, FusedFuture, LocalBoxFuture}, @@ -18,11 +19,11 @@ use futures_util::{ }; use quinn_proto::{ congestion::Controller, crypto::rustls::HandshakeData, ConnectionError, ConnectionHandle, - ConnectionStats, Dir, EndpointEvent, VarInt, + ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt, }; use thiserror::Error; -use crate::Socket; +use crate::{wait_event, RecvStream, SendStream, Socket}; #[derive(Debug)] pub(crate) enum ConnectionEvent { @@ -31,14 +32,17 @@ pub(crate) enum ConnectionEvent { } #[derive(Debug)] -struct ConnectionState { - conn: quinn_proto::Connection, +pub(crate) struct ConnectionState { + pub(crate) conn: quinn_proto::Connection, + pub(crate) error: Option, connected: bool, - error: Option, worker: Option>, poll_waker: Option, on_connected: Option, on_handshake_data: Option, + pub(crate) writable: HashMap, + pub(crate) readable: HashMap, + pub(crate) stopped: HashMap, } impl ConnectionState { @@ -46,55 +50,43 @@ impl ConnectionState { self.error = Some(reason); self.connected = false; - if let Some(waker) = self.on_connected.take() { + if let Some(waker) = self.on_handshake_data.take() { waker.wake() } - if let Some(waker) = self.on_handshake_data.take() { + if let Some(waker) = self.on_connected.take() { waker.wake() } + wake_all_streams(&mut self.writable); + wake_all_streams(&mut self.readable); + wake_all_streams(&mut self.stopped); } - fn wake(&mut self) { + pub(crate) fn wake(&mut self) { if let Some(waker) = self.poll_waker.take() { waker.wake() } } - #[inline] - fn try_map(&self, f: impl Fn(&Self) -> Option) -> Option> { - if let Some(error) = &self.error { - Some(Err(error.clone())) - } else { - f(self).map(Ok) - } + fn handshake_data(&self) -> Option> { + self.conn + .crypto_session() + .handshake_data() + .map(|data| data.downcast::().unwrap()) } +} - #[inline] - fn try_map_mut( - &mut self, - f: impl Fn(&mut Self) -> Option, - ) -> Option> { - if let Some(error) = &self.error { - Some(Err(error.clone())) - } else { - f(self).map(Ok) - } +fn wake_stream(stream: StreamId, wakers: &mut HashMap) { + if let Some(waker) = wakers.remove(&stream) { + waker.wake(); } +} - #[inline] - fn try_handshake_data(&self) -> Option, ConnectionError>> { - self.try_map(|state| { - state - .conn - .crypto_session() - .handshake_data() - .map(|data| data.downcast::().unwrap()) - }) - } +fn wake_all_streams(wakers: &mut HashMap) { + wakers.drain().for_each(|(_, waker)| waker.wake()) } #[derive(Debug)] -struct ConnectionInner { +pub(crate) struct ConnectionInner { state: Mutex, handle: ConnectionHandle, socket: Socket, @@ -102,6 +94,8 @@ struct ConnectionInner { events_rx: Receiver, datagram_received: Event, datagrams_unblocked: Event, + stream_opened: [Event; 2], + stream_available: [Event; 2], } impl ConnectionInner { @@ -121,6 +115,9 @@ impl ConnectionInner { poll_waker: None, on_connected: None, on_handshake_data: None, + writable: HashMap::new(), + readable: HashMap::new(), + stopped: HashMap::new(), }), handle, socket, @@ -128,16 +125,39 @@ impl ConnectionInner { events_rx, datagram_received: Event::new(), datagrams_unblocked: Event::new(), + stream_opened: [Event::new(), Event::new()], + stream_available: [Event::new(), Event::new()], + } + } + + #[inline] + pub(crate) fn state(&self) -> MutexGuard { + self.state.lock().unwrap() + } + + #[inline] + pub(crate) fn try_state(&self) -> Result, ConnectionError> { + let state = self.state(); + if let Some(error) = &state.error { + Err(error.clone()) + } else { + Ok(state) } } fn notify_events(&self) { self.datagram_received.notify(usize::MAX.additional()); self.datagrams_unblocked.notify(usize::MAX.additional()); + for e in &self.stream_opened { + e.notify(usize::MAX.additional()); + } + for e in &self.stream_available { + e.notify(usize::MAX.additional()); + } } fn close(&self, error_code: VarInt, reason: String) { - let mut state = self.state.lock().unwrap(); + let mut state = self.state(); state.conn.close(Instant::now(), error_code, reason.into()); state.terminate(ConnectionError::LocallyClosed); state.wake(); @@ -145,15 +165,13 @@ impl ConnectionInner { } async fn run(&self) -> io::Result<()> { - let mut send_buf = Some(Vec::with_capacity( - self.state.lock().unwrap().conn.current_mtu() as usize, - )); + let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize)); let mut transmit_fut = pin!(Fuse::terminated()); let mut timer = Timer::new(); let mut poller = stream::poll_fn(|cx| { - let mut state = self.state.lock().unwrap(); + let mut state = self.state(); let ready = state.poll_waker.is_none(); match &state.poll_waker { Some(waker) if waker.will_wake(cx.waker()) => {} @@ -171,12 +189,12 @@ impl ConnectionInner { select! { _ = poller.next() => {} _ = timer => { - self.state.lock().unwrap().conn.handle_timeout(Instant::now()); + self.state().conn.handle_timeout(Instant::now()); timer.reset(None); } ev = self.events_rx.recv_async() => match ev { Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason), - Ok(ConnectionEvent::Proto(ev)) => self.state.lock().unwrap().conn.handle_event(ev), + Ok(ConnectionEvent::Proto(ev)) => self.state().conn.handle_event(ev), Err(_) => unreachable!("endpoint dropped connection"), }, BufResult::<(), Vec>(res, mut buf) = transmit_fut => match res { @@ -189,7 +207,7 @@ impl ConnectionInner { } let now = Instant::now(); - let mut state = self.state.lock().unwrap(); + let mut state = self.state(); if let Some(mut buf) = send_buf.take() { if let Some(transmit) = @@ -227,13 +245,25 @@ impl ConnectionInner { state.terminate(reason); self.notify_events(); } + Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable), + Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable), + Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped), + Stream(StreamEvent::Stopped { id, .. }) => { + wake_stream(id, &mut state.stopped); + wake_stream(id, &mut state.writable); + } + Stream(StreamEvent::Available { dir }) => { + self.stream_available[dir as usize].notify(usize::MAX.additional()); + } + Stream(StreamEvent::Opened { dir }) => { + self.stream_opened[dir as usize].notify(usize::MAX.additional()); + } DatagramReceived => { self.datagram_received.notify(usize::MAX.additional()); } DatagramsUnblocked => { self.datagrams_unblocked.notify(usize::MAX.additional()); } - _ => {} } } @@ -255,36 +285,30 @@ macro_rules! conn_fn { /// This will return `None` for clients, or when the platform does not /// expose this information. pub fn local_ip(&self) -> Option { - self.0.state.lock().unwrap().conn.local_ip() + self.0.state().conn.local_ip() } /// The peer's UDP address. /// /// Will panic if called after `poll` has returned `Ready`. pub fn remote_address(&self) -> SocketAddr { - self.0.state.lock().unwrap().conn.remote_address() + self.0.state().conn.remote_address() } /// Current best estimate of this connection's latency (round-trip-time). pub fn rtt(&self) -> Duration { - self.0.state.lock().unwrap().conn.rtt() + self.0.state().conn.rtt() } /// Connection statistics. pub fn stats(&self) -> ConnectionStats { - self.0.state.lock().unwrap().conn.stats() + self.0.state().conn.stats() } /// Current state of the congestion control algorithm. (For debugging /// purposes) pub fn congestion_state(&self) -> Box { - self.0 - .state - .lock() - .unwrap() - .conn - .congestion_state() - .clone_box() + self.0.state().conn.congestion_state().clone_box() } /// Cryptographic identity of the peer. @@ -292,9 +316,7 @@ macro_rules! conn_fn { &self, ) -> Option>>> { self.0 - .state - .lock() - .unwrap() + .state() .conn .crypto_session() .peer_identity() @@ -319,9 +341,7 @@ macro_rules! conn_fn { context: &[u8], ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> { self.0 - .state - .lock() - .unwrap() + .state() .conn .crypto_session() .export_keying_material(output, label, context) @@ -351,16 +371,16 @@ impl Connecting { let inner = inner.clone(); async move { inner.run().await.unwrap() } }); - inner.state.lock().unwrap().worker = Some(worker); + inner.state().worker = Some(worker); Self(inner) } /// Parameters negotiated during the handshake. pub async fn handshake_data(&mut self) -> Result, ConnectionError> { future::poll_fn(|cx| { - let mut state = self.0.state.lock().unwrap(); - if let Some(res) = state.try_handshake_data() { - return Poll::Ready(res); + let mut state = self.0.try_state()?; + if let Some(data) = state.handshake_data() { + return Poll::Ready(Ok(data)); } match &state.on_handshake_data { @@ -368,11 +388,7 @@ impl Connecting { _ => state.on_handshake_data = Some(cx.waker().clone()), } - if let Some(res) = state.try_handshake_data() { - Poll::Ready(res) - } else { - Poll::Pending - } + Poll::Pending }) .await } @@ -382,12 +398,10 @@ impl Future for Connecting { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut state = self.0.state.lock().unwrap(); + let mut state = self.0.try_state()?; - if let Some(res) = - state.try_map(|state| state.connected.then(|| Connection(self.0.clone()))) - { - return Poll::Ready(res); + if state.connected { + return Poll::Ready(Ok(Connection(self.0.clone()))); } match &state.on_connected { @@ -395,13 +409,7 @@ impl Future for Connecting { _ => state.on_connected = Some(cx.waker().clone()), } - if let Some(res) = - state.try_map(|state| state.connected.then(|| Connection(self.0.clone()))) - { - Poll::Ready(res) - } else { - Poll::Pending - } + Poll::Pending } } @@ -422,7 +430,7 @@ impl Connection { /// Parameters negotiated during the handshake. pub fn handshake_data(&mut self) -> Result, ConnectionError> { - self.0.state.lock().unwrap().try_handshake_data().unwrap() + Ok(self.0.try_state()?.handshake_data().unwrap()) } /// Compute the maximum size of datagrams that may be passed to @@ -438,7 +446,7 @@ impl Connection { /// /// Not necessarily the maximum size of received datagrams. pub fn max_datagram_size(&self) -> Option { - self.0.state.lock().unwrap().conn.datagrams().max_size() + self.0.state().conn.datagrams().max_size() } /// Bytes available in the outgoing datagram buffer. @@ -447,23 +455,17 @@ impl Connection { /// with a datagram of at most this size is guaranteed not to cause older /// datagrams to be dropped. pub fn datagram_send_buffer_space(&self) -> usize { - self.0 - .state - .lock() - .unwrap() - .conn - .datagrams() - .send_buffer_space() + self.0.state().conn.datagrams().send_buffer_space() } /// Modify the number of remotely initiated unidirectional streams that may - /// be concurrently open + /// be concurrently open. /// /// No streams may be opened by the peer unless fewer than `count` are /// already open. Large `count`s increase both minimum and worst-case /// memory consumption. pub fn set_max_concurrent_uni_streams(&self, count: VarInt) { - let mut state = self.0.state.lock().unwrap(); + let mut state = self.0.state(); state.conn.set_max_concurrent_streams(Dir::Uni, count); // May need to send MAX_STREAMS to make progress state.wake(); @@ -471,19 +473,19 @@ impl Connection { /// See [`quinn_proto::TransportConfig::receive_window()`] pub fn set_receive_window(&self, receive_window: VarInt) { - let mut state = self.0.state.lock().unwrap(); + let mut state = self.0.state(); state.conn.set_receive_window(receive_window); state.wake(); } /// Modify the number of remotely initiated bidirectional streams that may - /// be concurrently open + /// be concurrently open. /// /// No streams may be opened by the peer unless fewer than `count` are /// already open. Large `count`s increase both minimum and worst-case /// memory consumption. pub fn set_max_concurrent_bi_streams(&self, count: VarInt) { - let mut state = self.0.state.lock().unwrap(); + let mut state = self.0.state(); state.conn.set_max_concurrent_streams(Dir::Bi, count); // May need to send MAX_STREAMS to make progress state.wake(); @@ -511,73 +513,52 @@ impl Connection { self.0.close(error_code, reason.to_string()); } - /// Wait for the connection to be closed for any reason + /// Wait for the connection to be closed for any reason. pub async fn closed(&self) -> ConnectionError { - let worker = self.0.state.lock().unwrap().worker.take(); + let worker = self.0.state().worker.take(); if let Some(worker) = worker { let _ = worker.await; } - self.0.state.lock().unwrap().error.clone().unwrap() + self.0.state().error.clone().unwrap() } - /// Receive an application datagram - pub async fn read_datagram(&self) -> Result, ConnectionError> { - loop { - if let Some(res) = self - .0 - .state - .lock() - .unwrap() - .try_map_mut(|state| state.conn.datagrams().recv().map(Into::into)) - { - return res; - } - - listener!(self.0.datagram_received => listener); - - if let Some(res) = self - .0 - .state - .lock() - .unwrap() - .try_map_mut(|state| state.conn.datagrams().recv().map(Into::into)) - { - return res; + /// Receive an application datagram. + pub async fn recv_datagram(&self) -> Result { + let bytes = wait_event!( + self.0.datagram_received, + if let Some(bytes) = self.0.try_state()?.conn.datagrams().recv() { + break bytes; } - - listener.await; - } + ); + Ok(bytes) } fn try_send_datagram( &self, data: Bytes, drop: bool, - ) -> Result, Bytes> { - let mut state = self.0.state.lock().unwrap(); - if let Some(err) = &state.error { - return Ok(Err(err.clone().into())); - } - match state.conn.datagrams().send(data, drop) { - Ok(()) => { - state.wake(); - Ok(Ok(())) - } - Err(e) => e.try_into().map(Err), - } + ) -> Result<(), Result> { + let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?; + state + .conn + .datagrams() + .send(data, drop) + .map_err(TryInto::try_into)?; + state.wake(); + Ok(()) } - /// Transmit `data` as an unreliable, unordered application datagram + /// Transmit `data` as an unreliable, unordered application datagram. /// /// Application datagrams are a low-level primitive. They may be lost or /// delivered out of order, and `data` must both fit inside a single /// QUIC packet and be smaller than the maximum dictated by the peer. - pub fn send_datagram(&self, data: impl Into) -> Result<(), SendDatagramError> { - self.try_send_datagram(data.into(), true).unwrap() + pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> { + self.try_send_datagram(data, true).map_err(Result::unwrap) } - /// Transmit `data` as an unreliable, unordered application datagram + /// Transmit `data` as an unreliable, unordered application datagram. /// /// Unlike [`send_datagram()`], this method will wait for buffer space /// during congestion conditions, which effectively prioritizes old @@ -586,26 +567,125 @@ impl Connection { /// See [`send_datagram()`] for details. /// /// [`send_datagram()`]: Connection::send_datagram - pub async fn send_datagram_wait( - &self, - data: impl Into, - ) -> Result<(), SendDatagramError> { - let mut data = Some(data.into()); - loop { + pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> { + let mut data = Some(data); + wait_event!( + self.0.datagrams_unblocked, match self.try_send_datagram(data.take().unwrap(), false) { - Ok(res) => return res, - Err(b) => data.replace(b), - }; + Ok(res) => break Ok(res), + Err(Ok(e)) => break Err(e), + Err(Err(b)) => data.replace(b), + } + ) + } - listener!(self.0.datagrams_unblocked => listener); + fn try_open_stream(&self, dir: Dir) -> Result { + self.0 + .try_state()? + .conn + .streams() + .open(dir) + .ok_or(OpenStreamError::StreamsExhausted) + } + + async fn open_stream(&self, dir: Dir) -> Result { + wait_event!( + self.0.stream_available[dir as usize], + match self.try_open_stream(dir) { + Ok(stream) => break Ok(stream), + Err(OpenStreamError::StreamsExhausted) => {} + Err(OpenStreamError::ConnectionLost(e)) => break Err(e), + } + ) + } - match self.try_send_datagram(data.take().unwrap(), false) { - Ok(res) => return res, - Err(b) => data.replace(b), - }; + /// Initiate a new outgoing unidirectional stream. + /// + /// Streams are cheap and instantaneous to open. As a consequence, the peer + /// won't be notified that a stream has been opened until the stream is + /// actually used. + pub fn open_uni(&self) -> Result { + let stream = self.try_open_stream(Dir::Uni)?; + Ok(SendStream::new(self.0.clone(), stream)) + } - listener.await; - } + /// Initiate a new outgoing unidirectional stream. + /// + /// Unlike [`open_uni()`], this method will wait for the connection to allow + /// a new stream to be opened. + /// + /// See [`open_uni()`] for details. + /// + /// [`open_uni()`]: crate::Connection::open_uni + pub async fn open_uni_wait(&self) -> Result { + let stream = self.open_stream(Dir::Uni).await?; + Ok(SendStream::new(self.0.clone(), stream)) + } + + /// Initiate a new outgoing bidirectional stream. + /// + /// Streams are cheap and instantaneous to open. As a consequence, the peer + /// won't be notified that a stream has been opened until the stream is + /// actually used. + pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> { + let stream = self.try_open_stream(Dir::Bi)?; + Ok(( + SendStream::new(self.0.clone(), stream), + RecvStream::new(self.0.clone(), stream), + )) + } + + /// Initiate a new outgoing bidirectional stream. + /// + /// Unlike [`open_bi()`], this method will wait for the connection to allow + /// a new stream to be opened. + /// + /// See [`open_bi()`] for details. + /// + /// [`open_bi()`]: crate::Connection::open_bi + pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> { + let stream = self.open_stream(Dir::Bi).await?; + Ok(( + SendStream::new(self.0.clone(), stream), + RecvStream::new(self.0.clone(), stream), + )) + } + + async fn accept_stream(&self, dir: Dir) -> Result { + wait_event!(self.0.stream_opened[dir as usize], { + let mut state = self.0.state(); + if let Some(stream) = state.conn.streams().accept(dir) { + state.wake(); + break Ok(stream); + } else if let Some(error) = &state.error { + break Err(error.clone()); + } + }) + } + + /// Accept the next incoming uni-directional stream + pub async fn accept_uni(&self) -> Result { + let stream = self.accept_stream(Dir::Uni).await?; + Ok(RecvStream::new(self.0.clone(), stream)) + } + + /// Accept the next incoming bidirectional stream + /// + /// **Important Note**: The `Connection` that calls [`open_bi()`] must write + /// to its [`SendStream`] before the other `Connection` is able to + /// `accept_bi()`. Calling [`open_bi()`] then waiting on the [`RecvStream`] + /// without writing anything to [`SendStream`] will never succeed. + /// + /// [`accept_bi()`]: crate::Connection::accept_bi + /// [`open_bi()`]: crate::Connection::open_bi + /// [`SendStream`]: crate::SendStream + /// [`RecvStream`]: crate::RecvStream + pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { + let stream = self.accept_stream(Dir::Bi).await?; + Ok(( + SendStream::new(self.0.clone(), stream), + RecvStream::new(self.0.clone(), stream), + )) } } @@ -699,3 +779,14 @@ impl TryFrom for SendDatagramError { } } } + +/// Errors that can arise when trying to open a stream +#[derive(Debug, Error, Clone, Eq, PartialEq)] +pub enum OpenStreamError { + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + // The streams in the given direction are currently exhausted + #[error("streams exhausted")] + StreamsExhausted, +} diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 2ca6b59d..78448960 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -12,7 +12,7 @@ use std::{ use compio_buf::BufResult; use compio_net::UdpSocket; use compio_runtime::JoinHandle; -use event_listener::{listener, Event, IntoNotification}; +use event_listener::{Event, IntoNotification}; use flume::{unbounded, Receiver, Sender}; use futures_util::{ future::{self}, @@ -26,7 +26,8 @@ use quinn_proto::{ }; use crate::{ - ClientBuilder, Connecting, ConnectionEvent, Incoming, RecvMeta, ServerBuilder, Socket, + wait_event, ClientBuilder, Connecting, ConnectionEvent, Incoming, RecvMeta, ServerBuilder, + Socket, }; #[derive(Debug)] @@ -94,9 +95,9 @@ impl EndpointState { fn try_get_incoming(&mut self) -> Option> { if self.close.is_none() { - self.incoming.pop_front().map(Some) + Some(self.incoming.pop_front()) } else { - Some(None) + None } } @@ -356,19 +357,13 @@ impl Endpoint { /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT /// data. pub async fn wait_incoming(&self) -> Option { - loop { - if let Some(incoming) = self.inner.state.lock().unwrap().try_get_incoming() { - return incoming.map(|incoming| Incoming::new(incoming, self.inner.clone())); - } - - listener!(self.inner.incoming => listener); - - if let Some(incoming) = self.inner.state.lock().unwrap().try_get_incoming() { - return incoming.map(|incoming| Incoming::new(incoming, self.inner.clone())); + let incoming = wait_event!( + self.inner.incoming, + if let Some(res) = self.inner.state.lock().unwrap().try_get_incoming()? { + break res; } - - listener.await; - } + ); + Some(Incoming::new(incoming, self.inner.clone())) } // Modified from [`SharedFd::try_unwrap_inner`], see notes there. diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs index 7d23f7d2..9144feb1 100644 --- a/compio-quic/src/lib.rs +++ b/compio-quic/src/lib.rs @@ -17,11 +17,58 @@ mod builder; mod connection; mod endpoint; mod incoming; +mod recv_stream; +mod send_stream; mod socket; -pub use builder::*; -pub(crate) use connection::ConnectionEvent; -pub use connection::*; -pub use endpoint::*; -pub use incoming::*; -pub(crate) use socket::*; +pub use builder::{ClientBuilder, ServerBuilder}; +pub use connection::{Connecting, Connection}; +pub use endpoint::Endpoint; +pub use incoming::{Incoming, IncomingFuture}; +pub use recv_stream::{ReadError, RecvStream}; +pub use send_stream::{SendStream, WriteError}; + +pub(crate) use crate::{ + connection::{ConnectionEvent, ConnectionInner}, + endpoint::EndpointInner, + socket::*, +}; + +/// Errors from [`SendStream::stopped`] and [`RecvStream::stopped`]. +#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)] +pub enum StoppedError { + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + /// This was a 0-RTT stream and the server rejected it + /// + /// Can only occur on clients for 0-RTT streams, which can be opened using + /// [`Connecting::into_0rtt()`]. + /// + /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() + #[error("0-RTT rejected")] + ZeroRttRejected, +} + +impl From for std::io::Error { + fn from(x: StoppedError) -> Self { + use StoppedError::*; + let kind = match x { + ZeroRttRejected => std::io::ErrorKind::ConnectionReset, + ConnectionLost(_) => std::io::ErrorKind::NotConnected, + }; + Self::new(kind, x) + } +} + +macro_rules! wait_event { + ($event:expr, $break:expr) => { + loop { + $break; + event_listener::listener!($event => listener); + $break; + listener.await; + } + }; +} +pub(crate) use wait_event; diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs new file mode 100644 index 00000000..e055bcfd --- /dev/null +++ b/compio-quic/src/recv_stream.rs @@ -0,0 +1,478 @@ +use std::{ + collections::BTreeMap, + io, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::{BufMut, Bytes}; +use compio_buf::{BufResult, IoBufMut}; +use compio_io::AsyncRead; +use futures_util::future::poll_fn; +use quinn_proto::{Chunk, Chunks, ClosedStream, ConnectionError, ReadableError, StreamId, VarInt}; +use thiserror::Error; + +use crate::{ConnectionInner, StoppedError}; + +/// A stream that can only be used to receive data +/// +/// `stop(0)` is implicitly called on drop unless: +/// - A variant of [`ReadError`] has been yielded by a read call +/// - [`stop()`] was called explicitly +/// +/// # Cancellation +/// +/// A `read` method is said to be *cancel-safe* when dropping its future before +/// the future becomes ready cannot lead to loss of stream data. This is true of +/// methods which succeed immediately when any progress is made, and is not true +/// of methods which might need to perform multiple reads internally before +/// succeeding. Each `read` method documents whether it is cancel-safe. +/// +/// # Common issues +/// +/// ## Data never received on a locally-opened stream +/// +/// Peers are not notified of streams until they or a later-numbered stream are +/// used to send data. If a bidirectional stream is locally opened but never +/// used to send, then the peer may never see it. Application protocols should +/// always arrange for the endpoint which will first transmit on a stream to be +/// the endpoint responsible for opening it. +/// +/// ## Data never received on a remotely-opened stream +/// +/// Verify that the stream you are receiving is the same one that the server is +/// sending on, e.g. by logging the [`id`] of each. Streams are always accepted +/// in the same order as they are created, i.e. ascending order by [`StreamId`]. +/// For example, even if a sender first transmits on bidirectional stream 1, the +/// first stream yielded by [`Connection::accept_bi`] on the receiver +/// will be bidirectional stream 0. +/// +/// [`stop()`]: RecvStream::stop +/// [`id`]: RecvStream::id +/// [`Connection::accept_bi`]: crate::Connection::accept_bi +#[derive(Debug)] +pub struct RecvStream { + conn: Arc, + stream: StreamId, + all_data_read: bool, + reset: Option, +} + +impl RecvStream { + pub(crate) fn new(conn: Arc, stream: StreamId) -> Self { + Self { + conn, + stream, + all_data_read: false, + reset: None, + } + } + + /// Get the identity of this stream + pub fn id(&self) -> StreamId { + self.stream + } + + /// Stop accepting data + /// + /// Discards unread data and notifies the peer to stop transmitting. Once + /// stopped, further attempts to operate on a stream will yield + /// `ClosedStream` errors. + pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { + let mut state = self.conn.state(); + state.conn.recv_stream(self.stream).stop(error_code)?; + state.wake(); + self.all_data_read = true; + Ok(()) + } + + /// Completes when the stream has been reset by the peer or otherwise + /// closed. + /// + /// Yields `Some` with the reset error code when the stream is reset by the + /// peer. Yields `None` when the stream was previously + /// [`stop()`](Self::stop)ed, or when the stream was + /// [`finish()`](crate::SendStream::finish)ed by the peer and all data has + /// been received, after which it is no longer meaningful for the stream to + /// be reset. + /// + /// This operation is cancel-safe. + pub async fn stopped(&mut self) -> Result, StoppedError> { + poll_fn(|cx| { + let mut state = self.conn.state(); + + if let Some(code) = self.reset { + return Poll::Ready(Ok(Some(code))); + } + + match state.conn.recv_stream(self.stream).received_reset() { + Err(_) => Poll::Ready(Ok(None)), + Ok(Some(error_code)) => { + // Stream state has just now been freed, so the connection may need to issue new + // stream ID flow control credit + state.wake(); + Poll::Ready(Ok(Some(error_code))) + } + Ok(None) => { + if let Some(e) = &state.error { + return Poll::Ready(Err(e.clone().into())); + } + // Resets always notify readers, since a reset is an immediate read error. We + // could introduce a dedicated channel to reduce the risk of spurious wakeups, + // but that increased complexity is probably not justified, as an application + // that is expecting a reset is not likely to receive large amounts of data. + state.readable.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + } + }) + .await + } + + /// Handle common logic related to reading out of a receive stream. + /// + /// This takes an `FnMut` closure that takes care of the actual reading + /// process, matching the detailed read semantics for the calling + /// function with a particular return type. The closure can read from + /// the passed `&mut Chunks` and has to return the status after reading: + /// the amount of data read, and the status after the final read call. + fn execute_poll_read( + &mut self, + cx: &mut Context, + ordered: bool, + mut read_fn: F, + ) -> Poll, ReadError>> + where + F: FnMut(&mut Chunks) -> ReadStatus, + { + use quinn_proto::ReadError::*; + + if self.all_data_read { + return Poll::Ready(Ok(None)); + } + + let mut state = self.conn.state(); + + // If we stored an error during a previous call, return it now. This can happen + // if a `read_fn` both wants to return data and also returns an error in + // its final stream status. + let status = match self.reset { + Some(code) => ReadStatus::Failed(None, Reset(code)), + None => { + let mut recv = state.conn.recv_stream(self.stream); + let mut chunks = recv.read(ordered)?; + let status = read_fn(&mut chunks); + if chunks.finalize().should_transmit() { + state.wake(); + } + status + } + }; + + match status { + ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))), + ReadStatus::Finished(read) => { + self.all_data_read = true; + Poll::Ready(Ok(read)) + } + ReadStatus::Failed(read, Blocked) => match read { + Some(val) => Poll::Ready(Ok(Some(val))), + None => { + if let Some(error) = &state.error { + return Poll::Ready(Err(error.clone().into())); + } + state.readable.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + }, + ReadStatus::Failed(read, Reset(error_code)) => match read { + None => { + self.all_data_read = true; + self.reset = Some(error_code); + Poll::Ready(Err(ReadError::Reset(error_code))) + } + done => { + self.reset = Some(error_code); + Poll::Ready(Ok(done)) + } + }, + } + } + + fn poll_read( + &mut self, + cx: &mut Context, + mut buf: impl BufMut, + ) -> Poll, ReadError>> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(Some(0))); + } + + self.execute_poll_read(cx, true, |chunks| { + let mut read = 0; + loop { + if !buf.has_remaining_mut() { + // We know `read` is `true` because `buf.remaining()` was not 0 before + return ReadStatus::Readable(read); + } + + match chunks.next(buf.remaining_mut()) { + Ok(Some(chunk)) => { + read += chunk.bytes.len(); + buf.put(chunk.bytes); + } + res => { + return (if read == 0 { None } else { Some(read) }, res.err()).into(); + } + } + } + }) + } + + /// Read data contiguously from the stream. + /// + /// Yields the number of bytes read into `buf` on success, or `None` if the + /// stream was finished. + /// + /// This operation is cancel-safe. + pub async fn read(&mut self, mut buf: impl BufMut) -> Result, ReadError> { + poll_fn(|cx| self.poll_read(cx, &mut buf)).await + } + + /// Read the next segment of data. + /// + /// Yields `None` if the stream was finished. Otherwise, yields a segment of + /// data and its offset in the stream. If `ordered` is `true`, the chunk's + /// offset will be immediately after the last data yielded by + /// [`read()`](Self::read) or [`read_chunk()`](Self::read_chunk). If + /// `ordered` is `false`, segments may be received in any order, and the + /// `Chunk`'s `offset` field can be used to determine ordering in the + /// caller. Unordered reads are less prone to head-of-line blocking within a + /// stream, but require the application to manage reassembling the original + /// data. + /// + /// Slightly more efficient than `read` due to not copying. Chunk boundaries + /// do not correspond to peer writes, and hence cannot be used as framing. + /// + /// This operation is cancel-safe. + pub async fn read_chunk( + &mut self, + max_length: usize, + ordered: bool, + ) -> Result, ReadError> { + poll_fn(|cx| { + self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) { + Ok(Some(chunk)) => ReadStatus::Readable(chunk), + res => (None, res.err()).into(), + }) + }) + .await + } + + /// Read the next segments of data. + /// + /// Fills `bufs` with the segments of data beginning immediately after the + /// last data yielded by `read` or `read_chunk`, or `None` if the stream was + /// finished. + /// + /// Slightly more efficient than `read` due to not copying. Chunk boundaries + /// do not correspond to peer writes, and hence cannot be used as framing. + /// + /// This operation is cancel-safe. + pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result, ReadError> { + if bufs.is_empty() { + return Ok(Some(0)); + } + + poll_fn(|cx| { + self.execute_poll_read(cx, true, |chunks| { + let mut read = 0; + loop { + if read >= bufs.len() { + // We know `read > 0` because `bufs` cannot be empty here + return ReadStatus::Readable(read); + } + + match chunks.next(usize::MAX) { + Ok(Some(chunk)) => { + bufs[read] = chunk.bytes; + read += 1; + } + res => { + return (if read == 0 { None } else { Some(read) }, res.err()).into(); + } + } + } + }) + }) + .await + } + + /// Convenience method to read all remaining data into a buffer. + /// + /// Uses unordered reads to be more efficient than using [`AsyncRead`]. If + /// unordered reads have already been made, the resulting buffer may have + /// gaps containing zero. + /// + /// Depending on [`BufMut`] implementation, this method may fail with + /// [`ReadError::BufferTooShort`] if the buffer is not large enough to + /// hold the entire stream. For example when using a `&mut [u8]` it will + /// never receive bytes more than the length of the slice, but when using a + /// `&mut Vec` it will allocate more memory as needed. + /// + /// This operation is *not* cancel-safe. + pub async fn read_to_end(&mut self, mut buf: impl BufMut) -> Result { + let mut start = u64::MAX; + let mut end = 0; + let mut chunks = BTreeMap::new(); + loop { + let Some(chunk) = self.read_chunk(usize::MAX, false).await? else { + break; + }; + start = start.min(chunk.offset); + end = end.max(chunk.offset + chunk.bytes.len() as u64); + if end - start > buf.remaining_mut() as u64 { + return Err(ReadError::BufferTooShort); + } + chunks.insert(chunk.offset, chunk.bytes); + } + let mut last = 0; + for (offset, bytes) in chunks { + let offset = (offset - start) as usize; + if offset > last { + buf.put_bytes(0, offset - last); + } + last = offset + bytes.len(); + buf.put(bytes); + } + Ok((end - start) as usize) + } +} + +impl Drop for RecvStream { + fn drop(&mut self) { + let mut state = self.conn.state(); + + // clean up any previously registered wakers + state.readable.remove(&self.stream); + + if state.error.is_some() { + return; + } + if !self.all_data_read { + // Ignore ClosedStream errors + let _ = state.conn.recv_stream(self.stream).stop(0u32.into()); + state.wake(); + } + } +} + +enum ReadStatus { + Readable(T), + Finished(Option), + Failed(Option, quinn_proto::ReadError), +} + +impl From<(Option, Option)> for ReadStatus { + fn from(status: (Option, Option)) -> Self { + match status { + (read, None) => Self::Finished(read), + (read, Some(e)) => Self::Failed(read, e), + } + } +} + +/// Errors that arise from reading from a stream. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ReadError { + /// The peer abandoned transmitting data on this stream. + /// + /// Carries an application-defined error code. + #[error("stream reset by peer: error {0}")] + Reset(VarInt), + /// The connection was lost. + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + /// The stream has already been stopped, finished, or reset. + #[error("closed stream")] + ClosedStream, + /// Attempted an ordered read following an unordered read. + /// + /// Performing an unordered read allows discontinuities to arise in the + /// receive buffer of a stream which cannot be recovered, making further + /// ordered reads impossible. + #[error("ordered read after unordered read")] + IllegalOrderedRead, + /// This was a 0-RTT stream and the server rejected it. + /// + /// Can only occur on clients for 0-RTT streams, which can be opened using + /// [`Connecting::into_0rtt()`]. + /// + /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() + #[error("0-RTT rejected")] + ZeroRttRejected, + /// The stream is larger than the user-supplied buffer capacity. + /// + /// Can only occur when using [`read_to_end()`](RecvStream::read_to_end). + #[error("buffer too short")] + BufferTooShort, +} + +impl From for ReadError { + fn from(e: ReadableError) -> Self { + match e { + ReadableError::ClosedStream => Self::ClosedStream, + ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead, + } + } +} + +impl From for ReadError { + fn from(e: StoppedError) -> Self { + match e { + StoppedError::ConnectionLost(e) => Self::ConnectionLost(e), + StoppedError::ZeroRttRejected => Self::ZeroRttRejected, + } + } +} + +impl From for io::Error { + fn from(x: ReadError) -> Self { + use self::ReadError::*; + let kind = match x { + Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset, + ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected, + IllegalOrderedRead | BufferTooShort => io::ErrorKind::InvalidInput, + }; + Self::new(kind, x) + } +} + +impl AsyncRead for RecvStream { + async fn read(&mut self, mut buf: B) -> BufResult { + let res = self + .read(buf.as_mut_slice()) + .await + .map(|n| { + let n = n.unwrap_or_default(); + unsafe { buf.set_buf_init(n) } + n + }) + .map_err(Into::into); + BufResult(res, buf) + } +} + +#[cfg(feature = "futures-io")] +impl futures_util::AsyncRead for RecvStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.get_mut() + .poll_read(cx, buf) + .map_ok(Option::unwrap_or_default) + .map_err(Into::into) + } +} diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs new file mode 100644 index 00000000..e564839f --- /dev/null +++ b/compio-quic/src/send_stream.rs @@ -0,0 +1,346 @@ +use std::{ + io, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use compio_buf::{BufResult, IoBuf}; +use compio_io::AsyncWrite; +use futures_util::{future::poll_fn, ready}; +use quinn_proto::{ClosedStream, ConnectionError, FinishError, StreamId, VarInt, Written}; +use thiserror::Error; + +use crate::{ConnectionInner, StoppedError}; + +/// A stream that can only be used to send data. +/// +/// If dropped, streams that haven't been explicitly [`reset()`] will be +/// implicitly [`finish()`]ed, continuing to (re)transmit previously written +/// data until it has been fully acknowledged or the connection is closed. +/// +/// # Cancellation +/// +/// A `write` method is said to be *cancel-safe* when dropping its future before +/// the future becomes ready will always result in no data being written to the +/// stream. This is true of methods which succeed immediately when any progress +/// is made, and is not true of methods which might need to perform multiple +/// writes internally before succeeding. Each `write` method documents whether +/// it is cancel-safe. +/// +/// [`reset()`]: SendStream::reset +/// [`finish()`]: SendStream::finish +#[derive(Debug)] +pub struct SendStream { + conn: Arc, + stream: StreamId, +} + +impl SendStream { + pub(crate) fn new(conn: Arc, stream: StreamId) -> Self { + Self { conn, stream } + } + + /// Get the identity of this stream + pub fn id(&self) -> StreamId { + self.stream + } + + /// Notify the peer that no more data will ever be written to this stream. + /// + /// It is an error to write to a stream after `finish()`ing it. [`reset()`] + /// may still be called after `finish` to abandon transmission of any stream + /// data that might still be buffered. + /// + /// To wait for the peer to receive all buffered stream data, see + /// [`stopped()`]. + /// + /// May fail if [`finish()`] or [`reset()`] was previously called.This + /// error is harmless and serves only to indicate that the caller may have + /// incorrect assumptions about the stream's state. + /// + /// [`reset()`]: Self::reset + /// [`stopped()`]: Self::stopped + /// [`finish()`]: Self::finish + pub fn finish(&mut self) -> Result<(), ClosedStream> { + let mut state = self.conn.state(); + match state.conn.send_stream(self.stream).finish() { + Ok(()) => { + state.wake(); + Ok(()) + } + Err(FinishError::ClosedStream) => Err(ClosedStream::new()), + // Harmless. If the application needs to know about stopped streams at this point, + // it should call `stopped`. + Err(FinishError::Stopped(_)) => Ok(()), + } + } + + /// Close the stream immediately. + /// + /// No new data can be written after calling this method. Locally buffered + /// data is dropped, and previously transmitted data will no longer be + /// retransmitted if lost. If an attempt has already been made to finish + /// the stream, the peer may still receive all written data. + /// + /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was + /// previously called. This error is harmless and serves only to + /// indicate that the caller may have incorrect assumptions about the + /// stream's state. + pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { + let mut state = self.conn.state(); + state.conn.send_stream(self.stream).reset(error_code)?; + state.wake(); + Ok(()) + } + + /// Set the priority of the stream. + /// + /// Every stream has an initial priority of 0. Locally buffered data + /// from streams with higher priority will be transmitted before data + /// from streams with lower priority. Changing the priority of a stream + /// with pending data may only take effect after that data has been + /// transmitted. Using many different priority levels per connection may + /// have a negative impact on performance. + pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> { + self.conn + .state() + .conn + .send_stream(self.stream) + .set_priority(priority) + } + + /// Get the priority of the stream + pub fn priority(&self) -> Result { + self.conn.state().conn.send_stream(self.stream).priority() + } + + /// Completes when the peer stops the stream or reads the stream to + /// completion. + /// + /// Yields `Some` with the stop error code if the peer stops the stream. + /// Yields `None` if the local side [`finish()`](Self::finish)es the stream + /// and then the peer acknowledges receipt of all stream data (although not + /// necessarily the processing of it), after which the peer closing the + /// stream is no longer meaningful. + /// + /// For a variety of reasons, the peer may not send acknowledgements + /// immediately upon receiving data. As such, relying on `stopped` to + /// know when the peer has read a stream to completion may introduce + /// more latency than using an application-level response of some sort. + pub async fn stopped(&mut self) -> Result, StoppedError> { + poll_fn(|cx| { + let mut state = self.conn.state(); + match state.conn.send_stream(self.stream).stopped() { + Err(_) => Poll::Ready(Ok(None)), + Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))), + Ok(None) => { + if let Some(e) = &state.error { + return Poll::Ready(Err(e.clone().into())); + } + state.stopped.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + } + }) + .await + } + + fn execute_poll_write(&mut self, cx: &mut Context, f: F) -> Poll> + where + F: FnOnce(quinn_proto::SendStream) -> Result, + { + let mut state = self.conn.try_state()?; + match f(state.conn.send_stream(self.stream)) { + Ok(r) => { + state.wake(); + Poll::Ready(Ok(r)) + } + Err(e) => match e.try_into() { + Ok(e) => Poll::Ready(Err(e)), + Err(()) => { + state.writable.insert(self.stream, cx.waker().clone()); + Poll::Pending + } + }, + } + } + + /// Write bytes to the stream. + /// + /// Yields the number of bytes written on success. Congestion and flow + /// control may cause this to be shorter than `buf.len()`, indicating + /// that only a prefix of `buf` was written. + /// + /// This operation is cancel-safe. + pub async fn write(&mut self, buf: &[u8]) -> Result { + poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await + } + + /// Convenience method to write an entire buffer to the stream. + /// + /// This operation is *not* cancel-safe. + pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> { + let mut count = 0; + poll_fn(|cx| { + loop { + if count == buf.len() { + return Poll::Ready(Ok(())); + } + let n = + ready!(self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..])))?; + count += n; + } + }) + .await + } + + /// Write chunks to the stream. + /// + /// Yields the number of bytes and chunks written on success. + /// Congestion and flow control may cause this to be shorter than + /// `buf.len()`, indicating that only a prefix of `bufs` was written. + /// + /// This operation is cancel-safe. + pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result { + poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await + } + + /// Convenience method to write an entire list of chunks to the stream. + /// + /// This operation is *not* cancel-safe. + pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> { + let mut chunks = 0; + poll_fn(|cx| { + loop { + if chunks == bufs.len() { + return Poll::Ready(Ok(())); + } + let written = ready!(self.execute_poll_write(cx, |mut stream| { + stream.write_chunks(&mut bufs[chunks..]) + }))?; + chunks += written.chunks; + } + }) + .await + } +} + +impl Drop for SendStream { + fn drop(&mut self) { + let mut state = self.conn.state(); + + // clean up any previously registered wakers + state.stopped.remove(&self.stream); + state.writable.remove(&self.stream); + + if state.error.is_some() { + return; + } + match state.conn.send_stream(self.stream).finish() { + Ok(()) => state.wake(), + Err(FinishError::Stopped(reason)) => { + if state.conn.send_stream(self.stream).reset(reason).is_ok() { + state.wake(); + } + } + // Already finished or reset, which is fine. + Err(FinishError::ClosedStream) => {} + } + } +} + +/// Errors that arise from writing to a stream +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum WriteError { + /// The peer is no longer accepting data on this stream + /// + /// Carries an application-defined error code. + #[error("sending stopped by peer: error {0}")] + Stopped(VarInt), + /// The connection was lost + #[error("connection lost")] + ConnectionLost(#[from] ConnectionError), + /// The stream has already been finished or reset + #[error("closed stream")] + ClosedStream, + /// This was a 0-RTT stream and the server rejected it + /// + /// Can only occur on clients for 0-RTT streams, which can be opened using + /// [`Connecting::into_0rtt()`]. + /// + /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() + #[error("0-RTT rejected")] + ZeroRttRejected, +} + +impl TryFrom for WriteError { + type Error = (); + + fn try_from(value: quinn_proto::WriteError) -> Result { + use quinn_proto::WriteError::*; + match value { + Stopped(e) => Ok(Self::Stopped(e)), + ClosedStream => Ok(Self::ClosedStream), + Blocked => Err(()), + } + } +} + +impl From for WriteError { + fn from(x: StoppedError) -> Self { + match x { + StoppedError::ConnectionLost(e) => Self::ConnectionLost(e), + StoppedError::ZeroRttRejected => Self::ZeroRttRejected, + } + } +} + +impl From for io::Error { + fn from(x: WriteError) -> Self { + use WriteError::*; + let kind = match x { + Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset, + ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected, + }; + Self::new(kind, x) + } +} + +impl AsyncWrite for SendStream { + async fn write(&mut self, buf: T) -> BufResult { + let res = self.write(buf.as_slice()).await.map_err(Into::into); + BufResult(res, buf) + } + + async fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.finish()?; + Ok(()) + } +} + +#[cfg(feature = "futures-io")] +impl futures_util::AsyncWrite for SendStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut() + .execute_poll_write(cx, |mut stream| stream.write(buf)) + .map_err(Into::into) + } + + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.get_mut().finish()?; + Poll::Ready(Ok(())) + } +} From a24f248f7fe5e303c9f6edd8bc8616254e4e9122 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 8 Aug 2024 16:47:54 +0800 Subject: [PATCH 06/26] feat(quic): 0-rtt --- compio-quic/src/connection.rs | 146 +++++++++++++++++++++++++++------ compio-quic/src/recv_stream.rs | 26 +++++- compio-quic/src/send_stream.rs | 24 +++++- 3 files changed, 168 insertions(+), 28 deletions(-) diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index a372c96f..af102737 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -73,6 +73,14 @@ impl ConnectionState { .handshake_data() .map(|data| data.downcast::().unwrap()) } + + pub(crate) fn check_0rtt(&self) -> Result<(), ()> { + if self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt() { + Ok(()) + } else { + Err(()) + } + } } fn wake_stream(stream: StreamId, wakers: &mut HashMap) { @@ -240,6 +248,13 @@ impl ConnectionInner { if let Some(waker) = state.on_connected.take() { waker.wake() } + if state.conn.side().is_client() && !state.conn.accepted_0rtt() { + // Wake up rejected 0-RTT streams so they can fail immediately with + // `ZeroRttRejected` errors. + wake_all_streams(&mut state.writable); + wake_all_streams(&mut state.readable); + wake_all_streams(&mut state.stopped); + } } ConnectionLost { reason } => { state.terminate(reason); @@ -392,6 +407,65 @@ impl Connecting { }) .await } + + /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened + /// security. + /// + /// Returns `Ok` immediately if the local endpoint is able to attempt + /// sending 0/0.5-RTT data. If so, the returned [`Connection`] can be used + /// to send application data without waiting for the rest of the handshake + /// to complete, at the cost of weakened cryptographic security guarantees. + /// The [`Connection::accepted_0rtt`] method resolves when the handshake + /// does complete, at which point subsequently opened streams and written + /// data will have full cryptographic protection. + /// + /// ## Outgoing + /// + /// For outgoing connections, the initial attempt to convert to a + /// [`Connection`] which sends 0-RTT data will proceed if the + /// [`crypto::ClientConfig`][crate::crypto::ClientConfig] attempts to resume + /// a previous TLS session. However, **the remote endpoint may not actually + /// _accept_ the 0-RTT data**--yet still accept the connection attempt in + /// general. This possibility is conveyed through the + /// [`Connection::accepted_0rtt`] method--when the handshake completes, it + /// resolves to true if the 0-RTT data was accepted and false if it was + /// rejected. If it was rejected, the existence of streams opened and other + /// application data sent prior to the handshake completing will not be + /// conveyed to the remote application, and local operations on them will + /// return `ZeroRttRejected` errors. + /// + /// A server may reject 0-RTT data at its discretion, but accepting 0-RTT + /// data requires the relevant resumption state to be stored in the server, + /// which servers may limit or lose for various reasons including not + /// persisting resumption state across server restarts. + /// + /// ## Incoming + /// + /// For incoming connections, conversion to 0.5-RTT will always fully + /// succeed. `into_0rtt` will always return `Ok` and + /// [`Connection::accepted_0rtt`] will always resolve to true. + /// + /// ## Security + /// + /// On outgoing connections, this enables transmission of 0-RTT data, which + /// is vulnerable to replay attacks, and should therefore never invoke + /// non-idempotent operations. + /// + /// On incoming connections, this enables transmission of 0.5-RTT data, + /// which may be sent before TLS client authentication has occurred, and + /// should therefore not be used to send data for which client + /// authentication is being used. + pub fn into_0rtt(self) -> Result { + let is_ok = { + let state = self.0.state(); + state.conn.has_0rtt() || state.conn.side().is_server() + }; + if is_ok { + Ok(Connection(self.0.clone())) + } else { + Err(self) + } + } } impl Future for Connecting { @@ -579,20 +653,24 @@ impl Connection { ) } - fn try_open_stream(&self, dir: Dir) -> Result { - self.0 - .try_state()? + fn try_open_stream(&self, dir: Dir) -> Result<(StreamId, bool), OpenStreamError> { + let mut state = self.0.try_state()?; + let stream = state .conn .streams() .open(dir) - .ok_or(OpenStreamError::StreamsExhausted) + .ok_or(OpenStreamError::StreamsExhausted)?; + Ok(( + stream, + state.conn.side().is_client() && state.conn.is_handshaking(), + )) } - async fn open_stream(&self, dir: Dir) -> Result { + async fn open_stream(&self, dir: Dir) -> Result<(StreamId, bool), ConnectionError> { wait_event!( self.0.stream_available[dir as usize], match self.try_open_stream(dir) { - Ok(stream) => break Ok(stream), + Ok(res) => break Ok(res), Err(OpenStreamError::StreamsExhausted) => {} Err(OpenStreamError::ConnectionLost(e)) => break Err(e), } @@ -605,8 +683,8 @@ impl Connection { /// won't be notified that a stream has been opened until the stream is /// actually used. pub fn open_uni(&self) -> Result { - let stream = self.try_open_stream(Dir::Uni)?; - Ok(SendStream::new(self.0.clone(), stream)) + let (stream, is_0rtt) = self.try_open_stream(Dir::Uni)?; + Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) } /// Initiate a new outgoing unidirectional stream. @@ -618,8 +696,8 @@ impl Connection { /// /// [`open_uni()`]: crate::Connection::open_uni pub async fn open_uni_wait(&self) -> Result { - let stream = self.open_stream(Dir::Uni).await?; - Ok(SendStream::new(self.0.clone(), stream)) + let (stream, is_0rtt) = self.open_stream(Dir::Uni).await?; + Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) } /// Initiate a new outgoing bidirectional stream. @@ -628,10 +706,10 @@ impl Connection { /// won't be notified that a stream has been opened until the stream is /// actually used. pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> { - let stream = self.try_open_stream(Dir::Bi)?; + let (stream, is_0rtt) = self.try_open_stream(Dir::Bi)?; Ok(( - SendStream::new(self.0.clone(), stream), - RecvStream::new(self.0.clone(), stream), + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), )) } @@ -644,19 +722,19 @@ impl Connection { /// /// [`open_bi()`]: crate::Connection::open_bi pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> { - let stream = self.open_stream(Dir::Bi).await?; + let (stream, is_0rtt) = self.open_stream(Dir::Bi).await?; Ok(( - SendStream::new(self.0.clone(), stream), - RecvStream::new(self.0.clone(), stream), + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), )) } - async fn accept_stream(&self, dir: Dir) -> Result { + async fn accept_stream(&self, dir: Dir) -> Result<(StreamId, bool), ConnectionError> { wait_event!(self.0.stream_opened[dir as usize], { let mut state = self.0.state(); if let Some(stream) = state.conn.streams().accept(dir) { state.wake(); - break Ok(stream); + break Ok((stream, state.conn.is_handshaking())); } else if let Some(error) = &state.error { break Err(error.clone()); } @@ -665,8 +743,8 @@ impl Connection { /// Accept the next incoming uni-directional stream pub async fn accept_uni(&self) -> Result { - let stream = self.accept_stream(Dir::Uni).await?; - Ok(RecvStream::new(self.0.clone(), stream)) + let (stream, is_0rtt) = self.accept_stream(Dir::Uni).await?; + Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)) } /// Accept the next incoming bidirectional stream @@ -681,12 +759,34 @@ impl Connection { /// [`SendStream`]: crate::SendStream /// [`RecvStream`]: crate::RecvStream pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { - let stream = self.accept_stream(Dir::Bi).await?; + let (stream, is_0rtt) = self.accept_stream(Dir::Bi).await?; Ok(( - SendStream::new(self.0.clone(), stream), - RecvStream::new(self.0.clone(), stream), + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), )) } + + /// Wait for the connection to be fully established. + /// + /// For clients, the resulting value indicates if 0-RTT was accepted. For + /// servers, the resulting value is meaningless. + pub async fn accepted_0rtt(&self) -> Result { + future::poll_fn(|cx| { + let mut state = self.0.try_state()?; + + if state.connected { + return Poll::Ready(Ok(state.conn.accepted_0rtt())); + } + + match &state.on_connected { + Some(waker) if waker.will_wake(cx.waker()) => {} + _ => state.on_connected = Some(cx.waker().clone()), + } + + Poll::Pending + }) + .await + } } impl PartialEq for Connection { diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs index e055bcfd..5a12f99a 100644 --- a/compio-quic/src/recv_stream.rs +++ b/compio-quic/src/recv_stream.rs @@ -54,15 +54,17 @@ use crate::{ConnectionInner, StoppedError}; pub struct RecvStream { conn: Arc, stream: StreamId, + is_0rtt: bool, all_data_read: bool, reset: Option, } impl RecvStream { - pub(crate) fn new(conn: Arc, stream: StreamId) -> Self { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { Self { conn, stream, + is_0rtt, all_data_read: false, reset: None, } @@ -73,6 +75,15 @@ impl RecvStream { self.stream } + /// Check if this stream has been opened during 0-RTT. + /// + /// In which case any non-idempotent request should be considered dangerous + /// at the application level. Because read data is subject to replay + /// attacks. + pub fn is_0rtt(&self) -> bool { + self.is_0rtt + } + /// Stop accepting data /// /// Discards unread data and notifies the peer to stop transmitting. Once @@ -80,6 +91,9 @@ impl RecvStream { /// `ClosedStream` errors. pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { let mut state = self.conn.state(); + if self.is_0rtt && state.check_0rtt().is_err() { + return Ok(()); + } state.conn.recv_stream(self.stream).stop(error_code)?; state.wake(); self.all_data_read = true; @@ -101,6 +115,9 @@ impl RecvStream { poll_fn(|cx| { let mut state = self.conn.state(); + if self.is_0rtt && state.check_0rtt().is_err() { + return Poll::Ready(Err(StoppedError::ZeroRttRejected)); + } if let Some(code) = self.reset { return Poll::Ready(Ok(Some(code))); } @@ -152,6 +169,11 @@ impl RecvStream { } let mut state = self.conn.state(); + if self.is_0rtt { + state + .check_0rtt() + .map_err(|()| ReadError::ZeroRttRejected)?; + } // If we stored an error during a previous call, return it now. This can happen // if a `read_fn` both wants to return data and also returns an error in @@ -356,7 +378,7 @@ impl Drop for RecvStream { // clean up any previously registered wakers state.readable.remove(&self.stream); - if state.error.is_some() { + if state.error.is_some() || (self.is_0rtt && state.check_0rtt().is_err()) { return; } if !self.all_data_read { diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs index e564839f..bf8fc41a 100644 --- a/compio-quic/src/send_stream.rs +++ b/compio-quic/src/send_stream.rs @@ -34,11 +34,16 @@ use crate::{ConnectionInner, StoppedError}; pub struct SendStream { conn: Arc, stream: StreamId, + is_0rtt: bool, } impl SendStream { - pub(crate) fn new(conn: Arc, stream: StreamId) -> Self { - Self { conn, stream } + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + conn, + stream, + is_0rtt, + } } /// Get the identity of this stream @@ -89,6 +94,9 @@ impl SendStream { /// stream's state. pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { let mut state = self.conn.state(); + if self.is_0rtt && state.check_0rtt().is_err() { + return Ok(()); + } state.conn.send_stream(self.stream).reset(error_code)?; state.wake(); Ok(()) @@ -131,6 +139,11 @@ impl SendStream { pub async fn stopped(&mut self) -> Result, StoppedError> { poll_fn(|cx| { let mut state = self.conn.state(); + if self.is_0rtt { + state + .check_0rtt() + .map_err(|()| StoppedError::ZeroRttRejected)?; + } match state.conn.send_stream(self.stream).stopped() { Err(_) => Poll::Ready(Ok(None)), Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))), @@ -151,6 +164,11 @@ impl SendStream { F: FnOnce(quinn_proto::SendStream) -> Result, { let mut state = self.conn.try_state()?; + if self.is_0rtt { + state + .check_0rtt() + .map_err(|()| WriteError::ZeroRttRejected)?; + } match f(state.conn.send_stream(self.stream)) { Ok(r) => { state.wake(); @@ -234,7 +252,7 @@ impl Drop for SendStream { state.stopped.remove(&self.stream); state.writable.remove(&self.stream); - if state.error.is_some() { + if state.error.is_some() || (self.is_0rtt && state.check_0rtt().is_err()) { return; } match state.conn.send_stream(self.stream).finish() { From 9a58d2ffeb0b7111b71b0ac2f4f4d58ec7eff761 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 8 Aug 2024 22:37:23 +0800 Subject: [PATCH 07/26] feat(quic): redesign builders --- compio-quic/Cargo.toml | 3 - compio-quic/examples/client.rs | 16 +- compio-quic/examples/server.rs | 6 +- compio-quic/src/builder.rs | 507 +++++++++------------------------ compio-quic/src/endpoint.rs | 58 ++-- 5 files changed, 178 insertions(+), 412 deletions(-) diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 025be868..5400e0eb 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -39,9 +39,6 @@ thiserror = "1.0.63" [target.'cfg(windows)'.dependencies] windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } -# Linux specific dependencies -[target.'cfg(target_os = "linux")'.dependencies] - [target.'cfg(unix)'.dependencies] libc = { workspace = true } diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/client.rs index 36e3235c..9703cdc7 100644 --- a/compio-quic/examples/client.rs +++ b/compio-quic/examples/client.rs @@ -1,6 +1,6 @@ use std::net::{IpAddr, Ipv6Addr, SocketAddr}; -use compio_quic::Endpoint; +use compio_quic::ClientBuilder; use tracing_subscriber::filter::LevelFilter; #[compio_macros::main] @@ -9,9 +9,7 @@ async fn main() { .with_max_level(LevelFilter::TRACE) .init(); - let endpoint = Endpoint::client() - .with_no_server_verification() - .with_alpn_protocols(&["hq-29"]) + let endpoint = ClientBuilder::new_with_no_server_verification() .with_key_log() .bind("[::1]:0") .await @@ -25,11 +23,11 @@ async fn main() { None, ) .unwrap() + .into_0rtt() + .unwrap_err() .await .unwrap(); - println!("Connected to {:?}", conn.remote_address()); - let (mut send, mut recv) = conn.open_bi().unwrap(); send.write(&[1, 2, 3]).await.unwrap(); send.finish().unwrap(); @@ -38,10 +36,8 @@ async fn main() { recv.read_to_end(&mut buf).await.unwrap(); println!("{:?}", buf); - let _ = dbg!(send.write(&[1, 2, 3]).await); - - conn.close(1u32.into(), "qaq"); - conn.closed().await; + conn.close(1u32.into(), "bye"); } + endpoint.close(0u32.into(), "").await.unwrap(); } diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs index 8a727177..d2e55e8d 100644 --- a/compio-quic/examples/server.rs +++ b/compio-quic/examples/server.rs @@ -1,4 +1,4 @@ -use compio_quic::Endpoint; +use compio_quic::ServerBuilder; use tracing_subscriber::filter::LevelFilter; #[compio_macros::main] @@ -11,10 +11,8 @@ async fn main() { let cert_chain = vec![cert.cert.into()]; let key_der = cert.key_pair.serialize_der().try_into().unwrap(); - let endpoint = Endpoint::server() - .with_single_cert(cert_chain, key_der) + let endpoint = ServerBuilder::new_with_single_cert(cert_chain, key_der) .unwrap() - .with_alpn_protocols(&["hq-29"]) .with_key_log() .bind("[::1]:4433") .await diff --git a/compio-quic/src/builder.rs b/compio-quic/src/builder.rs index c8b20030..85cd4c14 100644 --- a/compio-quic/src/builder.rs +++ b/compio-quic/src/builder.rs @@ -1,309 +1,189 @@ -use std::{ - io, - net::{SocketAddrV4, SocketAddrV6}, - sync::Arc, - time::Duration, -}; +use std::{io, sync::Arc}; -use compio_net::{ToSocketAddrsAsync, UdpSocket}; +use compio_net::ToSocketAddrsAsync; use quinn_proto::{ crypto::rustls::{QuicClientConfig, QuicServerConfig}, - ClientConfig, EndpointConfig, ServerConfig, TransportConfig, + ClientConfig, ServerConfig, }; use crate::Endpoint; -/// A [builder] for [`Endpoint`] in client mode. +/// Helper to construct an [`Endpoint`] for use with outgoing connections only. /// -/// To get one, call [`Endpoint::client()`] or [`ClientBuilder::default()`]. +/// To get one, call `new_with_xxx` methods. /// /// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html #[derive(Debug)] -pub struct ClientBuilder { - inner: T, - - alpn_protocols: Vec>, - key_log: bool, - enable_early_data: bool, - - transport: Option, - version: Option, +pub struct ClientBuilder(T); - endpoint_config: EndpointConfig, -} - -impl Default for ClientBuilder<()> { - fn default() -> Self { - Self { - inner: (), - alpn_protocols: Vec::new(), - key_log: false, - enable_early_data: true, - transport: None, - version: None, - endpoint_config: EndpointConfig::default(), - } - } -} - -impl From>> for Result, E> { - fn from(builder: ClientBuilder>) -> Self { - builder.inner.map(|inner| ClientBuilder { - inner, - alpn_protocols: builder.alpn_protocols, - key_log: builder.key_log, - enable_early_data: builder.enable_early_data, - transport: builder.transport, - version: builder.version, - endpoint_config: builder.endpoint_config, - }) - } -} - -impl ClientBuilder { - fn map_inner(self, f: impl FnOnce(T) -> S) -> ClientBuilder { - ClientBuilder { - inner: f(self.inner), - alpn_protocols: self.alpn_protocols, - key_log: self.key_log, - enable_early_data: self.enable_early_data, - transport: self.transport, - version: self.version, - endpoint_config: self.endpoint_config, - } +impl ClientBuilder { + /// Create a builder with an empty [`rustls::RootCertStore`]. + pub fn new_with_empty_roots() -> Self { + ClientBuilder(rustls::RootCertStore::empty()) } - /// Set the ALPN protocols to use. - pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { - self.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); - self + /// Create a builder with [`rustls_native_certs`]. + #[cfg(feature = "native-certs")] + pub fn new_with_native_certs() -> io::Result { + let mut roots = rustls::RootCertStore::empty(); + roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?); + Ok(ClientBuilder(roots)) } - /// Logging key material to a file for debugging. The file's name is given - /// by the `SSLKEYLOGFILE` environment variable. - /// - /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot - /// be written, this does nothing. - pub fn with_key_log(mut self) -> Self { - self.key_log = true; - self + /// Create a builder with [`webpki_roots`]. + #[cfg(feature = "webpki-roots")] + pub fn new_with_webpki_roots() -> Self { + let roots = + rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + ClientBuilder(roots) } - /// Set a custom [`TransportConfig`]. - pub fn with_transport_config(mut self, transport: TransportConfig) -> Self { - self.transport = Some(transport); - self + /// Add a custom certificate. + pub fn with_custom_certificate( + mut self, + der: rustls::pki_types::CertificateDer, + ) -> Result { + self.0.add(der)?; + Ok(self) } - /// Set the QUIC version to use. - pub fn with_version(mut self, version: u32) -> Self { - self.version = Some(version); - self + /// Don't configure revocation. + pub fn with_no_crls(self) -> ClientBuilder { + ClientBuilder::new_with_root_certificates(self.0) } - /// Use the provided [`EndpointConfig`]. - pub fn with_endpoint_config(mut self, endpoint_config: EndpointConfig) -> Self { - self.endpoint_config = endpoint_config; - self + /// Verify the revocation state of presented client certificates against the + /// provided certificate revocation lists (CRLs). + pub fn with_crls( + self, + crls: impl IntoIterator>, + ) -> Result, rustls::client::VerifierBuilderError> { + let verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(self.0)) + .with_crls(crls) + .build()?; + Ok(ClientBuilder::new_with_webpki_verifier(verifier)) } } -impl ClientBuilder<()> { - /// Use the provided [`rustls::ClientConfig`]. - pub fn with_rustls_client_config( - self, +impl ClientBuilder { + /// Create a builder with the provided [`rustls::ClientConfig`]. + pub fn new_with_rustls_client_config( client_config: rustls::ClientConfig, ) -> ClientBuilder { - self.map_inner(|_| client_config) + ClientBuilder(client_config) } /// Do not verify the server's certificate. It is vulnerable to MITM /// attacks, but convenient for testing. - pub fn with_no_server_verification( - self, - ) -> ClientBuilder> { - self.map_inner(|_| Arc::new(verifier::SkipServerVerification::new()) as _) + pub fn new_with_no_server_verification() -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier::SkipServerVerification::new())) + .with_no_client_auth(), + ) } - /// Use [`rustls_platform_verifier`]. + /// Create a builder with [`rustls_platform_verifier`]. #[cfg(feature = "platform-verifier")] - pub fn with_platform_verifier( - self, - ) -> ClientBuilder> { - self.map_inner(|_| Arc::new(rustls_platform_verifier::Verifier::new()) as _) - } - - /// Use an empty [`rustls::RootCertStore`]. - pub fn with_root_certificates(self) -> ClientBuilder { - self.map_inner(|_| rustls::RootCertStore::empty()) - } -} - -impl ClientBuilder { - /// Create an [`Endpoint`] binding to the addr provided. - pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { - let mut client_config = self.inner; - - client_config.alpn_protocols = self.alpn_protocols; - if self.key_log { - client_config.key_log = Arc::new(rustls::KeyLogFile::new()); - } - client_config.enable_early_data = self.enable_early_data; - - let mut client_config = ClientConfig::new(Arc::new( - QuicClientConfig::try_from(client_config) - .expect("should support TLS13_AES_128_GCM_SHA256"), - )); - - if let Some(transport) = self.transport { - client_config.transport_config(Arc::new(transport)); - } - if let Some(version) = self.version { - client_config.version(version); - } - - let socket = UdpSocket::bind(addr).await?; - Endpoint::new(socket, self.endpoint_config, None, Some(client_config)) + pub fn new_with_platform_verifier() -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new( + rustls_platform_verifier::Verifier::new(), + )) + .with_no_client_auth(), + ) } -} -impl ClientBuilder> { - /// Create an [`Endpoint`] binding to the addr provided. - pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { - self.map_inner(|verifier| { + /// Create a builder with the provided [`rustls::RootCertStore`]. + pub fn new_with_root_certificates( + roots: rustls::RootCertStore, + ) -> ClientBuilder { + ClientBuilder( rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .dangerous() - .with_custom_certificate_verifier(verifier) - .with_no_client_auth() - }) - .bind(addr) - .await + .with_root_certificates(roots) + .with_no_client_auth(), + ) } -} -impl ClientBuilder { - /// Use [`rustls_native_certs`]. - #[cfg(feature = "native-certs")] - pub fn with_native_certs(mut self) -> io::Result { - self.inner - .add_parsable_certificates(rustls_native_certs::load_native_certs()?); - Ok(self) + /// Create a builder with a custom [`rustls::client::WebPkiServerVerifier`]. + pub fn new_with_webpki_verifier( + verifier: Arc, + ) -> ClientBuilder { + ClientBuilder( + rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_webpki_verifier(verifier) + .with_no_client_auth(), + ) } - /// Use [`webpki_roots`]. - #[cfg(feature = "webpki-roots")] - pub fn with_webpki_roots(mut self) -> Self { - self.inner - .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + /// Set the ALPN protocols to use. + pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { + self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); self } - /// Add a custom certificate. - pub fn with_custom_certificate( - mut self, - der: rustls::pki_types::CertificateDer, - ) -> Result { - self.inner.add(der)?; - Ok(self) + /// Logging key material to a file for debugging. The file's name is given + /// by the `SSLKEYLOGFILE` environment variable. + /// + /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot + /// be written, this does nothing. + pub fn with_key_log(mut self) -> Self { + self.0.key_log = Arc::new(rustls::KeyLogFile::new()); + self } - /// Verify the revocation state of presented client certificates against the - /// provided certificate revocation lists (CRLs). - pub fn with_crls( - self, - crls: impl IntoIterator>, - ) -> Result< - ClientBuilder>, - rustls::client::VerifierBuilderError, - > { - self.map_inner(|roots| { - rustls::client::WebPkiServerVerifier::builder(Arc::new(roots)) - .with_crls(crls) - .build() - .map(|v| v as _) - }) - .into() + /// Build a [`ClientConfig`]. + pub fn build(mut self) -> ClientConfig { + self.0.enable_early_data = true; + ClientConfig::new(Arc::new( + QuicClientConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"), + )) } - /// Create an [`Endpoint`] binding to the addr provided. + /// Create a new [`Endpoint`]. + /// + /// See [`Endpoint::client`] for more information. pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { - self.map_inner(|roots| { - rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_root_certificates(roots) - .with_no_client_auth() - }) - .bind(addr) - .await + let mut endpoint = Endpoint::client(addr).await?; + endpoint.default_client_config = Some(self.build()); + Ok(endpoint) } } -/// A [builder] for [`Endpoint`] in server mode. +/// Helper to construct an [`Endpoint`] for use with incoming connections. /// -/// To get one, call [`Endpoint::server()`] or [`ServerBuilder::default()`]. +/// To get one, call `new_with_xxx` methods. /// /// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html #[derive(Debug)] -pub struct ServerBuilder { - inner: T, - - alpn_protocols: Vec>, - key_log: bool, - enable_early_data: bool, - - transport: Option, - retry_token_lifetime: Option, - migration: bool, - preferred_address_v4: Option, - preferred_address_v6: Option, - max_incoming: Option, - incoming_buffer_size: Option, - incoming_buffer_size_total: Option, - - endpoint_config: EndpointConfig, -} +pub struct ServerBuilder(T); -impl Default for ServerBuilder<()> { - fn default() -> Self { - Self { - inner: (), - alpn_protocols: Vec::new(), - key_log: false, - enable_early_data: true, - transport: None, - retry_token_lifetime: None, - migration: true, - preferred_address_v4: None, - preferred_address_v6: None, - max_incoming: None, - incoming_buffer_size: None, - incoming_buffer_size_total: None, - endpoint_config: EndpointConfig::default(), - } +impl ServerBuilder { + /// Create a builder with the provided [`rustls::ServerConfig`]. + pub fn new_with_rustls_server_config(server_config: rustls::ServerConfig) -> Self { + Self(server_config) } -} -impl ServerBuilder { - fn map_inner(self, f: impl FnOnce(T) -> S) -> ServerBuilder { - ServerBuilder { - inner: f(self.inner), - alpn_protocols: self.alpn_protocols, - key_log: self.key_log, - enable_early_data: self.enable_early_data, - transport: self.transport, - retry_token_lifetime: self.retry_token_lifetime, - migration: self.migration, - preferred_address_v4: self.preferred_address_v4, - preferred_address_v6: self.preferred_address_v6, - max_incoming: self.max_incoming, - incoming_buffer_size: self.incoming_buffer_size, - incoming_buffer_size_total: self.incoming_buffer_size_total, - endpoint_config: self.endpoint_config, - } + /// Create a builder with a single certificate chain and matching private + /// key. Using this method gets the same result as calling + /// [`ServerConfig::with_single_cert`]. + pub fn new_with_single_cert( + cert_chain: Vec>, + key_der: rustls::pki_types::PrivateKeyDer<'static>, + ) -> Result { + let server_config = + rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert(cert_chain, key_der)?; + Ok(Self::new_with_rustls_server_config(server_config)) } /// Set the ALPN protocols to use. pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self { - self.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); + self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect(); self } @@ -313,144 +193,23 @@ impl ServerBuilder { /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot /// be written, this does nothing. pub fn with_key_log(mut self) -> Self { - self.key_log = true; + self.0.key_log = Arc::new(rustls::KeyLogFile::new()); self } - /// Set a custom [`TransportConfig`]. - pub fn with_transport_config(mut self, transport: TransportConfig) -> Self { - self.transport = Some(transport); - self + /// Build a [`ServerConfig`]. + pub fn build(mut self) -> ServerConfig { + self.0.max_early_data_size = u32::MAX; + ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"), + )) } - /// Duration after a stateless retry token was issued for which it's - /// considered valid. - pub fn with_retry_token_lifetime(mut self, retry_token_lifetime: Duration) -> Self { - self.retry_token_lifetime = Some(retry_token_lifetime); - self - } - - /// Whether to allow clients to migrate to new addresses. + /// Create a new [`Endpoint`]. /// - /// See [`quinn_proto::ServerConfig::migration`]. - pub fn with_migration(mut self, migration: bool) -> Self { - self.migration = migration; - self - } - - /// The preferred IPv4 address during handshaking. - /// - /// See [`quinn_proto::ServerConfig::preferred_address_v4`]. - pub fn with_preferred_address_v4(mut self, addr: SocketAddrV4) -> Self { - self.preferred_address_v4 = Some(addr); - self - } - - /// The preferred IPv6 address during handshaking. - /// - /// See [`quinn_proto::ServerConfig::preferred_address_v6`]. - pub fn with_preferred_address_v6(mut self, addr: SocketAddrV6) -> Self { - self.preferred_address_v6 = Some(addr); - self - } - - /// Maximum number of [`Incoming`][crate::Incoming] to allow to exist at a - /// time. - /// - /// See [`quinn_proto::ServerConfig::max_incoming`]. - pub fn with_max_incoming(mut self, max_incoming: usize) -> Self { - self.max_incoming = Some(max_incoming); - self - } - - /// Maximum number of received bytes to buffer for each - /// [`Incoming`][crate::Incoming]. - /// - /// See [`quinn_proto::ServerConfig::incoming_buffer_size`]. - pub fn with_incoming_buffer_size(mut self, incoming_buffer_size: u64) -> Self { - self.incoming_buffer_size = Some(incoming_buffer_size); - self - } - - /// Maximum number of received bytes to buffer for all - /// [`Incoming`][crate::Incoming] collectively. - /// - /// See [`quinn_proto::ServerConfig::incoming_buffer_size_total`]. - pub fn with_incoming_buffer_size_total(mut self, incoming_buffer_size_total: u64) -> Self { - self.incoming_buffer_size_total = Some(incoming_buffer_size_total); - self - } - - /// Use the provided [`EndpointConfig`]. - pub fn with_endpoint_config(mut self, endpoint_config: EndpointConfig) -> Self { - self.endpoint_config = endpoint_config; - self - } -} - -impl ServerBuilder<()> { - /// Use the provided [`rustls::ServerConfig`]. - pub fn with_rustls_server_config( - self, - server_config: rustls::ServerConfig, - ) -> ServerBuilder { - self.map_inner(|_| server_config) - } - - /// Sets a single certificate chain and matching private key. - pub fn with_single_cert( - self, - cert_chain: Vec>, - key_der: rustls::pki_types::PrivateKeyDer<'static>, - ) -> Result, rustls::Error> { - let server_config = - rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_single_cert(cert_chain, key_der)?; - Ok(self.with_rustls_server_config(server_config)) - } -} - -impl ServerBuilder { - /// Create an [`Endpoint`] binding to the addr provided. + /// See [`Endpoint::server`] for more information. pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result { - let mut server_config = self.inner; - - server_config.alpn_protocols = self.alpn_protocols; - if self.key_log { - server_config.key_log = Arc::new(rustls::KeyLogFile::new()); - } - if self.enable_early_data { - server_config.max_early_data_size = u32::MAX; - } - - let mut server_config = ServerConfig::with_crypto(Arc::new( - QuicServerConfig::try_from(server_config) - .expect("should support TLS13_AES_128_GCM_SHA256"), - )); - - if let Some(transport) = self.transport { - server_config.transport_config(Arc::new(transport)); - } - if let Some(retry_token_lifetime) = self.retry_token_lifetime { - server_config.retry_token_lifetime(retry_token_lifetime); - } - server_config - .migration(self.migration) - .preferred_address_v4(self.preferred_address_v4) - .preferred_address_v6(self.preferred_address_v6); - if let Some(max_incoming) = self.max_incoming { - server_config.max_incoming(max_incoming); - } - if let Some(incoming_buffer_size) = self.incoming_buffer_size { - server_config.incoming_buffer_size(incoming_buffer_size); - } - if let Some(incoming_buffer_size_total) = self.incoming_buffer_size_total { - server_config.incoming_buffer_size_total(incoming_buffer_size_total); - } - - let socket = UdpSocket::bind(addr).await?; - Endpoint::new(socket, self.endpoint_config, Some(server_config), None) + Endpoint::server(addr, self.build()).await } } diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 78448960..9353aeab 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -10,7 +10,7 @@ use std::{ }; use compio_buf::BufResult; -use compio_net::UdpSocket; +use compio_net::{ToSocketAddrsAsync, UdpSocket}; use compio_runtime::JoinHandle; use event_listener::{Event, IntoNotification}; use flume::{unbounded, Receiver, Sender}; @@ -25,10 +25,7 @@ use quinn_proto::{ EndpointEvent, ServerConfig, Transmit, VarInt, }; -use crate::{ - wait_event, ClientBuilder, Connecting, ConnectionEvent, Incoming, RecvMeta, ServerBuilder, - Socket, -}; +use crate::{wait_event, Connecting, ConnectionEvent, Incoming, RecvMeta, Socket}; #[derive(Debug)] struct EndpointState { @@ -320,14 +317,37 @@ impl Endpoint { }) } - /// Create a builder for a QUIC client. - pub fn client() -> ClientBuilder<()> { - ClientBuilder::default() + /// Helper to construct an endpoint for use with outgoing connections only. + /// + /// Note that `addr` is the *local* address to bind to, which should usually + /// be a wildcard address like `0.0.0.0:0` or `[::]:0`, which allow + /// communication with any reachable IPv4 or IPv6 address respectively + /// from an OS-assigned port. + /// + /// If an IPv6 address is provided, the socket may dual-stack depending on + /// the platform, so as to allow communication with both IPv4 and IPv6 + /// addresses. As such, calling this method with the address `[::]:0` is a + /// reasonable default to maximize the ability to connect to other + /// address. + /// + /// IPv4 client is never dual-stack. + pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result { + // TODO: try to enable dual-stack on all platforms, notably Windows + let socket = UdpSocket::bind(addr).await?; + Self::new(socket, EndpointConfig::default(), None, None) } - /// Create a builder for a QUIC server. - pub fn server() -> ServerBuilder<()> { - ServerBuilder::default() + /// Helper to construct an endpoint for use with both incoming and outgoing + /// connections + /// + /// Platform defaults for dual-stack sockets vary. For example, any socket + /// bound to a wildcard IPv6 address on Windows will not by default be + /// able to communicate with IPv4 addresses. Portable applications + /// should bind an address that matches the family they wish to + /// communicate within. + pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result { + let socket = UdpSocket::bind(addr).await?; + Self::new(socket, EndpointConfig::default(), Some(config), None) } /// Connect to a remote endpoint. @@ -337,13 +357,9 @@ impl Endpoint { server_name: &str, config: Option, ) -> Result { - let config = if let Some(config) = config { - config - } else if let Some(config) = &self.default_client_config { - config.clone() - } else { - return Err(ConnectError::NoDefaultClientConfig); - }; + let config = config + .or_else(|| self.default_client_config.clone()) + .ok_or(ConnectError::NoDefaultClientConfig)?; self.inner.connect(remote, server_name, config) } @@ -392,11 +408,11 @@ impl Endpoint { let reason = reason.to_string(); { - let close = &mut self.inner.state.lock().unwrap().close; - if close.is_some() { + let mut state = self.inner.state.lock().unwrap(); + if state.close.is_some() { return Ok(()); } - close.replace((error_code, reason.clone())); + state.close = Some((error_code, reason.clone())); } for conn in self.inner.state.lock().unwrap().connections.values() { From 29ed9f9fee4ba5ccee56260e530a7acd95c7ab3e Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Fri, 9 Aug 2024 01:33:16 +0800 Subject: [PATCH 08/26] test(quic): port from quinn --- compio-quic/Cargo.toml | 3 +- compio-quic/examples/client.rs | 4 +- compio-quic/examples/server.rs | 4 +- compio-quic/src/connection.rs | 2 +- compio-quic/src/endpoint.rs | 24 +++ compio-quic/tests/basic.rs | 271 ++++++++++++++++++++++++++++++++ compio-quic/tests/common/mod.rs | 33 ++++ compio-quic/tests/control.rs | 91 +++++++++++ compio-quic/tests/echo.rs | 198 +++++++++++++++++++++++ 9 files changed, 624 insertions(+), 6 deletions(-) create mode 100644 compio-quic/tests/basic.rs create mode 100644 compio-quic/tests/common/mod.rs create mode 100644 compio-quic/tests/control.rs create mode 100644 compio-quic/tests/echo.rs diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 5400e0eb..da99a0a9 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -45,9 +45,10 @@ libc = { workspace = true } [dev-dependencies] compio-driver = { workspace = true } compio-macros = { workspace = true } +rand = "0.8.5" rcgen = "0.13.1" socket2 = { workspace = true, features = ["all"] } -tracing-subscriber = "0.3.18" +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } [features] default = ["webpki-roots"] diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/client.rs index 9703cdc7..e634649b 100644 --- a/compio-quic/examples/client.rs +++ b/compio-quic/examples/client.rs @@ -1,12 +1,12 @@ use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use compio_quic::ClientBuilder; -use tracing_subscriber::filter::LevelFilter; +use tracing_subscriber::EnvFilter; #[compio_macros::main] async fn main() { tracing_subscriber::fmt() - .with_max_level(LevelFilter::TRACE) + .with_env_filter(EnvFilter::from_default_env()) .init(); let endpoint = ClientBuilder::new_with_no_server_verification() diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs index d2e55e8d..9e52e8a9 100644 --- a/compio-quic/examples/server.rs +++ b/compio-quic/examples/server.rs @@ -1,10 +1,10 @@ use compio_quic::ServerBuilder; -use tracing_subscriber::filter::LevelFilter; +use tracing_subscriber::EnvFilter; #[compio_macros::main] async fn main() { tracing_subscriber::fmt() - .with_max_level(LevelFilter::TRACE) + .with_env_filter(EnvFilter::from_default_env()) .init(); let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index af102737..b03981d4 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -496,7 +496,7 @@ impl Drop for Connecting { } /// A QUIC connection. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Connection(Arc); impl Connection { diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 9353aeab..59d0f306 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -382,6 +382,30 @@ impl Endpoint { Some(Incoming::new(incoming, self.inner.clone())) } + /// Replace the server configuration, affecting new incoming connections + /// only. + /// + /// Useful for e.g. refreshing TLS certificates without disrupting existing + /// connections. + pub fn set_server_config(&self, server_config: Option) { + self.inner + .state + .lock() + .unwrap() + .endpoint + .set_server_config(server_config.map(Arc::new)) + } + + /// Get the local `SocketAddr` the underlying socket is bound to. + pub fn local_addr(&self) -> io::Result { + self.inner.socket.local_addr() + } + + /// Get the number of connections that are currently open. + pub fn open_connections(&self) -> usize { + self.inner.state.lock().unwrap().endpoint.open_connections() + } + // Modified from [`SharedFd::try_unwrap_inner`], see notes there. unsafe fn try_unwrap_inner(this: &ManuallyDrop) -> Option { let ptr = ManuallyDrop::new(std::ptr::read(&this.inner)); diff --git a/compio-quic/tests/basic.rs b/compio-quic/tests/basic.rs new file mode 100644 index 00000000..ccf720ab --- /dev/null +++ b/compio-quic/tests/basic.rs @@ -0,0 +1,271 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::{Duration, Instant}, +}; + +use compio_quic::{ClientBuilder, ConnectionError, Endpoint, TransportConfig}; +use futures_util::join; + +mod common; +use common::{config_pair, subscribe}; + +#[compio_macros::test] +async fn handshake_timeout() { + let _guard = subscribe(); + + let endpoint = Endpoint::client("127.0.0.1:0").await.unwrap(); + + const IDLE_TIMEOUT: Duration = Duration::from_millis(100); + + let mut transport_config = TransportConfig::default(); + transport_config + .max_idle_timeout(Some(IDLE_TIMEOUT.try_into().unwrap())) + .initial_rtt(Duration::from_millis(10)); + let mut client_config = ClientBuilder::new_with_no_server_verification().build(); + client_config.transport_config(Arc::new(transport_config)); + + let start = Instant::now(); + match endpoint + .connect( + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1), + "localhost", + Some(client_config), + ) + .unwrap() + .await + { + Err(ConnectionError::TimedOut) => {} + Err(e) => panic!("unexpected error: {e:?}"), + Ok(_) => panic!("unexpected success"), + } + let dt = start.elapsed(); + assert!(dt > IDLE_TIMEOUT && dt < 2 * IDLE_TIMEOUT); +} + +#[compio_macros::test] +async fn close_endpoint() { + let _guard = subscribe(); + + let endpoint = ClientBuilder::new_with_no_server_verification() + .bind("127.0.0.1:0") + .await + .unwrap(); + + let conn = endpoint + .connect( + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1), + "localhost", + None, + ) + .unwrap(); + + compio_runtime::spawn(endpoint.close(0u32.into(), "")).detach(); + + match conn.await { + Err(ConnectionError::LocallyClosed) => (), + Err(e) => panic!("unexpected error: {e}"), + Ok(_) => { + panic!("unexpected success"); + } + } +} + +async fn endpoint() -> Endpoint { + let (server_config, client_config) = config_pair(None); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + endpoint +} + +#[compio_macros::test] +async fn read_after_close() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + const MSG: &[u8] = b"goodbye!"; + + join!( + async { + let conn = endpoint.wait_incoming().await.unwrap().await.unwrap(); + let mut s = conn.open_uni().unwrap(); + s.write_all(MSG).await.unwrap(); + s.finish().unwrap(); + // Wait for the stream to be closed, one way or another. + let _ = s.stopped().await; + }, + async { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap(); + let mut recv = conn.accept_uni().await.unwrap(); + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, MSG); + }, + ); +} + +#[compio_macros::test] +async fn export_keying_material() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + let (conn1, conn2) = join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap() + }, + async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, + ); + let mut buf1 = [0u8; 64]; + let mut buf2 = [0u8; 64]; + conn1 + .export_keying_material(&mut buf1, b"qaq", b"qwq") + .unwrap(); + conn2 + .export_keying_material(&mut buf2, b"qaq", b"qwq") + .unwrap(); + assert_eq!(buf1, buf2); +} + +#[compio_macros::test] +async fn zero_rtt() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + const MSG0: &[u8] = b"zero"; + const MSG1: &[u8] = b"one"; + + join!( + async { + for _ in 0..2 { + let conn = endpoint + .wait_incoming() + .await + .unwrap() + .accept() + .unwrap() + .into_0rtt() + .unwrap(); + join!( + async { + while let Ok(mut recv) = conn.accept_uni().await { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, MSG0); + } + }, + async { + let mut send = conn.open_uni().unwrap(); + send.write_all(MSG0).await.unwrap(); + send.finish().unwrap(); + conn.accepted_0rtt().await.unwrap(); + let mut send = conn.open_uni().unwrap(); + send.write_all(MSG1).await.unwrap(); + send.finish().unwrap(); + // no need to wait for the stream to be closed due to + // the `while` loop above + }, + ); + } + }, + async { + { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .into_0rtt() + .unwrap_err() + .await + .unwrap(); + + let mut buf = vec![]; + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG0); + + buf.clear(); + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG1); + } + + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .into_0rtt() + .unwrap(); + + let mut send = conn.open_uni().unwrap(); + send.write_all(MSG0).await.unwrap(); + send.finish().unwrap(); + + let mut buf = vec![]; + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG0); + + assert!(conn.accepted_0rtt().await.unwrap()); + + buf.clear(); + let mut recv = conn.accept_uni().await.unwrap(); + recv.read_to_end(&mut buf).await.expect("read_to_end"); + assert_eq!(buf, MSG1); + }, + ); +} + +#[compio_macros::test] +async fn two_datagram_readers() { + let _guard = subscribe(); + + let endpoint = endpoint().await; + + const MSG1: &[u8] = b"one"; + const MSG2: &[u8] = b"two"; + + let (conn1, conn2) = join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap() + }, + async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, + ); + + let ev = event_listener::Event::new(); + + let (a, b, _) = join!( + async { + let x = conn1.recv_datagram().await.unwrap(); + ev.notify(1); + x + }, + async { + let x = conn1.recv_datagram().await.unwrap(); + ev.notify(1); + x + }, + async { + conn2.send_datagram(MSG1.into()).unwrap(); + ev.listen().await; + conn2.send_datagram_wait(MSG2.into()).await.unwrap(); + } + ); + + assert!(a == MSG1 || b == MSG1); + assert!(a == MSG2 || b == MSG2); +} diff --git a/compio-quic/tests/common/mod.rs b/compio-quic/tests/common/mod.rs new file mode 100644 index 00000000..05fbe3f0 --- /dev/null +++ b/compio-quic/tests/common/mod.rs @@ -0,0 +1,33 @@ +use std::sync::Arc; + +use compio_log::subscriber::DefaultGuard; +use compio_quic::{ClientBuilder, ClientConfig, ServerBuilder, ServerConfig, TransportConfig}; +use tracing_subscriber::{util::SubscriberInitExt, EnvFilter}; + +pub fn subscribe() -> DefaultGuard { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish() + .set_default() +} + +pub fn config_pair(transport: Option) -> (ServerConfig, ClientConfig) { + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert_chain = vec![cert.cert.der().clone()]; + let key_der = cert.key_pair.serialize_der().try_into().unwrap(); + + let mut server_config = ServerBuilder::new_with_single_cert(cert_chain, key_der) + .unwrap() + .build(); + let mut client_config = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert.cert.into()) + .unwrap() + .with_no_crls() + .build(); + if let Some(transport) = transport { + let transport = Arc::new(transport); + server_config.transport_config(transport.clone()); + client_config.transport_config(transport); + } + (server_config, client_config) +} diff --git a/compio-quic/tests/control.rs b/compio-quic/tests/control.rs new file mode 100644 index 00000000..6610feb9 --- /dev/null +++ b/compio-quic/tests/control.rs @@ -0,0 +1,91 @@ +use compio_quic::{ConnectionError, Endpoint, TransportConfig}; + +mod common; +use common::{config_pair, subscribe}; +use futures_util::join; + +#[compio_macros::test] +async fn ip_blocking() { + let _guard = subscribe(); + + let (server_config, client_config) = config_pair(None); + + let server = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + let server_addr = server.local_addr().unwrap(); + + let client1 = Endpoint::client("127.0.0.1:0").await.unwrap(); + let client1_addr = client1.local_addr().unwrap(); + let client2 = Endpoint::client("127.0.0.1:0").await.unwrap(); + + let srv = compio_runtime::spawn(async move { + loop { + let incoming = server.wait_incoming().await.unwrap(); + if incoming.remote_address() == client1_addr { + incoming.refuse(); + } else if incoming.remote_address_validated() { + incoming.await.unwrap(); + } else { + incoming.retry().unwrap(); + } + } + }); + + let e = client1 + .connect(server_addr, "localhost", Some(client_config.clone())) + .unwrap() + .await + .unwrap_err(); + assert!(matches!(e, ConnectionError::ConnectionClosed(_))); + client2 + .connect(server_addr, "localhost", Some(client_config)) + .unwrap() + .await + .unwrap(); + + let _ = srv.cancel().await; +} + +#[compio_macros::test] +async fn stream_id_flow_control() { + let _guard = subscribe(); + + let mut cfg = TransportConfig::default(); + cfg.max_concurrent_uni_streams(1u32.into()); + + let (server_config, client_config) = config_pair(Some(cfg)); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + + let (conn1, conn2) = join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap() + }, + async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, + ); + + // If `open_uni_wait` doesn't get unblocked when the previous stream is dropped, + // this will time out. + join!( + async { + conn1.open_uni_wait().await.unwrap(); + }, + async { + conn1.open_uni_wait().await.unwrap(); + }, + async { + conn1.open_uni_wait().await.unwrap(); + }, + async { + conn2.accept_uni().await.unwrap(); + conn2.accept_uni().await.unwrap(); + } + ); +} diff --git a/compio-quic/tests/echo.rs b/compio-quic/tests/echo.rs new file mode 100644 index 00000000..b76c392f --- /dev/null +++ b/compio-quic/tests/echo.rs @@ -0,0 +1,198 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use bytes::Bytes; +use compio_quic::{Endpoint, RecvStream, SendStream, TransportConfig}; + +mod common; +use common::{config_pair, subscribe}; +use futures_util::join; +use rand::{rngs::StdRng, RngCore, SeedableRng}; + +struct EchoArgs { + client_addr: SocketAddr, + server_addr: SocketAddr, + nr_streams: usize, + stream_size: usize, + receive_window: Option, + stream_receive_window: Option, +} + +async fn echo((mut send, mut recv): (SendStream, RecvStream)) { + loop { + // These are 32 buffers, for reading approximately 32kB at once + #[rustfmt::skip] + let mut bufs = [ + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + ]; + + match recv.read_chunks(&mut bufs).await.unwrap() { + Some(n) => { + send.write_all_chunks(&mut bufs[..n]).await.unwrap(); + } + None => break, + } + } + + let _ = send.finish(); +} + +/// This is just an arbitrary number to generate deterministic test data +const SEED: u64 = 0x12345678; + +fn gen_data(size: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(SEED); + let mut buf = vec![0; size]; + rng.fill_bytes(&mut buf); + buf +} + +async fn run_echo(args: EchoArgs) { + // Use small receive windows + let mut transport_config = TransportConfig::default(); + if let Some(receive_window) = args.receive_window { + transport_config.receive_window(receive_window.try_into().unwrap()); + } + if let Some(stream_receive_window) = args.stream_receive_window { + transport_config.stream_receive_window(stream_receive_window.try_into().unwrap()); + } + transport_config.max_concurrent_bidi_streams(1_u8.into()); + transport_config.max_concurrent_uni_streams(1_u8.into()); + + let (server_config, client_config) = config_pair(Some(transport_config)); + + let server = Endpoint::server(args.server_addr, server_config) + .await + .unwrap(); + let client = Endpoint::client(args.client_addr).await.unwrap(); + + join!( + async { + let conn = server.wait_incoming().await.unwrap().await.unwrap(); + + while let Ok(stream) = conn.accept_bi().await { + compio_runtime::spawn(echo(stream)).detach(); + } + }, + async { + let conn = client + .connect( + server.local_addr().unwrap(), + "localhost", + Some(client_config), + ) + .unwrap() + .await + .unwrap(); + + for _ in 0..args.nr_streams { + let (mut send, mut recv) = conn.open_bi_wait().await.unwrap(); + let msg = gen_data(args.stream_size); + + let (_, data) = join!( + async { + send.write_all(&msg).await.unwrap(); + send.finish().unwrap(); + }, + async { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + buf + } + ); + + assert_eq!(data, msg); + } + } + ); +} + +#[compio_macros::test] +async fn echo_v6() { + let _guard = subscribe(); + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + nr_streams: 1, + stream_size: 10 * 1024, + receive_window: None, + stream_receive_window: None, + }) + .await; +} + +#[compio_macros::test] +async fn echo_v4() { + let _guard = subscribe(); + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 1, + stream_size: 10 * 1024, + receive_window: None, + stream_receive_window: None, + }) + .await; +} + +#[compio_macros::test] +#[cfg_attr(target_os = "windows", ignore)] +async fn echo_dualstack() { + let _guard = subscribe(); + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 1, + stream_size: 10 * 1024, + receive_window: None, + stream_receive_window: None, + }) + .await; +} + +#[compio_macros::test] +async fn stress_receive_window() { + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 50, + stream_size: 25 * 1024 + 11, + receive_window: Some(37), + stream_receive_window: Some(100 * 1024 * 1024), + }) + .await; +} + +#[compio_macros::test] +async fn stress_stream_receive_window() { + // Note that there is no point in running this with too many streams, + // since the window is only active within a stream. + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 2, + stream_size: 250 * 1024 + 11, + receive_window: Some(100 * 1024 * 1024), + stream_receive_window: Some(37), + }) + .await; +} + +#[compio_macros::test] +async fn stress_both_windows() { + run_echo(EchoArgs { + client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + nr_streams: 50, + stream_size: 25 * 1024 + 11, + receive_window: Some(37), + stream_receive_window: Some(37), + }) + .await; +} From efa73237e352e0d0d6ef913f2d08e702cabcd512 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Fri, 9 Aug 2024 03:17:16 +0800 Subject: [PATCH 09/26] fix(driver,iocp): fix incorrect opcode impl --- compio-driver/src/iocp/op.rs | 61 +++++++++++++++--------------------- compio-quic/tests/basic.rs | 1 + 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index 8369b234..e09b555e 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -1,13 +1,7 @@ #[cfg(feature = "once_cell_try")] use std::sync::OnceLock; use std::{ - io, - marker::PhantomPinned, - net::Shutdown, - os::windows::io::AsRawSocket, - pin::Pin, - ptr::{null, null_mut}, - task::Poll, + io, marker::PhantomPinned, net::Shutdown, os::windows::io::AsRawSocket, pin::Pin, ptr::{null, null_mut}, task::Poll }; use aligned_array::{Aligned, A8}; @@ -781,12 +775,12 @@ static WSA_RECVMSG: OnceLock = OnceLock::new(); /// Receive data and source address with ancillary data into vectored buffer. pub struct RecvMsg { + msg: WSAMSG, addr: SOCKADDR_STORAGE, - addr_len: socklen_t, fd: SharedFd, buffer: T, control: C, - control_len: u32, + slices: Vec, _p: PhantomPinned, } @@ -802,12 +796,12 @@ impl RecvMsg { "misaligned control message buffer" ); Self { + msg: unsafe { std::mem::zeroed() }, addr: unsafe { std::mem::zeroed() }, - addr_len: std::mem::size_of::() as _, fd, buffer, control, - control_len: 0, + slices: vec![], _p: PhantomPinned, } } @@ -820,8 +814,8 @@ impl IntoInner for RecvMsg { ( (self.buffer, self.control), self.addr, - self.addr_len, - self.control_len as _, + self.msg.namelen, + self.msg.Control.len as _, ) } } @@ -835,26 +829,22 @@ impl OpCode for RecvMsg { })?; let this = self.get_unchecked_mut(); - let mut slices = this.buffer.io_slices_mut(); - let mut msg = WSAMSG { - name: &mut this.addr as *mut _ as _, - namelen: this.addr_len, - lpBuffers: slices.as_mut_ptr() as _, - dwBufferCount: slices.len() as _, - Control: std::mem::transmute::(this.control.as_io_slice_mut()), - dwFlags: 0, - }; - this.control_len = 0; + + this.slices = this.buffer.io_slices_mut(); + this.msg.name = &mut this.addr as *mut _ as _; + this.msg.namelen = std::mem::size_of::() as _; + this.msg.lpBuffers = this.slices.as_mut_ptr() as _; + this.msg.dwBufferCount = this.slices.len() as _; + this.msg.Control = std::mem::transmute::(this.control.as_io_slice_mut()); let mut received = 0; let res = recvmsg_fn( this.fd.as_raw_fd() as _, - &mut msg, + &mut this.msg, &mut received, optr, None, ); - this.control_len = msg.Control.len; winsock_result(res, received) } @@ -866,10 +856,12 @@ impl OpCode for RecvMsg { /// Send data to specified address accompanied by ancillary data from vectored /// buffer. pub struct SendMsg { + msg: WSAMSG, fd: SharedFd, buffer: T, control: C, addr: SockAddr, + pub(crate) slices: Vec, _p: PhantomPinned, } @@ -885,10 +877,12 @@ impl SendMsg { "misaligned control message buffer" ); Self { + msg: unsafe { std::mem::zeroed() }, fd, buffer, control, addr, + slices: vec![], _p: PhantomPinned, } } @@ -906,18 +900,15 @@ impl OpCode for SendMsg { unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { let this = self.get_unchecked_mut(); - let slices = this.buffer.io_slices(); - let msg = WSAMSG { - name: this.addr.as_ptr() as _, - namelen: this.addr.len(), - lpBuffers: slices.as_ptr() as _, - dwBufferCount: slices.len() as _, - Control: std::mem::transmute::(this.control.as_io_slice()), - dwFlags: 0, - }; + this.slices = this.buffer.io_slices(); + this.msg.name = this.addr.as_ptr() as _; + this.msg.namelen = this.addr.len(); + this.msg.lpBuffers = this.slices.as_ptr() as _; + this.msg.dwBufferCount = this.slices.len() as _; + this.msg.Control = std::mem::transmute::(this.control.as_io_slice()); let mut sent = 0; - let res = WSASendMsg(this.fd.as_raw_fd() as _, &msg, 0, &mut sent, optr, None); + let res = WSASendMsg(this.fd.as_raw_fd() as _, &this.msg, 0, &mut sent, optr, None); winsock_result(res, sent) } diff --git a/compio-quic/tests/basic.rs b/compio-quic/tests/basic.rs index ccf720ab..dc87ff11 100644 --- a/compio-quic/tests/basic.rs +++ b/compio-quic/tests/basic.rs @@ -11,6 +11,7 @@ mod common; use common::{config_pair, subscribe}; #[compio_macros::test] +#[cfg_attr(target_os = "windows", ignore)] // FIXME: ERROR_PORT_UNREACHABLE async fn handshake_timeout() { let _guard = subscribe(); From 7fb5067a8371ce37fb7b788ad6c069ac35b330d3 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Fri, 9 Aug 2024 13:16:33 +0800 Subject: [PATCH 10/26] fix(quic): windows specific bug --- compio-quic/src/endpoint.rs | 2 ++ compio-quic/tests/basic.rs | 1 - compio-quic/tests/echo.rs | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 59d0f306..5c09c5d1 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -269,6 +269,8 @@ impl EndpointInner { match res { Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn), Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} + #[cfg(windows)] + Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} Err(e) => break Err(e), } recv_fut.set(self.socket.recv(recv_buf).fuse()); diff --git a/compio-quic/tests/basic.rs b/compio-quic/tests/basic.rs index dc87ff11..ccf720ab 100644 --- a/compio-quic/tests/basic.rs +++ b/compio-quic/tests/basic.rs @@ -11,7 +11,6 @@ mod common; use common::{config_pair, subscribe}; #[compio_macros::test] -#[cfg_attr(target_os = "windows", ignore)] // FIXME: ERROR_PORT_UNREACHABLE async fn handshake_timeout() { let _guard = subscribe(); diff --git a/compio-quic/tests/echo.rs b/compio-quic/tests/echo.rs index b76c392f..1ab51364 100644 --- a/compio-quic/tests/echo.rs +++ b/compio-quic/tests/echo.rs @@ -142,7 +142,7 @@ async fn echo_v4() { } #[compio_macros::test] -#[cfg_attr(target_os = "windows", ignore)] +#[cfg_attr(windows, ignore)] // FIXME: dual-stack socket on Windows async fn echo_dualstack() { let _guard = subscribe(); run_echo(EchoArgs { From afb0539cb59c98621d2f6b6f133bf51d9eb44d63 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Fri, 9 Aug 2024 13:53:33 +0800 Subject: [PATCH 11/26] chore(quic): interaction with tracing --- compio-quic/src/connection.rs | 9 ++++++++- compio-quic/src/endpoint.rs | 13 ++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index b03981d4..f5d4c52d 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -10,6 +10,7 @@ use std::{ use bytes::Bytes; use compio_buf::BufResult; +use compio_log::{error, Instrument}; use compio_runtime::JoinHandle; use event_listener::{Event, IntoNotification}; use flume::{Receiver, Sender}; @@ -384,7 +385,13 @@ impl Connecting { )); let worker = compio_runtime::spawn({ let inner = inner.clone(); - async move { inner.run().await.unwrap() } + async move { + #[allow(unused)] + if let Err(e) = inner.run().await { + error!("I/O error: {}", e); + } + } + .in_current_span() }); inner.state().worker = Some(worker); Self(inner) diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 5c09c5d1..954501c6 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -10,6 +10,7 @@ use std::{ }; use compio_buf::BufResult; +use compio_log::{error, Instrument}; use compio_net::{ToSocketAddrsAsync, UdpSocket}; use compio_runtime::JoinHandle; use event_listener::{Event, IntoNotification}; @@ -269,8 +270,8 @@ impl EndpointInner { match res { Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn), Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} - #[cfg(windows)] - Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} + // #[cfg(windows)] + // Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} Err(e) => break Err(e), } recv_fut.set(self.socket.recv(recv_buf).fuse()); @@ -310,7 +311,13 @@ impl Endpoint { let inner = Arc::new(EndpointInner::new(socket, config, server_config)?); let worker = compio_runtime::spawn({ let inner = inner.clone(); - async move { inner.run().await.unwrap() } + async move { + #[allow(unused)] + if let Err(e) = inner.run().await { + error!("I/O error: {}", e); + } + } + .in_current_span() }); inner.state.lock().unwrap().worker = Some(worker); Ok(Self { From 83b928baf826c3ae256252fa7fba1f101b99e682 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 15 Aug 2024 17:39:08 +0800 Subject: [PATCH 12/26] feat(quic): add bench --- Cargo.toml | 1 + compio-quic/Cargo.toml | 15 ++- compio-quic/benches/quic.rs | 196 +++++++++++++++++++++++++++++ compio-quic/examples/dispatcher.rs | 75 +++++++++++ compio-quic/examples/server.rs | 9 +- compio-quic/src/lib.rs | 2 +- compio-quic/src/recv_stream.rs | 32 ++++- compio-quic/src/send_stream.rs | 2 +- compio-quic/tests/common/mod.rs | 11 +- compio-quic/tests/echo.rs | 17 +-- compio/Cargo.toml | 5 +- compio/src/lib.rs | 3 + 12 files changed, 340 insertions(+), 28 deletions(-) create mode 100644 compio-quic/benches/quic.rs create mode 100644 compio-quic/examples/dispatcher.rs diff --git a/Cargo.toml b/Cargo.toml index 0a00f90f..ad7520b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ compio-dispatcher = { path = "./compio-dispatcher", version = "0.3.0" } compio-log = { path = "./compio-log", version = "0.1.0" } compio-tls = { path = "./compio-tls", version = "0.2.0", default-features = false } compio-process = { path = "./compio-process", version = "0.1.0" } +compio-quic = { path = "./compio-quic", version = "0.1.0" } flume = "0.11.0" cfg-if = "1.0.0" diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index da99a0a9..795cf084 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -43,16 +43,27 @@ windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } libc = { workspace = true } [dev-dependencies] +compio-dispatcher = { workspace = true } compio-driver = { workspace = true } compio-macros = { workspace = true } +compio-runtime = { workspace = true, features = ["criterion"] } + rand = "0.8.5" rcgen = "0.13.1" socket2 = { workspace = true, features = ["all"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +criterion = { workspace = true, features = ["async_tokio"] } +quinn = "0.11.3" +tokio = { workspace = true, features = ["rt", "macros"] } + [features] -default = ["webpki-roots"] -futures-io = ["futures-util/io"] +default = [] +io-compat = ["futures-util/io"] platform-verifier = ["dep:rustls-platform-verifier"] native-certs = ["dep:rustls-native-certs"] webpki-roots = ["dep:webpki-roots"] + +[[bench]] +name = "quic" +harness = false diff --git a/compio-quic/benches/quic.rs b/compio-quic/benches/quic.rs new file mode 100644 index 00000000..e318a069 --- /dev/null +++ b/compio-quic/benches/quic.rs @@ -0,0 +1,196 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::Instant, +}; + +use bytes::Bytes; +use criterion::{criterion_group, criterion_main, Bencher, Criterion, Throughput}; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use rand::{thread_rng, RngCore}; + +criterion_group!(quic, echo); +criterion_main!(quic); + +fn gen_cert() -> ( + rustls::pki_types::CertificateDer<'static>, + rustls::pki_types::PrivateKeyDer<'static>, +) { + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + (cert, key_der) +} + +macro_rules! echo_impl { + ($send:ident, $recv:ident) => { + loop { + // These are 32 buffers, for reading approximately 32kB at once + let mut bufs: [Bytes; 32] = std::array::from_fn(|_| Bytes::new()); + + match $recv.read_chunks(&mut bufs).await.unwrap() { + Some(n) => { + $send.write_all_chunks(&mut bufs[..n]).await.unwrap(); + } + None => break, + } + } + + let _ = $send.finish(); + }; +} + +fn echo_compio_quic(b: &mut Bencher, content: &[u8], streams: usize) { + use compio_quic::{ClientBuilder, ServerBuilder}; + + let runtime = compio_runtime::Runtime::new().unwrap(); + b.to_async(runtime).iter_custom(|iter| async move { + let (cert, key_der) = gen_cert(); + let server = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .bind("127.0.0.1:0") + .await + .unwrap(); + let client = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .bind("127.0.0.1:0") + .await + .unwrap(); + let addr = server.local_addr().unwrap(); + + let (client_conn, server_conn) = futures_util::join!( + async move { + client + .connect(addr, "localhost", None) + .unwrap() + .await + .unwrap() + }, + async move { server.wait_incoming().await.unwrap().await.unwrap() } + ); + + let start = Instant::now(); + let handle = compio_runtime::spawn(async move { + while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { + compio_runtime::spawn(async move { + echo_impl!(send, recv); + }) + .detach(); + } + }); + for _i in 0..iter { + let mut futures = (0..streams) + .map(|_| async { + let (mut send, mut recv) = client_conn.open_bi_wait().await.unwrap(); + futures_util::join!( + async { + send.write_all(content).await.unwrap(); + send.finish().unwrap(); + }, + async { + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + } + ); + }) + .collect::>(); + while futures.next().await.is_some() {} + } + drop(handle); + start.elapsed() + }) +} + +fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { + use quinn::{ClientConfig, Endpoint, ServerConfig}; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + b.to_async(&runtime).iter_custom(|iter| async move { + let (cert, key_der) = gen_cert(); + let server_config = ServerConfig::with_single_cert(vec![cert.clone()], key_der).unwrap(); + let mut roots = rustls::RootCertStore::empty(); + roots.add(cert).unwrap(); + let client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap(); + let server = Endpoint::server( + server_config, + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + ) + .unwrap(); + let mut client = + Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(); + client.set_default_client_config(client_config); + let addr = server.local_addr().unwrap(); + + let (client_conn, server_conn) = futures_util::join!( + async move { client.connect(addr, "localhost").unwrap().await.unwrap() }, + async move { server.accept().await.unwrap().await.unwrap() } + ); + + let start = Instant::now(); + tokio::spawn(async move { + while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { + tokio::spawn(async move { + echo_impl!(send, recv); + }); + } + }); + for _i in 0..iter { + let mut futures = (0..streams) + .map(|_| async { + let (mut send, mut recv) = client_conn.open_bi().await.unwrap(); + tokio::join!( + async { + send.write_all(content).await.unwrap(); + send.finish().unwrap(); + }, + async { + recv.read_to_end(usize::MAX).await.unwrap(); + } + ); + }) + .collect::>(); + while futures.next().await.is_some() {} + } + start.elapsed() + }); +} + +fn echo(c: &mut Criterion) { + let mut rng = thread_rng(); + + let mut large_data = [0u8; 1024 * 1024]; + rng.fill_bytes(&mut large_data); + + let mut small_data = [0u8; 10]; + rng.fill_bytes(&mut small_data); + + let mut group = c.benchmark_group("echo-large-data-1-stream"); + group.throughput(Throughput::Bytes((large_data.len() * 2) as u64)); + + group.bench_function("compio-quic", |b| echo_compio_quic(b, &large_data, 1)); + group.bench_function("quinn", |b| echo_quinn(b, &large_data, 1)); + + group.finish(); + + let mut group = c.benchmark_group("echo-large-data-10-streams"); + group.throughput(Throughput::Bytes((large_data.len() * 10 * 2) as u64)); + + group.bench_function("compio-quic", |b| echo_compio_quic(b, &large_data, 10)); + group.bench_function("quinn", |b| echo_quinn(b, &large_data, 10)); + + group.finish(); + + let mut group = c.benchmark_group("echo-small-data-100-streams"); + group.throughput(Throughput::Bytes((small_data.len() * 10 * 2) as u64)); + + group.bench_function("compio-quic", |b| echo_compio_quic(b, &small_data, 100)); + group.bench_function("quinn", |b| echo_quinn(b, &small_data, 100)); + + group.finish(); +} diff --git a/compio-quic/examples/dispatcher.rs b/compio-quic/examples/dispatcher.rs new file mode 100644 index 00000000..851debcf --- /dev/null +++ b/compio-quic/examples/dispatcher.rs @@ -0,0 +1,75 @@ +use std::num::NonZeroUsize; + +use compio_dispatcher::Dispatcher; +use compio_quic::{ClientBuilder, Endpoint, ServerBuilder}; +use compio_runtime::spawn; +use futures_util::{stream::FuturesUnordered, StreamExt}; + +#[compio_macros::main] +async fn main() { + const THREAD_NUM: usize = 5; + const CLIENT_NUM: usize = 10; + + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let server_config = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) + .unwrap() + .build(); + let client_config = ClientBuilder::new_with_empty_roots() + .with_custom_certificate(cert) + .unwrap() + .with_no_crls() + .build(); + let mut endpoint = Endpoint::server("127.0.0.1:0", server_config) + .await + .unwrap(); + endpoint.default_client_config = Some(client_config); + + spawn({ + let endpoint = endpoint.clone(); + async move { + let mut futures = FuturesUnordered::from_iter((0..CLIENT_NUM).map(|i| { + let endpoint = &endpoint; + async move { + let conn = endpoint + .connect(endpoint.local_addr().unwrap(), "localhost", None) + .unwrap() + .await + .unwrap(); + let mut send = conn.open_uni().unwrap(); + send.write_all(format!("Hello world {}!", i).as_bytes()) + .await + .unwrap(); + send.finish().unwrap(); + send.stopped().await.unwrap(); + } + })); + while let Some(()) = futures.next().await {} + } + }) + .detach(); + + let dispatcher = Dispatcher::builder() + .worker_threads(NonZeroUsize::new(THREAD_NUM).unwrap()) + .build() + .unwrap(); + let mut handles = FuturesUnordered::new(); + for _i in 0..CLIENT_NUM { + let incoming = endpoint.wait_incoming().await.unwrap(); + let handle = dispatcher + .dispatch(move || async move { + let conn = incoming.await.unwrap(); + let mut recv = conn.accept_uni().await.unwrap(); + let mut buf = vec![]; + recv.read_to_end(&mut buf).await.unwrap(); + println!("{}", std::str::from_utf8(&buf).unwrap()); + }) + .unwrap(); + handles.push(handle); + } + while handles.next().await.is_some() {} + dispatcher.join().await.unwrap(); +} diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs index 9e52e8a9..3a380f88 100644 --- a/compio-quic/examples/server.rs +++ b/compio-quic/examples/server.rs @@ -7,11 +7,12 @@ async fn main() { .with_env_filter(EnvFilter::from_default_env()) .init(); - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let cert_chain = vec![cert.cert.into()]; - let key_der = cert.key_pair.serialize_der().try_into().unwrap(); + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); - let endpoint = ServerBuilder::new_with_single_cert(cert_chain, key_der) + let endpoint = ServerBuilder::new_with_single_cert(vec![cert], key_der) .unwrap() .with_key_log() .bind("[::1]:4433") diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs index 9144feb1..d73a3c65 100644 --- a/compio-quic/src/lib.rs +++ b/compio-quic/src/lib.rs @@ -25,7 +25,7 @@ pub use builder::{ClientBuilder, ServerBuilder}; pub use connection::{Connecting, Connection}; pub use endpoint::Endpoint; pub use incoming::{Incoming, IncomingFuture}; -pub use recv_stream::{ReadError, RecvStream}; +pub use recv_stream::{ReadError, ReadExactError, RecvStream}; pub use send_stream::{SendStream, WriteError}; pub(crate) use crate::{ diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs index 5a12f99a..723c5075 100644 --- a/compio-quic/src/recv_stream.rs +++ b/compio-quic/src/recv_stream.rs @@ -8,7 +8,7 @@ use std::{ use bytes::{BufMut, Bytes}; use compio_buf::{BufResult, IoBufMut}; use compio_io::AsyncRead; -use futures_util::future::poll_fn; +use futures_util::{future::poll_fn, ready}; use quinn_proto::{Chunk, Chunks, ClosedStream, ConnectionError, ReadableError, StreamId, VarInt}; use thiserror::Error; @@ -261,6 +261,23 @@ impl RecvStream { poll_fn(|cx| self.poll_read(cx, &mut buf)).await } + /// Read an exact number of bytes contiguously from the stream. + /// + /// See [`read()`] for details. This operation is *not* cancel-safe. + /// + /// [`read()`]: RecvStream::read + pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> { + poll_fn(|cx| { + while buf.has_remaining_mut() { + if ready!(self.poll_read(cx, &mut buf))?.is_none() { + return Poll::Ready(Err(ReadExactError::FinishedEarly(buf.remaining_mut()))); + } + } + Poll::Ready(Ok(())) + }) + .await + } + /// Read the next segment of data. /// /// Yields `None` if the stream was finished. Otherwise, yields a segment of @@ -470,6 +487,17 @@ impl From for io::Error { } } +/// Errors that arise from reading from a stream. +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ReadExactError { + /// The stream finished before all bytes were read + #[error("stream finished early (expected {0} bytes more)")] + FinishedEarly(usize), + /// A read error occurred + #[error(transparent)] + ReadError(#[from] ReadError), +} + impl AsyncRead for RecvStream { async fn read(&mut self, mut buf: B) -> BufResult { let res = self @@ -485,7 +513,7 @@ impl AsyncRead for RecvStream { } } -#[cfg(feature = "futures-io")] +#[cfg(feature = "io-compat")] impl futures_util::AsyncRead for RecvStream { fn poll_read( self: std::pin::Pin<&mut Self>, diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs index bf8fc41a..7801e726 100644 --- a/compio-quic/src/send_stream.rs +++ b/compio-quic/src/send_stream.rs @@ -341,7 +341,7 @@ impl AsyncWrite for SendStream { } } -#[cfg(feature = "futures-io")] +#[cfg(feature = "io-compat")] impl futures_util::AsyncWrite for SendStream { fn poll_write( self: std::pin::Pin<&mut Self>, diff --git a/compio-quic/tests/common/mod.rs b/compio-quic/tests/common/mod.rs index 05fbe3f0..08745b3d 100644 --- a/compio-quic/tests/common/mod.rs +++ b/compio-quic/tests/common/mod.rs @@ -12,15 +12,16 @@ pub fn subscribe() -> DefaultGuard { } pub fn config_pair(transport: Option) -> (ServerConfig, ClientConfig) { - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let cert_chain = vec![cert.cert.der().clone()]; - let key_der = cert.key_pair.serialize_der().try_into().unwrap(); + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); - let mut server_config = ServerBuilder::new_with_single_cert(cert_chain, key_der) + let mut server_config = ServerBuilder::new_with_single_cert(vec![cert.clone()], key_der) .unwrap() .build(); let mut client_config = ClientBuilder::new_with_empty_roots() - .with_custom_certificate(cert.cert.into()) + .with_custom_certificate(cert) .unwrap() .with_no_crls() .build(); diff --git a/compio-quic/tests/echo.rs b/compio-quic/tests/echo.rs index 1ab51364..69d942ad 100644 --- a/compio-quic/tests/echo.rs +++ b/compio-quic/tests/echo.rs @@ -1,4 +1,7 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::{ + array, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, +}; use bytes::Bytes; use compio_quic::{Endpoint, RecvStream, SendStream, TransportConfig}; @@ -20,17 +23,7 @@ struct EchoArgs { async fn echo((mut send, mut recv): (SendStream, RecvStream)) { loop { // These are 32 buffers, for reading approximately 32kB at once - #[rustfmt::skip] - let mut bufs = [ - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), - ]; + let mut bufs: [Bytes; 32] = array::from_fn(|_| Bytes::new()); match recv.read_chunks(&mut bufs).await.unwrap() { Some(n) => { diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 692d7ac1..8cbb3715 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -42,6 +42,7 @@ compio-dispatcher = { workspace = true, optional = true } compio-log = { workspace = true } compio-tls = { workspace = true, optional = true } compio-process = { workspace = true, optional = true } +compio-quic = { workspace = true, optional = true } # Shared dev dependencies for all platforms [dev-dependencies] @@ -83,7 +84,7 @@ io-uring = [ ] polling = ["compio-driver/polling"] io = ["dep:compio-io"] -io-compat = ["io", "compio-io/compat"] +io-compat = ["io", "compio-io/compat", "compio-quic/io-compat"] runtime = ["dep:compio-runtime", "dep:compio-fs", "dep:compio-net", "io"] macros = ["dep:compio-macros", "runtime"] event = ["compio-runtime/event", "runtime"] @@ -94,6 +95,7 @@ tls = ["dep:compio-tls"] native-tls = ["tls", "compio-tls/native-tls"] rustls = ["tls", "compio-tls/rustls"] process = ["dep:compio-process"] +quic = ["dep:compio-quic"] all = [ "time", "macros", @@ -102,6 +104,7 @@ all = [ "native-tls", "rustls", "process", + "quic", ] arrayvec = ["compio-buf/arrayvec"] diff --git a/compio/src/lib.rs b/compio/src/lib.rs index 244d8b37..8b6c5c09 100644 --- a/compio/src/lib.rs +++ b/compio/src/lib.rs @@ -41,6 +41,9 @@ pub use compio_macros::*; #[cfg(feature = "process")] #[doc(inline)] pub use compio_process as process; +#[cfg(feature = "quic")] +#[doc(inline)] +pub use compio_quic as quic; #[cfg(feature = "signal")] #[doc(inline)] pub use compio_signal as signal; From aae6db4b5a9419e765400ef11af8f2fd29ed88e7 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 15 Aug 2024 20:50:32 +0800 Subject: [PATCH 13/26] feat(quic): improved close logic --- compio-quic/examples/client.rs | 4 +- compio-quic/examples/server.rs | 3 +- compio-quic/src/connection.rs | 61 ++++++++++++++++++++--------- compio-quic/src/endpoint.rs | 71 +++++++++++++++++++++------------- compio-quic/tests/basic.rs | 4 +- 5 files changed, 93 insertions(+), 50 deletions(-) diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/client.rs index e634649b..167bbab1 100644 --- a/compio-quic/examples/client.rs +++ b/compio-quic/examples/client.rs @@ -36,8 +36,8 @@ async fn main() { recv.read_to_end(&mut buf).await.unwrap(); println!("{:?}", buf); - conn.close(1u32.into(), "bye"); + conn.close(1u32.into(), b"bye"); } - endpoint.close(0u32.into(), "").await.unwrap(); + endpoint.shutdown().await.unwrap(); } diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/server.rs index 3a380f88..20b6c01b 100644 --- a/compio-quic/examples/server.rs +++ b/compio-quic/examples/server.rs @@ -34,5 +34,6 @@ async fn main() { conn.closed().await; } - endpoint.close(0u32.into(), "").await.unwrap(); + endpoint.close(0u32.into(), b""); + endpoint.shutdown().await.unwrap(); } diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index f5d4c52d..df9905ca 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -28,7 +28,7 @@ use crate::{wait_event, RecvStream, SendStream, Socket}; #[derive(Debug)] pub(crate) enum ConnectionEvent { - Close(VarInt, String), + Close(VarInt, Bytes), Proto(quinn_proto::ConnectionEvent), } @@ -165,9 +165,9 @@ impl ConnectionInner { } } - fn close(&self, error_code: VarInt, reason: String) { + fn close(&self, error_code: VarInt, reason: Bytes) { let mut state = self.state(); - state.conn.close(Instant::now(), error_code, reason.into()); + state.conn.close(Instant::now(), error_code, reason); state.terminate(ConnectionError::LocallyClosed); state.wake(); self.notify_events(); @@ -497,7 +497,7 @@ impl Future for Connecting { impl Drop for Connecting { fn drop(&mut self) { if Arc::strong_count(&self.0) == 2 { - self.0.close(0u32.into(), String::new()) + self.0.close(0u32.into(), Bytes::new()) } } } @@ -575,23 +575,41 @@ impl Connection { /// Close the connection immediately. /// /// Pending operations will fail immediately with - /// [`ConnectionError::LocallyClosed`]. Delivery of data on unfinished - /// streams is not guaranteed, so the application must call this only when - /// all important communications have been completed, e.g. by calling - /// [`finish`] on outstanding [`SendStream`]s and waiting for the resulting - /// futures to complete. + /// [`ConnectionError::LocallyClosed`]. No more data is sent to the peer + /// and the peer may drop buffered data upon receiving + /// the CONNECTION_CLOSE frame. /// /// `error_code` and `reason` are not interpreted, and are provided directly /// to the peer. /// /// `reason` will be truncated to fit in a single packet with overhead; to - /// improve odds that it is preserved in full, it should be kept under 1KiB. - /// - /// [`ConnectionError::LocallyClosed`]: quinn_proto::ConnectionError::LocallyClosed - /// [`finish`]: crate::SendStream::finish - /// [`SendStream`]: crate::SendStream - pub fn close(&self, error_code: VarInt, reason: &str) { - self.0.close(error_code, reason.to_string()); + /// improve odds that it is preserved in full, it should be kept under + /// 1KiB. + /// + /// # Gracefully closing a connection + /// + /// Only the peer last receiving application data can be certain that all + /// data is delivered. The only reliable action it can then take is to + /// close the connection, potentially with a custom error code. The + /// delivery of the final CONNECTION_CLOSE frame is very likely if both + /// endpoints stay online long enough, and [`Endpoint::shutdown()`] can + /// be used to provide sufficient time. Otherwise, the remote peer will + /// time out the connection, provided that the idle timeout is not + /// disabled. + /// + /// The sending side can not guarantee all stream data is delivered to the + /// remote application. It only knows the data is delivered to the QUIC + /// stack of the remote endpoint. Once the local side sends a + /// CONNECTION_CLOSE frame in response to calling [`close()`] the remote + /// endpoint may drop any data it received but is as yet undelivered to + /// the application, including data that was acknowledged as received to + /// the local endpoint. + /// + /// [`ConnectionError::LocallyClosed`]: ConnectionError::LocallyClosed + /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown + /// [`close()`]: Connection::close + pub fn close(&self, error_code: VarInt, reason: &[u8]) { + self.0.close(error_code, Bytes::copy_from_slice(reason)); } /// Wait for the connection to be closed for any reason. @@ -601,7 +619,14 @@ impl Connection { let _ = worker.await; } - self.0.state().error.clone().unwrap() + self.0.try_state().unwrap_err() + } + + /// If the connection is closed, the reason why. + /// + /// Returns `None` if the connection is still open. + pub fn close_reason(&self) -> Option { + self.0.try_state().err() } /// Receive an application datagram. @@ -807,7 +832,7 @@ impl Eq for Connection {} impl Drop for Connection { fn drop(&mut self) { if Arc::strong_count(&self.0) == 2 { - self.close(0u32.into(), "") + self.close(0u32.into(), b"") } } } diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 954501c6..3288d206 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -9,6 +9,7 @@ use std::{ time::Instant, }; +use bytes::Bytes; use compio_buf::BufResult; use compio_log::{error, Instrument}; use compio_net::{ToSocketAddrsAsync, UdpSocket}; @@ -33,7 +34,8 @@ struct EndpointState { endpoint: quinn_proto::Endpoint, worker: Option>, connections: HashMap>, - close: Option<(VarInt, String)>, + close: Option<(VarInt, Bytes)>, + exit_on_idle: bool, incoming: VecDeque, } @@ -149,6 +151,7 @@ impl EndpointInner { worker: None, connections: HashMap::new(), close: None, + exit_on_idle: false, incoming: VecDeque::new(), }), socket, @@ -270,8 +273,8 @@ impl EndpointInner { match res { Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn), Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} - // #[cfg(windows)] - // Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} + #[cfg(windows)] + Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} Err(e) => break Err(e), } recv_fut.set(self.socket.recv(recv_buf).fuse()); @@ -282,7 +285,7 @@ impl EndpointInner { } let state = self.state.lock().unwrap(); - if state.close.is_some() && state.is_idle() { + if state.exit_on_idle && state.is_idle() { break Ok(()); } if !state.incoming.is_empty() { @@ -415,6 +418,25 @@ impl Endpoint { self.inner.state.lock().unwrap().endpoint.open_connections() } + /// Close all of this endpoint's connections immediately and cease accepting + /// new connections. + /// + /// See [`Connection::close()`] for details. + /// + /// [`Connection::close()`]: crate::Connection::close + pub fn close(&self, error_code: VarInt, reason: &[u8]) { + let reason = Bytes::copy_from_slice(reason); + let mut state = self.inner.state.lock().unwrap(); + if state.close.is_some() { + return; + } + state.close = Some((error_code, reason.clone())); + for conn in state.connections.values() { + let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone())); + } + self.inner.incoming.notify(usize::MAX.additional()); + } + // Modified from [`SharedFd::try_unwrap_inner`], see notes there. unsafe fn try_unwrap_inner(this: &ManuallyDrop) -> Option { let ptr = ManuallyDrop::new(std::ptr::read(&this.inner)); @@ -427,36 +449,29 @@ impl Endpoint { } } - /// Shutdown the endpoint and close the underlying socket. + /// Gracefully shutdown the endpoint. /// - /// This will close all connections and the underlying socket. Note that it - /// will wait for all connections and all clones of the endpoint (and any - /// clone of the underlying socket) to be dropped before closing the socket. + /// Wait for all connections on the endpoint to be cleanly shut down and + /// close the underlying socket. This will wait for all clones of the + /// endpoint, all connections and all streams to be dropped before + /// closing the socket. /// - /// If the endpoint has already been closed or is closing, this will return - /// immediately with `Ok(())`. + /// Waiting for this condition before exiting ensures that a good-faith + /// effort is made to notify peers of recent connection closes, whereas + /// exiting immediately could force them to wait out the idle timeout + /// period. /// - /// See [`Connection::close()`](crate::Connection::close) for details. - pub async fn close(self, error_code: VarInt, reason: &str) -> io::Result<()> { - let reason = reason.to_string(); - - { - let mut state = self.inner.state.lock().unwrap(); - if state.close.is_some() { - return Ok(()); - } - state.close = Some((error_code, reason.clone())); - } - - for conn in self.inner.state.lock().unwrap().connections.values() { - let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone())); - } - + /// Does not proactively close existing connections. Consider calling + /// [`close()`] if that is desired. + /// + /// [`close()`]: Endpoint::close + pub async fn shutdown(self) -> io::Result<()> { let worker = self.inner.state.lock().unwrap().worker.take(); if let Some(worker) = worker { if self.inner.state.lock().unwrap().is_idle() { worker.cancel().await; } else { + self.inner.state.lock().unwrap().exit_on_idle = true; let _ = worker.await; } } @@ -484,7 +499,11 @@ impl Endpoint { impl Drop for Endpoint { fn drop(&mut self) { if Arc::strong_count(&self.inner) == 2 { + // There are actually two cases: + // 1. User is trying to shutdown the socket. self.inner.done.wake(); + // 2. User dropped the endpoint but the worker is still running. + self.inner.state.lock().unwrap().exit_on_idle = true; } } } diff --git a/compio-quic/tests/basic.rs b/compio-quic/tests/basic.rs index ccf720ab..415dae3b 100644 --- a/compio-quic/tests/basic.rs +++ b/compio-quic/tests/basic.rs @@ -59,9 +59,7 @@ async fn close_endpoint() { None, ) .unwrap(); - - compio_runtime::spawn(endpoint.close(0u32.into(), "")).detach(); - + endpoint.close(0u32.into(), b""); match conn.await { Err(ConnectionError::LocallyClosed) => (), Err(e) => panic!("unexpected error: {e}"), From 13fa97823f070623025e3d5c1aabb531c94c9f9c Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 17 Aug 2024 23:14:55 +0800 Subject: [PATCH 14/26] fix(net): mark get/set_socket_option as unsafe --- compio-net/src/socket.rs | 12 ++++++------ compio-net/src/udp.rs | 12 ++++++++++-- compio-net/tests/udp.rs | 8 +++++--- compio-quic/src/socket.rs | 23 ++++++++++++++++------- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 74d8b7ee..03bb8b27 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -324,7 +324,7 @@ impl Socket { } #[cfg(unix)] - pub fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { let mut value: MaybeUninit = MaybeUninit::uninit(); let mut len = size_of::() as libc::socklen_t; syscall!(libc::getsockopt( @@ -337,12 +337,12 @@ impl Socket { .map(|_| { debug_assert_eq!(len as usize, size_of::()); // SAFETY: The value is initialized by `getsockopt`. - unsafe { value.assume_init() } + value.assume_init() }) } #[cfg(windows)] - pub fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { let mut value: MaybeUninit = MaybeUninit::uninit(); let mut len = size_of::() as i32; syscall!( @@ -358,12 +358,12 @@ impl Socket { .map(|_| { debug_assert_eq!(len as usize, size_of::()); // SAFETY: The value is initialized by `getsockopt`. - unsafe { value.assume_init() } + value.assume_init() }) } #[cfg(unix)] - pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { syscall!(libc::setsockopt( self.socket.as_raw_fd(), level, @@ -375,7 +375,7 @@ impl Socket { } #[cfg(windows)] - pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { syscall!( SOCKET, windows_sys::Win32::Networking::WinSock::setsockopt( diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index d0855833..7a28025a 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -317,12 +317,20 @@ impl UdpSocket { } /// Gets a socket option. - pub fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + /// + /// # Safety + /// + /// The caller must ensure `T` is the correct type for `level` and `name`. + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { self.inner.get_socket_option(level, name) } /// Sets a socket option. - pub fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + /// + /// # Safety + /// + /// The caller must ensure `T` is the correct type for `level` and `name`. + pub unsafe fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { self.inner.set_socket_option(level, name, value) } } diff --git a/compio-net/tests/udp.rs b/compio-net/tests/udp.rs index d813290e..7cea74c9 100644 --- a/compio-net/tests/udp.rs +++ b/compio-net/tests/udp.rs @@ -79,9 +79,11 @@ async fn send_msg_with_ipv6_ecn() { let passive = UdpSocket::bind("[::1]:0").await.unwrap(); let passive_addr = passive.local_addr().unwrap(); - passive - .set_socket_option(IPPROTO_IPV6, IPV6_RECVTCLASS, &1) - .unwrap(); + unsafe { + passive + .set_socket_option(IPPROTO_IPV6, IPV6_RECVTCLASS, &1) + .unwrap(); + } let active = UdpSocket::bind("[::1]:0").await.unwrap(); let active_addr = active.local_addr().unwrap(); diff --git a/compio-quic/src/socket.rs b/compio-quic/src/socket.rs index c3f66454..4ea39e97 100644 --- a/compio-quic/src/socket.rs +++ b/compio-quic/src/socket.rs @@ -118,13 +118,17 @@ impl DerefMut for Ancillary { #[cfg(target_os = "linux")] #[inline] fn max_gso_segments(socket: &UdpSocket) -> io::Result { - socket.get_socket_option::(libc::SOL_UDP, libc::UDP_SEGMENT)?; + unsafe { + socket.get_socket_option::(libc::SOL_UDP, libc::UDP_SEGMENT)?; + } Ok(64) } #[cfg(windows)] #[inline] fn max_gso_segments(socket: &UdpSocket) -> io::Result { - socket.get_socket_option::(WinSock::IPPROTO_UDP, WinSock::UDP_SEND_MSG_SIZE)?; + unsafe { + socket.get_socket_option::(WinSock::IPPROTO_UDP, WinSock::UDP_SEND_MSG_SIZE)?; + } Ok(512) } #[cfg(not(any(target_os = "linux", windows)))] @@ -135,7 +139,7 @@ fn max_gso_segments(_socket: &UdpSocket) -> io::Result { macro_rules! set_socket_option { ($socket:expr, $level:expr, $name:expr, $value:expr $(,)?) => { - match $socket.set_socket_option($level, $name, $value) { + match unsafe { $socket.set_socket_option($level, $name, $value) } { Ok(()) => true, Err(e) => { compio_log::warn!( @@ -178,11 +182,16 @@ impl Socket { pub fn new(socket: UdpSocket) -> io::Result { let is_ipv6 = socket.local_addr()?.is_ipv6(); #[cfg(unix)] - let only_v6 = is_ipv6 - && socket.get_socket_option::(libc::IPPROTO_IPV6, libc::IPV6_V6ONLY)? != 0; + let only_v6 = unsafe { + is_ipv6 + && socket.get_socket_option::(libc::IPPROTO_IPV6, libc::IPV6_V6ONLY)? + != 0 + }; #[cfg(windows)] - let only_v6 = is_ipv6 - && socket.get_socket_option::(WinSock::IPPROTO_IPV6, WinSock::IPV6_V6ONLY)? != 0; + let only_v6 = unsafe { + is_ipv6 + && socket.get_socket_option::(WinSock::IPPROTO_IPV6, WinSock::IPV6_V6ONLY)? != 0 + }; let is_ipv4 = socket.local_addr()?.is_ipv4() || !only_v6; // ECN From c329a7c0a238ecc613b1c6182d46da43385d0207 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 17 Aug 2024 23:19:47 +0800 Subject: [PATCH 15/26] test(net): remove redundant send_msg test --- compio-net/tests/udp.rs | 58 +---------------------------------------- 1 file changed, 1 insertion(+), 57 deletions(-) diff --git a/compio-net/tests/udp.rs b/compio-net/tests/udp.rs index 7cea74c9..699dec25 100644 --- a/compio-net/tests/udp.rs +++ b/compio-net/tests/udp.rs @@ -1,4 +1,4 @@ -use compio_net::{CMsgBuilder, CMsgIter, UdpSocket}; +use compio_net::UdpSocket; #[compio_macros::test] async fn connect() { @@ -64,59 +64,3 @@ async fn send_to() { active_addr ); } - -#[compio_macros::test] -async fn send_msg_with_ipv6_ecn() { - #[cfg(unix)] - use libc::{IPPROTO_IPV6, IPV6_RECVTCLASS, IPV6_TCLASS}; - #[cfg(windows)] - use windows_sys::Win32::Networking::WinSock::{ - IPPROTO_IPV6, IPV6_ECN, IPV6_RECVTCLASS, IPV6_TCLASS, - }; - - const MSG: &str = "foo bar baz"; - - let passive = UdpSocket::bind("[::1]:0").await.unwrap(); - let passive_addr = passive.local_addr().unwrap(); - - unsafe { - passive - .set_socket_option(IPPROTO_IPV6, IPV6_RECVTCLASS, &1) - .unwrap(); - } - - let active = UdpSocket::bind("[::1]:0").await.unwrap(); - let active_addr = active.local_addr().unwrap(); - - let mut control = vec![0u8; 32]; - let mut builder = CMsgBuilder::new(&mut control); - - const ECN_BITS: i32 = 0b10; - - #[cfg(unix)] - builder - .try_push(IPPROTO_IPV6, IPV6_TCLASS, ECN_BITS) - .unwrap(); - #[cfg(windows)] - builder.try_push(IPPROTO_IPV6, IPV6_ECN, ECN_BITS).unwrap(); - - let len = builder.finish(); - control.truncate(len); - - active.send_msg(MSG, control, passive_addr).await.unwrap(); - - let ((_, _, addr), (buffer, control)) = passive - .recv_msg(Vec::with_capacity(20), Vec::with_capacity(32)) - .await - .unwrap(); - assert_eq!(addr, active_addr); - assert_eq!(buffer, MSG.as_bytes()); - unsafe { - let mut iter = CMsgIter::new(&control); - let cmsg = iter.next().unwrap(); - assert_eq!(cmsg.level(), IPPROTO_IPV6); - assert_eq!(cmsg.ty(), IPV6_TCLASS); - assert_eq!(cmsg.data::(), &ECN_BITS); - assert!(iter.next().is_none()); - } -} From 32e07e06f94af194cebcbfed911581b02fc1f635 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Tue, 20 Aug 2024 02:18:39 +0800 Subject: [PATCH 16/26] test(quic): ignore echo_dualstack on unsupported platforms --- compio-quic/tests/echo.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compio-quic/tests/echo.rs b/compio-quic/tests/echo.rs index 69d942ad..9ac40196 100644 --- a/compio-quic/tests/echo.rs +++ b/compio-quic/tests/echo.rs @@ -135,7 +135,7 @@ async fn echo_v4() { } #[compio_macros::test] -#[cfg_attr(windows, ignore)] // FIXME: dual-stack socket on Windows +#[cfg_attr(any(target_os = "openbsd", target_os = "netbsd", windows), ignore)] async fn echo_dualstack() { let _guard = subscribe(); run_echo(EchoArgs { From 20a3683892317ddd0addba17de3c71fe8a988c0a Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Tue, 20 Aug 2024 17:26:56 +0800 Subject: [PATCH 17/26] feat(quic): remove event-listener --- compio-quic/Cargo.toml | 1 - compio-quic/src/connection.rs | 221 +++++++++++++++++----------------- compio-quic/src/endpoint.rs | 37 +++--- compio-quic/src/lib.rs | 12 -- compio-quic/tests/basic.rs | 8 +- 5 files changed, 132 insertions(+), 147 deletions(-) diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 795cf084..08abac05 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -30,7 +30,6 @@ webpki-roots = { version = "0.26.3", optional = true } # Utils bytes = "1.7.1" -event-listener = "5.3.1" flume = { workspace = true } futures-util = { workspace = true } thiserror = "1.0.63" diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index df9905ca..44d7ce55 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, io, net::{IpAddr, SocketAddr}, pin::{pin, Pin}, @@ -12,7 +12,6 @@ use bytes::Bytes; use compio_buf::BufResult; use compio_log::{error, Instrument}; use compio_runtime::JoinHandle; -use event_listener::{Event, IntoNotification}; use flume::{Receiver, Sender}; use futures_util::{ future::{self, Fuse, FusedFuture, LocalBoxFuture}, @@ -24,7 +23,7 @@ use quinn_proto::{ }; use thiserror::Error; -use crate::{wait_event, RecvStream, SendStream, Socket}; +use crate::{RecvStream, SendStream, Socket}; #[derive(Debug)] pub(crate) enum ConnectionEvent { @@ -41,6 +40,10 @@ pub(crate) struct ConnectionState { poll_waker: Option, on_connected: Option, on_handshake_data: Option, + datagram_received: VecDeque, + datagrams_unblocked: VecDeque, + stream_opened: [VecDeque; 2], + stream_available: [VecDeque; 2], pub(crate) writable: HashMap, pub(crate) readable: HashMap, pub(crate) stopped: HashMap, @@ -57,6 +60,14 @@ impl ConnectionState { if let Some(waker) = self.on_connected.take() { waker.wake() } + self.datagram_received.drain(..).for_each(Waker::wake); + self.datagrams_unblocked.drain(..).for_each(Waker::wake); + for e in &mut self.stream_opened { + e.drain(..).for_each(Waker::wake); + } + for e in &mut self.stream_available { + e.drain(..).for_each(Waker::wake); + } wake_all_streams(&mut self.writable); wake_all_streams(&mut self.readable); wake_all_streams(&mut self.stopped); @@ -101,10 +112,6 @@ pub(crate) struct ConnectionInner { socket: Socket, events_tx: Sender<(ConnectionHandle, EndpointEvent)>, events_rx: Receiver, - datagram_received: Event, - datagrams_unblocked: Event, - stream_opened: [Event; 2], - stream_available: [Event; 2], } impl ConnectionInner { @@ -124,6 +131,10 @@ impl ConnectionInner { poll_waker: None, on_connected: None, on_handshake_data: None, + datagram_received: VecDeque::new(), + datagrams_unblocked: VecDeque::new(), + stream_opened: [VecDeque::new(), VecDeque::new()], + stream_available: [VecDeque::new(), VecDeque::new()], writable: HashMap::new(), readable: HashMap::new(), stopped: HashMap::new(), @@ -132,10 +143,6 @@ impl ConnectionInner { socket, events_tx, events_rx, - datagram_received: Event::new(), - datagrams_unblocked: Event::new(), - stream_opened: [Event::new(), Event::new()], - stream_available: [Event::new(), Event::new()], } } @@ -154,23 +161,11 @@ impl ConnectionInner { } } - fn notify_events(&self) { - self.datagram_received.notify(usize::MAX.additional()); - self.datagrams_unblocked.notify(usize::MAX.additional()); - for e in &self.stream_opened { - e.notify(usize::MAX.additional()); - } - for e in &self.stream_available { - e.notify(usize::MAX.additional()); - } - } - fn close(&self, error_code: VarInt, reason: Bytes) { let mut state = self.state(); state.conn.close(Instant::now(), error_code, reason); state.terminate(ConnectionError::LocallyClosed); state.wake(); - self.notify_events(); } async fn run(&self) -> io::Result<()> { @@ -257,10 +252,7 @@ impl ConnectionInner { wake_all_streams(&mut state.stopped); } } - ConnectionLost { reason } => { - state.terminate(reason); - self.notify_events(); - } + ConnectionLost { reason } => state.terminate(reason), Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable), Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable), Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped), @@ -268,18 +260,14 @@ impl ConnectionInner { wake_stream(id, &mut state.stopped); wake_stream(id, &mut state.writable); } - Stream(StreamEvent::Available { dir }) => { - self.stream_available[dir as usize].notify(usize::MAX.additional()); - } - Stream(StreamEvent::Opened { dir }) => { - self.stream_opened[dir as usize].notify(usize::MAX.additional()); - } - DatagramReceived => { - self.datagram_received.notify(usize::MAX.additional()); - } - DatagramsUnblocked => { - self.datagrams_unblocked.notify(usize::MAX.additional()); - } + Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize] + .drain(..) + .for_each(Waker::wake), + Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize] + .drain(..) + .for_each(Waker::wake), + DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake), + DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake), } } @@ -629,28 +617,42 @@ impl Connection { self.0.try_state().err() } + fn poll_recv_datagram(&self, cx: &mut Context) -> Poll> { + let mut state = self.0.try_state()?; + if let Some(bytes) = state.conn.datagrams().recv() { + return Poll::Ready(Ok(bytes)); + } + state.datagram_received.push_back(cx.waker().clone()); + Poll::Pending + } + /// Receive an application datagram. pub async fn recv_datagram(&self) -> Result { - let bytes = wait_event!( - self.0.datagram_received, - if let Some(bytes) = self.0.try_state()?.conn.datagrams().recv() { - break bytes; - } - ); - Ok(bytes) + future::poll_fn(|cx| self.poll_recv_datagram(cx)).await } fn try_send_datagram( &self, + cx: Option<&mut Context>, data: Bytes, - drop: bool, ) -> Result<(), Result> { + use quinn_proto::SendDatagramError::*; let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?; state .conn .datagrams() - .send(data, drop) - .map_err(TryInto::try_into)?; + .send(data, cx.is_none()) + .map_err(|err| match err { + UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer), + Disabled => Ok(SendDatagramError::Disabled), + TooLarge => Ok(SendDatagramError::TooLarge), + Blocked(data) => { + state + .datagrams_unblocked + .push_back(cx.unwrap().waker().clone()); + Err(data) + } + })?; state.wake(); Ok(()) } @@ -661,7 +663,7 @@ impl Connection { /// delivered out of order, and `data` must both fit inside a single /// QUIC packet and be smaller than the maximum dictated by the peer. pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> { - self.try_send_datagram(data, true).map_err(Result::unwrap) + self.try_send_datagram(None, data).map_err(Result::unwrap) } /// Transmit `data` as an unreliable, unordered application datagram. @@ -675,38 +677,36 @@ impl Connection { /// [`send_datagram()`]: Connection::send_datagram pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> { let mut data = Some(data); - wait_event!( - self.0.datagrams_unblocked, - match self.try_send_datagram(data.take().unwrap(), false) { - Ok(res) => break Ok(res), - Err(Ok(e)) => break Err(e), - Err(Err(b)) => data.replace(b), - } + future::poll_fn( + |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) { + Ok(()) => Poll::Ready(Ok(())), + Err(Ok(e)) => Poll::Ready(Err(e)), + Err(Err(b)) => { + data.replace(b); + Poll::Pending + } + }, ) + .await } - fn try_open_stream(&self, dir: Dir) -> Result<(StreamId, bool), OpenStreamError> { + fn poll_open_stream( + &self, + cx: Option<&mut Context>, + dir: Dir, + ) -> Poll> { let mut state = self.0.try_state()?; - let stream = state - .conn - .streams() - .open(dir) - .ok_or(OpenStreamError::StreamsExhausted)?; - Ok(( - stream, - state.conn.side().is_client() && state.conn.is_handshaking(), - )) - } - - async fn open_stream(&self, dir: Dir) -> Result<(StreamId, bool), ConnectionError> { - wait_event!( - self.0.stream_available[dir as usize], - match self.try_open_stream(dir) { - Ok(res) => break Ok(res), - Err(OpenStreamError::StreamsExhausted) => {} - Err(OpenStreamError::ConnectionLost(e)) => break Err(e), + if let Some(stream) = state.conn.streams().open(dir) { + Poll::Ready(Ok(( + stream, + state.conn.side().is_client() && state.conn.is_handshaking(), + ))) + } else { + if let Some(cx) = cx { + state.stream_available[dir as usize].push_back(cx.waker().clone()); } - ) + Poll::Pending + } } /// Initiate a new outgoing unidirectional stream. @@ -715,8 +715,11 @@ impl Connection { /// won't be notified that a stream has been opened until the stream is /// actually used. pub fn open_uni(&self) -> Result { - let (stream, is_0rtt) = self.try_open_stream(Dir::Uni)?; - Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) + if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? { + Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) + } else { + Err(OpenStreamError::StreamsExhausted) + } } /// Initiate a new outgoing unidirectional stream. @@ -728,7 +731,8 @@ impl Connection { /// /// [`open_uni()`]: crate::Connection::open_uni pub async fn open_uni_wait(&self) -> Result { - let (stream, is_0rtt) = self.open_stream(Dir::Uni).await?; + let (stream, is_0rtt) = + future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?; Ok(SendStream::new(self.0.clone(), stream, is_0rtt)) } @@ -738,11 +742,14 @@ impl Connection { /// won't be notified that a stream has been opened until the stream is /// actually used. pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> { - let (stream, is_0rtt) = self.try_open_stream(Dir::Bi)?; - Ok(( - SendStream::new(self.0.clone(), stream, is_0rtt), - RecvStream::new(self.0.clone(), stream, is_0rtt), - )) + if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? { + Ok(( + SendStream::new(self.0.clone(), stream, is_0rtt), + RecvStream::new(self.0.clone(), stream, is_0rtt), + )) + } else { + Err(OpenStreamError::StreamsExhausted) + } } /// Initiate a new outgoing bidirectional stream. @@ -754,28 +761,32 @@ impl Connection { /// /// [`open_bi()`]: crate::Connection::open_bi pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> { - let (stream, is_0rtt) = self.open_stream(Dir::Bi).await?; + let (stream, is_0rtt) = + future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?; Ok(( SendStream::new(self.0.clone(), stream, is_0rtt), RecvStream::new(self.0.clone(), stream, is_0rtt), )) } - async fn accept_stream(&self, dir: Dir) -> Result<(StreamId, bool), ConnectionError> { - wait_event!(self.0.stream_opened[dir as usize], { - let mut state = self.0.state(); - if let Some(stream) = state.conn.streams().accept(dir) { - state.wake(); - break Ok((stream, state.conn.is_handshaking())); - } else if let Some(error) = &state.error { - break Err(error.clone()); - } - }) + fn poll_accept_stream( + &self, + cx: &mut Context, + dir: Dir, + ) -> Poll> { + let mut state = self.0.try_state()?; + if let Some(stream) = state.conn.streams().accept(dir) { + state.wake(); + Poll::Ready(Ok((stream, state.conn.is_handshaking()))) + } else { + state.stream_opened[dir as usize].push_back(cx.waker().clone()); + Poll::Pending + } } /// Accept the next incoming uni-directional stream pub async fn accept_uni(&self) -> Result { - let (stream, is_0rtt) = self.accept_stream(Dir::Uni).await?; + let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?; Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)) } @@ -791,7 +802,7 @@ impl Connection { /// [`SendStream`]: crate::SendStream /// [`RecvStream`]: crate::RecvStream pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { - let (stream, is_0rtt) = self.accept_stream(Dir::Bi).await?; + let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?; Ok(( SendStream::new(self.0.clone(), stream, is_0rtt), RecvStream::new(self.0.clone(), stream, is_0rtt), @@ -898,20 +909,6 @@ pub enum SendDatagramError { ConnectionLost(#[from] ConnectionError), } -impl TryFrom for SendDatagramError { - type Error = Bytes; - - fn try_from(value: quinn_proto::SendDatagramError) -> Result { - use quinn_proto::SendDatagramError::*; - match value { - UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer), - Disabled => Ok(SendDatagramError::Disabled), - TooLarge => Ok(SendDatagramError::TooLarge), - Blocked(data) => Err(data), - } - } -} - /// Errors that can arise when trying to open a stream #[derive(Debug, Error, Clone, Eq, PartialEq)] pub enum OpenStreamError { diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 3288d206..2721ffd4 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -5,7 +5,7 @@ use std::{ net::{SocketAddr, SocketAddrV6}, pin::pin, sync::{Arc, Mutex}, - task::Poll, + task::{Context, Poll, Waker}, time::Instant, }; @@ -14,7 +14,6 @@ use compio_buf::BufResult; use compio_log::{error, Instrument}; use compio_net::{ToSocketAddrsAsync, UdpSocket}; use compio_runtime::JoinHandle; -use event_listener::{Event, IntoNotification}; use flume::{unbounded, Receiver, Sender}; use futures_util::{ future::{self}, @@ -27,7 +26,7 @@ use quinn_proto::{ EndpointEvent, ServerConfig, Transmit, VarInt, }; -use crate::{wait_event, Connecting, ConnectionEvent, Incoming, RecvMeta, Socket}; +use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket}; #[derive(Debug)] struct EndpointState { @@ -37,6 +36,7 @@ struct EndpointState { close: Option<(VarInt, Bytes)>, exit_on_idle: bool, incoming: VecDeque, + incoming_wakers: VecDeque, } impl EndpointState { @@ -93,11 +93,16 @@ impl EndpointState { self.connections.is_empty() } - fn try_get_incoming(&mut self) -> Option> { + fn poll_incoming(&mut self, cx: &mut Context) -> Poll> { if self.close.is_none() { - Some(self.incoming.pop_front()) + if let Some(incoming) = self.incoming.pop_front() { + Poll::Ready(Some(incoming)) + } else { + self.incoming_wakers.push_back(cx.waker().clone()); + Poll::Pending + } } else { - None + Poll::Ready(None) } } @@ -127,7 +132,6 @@ pub(crate) struct EndpointInner { ipv6: bool, events: ChannelPair<(ConnectionHandle, EndpointEvent)>, done: AtomicWaker, - incoming: Event, } impl EndpointInner { @@ -153,12 +157,12 @@ impl EndpointInner { close: None, exit_on_idle: false, incoming: VecDeque::new(), + incoming_wakers: VecDeque::new(), }), socket, ipv6, events: unbounded(), done: AtomicWaker::new(), - incoming: Event::new(), }) } @@ -284,12 +288,13 @@ impl EndpointInner { }, } - let state = self.state.lock().unwrap(); + let mut state = self.state.lock().unwrap(); if state.exit_on_idle && state.is_idle() { break Ok(()); } if !state.incoming.is_empty() { - self.incoming.notify(state.incoming.len().additional()); + let n = state.incoming.len().min(state.incoming_wakers.len()); + state.incoming_wakers.drain(..n).for_each(Waker::wake); } } } @@ -385,13 +390,9 @@ impl Endpoint { /// intermediate `Connecting` future which can be used to e.g. send 0.5-RTT /// data. pub async fn wait_incoming(&self) -> Option { - let incoming = wait_event!( - self.inner.incoming, - if let Some(res) = self.inner.state.lock().unwrap().try_get_incoming()? { - break res; - } - ); - Some(Incoming::new(incoming, self.inner.clone())) + future::poll_fn(|cx| self.inner.state.lock().unwrap().poll_incoming(cx)) + .await + .map(|incoming| Incoming::new(incoming, self.inner.clone())) } /// Replace the server configuration, affecting new incoming connections @@ -434,7 +435,7 @@ impl Endpoint { for conn in state.connections.values() { let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone())); } - self.inner.incoming.notify(usize::MAX.additional()); + state.incoming_wakers.drain(..).for_each(Waker::wake); } // Modified from [`SharedFd::try_unwrap_inner`], see notes there. diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs index d73a3c65..259d1718 100644 --- a/compio-quic/src/lib.rs +++ b/compio-quic/src/lib.rs @@ -60,15 +60,3 @@ impl From for std::io::Error { Self::new(kind, x) } } - -macro_rules! wait_event { - ($event:expr, $break:expr) => { - loop { - $break; - event_listener::listener!($event => listener); - $break; - listener.await; - } - }; -} -pub(crate) use wait_event; diff --git a/compio-quic/tests/basic.rs b/compio-quic/tests/basic.rs index 415dae3b..4f0ed775 100644 --- a/compio-quic/tests/basic.rs +++ b/compio-quic/tests/basic.rs @@ -244,22 +244,22 @@ async fn two_datagram_readers() { async { endpoint.wait_incoming().await.unwrap().await.unwrap() }, ); - let ev = event_listener::Event::new(); + let (tx, rx) = flume::bounded::<()>(1); let (a, b, _) = join!( async { let x = conn1.recv_datagram().await.unwrap(); - ev.notify(1); + let _ = tx.try_send(()); x }, async { let x = conn1.recv_datagram().await.unwrap(); - ev.notify(1); + let _ = tx.try_send(()); x }, async { conn2.send_datagram(MSG1.into()).unwrap(); - ev.listen().await; + rx.recv_async().await.unwrap(); conn2.send_datagram_wait(MSG2.into()).await.unwrap(); } ); From e99bc8d23baa8c17fb9360fdfe124faa71652111 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Wed, 21 Aug 2024 02:11:45 +0800 Subject: [PATCH 18/26] feat(quic): http3 --- compio-quic/Cargo.toml | 15 +- compio-quic/examples/client.rs | 2 - compio-quic/examples/http3-client.rs | 84 +++++++ compio-quic/examples/http3-server.rs | 57 +++++ compio-quic/src/connection.rs | 315 ++++++++++++++++++++++++++- compio-quic/src/incoming.rs | 8 +- compio-quic/src/lib.rs | 17 +- compio-quic/src/recv_stream.rs | 51 ++++- compio-quic/src/send_stream.rs | 118 +++++++++- 9 files changed, 650 insertions(+), 17 deletions(-) create mode 100644 compio-quic/examples/http3-client.rs create mode 100644 compio-quic/examples/http3-server.rs diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 08abac05..1457f306 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -27,6 +27,7 @@ rustls = { workspace = true, features = ["ring"] } rustls-platform-verifier = { version = "0.3.3", optional = true } rustls-native-certs = { version = "0.7.1", optional = true } webpki-roots = { version = "0.26.3", optional = true } +h3 = { version = "0.0.6", optional = true } # Utils bytes = "1.7.1" @@ -42,8 +43,10 @@ windows-sys = { workspace = true, features = ["Win32_Networking_WinSock"] } libc = { workspace = true } [dev-dependencies] +compio-buf = { workspace = true, features = ["bytes"] } compio-dispatcher = { workspace = true } compio-driver = { workspace = true } +compio-fs = { workspace = true } compio-macros = { workspace = true } compio-runtime = { workspace = true, features = ["criterion"] } @@ -55,13 +58,23 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } criterion = { workspace = true, features = ["async_tokio"] } quinn = "0.11.3" tokio = { workspace = true, features = ["rt", "macros"] } +http = "1.1.0" [features] -default = [] +default = ["h3"] io-compat = ["futures-util/io"] platform-verifier = ["dep:rustls-platform-verifier"] native-certs = ["dep:rustls-native-certs"] webpki-roots = ["dep:webpki-roots"] +h3 = ["dep:h3"] + +[[example]] +name = "http3-client" +required-features = ["h3"] + +[[example]] +name = "http3-server" +required-features = ["h3"] [[bench]] name = "quic" diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/client.rs index 167bbab1..85fb23be 100644 --- a/compio-quic/examples/client.rs +++ b/compio-quic/examples/client.rs @@ -23,8 +23,6 @@ async fn main() { None, ) .unwrap() - .into_0rtt() - .unwrap_err() .await .unwrap(); diff --git a/compio-quic/examples/http3-client.rs b/compio-quic/examples/http3-client.rs new file mode 100644 index 00000000..1ddc6936 --- /dev/null +++ b/compio-quic/examples/http3-client.rs @@ -0,0 +1,84 @@ +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + path::PathBuf, + str::FromStr, +}; + +use bytes::Buf; +use compio_io::AsyncWriteAtExt; +use compio_net::ToSocketAddrsAsync; +use compio_quic::ClientBuilder; +use http::{Request, Uri}; +use tracing_subscriber::EnvFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let args = std::env::args().collect::>(); + if args.len() != 3 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let uri = Uri::from_str(&args[1]).unwrap(); + let outpath = PathBuf::from(&args[2]); + + let host = uri.host().unwrap(); + let remote = (host, uri.port_u16().unwrap_or(443)) + .to_socket_addrs_async() + .await + .unwrap() + .next() + .unwrap(); + + let endpoint = ClientBuilder::new_with_no_server_verification() + .with_key_log() + .with_alpn_protocols(&["h3"]) + .bind(SocketAddr::new( + if remote.is_ipv6() { + IpAddr::V6(Ipv6Addr::UNSPECIFIED) + } else { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) + }, + 0, + )) + .await + .unwrap(); + + { + println!("Connecting to {} at {}", host, remote); + let conn = endpoint.connect(remote, host, None).unwrap().await.unwrap(); + + let (mut conn, mut send_req) = compio_quic::h3::client::new(conn).await.unwrap(); + let handle = compio_runtime::spawn(async move { conn.wait_idle().await }); + + let req = Request::get(uri).body(()).unwrap(); + let mut stream = send_req.send_request(req).await.unwrap(); + stream.finish().await.unwrap(); + + let resp = stream.recv_response().await.unwrap(); + println!("{:?}", resp); + + let mut out = compio_fs::File::create(outpath).await.unwrap(); + let mut pos = 0; + while let Some(mut chunk) = stream.recv_data().await.unwrap() { + let len = chunk.remaining(); + out.write_all_at(chunk.copy_to_bytes(len), pos) + .await + .unwrap(); + pos += len as u64; + } + if let Some(headers) = stream.recv_trailers().await.unwrap() { + println!("{:?}", headers); + } + + drop(send_req); + + handle.await.unwrap().unwrap(); + } + + endpoint.shutdown().await.unwrap(); +} diff --git a/compio-quic/examples/http3-server.rs b/compio-quic/examples/http3-server.rs new file mode 100644 index 00000000..96450910 --- /dev/null +++ b/compio-quic/examples/http3-server.rs @@ -0,0 +1,57 @@ +use bytes::Bytes; +use compio_quic::ServerBuilder; +use http::{HeaderMap, Response}; +use tracing_subscriber::EnvFilter; + +#[compio_macros::main] +async fn main() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + + let rcgen::CertifiedKey { cert, key_pair } = + rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let cert = cert.der().clone(); + let key_der = key_pair.serialize_der().try_into().unwrap(); + + let endpoint = ServerBuilder::new_with_single_cert(vec![cert], key_der) + .unwrap() + .with_key_log() + .with_alpn_protocols(&["h3"]) + .bind("[::1]:4433") + .await + .unwrap(); + + while let Some(incoming) = endpoint.wait_incoming().await { + compio_runtime::spawn(async move { + let conn = incoming.await.unwrap(); + println!("Accepted connection from {}", conn.remote_address()); + + let mut conn = compio_quic::h3::server::builder() + .build::<_, Bytes>(conn) + .await + .unwrap(); + + while let Ok(Some((req, mut stream))) = conn.accept().await { + println!("Received request: {:?}", req); + stream + .send_response( + Response::builder() + .header("server", "compio-quic") + .body(()) + .unwrap(), + ) + .await + .unwrap(); + stream + .send_data("hello from compio-quic".into()) + .await + .unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("msg", "byebye".parse().unwrap()); + stream.send_trailers(headers).await.unwrap(); + } + }) + .detach(); + } +} diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index 44d7ce55..3b2b7ba8 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -18,8 +18,8 @@ use futures_util::{ select, stream, Future, FutureExt, StreamExt, }; use quinn_proto::{ - congestion::Controller, crypto::rustls::HandshakeData, ConnectionError, ConnectionHandle, - ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt, + congestion::Controller, crypto::rustls::HandshakeData, ConnectionHandle, ConnectionStats, Dir, + EndpointEvent, StreamEvent, StreamId, VarInt, }; use thiserror::Error; @@ -252,7 +252,7 @@ impl ConnectionInner { wake_all_streams(&mut state.stopped); } } - ConnectionLost { reason } => state.terminate(reason), + ConnectionLost { reason } => state.terminate(reason.into()), Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable), Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable), Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped), @@ -889,6 +889,63 @@ impl FusedFuture for Timer { } } +/// Reasons why a connection might be lost +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum ConnectionError { + /// The peer doesn't implement any supported version + #[error("peer doesn't implement any supported version")] + VersionMismatch, + /// The peer violated the QUIC specification as understood by this + /// implementation + #[error(transparent)] + TransportError(#[from] quinn_proto::TransportError), + /// The peer's QUIC stack aborted the connection automatically + #[error("aborted by peer: {0}")] + ConnectionClosed(quinn_proto::ConnectionClose), + /// The peer closed the connection + #[error("closed by peer: {0}")] + ApplicationClosed(quinn_proto::ApplicationClose), + /// The peer is unable to continue processing this connection, usually due + /// to having restarted + #[error("reset by peer")] + Reset, + /// Communication with the peer has lapsed for longer than the negotiated + /// idle timeout + /// + /// If neither side is sending keep-alives, a connection will time out after + /// a long enough idle period even if the peer is still reachable. See + /// also [`TransportConfig::max_idle_timeout()`] + /// and [`TransportConfig::keep_alive_interval()`]. + #[error("timed out")] + TimedOut, + /// The local application closed the connection + #[error("closed")] + LocallyClosed, + /// The connection could not be created because not enough of the CID space + /// is available + /// + /// Try using longer connection IDs. + #[error("CIDs exhausted")] + CidsExhausted, +} + +impl From for ConnectionError { + fn from(value: quinn_proto::ConnectionError) -> Self { + use quinn_proto::ConnectionError::*; + + match value { + VersionMismatch => ConnectionError::VersionMismatch, + TransportError(e) => ConnectionError::TransportError(e), + ConnectionClosed(e) => ConnectionError::ConnectionClosed(e), + ApplicationClosed(e) => ConnectionError::ApplicationClosed(e), + Reset => ConnectionError::Reset, + TimedOut => ConnectionError::TimedOut, + LocallyClosed => ConnectionError::LocallyClosed, + CidsExhausted => ConnectionError::CidsExhausted, + } + } +} + /// Errors that can arise when sending a datagram #[derive(Debug, Error, Clone, Eq, PartialEq)] pub enum SendDatagramError { @@ -919,3 +976,255 @@ pub enum OpenStreamError { #[error("streams exhausted")] StreamsExhausted, } + +#[cfg(feature = "h3")] +pub(crate) mod h3_impl { + use bytes::{Buf, BytesMut}; + use futures_util::ready; + use h3::{ + error::Code, + ext::Datagram, + quic::{self, Error, RecvDatagramExt, SendDatagramExt, WriteBuf}, + }; + + use super::*; + use crate::{send_stream::h3_impl::SendStream, ReadError, WriteError}; + + impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self, ConnectionError::TimedOut) + } + + fn err_code(&self) -> Option { + match &self { + ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { + error_code, + .. + }) => Some(error_code.into_inner()), + _ => None, + } + } + } + + impl Error for SendDatagramError { + fn is_timeout(&self) -> bool { + false + } + + fn err_code(&self) -> Option { + match self { + SendDatagramError::ConnectionLost(ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) => Some(error_code.into_inner()), + _ => None, + } + } + } + + impl SendDatagramExt for Connection + where + B: Buf, + { + type Error = SendDatagramError; + + fn send_datagram(&mut self, data: Datagram) -> Result<(), Self::Error> { + let mut buf = BytesMut::new(); + data.encode(&mut buf); + Connection::send_datagram(self, buf.freeze()) + } + } + + impl RecvDatagramExt for Connection { + type Buf = Bytes; + type Error = ConnectionError; + + fn poll_accept_datagram( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(Some(ready!(self.poll_recv_datagram(cx))?))) + } + } + + /// Bidirectional stream. + pub struct BidiStream { + send: SendStream, + recv: RecvStream, + } + + impl BidiStream { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + send: SendStream::new(conn.clone(), stream, is_0rtt), + recv: RecvStream::new(conn, stream, is_0rtt), + } + } + } + + impl quic::BidiStream for BidiStream + where + B: Buf, + { + type RecvStream = RecvStream; + type SendStream = SendStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } + } + + impl quic::RecvStream for BidiStream + where + B: Buf, + { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.recv.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.recv.recv_id() + } + } + + impl quic::SendStream for BidiStream + where + B: Buf, + { + type Error = WriteError; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.send.send_id() + } + } + + impl quic::SendStreamUnframed for BidiStream + where + B: Buf, + { + fn poll_send( + &mut self, + cx: &mut Context<'_>, + buf: &mut D, + ) -> Poll> { + self.send.poll_send(cx, buf) + } + } + + /// Stream opener. + #[derive(Clone)] + pub struct OpenStreams(Connection); + + impl quic::OpenStreams for OpenStreams + where + B: Buf, + { + type BidiStream = BidiStream; + type OpenError = ConnectionError; + type SendStream = SendStream; + + fn poll_open_bidi( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?; + Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt))) + } + + fn poll_open_send( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?; + Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt))) + } + + fn close(&mut self, code: Code, reason: &[u8]) { + self.0 + .close(code.value().try_into().expect("invalid code"), reason) + } + } + + impl quic::OpenStreams for Connection + where + B: Buf, + { + type BidiStream = BidiStream; + type OpenError = ConnectionError; + type SendStream = SendStream; + + fn poll_open_bidi( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?; + Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt))) + } + + fn poll_open_send( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?; + Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt))) + } + + fn close(&mut self, code: Code, reason: &[u8]) { + Connection::close(self, code.value().try_into().expect("invalid code"), reason) + } + } + + impl quic::Connection for Connection + where + B: Buf, + { + type AcceptError = ConnectionError; + type OpenStreams = OpenStreams; + type RecvStream = RecvStream; + + fn poll_accept_recv( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::AcceptError>> { + let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?; + Poll::Ready(Ok(Some(RecvStream::new(self.0.clone(), stream, is_0rtt)))) + } + + fn poll_accept_bidi( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::AcceptError>> { + let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?; + Poll::Ready(Ok(Some(BidiStream::new(self.0.clone(), stream, is_0rtt)))) + } + + fn opener(&self) -> Self::OpenStreams { + OpenStreams(self.clone()) + } + } +} diff --git a/compio-quic/src/incoming.rs b/compio-quic/src/incoming.rs index 0af41b62..d0eec213 100644 --- a/compio-quic/src/incoming.rs +++ b/compio-quic/src/incoming.rs @@ -7,10 +7,10 @@ use std::{ }; use futures_util::FutureExt; -use quinn_proto::{ConnectionError, ServerConfig}; +use quinn_proto::ServerConfig; use thiserror::Error; -use crate::{Connecting, Connection, EndpointInner}; +use crate::{Connecting, Connection, ConnectionError, EndpointInner}; #[derive(Debug)] pub(crate) struct IncomingInner { @@ -32,7 +32,7 @@ impl Incoming { /// occur). pub fn accept(mut self) -> Result { let inner = self.0.take().unwrap(); - inner.endpoint.accept(inner.incoming, None) + Ok(inner.endpoint.accept(inner.incoming, None)?) } /// Accept this incoming connection using a custom configuration. @@ -45,7 +45,7 @@ impl Incoming { server_config: ServerConfig, ) -> Result { let inner = self.0.take().unwrap(); - inner.endpoint.accept(inner.incoming, Some(server_config)) + Ok(inner.endpoint.accept(inner.incoming, Some(server_config))?) } /// Reject this incoming connection attempt. diff --git a/compio-quic/src/lib.rs b/compio-quic/src/lib.rs index 259d1718..3e3e36ff 100644 --- a/compio-quic/src/lib.rs +++ b/compio-quic/src/lib.rs @@ -9,8 +9,8 @@ pub use quinn_proto::{ congestion, crypto, AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, - ConfigError, ConnectError, ConnectionClose, ConnectionError, ConnectionStats, EndpointConfig, - IdleTimeout, MtuDiscoveryConfig, ServerConfig, StreamId, Transmit, TransportConfig, VarInt, + ConfigError, ConnectError, ConnectionClose, ConnectionStats, EndpointConfig, IdleTimeout, + MtuDiscoveryConfig, ServerConfig, StreamId, Transmit, TransportConfig, VarInt, }; mod builder; @@ -22,7 +22,7 @@ mod send_stream; mod socket; pub use builder::{ClientBuilder, ServerBuilder}; -pub use connection::{Connecting, Connection}; +pub use connection::{Connecting, Connection, ConnectionError}; pub use endpoint::Endpoint; pub use incoming::{Incoming, IncomingFuture}; pub use recv_stream::{ReadError, ReadExactError, RecvStream}; @@ -60,3 +60,14 @@ impl From for std::io::Error { Self::new(kind, x) } } + +/// HTTP/3 support via [`h3`]. +#[cfg(feature = "h3")] +pub mod h3 { + pub use h3::*; + + pub use crate::{ + connection::h3_impl::{BidiStream, OpenStreams}, + send_stream::h3_impl::SendStream, + }; +} diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs index 723c5075..caa664f1 100644 --- a/compio-quic/src/recv_stream.rs +++ b/compio-quic/src/recv_stream.rs @@ -9,10 +9,10 @@ use bytes::{BufMut, Bytes}; use compio_buf::{BufResult, IoBufMut}; use compio_io::AsyncRead; use futures_util::{future::poll_fn, ready}; -use quinn_proto::{Chunk, Chunks, ClosedStream, ConnectionError, ReadableError, StreamId, VarInt}; +use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt}; use thiserror::Error; -use crate::{ConnectionInner, StoppedError}; +use crate::{ConnectionError, ConnectionInner, StoppedError}; /// A stream that can only be used to receive data /// @@ -526,3 +526,50 @@ impl futures_util::AsyncRead for RecvStream { .map_err(Into::into) } } + +#[cfg(feature = "h3")] +pub(crate) mod h3_impl { + use h3::quic::{self, Error}; + + use super::*; + + impl Error for ReadError { + fn is_timeout(&self) -> bool { + matches!(self, Self::ConnectionLost(ConnectionError::TimedOut)) + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) + | Self::Reset(error_code) => Some(error_code.into_inner()), + _ => None, + } + } + } + + impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) { + Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes), + res => (None, res.err()).into(), + }) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stop(error_code.try_into().expect("invalid error_code")) + .ok(); + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.0.try_into().unwrap() + } + } +} diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs index 7801e726..78349ead 100644 --- a/compio-quic/src/send_stream.rs +++ b/compio-quic/src/send_stream.rs @@ -8,10 +8,10 @@ use bytes::Bytes; use compio_buf::{BufResult, IoBuf}; use compio_io::AsyncWrite; use futures_util::{future::poll_fn, ready}; -use quinn_proto::{ClosedStream, ConnectionError, FinishError, StreamId, VarInt, Written}; +use quinn_proto::{ClosedStream, FinishError, StreamId, VarInt, Written}; use thiserror::Error; -use crate::{ConnectionInner, StoppedError}; +use crate::{ConnectionError, ConnectionInner, StoppedError}; /// A stream that can only be used to send data. /// @@ -290,6 +290,11 @@ pub enum WriteError { /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt() #[error("0-RTT rejected")] ZeroRttRejected, + /// Error when the stream is not ready, because it is still sending + /// data from a previous call + #[cfg(feature = "h3")] + #[error("stream not ready")] + NotReady, } impl TryFrom for WriteError { @@ -320,6 +325,8 @@ impl From for io::Error { let kind = match x { Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset, ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected, + #[cfg(feature = "h3")] + NotReady => io::ErrorKind::Other, }; Self::new(kind, x) } @@ -362,3 +369,110 @@ impl futures_util::AsyncWrite for SendStream { Poll::Ready(Ok(())) } } + +#[cfg(feature = "h3")] +pub(crate) mod h3_impl { + use bytes::Buf; + use h3::quic::{self, Error, WriteBuf}; + + use super::*; + + impl Error for WriteError { + fn is_timeout(&self) -> bool { + matches!(self, Self::ConnectionLost(ConnectionError::TimedOut)) + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) + | Self::Stopped(error_code) => Some(error_code.into_inner()), + _ => None, + } + } + } + + /// A wrapper around `SendStream` that implements `quic::SendStream` and + /// `quic::SendStreamUnframed`. + pub struct SendStream { + inner: super::SendStream, + buf: Option>, + } + + impl SendStream { + pub(crate) fn new(conn: Arc, stream: StreamId, is_0rtt: bool) -> Self { + Self { + inner: super::SendStream::new(conn, stream, is_0rtt), + buf: None, + } + } + } + + impl quic::SendStream for SendStream + where + B: Buf, + { + type Error = WriteError; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(data) = &mut self.buf { + while data.has_remaining() { + let n = ready!( + self.inner + .execute_poll_write(cx, |mut stream| stream.write(data.chunk())) + )?; + data.advance(n); + } + } + self.buf = None; + Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + if self.buf.is_some() { + return Err(WriteError::NotReady); + } + self.buf = Some(data.into()); + Ok(()) + } + + fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.inner.finish().map_err(|_| WriteError::ClosedStream)) + } + + fn reset(&mut self, reset_code: u64) { + self.inner + .reset(reset_code.try_into().unwrap_or(VarInt::MAX)) + .ok(); + } + + fn send_id(&self) -> quic::StreamId { + self.inner.stream.0.try_into().unwrap() + } + } + + impl quic::SendStreamUnframed for SendStream + where + B: Buf, + { + fn poll_send( + &mut self, + cx: &mut Context<'_>, + buf: &mut D, + ) -> Poll> { + // This signifies a bug in implementation + debug_assert!( + self.buf.is_some(), + "poll_send called while send stream is not ready" + ); + + let n = ready!( + self.inner + .execute_poll_write(cx, |mut stream| stream.write(buf.chunk())) + )?; + buf.advance(n); + Poll::Ready(Ok(n)) + } + } +} From 56bcf29674afaad18ca905c86664b98e3683c09a Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 24 Aug 2024 14:52:10 +0800 Subject: [PATCH 19/26] fix(quic): apply suggestions from code review --- compio-driver/src/iocp/op.rs | 22 +++++++++++++++---- compio-quic/Cargo.toml | 2 +- .../examples/{client.rs => quic-client.rs} | 0 .../{dispatcher.rs => quic-dispatcher.rs} | 0 .../examples/{server.rs => quic-server.rs} | 0 compio/Cargo.toml | 3 ++- 6 files changed, 21 insertions(+), 6 deletions(-) rename compio-quic/examples/{client.rs => quic-client.rs} (100%) rename compio-quic/examples/{dispatcher.rs => quic-dispatcher.rs} (100%) rename compio-quic/examples/{server.rs => quic-server.rs} (100%) diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index e09b555e..0da6d7bd 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -1,7 +1,13 @@ #[cfg(feature = "once_cell_try")] use std::sync::OnceLock; use std::{ - io, marker::PhantomPinned, net::Shutdown, os::windows::io::AsRawSocket, pin::Pin, ptr::{null, null_mut}, task::Poll + io, + marker::PhantomPinned, + net::Shutdown, + os::windows::io::AsRawSocket, + pin::Pin, + ptr::{null, null_mut}, + task::Poll, }; use aligned_array::{Aligned, A8}; @@ -829,13 +835,14 @@ impl OpCode for RecvMsg { })?; let this = self.get_unchecked_mut(); - + this.slices = this.buffer.io_slices_mut(); this.msg.name = &mut this.addr as *mut _ as _; this.msg.namelen = std::mem::size_of::() as _; this.msg.lpBuffers = this.slices.as_mut_ptr() as _; this.msg.dwBufferCount = this.slices.len() as _; - this.msg.Control = std::mem::transmute::(this.control.as_io_slice_mut()); + this.msg.Control = + std::mem::transmute::(this.control.as_io_slice_mut()); let mut received = 0; let res = recvmsg_fn( @@ -908,7 +915,14 @@ impl OpCode for SendMsg { this.msg.Control = std::mem::transmute::(this.control.as_io_slice()); let mut sent = 0; - let res = WSASendMsg(this.fd.as_raw_fd() as _, &this.msg, 0, &mut sent, optr, None); + let res = WSASendMsg( + this.fd.as_raw_fd() as _, + &this.msg, + 0, + &mut sent, + optr, + None, + ); winsock_result(res, sent) } diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 1457f306..2d9f2c32 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -61,7 +61,7 @@ tokio = { workspace = true, features = ["rt", "macros"] } http = "1.1.0" [features] -default = ["h3"] +default = [] io-compat = ["futures-util/io"] platform-verifier = ["dep:rustls-platform-verifier"] native-certs = ["dep:rustls-native-certs"] diff --git a/compio-quic/examples/client.rs b/compio-quic/examples/quic-client.rs similarity index 100% rename from compio-quic/examples/client.rs rename to compio-quic/examples/quic-client.rs diff --git a/compio-quic/examples/dispatcher.rs b/compio-quic/examples/quic-dispatcher.rs similarity index 100% rename from compio-quic/examples/dispatcher.rs rename to compio-quic/examples/quic-dispatcher.rs diff --git a/compio-quic/examples/server.rs b/compio-quic/examples/quic-server.rs similarity index 100% rename from compio-quic/examples/server.rs rename to compio-quic/examples/quic-server.rs diff --git a/compio/Cargo.toml b/compio/Cargo.toml index 8cbb3715..f6820dc3 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -84,7 +84,7 @@ io-uring = [ ] polling = ["compio-driver/polling"] io = ["dep:compio-io"] -io-compat = ["io", "compio-io/compat", "compio-quic/io-compat"] +io-compat = ["io", "compio-io/compat", "compio-quic?/io-compat"] runtime = ["dep:compio-runtime", "dep:compio-fs", "dep:compio-net", "io"] macros = ["dep:compio-macros", "runtime"] event = ["compio-runtime/event", "runtime"] @@ -96,6 +96,7 @@ native-tls = ["tls", "compio-tls/native-tls"] rustls = ["tls", "compio-tls/rustls"] process = ["dep:compio-process"] quic = ["dep:compio-quic"] +h3 = ["quic", "compio-quic/h3"] all = [ "time", "macros", From aa0ec20d182655aec4500bec9dbf2ff702c2044e Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 24 Aug 2024 17:33:52 +0800 Subject: [PATCH 20/26] fix(driver, iocp): remove unnecessary field --- compio-driver/src/iocp/op.rs | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/compio-driver/src/iocp/op.rs b/compio-driver/src/iocp/op.rs index 0da6d7bd..672316ae 100644 --- a/compio-driver/src/iocp/op.rs +++ b/compio-driver/src/iocp/op.rs @@ -786,7 +786,6 @@ pub struct RecvMsg { fd: SharedFd, buffer: T, control: C, - slices: Vec, _p: PhantomPinned, } @@ -807,7 +806,6 @@ impl RecvMsg { fd, buffer, control, - slices: vec![], _p: PhantomPinned, } } @@ -836,11 +834,11 @@ impl OpCode for RecvMsg { let this = self.get_unchecked_mut(); - this.slices = this.buffer.io_slices_mut(); + let mut slices = this.buffer.io_slices_mut(); this.msg.name = &mut this.addr as *mut _ as _; this.msg.namelen = std::mem::size_of::() as _; - this.msg.lpBuffers = this.slices.as_mut_ptr() as _; - this.msg.dwBufferCount = this.slices.len() as _; + this.msg.lpBuffers = slices.as_mut_ptr() as _; + this.msg.dwBufferCount = slices.len() as _; this.msg.Control = std::mem::transmute::(this.control.as_io_slice_mut()); @@ -863,12 +861,10 @@ impl OpCode for RecvMsg { /// Send data to specified address accompanied by ancillary data from vectored /// buffer. pub struct SendMsg { - msg: WSAMSG, fd: SharedFd, buffer: T, control: C, addr: SockAddr, - pub(crate) slices: Vec, _p: PhantomPinned, } @@ -884,12 +880,10 @@ impl SendMsg { "misaligned control message buffer" ); Self { - msg: unsafe { std::mem::zeroed() }, fd, buffer, control, addr, - slices: vec![], _p: PhantomPinned, } } @@ -907,22 +901,18 @@ impl OpCode for SendMsg { unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll> { let this = self.get_unchecked_mut(); - this.slices = this.buffer.io_slices(); - this.msg.name = this.addr.as_ptr() as _; - this.msg.namelen = this.addr.len(); - this.msg.lpBuffers = this.slices.as_ptr() as _; - this.msg.dwBufferCount = this.slices.len() as _; - this.msg.Control = std::mem::transmute::(this.control.as_io_slice()); + let slices = this.buffer.io_slices(); + let msg = WSAMSG { + name: this.addr.as_ptr() as _, + namelen: this.addr.len(), + lpBuffers: slices.as_ptr() as _, + dwBufferCount: slices.len() as _, + Control: std::mem::transmute::(this.control.as_io_slice()), + dwFlags: 0, + }; let mut sent = 0; - let res = WSASendMsg( - this.fd.as_raw_fd() as _, - &this.msg, - 0, - &mut sent, - optr, - None, - ); + let res = WSASendMsg(this.fd.as_raw_fd() as _, &msg, 0, &mut sent, optr, None); winsock_result(res, sent) } From 514b9b75c33411c12b7b0fda3d3926990fddd9fb Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 24 Aug 2024 20:43:10 +0800 Subject: [PATCH 21/26] chore(quic): rustls provider --- compio-quic/Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 2d9f2c32..bc5cca9b 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -23,7 +23,7 @@ compio-net = { workspace = true } compio-runtime = { workspace = true, features = ["time"] } quinn-proto = "0.11.3" -rustls = { workspace = true, features = ["ring"] } +rustls = { workspace = true } rustls-platform-verifier = { version = "0.3.3", optional = true } rustls-native-certs = { version = "0.7.1", optional = true } webpki-roots = { version = "0.26.3", optional = true } @@ -67,6 +67,7 @@ platform-verifier = ["dep:rustls-platform-verifier"] native-certs = ["dep:rustls-native-certs"] webpki-roots = ["dep:webpki-roots"] h3 = ["dep:h3"] +# FIXME: see https://github.com/quinn-rs/quinn/pull/1962 [[example]] name = "http3-client" From 65908336ce72980389d7fcb7d596cf93701093de Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Sat, 24 Aug 2024 21:17:56 +0800 Subject: [PATCH 22/26] bench(quic): various size --- compio-quic/benches/quic.rs | 54 +++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/compio-quic/benches/quic.rs b/compio-quic/benches/quic.rs index e318a069..63d12c88 100644 --- a/compio-quic/benches/quic.rs +++ b/compio-quic/benches/quic.rs @@ -5,7 +5,7 @@ use std::{ }; use bytes::Bytes; -use criterion::{criterion_group, criterion_main, Bencher, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Bencher, BenchmarkId, Criterion, Throughput}; use futures_util::{stream::FuturesUnordered, StreamExt}; use rand::{thread_rng, RngCore}; @@ -161,36 +161,32 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { }); } +const DATA_SIZES: &[usize] = &[1, 10, 1024, 1200, 1024 * 16, 1024 * 128]; +const STREAMS: &[usize] = &[1, 10, 100]; + fn echo(c: &mut Criterion) { let mut rng = thread_rng(); - let mut large_data = [0u8; 1024 * 1024]; - rng.fill_bytes(&mut large_data); - - let mut small_data = [0u8; 10]; - rng.fill_bytes(&mut small_data); - - let mut group = c.benchmark_group("echo-large-data-1-stream"); - group.throughput(Throughput::Bytes((large_data.len() * 2) as u64)); - - group.bench_function("compio-quic", |b| echo_compio_quic(b, &large_data, 1)); - group.bench_function("quinn", |b| echo_quinn(b, &large_data, 1)); - - group.finish(); - - let mut group = c.benchmark_group("echo-large-data-10-streams"); - group.throughput(Throughput::Bytes((large_data.len() * 10 * 2) as u64)); - - group.bench_function("compio-quic", |b| echo_compio_quic(b, &large_data, 10)); - group.bench_function("quinn", |b| echo_quinn(b, &large_data, 10)); - - group.finish(); - - let mut group = c.benchmark_group("echo-small-data-100-streams"); - group.throughput(Throughput::Bytes((small_data.len() * 10 * 2) as u64)); - - group.bench_function("compio-quic", |b| echo_compio_quic(b, &small_data, 100)); - group.bench_function("quinn", |b| echo_quinn(b, &small_data, 100)); - + let mut data = vec![0u8; *DATA_SIZES.last().unwrap()]; + rng.fill_bytes(&mut data); + + let mut group = c.benchmark_group("echo"); + for &size in DATA_SIZES { + let data = &data[..size]; + for &streams in STREAMS { + group.throughput(Throughput::Bytes((data.len() * streams * 2) as u64)); + + group.bench_with_input( + BenchmarkId::new("compio-quic", format!("{}-streams-{}-bytes", streams, size)), + &(), + |b, _| echo_compio_quic(b, data, streams), + ); + group.bench_with_input( + BenchmarkId::new("quinn", format!("{}-streams-{}-bytes", streams, size)), + &(), + |b, _| echo_quinn(b, data, streams), + ); + } + } group.finish(); } From 77f1e23eb2620cb4ce0a7dd58665adf48a1df227 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Mon, 26 Aug 2024 19:52:34 +0800 Subject: [PATCH 23/26] chore: extract common deps into workspace --- Cargo.toml | 3 +++ compio-buf/Cargo.toml | 2 +- compio-log/Cargo.toml | 2 +- compio-quic/Cargo.toml | 13 ++++++------- compio/Cargo.toml | 2 +- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ad7520b3..5e44ed7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ compio-tls = { path = "./compio-tls", version = "0.2.0", default-features = fals compio-process = { path = "./compio-process", version = "0.1.0" } compio-quic = { path = "./compio-quic", version = "0.1.0" } +bytes = "1.7.1" flume = "0.11.0" cfg-if = "1.0.0" criterion = "0.5.1" @@ -51,11 +52,13 @@ nix = "0.29.0" once_cell = "1.18.0" os_pipe = "1.1.4" paste = "1.0.14" +rand = "0.8.5" rustls = { version = "0.23.1", default-features = false } slab = "0.4.9" socket2 = "0.5.6" tempfile = "3.8.1" tokio = "1.33.0" +tracing-subscriber = "0.3.18" widestring = "1.0.2" windows-sys = "0.52.0" diff --git a/compio-buf/Cargo.toml b/compio-buf/Cargo.toml index 5644c243..16c34ed4 100644 --- a/compio-buf/Cargo.toml +++ b/compio-buf/Cargo.toml @@ -17,7 +17,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] bumpalo = { version = "3.14.0", optional = true } arrayvec = { version = "0.7.4", optional = true } -bytes = { version = "1.5.0", optional = true } +bytes = { workspace = true, optional = true } [target.'cfg(unix)'.dependencies] libc = { workspace = true } diff --git a/compio-log/Cargo.toml b/compio-log/Cargo.toml index bc07ac8b..eb1e26b8 100644 --- a/compio-log/Cargo.toml +++ b/compio-log/Cargo.toml @@ -13,7 +13,7 @@ repository = { workspace = true } tracing = { version = "0.1", default-features = false } [dev-dependencies] -tracing-subscriber = "0.3" +tracing-subscriber = { workspace = true } [features] enable_log = [] diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index bc5cca9b..1642eb68 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -30,7 +30,7 @@ webpki-roots = { version = "0.26.3", optional = true } h3 = { version = "0.0.6", optional = true } # Utils -bytes = "1.7.1" +bytes = { workspace = true } flume = { workspace = true } futures-util = { workspace = true } thiserror = "1.0.63" @@ -50,15 +50,14 @@ compio-fs = { workspace = true } compio-macros = { workspace = true } compio-runtime = { workspace = true, features = ["criterion"] } -rand = "0.8.5" -rcgen = "0.13.1" -socket2 = { workspace = true, features = ["all"] } -tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } - criterion = { workspace = true, features = ["async_tokio"] } +http = "1.1.0" quinn = "0.11.3" +rand = { workspace = true } +rcgen = "0.13.1" +socket2 = { workspace = true, features = ["all"] } tokio = { workspace = true, features = ["rt", "macros"] } -http = "1.1.0" +tracing-subscriber = { workspace = true, features = ["env-filter"] } [features] default = [] diff --git a/compio/Cargo.toml b/compio/Cargo.toml index f6820dc3..921e1692 100644 --- a/compio/Cargo.toml +++ b/compio/Cargo.toml @@ -53,7 +53,7 @@ compio-macros = { workspace = true } criterion = { workspace = true, features = ["async_tokio"] } futures-channel = { workspace = true } futures-util = { workspace = true } -rand = "0.8.5" +rand = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = [ "fs", From b1ed765f40301325a2db5507bb8bfa65bae6f82f Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Tue, 27 Aug 2024 21:23:37 +0800 Subject: [PATCH 24/26] chore(quic): simplify check_0rtt --- compio-quic/src/connection.rs | 8 ++------ compio-quic/src/recv_stream.rs | 12 +++++------- compio-quic/src/send_stream.rs | 16 ++++++---------- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index 3b2b7ba8..83a41cf0 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -86,12 +86,8 @@ impl ConnectionState { .map(|data| data.downcast::().unwrap()) } - pub(crate) fn check_0rtt(&self) -> Result<(), ()> { - if self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt() { - Ok(()) - } else { - Err(()) - } + pub(crate) fn check_0rtt(&self) -> bool { + self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt() } } diff --git a/compio-quic/src/recv_stream.rs b/compio-quic/src/recv_stream.rs index caa664f1..93cff0c2 100644 --- a/compio-quic/src/recv_stream.rs +++ b/compio-quic/src/recv_stream.rs @@ -91,7 +91,7 @@ impl RecvStream { /// `ClosedStream` errors. pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { let mut state = self.conn.state(); - if self.is_0rtt && state.check_0rtt().is_err() { + if self.is_0rtt && !state.check_0rtt() { return Ok(()); } state.conn.recv_stream(self.stream).stop(error_code)?; @@ -115,7 +115,7 @@ impl RecvStream { poll_fn(|cx| { let mut state = self.conn.state(); - if self.is_0rtt && state.check_0rtt().is_err() { + if self.is_0rtt && !state.check_0rtt() { return Poll::Ready(Err(StoppedError::ZeroRttRejected)); } if let Some(code) = self.reset { @@ -169,10 +169,8 @@ impl RecvStream { } let mut state = self.conn.state(); - if self.is_0rtt { - state - .check_0rtt() - .map_err(|()| ReadError::ZeroRttRejected)?; + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(ReadError::ZeroRttRejected)); } // If we stored an error during a previous call, return it now. This can happen @@ -395,7 +393,7 @@ impl Drop for RecvStream { // clean up any previously registered wakers state.readable.remove(&self.stream); - if state.error.is_some() || (self.is_0rtt && state.check_0rtt().is_err()) { + if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) { return; } if !self.all_data_read { diff --git a/compio-quic/src/send_stream.rs b/compio-quic/src/send_stream.rs index 78349ead..3e0a1a2a 100644 --- a/compio-quic/src/send_stream.rs +++ b/compio-quic/src/send_stream.rs @@ -94,7 +94,7 @@ impl SendStream { /// stream's state. pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> { let mut state = self.conn.state(); - if self.is_0rtt && state.check_0rtt().is_err() { + if self.is_0rtt && !state.check_0rtt() { return Ok(()); } state.conn.send_stream(self.stream).reset(error_code)?; @@ -139,10 +139,8 @@ impl SendStream { pub async fn stopped(&mut self) -> Result, StoppedError> { poll_fn(|cx| { let mut state = self.conn.state(); - if self.is_0rtt { - state - .check_0rtt() - .map_err(|()| StoppedError::ZeroRttRejected)?; + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(StoppedError::ZeroRttRejected)); } match state.conn.send_stream(self.stream).stopped() { Err(_) => Poll::Ready(Ok(None)), @@ -164,10 +162,8 @@ impl SendStream { F: FnOnce(quinn_proto::SendStream) -> Result, { let mut state = self.conn.try_state()?; - if self.is_0rtt { - state - .check_0rtt() - .map_err(|()| WriteError::ZeroRttRejected)?; + if self.is_0rtt && !state.check_0rtt() { + return Poll::Ready(Err(WriteError::ZeroRttRejected)); } match f(state.conn.send_stream(self.stream)) { Ok(r) => { @@ -252,7 +248,7 @@ impl Drop for SendStream { state.stopped.remove(&self.stream); state.writable.remove(&self.stream); - if state.error.is_some() || (self.is_0rtt && state.check_0rtt().is_err()) { + if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) { return; } match state.conn.send_stream(self.stream).finish() { From 148f3acd620d6f417636ab34b99bfa5f9ee595d4 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Tue, 27 Aug 2024 21:26:08 +0800 Subject: [PATCH 25/26] chore(net): add type constraint for get/set_socket_option --- compio-net/src/socket.rs | 18 ++++++++++++++---- compio-net/src/udp.rs | 9 +++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/compio-net/src/socket.rs b/compio-net/src/socket.rs index 03bb8b27..31f6d6be 100644 --- a/compio-net/src/socket.rs +++ b/compio-net/src/socket.rs @@ -324,7 +324,7 @@ impl Socket { } #[cfg(unix)] - pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { let mut value: MaybeUninit = MaybeUninit::uninit(); let mut len = size_of::() as libc::socklen_t; syscall!(libc::getsockopt( @@ -342,7 +342,7 @@ impl Socket { } #[cfg(windows)] - pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { let mut value: MaybeUninit = MaybeUninit::uninit(); let mut len = size_of::() as i32; syscall!( @@ -363,7 +363,12 @@ impl Socket { } #[cfg(unix)] - pub unsafe fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn set_socket_option( + &self, + level: i32, + name: i32, + value: &T, + ) -> io::Result<()> { syscall!(libc::setsockopt( self.socket.as_raw_fd(), level, @@ -375,7 +380,12 @@ impl Socket { } #[cfg(windows)] - pub unsafe fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn set_socket_option( + &self, + level: i32, + name: i32, + value: &T, + ) -> io::Result<()> { syscall!( SOCKET, windows_sys::Win32::Networking::WinSock::setsockopt( diff --git a/compio-net/src/udp.rs b/compio-net/src/udp.rs index 7a28025a..33f39c2d 100644 --- a/compio-net/src/udp.rs +++ b/compio-net/src/udp.rs @@ -321,7 +321,7 @@ impl UdpSocket { /// # Safety /// /// The caller must ensure `T` is the correct type for `level` and `name`. - pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { + pub unsafe fn get_socket_option(&self, level: i32, name: i32) -> io::Result { self.inner.get_socket_option(level, name) } @@ -330,7 +330,12 @@ impl UdpSocket { /// # Safety /// /// The caller must ensure `T` is the correct type for `level` and `name`. - pub unsafe fn set_socket_option(&self, level: i32, name: i32, value: &T) -> io::Result<()> { + pub unsafe fn set_socket_option( + &self, + level: i32, + name: i32, + value: &T, + ) -> io::Result<()> { self.inner.set_socket_option(level, name, value) } } From ee92ffdc0b32b7e93cf1ab2d38f9b788e2570e79 Mon Sep 17 00:00:00 2001 From: Asakura Mizu Date: Thu, 29 Aug 2024 03:29:58 +0800 Subject: [PATCH 26/26] perf(quic): batch recv on channel --- compio-quic/Cargo.toml | 1 + compio-quic/benches/quic.rs | 5 +- compio-quic/src/connection.rs | 101 ++++++++++++++++++---------------- compio-quic/src/endpoint.rs | 28 ++++++---- 4 files changed, 77 insertions(+), 58 deletions(-) diff --git a/compio-quic/Cargo.toml b/compio-quic/Cargo.toml index 1642eb68..00575023 100644 --- a/compio-quic/Cargo.toml +++ b/compio-quic/Cargo.toml @@ -33,6 +33,7 @@ h3 = { version = "0.0.6", optional = true } bytes = { workspace = true } flume = { workspace = true } futures-util = { workspace = true } +rustc-hash = "2.0.0" thiserror = "1.0.63" # Windows specific dependencies diff --git a/compio-quic/benches/quic.rs b/compio-quic/benches/quic.rs index 63d12c88..66694da5 100644 --- a/compio-quic/benches/quic.rs +++ b/compio-quic/benches/quic.rs @@ -127,13 +127,13 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { client.set_default_client_config(client_config); let addr = server.local_addr().unwrap(); - let (client_conn, server_conn) = futures_util::join!( + let (client_conn, server_conn) = tokio::join!( async move { client.connect(addr, "localhost").unwrap().await.unwrap() }, async move { server.accept().await.unwrap().await.unwrap() } ); let start = Instant::now(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { while let Ok((mut send, mut recv)) = server_conn.accept_bi().await { tokio::spawn(async move { echo_impl!(send, recv); @@ -157,6 +157,7 @@ fn echo_quinn(b: &mut Bencher, content: &[u8], streams: usize) { .collect::>(); while futures.next().await.is_some() {} } + handle.abort(); start.elapsed() }); } diff --git a/compio-quic/src/connection.rs b/compio-quic/src/connection.rs index 83a41cf0..181aac18 100644 --- a/compio-quic/src/connection.rs +++ b/compio-quic/src/connection.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, VecDeque}, + collections::VecDeque, io, net::{IpAddr, SocketAddr}, pin::{pin, Pin}, @@ -21,6 +21,7 @@ use quinn_proto::{ congestion::Controller, crypto::rustls::HandshakeData, ConnectionHandle, ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt, }; +use rustc_hash::FxHashMap as HashMap; use thiserror::Error; use crate::{RecvStream, SendStream, Socket}; @@ -37,7 +38,7 @@ pub(crate) struct ConnectionState { pub(crate) error: Option, connected: bool, worker: Option>, - poll_waker: Option, + poller: Option, on_connected: Option, on_handshake_data: Option, datagram_received: VecDeque, @@ -73,8 +74,14 @@ impl ConnectionState { wake_all_streams(&mut self.stopped); } + fn close(&mut self, error_code: VarInt, reason: Bytes) { + self.conn.close(Instant::now(), error_code, reason); + self.terminate(ConnectionError::LocallyClosed); + self.wake(); + } + pub(crate) fn wake(&mut self) { - if let Some(waker) = self.poll_waker.take() { + if let Some(waker) = self.poller.take() { waker.wake() } } @@ -110,6 +117,12 @@ pub(crate) struct ConnectionInner { events_rx: Receiver, } +fn implicit_close(this: &Arc) { + if Arc::strong_count(this) == 2 { + this.state().close(0u32.into(), Bytes::new()) + } +} + impl ConnectionInner { fn new( handle: ConnectionHandle, @@ -124,16 +137,16 @@ impl ConnectionInner { connected: false, error: None, worker: None, - poll_waker: None, + poller: None, on_connected: None, on_handshake_data: None, datagram_received: VecDeque::new(), datagrams_unblocked: VecDeque::new(), stream_opened: [VecDeque::new(), VecDeque::new()], stream_available: [VecDeque::new(), VecDeque::new()], - writable: HashMap::new(), - readable: HashMap::new(), - stopped: HashMap::new(), + writable: HashMap::default(), + readable: HashMap::default(), + stopped: HashMap::default(), }), handle, socket, @@ -157,25 +170,13 @@ impl ConnectionInner { } } - fn close(&self, error_code: VarInt, reason: Bytes) { - let mut state = self.state(); - state.conn.close(Instant::now(), error_code, reason); - state.terminate(ConnectionError::LocallyClosed); - state.wake(); - } - - async fn run(&self) -> io::Result<()> { - let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize)); - let mut transmit_fut = pin!(Fuse::terminated()); - - let mut timer = Timer::new(); - + async fn run(self: &Arc) -> io::Result<()> { let mut poller = stream::poll_fn(|cx| { let mut state = self.state(); - let ready = state.poll_waker.is_none(); - match &state.poll_waker { + let ready = state.poller.is_none(); + match &state.poller { Some(waker) if waker.will_wake(cx.waker()) => {} - _ => state.poll_waker = Some(cx.waker().clone()), + _ => state.poller = Some(cx.waker().clone()), }; if ready { Poll::Ready(Some(())) @@ -185,36 +186,46 @@ impl ConnectionInner { }) .fuse(); + let mut timer = Timer::new(); + let mut event_stream = self.events_rx.stream().ready_chunks(100); + let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize)); + let mut transmit_fut = pin!(Fuse::terminated()); + loop { - select! { - _ = poller.next() => {} + let mut state = select! { + _ = poller.select_next_some() => self.state(), _ = timer => { - self.state().conn.handle_timeout(Instant::now()); timer.reset(None); + let mut state = self.state(); + state.conn.handle_timeout(Instant::now()); + state } - ev = self.events_rx.recv_async() => match ev { - Ok(ConnectionEvent::Close(error_code, reason)) => self.close(error_code, reason), - Ok(ConnectionEvent::Proto(ev)) => self.state().conn.handle_event(ev), - Err(_) => unreachable!("endpoint dropped connection"), + events = event_stream.select_next_some() => { + let mut state = self.state(); + for event in events { + match event { + ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason), + ConnectionEvent::Proto(event) => state.conn.handle_event(event), + } + } + state }, BufResult::<(), Vec>(res, mut buf) = transmit_fut => match res { Ok(()) => { buf.clear(); send_buf = Some(buf); + self.state() }, Err(e) => break Err(e), }, - } - - let now = Instant::now(); - let mut state = self.state(); + }; if let Some(mut buf) = send_buf.take() { - if let Some(transmit) = - state - .conn - .poll_transmit(now, self.socket.max_gso_segments(), &mut buf) - { + if let Some(transmit) = state.conn.poll_transmit( + Instant::now(), + self.socket.max_gso_segments(), + &mut buf, + ) { transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse()) } else { send_buf = Some(buf); @@ -480,9 +491,7 @@ impl Future for Connecting { impl Drop for Connecting { fn drop(&mut self) { - if Arc::strong_count(&self.0) == 2 { - self.0.close(0u32.into(), Bytes::new()) - } + implicit_close(&self.0) } } @@ -593,7 +602,9 @@ impl Connection { /// [`Endpoint::shutdown()`]: crate::Endpoint::shutdown /// [`close()`]: Connection::close pub fn close(&self, error_code: VarInt, reason: &[u8]) { - self.0.close(error_code, Bytes::copy_from_slice(reason)); + self.0 + .state() + .close(error_code, Bytes::copy_from_slice(reason)); } /// Wait for the connection to be closed for any reason. @@ -838,9 +849,7 @@ impl Eq for Connection {} impl Drop for Connection { fn drop(&mut self) { - if Arc::strong_count(&self.0) == 2 { - self.close(0u32.into(), b"") - } + implicit_close(&self.0) } } diff --git a/compio-quic/src/endpoint.rs b/compio-quic/src/endpoint.rs index 2721ffd4..99d7400f 100644 --- a/compio-quic/src/endpoint.rs +++ b/compio-quic/src/endpoint.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, VecDeque}, + collections::VecDeque, io, mem::ManuallyDrop, net::{SocketAddr, SocketAddrV6}, @@ -19,12 +19,13 @@ use futures_util::{ future::{self}, select, task::AtomicWaker, - FutureExt, + FutureExt, StreamExt, }; use quinn_proto::{ ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig, EndpointEvent, ServerConfig, Transmit, VarInt, }; +use rustc_hash::FxHashMap as HashMap; use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket}; @@ -153,7 +154,7 @@ impl EndpointInner { None, ), worker: None, - connections: HashMap::new(), + connections: HashMap::default(), close: None, exit_on_idle: false, incoming: VecDeque::new(), @@ -254,6 +255,8 @@ impl EndpointInner { } async fn run(&self) -> io::Result<()> { + let respond_fn = |buf: Vec, transmit: Transmit| self.respond(buf, transmit); + let mut recv_fut = pin!( self.socket .recv(Vec::with_capacity( @@ -269,26 +272,31 @@ impl EndpointInner { .fuse() ); - let respond_fn = |buf: Vec, transmit: Transmit| self.respond(buf, transmit); + let mut event_stream = self.events.1.stream().ready_chunks(100); loop { - select! { + let mut state = select! { BufResult(res, recv_buf) = recv_fut => { + let mut state = self.state.lock().unwrap(); match res { - Ok(meta) => self.state.lock().unwrap().handle_data(meta, &recv_buf, respond_fn), + Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn), Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} #[cfg(windows)] Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {} Err(e) => break Err(e), } recv_fut.set(self.socket.recv(recv_buf).fuse()); + state }, - (ch, event) = self.events.1.recv_async().map(Result::unwrap) => { - self.state.lock().unwrap().handle_event(ch, event); + events = event_stream.select_next_some() => { + let mut state = self.state.lock().unwrap(); + for (ch, event) in events { + state.handle_event(ch, event); + } + state }, - } + }; - let mut state = self.state.lock().unwrap(); if state.exit_on_idle && state.is_idle() { break Ok(()); }