From c843d2e2c580119c3799a3907249eb0a1b91f479 Mon Sep 17 00:00:00 2001 From: MaxVerevkin Date: Fri, 22 Mar 2024 10:44:56 +0200 Subject: [PATCH] Enforce that serials are not zero From the spec: > 2nd UINT32 - The serial of this message, used as a cookie by the > sender to identify the reply corresponding to this request. This must > not be zero. --- rustbus/benches/marshal_benchmark.rs | 8 +++++--- rustbus/src/bin/create_corpus.rs | 3 ++- rustbus/src/bin/perf_test.rs | 4 +++- rustbus/src/connection/ll_conn.rs | 29 ++++++++++++++++------------ rustbus/src/connection/rpc_conn.rs | 13 +++++++++---- rustbus/src/message_builder.rs | 5 +++-- rustbus/src/tests.rs | 14 ++++++++------ rustbus/src/wire.rs | 5 ++++- rustbus/src/wire/errors.rs | 3 +++ rustbus/src/wire/marshal.rs | 10 ++++++---- rustbus/src/wire/unmarshal.rs | 11 ++++++++--- 11 files changed, 68 insertions(+), 37 deletions(-) diff --git a/rustbus/benches/marshal_benchmark.rs b/rustbus/benches/marshal_benchmark.rs index bf3470e..11487d5 100644 --- a/rustbus/benches/marshal_benchmark.rs +++ b/rustbus/benches/marshal_benchmark.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU32; + use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rustbus::params::Container; use rustbus::params::DictMap; @@ -9,7 +11,7 @@ use rustbus::wire::unmarshal::unmarshal_next_message; use rustbus::wire::unmarshal_context::Cursor; fn marsh(msg: &rustbus::message_builder::MarshalledMessage, buf: &mut Vec) { - marshal(msg, 0, buf).unwrap(); + marshal(msg, NonZeroU32::MIN, buf).unwrap(); } fn unmarshal(buf: &[u8]) { @@ -66,7 +68,7 @@ fn criterion_benchmark(c: &mut Criterion) { .signal("io.killing.spark", "TestSignal", "/io/killing/spark") .build(); msg.body.push_old_params(¶ms).unwrap(); - msg.dynheader.serial = Some(1); + msg.dynheader.serial = Some(NonZeroU32::MIN); buf.clear(); marsh(black_box(&msg), &mut buf) }) @@ -76,7 +78,7 @@ fn criterion_benchmark(c: &mut Criterion) { .signal("io.killing.spark", "TestSignal", "/io/killing/spark") .build(); msg.body.push_old_params(¶ms).unwrap(); - msg.dynheader.serial = Some(1); + msg.dynheader.serial = Some(NonZeroU32::MIN); buf.clear(); marsh(&msg, &mut buf); buf.extend_from_slice(msg.get_buf()); diff --git a/rustbus/src/bin/create_corpus.rs b/rustbus/src/bin/create_corpus.rs index 2c99be6..56caf30 100644 --- a/rustbus/src/bin/create_corpus.rs +++ b/rustbus/src/bin/create_corpus.rs @@ -5,6 +5,7 @@ use rustbus::message_builder::MessageBuilder; use rustbus::Marshal; use std::io::Write; +use std::num::NonZeroU32; fn main() { make_and_dump( @@ -60,7 +61,7 @@ fn make_message() -> MarshalledMessage { fn dump_message(path: &str, msg: &MarshalledMessage) { let mut hdrbuf = vec![]; - rustbus::wire::marshal::marshal(msg, 0, &mut hdrbuf).unwrap(); + rustbus::wire::marshal::marshal(msg, NonZeroU32::MIN, &mut hdrbuf).unwrap(); let mut file = std::fs::File::create(path).unwrap(); file.write_all(&hdrbuf).unwrap(); diff --git a/rustbus/src/bin/perf_test.rs b/rustbus/src/bin/perf_test.rs index 544d175..0c0f58e 100644 --- a/rustbus/src/bin/perf_test.rs +++ b/rustbus/src/bin/perf_test.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU32; + use rustbus::connection::Timeout; use rustbus::MessageBuilder; use rustbus::RpcConn; @@ -28,7 +30,7 @@ fn main() { let mut buf = Vec::new(); for _ in 0..20000000 { buf.clear(); - rustbus::wire::marshal::marshal(&sig, 1, &mut buf).unwrap(); + rustbus::wire::marshal::marshal(&sig, NonZeroU32::MIN, &mut buf).unwrap(); } // for _ in 0..50000000 { diff --git a/rustbus/src/connection/ll_conn.rs b/rustbus/src/connection/ll_conn.rs index a3fc154..e55cfca 100644 --- a/rustbus/src/connection/ll_conn.rs +++ b/rustbus/src/connection/ll_conn.rs @@ -5,6 +5,7 @@ use crate::wire::errors::UnmarshalError; use crate::wire::{marshal, unmarshal, UnixFd}; use std::io::{self, IoSlice, IoSliceMut}; +use std::num::NonZeroU32; use std::os::fd::AsFd; use std::time; @@ -26,7 +27,7 @@ pub struct SendConn { stream: UnixStream, header_buf: Vec, - serial_counter: u32, + serial_counter: NonZeroU32, } pub struct RecvConn { @@ -248,9 +249,12 @@ impl RecvConn { impl SendConn { /// get the next new serial - pub fn alloc_serial(&mut self) -> u32 { + pub fn alloc_serial(&mut self) -> NonZeroU32 { let serial = self.serial_counter; - self.serial_counter += 1; + self.serial_counter = self + .serial_counter + .checked_add(1) + .expect("run out of serials"); serial } @@ -262,9 +266,7 @@ impl SendConn { let serial = if let Some(serial) = msg.dynheader.serial { serial } else { - let serial = self.serial_counter; - self.serial_counter += 1; - serial + self.alloc_serial() }; // clear the buf before marshalling the new header @@ -285,7 +287,7 @@ impl SendConn { } /// send a message and block until all bytes have been sent. Returns the serial of the message to match the response. - pub fn send_message_write_all(&mut self, msg: &MarshalledMessage) -> Result { + pub fn send_message_write_all(&mut self, msg: &MarshalledMessage) -> Result { let ctx = self.send_message(msg)?; ctx.write_all().map_err(force_finish_on_error) } @@ -317,7 +319,7 @@ pub struct SendMessageContext<'a> { #[derive(Debug, Copy, Clone)] pub struct SendMessageState { bytes_sent: usize, - serial: u32, + serial: NonZeroU32, } /// This panics if the SendMessageContext was dropped when it was not yet finished. Use force_finish / force_finish_on_error @@ -333,7 +335,7 @@ impl Drop for SendMessageContext<'_> { } impl SendMessageContext<'_> { - pub fn serial(&self) -> u32 { + pub fn serial(&self) -> NonZeroU32 { self.state.serial } @@ -384,7 +386,10 @@ impl SendMessageContext<'_> { /// Try writing as many bytes as possible until either no more bytes need to be written or /// the timeout is reached. For an infinite timeout there is write_all as a shortcut - pub fn write(mut self, timeout: Timeout) -> std::result::Result { + pub fn write( + mut self, + timeout: Timeout, + ) -> std::result::Result { let start_time = std::time::Instant::now(); // loop until either the time is up or all bytes have been written @@ -409,7 +414,7 @@ impl SendMessageContext<'_> { } /// Block until all bytes have been written - pub fn write_all(self) -> std::result::Result { + pub fn write_all(self) -> std::result::Result { self.write(Timeout::Infinite) } @@ -513,7 +518,7 @@ impl DuplexConn { send: SendConn { stream: stream.try_clone()?, header_buf: Vec::new(), - serial_counter: 1, + serial_counter: NonZeroU32::MIN, }, recv: RecvConn { msg_buf_in: IncomingBuffer::new(), diff --git a/rustbus/src/connection/rpc_conn.rs b/rustbus/src/connection/rpc_conn.rs index 0d2ff4b..6040d9e 100644 --- a/rustbus/src/connection/rpc_conn.rs +++ b/rustbus/src/connection/rpc_conn.rs @@ -5,6 +5,7 @@ use super::ll_conn::DuplexConn; use super::*; use crate::message_builder::{MarshalledMessage, MessageType}; use std::collections::{HashMap, VecDeque}; +use std::num::NonZeroU32; /// Convenience wrapper around the lowlevel connection /// ```rust,no_run @@ -32,7 +33,7 @@ use std::collections::{HashMap, VecDeque}; pub struct RpcConn { signals: VecDeque, calls: VecDeque, - responses: HashMap, + responses: HashMap, conn: DuplexConn, filter: MessageFilter, } @@ -91,7 +92,7 @@ impl RpcConn { } /// get the next new serial - pub fn alloc_serial(&mut self) -> u32 { + pub fn alloc_serial(&mut self) -> NonZeroU32 { self.conn.send.alloc_serial() } @@ -124,12 +125,16 @@ impl RpcConn { } /// Return a response if one is there but dont block - pub fn try_get_response(&mut self, serial: u32) -> Option { + pub fn try_get_response(&mut self, serial: NonZeroU32) -> Option { self.responses.remove(&serial) } /// Return a response if one is there or block until it arrives - pub fn wait_response(&mut self, serial: u32, timeout: Timeout) -> Result { + pub fn wait_response( + &mut self, + serial: NonZeroU32, + timeout: Timeout, + ) -> Result { let start_time = time::Instant::now(); loop { if let Some(msg) = self.try_get_response(serial) { diff --git a/rustbus/src/message_builder.rs b/rustbus/src/message_builder.rs index 42c289f..c0e1ce4 100644 --- a/rustbus/src/message_builder.rs +++ b/rustbus/src/message_builder.rs @@ -1,4 +1,5 @@ //! Build new messages that you want to send over a connection +use std::num::NonZeroU32; use std::os::fd::RawFd; use crate::params::message; @@ -66,11 +67,11 @@ pub struct DynamicHeader { pub member: Option, pub object: Option, pub destination: Option, - pub serial: Option, + pub serial: Option, pub sender: Option, pub signature: Option, pub error_name: Option, - pub response_serial: Option, + pub response_serial: Option, pub num_fds: Option, } diff --git a/rustbus/src/tests.rs b/rustbus/src/tests.rs index 6186656..f361ef8 100644 --- a/rustbus/src/tests.rs +++ b/rustbus/src/tests.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroU32; + use crate::params::Base; use crate::params::Param; use crate::wire::marshal::marshal; @@ -40,9 +42,9 @@ fn test_marshal_unmarshal() { params.push(128u64.into()); params.push(128i32.into()); - msg.dynheader.serial = Some(1); + msg.dynheader.serial = Some(NonZeroU32::MIN); let mut buf = Vec::new(); - marshal(&msg, 0, &mut buf).unwrap(); + marshal(&msg, NonZeroU32::MIN, &mut buf).unwrap(); let mut cursor = Cursor::new(&buf); let header = unmarshal_header(&mut cursor).unwrap(); @@ -97,13 +99,13 @@ fn test_invalid_stuff() { let mut msg = crate::message_builder::MessageBuilder::new() .signal(".......io.killing.spark", "TestSignal", "/io/killing/spark") .build(); - msg.dynheader.serial = Some(1); + msg.dynheader.serial = Some(NonZeroU32::MIN); let mut buf = Vec::new(); assert_eq!( Err(crate::wire::errors::MarshalError::Validation( crate::params::validation::Error::InvalidInterface )), - marshal(&msg, 0, &mut buf) + marshal(&msg, NonZeroU32::MIN, &mut buf) ); // invalid member @@ -114,12 +116,12 @@ fn test_invalid_stuff() { "/io/killing/spark", ) .build(); - msg.dynheader.serial = Some(1); + msg.dynheader.serial = Some(NonZeroU32::MIN); let mut buf = Vec::new(); assert_eq!( Err(crate::wire::errors::MarshalError::Validation( crate::params::validation::Error::InvalidMembername )), - marshal(&msg, 0, &mut buf) + marshal(&msg, NonZeroU32::MIN, &mut buf) ); } diff --git a/rustbus/src/wire.rs b/rustbus/src/wire.rs index 1b22795..1245f0b 100644 --- a/rustbus/src/wire.rs +++ b/rustbus/src/wire.rs @@ -9,6 +9,9 @@ pub mod validate_raw; pub mod variant_macros; mod wrapper_types; + +use std::num::NonZeroU32; + pub use wrapper_types::unixfd::UnixFd; pub use wrapper_types::ObjectPath; pub use wrapper_types::SignatureWrapper; @@ -20,7 +23,7 @@ pub enum HeaderField { Interface(String), Member(String), ErrorName(String), - ReplySerial(u32), + ReplySerial(NonZeroU32), Destination(String), Sender(String), Signature(String), diff --git a/rustbus/src/wire/errors.rs b/rustbus/src/wire/errors.rs index 9ee857c..3c8422e 100644 --- a/rustbus/src/wire/errors.rs +++ b/rustbus/src/wire/errors.rs @@ -45,6 +45,9 @@ pub enum UnmarshalError { /// A message indicated an invalid byteorder in the header #[error("A message indicated an invalid byteorder in the header")] InvalidByteOrder, + /// A message has an invalid (zero) serial in the header + #[error("A message has an invalid (zero) serial in the header")] + InvalidSerial, /// A message indicated an invalid message type #[error("A message indicated an invalid message type")] InvalidMessageType, diff --git a/rustbus/src/wire/marshal.rs b/rustbus/src/wire/marshal.rs index 1588f31..ec22914 100644 --- a/rustbus/src/wire/marshal.rs +++ b/rustbus/src/wire/marshal.rs @@ -3,6 +3,8 @@ //! * `base` and `container` are for the Param approach that map dbus concepts to enums/structs //! * `traits` is for the trait based approach +use std::num::NonZeroU32; + use crate::message_builder; use crate::params; use crate::wire::HeaderField; @@ -34,7 +36,7 @@ impl MarshalContext<'_, '_> { /// and use get_buf() to get to the contents pub fn marshal( msg: &crate::message_builder::MarshalledMessage, - chosen_serial: u32, + chosen_serial: NonZeroU32, buf: &mut Vec, ) -> MarshalResult<()> { marshal_header(msg, chosen_serial, buf)?; @@ -51,7 +53,7 @@ pub fn marshal( fn marshal_header( msg: &crate::message_builder::MarshalledMessage, - chosen_serial: u32, + chosen_serial: NonZeroU32, buf: &mut Vec, ) -> MarshalResult<()> { let byteorder = msg.body.byteorder(); @@ -84,7 +86,7 @@ fn marshal_header( // Zero bytes where the length of the message will be put buf.extend_from_slice(&[0, 0, 0, 0]); - write_u32(chosen_serial, byteorder, buf); + write_u32(chosen_serial.get(), byteorder, buf); // Zero bytes where the length of the header fields will be put let pos = buf.len(); @@ -172,7 +174,7 @@ fn marshal_header_field( buf.push(b'u'); buf.push(0); pad_to_align(4, buf); - write_u32(*rs, byteorder, buf); + write_u32(rs.get(), byteorder, buf); } HeaderField::Destination(dest) => { params::validate_busname(dest)?; diff --git a/rustbus/src/wire/unmarshal.rs b/rustbus/src/wire/unmarshal.rs index 117ac54..7faf2b3 100644 --- a/rustbus/src/wire/unmarshal.rs +++ b/rustbus/src/wire/unmarshal.rs @@ -4,6 +4,8 @@ //! * `traits` is for the trait based approach //! * `iter` is an experimental approach to an libdbus-like iterator +use std::num::NonZeroU32; + use crate::message_builder::DynamicHeader; use crate::message_builder::MarshalledMessage; use crate::message_builder::MarshalledMessageBody; @@ -33,7 +35,7 @@ pub struct Header { pub flags: u8, pub version: u8, pub body_len: u32, - pub serial: u32, + pub serial: NonZeroU32, } impl From for UnmarshalError { @@ -75,7 +77,8 @@ pub fn unmarshal_header(cursor: &mut Cursor) -> UnmarshalResult
{ let flags = cursor.read_u8()?; let version = cursor.read_u8()?; let body_len = cursor.read_u32(byteorder)?; - let serial = cursor.read_u32(byteorder)?; + let serial = + NonZeroU32::new(cursor.read_u32(byteorder)?).ok_or(UnmarshalError::InvalidSerial)?; Ok(Header { byteorder, @@ -242,7 +245,9 @@ fn unmarshal_header_field(header: &Header, cursor: &mut Cursor) -> UnmarshalResu }, 5 => match sig { signature::Type::Base(signature::Base::Uint32) => { - Ok(HeaderField::ReplySerial(cursor.read_u32(header.byteorder)?)) + NonZeroU32::new(cursor.read_u32(header.byteorder)?) + .ok_or(UnmarshalError::InvalidHeaderField) + .map(HeaderField::ReplySerial) } _ => Err(UnmarshalError::WrongSignature), },