Skip to content

Commit

Permalink
Merge pull request #121 from MaxVerevkin/nonzeroserial
Browse files Browse the repository at this point in the history
Enforce that serials are not zero
  • Loading branch information
KillingSpark authored Mar 22, 2024
2 parents 5875f1f + c843d2e commit f7dbe26
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 37 deletions.
8 changes: 5 additions & 3 deletions rustbus/benches/marshal_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rustbus::params::Container;
use rustbus::params::DictMap;
Expand All @@ -9,7 +11,7 @@ use rustbus::wire::unmarshal::unmarshal_next_message;
use rustbus::wire::unmarshal_context::Cursor;

fn marsh(msg: &rustbus::message_builder::MarshalledMessage, buf: &mut Vec<u8>) {
marshal(msg, 0, buf).unwrap();
marshal(msg, NonZeroU32::MIN, buf).unwrap();
}

fn unmarshal(buf: &[u8]) {
Expand Down Expand Up @@ -66,7 +68,7 @@ fn criterion_benchmark(c: &mut Criterion) {
.signal("io.killing.spark", "TestSignal", "/io/killing/spark")
.build();
msg.body.push_old_params(&params).unwrap();
msg.dynheader.serial = Some(1);
msg.dynheader.serial = Some(NonZeroU32::MIN);
buf.clear();
marsh(black_box(&msg), &mut buf)
})
Expand All @@ -76,7 +78,7 @@ fn criterion_benchmark(c: &mut Criterion) {
.signal("io.killing.spark", "TestSignal", "/io/killing/spark")
.build();
msg.body.push_old_params(&params).unwrap();
msg.dynheader.serial = Some(1);
msg.dynheader.serial = Some(NonZeroU32::MIN);
buf.clear();
marsh(&msg, &mut buf);
buf.extend_from_slice(msg.get_buf());
Expand Down
3 changes: 2 additions & 1 deletion rustbus/src/bin/create_corpus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rustbus::message_builder::MessageBuilder;
use rustbus::Marshal;

use std::io::Write;
use std::num::NonZeroU32;

fn main() {
make_and_dump(
Expand Down Expand Up @@ -60,7 +61,7 @@ fn make_message() -> MarshalledMessage {

fn dump_message(path: &str, msg: &MarshalledMessage) {
let mut hdrbuf = vec![];
rustbus::wire::marshal::marshal(msg, 0, &mut hdrbuf).unwrap();
rustbus::wire::marshal::marshal(msg, NonZeroU32::MIN, &mut hdrbuf).unwrap();

let mut file = std::fs::File::create(path).unwrap();
file.write_all(&hdrbuf).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion rustbus/src/bin/perf_test.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use rustbus::connection::Timeout;
use rustbus::MessageBuilder;
use rustbus::RpcConn;
Expand Down Expand Up @@ -28,7 +30,7 @@ fn main() {
let mut buf = Vec::new();
for _ in 0..20000000 {
buf.clear();
rustbus::wire::marshal::marshal(&sig, 1, &mut buf).unwrap();
rustbus::wire::marshal::marshal(&sig, NonZeroU32::MIN, &mut buf).unwrap();
}

// for _ in 0..50000000 {
Expand Down
29 changes: 17 additions & 12 deletions rustbus/src/connection/ll_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::wire::errors::UnmarshalError;
use crate::wire::{marshal, unmarshal, UnixFd};

use std::io::{self, IoSlice, IoSliceMut};
use std::num::NonZeroU32;
use std::os::fd::AsFd;
use std::time;

Expand All @@ -26,7 +27,7 @@ pub struct SendConn {
stream: UnixStream,
header_buf: Vec<u8>,

serial_counter: u32,
serial_counter: NonZeroU32,
}

pub struct RecvConn {
Expand Down Expand Up @@ -248,9 +249,12 @@ impl RecvConn {

impl SendConn {
/// get the next new serial
pub fn alloc_serial(&mut self) -> u32 {
pub fn alloc_serial(&mut self) -> NonZeroU32 {
let serial = self.serial_counter;
self.serial_counter += 1;
self.serial_counter = self
.serial_counter
.checked_add(1)
.expect("run out of serials");
serial
}

Expand All @@ -262,9 +266,7 @@ impl SendConn {
let serial = if let Some(serial) = msg.dynheader.serial {
serial
} else {
let serial = self.serial_counter;
self.serial_counter += 1;
serial
self.alloc_serial()
};

// clear the buf before marshalling the new header
Expand All @@ -285,7 +287,7 @@ impl SendConn {
}

/// send a message and block until all bytes have been sent. Returns the serial of the message to match the response.
pub fn send_message_write_all(&mut self, msg: &MarshalledMessage) -> Result<u32> {
pub fn send_message_write_all(&mut self, msg: &MarshalledMessage) -> Result<NonZeroU32> {
let ctx = self.send_message(msg)?;
ctx.write_all().map_err(force_finish_on_error)
}
Expand Down Expand Up @@ -317,7 +319,7 @@ pub struct SendMessageContext<'a> {
#[derive(Debug, Copy, Clone)]
pub struct SendMessageState {
bytes_sent: usize,
serial: u32,
serial: NonZeroU32,
}

/// This panics if the SendMessageContext was dropped when it was not yet finished. Use force_finish / force_finish_on_error
Expand All @@ -333,7 +335,7 @@ impl Drop for SendMessageContext<'_> {
}

impl SendMessageContext<'_> {
pub fn serial(&self) -> u32 {
pub fn serial(&self) -> NonZeroU32 {
self.state.serial
}

Expand Down Expand Up @@ -384,7 +386,10 @@ impl SendMessageContext<'_> {

/// Try writing as many bytes as possible until either no more bytes need to be written or
/// the timeout is reached. For an infinite timeout there is write_all as a shortcut
pub fn write(mut self, timeout: Timeout) -> std::result::Result<u32, (Self, super::Error)> {
pub fn write(
mut self,
timeout: Timeout,
) -> std::result::Result<NonZeroU32, (Self, super::Error)> {
let start_time = std::time::Instant::now();

// loop until either the time is up or all bytes have been written
Expand All @@ -409,7 +414,7 @@ impl SendMessageContext<'_> {
}

/// Block until all bytes have been written
pub fn write_all(self) -> std::result::Result<u32, (Self, super::Error)> {
pub fn write_all(self) -> std::result::Result<NonZeroU32, (Self, super::Error)> {
self.write(Timeout::Infinite)
}

Expand Down Expand Up @@ -513,7 +518,7 @@ impl DuplexConn {
send: SendConn {
stream: stream.try_clone()?,
header_buf: Vec::new(),
serial_counter: 1,
serial_counter: NonZeroU32::MIN,
},
recv: RecvConn {
msg_buf_in: IncomingBuffer::new(),
Expand Down
13 changes: 9 additions & 4 deletions rustbus/src/connection/rpc_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::ll_conn::DuplexConn;
use super::*;
use crate::message_builder::{MarshalledMessage, MessageType};
use std::collections::{HashMap, VecDeque};
use std::num::NonZeroU32;

/// Convenience wrapper around the lowlevel connection
/// ```rust,no_run
Expand Down Expand Up @@ -32,7 +33,7 @@ use std::collections::{HashMap, VecDeque};
pub struct RpcConn {
signals: VecDeque<MarshalledMessage>,
calls: VecDeque<MarshalledMessage>,
responses: HashMap<u32, MarshalledMessage>,
responses: HashMap<NonZeroU32, MarshalledMessage>,
conn: DuplexConn,
filter: MessageFilter,
}
Expand Down Expand Up @@ -91,7 +92,7 @@ impl RpcConn {
}

/// get the next new serial
pub fn alloc_serial(&mut self) -> u32 {
pub fn alloc_serial(&mut self) -> NonZeroU32 {
self.conn.send.alloc_serial()
}

Expand Down Expand Up @@ -124,12 +125,16 @@ impl RpcConn {
}

/// Return a response if one is there but dont block
pub fn try_get_response(&mut self, serial: u32) -> Option<MarshalledMessage> {
pub fn try_get_response(&mut self, serial: NonZeroU32) -> Option<MarshalledMessage> {
self.responses.remove(&serial)
}

/// Return a response if one is there or block until it arrives
pub fn wait_response(&mut self, serial: u32, timeout: Timeout) -> Result<MarshalledMessage> {
pub fn wait_response(
&mut self,
serial: NonZeroU32,
timeout: Timeout,
) -> Result<MarshalledMessage> {
let start_time = time::Instant::now();
loop {
if let Some(msg) = self.try_get_response(serial) {
Expand Down
5 changes: 3 additions & 2 deletions rustbus/src/message_builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Build new messages that you want to send over a connection
use std::num::NonZeroU32;
use std::os::fd::RawFd;

use crate::params::message;
Expand Down Expand Up @@ -66,11 +67,11 @@ pub struct DynamicHeader {
pub member: Option<String>,
pub object: Option<String>,
pub destination: Option<String>,
pub serial: Option<u32>,
pub serial: Option<NonZeroU32>,
pub sender: Option<String>,
pub signature: Option<String>,
pub error_name: Option<String>,
pub response_serial: Option<u32>,
pub response_serial: Option<NonZeroU32>,
pub num_fds: Option<u32>,
}

Expand Down
14 changes: 8 additions & 6 deletions rustbus/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use crate::params::Base;
use crate::params::Param;
use crate::wire::marshal::marshal;
Expand Down Expand Up @@ -40,9 +42,9 @@ fn test_marshal_unmarshal() {
params.push(128u64.into());
params.push(128i32.into());

msg.dynheader.serial = Some(1);
msg.dynheader.serial = Some(NonZeroU32::MIN);
let mut buf = Vec::new();
marshal(&msg, 0, &mut buf).unwrap();
marshal(&msg, NonZeroU32::MIN, &mut buf).unwrap();

let mut cursor = Cursor::new(&buf);
let header = unmarshal_header(&mut cursor).unwrap();
Expand Down Expand Up @@ -97,13 +99,13 @@ fn test_invalid_stuff() {
let mut msg = crate::message_builder::MessageBuilder::new()
.signal(".......io.killing.spark", "TestSignal", "/io/killing/spark")
.build();
msg.dynheader.serial = Some(1);
msg.dynheader.serial = Some(NonZeroU32::MIN);
let mut buf = Vec::new();
assert_eq!(
Err(crate::wire::errors::MarshalError::Validation(
crate::params::validation::Error::InvalidInterface
)),
marshal(&msg, 0, &mut buf)
marshal(&msg, NonZeroU32::MIN, &mut buf)
);

// invalid member
Expand All @@ -114,12 +116,12 @@ fn test_invalid_stuff() {
"/io/killing/spark",
)
.build();
msg.dynheader.serial = Some(1);
msg.dynheader.serial = Some(NonZeroU32::MIN);
let mut buf = Vec::new();
assert_eq!(
Err(crate::wire::errors::MarshalError::Validation(
crate::params::validation::Error::InvalidMembername
)),
marshal(&msg, 0, &mut buf)
marshal(&msg, NonZeroU32::MIN, &mut buf)
);
}
5 changes: 4 additions & 1 deletion rustbus/src/wire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub mod validate_raw;
pub mod variant_macros;

mod wrapper_types;

use std::num::NonZeroU32;

pub use wrapper_types::unixfd::UnixFd;
pub use wrapper_types::ObjectPath;
pub use wrapper_types::SignatureWrapper;
Expand All @@ -20,7 +23,7 @@ pub enum HeaderField {
Interface(String),
Member(String),
ErrorName(String),
ReplySerial(u32),
ReplySerial(NonZeroU32),
Destination(String),
Sender(String),
Signature(String),
Expand Down
3 changes: 3 additions & 0 deletions rustbus/src/wire/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ pub enum UnmarshalError {
/// A message indicated an invalid byteorder in the header
#[error("A message indicated an invalid byteorder in the header")]
InvalidByteOrder,
/// A message has an invalid (zero) serial in the header
#[error("A message has an invalid (zero) serial in the header")]
InvalidSerial,
/// A message indicated an invalid message type
#[error("A message indicated an invalid message type")]
InvalidMessageType,
Expand Down
10 changes: 6 additions & 4 deletions rustbus/src/wire/marshal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//! * `base` and `container` are for the Param approach that map dbus concepts to enums/structs
//! * `traits` is for the trait based approach
use std::num::NonZeroU32;

use crate::message_builder;
use crate::params;
use crate::wire::HeaderField;
Expand Down Expand Up @@ -34,7 +36,7 @@ impl MarshalContext<'_, '_> {
/// and use get_buf() to get to the contents
pub fn marshal(
msg: &crate::message_builder::MarshalledMessage,
chosen_serial: u32,
chosen_serial: NonZeroU32,
buf: &mut Vec<u8>,
) -> MarshalResult<()> {
marshal_header(msg, chosen_serial, buf)?;
Expand All @@ -51,7 +53,7 @@ pub fn marshal(

fn marshal_header(
msg: &crate::message_builder::MarshalledMessage,
chosen_serial: u32,
chosen_serial: NonZeroU32,
buf: &mut Vec<u8>,
) -> MarshalResult<()> {
let byteorder = msg.body.byteorder();
Expand Down Expand Up @@ -84,7 +86,7 @@ fn marshal_header(
// Zero bytes where the length of the message will be put
buf.extend_from_slice(&[0, 0, 0, 0]);

write_u32(chosen_serial, byteorder, buf);
write_u32(chosen_serial.get(), byteorder, buf);

// Zero bytes where the length of the header fields will be put
let pos = buf.len();
Expand Down Expand Up @@ -172,7 +174,7 @@ fn marshal_header_field(
buf.push(b'u');
buf.push(0);
pad_to_align(4, buf);
write_u32(*rs, byteorder, buf);
write_u32(rs.get(), byteorder, buf);
}
HeaderField::Destination(dest) => {
params::validate_busname(dest)?;
Expand Down
11 changes: 8 additions & 3 deletions rustbus/src/wire/unmarshal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//! * `traits` is for the trait based approach
//! * `iter` is an experimental approach to an libdbus-like iterator
use std::num::NonZeroU32;

use crate::message_builder::DynamicHeader;
use crate::message_builder::MarshalledMessage;
use crate::message_builder::MarshalledMessageBody;
Expand Down Expand Up @@ -33,7 +35,7 @@ pub struct Header {
pub flags: u8,
pub version: u8,
pub body_len: u32,
pub serial: u32,
pub serial: NonZeroU32,
}

impl From<crate::signature::Error> for UnmarshalError {
Expand Down Expand Up @@ -75,7 +77,8 @@ pub fn unmarshal_header(cursor: &mut Cursor) -> UnmarshalResult<Header> {
let flags = cursor.read_u8()?;
let version = cursor.read_u8()?;
let body_len = cursor.read_u32(byteorder)?;
let serial = cursor.read_u32(byteorder)?;
let serial =
NonZeroU32::new(cursor.read_u32(byteorder)?).ok_or(UnmarshalError::InvalidSerial)?;

Ok(Header {
byteorder,
Expand Down Expand Up @@ -242,7 +245,9 @@ fn unmarshal_header_field(header: &Header, cursor: &mut Cursor) -> UnmarshalResu
},
5 => match sig {
signature::Type::Base(signature::Base::Uint32) => {
Ok(HeaderField::ReplySerial(cursor.read_u32(header.byteorder)?))
NonZeroU32::new(cursor.read_u32(header.byteorder)?)
.ok_or(UnmarshalError::InvalidHeaderField)
.map(HeaderField::ReplySerial)
}
_ => Err(UnmarshalError::WrongSignature),
},
Expand Down

0 comments on commit f7dbe26

Please sign in to comment.