Skip to content

Commit

Permalink
use Bytes in jit ops
Browse files Browse the repository at this point in the history
  • Loading branch information
WorldSEnder committed Dec 14, 2024
1 parent 2755009 commit f062826
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
8 changes: 2 additions & 6 deletions crates/burn-jit/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Range;
use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
DType, Device, Shape, TensorData,
Bytes, DType, Device, Shape, TensorData,
};

use crate::{
Expand Down Expand Up @@ -82,12 +82,8 @@ where
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;

// TODO: this should be refactored such that the bytes type is opaque.
// With this, the logic for handling the bytes representation of quantized data
// (as well as all data manipulations) will be encapsulated in the type.
// Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate.
TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape: tensor.shape.into(),
dtype: tensor.dtype,
}
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-jit/src/ops/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn_tensor::{
ops::{TransactionOps, TransactionPrimitiveResult},
DType, TensorData,
Bytes, DType, TensorData,
};

use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
Expand Down Expand Up @@ -74,23 +74,23 @@ where
Kind::Float(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_floats.push(TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
}
Kind::Int(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_ints.push(TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
}
Kind::Bool(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_bools.push(TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-tensor/src/tensor/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,20 @@ impl Bytes {
Ok(())
}

/// Create a sequence of [Bytes] from the memory representation of an unknown type of elements.
/// Prefer this over [Self::from_elems] when the datatype is not statically known and erased at runtime.
pub fn from_bytes_vec(bytes: Vec<u8>) -> Self {
let mut bytes = Self::from_elems(bytes);
// TODO: this method could be datatype aware and enforce a less strict alignment.
// On most platforms, this alignment check is fulfilled either way though, so
// the benefits of potentially saving a memcopy are negligible.
bytes.try_enforce_runtime_align(MAX_ALIGN).unwrap();
bytes
}

/// Erase the element type of a vector by converting into a sequence of [Bytes].
///
/// In case the element type is not statically known at runtime, prefer to use [Self::from_bytes_vec].
pub fn from_elems<E>(elems: Vec<E>) -> Self
where
// NoUninit implies Copy
Expand Down

0 comments on commit f062826

Please sign in to comment.