From 28f99d14e9a43ffe31c3d9652da4ac4ff00ee30b Mon Sep 17 00:00:00 2001 From: WorldSEnder Date: Tue, 17 Dec 2024 16:40:44 +0100 Subject: [PATCH] Fix alignment issue of TensorData bytes (#2416) * implement memory-safe bytes that can be serialized and cloned * change serialization to only serialize the bytes introduce max alignment (which depends on platform anyway) and dont serialize that part fixes Clone, Debug, and Eq impls to work on the bytes, not the pointers. * make bytes no-std compatible * enforce Send and Sync for Bytes * avoid a copy during deserialization if data is already aligned this already improves readability a bit by separating out alloc/dealloc logic and adding a bunch of safety comments and better error messages * revert back to using Vec as deserialization intermediate borrowing from the deserializer will not save a copy, and is moreover inefficient when we could take ownership of an existing byte buffer * add serialization and conversion tests * make Bytes tests run under miri both changes only target miri's borrowing semantics, oprationally the pointers are the same, but they obey different borrow-stack rules. * let the Bytes buffer grow * Clean the code by separation of concerns The new Allocation struct keeps the raw allocation and its layout, the Bytes struct wraps an Allocation and asserts that len bytes of it are initialized * nit: change typo and improve internal naming * use Bytes in jit ops --- Cargo.lock | 1 + crates/burn-core/src/record/serde/data.rs | 12 +- crates/burn-core/src/record/serde/de.rs | 6 +- crates/burn-core/src/record/serde/ser.rs | 2 +- crates/burn-import/src/pytorch/reader.rs | 2 +- crates/burn-jit/src/backend.rs | 3 +- crates/burn-jit/src/ops/qtensor.rs | 8 +- crates/burn-jit/src/ops/transaction.rs | 8 +- crates/burn-tensor/Cargo.toml | 1 + crates/burn-tensor/src/tensor/bytes.rs | 547 ++++++++++++++++++ crates/burn-tensor/src/tensor/data.rs | 122 ++-- crates/burn-tensor/src/tensor/mod.rs | 2 + .../src/tensor/quantization/data.rs | 7 +- 13 files changed, 641 insertions(+), 80 deletions(-) create mode 100644 crates/burn-tensor/src/tensor/bytes.rs diff --git a/Cargo.lock b/Cargo.lock index 9eb57e9c70..0fa466bb87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -937,6 +937,7 @@ dependencies = [ name = "burn-tensor" version = "0.16.0" dependencies = [ + "bincode", "burn-common", "burn-tensor-testgen", "bytemuck", diff --git a/crates/burn-core/src/record/serde/data.rs b/crates/burn-core/src/record/serde/data.rs index 27a55c5b2a..b5e5442b54 100644 --- a/crates/burn-core/src/record/serde/data.rs +++ b/crates/burn-core/src/record/serde/data.rs @@ -8,6 +8,7 @@ use crate::record::{PrecisionSettings, Record}; use crate::tensor::backend::Backend; use alloc::fmt; +use burn_tensor::Bytes; use num_traits::cast::ToPrimitive; use regex::Regex; use serde::Deserialize; @@ -66,7 +67,11 @@ pub enum NestedValue { /// A vector of 32-bit floating point values. F32s(Vec), + + /// An opaque vector of bytes, with alignment. + Bytes(Bytes), } + impl NestedValue { /// Get the nested value as a map. pub fn as_map(self) -> Option> { @@ -184,9 +189,10 @@ impl NestedValue { } /// Get the nested value as a vector of bytes. - pub fn as_bytes(self) -> Option> { + pub fn as_bytes(self) -> Option { match self { - NestedValue::U8s(u) => Some(u), + NestedValue::Bytes(u) => Some(u), + NestedValue::U8s(u) => Some(Bytes::from_elems(u)), _ => None, } } @@ -368,6 +374,7 @@ impl fmt::Debug for NestedValue { NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), + NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f), // Handle other variants as usual NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(), NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(), @@ -385,6 +392,7 @@ impl fmt::Debug for NestedValue { NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(), + NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(), } } } diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 3c93afed16..04bed6899f 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -233,7 +233,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { where V: Visitor<'de>, { - visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap()) + let bytes = self.value.unwrap().as_bytes().unwrap(); + match bytes.try_into_vec::() { + Ok(bytes) => visitor.visit_byte_buf(bytes), + Err(bytes) => visitor.visit_bytes(&bytes), + } } fn deserialize_option(self, visitor: V) -> Result diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index f2803d8bdc..b0baaa5cd1 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -383,6 +383,6 @@ mod tests { .clone() .as_bytes() .expect("has bytes vec"); - assert_eq!(bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); + assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); } } diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs index 55a8d9838a..87a76278b8 100644 --- a/crates/burn-import/src/pytorch/reader.rs +++ b/crates/burn-import/src/pytorch/reader.rs @@ -152,7 +152,7 @@ where // Because serializer copies individual elements of TensorData `value` into a new Vec, // which is not necessary and inefficient. let mut tensor_data: HashMap = HashMap::new(); - tensor_data.insert("bytes".into(), NestedValue::U8s(bytes)); + tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes)); tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?); tensor_data.insert("dtype".into(), dtype.serialize(serializer)?); diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 3f5637e412..e29a98ca96 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -7,8 +7,7 @@ use std::{marker::PhantomData, sync::Mutex}; #[cfg(not(feature = "fusion"))] use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - quantization::QuantizationScheme, - repr::{HandleKind, ReprBackend, TensorHandle}, + repr::{ReprBackend, TensorHandle}, }; pub(crate) static SEED: Mutex> = Mutex::new(None); diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index cde44b0552..feca09a41f 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -3,7 +3,7 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, - DType, Device, Shape, TensorData, + Bytes, DType, Device, Shape, TensorData, }; use crate::{ @@ -82,12 +82,8 @@ where let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; - // TODO: this should be refactored such that the bytes type is opaque. - // With this, the logic for handling the bytes representation of quantized data - // (as well as all data manipulations) will be encapsulated in the type. - // Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate. TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape: tensor.shape.into(), dtype: tensor.dtype, } diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index b67740bc97..49bc1936ee 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{TransactionOps, TransactionPrimitiveResult}, - DType, TensorData, + Bytes, DType, TensorData, }; use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; @@ -74,7 +74,7 @@ where Kind::Float(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_floats.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); @@ -82,7 +82,7 @@ where Kind::Int(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_ints.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); @@ -90,7 +90,7 @@ where Kind::Bool(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_bools.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 9c4dfa3a49..02683531c6 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -56,6 +56,7 @@ portable-atomic-util = { workspace = true } [dev-dependencies] rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std +bincode = { workspace = true } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-tensor/src/tensor/bytes.rs b/crates/burn-tensor/src/tensor/bytes.rs new file mode 100644 index 0000000000..f9cb238a26 --- /dev/null +++ b/crates/burn-tensor/src/tensor/bytes.rs @@ -0,0 +1,547 @@ +//! A version of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements. + +use alloc::alloc::{Layout, LayoutError}; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +use alloc::vec::Vec; + +/// Internally used to avoid accidentally leaking an allocation or using the wrong layout. +struct Allocation { + /// SAFETY: + /// - If `layout.size() > 0`, `ptr` points to a valid allocation from the global allocator + /// of the specified layout. The first `len` bytes are initialized. + /// - If `layout.size() == 0`, `ptr` is aligned to `layout.align()` and `len` is 0. + /// `ptr` is further suitable to be used as the argument for `Vec::from_raw_parts` see [buffer alloc] + /// for more details. + ptr: NonNull, + layout: Layout, +} + +/// A sort of `Box<[u8]>` that remembers the original alignment and can contain trailing uninitialized bytes. +pub struct Bytes { + alloc: Allocation, + // SAFETY: The first `len` bytes of the allocation are initialized + len: usize, +} + +/// The maximum supported alignment. The limit exists to not have to store alignment when serializing. Instead, +/// the bytes are always over-aligned when deserializing to MAX_ALIGN. +const MAX_ALIGN: usize = core::mem::align_of::(); + +fn debug_from_fn) -> core::fmt::Result>( + f: F, +) -> impl core::fmt::Debug { + // See also: std::fmt::from_fn + struct FromFn(F); + impl core::fmt::Debug for FromFn + where + F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result, + { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + (self.0)(f) + } + } + FromFn(f) +} + +impl core::fmt::Debug for Bytes { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let data = &**self; + let fmt_data = move |f: &mut core::fmt::Formatter<'_>| { + if data.len() > 3 { + // There is a nightly API `debug_more_non_exhaustive` which has `finish_non_exhaustive` + f.debug_list().entries(&data[0..3]).entry(&"...").finish() + } else { + f.debug_list().entries(data).finish() + } + }; + f.debug_struct("Bytes") + .field("data", &debug_from_fn(fmt_data)) + .field("len", &self.len) + .finish() + } +} + +impl serde::Serialize for Bytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde_bytes::serialize(self.deref(), serializer) + } +} + +impl<'de> serde::Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cold] + fn too_large(len: usize, align: usize) -> E { + // max_length = largest multiple of align that is <= isize::MAX + // align is a power of 2, hence a multiple has the lower bits unset. Mask them off to find the largest multiple + let max_length = (isize::MAX as usize) & !(align - 1); + E::custom(core::format_args!( + "length too large: {len}. Expected at most {max_length} bytes" + )) + } + + // TODO: we can possibly avoid one copy here by deserializing into an existing, correctly aligned, slice of bytes. + // We might not be able to predict the length of the data, hence it's far more convenient to let `Vec` handle the growth and re-allocations. + // Further, on a lot of systems, the allocator naturally aligns data to some reasonably large alignment, where no further copy is then + // necessary. + let data: Vec = serde_bytes::deserialize(deserializer)?; + // When deserializing, we over-align the data. This saves us from having to encode the alignment (which is platform-dependent in any case). + // If we had more context information here, we could enforce some (smaller) alignment per data type. But this information is only available + // in `TensorData`. Moreover it depends on the Deserializer there whether the datatype or data comes first. + let align = MAX_ALIGN; + let mut bytes = Self::from_elems(data); + bytes + .try_enforce_runtime_align(align) + .map_err(|_| too_large(bytes.len(), align))?; + Ok(bytes) + } +} + +impl Clone for Bytes { + fn clone(&self) -> Self { + // unwrap here: the layout is valid as it has the alignment & size of self + Self::try_from_data(MAX_ALIGN, self.deref()).unwrap() + } +} + +impl PartialEq for Bytes { + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl Eq for Bytes {} + +impl Allocation { + // Wrap the allocation of a vector without copying + fn from_vec(vec: Vec) -> Self { + let mut elems = core::mem::ManuallyDrop::new(vec); + // Set the length to 0, then all data is in the "spare capacity". + // SAFETY: Data is Copy, so in particular does not need to be dropped. In any case, try not to panic until + // we have taken ownership of the data! + unsafe { elems.set_len(0) }; + let data = elems.spare_capacity_mut(); + // We now have one contiguous slice of data to pass to Layout::for_value. + let layout = Layout::for_value(data); + // SAFETY: data is the allocation of a vec, hence can not be null. We use unchecked to avoid a panic-path. + let ptr = unsafe { NonNull::new_unchecked(elems.as_mut_ptr().cast()) }; + Self { ptr, layout } + } + // Create a new allocation with the specified layout + fn new(layout: Layout) -> Self { + let ptr = buffer_alloc(layout); + Self { ptr, layout } + } + // Reallocate to fit at least the size and align of min_layout + fn grow(&mut self, min_layout: Layout) { + (self.layout, self.ptr) = buffer_grow(self.layout, self.ptr, min_layout); + } + // Returns a mutable view of the memory of the whole allocation + fn memory_mut(&mut self) -> &mut [MaybeUninit] { + // SAFETY: See type invariants + unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr().cast(), self.layout.size()) } + } + // Return a pointer to the underlying allocation. This pointer is valid for reads and writes until the allocation is dropped or reallocated. + fn as_mut_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + // Try to convert the allocation to a Vec. The Vec has a length of 0 when returned, but correct capacity and pointer! + fn try_into_vec(self) -> Result, Self> { + let byte_capacity = self.layout.size(); + let Some(capacity) = byte_capacity.checked_div(size_of::()) else { + return Err(self); + }; + if capacity * size_of::() != byte_capacity { + return Err(self); + }; + if self.layout.align() != align_of::() { + return Err(self); + } + // Okay, let's commit + let ptr = self.ptr.as_ptr().cast(); + core::mem::forget(self); + // SAFETY: + // - ptr was allocated by the global allocator as per type-invariant + // - `E` has the same alignment as indicated by the stored layout. + // - capacity * size_of:: == layout.size() + // - 0 <= capacity + // - no bytes are claimed to be initialized + // - the layout represents a valid allocation, hence has allocation size less than isize::MAX + Ok(unsafe { Vec::from_raw_parts(ptr, 0, capacity) }) + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + buffer_dealloc(self.layout, self.ptr); + } +} + +// Allocate a pointer that can be passed to Vec::from_raw_parts +fn buffer_alloc(layout: Layout) -> NonNull { + // [buffer alloc]: The current docs of Vec::from_raw_parts(ptr, ...) say: + // > ptr must have been allocated using the global allocator + // Yet, an empty Vec is guaranteed to not allocate (it is even illegal! to allocate with a zero-sized layout) + // Hence, we slightly re-interpret the above to only needing to hold if `capacity > 0`. Still, the pointer + // must be non-zero. So in case we need a pointer for an empty vec, use a correctly aligned, dangling one. + if layout.size() == 0 { + // we would use NonNull:dangling() but we don't have a concrete type for the requested alignment + let ptr = core::ptr::null_mut::().wrapping_add(layout.align()); + // SAFETY: layout.align() is never 0 + unsafe { NonNull::new_unchecked(ptr) } + } else { + // SAFETY: layout has non-zero size. + let ptr = unsafe { alloc::alloc::alloc(layout) }; + NonNull::new(ptr).unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout)) + } +} + +fn expect_dangling(align: usize, buffer: NonNull) { + debug_assert!( + buffer.as_ptr().wrapping_sub(align).is_null(), + "expected a nullptr for size 0" + ); +} + +#[cold] +fn alloc_overflow() -> ! { + panic!("Overflow, too many elements") +} + +// Grow the buffer while keeping alignment +fn buffer_grow( + old_layout: Layout, + buffer: NonNull, + min_layout: Layout, +) -> (Layout, NonNull) { + let new_align = min_layout.align().max(old_layout.align()); // Don't let data become less aligned + let new_size = min_layout.size().next_multiple_of(new_align); + if new_size > isize::MAX as usize { + alloc_overflow(); + } + + assert!(new_size > old_layout.size(), "size must actually grow"); + if old_layout.size() == 0 { + expect_dangling(old_layout.align(), buffer); + let new_layout = Layout::from_size_align(new_size, new_align).unwrap(); + let buffer = buffer_alloc(new_layout); + return (new_layout, buffer); + }; + let realloc = || { + let new_layout = Layout::from_size_align(new_size, old_layout.align()).unwrap(); + // SAFETY: + // - buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`]. + // - old_layout is the same as with which the pointer was allocated + // - new_size is not 0, since it is larger than old_layout.size() which is non-zero + // - size constitutes a valid layout + let ptr = unsafe { alloc::alloc::realloc(buffer.as_ptr(), old_layout, new_layout.size()) }; + (new_layout, ptr) + }; + if new_align == old_layout.align() { + // happy path. We can just realloc. + let (new_layout, ptr) = realloc(); + let buffer = NonNull::new(ptr); + let buffer = buffer.unwrap_or_else(|| alloc::alloc::handle_alloc_error(new_layout)); + return (new_layout, buffer); + } + // [buffer grow]: alloc::realloc can *not* change the alignment of the allocation's layout. + // The unstable Allocator::{grow,shrink} API changes this, but might take a while to make it + // into alloc::GlobalAlloc. + // + // As such, we can not request a specific alignment. But most allocators will give us the required + // alignment "for free". Hence, we speculatively avoid a mem-copy by using realloc. + // + // If in the future requesting an alignment change for an existing is available, this can be removed. + #[cfg(target_has_atomic = "8")] + mod alignment_assumption { + use core::sync::atomic::{AtomicBool, Ordering}; + static SPECULATE: AtomicBool = AtomicBool::new(true); + pub fn speculate() -> bool { + // We load and store with relaxed order, since worst case this leads to a few more memcopies + SPECULATE.load(Ordering::Relaxed) + } + pub fn report_violation() { + SPECULATE.store(false, Ordering::Relaxed) + } + } + #[cfg(not(target_has_atomic = "8"))] + mod alignment_assumption { + // On these platforms we don't speculate, and take the hit of performance + pub fn speculate() -> bool { + false + } + pub fn report_violation() {} + } + // reminder: old_layout.align() < new_align + let mut old_buffer = buffer; + let mut old_layout = old_layout; + if alignment_assumption::speculate() { + let (realloc_layout, ptr) = realloc(); + if let Some(buffer) = NonNull::new(ptr) { + if buffer.align_offset(new_align) == 0 { + return (realloc_layout, buffer); + } + // Speculating hasn't succeeded, but access now has to go through the reallocated buffer + alignment_assumption::report_violation(); + old_buffer = buffer; + old_layout = realloc_layout; + } else { + // If realloc fails, the later alloc will likely too, but don't report this yet + } + } + // realloc but change alignment. This requires a mem copy as pointed out above + let new_layout = Layout::from_size_align(new_size, new_align).unwrap(); + let new_buffer = buffer_alloc(new_layout); + // SAFETY: two different memory allocations, and old buffer's size is smaller than new_size + unsafe { + core::ptr::copy_nonoverlapping(old_buffer.as_ptr(), new_buffer.as_ptr(), old_layout.size()); + } + buffer_dealloc(old_layout, old_buffer); + (new_layout, new_buffer) +} + +// Deallocate a buffer of a Vec +fn buffer_dealloc(layout: Layout, buffer: NonNull) { + if layout.size() != 0 { + // SAFETY: buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`]. + // The layout is the same as per type-invariants + unsafe { + alloc::alloc::dealloc(buffer.as_ptr(), layout); + } + } else { + // An empty Vec does not allocate, hence nothing to dealloc + expect_dangling(layout.align(), buffer); + } +} + +impl Bytes { + /// Copy an existing slice of data into Bytes that are aligned to `align` + fn try_from_data(align: usize, data: &[u8]) -> Result { + let len = data.len(); + let layout = Layout::from_size_align(len, align)?; + let alloc = Allocation::new(layout); + unsafe { + // SAFETY: + // - data and alloc are distinct allocations of `len` bytes + core::ptr::copy_nonoverlapping::(data.as_ref().as_ptr(), alloc.as_mut_ptr(), len); + }; + Ok(Self { alloc, len }) + } + + /// Ensure the contained buffer is aligned to `align` by possibly moving it to a new buffer. + fn try_enforce_runtime_align(&mut self, align: usize) -> Result<(), LayoutError> { + if self.as_mut_ptr().align_offset(align) == 0 { + // data is already aligned correctly + return Ok(()); + } + *self = Self::try_from_data(align, self)?; + Ok(()) + } + + /// Create a sequence of [Bytes] from the memory representation of an unknown type of elements. + /// Prefer this over [Self::from_elems] when the datatype is not statically known and erased at runtime. + pub fn from_bytes_vec(bytes: Vec) -> Self { + let mut bytes = Self::from_elems(bytes); + // TODO: this method could be datatype aware and enforce a less strict alignment. + // On most platforms, this alignment check is fulfilled either way though, so + // the benefits of potentially saving a memcopy are negligible. + bytes.try_enforce_runtime_align(MAX_ALIGN).unwrap(); + bytes + } + + /// Erase the element type of a vector by converting into a sequence of [Bytes]. + /// + /// In case the element type is not statically known at runtime, prefer to use [Self::from_bytes_vec]. + pub fn from_elems(elems: Vec) -> Self + where + // NoUninit implies Copy + E: bytemuck::NoUninit + Send + Sync, + { + let _: () = const { + assert!( + core::mem::align_of::() <= MAX_ALIGN, + "element type not supported due to too large alignment" + ); + }; + // Note: going through a Box as in Vec::into_boxed_slice would re-allocate on excess capacity. Avoid that. + let byte_len = elems.len() * core::mem::size_of::(); + let alloc = Allocation::from_vec(elems); + Self { + alloc, + len: byte_len, + } + } + + fn reserve(&mut self, additional: usize) { + let needs_to_grow = additional > self.capacity().wrapping_sub(self.len()); + if !needs_to_grow { + return; + } + let Some(required_cap) = self.len().checked_add(additional) else { + alloc_overflow() + }; + // guarantee exponential growth for amortization + let new_cap = required_cap.max(self.capacity() * 2); + let new_cap = new_cap.max(MAX_ALIGN); // Small allocations would be pointless + let Ok(new_layout) = Layout::from_size_align(new_cap, MAX_ALIGN) else { + alloc_overflow() + }; + self.alloc.grow(new_layout); + } + + /// Extend the byte buffer from a slice of bytes + pub fn extend_from_byte_slice(&mut self, bytes: &[u8]) { + let additional = bytes.len(); + self.reserve(additional); + let len = self.len(); + let new_cap = len.wrapping_add(additional); // Can not overflow, as we've just successfully reserved sufficient space for it + let uninit_spare = &mut self.alloc.memory_mut()[len..new_cap]; + // SAFETY: reinterpreting the slice as a MaybeUninit. + // See also #![feature(maybe_uninit_write_slice)], which would replace this with safe code + uninit_spare.copy_from_slice(unsafe { + core::slice::from_raw_parts(bytes.as_ptr().cast(), additional) + }); + self.len = new_cap; + } + + /// Get the total capacity, in bytes, of the wrapped allocation. + pub fn capacity(&self) -> usize { + self.alloc.layout.size() + } + + /// Convert the bytes back into a vector. This requires that the type has the same alignment as the element + /// type this [Bytes] was initialized with. + /// This only returns with Ok(_) if the conversion can be done without a memcopy + pub fn try_into_vec( + mut self, + ) -> Result, Self> { + // See if the length is compatible + let Ok(data) = bytemuck::checked::try_cast_slice_mut::<_, E>(&mut self) else { + return Err(self); + }; + let length = data.len(); + // If so, try to convert the allocation to a vec + let mut vec = match self.alloc.try_into_vec::() { + Ok(vec) => vec, + Err(alloc) => { + self.alloc = alloc; + return Err(self); + } + }; + // SAFETY: We computed this length from the bytemuck-ed slice into this allocation + unsafe { + vec.set_len(length); + }; + Ok(vec) + } +} + +impl Deref for Bytes { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + // SAFETY: see type invariants + unsafe { core::slice::from_raw_parts(self.alloc.as_mut_ptr(), self.len) } + } +} + +impl DerefMut for Bytes { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: see type invariants + unsafe { core::slice::from_raw_parts_mut(self.alloc.as_mut_ptr(), self.len) } + } +} + +// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Send +unsafe impl Send for Bytes {} +// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Sync +unsafe impl Sync for Bytes {} + +#[cfg(test)] +mod tests { + use super::Bytes; + use alloc::{vec, vec::Vec}; + + const _CONST_ASSERTS: fn() = || { + fn test_send() {} + fn test_sync() {} + test_send::(); + test_sync::(); + }; + + fn test_serialization_roundtrip(bytes: &Bytes) { + let config = bincode::config::standard(); + let serialized = + bincode::serde::encode_to_vec(bytes, config).expect("serialization to succeed"); + let (roundtripped, _) = bincode::serde::decode_from_slice(&serialized, config) + .expect("deserialization to succeed"); + assert_eq!( + bytes, &roundtripped, + "roundtripping through serialization didn't lead to equal Bytes" + ); + } + + #[test] + fn test_serialization() { + test_serialization_roundtrip(&Bytes::from_elems::(vec![])); + test_serialization_roundtrip(&Bytes::from_elems(vec![0xdead, 0xbeaf])); + } + + #[test] + fn test_into_vec() { + // We test an edge case here, where the capacity (but not actual size) makes it impossible to convert to a vec + let mut bytes = Vec::with_capacity(6); + let actual_cap = bytes.capacity(); + bytes.extend_from_slice(&[0, 1, 2, 3]); + let mut bytes = Bytes::from_elems::(bytes); + + bytes = bytes + .try_into_vec::<[u8; 0]>() + .expect_err("Conversion should not succeed for a zero-sized type"); + if actual_cap % 4 != 0 { + // We most likely get actual_cap == 6, we can't force Vec to actually do that. Code coverage should complain if the actual test misses this + bytes = bytes.try_into_vec::<[u8; 4]>().err().unwrap_or_else(|| { + panic!("Conversion should not succeed due to capacity {actual_cap} not fitting a whole number of elements"); + }); + } + bytes = bytes + .try_into_vec::() + .expect_err("Conversion should not succeed due to mismatched alignment"); + bytes = bytes.try_into_vec::<[u8; 3]>().expect_err( + "Conversion should not succeed due to size not fitting a whole number of elements", + ); + let bytes = bytes.try_into_vec::<[u8; 2]>().expect("Conversion should succeed for bit-convertible types of equal alignment and compatible size"); + assert_eq!(bytes, &[[0, 1], [2, 3]]); + } + + #[test] + fn test_grow() { + let mut bytes = Bytes::from_elems::(vec![]); + bytes.extend_from_byte_slice(&[0, 1, 2, 3]); + assert_eq!(bytes[..], [0, 1, 2, 3][..]); + + let mut bytes = Bytes::from_elems(vec![42u8; 4]); + bytes.extend_from_byte_slice(&[0, 1, 2, 3]); + assert_eq!(bytes[..], [42, 42, 42, 42, 0, 1, 2, 3][..]); + } + + #[test] + fn test_large_elems() { + let mut bytes = Bytes::from_elems(vec![42u128]); + const TEST_BYTES: [u8; 16] = [ + 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, + 0x34, 0x12, + ]; + bytes.extend_from_byte_slice(&TEST_BYTES); + let vec = bytes.try_into_vec::().unwrap(); + assert_eq!(vec, [42u128, u128::from_ne_bytes(TEST_BYTES)]); + } +} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index d52a181929..c65572a068 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -13,7 +13,7 @@ use half::{bf16, f16}; use crate::{ quantization::{AffineQuantization, Quantization, QuantizationStrategy}, - tensor::Shape, + tensor::{bytes::Bytes, Shape}, DType, Distribution, Element, ElementConversion, }; @@ -43,8 +43,7 @@ pub enum DataError { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct TensorData { /// The values of the tensor (as bytes). - #[serde(with = "serde_bytes")] - pub bytes: Vec, + pub bytes: Bytes, /// The shape of the tensor. pub shape: Vec, @@ -53,33 +52,12 @@ pub struct TensorData { pub dtype: DType, } -fn into_bytes(mut value: Vec) -> Vec { - // Ensure `E` satisfies the `Pod` trait requirements - assert_eq!(core::mem::size_of::() % core::mem::size_of::(), 0); - - let factor = core::mem::size_of::() / core::mem::size_of::(); - let len = value.len() * factor; - let capacity = value.capacity() * factor; - let ptr = value.as_mut_ptr(); - - core::mem::forget(value); - - unsafe { Vec::from_raw_parts(ptr as *mut u8, len, capacity) } -} - impl TensorData { /// Creates a new tensor data structure. - pub fn new>>(mut value: Vec, shape: S) -> Self { + pub fn new>>(value: Vec, shape: S) -> Self { // Ensure shape is valid let shape = shape.into(); - let shape_numel = Self::numel(&shape); - value.truncate(shape_numel); - let numel = value.len(); - assert_eq!( - shape_numel, numel, - "Shape {:?} is invalid for input of size {:?}", - shape, numel, - ); + Self::check_data_len(&value, &shape, None); Self::init(value, shape, E::dtype()) } @@ -93,10 +71,10 @@ impl TensorData { shape: S, strategy: QuantizationStrategy, ) -> Self { - // TODO: this method should go into a dedicated Bytes opaque type with other bytes - // handling logic - let mut value = into_bytes(value); + let shape = shape.into(); + Self::check_data_len(&value, &shape, Some(&strategy)); + let mut bytes: Bytes; // Notes on quantization data representation: // 1) The quantized values are packed into 32-bit unsigned integers. For example, int8 // quantized values pack 4 grouped values into a single `u32`. When unpacking these values, @@ -107,9 +85,9 @@ impl TensorData { match strategy { QuantizationStrategy::PerTensorAffineInt8(q) => { if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&value).to_vec(); // already packed values - } else if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&pack_i8s_to_u32s(&value)).to_vec(); + bytes = Bytes::from_elems(value); // already packed values + } else if let Some(value) = ::downcast_ref::>(&value) { + bytes = Bytes::from_elems(pack_i8s_to_u32s(value)); } else { panic!("Invalid quantized type"); } @@ -117,31 +95,62 @@ impl TensorData { let offset = q.offset as i32; let scale_bytes = bytemuck::bytes_of(&q.scale); let offset_bytes = bytemuck::bytes_of(&offset); - value.extend_from_slice(offset_bytes); - value.extend_from_slice(scale_bytes); + bytes.extend_from_byte_slice(offset_bytes); + bytes.extend_from_byte_slice(scale_bytes); } QuantizationStrategy::PerTensorSymmetricInt8(q) => { if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&value).to_vec(); // already packed values - } else if TypeId::of::() == TypeId::of::() { - let packed = pack_i8s_to_u32s(&value); - value = bytemuck::checked::cast_slice(&packed).to_vec(); + bytes = Bytes::from_elems(value); // already packed values + } else if let Some(value) = ::downcast_ref::>(&value) { + bytes = Bytes::from_elems(pack_i8s_to_u32s(value)); } else { panic!("Invalid quantized type"); } let scale_bytes = bytemuck::bytes_of(&q.scale); - value.extend_from_slice(scale_bytes); + bytes.extend_from_byte_slice(scale_bytes); } } - Self::init(value, shape, DType::QFloat(strategy.scheme())) + Self { + bytes, + shape, + dtype: DType::QFloat(strategy.scheme()), + } + } + + // Check that the input vector contains a correct number of elements + fn check_data_len( + data: &[E], + shape: &Vec, + quantization: Option<&QuantizationStrategy>, + ) { + let mut expected_data_len = Self::numel(shape); + if let Some(quantization) = quantization { + let elem_per_data = match quantization { + QuantizationStrategy::PerTensorAffineInt8(_) + | QuantizationStrategy::PerTensorSymmetricInt8(_) => { + if TypeId::of::() == TypeId::of::() { + 4 + } else { + 1 + } + } + }; + expected_data_len = expected_data_len.div_ceil(elem_per_data); + } + let num_data = data.len(); + assert_eq!( + expected_data_len, num_data, + "Shape {:?} is invalid for input of size {:?}", + shape, num_data, + ); } /// Initializes a new tensor data structure from the provided values. - fn init>>(value: Vec, shape: S, dtype: DType) -> Self { + fn init(value: Vec, shape: Vec, dtype: DType) -> Self { Self { - bytes: into_bytes(value), - shape: shape.into(), + bytes: Bytes::from_elems(value), + shape, dtype, } } @@ -185,7 +194,7 @@ impl TensorData { } /// Returns the tensor data as a vector of scalar values. - pub fn into_vec(mut self) -> Result, DataError> { + pub fn into_vec(self) -> Result, DataError> { if E::dtype() != self.dtype { return Err(DataError::TypeMismatch(format!( "Invalid target element type (expected {:?}, got {:?})", @@ -194,19 +203,16 @@ impl TensorData { ))); } - let capacity_bytes = self.bytes.capacity(); - let length_bytes = self.bytes.len(); - let size_elem = core::mem::size_of::(); - - let capacity = capacity_bytes / size_elem; - let length = length_bytes / size_elem; - - unsafe { - let ptr = self.bytes.as_mut_ptr(); - core::mem::forget(self.bytes); - - Ok(Vec::from_raw_parts(ptr.cast::(), length, capacity)) - } + let mut me = self; + me.bytes = match me.bytes.try_into_vec::() { + Ok(elems) => return Ok(elems), + Err(bytes) => bytes, + }; + // The bytes might have been deserialized and allocated with a different align. + // In that case, we have to memcopy the data into a new vector, more suitably allocated + Ok(bytemuck::checked::try_cast_slice(me.values_as_bytes()) + .map_err(DataError::CastError)? + .to_vec()) } /// Returns an iterator over the values of the tensor data. @@ -405,7 +411,7 @@ impl TensorData { /// Returns the data as a slice of bytes. pub fn as_bytes(&self) -> &[u8] { - self.bytes.as_slice() + &self.bytes } /// Applies the data quantization strategy. diff --git a/crates/burn-tensor/src/tensor/mod.rs b/crates/burn-tensor/src/tensor/mod.rs index feed9571c4..40fa4b0f2c 100644 --- a/crates/burn-tensor/src/tensor/mod.rs +++ b/crates/burn-tensor/src/tensor/mod.rs @@ -1,12 +1,14 @@ pub(crate) mod stats; mod api; +mod bytes; mod data; mod distribution; mod element; mod shape; pub use api::*; +pub use bytes::*; pub use data::*; pub use distribution::*; pub use element::*; diff --git a/crates/burn-tensor/src/tensor/quantization/data.rs b/crates/burn-tensor/src/tensor/quantization/data.rs index 96096b784b..7f833edc58 100644 --- a/crates/burn-tensor/src/tensor/quantization/data.rs +++ b/crates/burn-tensor/src/tensor/quantization/data.rs @@ -1,10 +1,7 @@ use alloc::vec::Vec; /// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers. -/// -/// # Note -/// This assumes that the bytes represent `i8` values. -pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec { +pub fn pack_i8s_to_u32s(bytes: &[i8]) -> Vec { // Shift and combine groups of four 8-bit values into a u32. // Same as doing this: // let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF); @@ -12,7 +9,7 @@ pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec { .chunks(4) .map(|x| { x.iter().enumerate().fold(0u32, |acc, (i, x)| { - acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8) + acc | (*x as u32 & 0xFF) << ((3 - i) * 8) }) }) .collect()