Skip to content

Commit

Permalink
change serialization to only serialize the bytes
Browse files Browse the repository at this point in the history
introduce max alignment (which depends on platform anyway) and dont serialize that part
fixes Clone, Debug, and Eq impls to work on the bytes, not the pointers.
  • Loading branch information
WorldSEnder committed Oct 24, 2024
1 parent cc8549f commit 45748ca
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 31 deletions.
5 changes: 3 additions & 2 deletions crates/burn-core/src/record/serde/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ impl NestedValue {
}

/// Get the nested value as a vector of bytes.
pub fn as_bytes(self) -> Option<Vec<u8>> {
pub fn as_bytes(self) -> Option<Bytes> {
match self {
NestedValue::U8s(u) => Some(u),
NestedValue::Bytes(u) => Some(u),
NestedValue::U8s(u) => Some(Bytes::from_elems(u)),
_ => None,
}
}
Expand Down
6 changes: 5 additions & 1 deletion crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
where
V: Visitor<'de>,
{
visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap())
let bytes = self.value.unwrap().as_bytes().unwrap();
match bytes.try_into_vec::<u8>() {
Ok(bytes) => visitor.visit_byte_buf(bytes),
Err(bytes) => visitor.visit_bytes(&bytes),
}
}

fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,6 @@ mod tests {
.clone()
.as_bytes()
.expect("has bytes vec");
assert_eq!(bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());
assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());
}
}
94 changes: 67 additions & 27 deletions crates/burn-tensor/src/tensor/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
//! A verion of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements.
//! A version of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements.
use core::{
alloc::{Layout, LayoutError},
ops::{Deref, DerefMut},
ptr::NonNull,
};
use std::alloc::{Layout, LayoutError};
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;

/// A sort of `Box<[u8]>` that remembers the original alignment and can contain trailing uninitialized bytes.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Bytes {
/// SAFETY:
/// - If `layout.size() > 0`, `ptr` points to a valid allocation from the global allocator
Expand All @@ -19,23 +16,50 @@ pub struct Bytes {
layout: Layout,
}

#[derive(serde::Serialize, serde::Deserialize)]
struct WireFormat<'a> {
align: usize,
#[serde(with = "serde_bytes", borrow)]
data: Cow<'a, [u8]>,
/// The maximum supported alignment. The limit exists to not have to store alignment when serializing. Instead,
/// the bytes are always over-aligned when deserializing to MAX_ALIGN.
const MAX_ALIGN: usize = std::mem::align_of::<u128>();

fn debug_from_fn<F: Fn(&mut std::fmt::Formatter<'_>) -> std::fmt::Result>(
f: F,
) -> impl std::fmt::Debug {
// See also: std::fmt::from_fn
struct FromFn<F>(F);
impl<F> std::fmt::Debug for FromFn<F>
where
F: Fn(&mut std::fmt::Formatter<'_>) -> std::fmt::Result,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
(self.0)(f)
}
}
FromFn(f)
}

impl std::fmt::Debug for Bytes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let data = &**self;
let fmt_data = move |f: &mut std::fmt::Formatter<'_>| {
if data.len() > 3 {
// There is a nightly API `debug_more_non_exhaustive` which has `finish_non_exhaustive`
f.debug_list().entries(&data[0..3]).entry(&"...").finish()
} else {
f.debug_list().entries(data).finish()
}
};
f.debug_struct("Bytes")
.field("data", &debug_from_fn(fmt_data))
.field("layout", &self.layout)
.finish()
}
}

impl serde::Serialize for Bytes {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
WireFormat {
align: self.layout.align(),
data: Cow::Borrowed(self),
}
.serialize(serializer)
serde_bytes::serialize(self.deref(), serializer)
}
}

Expand All @@ -44,21 +68,36 @@ impl<'de> serde::Deserialize<'de> for Bytes {
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let wire = WireFormat::deserialize(deserializer)?;
Self::from_data(wire.align, wire.data)
.map_err(|_| Error::custom("alignment is invalid, or length too large"))
// When deserializing, we over-align the data. This saves us from having to encode the alignment (which is platform-dependent in any case).
let data: Vec<u8> = serde_bytes::deserialize(deserializer)?;
Self::try_from_data(MAX_ALIGN, Cow::Owned(data))
.map_err(|_| serde::de::Error::custom("alignment is invalid, or length too large"))
}
}

impl Clone for Bytes {
fn clone(&self) -> Self {
// unwrap here: the layout is always valid as it has the alignment & size of self
Self::try_from_data(self.layout.align(), Cow::Borrowed(self.deref())).unwrap()
}
}

impl PartialEq for Bytes {
fn eq(&self, other: &Self) -> bool {
self.deref() == other.deref()
}
}

impl Eq for Bytes {}

impl Bytes {
/// Convert from possibly owned data to Bytes.
fn from_data(align: usize, data: Cow<'_, [u8]>) -> Result<Self, LayoutError> {
fn try_from_data(align: usize, data: Cow<'_, [u8]>) -> Result<Self, LayoutError> {
let len = data.len();
let layout = Layout::from_size_align(len, align)?;
// TODO: we can possibly avoid a copy here (or even earlier by replacing serde_bytes::deserialize) by deserializing into an existing,
// correctly aligned, slice of bytes. Since we might not be able to fully predict the length and align ahead of time, this does currently
// not seem worth the hazzle.
// not seem worth the hassle.
let bytes = unsafe {
let mem = std::alloc::alloc(layout);
std::ptr::copy_nonoverlapping(data.as_ref().as_ptr(), mem, len);
Expand All @@ -72,17 +111,18 @@ impl Bytes {
}

/// Erase the element type of a vector by converting into a sequence of [Bytes].
pub fn from_elems<E: bytemuck::NoUninit>(mut elems: Vec<E>) -> Self {
pub fn from_elems<E: bytemuck::NoUninit + Copy>(mut elems: Vec<E>) -> Self {
let _: () = const {
assert!(
!std::mem::needs_drop::<E>(),
"elements must not need a drop impl"
std::mem::align_of::<E>() <= MAX_ALIGN,
"element type not supported due to too large alignment"
);
};
// Note: going through a Box as in Vec::into_boxed_slice would re-allocate on excess capacity. Avoid that.
let byte_len = elems.len() * std::mem::size_of::<E>();
// Set the length to 0, then all data is in the "spare capacity".
// SAFETY: Careful not to panic now, or this leaks our data!
// SAFETY: Data is Copy, so in particular does not need to be dropped. In any case, try not to panic until
// we have taken ownership of the data!
unsafe { elems.set_len(0) };
let data = elems.spare_capacity_mut();
// We do this to get one contiguous slice of data to pass to Layout::for_value.
Expand Down

0 comments on commit 45748ca

Please sign in to comment.