diff --git a/.bleep b/.bleep index 321c1598..a5d99c62 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -7e74e88f3e50bfdb38eeea244f3100d8477db7e4 \ No newline at end of file +2b28af8029c2e74b642c3a7445dfd6768eda1b24 \ No newline at end of file diff --git a/pingora-core/src/protocols/l4/stream.rs b/pingora-core/src/protocols/l4/stream.rs index edcb188e..37f7486c 100644 --- a/pingora-core/src/protocols/l4/stream.rs +++ b/pingora-core/src/protocols/l4/stream.rs @@ -17,13 +17,15 @@ use async_trait::async_trait; use futures::FutureExt; use log::{debug, error}; + use pingora_error::{ErrorType::*, OrErr, Result}; +use std::io::IoSliceMut; use std::os::unix::io::AsRawFd; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant, SystemTime}; -use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf}; +use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, Interest, ReadBuf}; use tokio::net::{TcpStream, UnixStream}; use crate::protocols::l4::ext::{set_tcp_keepalive, TcpKeepalive}; @@ -118,6 +120,162 @@ impl AsRawFd for RawStream { } } +#[derive(Debug)] +struct RawStreamWrapper { + pub(crate) stream: RawStream, + /// store the last rx timestamp of the stream. + pub(crate) rx_ts: Option, + #[cfg(target_os = "linux")] + /// This can be reused across multiple recvmsg calls. The cmsg buffer may + /// come from old sockets created by older version of pingora and so, + /// this vector can only grow. + reusable_cmsg_space: Vec, +} + +impl RawStreamWrapper { + pub fn new(stream: RawStream) -> Self { + RawStreamWrapper { + stream, + rx_ts: None, + #[cfg(target_os = "linux")] + reusable_cmsg_space: nix::cmsg_space!(nix::sys::time::TimeSpec), + } + } +} + +impl AsyncRead for RawStreamWrapper { + #[cfg(not(target_os = "linux"))] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // Safety: Basic enum pin projection + unsafe { + let rs_wrapper = Pin::get_unchecked_mut(self); + match &mut rs_wrapper.stream { + RawStream::Tcp(s) => Pin::new_unchecked(s).poll_read(cx, buf), + RawStream::Unix(s) => Pin::new_unchecked(s).poll_read(cx, buf), + } + } + } + + #[cfg(target_os = "linux")] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + use futures::ready; + use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags, SockaddrStorage}; + + // Safety: Basic pin projection to get mutable stream + let rs_wrapper = unsafe { Pin::get_unchecked_mut(self) }; + match &mut rs_wrapper.stream { + RawStream::Tcp(s) => { + loop { + ready!(s.poll_read_ready(cx))?; + // Safety: maybe uninitialized bytes will only be passed to recvmsg + let b = unsafe { + &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit] + as *mut [u8]) + }; + let mut iov = [IoSliceMut::new(b)]; + rs_wrapper.reusable_cmsg_space.clear(); + + match s.try_io(Interest::READABLE, || { + recvmsg::( + s.as_raw_fd(), + &mut iov, + Some(&mut rs_wrapper.reusable_cmsg_space), + MsgFlags::empty(), + ) + .map_err(|errno| errno.into()) + }) { + Ok(r) => { + if let Some(ControlMessageOwned::ScmTimestampsns(rtime)) = r + .cmsgs() + .find(|i| matches!(i, ControlMessageOwned::ScmTimestampsns(_))) + { + // The returned timestamp is a real (i.e. not monotonic) timestamp + // https://docs.kernel.org/networking/timestamping.html + rs_wrapper.rx_ts = + SystemTime::UNIX_EPOCH.checked_add(rtime.system.into()); + } + // Safety: We trust `recvmsg` to have filled up `r.bytes` bytes in the buffer. + unsafe { + buf.assume_init(r.bytes); + } + buf.advance(r.bytes); + return Poll::Ready(Ok(())); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => return Poll::Ready(Err(e)), + } + } + } + // Unix RX timestamp only works with datagram for now, so we do not care about it + RawStream::Unix(s) => unsafe { Pin::new_unchecked(s).poll_read(cx, buf) }, + } + } +} + +impl AsyncWrite for RawStreamWrapper { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + // Safety: Basic enum pin projection + unsafe { + match &mut Pin::get_unchecked_mut(self).stream { + RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write(cx, buf), + RawStream::Unix(s) => Pin::new_unchecked(s).poll_write(cx, buf), + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // Safety: Basic enum pin projection + unsafe { + match &mut Pin::get_unchecked_mut(self).stream { + RawStream::Tcp(s) => Pin::new_unchecked(s).poll_flush(cx), + RawStream::Unix(s) => Pin::new_unchecked(s).poll_flush(cx), + } + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // Safety: Basic enum pin projection + unsafe { + match &mut Pin::get_unchecked_mut(self).stream { + RawStream::Tcp(s) => Pin::new_unchecked(s).poll_shutdown(cx), + RawStream::Unix(s) => Pin::new_unchecked(s).poll_shutdown(cx), + } + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + // Safety: Basic enum pin projection + unsafe { + match &mut Pin::get_unchecked_mut(self).stream { + RawStream::Tcp(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), + RawStream::Unix(s) => Pin::new_unchecked(s).poll_write_vectored(cx, bufs), + } + } + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } +} + +impl AsRawFd for RawStreamWrapper { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + self.stream.as_raw_fd() + } +} + // Large read buffering helps reducing syscalls with little trade-off // Ssl layer always does "small" reads in 16k (TLS record size) so L4 read buffer helps a lot. const BUF_READ_SIZE: usize = 64 * 1024; @@ -133,7 +291,7 @@ const BUF_WRITE_SIZE: usize = 1460; /// A concrete type for transport layer connection + extra fields for logging #[derive(Debug)] pub struct Stream { - stream: BufStream, + stream: BufStream, buffer_write: bool, proxy_digest: Option>, socket_digest: Option>, @@ -143,12 +301,14 @@ pub struct Stream { pub tracer: Option, read_pending_time: AccumulatedDuration, write_pending_time: AccumulatedDuration, + /// Last rx timestamp associated with the last recvmsg call. + pub rx_ts: Option, } impl Stream { /// set TCP nodelay for this connection if `self` is TCP pub fn set_nodelay(&mut self) -> Result<()> { - if let RawStream::Tcp(s) = &self.stream.get_ref() { + if let RawStream::Tcp(s) = &self.stream.get_mut().stream { s.set_nodelay(true) .or_err(ConnectError, "failed to set_nodelay")?; } @@ -157,18 +317,40 @@ impl Stream { /// set TCP keepalive settings for this connection if `self` is TCP pub fn set_keepalive(&mut self, ka: &TcpKeepalive) -> Result<()> { - if let RawStream::Tcp(s) = &self.stream.get_ref() { + if let RawStream::Tcp(s) = &self.stream.get_mut().stream { debug!("Setting tcp keepalive"); set_tcp_keepalive(s, ka)?; } Ok(()) } + + #[cfg(target_os = "linux")] + pub fn set_rx_timestamp(&mut self) -> Result<()> { + use nix::sys::socket::{setsockopt, sockopt, TimestampingFlag}; + + if let RawStream::Tcp(s) = &self.stream.get_mut().stream { + let timestamp_options = TimestampingFlag::SOF_TIMESTAMPING_RX_SOFTWARE + | TimestampingFlag::SOF_TIMESTAMPING_SOFTWARE; + return setsockopt(s.as_raw_fd(), sockopt::Timestamping, ×tamp_options) + .or_err(InternalError, "failed to set SOF_TIMESTAMPING_RX_SOFTWARE"); + } + Ok(()) + } + + #[cfg(not(target_os = "linux"))] + pub fn set_rx_timestamp(&mut self) -> io::Result<()> { + Ok(()) + } } impl From for Stream { fn from(s: TcpStream) -> Self { Stream { - stream: BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, RawStream::Tcp(s)), + stream: BufStream::with_capacity( + BUF_READ_SIZE, + BUF_WRITE_SIZE, + RawStreamWrapper::new(RawStream::Tcp(s)), + ), buffer_write: true, established_ts: SystemTime::now(), proxy_digest: None, @@ -176,6 +358,7 @@ impl From for Stream { tracer: None, read_pending_time: AccumulatedDuration::new(), write_pending_time: AccumulatedDuration::new(), + rx_ts: None, } } } @@ -183,7 +366,11 @@ impl From for Stream { impl From for Stream { fn from(s: UnixStream) -> Self { Stream { - stream: BufStream::with_capacity(BUF_READ_SIZE, BUF_WRITE_SIZE, RawStream::Unix(s)), + stream: BufStream::with_capacity( + BUF_READ_SIZE, + BUF_WRITE_SIZE, + RawStreamWrapper::new(RawStream::Unix(s)), + ), buffer_write: true, established_ts: SystemTime::now(), proxy_digest: None, @@ -191,6 +378,7 @@ impl From for Stream { tracer: None, read_pending_time: AccumulatedDuration::new(), write_pending_time: AccumulatedDuration::new(), + rx_ts: None, } } } @@ -262,7 +450,7 @@ impl Drop for Stream { t.0.on_disconnected(); } /* use nodelay/local_addr function to detect socket status */ - let ret = match &self.stream.get_ref() { + let ret = match &self.stream.get_ref().stream { RawStream::Tcp(s) => s.nodelay().err(), RawStream::Unix(s) => s.local_addr().err(), }; @@ -298,6 +486,7 @@ impl AsyncRead for Stream { ) -> Poll> { let result = Pin::new(&mut self.stream).poll_read(cx, buf); self.read_pending_time.poll_time(&result); + self.rx_ts = self.stream.get_ref().rx_ts; result } } @@ -528,3 +717,42 @@ impl AccumulatedDuration { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use tokio::io::AsyncReadExt; + use tokio::io::AsyncWriteExt; + use tokio::net::TcpListener; + use tokio::sync::Notify; + + #[cfg(target_os = "linux")] + #[tokio::test] + async fn test_rx_timestamp() { + let message = "hello world".as_bytes(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + notify2.notified().await; + stream.write_all(message).await.unwrap(); + }); + + let mut stream: Stream = TcpStream::connect(addr).await.unwrap().into(); + stream.set_rx_timestamp().unwrap(); + // Receive the message + // setsockopt for SO_TIMESTAMPING is asynchronous so sleep a little bit + // to let kernel do the work + std::thread::sleep(Duration::from_micros(100)); + notify.notify_one(); + + let mut buffer = vec![0u8; message.len()]; + let n = stream.read(buffer.as_mut_slice()).await.unwrap(); + assert_eq!(n, message.len()); + assert!(stream.rx_ts.is_some()); + } +}