From 69c0c918f9f5557756880fc2b78f8684b7a602cd Mon Sep 17 00:00:00 2001 From: Max Inden Date: Tue, 30 Jan 2024 19:51:17 +0100 Subject: [PATCH] Instantiate socket state once --- neqo-common/src/udp.rs | 12 +++----- neqo-server/Cargo.toml | 1 + neqo-server/src/main.rs | 65 ++++++++++++++++++++++++----------------- 3 files changed, 43 insertions(+), 35 deletions(-) diff --git a/neqo-common/src/udp.rs b/neqo-common/src/udp.rs index a3d45755a3..c680ef2f79 100644 --- a/neqo-common/src/udp.rs +++ b/neqo-common/src/udp.rs @@ -33,9 +33,7 @@ use crate::{Datagram, IpTos}; /// # Panics /// /// Panics if the datagram is too large to send. -pub fn tx(socket: impl AsFd, d: &Datagram) -> io::Result { - // TODO: Don't instantiate on each write. - let send_state = UdpSocketState::new((&socket).into()).unwrap(); +pub fn tx(socket: impl AsFd, state: &UdpSocketState, d: &Datagram) -> io::Result { let transmit = Transmit { destination: d.destination(), ecn: EcnCodepoint::from_bits(Into::::into(d.tos())), @@ -44,7 +42,7 @@ pub fn tx(socket: impl AsFd, d: &Datagram) -> io::Result { // TODO src_ip: None, }; - let n = send_state + let n = state .send((&socket).into(), slice::from_ref(&transmit)) .unwrap(); Ok(n) @@ -72,21 +70,19 @@ pub fn tx(socket: impl AsFd, d: &Datagram) -> io::Result { /// Panics if the datagram is too large to receive. pub fn rx( socket: impl AsFd, + state: &UdpSocketState, buf: &mut [u8], // TODO: Can these be return values instead of mutable inputs? tos: &mut u8, ttl: &mut u8, ) -> io::Result<(usize, SocketAddr)> { let mut meta = RecvMeta::default(); - // TODO: Don't instantiate on each read. - let recv_state = UdpSocketState::new((&socket).into()).unwrap(); - // TODO: needed? // #[cfg(test)] // // `UdpSocketState` switches to non-blocking mode, undo that for the tests. // socket.set_nonblocking(false).unwrap(); - match recv_state.recv( + match state.recv( (&socket).into(), &mut [IoSliceMut::new(buf)], slice::from_mut(&mut meta), diff --git a/neqo-server/Cargo.toml b/neqo-server/Cargo.toml index fc97218c82..40a8567adb 100644 --- a/neqo-server/Cargo.toml +++ b/neqo-server/Cargo.toml @@ -18,6 +18,7 @@ qlog = "0.11.0" regex = "1.9" structopt = "0.3" tokio = { version = "1", features = ["net", "time", "macros", "rt", "rt-multi-thread"] } +quinn-udp = { git = "https://github.com/quinn-rs/quinn/" } [features] deny-warnings = [] diff --git a/neqo-server/src/main.rs b/neqo-server/src/main.rs index e2f4d26bc8..b82aa74ce0 100644 --- a/neqo-server/src/main.rs +++ b/neqo-server/src/main.rs @@ -580,12 +580,13 @@ impl HttpServer for SimpleServer { fn read_dgram( socket: &mut tokio::net::UdpSocket, + state: &quinn_udp::UdpSocketState, local_address: &SocketAddr, ) -> Result, io::Error> { let mut buf = [0; u16::MAX as usize]; let mut tos = 0; let mut ttl = 0; - let (sz, remote_addr) = match udp::rx(socket, &mut buf[..], &mut tos, &mut ttl) { + let (sz, remote_addr) = match udp::rx(socket, state, &mut buf[..], &mut tos, &mut ttl) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Ok(None), Err(err) => { eprintln!("UDP recv error: {err:?}"); @@ -593,6 +594,7 @@ fn read_dgram( } Ok(res) => res, }; + qdebug!("read {}, {:?}", sz, &buf[0..32]); if sz == buf.len() { eprintln!("Might have received more than {} bytes", buf.len()); @@ -616,7 +618,7 @@ struct ServersRunner { args: Args, server: Box, timeout: Option>>, - sockets: Vec<(SocketAddr, tokio::net::UdpSocket)>, + sockets: Vec<(SocketAddr, tokio::net::UdpSocket, quinn_udp::UdpSocketState)>, } impl ServersRunner { @@ -660,14 +662,13 @@ impl ServersRunner { print!("Server waiting for connection on: {local_addr:?}"); - // TODO: needed? - socket - .set_nonblocking(true) - .expect("set_nonblocking to succeed"); + let state = quinn_udp::UdpSocketState::new((&socket).into()).unwrap(); self.sockets.push(( host, - tokio::net::UdpSocket::from_std(socket).expect("conversion to Tokio socket to succeed"), + tokio::net::UdpSocket::from_std(socket) + .expect("conversion to Tokio socket to succeed"), + state, )); } @@ -708,27 +709,30 @@ impl ServersRunner { } /// Tries to find a socket, but then just falls back to sending from the first. - fn find_socket(&mut self, addr: SocketAddr) -> &mut tokio::net::UdpSocket { - let ((_host, first_socket), rest) = self.sockets.split_first_mut().unwrap(); + fn find_socket( + &mut self, + addr: SocketAddr, + ) -> (&mut tokio::net::UdpSocket, &mut quinn_udp::UdpSocketState) { + let ((_host, first_socket, first_state), rest) = self.sockets.split_first_mut().unwrap(); rest.iter_mut() - .map(|(_host, socket)| socket) - .find(|socket| { + .map(|(_host, socket, state)| (socket, state)) + .find(|(socket, _state)| { socket .local_addr() .ok() .map_or(false, |socket_addr| socket_addr == addr) }) - .unwrap_or(first_socket) + .unwrap_or((first_socket, first_state)) } - async fn process(&mut self, mut dgram: Option<&Datagram>) { + async fn process(&mut self, mut dgram: Option<&Datagram>) -> Result<(), io::Error> { + qdebug!("process with {:?}", dgram); loop { match self.server.process(dgram.take(), self.args.now()) { Output::Datagram(dgram) => { - let socket = self.find_socket(dgram.source()); - if let Err(e) = udp::tx(socket, &dgram) { - eprintln!("UDP write error: {}", e); - } + qdebug!("writing to {:?}", dgram.source()); + let (socket, state) = self.find_socket(dgram.source()); + udp::tx(socket, state, &dgram)?; } Output::Callback(new_timeout) => { qinfo!("Setting timeout of {:?}", new_timeout); @@ -741,6 +745,7 @@ impl ServersRunner { } } } + Ok(()) } // Wait for any of the sockets to be readable or the timeout to fire. @@ -748,7 +753,7 @@ impl ServersRunner { let sockets_ready = select_all( self.sockets .iter() - .map(|(_host, socket)| Box::pin(socket.readable())), + .map(|(_host, socket, _state)| Box::pin(socket.readable())), ) .map(|(res, inx, _)| match res { Ok(()) => Ok(Ready::Socket(inx)), @@ -765,22 +770,28 @@ impl ServersRunner { async fn run(&mut self) -> Result<(), io::Error> { loop { + qdebug!("iteration"); match self.ready().await? { - Ready::Socket(inx) => loop { - let (host, socket) = self.sockets.get_mut(inx).unwrap(); - let dgram = read_dgram(socket, host)?; - if dgram.is_none() { - break; + Ready::Socket(inx) => { + qdebug!("socket {} ready", inx); + loop { + qdebug!("reading from {}", inx); + let (host, socket, state) = self.sockets.get_mut(inx).unwrap(); + let dgram = read_dgram(socket, state, host)?; + if dgram.is_none() { + break; + } + self.process(dgram.as_ref()).await?; } - self.process(dgram.as_ref()).await; - }, + } Ready::Timeout => { - self.process(None).await; + qdebug!("timeout fired"); + self.process(None).await?; } } self.server.process_events(&self.args, self.args.now()); - self.process(None).await; + self.process(None).await?; } } }