diff --git a/Cargo.lock b/Cargo.lock index 9eb57e9c70..0fa466bb87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -937,6 +937,7 @@ dependencies = [ name = "burn-tensor" version = "0.16.0" dependencies = [ + "bincode", "burn-common", "burn-tensor-testgen", "bytemuck", diff --git a/crates/burn-core/src/record/serde/data.rs b/crates/burn-core/src/record/serde/data.rs index 27a55c5b2a..b5e5442b54 100644 --- a/crates/burn-core/src/record/serde/data.rs +++ b/crates/burn-core/src/record/serde/data.rs @@ -8,6 +8,7 @@ use crate::record::{PrecisionSettings, Record}; use crate::tensor::backend::Backend; use alloc::fmt; +use burn_tensor::Bytes; use num_traits::cast::ToPrimitive; use regex::Regex; use serde::Deserialize; @@ -66,7 +67,11 @@ pub enum NestedValue { /// A vector of 32-bit floating point values. F32s(Vec), + + /// An opaque vector of bytes, with alignment. + Bytes(Bytes), } + impl NestedValue { /// Get the nested value as a map. pub fn as_map(self) -> Option> { @@ -184,9 +189,10 @@ impl NestedValue { } /// Get the nested value as a vector of bytes. - pub fn as_bytes(self) -> Option> { + pub fn as_bytes(self) -> Option { match self { - NestedValue::U8s(u) => Some(u), + NestedValue::Bytes(u) => Some(u), + NestedValue::U8s(u) => Some(Bytes::from_elems(u)), _ => None, } } @@ -368,6 +374,7 @@ impl fmt::Debug for NestedValue { NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f), + NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f), // Handle other variants as usual NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(), NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(), @@ -385,6 +392,7 @@ impl fmt::Debug for NestedValue { NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(), NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(), + NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(), } } } diff --git a/crates/burn-core/src/record/serde/de.rs b/crates/burn-core/src/record/serde/de.rs index 3c93afed16..04bed6899f 100644 --- a/crates/burn-core/src/record/serde/de.rs +++ b/crates/burn-core/src/record/serde/de.rs @@ -233,7 +233,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer { where V: Visitor<'de>, { - visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap()) + let bytes = self.value.unwrap().as_bytes().unwrap(); + match bytes.try_into_vec::() { + Ok(bytes) => visitor.visit_byte_buf(bytes), + Err(bytes) => visitor.visit_bytes(&bytes), + } } fn deserialize_option(self, visitor: V) -> Result diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index f2803d8bdc..b0baaa5cd1 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -383,6 +383,6 @@ mod tests { .clone() .as_bytes() .expect("has bytes vec"); - assert_eq!(bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); + assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened()); } } diff --git a/crates/burn-import/src/pytorch/reader.rs b/crates/burn-import/src/pytorch/reader.rs index 55a8d9838a..87a76278b8 100644 --- a/crates/burn-import/src/pytorch/reader.rs +++ b/crates/burn-import/src/pytorch/reader.rs @@ -152,7 +152,7 @@ where // Because serializer copies individual elements of TensorData `value` into a new Vec, // which is not necessary and inefficient. let mut tensor_data: HashMap = HashMap::new(); - tensor_data.insert("bytes".into(), NestedValue::U8s(bytes)); + tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes)); tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?); tensor_data.insert("dtype".into(), dtype.serialize(serializer)?); diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 3f5637e412..e29a98ca96 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -7,8 +7,7 @@ use std::{marker::PhantomData, sync::Mutex}; #[cfg(not(feature = "fusion"))] use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - quantization::QuantizationScheme, - repr::{HandleKind, ReprBackend, TensorHandle}, + repr::{ReprBackend, TensorHandle}, }; pub(crate) static SEED: Mutex> = Mutex::new(None); diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index cde44b0552..feca09a41f 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -3,7 +3,7 @@ use std::ops::Range; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType}, - DType, Device, Shape, TensorData, + Bytes, DType, Device, Shape, TensorData, }; use crate::{ @@ -82,12 +82,8 @@ where let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; - // TODO: this should be refactored such that the bytes type is opaque. - // With this, the logic for handling the bytes representation of quantized data - // (as well as all data manipulations) will be encapsulated in the type. - // Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate. TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape: tensor.shape.into(), dtype: tensor.dtype, } diff --git a/crates/burn-jit/src/ops/transaction.rs b/crates/burn-jit/src/ops/transaction.rs index b67740bc97..49bc1936ee 100644 --- a/crates/burn-jit/src/ops/transaction.rs +++ b/crates/burn-jit/src/ops/transaction.rs @@ -1,6 +1,6 @@ use burn_tensor::{ ops::{TransactionOps, TransactionPrimitiveResult}, - DType, TensorData, + Bytes, DType, TensorData, }; use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime}; @@ -74,7 +74,7 @@ where Kind::Float(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_floats.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); @@ -82,7 +82,7 @@ where Kind::Int(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_ints.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); @@ -90,7 +90,7 @@ where Kind::Bool(index, shape, dtype) => { let bytes = data.get_mut(index).unwrap().take().unwrap(); result.read_bools.push(TensorData { - bytes, + bytes: Bytes::from_bytes_vec(bytes), shape, dtype, }); diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 9c4dfa3a49..02683531c6 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -56,6 +56,7 @@ portable-atomic-util = { workspace = true } [dev-dependencies] rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std +bincode = { workspace = true } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-tensor/src/tensor/bytes.rs b/crates/burn-tensor/src/tensor/bytes.rs new file mode 100644 index 0000000000..f9cb238a26 --- /dev/null +++ b/crates/burn-tensor/src/tensor/bytes.rs @@ -0,0 +1,547 @@ +//! A version of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements. + +use alloc::alloc::{Layout, LayoutError}; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +use alloc::vec::Vec; + +/// Internally used to avoid accidentally leaking an allocation or using the wrong layout. +struct Allocation { + /// SAFETY: + /// - If `layout.size() > 0`, `ptr` points to a valid allocation from the global allocator + /// of the specified layout. The first `len` bytes are initialized. + /// - If `layout.size() == 0`, `ptr` is aligned to `layout.align()` and `len` is 0. + /// `ptr` is further suitable to be used as the argument for `Vec::from_raw_parts` see [buffer alloc] + /// for more details. + ptr: NonNull, + layout: Layout, +} + +/// A sort of `Box<[u8]>` that remembers the original alignment and can contain trailing uninitialized bytes. +pub struct Bytes { + alloc: Allocation, + // SAFETY: The first `len` bytes of the allocation are initialized + len: usize, +} + +/// The maximum supported alignment. The limit exists to not have to store alignment when serializing. Instead, +/// the bytes are always over-aligned when deserializing to MAX_ALIGN. +const MAX_ALIGN: usize = core::mem::align_of::(); + +fn debug_from_fn) -> core::fmt::Result>( + f: F, +) -> impl core::fmt::Debug { + // See also: std::fmt::from_fn + struct FromFn(F); + impl core::fmt::Debug for FromFn + where + F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result, + { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + (self.0)(f) + } + } + FromFn(f) +} + +impl core::fmt::Debug for Bytes { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let data = &**self; + let fmt_data = move |f: &mut core::fmt::Formatter<'_>| { + if data.len() > 3 { + // There is a nightly API `debug_more_non_exhaustive` which has `finish_non_exhaustive` + f.debug_list().entries(&data[0..3]).entry(&"...").finish() + } else { + f.debug_list().entries(data).finish() + } + }; + f.debug_struct("Bytes") + .field("data", &debug_from_fn(fmt_data)) + .field("len", &self.len) + .finish() + } +} + +impl serde::Serialize for Bytes { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serde_bytes::serialize(self.deref(), serializer) + } +} + +impl<'de> serde::Deserialize<'de> for Bytes { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cold] + fn too_large(len: usize, align: usize) -> E { + // max_length = largest multiple of align that is <= isize::MAX + // align is a power of 2, hence a multiple has the lower bits unset. Mask them off to find the largest multiple + let max_length = (isize::MAX as usize) & !(align - 1); + E::custom(core::format_args!( + "length too large: {len}. Expected at most {max_length} bytes" + )) + } + + // TODO: we can possibly avoid one copy here by deserializing into an existing, correctly aligned, slice of bytes. + // We might not be able to predict the length of the data, hence it's far more convenient to let `Vec` handle the growth and re-allocations. + // Further, on a lot of systems, the allocator naturally aligns data to some reasonably large alignment, where no further copy is then + // necessary. + let data: Vec = serde_bytes::deserialize(deserializer)?; + // When deserializing, we over-align the data. This saves us from having to encode the alignment (which is platform-dependent in any case). + // If we had more context information here, we could enforce some (smaller) alignment per data type. But this information is only available + // in `TensorData`. Moreover it depends on the Deserializer there whether the datatype or data comes first. + let align = MAX_ALIGN; + let mut bytes = Self::from_elems(data); + bytes + .try_enforce_runtime_align(align) + .map_err(|_| too_large(bytes.len(), align))?; + Ok(bytes) + } +} + +impl Clone for Bytes { + fn clone(&self) -> Self { + // unwrap here: the layout is valid as it has the alignment & size of self + Self::try_from_data(MAX_ALIGN, self.deref()).unwrap() + } +} + +impl PartialEq for Bytes { + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl Eq for Bytes {} + +impl Allocation { + // Wrap the allocation of a vector without copying + fn from_vec(vec: Vec) -> Self { + let mut elems = core::mem::ManuallyDrop::new(vec); + // Set the length to 0, then all data is in the "spare capacity". + // SAFETY: Data is Copy, so in particular does not need to be dropped. In any case, try not to panic until + // we have taken ownership of the data! + unsafe { elems.set_len(0) }; + let data = elems.spare_capacity_mut(); + // We now have one contiguous slice of data to pass to Layout::for_value. + let layout = Layout::for_value(data); + // SAFETY: data is the allocation of a vec, hence can not be null. We use unchecked to avoid a panic-path. + let ptr = unsafe { NonNull::new_unchecked(elems.as_mut_ptr().cast()) }; + Self { ptr, layout } + } + // Create a new allocation with the specified layout + fn new(layout: Layout) -> Self { + let ptr = buffer_alloc(layout); + Self { ptr, layout } + } + // Reallocate to fit at least the size and align of min_layout + fn grow(&mut self, min_layout: Layout) { + (self.layout, self.ptr) = buffer_grow(self.layout, self.ptr, min_layout); + } + // Returns a mutable view of the memory of the whole allocation + fn memory_mut(&mut self) -> &mut [MaybeUninit] { + // SAFETY: See type invariants + unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr().cast(), self.layout.size()) } + } + // Return a pointer to the underlying allocation. This pointer is valid for reads and writes until the allocation is dropped or reallocated. + fn as_mut_ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + // Try to convert the allocation to a Vec. The Vec has a length of 0 when returned, but correct capacity and pointer! + fn try_into_vec(self) -> Result, Self> { + let byte_capacity = self.layout.size(); + let Some(capacity) = byte_capacity.checked_div(size_of::()) else { + return Err(self); + }; + if capacity * size_of::() != byte_capacity { + return Err(self); + }; + if self.layout.align() != align_of::() { + return Err(self); + } + // Okay, let's commit + let ptr = self.ptr.as_ptr().cast(); + core::mem::forget(self); + // SAFETY: + // - ptr was allocated by the global allocator as per type-invariant + // - `E` has the same alignment as indicated by the stored layout. + // - capacity * size_of:: == layout.size() + // - 0 <= capacity + // - no bytes are claimed to be initialized + // - the layout represents a valid allocation, hence has allocation size less than isize::MAX + Ok(unsafe { Vec::from_raw_parts(ptr, 0, capacity) }) + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + buffer_dealloc(self.layout, self.ptr); + } +} + +// Allocate a pointer that can be passed to Vec::from_raw_parts +fn buffer_alloc(layout: Layout) -> NonNull { + // [buffer alloc]: The current docs of Vec::from_raw_parts(ptr, ...) say: + // > ptr must have been allocated using the global allocator + // Yet, an empty Vec is guaranteed to not allocate (it is even illegal! to allocate with a zero-sized layout) + // Hence, we slightly re-interpret the above to only needing to hold if `capacity > 0`. Still, the pointer + // must be non-zero. So in case we need a pointer for an empty vec, use a correctly aligned, dangling one. + if layout.size() == 0 { + // we would use NonNull:dangling() but we don't have a concrete type for the requested alignment + let ptr = core::ptr::null_mut::().wrapping_add(layout.align()); + // SAFETY: layout.align() is never 0 + unsafe { NonNull::new_unchecked(ptr) } + } else { + // SAFETY: layout has non-zero size. + let ptr = unsafe { alloc::alloc::alloc(layout) }; + NonNull::new(ptr).unwrap_or_else(|| alloc::alloc::handle_alloc_error(layout)) + } +} + +fn expect_dangling(align: usize, buffer: NonNull) { + debug_assert!( + buffer.as_ptr().wrapping_sub(align).is_null(), + "expected a nullptr for size 0" + ); +} + +#[cold] +fn alloc_overflow() -> ! { + panic!("Overflow, too many elements") +} + +// Grow the buffer while keeping alignment +fn buffer_grow( + old_layout: Layout, + buffer: NonNull, + min_layout: Layout, +) -> (Layout, NonNull) { + let new_align = min_layout.align().max(old_layout.align()); // Don't let data become less aligned + let new_size = min_layout.size().next_multiple_of(new_align); + if new_size > isize::MAX as usize { + alloc_overflow(); + } + + assert!(new_size > old_layout.size(), "size must actually grow"); + if old_layout.size() == 0 { + expect_dangling(old_layout.align(), buffer); + let new_layout = Layout::from_size_align(new_size, new_align).unwrap(); + let buffer = buffer_alloc(new_layout); + return (new_layout, buffer); + }; + let realloc = || { + let new_layout = Layout::from_size_align(new_size, old_layout.align()).unwrap(); + // SAFETY: + // - buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`]. + // - old_layout is the same as with which the pointer was allocated + // - new_size is not 0, since it is larger than old_layout.size() which is non-zero + // - size constitutes a valid layout + let ptr = unsafe { alloc::alloc::realloc(buffer.as_ptr(), old_layout, new_layout.size()) }; + (new_layout, ptr) + }; + if new_align == old_layout.align() { + // happy path. We can just realloc. + let (new_layout, ptr) = realloc(); + let buffer = NonNull::new(ptr); + let buffer = buffer.unwrap_or_else(|| alloc::alloc::handle_alloc_error(new_layout)); + return (new_layout, buffer); + } + // [buffer grow]: alloc::realloc can *not* change the alignment of the allocation's layout. + // The unstable Allocator::{grow,shrink} API changes this, but might take a while to make it + // into alloc::GlobalAlloc. + // + // As such, we can not request a specific alignment. But most allocators will give us the required + // alignment "for free". Hence, we speculatively avoid a mem-copy by using realloc. + // + // If in the future requesting an alignment change for an existing is available, this can be removed. + #[cfg(target_has_atomic = "8")] + mod alignment_assumption { + use core::sync::atomic::{AtomicBool, Ordering}; + static SPECULATE: AtomicBool = AtomicBool::new(true); + pub fn speculate() -> bool { + // We load and store with relaxed order, since worst case this leads to a few more memcopies + SPECULATE.load(Ordering::Relaxed) + } + pub fn report_violation() { + SPECULATE.store(false, Ordering::Relaxed) + } + } + #[cfg(not(target_has_atomic = "8"))] + mod alignment_assumption { + // On these platforms we don't speculate, and take the hit of performance + pub fn speculate() -> bool { + false + } + pub fn report_violation() {} + } + // reminder: old_layout.align() < new_align + let mut old_buffer = buffer; + let mut old_layout = old_layout; + if alignment_assumption::speculate() { + let (realloc_layout, ptr) = realloc(); + if let Some(buffer) = NonNull::new(ptr) { + if buffer.align_offset(new_align) == 0 { + return (realloc_layout, buffer); + } + // Speculating hasn't succeeded, but access now has to go through the reallocated buffer + alignment_assumption::report_violation(); + old_buffer = buffer; + old_layout = realloc_layout; + } else { + // If realloc fails, the later alloc will likely too, but don't report this yet + } + } + // realloc but change alignment. This requires a mem copy as pointed out above + let new_layout = Layout::from_size_align(new_size, new_align).unwrap(); + let new_buffer = buffer_alloc(new_layout); + // SAFETY: two different memory allocations, and old buffer's size is smaller than new_size + unsafe { + core::ptr::copy_nonoverlapping(old_buffer.as_ptr(), new_buffer.as_ptr(), old_layout.size()); + } + buffer_dealloc(old_layout, old_buffer); + (new_layout, new_buffer) +} + +// Deallocate a buffer of a Vec +fn buffer_dealloc(layout: Layout, buffer: NonNull) { + if layout.size() != 0 { + // SAFETY: buffer comes from a Vec or from [`buffer_alloc`/`buffer_grow`]. + // The layout is the same as per type-invariants + unsafe { + alloc::alloc::dealloc(buffer.as_ptr(), layout); + } + } else { + // An empty Vec does not allocate, hence nothing to dealloc + expect_dangling(layout.align(), buffer); + } +} + +impl Bytes { + /// Copy an existing slice of data into Bytes that are aligned to `align` + fn try_from_data(align: usize, data: &[u8]) -> Result { + let len = data.len(); + let layout = Layout::from_size_align(len, align)?; + let alloc = Allocation::new(layout); + unsafe { + // SAFETY: + // - data and alloc are distinct allocations of `len` bytes + core::ptr::copy_nonoverlapping::(data.as_ref().as_ptr(), alloc.as_mut_ptr(), len); + }; + Ok(Self { alloc, len }) + } + + /// Ensure the contained buffer is aligned to `align` by possibly moving it to a new buffer. + fn try_enforce_runtime_align(&mut self, align: usize) -> Result<(), LayoutError> { + if self.as_mut_ptr().align_offset(align) == 0 { + // data is already aligned correctly + return Ok(()); + } + *self = Self::try_from_data(align, self)?; + Ok(()) + } + + /// Create a sequence of [Bytes] from the memory representation of an unknown type of elements. + /// Prefer this over [Self::from_elems] when the datatype is not statically known and erased at runtime. + pub fn from_bytes_vec(bytes: Vec) -> Self { + let mut bytes = Self::from_elems(bytes); + // TODO: this method could be datatype aware and enforce a less strict alignment. + // On most platforms, this alignment check is fulfilled either way though, so + // the benefits of potentially saving a memcopy are negligible. + bytes.try_enforce_runtime_align(MAX_ALIGN).unwrap(); + bytes + } + + /// Erase the element type of a vector by converting into a sequence of [Bytes]. + /// + /// In case the element type is not statically known at runtime, prefer to use [Self::from_bytes_vec]. + pub fn from_elems(elems: Vec) -> Self + where + // NoUninit implies Copy + E: bytemuck::NoUninit + Send + Sync, + { + let _: () = const { + assert!( + core::mem::align_of::() <= MAX_ALIGN, + "element type not supported due to too large alignment" + ); + }; + // Note: going through a Box as in Vec::into_boxed_slice would re-allocate on excess capacity. Avoid that. + let byte_len = elems.len() * core::mem::size_of::(); + let alloc = Allocation::from_vec(elems); + Self { + alloc, + len: byte_len, + } + } + + fn reserve(&mut self, additional: usize) { + let needs_to_grow = additional > self.capacity().wrapping_sub(self.len()); + if !needs_to_grow { + return; + } + let Some(required_cap) = self.len().checked_add(additional) else { + alloc_overflow() + }; + // guarantee exponential growth for amortization + let new_cap = required_cap.max(self.capacity() * 2); + let new_cap = new_cap.max(MAX_ALIGN); // Small allocations would be pointless + let Ok(new_layout) = Layout::from_size_align(new_cap, MAX_ALIGN) else { + alloc_overflow() + }; + self.alloc.grow(new_layout); + } + + /// Extend the byte buffer from a slice of bytes + pub fn extend_from_byte_slice(&mut self, bytes: &[u8]) { + let additional = bytes.len(); + self.reserve(additional); + let len = self.len(); + let new_cap = len.wrapping_add(additional); // Can not overflow, as we've just successfully reserved sufficient space for it + let uninit_spare = &mut self.alloc.memory_mut()[len..new_cap]; + // SAFETY: reinterpreting the slice as a MaybeUninit. + // See also #![feature(maybe_uninit_write_slice)], which would replace this with safe code + uninit_spare.copy_from_slice(unsafe { + core::slice::from_raw_parts(bytes.as_ptr().cast(), additional) + }); + self.len = new_cap; + } + + /// Get the total capacity, in bytes, of the wrapped allocation. + pub fn capacity(&self) -> usize { + self.alloc.layout.size() + } + + /// Convert the bytes back into a vector. This requires that the type has the same alignment as the element + /// type this [Bytes] was initialized with. + /// This only returns with Ok(_) if the conversion can be done without a memcopy + pub fn try_into_vec( + mut self, + ) -> Result, Self> { + // See if the length is compatible + let Ok(data) = bytemuck::checked::try_cast_slice_mut::<_, E>(&mut self) else { + return Err(self); + }; + let length = data.len(); + // If so, try to convert the allocation to a vec + let mut vec = match self.alloc.try_into_vec::() { + Ok(vec) => vec, + Err(alloc) => { + self.alloc = alloc; + return Err(self); + } + }; + // SAFETY: We computed this length from the bytemuck-ed slice into this allocation + unsafe { + vec.set_len(length); + }; + Ok(vec) + } +} + +impl Deref for Bytes { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + // SAFETY: see type invariants + unsafe { core::slice::from_raw_parts(self.alloc.as_mut_ptr(), self.len) } + } +} + +impl DerefMut for Bytes { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: see type invariants + unsafe { core::slice::from_raw_parts_mut(self.alloc.as_mut_ptr(), self.len) } + } +} + +// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Send +unsafe impl Send for Bytes {} +// SAFETY: Bytes behaves like a Box<[u8]> and can contain only elements that are themselves Sync +unsafe impl Sync for Bytes {} + +#[cfg(test)] +mod tests { + use super::Bytes; + use alloc::{vec, vec::Vec}; + + const _CONST_ASSERTS: fn() = || { + fn test_send() {} + fn test_sync() {} + test_send::(); + test_sync::(); + }; + + fn test_serialization_roundtrip(bytes: &Bytes) { + let config = bincode::config::standard(); + let serialized = + bincode::serde::encode_to_vec(bytes, config).expect("serialization to succeed"); + let (roundtripped, _) = bincode::serde::decode_from_slice(&serialized, config) + .expect("deserialization to succeed"); + assert_eq!( + bytes, &roundtripped, + "roundtripping through serialization didn't lead to equal Bytes" + ); + } + + #[test] + fn test_serialization() { + test_serialization_roundtrip(&Bytes::from_elems::(vec![])); + test_serialization_roundtrip(&Bytes::from_elems(vec![0xdead, 0xbeaf])); + } + + #[test] + fn test_into_vec() { + // We test an edge case here, where the capacity (but not actual size) makes it impossible to convert to a vec + let mut bytes = Vec::with_capacity(6); + let actual_cap = bytes.capacity(); + bytes.extend_from_slice(&[0, 1, 2, 3]); + let mut bytes = Bytes::from_elems::(bytes); + + bytes = bytes + .try_into_vec::<[u8; 0]>() + .expect_err("Conversion should not succeed for a zero-sized type"); + if actual_cap % 4 != 0 { + // We most likely get actual_cap == 6, we can't force Vec to actually do that. Code coverage should complain if the actual test misses this + bytes = bytes.try_into_vec::<[u8; 4]>().err().unwrap_or_else(|| { + panic!("Conversion should not succeed due to capacity {actual_cap} not fitting a whole number of elements"); + }); + } + bytes = bytes + .try_into_vec::() + .expect_err("Conversion should not succeed due to mismatched alignment"); + bytes = bytes.try_into_vec::<[u8; 3]>().expect_err( + "Conversion should not succeed due to size not fitting a whole number of elements", + ); + let bytes = bytes.try_into_vec::<[u8; 2]>().expect("Conversion should succeed for bit-convertible types of equal alignment and compatible size"); + assert_eq!(bytes, &[[0, 1], [2, 3]]); + } + + #[test] + fn test_grow() { + let mut bytes = Bytes::from_elems::(vec![]); + bytes.extend_from_byte_slice(&[0, 1, 2, 3]); + assert_eq!(bytes[..], [0, 1, 2, 3][..]); + + let mut bytes = Bytes::from_elems(vec![42u8; 4]); + bytes.extend_from_byte_slice(&[0, 1, 2, 3]); + assert_eq!(bytes[..], [42, 42, 42, 42, 0, 1, 2, 3][..]); + } + + #[test] + fn test_large_elems() { + let mut bytes = Bytes::from_elems(vec![42u128]); + const TEST_BYTES: [u8; 16] = [ + 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, 0x34, 0x12, 0x90, 0x78, 0x56, + 0x34, 0x12, + ]; + bytes.extend_from_byte_slice(&TEST_BYTES); + let vec = bytes.try_into_vec::().unwrap(); + assert_eq!(vec, [42u128, u128::from_ne_bytes(TEST_BYTES)]); + } +} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index d52a181929..c65572a068 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -13,7 +13,7 @@ use half::{bf16, f16}; use crate::{ quantization::{AffineQuantization, Quantization, QuantizationStrategy}, - tensor::Shape, + tensor::{bytes::Bytes, Shape}, DType, Distribution, Element, ElementConversion, }; @@ -43,8 +43,7 @@ pub enum DataError { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct TensorData { /// The values of the tensor (as bytes). - #[serde(with = "serde_bytes")] - pub bytes: Vec, + pub bytes: Bytes, /// The shape of the tensor. pub shape: Vec, @@ -53,33 +52,12 @@ pub struct TensorData { pub dtype: DType, } -fn into_bytes(mut value: Vec) -> Vec { - // Ensure `E` satisfies the `Pod` trait requirements - assert_eq!(core::mem::size_of::() % core::mem::size_of::(), 0); - - let factor = core::mem::size_of::() / core::mem::size_of::(); - let len = value.len() * factor; - let capacity = value.capacity() * factor; - let ptr = value.as_mut_ptr(); - - core::mem::forget(value); - - unsafe { Vec::from_raw_parts(ptr as *mut u8, len, capacity) } -} - impl TensorData { /// Creates a new tensor data structure. - pub fn new>>(mut value: Vec, shape: S) -> Self { + pub fn new>>(value: Vec, shape: S) -> Self { // Ensure shape is valid let shape = shape.into(); - let shape_numel = Self::numel(&shape); - value.truncate(shape_numel); - let numel = value.len(); - assert_eq!( - shape_numel, numel, - "Shape {:?} is invalid for input of size {:?}", - shape, numel, - ); + Self::check_data_len(&value, &shape, None); Self::init(value, shape, E::dtype()) } @@ -93,10 +71,10 @@ impl TensorData { shape: S, strategy: QuantizationStrategy, ) -> Self { - // TODO: this method should go into a dedicated Bytes opaque type with other bytes - // handling logic - let mut value = into_bytes(value); + let shape = shape.into(); + Self::check_data_len(&value, &shape, Some(&strategy)); + let mut bytes: Bytes; // Notes on quantization data representation: // 1) The quantized values are packed into 32-bit unsigned integers. For example, int8 // quantized values pack 4 grouped values into a single `u32`. When unpacking these values, @@ -107,9 +85,9 @@ impl TensorData { match strategy { QuantizationStrategy::PerTensorAffineInt8(q) => { if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&value).to_vec(); // already packed values - } else if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&pack_i8s_to_u32s(&value)).to_vec(); + bytes = Bytes::from_elems(value); // already packed values + } else if let Some(value) = ::downcast_ref::>(&value) { + bytes = Bytes::from_elems(pack_i8s_to_u32s(value)); } else { panic!("Invalid quantized type"); } @@ -117,31 +95,62 @@ impl TensorData { let offset = q.offset as i32; let scale_bytes = bytemuck::bytes_of(&q.scale); let offset_bytes = bytemuck::bytes_of(&offset); - value.extend_from_slice(offset_bytes); - value.extend_from_slice(scale_bytes); + bytes.extend_from_byte_slice(offset_bytes); + bytes.extend_from_byte_slice(scale_bytes); } QuantizationStrategy::PerTensorSymmetricInt8(q) => { if TypeId::of::() == TypeId::of::() { - value = bytemuck::checked::cast_slice(&value).to_vec(); // already packed values - } else if TypeId::of::() == TypeId::of::() { - let packed = pack_i8s_to_u32s(&value); - value = bytemuck::checked::cast_slice(&packed).to_vec(); + bytes = Bytes::from_elems(value); // already packed values + } else if let Some(value) = ::downcast_ref::>(&value) { + bytes = Bytes::from_elems(pack_i8s_to_u32s(value)); } else { panic!("Invalid quantized type"); } let scale_bytes = bytemuck::bytes_of(&q.scale); - value.extend_from_slice(scale_bytes); + bytes.extend_from_byte_slice(scale_bytes); } } - Self::init(value, shape, DType::QFloat(strategy.scheme())) + Self { + bytes, + shape, + dtype: DType::QFloat(strategy.scheme()), + } + } + + // Check that the input vector contains a correct number of elements + fn check_data_len( + data: &[E], + shape: &Vec, + quantization: Option<&QuantizationStrategy>, + ) { + let mut expected_data_len = Self::numel(shape); + if let Some(quantization) = quantization { + let elem_per_data = match quantization { + QuantizationStrategy::PerTensorAffineInt8(_) + | QuantizationStrategy::PerTensorSymmetricInt8(_) => { + if TypeId::of::() == TypeId::of::() { + 4 + } else { + 1 + } + } + }; + expected_data_len = expected_data_len.div_ceil(elem_per_data); + } + let num_data = data.len(); + assert_eq!( + expected_data_len, num_data, + "Shape {:?} is invalid for input of size {:?}", + shape, num_data, + ); } /// Initializes a new tensor data structure from the provided values. - fn init>>(value: Vec, shape: S, dtype: DType) -> Self { + fn init(value: Vec, shape: Vec, dtype: DType) -> Self { Self { - bytes: into_bytes(value), - shape: shape.into(), + bytes: Bytes::from_elems(value), + shape, dtype, } } @@ -185,7 +194,7 @@ impl TensorData { } /// Returns the tensor data as a vector of scalar values. - pub fn into_vec(mut self) -> Result, DataError> { + pub fn into_vec(self) -> Result, DataError> { if E::dtype() != self.dtype { return Err(DataError::TypeMismatch(format!( "Invalid target element type (expected {:?}, got {:?})", @@ -194,19 +203,16 @@ impl TensorData { ))); } - let capacity_bytes = self.bytes.capacity(); - let length_bytes = self.bytes.len(); - let size_elem = core::mem::size_of::(); - - let capacity = capacity_bytes / size_elem; - let length = length_bytes / size_elem; - - unsafe { - let ptr = self.bytes.as_mut_ptr(); - core::mem::forget(self.bytes); - - Ok(Vec::from_raw_parts(ptr.cast::(), length, capacity)) - } + let mut me = self; + me.bytes = match me.bytes.try_into_vec::() { + Ok(elems) => return Ok(elems), + Err(bytes) => bytes, + }; + // The bytes might have been deserialized and allocated with a different align. + // In that case, we have to memcopy the data into a new vector, more suitably allocated + Ok(bytemuck::checked::try_cast_slice(me.values_as_bytes()) + .map_err(DataError::CastError)? + .to_vec()) } /// Returns an iterator over the values of the tensor data. @@ -405,7 +411,7 @@ impl TensorData { /// Returns the data as a slice of bytes. pub fn as_bytes(&self) -> &[u8] { - self.bytes.as_slice() + &self.bytes } /// Applies the data quantization strategy. diff --git a/crates/burn-tensor/src/tensor/mod.rs b/crates/burn-tensor/src/tensor/mod.rs index feed9571c4..40fa4b0f2c 100644 --- a/crates/burn-tensor/src/tensor/mod.rs +++ b/crates/burn-tensor/src/tensor/mod.rs @@ -1,12 +1,14 @@ pub(crate) mod stats; mod api; +mod bytes; mod data; mod distribution; mod element; mod shape; pub use api::*; +pub use bytes::*; pub use data::*; pub use distribution::*; pub use element::*; diff --git a/crates/burn-tensor/src/tensor/quantization/data.rs b/crates/burn-tensor/src/tensor/quantization/data.rs index 96096b784b..7f833edc58 100644 --- a/crates/burn-tensor/src/tensor/quantization/data.rs +++ b/crates/burn-tensor/src/tensor/quantization/data.rs @@ -1,10 +1,7 @@ use alloc::vec::Vec; /// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers. -/// -/// # Note -/// This assumes that the bytes represent `i8` values. -pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec { +pub fn pack_i8s_to_u32s(bytes: &[i8]) -> Vec { // Shift and combine groups of four 8-bit values into a u32. // Same as doing this: // let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF); @@ -12,7 +9,7 @@ pub fn pack_i8s_to_u32s(bytes: &[u8]) -> Vec { .chunks(4) .map(|x| { x.iter().enumerate().fold(0u32, |acc, (i, x)| { - acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8) + acc | (*x as u32 & 0xFF) << ((3 - i) * 8) }) }) .collect()