Skip to content

Commit

Permalink
better abstraction for buffering incoming bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz Borcherding committed Feb 27, 2024
1 parent bf87291 commit 75d3ebe
Showing 1 changed file with 98 additions and 56 deletions.
154 changes: 98 additions & 56 deletions rustbus/src/connection/ll_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ pub struct SendConn {
pub struct RecvConn {
stream: UnixStream,

msg_buf_in: Vec<u8>,
msg_buf_filled: usize,
msg_buf_in: IncomingBuffer,
cmsgs_in: Vec<ControlMessageOwned>,
cmsgspace: Vec<u8>,
}
Expand All @@ -44,6 +43,51 @@ pub struct DuplexConn {
pub recv: RecvConn,
}

struct IncomingBuffer {
buf: Vec<u8>,
filled: usize,
}

impl IncomingBuffer {
fn new() -> Self {
IncomingBuffer {
buf: Vec::new(),
filled: 0,
}
}

fn reserve(&mut self, new_len: usize) {
if self.buf.len() < new_len {
self.buf.resize(new_len, 0);
}
}

fn spare_capacity_mut(&mut self) -> &mut [u8] {
&mut self.buf[self.filled..]
}

fn read(&mut self, r: impl FnOnce(&mut [u8]) -> Result<usize>) -> Result<()> {
let read = r(self.spare_capacity_mut())?;
self.filled += read;
debug_assert!(self.filled <= self.buf.len());
Ok(())
}

fn len(&self) -> usize {
self.filled
}

fn take(&mut self) -> Vec<u8> {
self.buf.truncate(self.filled);
self.filled = 0;
std::mem::replace(&mut self.buf, Vec::new())

Check failure on line 83 in rustbus/src/connection/ll_conn.rs

View workflow job for this annotation

GitHub Actions / Lints

replacing a value of type `T` with `T::default()` is better expressed using `std::mem::take`
}

fn peek(&self) -> &[u8] {
&self.buf[..self.filled]
}
}

impl RecvConn {
#[deprecated = "use poll() or select() on the file descriptor"]
pub fn can_read_from_source(&self) -> io::Result<bool> {
Expand All @@ -60,59 +104,60 @@ impl RecvConn {
/// Reads from the source once but takes care that the internal buffer only reaches at maximum max_buffer_size
/// so we can process messages separatly and avoid leaking file descriptors to wrong messages
fn refill_buffer(&mut self, max_buffer_size: usize, timeout: Timeout) -> Result<()> {
if self.msg_buf_in.len() != max_buffer_size {
self.msg_buf_in.resize(max_buffer_size, 0);
}
self.msg_buf_in.reserve(max_buffer_size);

let iovec = IoSliceMut::new(&mut self.msg_buf_in[self.msg_buf_filled..max_buffer_size]);
// Borrow all the fields because we can't use self in the closure...
let cmsgspace = &mut self.cmsgspace;
cmsgspace.clear();
let cmsgs_in = &mut self.cmsgs_in;
let stream = &mut self.stream;

self.cmsgspace.clear();
let flags = MsgFlags::empty();
self.msg_buf_in.read(|buffer| {
let iovec = IoSliceMut::new(buffer);

let old_timeout = self.stream.read_timeout()?;
match timeout {
Timeout::Duration(d) => {
self.stream.set_read_timeout(Some(d))?;
}
Timeout::Infinite => {
self.stream.set_read_timeout(None)?;
}
Timeout::Nonblock => {
self.stream.set_nonblocking(true)?;
let flags = MsgFlags::empty();

let old_timeout = stream.read_timeout()?;
match timeout {
Timeout::Duration(d) => {
stream.set_read_timeout(Some(d))?;
}
Timeout::Infinite => {
stream.set_read_timeout(None)?;
}
Timeout::Nonblock => {
stream.set_nonblocking(true)?;
}
}
}
let iovec_mut = &mut [iovec];
let msg = recvmsg::<SockaddrStorage>(
self.stream.as_raw_fd(),
iovec_mut,
Some(&mut self.cmsgspace),
flags,
)
.map_err(|e| match e {
nix::errno::Errno::EAGAIN => Error::TimedOut,
_ => Error::IoError(e.into()),
});
let iovec_mut = &mut [iovec];
let msg =
recvmsg::<SockaddrStorage>(stream.as_raw_fd(), iovec_mut, Some(cmsgspace), flags)
.map_err(|e| match e {
nix::errno::Errno::EAGAIN => Error::TimedOut,
_ => Error::IoError(e.into()),
});

self.stream.set_nonblocking(false)?;
self.stream.set_read_timeout(old_timeout)?;
stream.set_nonblocking(false)?;
stream.set_read_timeout(old_timeout)?;

let msg = msg?;
let msg = msg?;

if msg.bytes == 0 {
return Err(Error::ConnectionClosed);
}
if msg.bytes == 0 {
return Err(Error::ConnectionClosed);
}

cmsgs_in.extend(msg.cmsgs());
Ok(msg.bytes)
})?;

self.cmsgs_in.extend(msg.cmsgs());
let bytes = msg.bytes;
self.msg_buf_filled += bytes;
Ok(())
}

pub fn bytes_needed_for_current_message(&self) -> Result<usize> {
if self.msg_buf_filled < 16 {
if self.msg_buf_in.len() < 16 {
return Ok(16);
}
let msg_buf_in = &self.msg_buf_in[..self.msg_buf_filled];
let msg_buf_in = &self.msg_buf_in.peek();
let (_, header) = unmarshal::unmarshal_header(msg_buf_in, 0)?;
let (_, header_fields_len) =
crate::wire::util::parse_u32(&msg_buf_in[unmarshal::HEADER_LEN..], header.byteorder)?;
Expand All @@ -132,7 +177,7 @@ impl RecvConn {

// Checks if the internal buffer currently holds a complete message
pub fn buffer_contains_whole_message(&self) -> Result<bool> {
if self.msg_buf_filled < 16 {
if self.msg_buf_in.len() < 16 {
return Ok(false);
}
let bytes_needed = self.bytes_needed_for_current_message();
Expand All @@ -144,7 +189,7 @@ impl RecvConn {
Err(e)
}
}
Ok(bytes_needed) => Ok(self.msg_buf_filled >= bytes_needed),
Ok(bytes_needed) => Ok(self.msg_buf_in.len() >= bytes_needed),
}
}
/// Blocks until a message has been read from the conn or the timeout has been reached
Expand Down Expand Up @@ -172,22 +217,18 @@ impl RecvConn {
/// Blocks until a message has been read from the conn or the timeout has been reached
pub fn get_next_message(&mut self, timeout: Timeout) -> Result<MarshalledMessage> {
self.read_whole_message(timeout)?;
debug_assert_eq!(self.msg_buf_filled, self.msg_buf_in.len());
let (hdrbytes, header) = unmarshal::unmarshal_header(&self.msg_buf_in, 0)?;
let (hdrbytes, header) = unmarshal::unmarshal_header(self.msg_buf_in.peek(), 0)?;
let (dynhdrbytes, dynheader) =
unmarshal::unmarshal_dynamic_header(&header, &self.msg_buf_in, hdrbytes)?;
unmarshal::unmarshal_dynamic_header(&header, self.msg_buf_in.peek(), hdrbytes)?;

let (bytes_used, mut msg) = unmarshal::unmarshal_next_message(
&header,
dynheader,
std::mem::take(&mut self.msg_buf_in),
hdrbytes + dynhdrbytes,
)?;
let buf = self.msg_buf_in.take();
let buf_len = buf.len();
let (bytes_used, mut msg) =
unmarshal::unmarshal_next_message(&header, dynheader, buf, hdrbytes + dynhdrbytes)?;

if self.msg_buf_filled != bytes_used + hdrbytes + dynhdrbytes {
if buf_len != bytes_used + hdrbytes + dynhdrbytes {
return Err(Error::UnmarshalError(UnmarshalError::NotAllBytesUsed));
}
self.msg_buf_filled = 0;

for cmsg in &self.cmsgs_in {
match cmsg {
Expand Down Expand Up @@ -483,8 +524,7 @@ impl DuplexConn {
serial_counter: 1,
},
recv: RecvConn {
msg_buf_in: Vec::new(),
msg_buf_filled: 0,
msg_buf_in: IncomingBuffer::new(),
cmsgs_in: Vec::new(),
cmsgspace: cmsg_space!([RawFd; 10]),
stream,
Expand Down Expand Up @@ -523,13 +563,15 @@ impl AsRawFd for SendConn {
self.stream.as_raw_fd()
}
}

impl AsRawFd for RecvConn {
/// Reading or writing to the `RawFd` may result in undefined behavior
/// and break the `Conn`.
fn as_raw_fd(&self) -> RawFd {
self.stream.as_raw_fd()
}
}

impl AsRawFd for DuplexConn {
/// Reading or writing to the `RawFd` may result in undefined behavior
/// and break the `Conn`.
Expand Down

0 comments on commit 75d3ebe

Please sign in to comment.