Skip to content

Commit

Permalink
make bytes no-std compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
WorldSEnder committed Oct 24, 2024
1 parent 45748ca commit 5e9679a
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions crates/burn-tensor/src/tensor/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! A version of [`bytemuck::BoxBytes`] that is cloneable and allows trailing uninitialized elements.
use std::alloc::{Layout, LayoutError};
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;
use alloc::alloc::{Layout, LayoutError};
use alloc::borrow::Cow;
use core::ops::{Deref, DerefMut};
use core::ptr::NonNull;

/// A sort of `Box<[u8]>` that remembers the original alignment and can contain trailing uninitialized bytes.
pub struct Bytes {
Expand All @@ -18,28 +18,28 @@ pub struct Bytes {

/// The maximum supported alignment. The limit exists to not have to store alignment when serializing. Instead,
/// the bytes are always over-aligned when deserializing to MAX_ALIGN.
const MAX_ALIGN: usize = std::mem::align_of::<u128>();
const MAX_ALIGN: usize = core::mem::align_of::<u128>();

fn debug_from_fn<F: Fn(&mut std::fmt::Formatter<'_>) -> std::fmt::Result>(
fn debug_from_fn<F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result>(
f: F,
) -> impl std::fmt::Debug {
) -> impl core::fmt::Debug {
// See also: std::fmt::from_fn
struct FromFn<F>(F);
impl<F> std::fmt::Debug for FromFn<F>
impl<F> core::fmt::Debug for FromFn<F>
where
F: Fn(&mut std::fmt::Formatter<'_>) -> std::fmt::Result,
F: Fn(&mut core::fmt::Formatter<'_>) -> core::fmt::Result,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
(self.0)(f)
}
}
FromFn(f)
}

impl std::fmt::Debug for Bytes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl core::fmt::Debug for Bytes {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let data = &**self;
let fmt_data = move |f: &mut std::fmt::Formatter<'_>| {
let fmt_data = move |f: &mut core::fmt::Formatter<'_>| {
if data.len() > 3 {
// There is a nightly API `debug_more_non_exhaustive` which has `finish_non_exhaustive`
f.debug_list().entries(&data[0..3]).entry(&"...").finish()
Expand Down Expand Up @@ -99,8 +99,8 @@ impl Bytes {
// correctly aligned, slice of bytes. Since we might not be able to fully predict the length and align ahead of time, this does currently
// not seem worth the hassle.
let bytes = unsafe {
let mem = std::alloc::alloc(layout);
std::ptr::copy_nonoverlapping(data.as_ref().as_ptr(), mem, len);
let mem = alloc::alloc::alloc(layout);
core::ptr::copy_nonoverlapping(data.as_ref().as_ptr(), mem, len);
NonNull::new_unchecked(mem)
};
Ok(Self {
Expand All @@ -114,12 +114,12 @@ impl Bytes {
pub fn from_elems<E: bytemuck::NoUninit + Copy>(mut elems: Vec<E>) -> Self {
let _: () = const {
assert!(
std::mem::align_of::<E>() <= MAX_ALIGN,
core::mem::align_of::<E>() <= MAX_ALIGN,
"element type not supported due to too large alignment"
);
};
// Note: going through a Box as in Vec::into_boxed_slice would re-allocate on excess capacity. Avoid that.
let byte_len = elems.len() * std::mem::size_of::<E>();
let byte_len = elems.len() * core::mem::size_of::<E>();
// Set the length to 0, then all data is in the "spare capacity".
// SAFETY: Data is Copy, so in particular does not need to be dropped. In any case, try not to panic until
// we have taken ownership of the data!
Expand All @@ -130,7 +130,7 @@ impl Bytes {
// SAFETY: data is the allocation of a vec, hence can not be null. We use unchecked to avoid a panic-path.
let ptr = unsafe { NonNull::new_unchecked(data.as_mut_ptr() as *mut u8) };
// Now we manage the memory manually, forget the vec.
std::mem::forget(elems);
core::mem::forget(elems);
Self {
ptr,
len: byte_len,
Expand All @@ -151,15 +151,15 @@ impl Bytes {
let Some(capacity) = self.layout.size().checked_div(size_of::<E>()) else {
return Err(self);
};
if self.layout.align() != std::mem::align_of::<E>() {
if self.layout.align() != core::mem::align_of::<E>() {
return Err(self);
}
let Ok(data) = bytemuck::checked::try_cast_slice_mut::<_, E>(&mut self) else {
return Err(self);
};
let length = data.len();
let data = data.as_mut_ptr();
std::mem::forget(self);
core::mem::forget(self);
// SAFETY:
// - data was allocated by the global allocator as per type-invariant
// - `E` has the same alignment as indicated by the stored layout.
Expand All @@ -176,22 +176,22 @@ impl Deref for Bytes {

fn deref(&self) -> &Self::Target {
// SAFETY: see type invariants
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
}

impl DerefMut for Bytes {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: see type invariants
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut(), self.len) }
unsafe { core::slice::from_raw_parts_mut(self.ptr.as_mut(), self.len) }
}
}

impl Drop for Bytes {
fn drop(&mut self) {
if self.layout.size() != 0 {
unsafe {
std::alloc::dealloc(self.ptr.as_ptr(), self.layout);
alloc::alloc::dealloc(self.ptr.as_ptr(), self.layout);
}
}
}
Expand Down

0 comments on commit 5e9679a

Please sign in to comment.