Skip to content

Commit

Permalink
Support big endian headers in addition to little endian.
Browse files Browse the repository at this point in the history
  • Loading branch information
de-vri-es committed Dec 4, 2023
1 parent b568d9b commit 58b14dc
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 57 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Unreleased
- [change][major] Mark `StreamConfig` and `UnixConfig` as non-exhaustive structs.
- [change][major] Make the `MessageHeader::encode/decode()` functions take an `endian` parameter.
- [add][major] Add an `endian` field to `StreamConfig` and `UnixConfig`.

# Version 0.7.1 - 2023-11-26
- [change][patch] Remove dependency on `byteorder` crate.

Expand Down
41 changes: 11 additions & 30 deletions src/message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::Error;
use crate::error::private::InnerError;
use crate::transport::Endian;

/// The encoded length of a message header.
///
Expand Down Expand Up @@ -220,16 +221,16 @@ impl MessageHeader {
}
}

/// Decode a message header from a byte slice.
/// Decode a message header from a byte slice using the given endianness for the header fields.
///
/// The byte slice should NOT contain the message size.
///
/// # Panic
/// This function panics if the buffer does not contain a full header.
pub fn decode(buffer: &[u8]) -> Result<Self, Error> {
let message_type = read_u32_le(&buffer[0..]);
let request_id = read_u32_le(&buffer[4..]);
let service_id = read_i32_le(&buffer[8..]);
pub fn decode(buffer: &[u8], endian: Endian) -> Result<Self, Error> {
let message_type = endian.read_u32(&buffer[0..]);
let request_id = endian.read_u32(&buffer[4..]);
let service_id = endian.read_i32(&buffer[8..]);

let message_type = MessageType::from_u32(message_type)?;
Ok(Self {
Expand All @@ -239,17 +240,17 @@ impl MessageHeader {
})
}

/// Encode a message header into a byte slice.
/// Encode a message header into a byte slice using the given endianness for the header fields.
///
/// This will NOT add a message size (which would be impossible even if we wanted to).
///
/// # Panic
/// This function panics if the buffer is not large enough to hold a full header.
pub fn encode(&self, buffer: &mut [u8]) {
pub fn encode(&self, buffer: &mut [u8], endian: Endian) {
assert!(buffer.len() >= 12);
write_u32_le(&mut buffer[0..], self.message_type as u32);
write_u32_le(&mut buffer[4..], self.request_id);
write_i32_le(&mut buffer[8..], self.service_id);
endian.write_u32(&mut buffer[0..], self.message_type as u32);
endian.write_u32(&mut buffer[4..], self.request_id);
endian.write_i32(&mut buffer[8..], self.service_id);
}
}

Expand All @@ -260,23 +261,3 @@ impl<Body> std::fmt::Debug for Message<Body> {
.finish_non_exhaustive()
}
}

/// Read a [`u32`] from a buffer in little endian format.
fn read_u32_le(buffer: &[u8]) -> u32 {
u32::from_le_bytes(buffer[0..4].try_into().unwrap())
}

/// Read a [`i32`] from a buffer in little endian format.
fn read_i32_le(buffer: &[u8]) -> i32 {
i32::from_le_bytes(buffer[0..4].try_into().unwrap())
}

/// Write a [`i32`] to a buffer in little endian format.
fn write_i32_le(buffer: &mut [u8], value: i32) {
buffer[0..4].copy_from_slice(&value.to_le_bytes());
}

/// Write a [`u32`] to a buffer in little endian format.
fn write_u32_le(buffer: &mut [u8], value: u32) {
buffer[0..4].copy_from_slice(&value.to_le_bytes());
}
50 changes: 50 additions & 0 deletions src/transport/endian.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/// The endianness to use for encoding header fields.
///
/// The encoding and serialization of message bodies is up to the application code,
/// and it not affected by this configuration parameter.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Endian {
/// Encode header fields in little endian.
LittleEndian,

/// Encode header fields in big endian.
BigEndian,
}

impl Endian {
/// Read a [`u32`] from a buffer in the correct endianness.
pub(crate) fn read_u32(self, buffer: &[u8]) -> u32 {
let buffer = buffer[0..4].try_into().unwrap();
match self {
Self::LittleEndian => u32::from_le_bytes(buffer),
Self::BigEndian => u32::from_be_bytes(buffer),
}
}

/// Write a [`u32`] to a buffer in thcorrect endianness.
pub(crate) fn write_u32(self, buffer: &mut [u8], value: u32) {
let bytes = match self {
Self::LittleEndian => value.to_le_bytes(),
Self::BigEndian => value.to_be_bytes(),
};
buffer[0..4].copy_from_slice(&bytes);
}

/// Read a [`i32`] from a buffer in the correct endianness.
pub(crate) fn read_i32(self, buffer: &[u8]) -> i32 {
let buffer = buffer[0..4].try_into().unwrap();
match self {
Self::LittleEndian => i32::from_le_bytes(buffer),
Self::BigEndian => i32::from_be_bytes(buffer),
}
}

/// Write a [`i32`] to a buffer in thcorrect endianness.
pub(crate) fn write_i32(self, buffer: &mut [u8], value: i32) {
let bytes = match self {
Self::LittleEndian => value.to_le_bytes(),
Self::BigEndian => value.to_be_bytes(),
};
buffer[0..4].copy_from_slice(&bytes);
}
}
3 changes: 3 additions & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ use std::task::{Context, Poll};

use crate::{Error, Message, MessageHeader};

mod endian;
pub use endian::Endian;

pub(crate) mod stream;
pub use stream::StreamTransport;

Expand Down
10 changes: 10 additions & 0 deletions src/transport/stream/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::transport::Endian;

/// Configuration for a byte-stream transport.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct StreamConfig {
/// The maximum body size for incoming messages.
///
Expand All @@ -13,13 +16,20 @@ pub struct StreamConfig {
/// the message is discarded and an error is returned.
/// Stream sockets remain usable since the message header will not be sent either.
pub max_body_len_write: u32,

/// The endianness to use when encoding/decoding header fields.
///
/// The encoding and serialization of message bodies is up to the application code,
/// and it not affected by this configuration parameter.
pub endian: Endian,
}

impl Default for StreamConfig {
fn default() -> Self {
Self {
max_body_len_read: 8 * 1024,
max_body_len_write: 8 * 1024,
endian: Endian::LittleEndian,
}
}
}
8 changes: 4 additions & 4 deletions src/transport/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ mod impl_unix_stream {

fn split(&mut self) -> (StreamReadHalf<tokio::net::unix::ReadHalf>, StreamWriteHalf<tokio::net::unix::WriteHalf>) {
let (read_half, write_half) = self.stream.split();
let read_half = StreamReadHalf::new(read_half, self.config.max_body_len_read);
let write_half = StreamWriteHalf::new(write_half, self.config.max_body_len_write);
let read_half = StreamReadHalf::new(read_half, self.config.max_body_len_read, self.config.endian);
let write_half = StreamWriteHalf::new(write_half, self.config.max_body_len_write, self.config.endian);
(read_half, write_half)
}

Expand Down Expand Up @@ -155,8 +155,8 @@ mod impl_tcp {

fn split(&mut self) -> (StreamReadHalf<tokio::net::tcp::ReadHalf>, StreamWriteHalf<tokio::net::tcp::WriteHalf>) {
let (read_half, write_half) = self.stream.split();
let read_half = StreamReadHalf::new(read_half, self.config.max_body_len_read);
let write_half = StreamWriteHalf::new(write_half, self.config.max_body_len_write);
let read_half = StreamReadHalf::new(read_half, self.config.max_body_len_read, self.config.endian);
let write_half = StreamWriteHalf::new(write_half, self.config.max_body_len_write, self.config.endian);
(read_half, write_half)
}

Expand Down
32 changes: 15 additions & 17 deletions src/transport/stream/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tokio::io::{AsyncRead, AsyncWrite};

use super::{StreamBody, StreamConfig};
use crate::error::private::check_payload_too_large;
use crate::transport::TransportError;
use crate::transport::{TransportError, Endian};
use crate::{Message, MessageHeader};

/// Length of a message frame and header.
Expand All @@ -30,6 +30,9 @@ pub struct StreamReadHalf<ReadStream> {
/// The maximum body length to accept when reading messages.
pub(super) max_body_len: u32,

/// The endianness to use for decoding header fields.
pub(super) endian: Endian,

/// The number of bytes read for the current message.
pub(super) bytes_read: usize,

Expand All @@ -52,6 +55,9 @@ pub struct StreamWriteHalf<WriteStream> {
/// The maximum body length to enforce for messages.
pub(super) max_body_len: u32,

/// The endianness to use for encoding header fields.
pub(super) endian: Endian,

/// The number of bytes written for the current message.
pub(super) bytes_written: usize,

Expand Down Expand Up @@ -91,10 +97,11 @@ where

impl<ReadStream> StreamReadHalf<ReadStream> {
#[allow(dead_code)] // Not used when transports are disabled.
pub(super) fn new(stream: ReadStream, max_body_len: u32) -> Self {
pub(super) fn new(stream: ReadStream, max_body_len: u32, endian: Endian) -> Self {
Self {
stream,
max_body_len,
endian,
header_buffer: [0u8; FRAMED_HEADER_LEN],
bytes_read: 0,
parsed_header: MessageHeader::request(0, 0),
Expand All @@ -117,10 +124,11 @@ impl<ReadStream> StreamReadHalf<ReadStream> {

impl<WriteStream> StreamWriteHalf<WriteStream> {
#[allow(dead_code)] // Not used when transports are disabled.
pub(super) fn new(stream: WriteStream, max_body_len: u32) -> Self {
pub(super) fn new(stream: WriteStream, max_body_len: u32, endian: Endian) -> Self {
Self {
stream,
max_body_len,
endian,
header_buffer: None,
bytes_written: 0,
}
Expand Down Expand Up @@ -171,8 +179,8 @@ where
// Check if we have the whole frame + header.
if this.bytes_read == FRAMED_HEADER_LEN {
// Parse frame and header.
let length = read_u32_le(&this.header_buffer[0..]);
this.parsed_header = MessageHeader::decode(&this.header_buffer[4..])
let length = this.endian.read_u32(&this.header_buffer[0..]);
this.parsed_header = MessageHeader::decode(&this.header_buffer[4..], this.endian)
.map_err(TransportError::new_fatal)?;

// Check body length and create body buffer.
Expand Down Expand Up @@ -218,8 +226,8 @@ where
// Encode the header if we haven't done that yet.
let header_buffer = this.header_buffer.get_or_insert_with(|| {
let mut buffer = [0u8; FRAMED_HEADER_LEN];
write_u32_le(&mut buffer[0..], body.len() as u32 + crate::HEADER_LEN);
header.encode(&mut buffer[4..]);
this.endian.write_u32(&mut buffer[0..], body.len() as u32 + crate::HEADER_LEN);
header.encode(&mut buffer[4..], this.endian);
buffer
});

Expand All @@ -241,13 +249,3 @@ where
Poll::Ready(Ok(()))
}
}

/// Read a [`u32`] from a buffer in little endian format.
fn read_u32_le(buffer: &[u8]) -> u32 {
u32::from_le_bytes(buffer[0..4].try_into().unwrap())
}

/// Write a [`u32`] to a buffer in little endian format.
fn write_u32_le(buffer: &mut [u8], value: u32) {
buffer[0..4].copy_from_slice(&value.to_le_bytes());
}
10 changes: 10 additions & 0 deletions src/transport/unix/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::transport::Endian;

/// Configuration for Unix datagram transports.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct UnixConfig {
/// The maximum body size for incoming messages.
///
Expand All @@ -22,6 +25,12 @@ pub struct UnixConfig {

/// The maximum number of attached file descriptors for sending messages.
pub max_fds_write: u32,

/// The endianness to use when encoding/decoding header fields.
///
/// The encoding and serialization of message bodies is up to the application code,
/// and it not affected by this configuration parameter.
pub endian: Endian,
}

impl Default for UnixConfig {
Expand All @@ -31,6 +40,7 @@ impl Default for UnixConfig {
max_body_len_write: 4 * 1024,
max_fds_read: 10,
max_fds_write: 10,
endian: Endian::LittleEndian,
}
}
}
4 changes: 2 additions & 2 deletions src/transport/unix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ mod impl_unix_seqpacket {

fn split(&mut self) -> (UnixReadHalf<&tokio_seqpacket::UnixSeqpacket>, UnixWriteHalf<&tokio_seqpacket::UnixSeqpacket>) {
let (read_half, write_half) = (&self.socket, &self.socket);
let read_half = UnixReadHalf::new(read_half, self.config.max_body_len_read, self.config.max_fds_read);
let write_half = UnixWriteHalf::new(write_half, self.config.max_body_len_write, self.config.max_fds_write);
let read_half = UnixReadHalf::new(read_half, self.config.max_body_len_read, self.config.max_fds_read, self.config.endian);
let write_half = UnixWriteHalf::new(write_half, self.config.max_body_len_write, self.config.max_fds_write, self.config.endian);
(read_half, write_half)
}

Expand Down
Loading

0 comments on commit 58b14dc

Please sign in to comment.