Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Tensor::buffer_mut #127

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/openvino/src/prepostprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
//! # let data = fs::read("tests/fixtures/inception/tensor-1x3x299x299-f32.bgr").expect("to read the tensor from file");
//! # let input_shape = Shape::new(&vec![1, 299, 299, 3]).expect("to create a new shape");
//! # let mut tensor = Tensor::new(ElementType::F32, &input_shape).expect("to create a new tensor");
//! # let buffer = tensor.buffer_mut().unwrap();
//! # let buffer = tensor.get_raw_data_mut().unwrap();
//! # buffer.copy_from_slice(&data);
//! // Insantiate a new core, read in a model, and set up a tensor with input data before performing pre/post processing
//! // Instantiate a new core, read in a model, and set up a tensor with input data before performing pre/post processing
//! // Pre-process the input by:
//! // - converting NHWC to NCHW
//! // - resizing the input image
Expand Down
100 changes: 72 additions & 28 deletions crates/openvino/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,49 +86,79 @@ impl Tensor {
Ok(byte_size)
}

/// Get a mutable reference to the data of the tensor.
pub fn get_data<T>(&mut self) -> Result<&mut [T]> {
let mut data = std::ptr::null_mut();
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(data),))?;
let size = self.get_byte_size()? / std::mem::size_of::<T>();
let slice = unsafe { std::slice::from_raw_parts_mut(data.cast::<T>(), size) };
/// Get the underlying data for the tensor.
pub fn get_raw_data(&self) -> Result<&[u8]> {
let mut buffer = std::ptr::null_mut();
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(buffer)))?;
let size = self.get_byte_size()?;
let slice = unsafe { std::slice::from_raw_parts(buffer.cast::<u8>(), size) };
Ok(slice)
}

/// Get a mutable reference to the buffer of the tensor.
///
/// # Returns
///
/// A mutable reference to the buffer of the tensor.
pub fn buffer_mut(&mut self) -> Result<&mut [u8]> {
/// Get a mutable reference to the underlying data for the tensor.
pub fn get_raw_data_mut(&mut self) -> Result<&mut [u8]> {
let mut buffer = std::ptr::null_mut();
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(buffer)))?;
let size = self.get_byte_size()?;
let slice = unsafe { std::slice::from_raw_parts_mut(buffer.cast::<u8>(), size) };
Ok(slice)
}

/// Get a `T`-casted slice of the underlying data for the tensor.
///
/// # Panics
///
/// This method will panic if it can't cast the data to `T` due to the type size or the
/// underlying pointer's alignment.
pub fn get_data<T>(&self) -> Result<&[T]> {
let raw_data = self.get_raw_data()?;
let len = get_safe_len::<T>(raw_data);
let slice = unsafe { std::slice::from_raw_parts(raw_data.as_ptr().cast::<T>(), len) };
Ok(slice)
}

/// Get a mutable `T`-casted slice of the underlying data for the tensor.
///
/// # Panics
///
/// This method will panic if it can't cast the data to `T` due to the type size or the
/// underlying pointer's alignment.
pub fn get_data_mut<T>(&mut self) -> Result<&mut [T]> {
let raw_data = self.get_raw_data_mut()?;
let len = get_safe_len::<T>(raw_data);
let slice =
unsafe { std::slice::from_raw_parts_mut(raw_data.as_mut_ptr().cast::<T>(), len) };
Ok(slice)
}
}

/// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
/// length of that slice.
fn get_safe_len<T>(data: &[u8]) -> usize {
if data.len() % std::mem::size_of::<T>() != 0 {
panic!("data size is not a multiple of the size of `T`");
}
if data.as_ptr() as usize % std::mem::align_of::<T>() != 0 {
panic!("raw data is not aligned to `T`'s alignment");
}
data.len() / std::mem::size_of::<T>()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{ElementType, LoadingError, Shape};

#[test]
fn test_create_tensor() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let shape = Shape::new(&vec![1, 3, 227, 227]).unwrap();
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
assert!(!tensor.ptr.is_null());
}

#[test]
fn test_get_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -140,9 +170,7 @@ mod tests {

#[test]
fn test_get_element_type() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -154,9 +182,7 @@ mod tests {

#[test]
fn test_get_size() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -168,9 +194,7 @@ mod tests {

#[test]
fn test_get_byte_size() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -182,4 +206,24 @@ mod tests {
1 * 3 * 227 * 227 * std::mem::size_of::<f32>() as usize
);
}

#[test]
fn casting() {
openvino_sys::library::load().unwrap();
let shape = Shape::new(&vec![10, 10, 10]).unwrap();
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
let data = tensor.get_data::<f32>().unwrap();
assert_eq!(data.len(), 10 * 10 * 10);
}

#[test]
#[should_panic(expected = "data size is not a multiple of the size of `T`")]
fn casting_check() {
openvino_sys::library::load().unwrap();
let shape = Shape::new(&vec![10, 10, 10]).unwrap();
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
#[allow(dead_code)]
struct LargeOddType([u8; 1061]);
tensor.get_data::<LargeOddType>().unwrap();
}
}
2 changes: 1 addition & 1 deletion crates/openvino/tests/classify-alexnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn classify_alexnet() -> anyhow::Result<()> {
let input_shape = Shape::new(&vec![1, 227, 227, 3])?;
let element_type = ElementType::F32;
let mut tensor = Tensor::new(element_type, &input_shape)?;
let buffer = tensor.buffer_mut()?;
let buffer = tensor.get_raw_data_mut()?;
buffer.copy_from_slice(&data);

// Pre-process the input by:
Expand Down
2 changes: 1 addition & 1 deletion crates/openvino/tests/classify-inception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn classify_inception() -> anyhow::Result<()> {
let input_shape = Shape::new(&vec![1, 299, 299, 3])?;
let element_type = ElementType::F32;
let mut tensor = Tensor::new(element_type, &input_shape)?;
let buffer = tensor.buffer_mut()?;
let buffer = tensor.get_raw_data_mut()?;
buffer.copy_from_slice(&data);

// Pre-process the input by:
Expand Down
2 changes: 1 addition & 1 deletion crates/openvino/tests/classify-mobilenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn classify_mobilenet() -> anyhow::Result<()> {
let input_shape = Shape::new(&vec![1, 224, 224, 3])?;
let element_type = ElementType::F32;
let mut tensor = Tensor::new(element_type, &input_shape)?;
let buffer = tensor.buffer_mut()?;
let buffer = tensor.get_raw_data_mut()?;
buffer.copy_from_slice(&data);

// Pre-process the input by:
Expand Down
Loading