Skip to content

Commit

Permalink
Merge pull request #36 from krazijames/load-model-from-memory
Browse files Browse the repository at this point in the history
Allow loading models from memory
  • Loading branch information
nbigaouette authored Nov 1, 2020
2 parents b6383af + da17267 commit 32e9935
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,61 @@ impl SessionBuilder {
outputs,
})
}

/// Load an ONNX graph from memory and commit the session
pub fn with_model_from_memory<B>(self, model_bytes: B) -> Result<Session>
where
B: AsRef<[u8]>,
{
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
}

fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> Result<Session> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();

let env_ptr: *const sys::OrtEnv = self.env.env_ptr();

let status = unsafe {
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len() as u64;
g_ort().CreateSessionFromArray.unwrap()(
env_ptr,
model_data,
model_data_length,
self.session_options_ptr,
&mut session_ptr,
)
};
status_to_result(status).map_err(OrtError::Session)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_ptr, std::ptr::null_mut());

let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
status_to_result(status).map_err(OrtError::Allocator)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(allocator_ptr, std::ptr::null_mut());

let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;

// Extract input and output properties
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<Result<Vec<Output>>>()?;

Ok(Session {
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs,
})
}
}

/// Type storing the session information, built from an [`Environment`](environment/struct.Environment.html)
Expand Down

0 comments on commit 32e9935

Please sign in to comment.