From d886107356ee59269be754800139f06467e99628 Mon Sep 17 00:00:00 2001 From: krazijames Date: Fri, 30 Oct 2020 22:07:13 +0900 Subject: [PATCH 1/2] Allow loading models from memory --- onnxruntime/src/session.rs | 55 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index bafb498b..56a4c13a 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -231,6 +231,61 @@ impl SessionBuilder { outputs, }) } + + /// Load an ONNX graph from memory and commit the session + pub fn with_model_from_memory(self, model_bytes: B) -> Result + where + B: AsRef<[u8]> + { + self.with_model_from_memory_monomorphized(model_bytes.as_ref()) + } + + fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result { + let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut(); + + let env_ptr: *const sys::OrtEnv = self.env.env_ptr(); + + let status = unsafe { + let model_data = model_bytes.as_ptr() as *const std::ffi::c_void; + let model_data_length = model_bytes.len() as u64; + g_ort().CreateSessionFromArray.unwrap()( + env_ptr, + model_data, + model_data_length, + self.session_options_ptr, + &mut session_ptr, + ) + }; + status_to_result(status).map_err(OrtError::Session)?; + assert_eq!(status, std::ptr::null_mut()); + assert_ne!(session_ptr, std::ptr::null_mut()); + + let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut(); + let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) }; + status_to_result(status).map_err(OrtError::Allocator)?; + assert_eq!(status, std::ptr::null_mut()); + assert_ne!(allocator_ptr, std::ptr::null_mut()); + + let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?; + + // Extract input and output properties + let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?; + let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?; + let inputs = (0..num_input_nodes) + .map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i)) + .collect::>>()?; + let outputs = (0..num_output_nodes) + .map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i)) + .collect::>>()?; + + Ok(Session { + session_ptr, + allocator_ptr, + memory_info, + inputs, + outputs, + }) + } } /// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html) From da17267c792c952d63cab84ebe08b0add7407a06 Mon Sep 17 00:00:00 2001 From: krazijames Date: Fri, 30 Oct 2020 22:45:51 +0900 Subject: [PATCH 2/2] Format --- onnxruntime/src/session.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 56a4c13a..2cef12e5 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -235,7 +235,7 @@ impl SessionBuilder { /// Load an ONNX graph from memory and commit the session pub fn with_model_from_memory(self, model_bytes: B) -> Result where - B: AsRef<[u8]> + B: AsRef<[u8]>, { self.with_model_from_memory_monomorphized(model_bytes.as_ref()) }