Skip to content

Commit

Permalink
Use new API calls for OpenVino (#1084)
Browse files Browse the repository at this point in the history
* Use new API calls

* Fix some mistakes

* Remove scratchpad state, directly fill the tensor instead (for now)

* Remove manually specifying generics

* Remove AsBytes trait
  • Loading branch information
tuxbotix authored and okiwi6 committed Sep 30, 2024
1 parent ded8d40 commit 140750b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 37 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ nix = { version = "0.28", features = ["ioctl"] }
num-derive = "0.4.2"
num-traits = "0.2"
once_cell = "1.19.0"
openvino = { version = "0.7.1", features = ["runtime-linking"] }
openvino = { version = "0.7.2", features = ["runtime-linking"] }
opn = { path = "crates/opn" }
object_detection = { path = "crates/object_detection" }
opusfile-ng = "0.1.0"
Expand Down
35 changes: 5 additions & 30 deletions crates/object_detection/src/pose_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ const STRIDE: usize = DETECTION_IMAGE_HEIGHT * DETECTION_IMAGE_WIDTH;

#[derive(Deserialize, Serialize)]
pub struct PoseDetection {
#[serde(skip, default = "deserialize_not_implemented")]
scratchpad: Box<[f32]>,
#[serde(skip, default = "deserialize_not_implemented")]
network: CompiledModel,
}
Expand Down Expand Up @@ -108,16 +106,7 @@ impl PoseDetection {
bail!("expected exactly one input and one output");
}

let input_shape = network
.get_input_by_index(0)
.wrap_err("failed to get input node")?
.get_shape()
.wrap_err("failed to get shape of input node")?;
let number_of_elements = input_shape.get_dimensions().iter().product::<i64>();
let scratchpad = vec![0.0; number_of_elements as usize].into_boxed_slice();

Ok(Self {
scratchpad,
network: core.compile_model(&network, DeviceType::CPU)?,
})
}
Expand All @@ -139,10 +128,12 @@ impl PoseDetection {
};

let image = context.image;

let mut tensor = Tensor::new(ElementType::F32, &self.network.get_input()?.get_shape()?)?;
{
let earlier = SystemTime::now();

load_into_scratchpad(self.scratchpad.as_mut(), image);
load_into_scratchpad(tensor.get_data_mut()?, image);

context.preprocess_duration.fill_if_subscribed(|| {
SystemTime::now()
Expand All @@ -152,11 +143,7 @@ impl PoseDetection {
}

let mut infer_request = self.network.create_infer_request()?;
let tensor = Tensor::new_from_host_ptr(
ElementType::F32,
&self.network.get_input()?.get_shape()?,
self.scratchpad.as_bytes(),
)?;

infer_request.set_input_tensor(&tensor)?;

{
Expand All @@ -169,7 +156,7 @@ impl PoseDetection {
.expect("time ran backwards")
});
}
let mut prediction = infer_request.get_output_tensor()?;
let prediction = infer_request.get_output_tensor()?;
let prediction =
ArrayView::from_shape((56, MAX_DETECTIONS), prediction.get_data::<f32>()?)?;

Expand Down Expand Up @@ -267,15 +254,3 @@ fn non_maximum_suppression(

poses
}

trait AsBytes {
fn as_bytes(&self) -> &[u8];
}

impl AsBytes for [f32] {
fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(self.as_ptr() as *const u8, std::mem::size_of_val(self))
}
}
}

0 comments on commit 140750b

Please sign in to comment.