diff --git a/neqo-bin/src/udp.rs b/neqo-bin/src/udp.rs index 88a037d701..a1c82e2eb7 100644 --- a/neqo-bin/src/udp.rs +++ b/neqo-bin/src/udp.rs @@ -9,6 +9,7 @@ use std::{ io::{self, IoSliceMut}, + mem::MaybeUninit, net::{SocketAddr, ToSocketAddrs}, slice, }; @@ -18,6 +19,13 @@ use tokio::io::Interest; use neqo_common::{Datagram, IpTos}; +#[cfg(not(any(target_os = "macos", target_os = "ios")))] +// Chosen somewhat arbitrarily; might benefit from additional tuning. +pub(crate) const BATCH_SIZE: usize = 32; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub(crate) const BATCH_SIZE: usize = 1; + /// Socket receive buffer size. /// /// Allows reading multiple datagrams in a single [`Socket::recv`] call. @@ -26,7 +34,8 @@ const RECV_BUF_SIZE: usize = u16::MAX as usize; pub struct Socket { socket: tokio::net::UdpSocket, state: UdpSocketState, - recv_buf: Vec, + // TODO: Rename + recv_buf: [Vec; BATCH_SIZE], } impl Socket { @@ -37,7 +46,12 @@ impl Socket { Ok(Self { state: quinn_udp::UdpSocketState::new((&socket).into())?, socket: tokio::net::UdpSocket::from_std(socket)?, - recv_buf: vec![0; RECV_BUF_SIZE], + recv_buf: (0..BATCH_SIZE) + .into_iter() + .map(|_| vec![0; RECV_BUF_SIZE]) + .collect::>() + .try_into() + .expect("successful array instantiation"), }) } @@ -78,18 +92,25 @@ impl Socket { /// Receive a UDP datagram on the specified socket. pub fn recv(&mut self, local_address: &SocketAddr) -> Result, io::Error> { - let mut meta = RecvMeta::default(); - - match self.socket.try_io(Interest::READABLE, || { - self.state.recv( - (&self.socket).into(), - &mut [IoSliceMut::new(&mut self.recv_buf)], - slice::from_mut(&mut meta), - ) + let mut metas = [RecvMeta::default(); BATCH_SIZE]; + + // TODO: Safe? + let mut iovs = MaybeUninit::<[IoSliceMut<'_>; BATCH_SIZE]>::uninit(); + for (i, buf) in self.recv_buf.iter_mut().enumerate() { + unsafe { + iovs.as_mut_ptr() + .cast::() + .add(i) + .write(IoSliceMut::new(buf)) + }; + } + let mut iovs = unsafe { iovs.assume_init() }; + + let msgs = match self.socket.try_io(Interest::READABLE, || { + self.state + .recv((&self.socket).into(), &mut iovs, &mut metas) }) { - Ok(n) => { - assert_eq!(n, 1, "only passed one slice"); - } + Ok(n) => n, Err(ref err) if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted => @@ -101,28 +122,36 @@ impl Socket { } }; - if meta.len == 0 { - eprintln!("zero length datagram received?"); - return Ok(vec![]); - } - if meta.len == self.recv_buf.len() { - eprintln!( - "Might have received more than {} bytes", - self.recv_buf.len() - ); - } - - Ok(self.recv_buf[0..meta.len] - .chunks(meta.stride.min(self.recv_buf.len())) - .map(|d| { - Datagram::new( - meta.addr, - *local_address, - meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), - None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 - d, - ) + // TODO + // if meta.len == 0 { + // eprintln!("zero length datagram received?"); + // return Ok(vec![]); + // } + // if meta.len == self.recv_buf.len() { + // eprintln!( + // "Might have received more than {} bytes", + // self.recv_buf.len() + // ); + // } + + Ok(metas + .iter() + .zip(iovs.iter()) + .take(msgs) + .map(|(meta, buf)| { + buf[0..meta.len] + .chunks(meta.stride.min(buf.len())) + .map(|d| { + Datagram::new( + meta.addr, + *local_address, + meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), + None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 + d, + ) + }) }) + .flatten() .collect()) } }