diff --git a/modules/core/Cargo.toml b/modules/core/Cargo.toml index b5ccc22..d661c6f 100644 --- a/modules/core/Cargo.toml +++ b/modules/core/Cargo.toml @@ -9,8 +9,6 @@ license-file = "LICENSE" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -axum-feature = ["axum"] -actix-feature = ["actix-web"] default = [] # below are the features for testing different engines sklearn-tests = [] @@ -20,7 +18,7 @@ tensorflow-tests = [] [dependencies] regex = "1.9.3" -ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false } +ort = { version = "2.0.0-rc.5", features = ["load-dynamic", "ndarray"], default-features = false } ndarray = "0.15.6" once_cell = "1.18.0" bytes = "1.5.0" @@ -29,9 +27,7 @@ futures-core = "0.3.28" thiserror = "1.0.57" serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0" -axum = { version = "0.7.4", optional = true } -actix-web = { version = "4.5.1", optional = true } - +nanoservices-utils = "0.1.5" [dev-dependencies] tokio = { version = "1.12.0", features = ["full"] } diff --git a/modules/core/src/errors.rs b/modules/core/src/errors.rs new file mode 100644 index 0000000..b6e0c4b --- /dev/null +++ b/modules/core/src/errors.rs @@ -0,0 +1,20 @@ + +#[macro_export] +macro_rules! safe_eject_option { + ($check:expr) => { + match $check {Some(x) => x, None => {let file_track = format!("{}:{}", file!(), line!());let message = format!("{}=>The value is not found", file_track);return Err(NanoServiceError::new(message, NanoServiceErrorStatus::NotFound))}} + }; +} + + +#[macro_export] +macro_rules! safe_eject_internal { + // Match when the optional string is provided + ($e:expr, $err_status:expr, $msg:expr) => { + $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); NanoServiceError::new(formatted_error, NanoServiceErrorStatus::Unknown)})? + }; + // Match when the optional string is not provided + ($e:expr) => { + $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); NanoServiceError::new(formatted_error, NanoServiceErrorStatus::Unknown)})? + }; +} \ No newline at end of file diff --git a/modules/core/src/errors/actix.rs b/modules/core/src/errors/actix.rs deleted file mode 100644 index 3ad0b95..0000000 --- a/modules/core/src/errors/actix.rs +++ /dev/null @@ -1,57 +0,0 @@ -//! Implements the `ResponseError` trait for the `SurrealError` type for the `actix_web` web framework. -use actix_web::{HttpResponse, error::ResponseError, http::StatusCode}; -pub use crate::errors::error::{SurrealErrorStatus, SurrealError}; - - -impl ResponseError for SurrealError { - - /// Yields the status code for the error. - /// - /// # Returns - /// * `StatusCode` - The status code for the error. - fn status_code(&self) -> StatusCode { - match self.status { - SurrealErrorStatus::NotFound => StatusCode::NOT_FOUND, - SurrealErrorStatus::Forbidden => StatusCode::FORBIDDEN, - SurrealErrorStatus::Unknown => StatusCode::INTERNAL_SERVER_ERROR, - SurrealErrorStatus::BadRequest => StatusCode::BAD_REQUEST, - SurrealErrorStatus::Conflict => StatusCode::CONFLICT, - SurrealErrorStatus::Unauthorized => StatusCode::UNAUTHORIZED - } - } - - /// Constructs a HTTP response for the error. - /// - /// # Returns - /// * `HttpResponse` - The HTTP response for the error. - fn error_response(&self) -> HttpResponse { - let status_code = self.status_code(); - HttpResponse::build(status_code).json(self.message.clone()) - } -} - - -#[cfg(test)] -mod tests { - use super::*; - use actix_web::http::StatusCode; - - #[test] - fn test_status_code() { - let error = SurrealError { - message: "Test".to_string(), - status: SurrealErrorStatus::NotFound - }; - assert_eq!(error.status_code(), StatusCode::NOT_FOUND); - } - - #[test] - fn test_error_response() { - let error = SurrealError { - message: "Test".to_string(), - status: SurrealErrorStatus::NotFound - }; - let response = error.error_response(); - assert_eq!(response.status(), StatusCode::NOT_FOUND); - } -} diff --git a/modules/core/src/errors/axum.rs b/modules/core/src/errors/axum.rs deleted file mode 100644 index b5a3821..0000000 --- a/modules/core/src/errors/axum.rs +++ /dev/null @@ -1,44 +0,0 @@ -//! Implements the `IntoResponse` trait for the `SurrealError` type for the `axum` web framework. -use axum::response::{IntoResponse, Response}; -use axum::body::Body; -pub use crate::errors::error::{SurrealErrorStatus, SurrealError}; - - -impl IntoResponse for SurrealError { - - /// Constructs a HTTP response for the error. - /// - /// # Returns - /// * `Response` - The HTTP response for the error. - fn into_response(self) -> Response { - let status_code = match self.status { - SurrealErrorStatus::NotFound => axum::http::StatusCode::NOT_FOUND, - SurrealErrorStatus::Forbidden => axum::http::StatusCode::FORBIDDEN, - SurrealErrorStatus::Unknown => axum::http::StatusCode::INTERNAL_SERVER_ERROR, - SurrealErrorStatus::BadRequest => axum::http::StatusCode::BAD_REQUEST, - SurrealErrorStatus::Conflict => axum::http::StatusCode::CONFLICT, - SurrealErrorStatus::Unauthorized => axum::http::StatusCode::UNAUTHORIZED - }; - axum::http::Response::builder() - .status(status_code) - .body(Body::new(self.message)) - .unwrap() - } -} - - -#[cfg(test)] -mod tests { - use super::*; - use axum::http::StatusCode; - - #[test] - fn test_into_response() { - let error = SurrealError { - message: "Test".to_string(), - status: SurrealErrorStatus::NotFound - }; - let response = error.into_response(); - assert_eq!(response.status(), StatusCode::NOT_FOUND); - } -} diff --git a/modules/core/src/errors/error.rs b/modules/core/src/errors/error.rs deleted file mode 100644 index edc30f2..0000000 --- a/modules/core/src/errors/error.rs +++ /dev/null @@ -1,101 +0,0 @@ -//! Custom error that can be attached to a web framework to automcatically result in a http response, -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use std::fmt; - - -#[macro_export] -macro_rules! safe_eject { - // Match when the optional string is provided - ($e:expr, $err_status:expr, $msg:expr) => { - $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, $err_status)})? - }; - // Match when the optional string is not provided - ($e:expr, $err_status:expr) => { - $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, $err_status)})? - }; -} - - -#[macro_export] -macro_rules! safe_eject_internal { - // Match when the optional string is provided - ($e:expr, $err_status:expr, $msg:expr) => { - $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, SurrealErrorStatus::Unknown)})? - }; - // Match when the optional string is not provided - ($e:expr) => { - $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, SurrealErrorStatus::Unknown)})? - }; -} - - -#[macro_export] -macro_rules! safe_eject_option { - ($check:expr) => { - match $check {Some(x) => x, None => {let file_track = format!("{}:{}", file!(), line!());let message = format!("{}=>The value is not found", file_track);return Err(SurrealError::new(message, SurrealErrorStatus::NotFound))}} - }; -} - - -/// The status of the custom error. -/// -/// # Fields -/// * `NotFound` - The request was not found. -/// * `Forbidden` - You are forbidden to access. -/// * `Unknown` - An unknown internal error occurred. -/// * `BadRequest` - The request was bad. -/// * `Conflict` - The request conflicted with the current state of the server. -#[derive(Error, Debug, Serialize, Deserialize, PartialEq)] -pub enum SurrealErrorStatus { - #[error("not found")] - NotFound, - #[error("You are forbidden to access resource")] - Forbidden, - #[error("Unknown Internal Error")] - Unknown, - #[error("Bad Request")] - BadRequest, - #[error("Conflict")] - Conflict, - #[error("Unauthorized")] - Unauthorized -} - - -/// The custom error that the web framework will construct into a HTTP response. -/// -/// # Fields -/// * `message` - The message of the error. -/// * `status` - The status of the error. -#[derive(Serialize, Deserialize, Debug, Error)] -pub struct SurrealError { - pub message: String, - pub status: SurrealErrorStatus -} - - -impl SurrealError { - - /// Create a new custom error. - /// - /// # Arguments - /// * `message` - The message of the error. - /// * `status` - The status of the error. - /// - /// # Returns - /// A new custom error. - pub fn new(message: String, status: SurrealErrorStatus) -> Self { - SurrealError { - message, - status - } - } -} - - -impl fmt::Display for SurrealError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.message) - } -} diff --git a/modules/core/src/errors/mod.rs b/modules/core/src/errors/mod.rs deleted file mode 100644 index e61abff..0000000 --- a/modules/core/src/errors/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod error; - -#[cfg(feature = "actix-feature")] -pub mod actix; - -#[cfg(feature = "axum-feature")] -pub mod axum; diff --git a/modules/core/src/execution/compute.rs b/modules/core/src/execution/compute.rs index 2f75637..5b7c652 100644 --- a/modules/core/src/execution/compute.rs +++ b/modules/core/src/execution/compute.rs @@ -1,12 +1,11 @@ //! Defines the operations around performing computations on a loaded model. use crate::storage::surml_file::SurMlFile; +use super::session::session; use std::collections::HashMap; -use ndarray::{ArrayD, CowArray}; -use ort::{SessionBuilder, Value, session::Input}; - -use super::onnx_environment::ENVIRONMENT; -use crate::safe_eject; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use ndarray::ArrayD; +use ort::{Input, ValueType}; +use nanoservices_utils::safe_eject; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// A wrapper for the loaded machine learning model so we can perform computations on the loaded model. @@ -27,7 +26,7 @@ impl <'a>ModelComputation<'a> { /// /// # Returns /// A Tensor that can be used as input to the loaded model. - pub fn input_tensor_from_key_bindings(&self, input_values: HashMap) -> Result, SurrealError> { + pub fn input_tensor_from_key_bindings(&self, input_values: HashMap) -> Result, NanoServiceError> { let buffer = self.input_vector_from_key_bindings(input_values)?; Ok(ndarray::arr1::(&buffer).into_dyn()) } @@ -39,15 +38,20 @@ impl <'a>ModelComputation<'a> { /// /// # Returns /// A vector of dimensions for the input tensor to be reshaped into from the loaded model. - fn process_input_dims(input_dims: &Input) -> Vec { - let mut buffer = Vec::new(); - for dim in input_dims.dimensions() { - match dim { - Some(dim) => buffer.push(dim as usize), - None => buffer.push(1) - } + fn process_input_dims(input_dims: &Input) -> Result, NanoServiceError> { + match &input_dims.input_type { + ValueType::Tensor { dimensions, .. } => { + let mut buffer = Vec::new(); + for dim in dimensions { + buffer.push(*dim as usize); + } + Ok(buffer) + }, + _ => Err(NanoServiceError::new( + String::from("compute => process_input_dims: Unknown input type for input dims"), + NanoServiceErrorStatus::Unknown + )) } - buffer } /// Creates a Vector that can be used manipulated with other operations such as normalisation from a hashmap of keys and values. @@ -57,13 +61,13 @@ impl <'a>ModelComputation<'a> { /// /// # Returns /// A Vector that can be used manipulated with other operations such as normalisation. - pub fn input_vector_from_key_bindings(&self, mut input_values: HashMap) -> Result, SurrealError> { + pub fn input_vector_from_key_bindings(&self, mut input_values: HashMap) -> Result, NanoServiceError> { let mut buffer = Vec::with_capacity(self.surml_file.header.keys.store.len()); for key in &self.surml_file.header.keys.store { let value = match input_values.get_mut(key) { Some(value) => value, - None => return Err(SurrealError::new(format!("src/execution/compute.rs 67: Key {} not found in input values", key), SurrealErrorStatus::NotFound)) + None => return Err(NanoServiceError::new(format!("src/execution/compute.rs 67: Key {} not found in input values", key), NanoServiceErrorStatus::NotFound)) }; buffer.push(std::mem::take(value)); } @@ -78,27 +82,52 @@ impl <'a>ModelComputation<'a> { /// /// # Returns /// The computed output tensor from the loaded model. - pub fn raw_compute(&self, tensor: ArrayD, _dims: Option<(i32, i32)>) -> Result, SurrealError> { - let session = safe_eject!(SessionBuilder::new(&ENVIRONMENT), SurrealErrorStatus::Unknown); - let session = safe_eject!(session.with_model_from_memory(&self.surml_file.model), SurrealErrorStatus::Unknown); - let unwrapped_dims = ModelComputation::process_input_dims(&session.inputs[0]); - let tensor = safe_eject!(tensor.into_shape(unwrapped_dims), SurrealErrorStatus::Unknown); + pub fn raw_compute(&self, tensor: ArrayD, _dims: Option<(i32, i32)>) -> Result, NanoServiceError> { + let session = session(&self.surml_file.model)?; + let unwrapped_dims = ModelComputation::process_input_dims(&session.inputs[0])?; + let tensor = safe_eject!( + tensor.into_shape(unwrapped_dims.clone()), + NanoServiceErrorStatus::Unknown, + "problem with reshaping tensor for raw_compute" + )?; + + // let x = CowArray::from(tensor).into_dyn(); + let mut buffer = Vec::with_capacity(tensor.len()); + for i in tensor.iter() { + buffer.push(*i); + } + + let buffer = ort::Tensor::from_array(( + unwrapped_dims, + buffer.into_boxed_slice() + )).unwrap(); - let x = CowArray::from(tensor).into_dyn(); - let input_values = safe_eject!(Value::from_array(session.allocator(), &x), SurrealErrorStatus::Unknown); - let outputs = safe_eject!(session.run(vec![input_values]), SurrealErrorStatus::Unknown); + let input_values = safe_eject!( + ort::inputs![buffer], + NanoServiceErrorStatus::Unknown, + "problem with creating input values in raw_compute" + )?; + let outputs = safe_eject!( + session.run(input_values), + NanoServiceErrorStatus::Unknown, + "problem with running session in raw_compute" + )?; - let mut buffer: Vec = Vec::new(); + let mut buffer: Vec = Vec::with_capacity(outputs.len()); // extract the output tensor converting the values to f32 if they are i64 - match outputs[0].try_extract::() { + match outputs[0].try_extract_tensor::() { Ok(y) => { for i in y.view().clone().into_iter() { buffer.push(*i); } }, Err(_) => { - for i in safe_eject!(outputs[0].try_extract::(), SurrealErrorStatus::Unknown).view().clone().into_iter() { + for i in safe_eject!( + outputs[0].try_extract_tensor::(), + NanoServiceErrorStatus::Unknown, + "problem with extracting output tensor in raw_compute" + )?.view().into_iter() { buffer.push(*i as f32); } } @@ -117,7 +146,7 @@ impl <'a>ModelComputation<'a> { /// /// # Returns /// The computed output tensor from the loaded model. - pub fn buffered_compute(&self, input_values: &mut HashMap) -> Result, SurrealError> { + pub fn buffered_compute(&self, input_values: &mut HashMap) -> Result, NanoServiceError> { // applying normalisers if present for (key, value) in &mut *input_values { let value_ref = value.clone(); @@ -139,9 +168,9 @@ impl <'a>ModelComputation<'a> { // apply the normaliser to the output let output_normaliser = match self.surml_file.header.output.normaliser.as_ref() { Some(normaliser) => normaliser, - None => return Err(SurrealError::new( + None => return Err(NanoServiceError::new( String::from("No normaliser present for output which shouldn't happen as passed initial check for").to_string(), - SurrealErrorStatus::Unknown + NanoServiceErrorStatus::Unknown )) }; let mut buffer = Vec::with_capacity(output.len()); diff --git a/modules/core/src/execution/mod.rs b/modules/core/src/execution/mod.rs index 39ebd06..b1f6c84 100644 --- a/modules/core/src/execution/mod.rs +++ b/modules/core/src/execution/mod.rs @@ -1,3 +1,4 @@ //! Defines operations around performing computations on a loaded model. pub mod compute; -pub mod onnx_environment; +pub mod session; +// pub mod onnx_environment; // This is deactivated for now to test the new implementation of ort v2 diff --git a/modules/core/src/execution/session.rs b/modules/core/src/execution/session.rs new file mode 100644 index 0000000..cda7748 --- /dev/null +++ b/modules/core/src/execution/session.rs @@ -0,0 +1,9 @@ +use ort::{InMemorySession, Session}; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; +use nanoservices_utils::safe_eject; + + +pub fn session(model_data: &Vec) -> Result { + let builder = Session::builder().unwrap(); + safe_eject!(builder.commit_from_memory_directly(model_data), NanoServiceErrorStatus::Unknown) +} diff --git a/modules/core/src/storage/header/keys.rs b/modules/core/src/storage/header/keys.rs index 28a3cd2..d7fbd5f 100644 --- a/modules/core/src/storage/header/keys.rs +++ b/modules/core/src/storage/header/keys.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use serde::{Serialize, Deserialize}; use crate::safe_eject_internal; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// Defines the key bindings for input data. @@ -86,7 +86,7 @@ impl KeyBindings { /// /// # Returns /// The key bindings constructed from the bytes. - pub fn from_bytes(data: &[u8]) -> Result { + pub fn from_bytes(data: &[u8]) -> Result { let data = safe_eject_internal!(String::from_utf8(data.to_vec())); Ok(Self::from_string(data)) } diff --git a/modules/core/src/storage/header/mod.rs b/modules/core/src/storage/header/mod.rs index 689a4ee..0611da6 100644 --- a/modules/core/src/storage/header/mod.rs +++ b/modules/core/src/storage/header/mod.rs @@ -20,8 +20,8 @@ use version::Version; use engine::Engine; use origin::Origin; use input_dims::InputDims; -use crate::safe_eject; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use nanoservices_utils::safe_eject; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// The header of the model file. @@ -81,7 +81,7 @@ impl Header { /// /// # Arguments /// * `version` - The version to be added. - pub fn add_version(&mut self, version: String) -> Result<(), SurrealError> { + pub fn add_version(&mut self, version: String) -> Result<(), NanoServiceError> { self.version = Version::from_string(version)?; Ok(()) } @@ -108,7 +108,7 @@ impl Header { /// # Arguments /// * `column_name` - The name of the column to which the normaliser will be applied. /// * `normaliser` - The normaliser to be applied to the column. - pub fn add_normaliser(&mut self, column_name: String, normaliser: NormaliserType) -> Result<(), SurrealError> { + pub fn add_normaliser(&mut self, column_name: String, normaliser: NormaliserType) -> Result<(), NanoServiceError> { let _ = self.normalisers.add_normaliser(normaliser, column_name, &self.keys)?; Ok(()) } @@ -120,7 +120,7 @@ impl Header { /// /// # Returns /// The normaliser for the given column name. - pub fn get_normaliser(&self, column_name: &String) -> Result, SurrealError> { + pub fn get_normaliser(&self, column_name: &String) -> Result, NanoServiceError> { self.normalisers.get_normaliser(column_name.to_string(), &self.keys) } @@ -154,7 +154,7 @@ impl Header { /// /// # Arguments /// * `origin` - The origin to be added. - pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> { + pub fn add_origin(&mut self, origin: String) -> Result<(), NanoServiceError> { self.origin.add_origin(origin) } @@ -170,9 +170,9 @@ impl Header { /// /// # Returns /// The `Header` struct. - pub fn from_bytes(data: Vec) -> Result { + pub fn from_bytes(data: Vec) -> Result { - let string_data = safe_eject!(String::from_utf8(data), SurrealErrorStatus::BadRequest); + let string_data = safe_eject!(String::from_utf8(data), NanoServiceErrorStatus::BadRequest)?; let buffer = string_data.split(Self::delimiter()).collect::>(); diff --git a/modules/core/src/storage/header/normalisers/mod.rs b/modules/core/src/storage/header/normalisers/mod.rs index 6f4f3ca..519fa10 100644 --- a/modules/core/src/storage/header/normalisers/mod.rs +++ b/modules/core/src/storage/header/normalisers/mod.rs @@ -14,7 +14,7 @@ use super::keys::KeyBindings; use utils::{extract_label, extract_two_numbers}; use wrapper::NormaliserType; use crate::safe_eject_option; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// A map of normalisers so they can be accessed by column name and input index. @@ -50,7 +50,7 @@ impl NormaliserMap { /// * `normaliser` - The normaliser to add. /// * `column_name` - The name of the column to which the normaliser is applied. /// * `keys_reference` - A reference to the key bindings to extract the index. - pub fn add_normaliser(&mut self, normaliser: NormaliserType, column_name: String, keys_reference: &KeyBindings) -> Result<(), SurrealError> { + pub fn add_normaliser(&mut self, normaliser: NormaliserType, column_name: String, keys_reference: &KeyBindings) -> Result<(), NanoServiceError> { let counter = self.store.len(); let column_input_index = safe_eject_option!(keys_reference.reference.get(column_name.as_str())); self.reference.insert(column_input_index.clone() as usize, counter as usize); @@ -67,7 +67,7 @@ impl NormaliserMap { /// /// # Returns /// The normaliser corresponding to the column name. - pub fn get_normaliser(&self, column_name: String, keys_reference: &KeyBindings) -> Result, SurrealError> { + pub fn get_normaliser(&self, column_name: String, keys_reference: &KeyBindings) -> Result, NanoServiceError> { let column_input_index = safe_eject_option!(keys_reference.reference.get(column_name.as_str())); let normaliser_index = self.reference.get(column_input_index); match normaliser_index { @@ -83,7 +83,7 @@ impl NormaliserMap { /// /// # Returns /// A tuple containing the label (type of normaliser), the numbers and the column name. - pub fn unpack_normaliser_data(normaliser_data: &str) -> Result<(String, [f32; 2], String), SurrealError> { + pub fn unpack_normaliser_data(normaliser_data: &str) -> Result<(String, [f32; 2], String), NanoServiceError> { let mut normaliser_buffer = normaliser_data.split("=>"); let column_name = safe_eject_option!(normaliser_buffer.next()); @@ -102,7 +102,7 @@ impl NormaliserMap { /// /// # Returns /// A `NormaliserMap` containing the normalisers. - pub fn from_string(data: String, keys_reference: &KeyBindings) -> Result { + pub fn from_string(data: String, keys_reference: &KeyBindings) -> Result { if data.len() == 0 { return Ok(NormaliserMap::fresh()) } diff --git a/modules/core/src/storage/header/normalisers/utils.rs b/modules/core/src/storage/header/normalisers/utils.rs index 21d2b58..8077eed 100644 --- a/modules/core/src/storage/header/normalisers/utils.rs +++ b/modules/core/src/storage/header/normalisers/utils.rs @@ -4,14 +4,14 @@ use crate::{ safe_eject_option, safe_eject_internal, }; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// Extracts the label from a normaliser string. /// /// # Arguments /// * `data` - A string containing the normaliser data. -pub fn extract_label(data: &String) -> Result { +pub fn extract_label(data: &String) -> Result { let re: Regex = safe_eject_internal!(Regex::new(r"^(.*?)\(")); let captures: Captures = safe_eject_option!(re.captures(data)); Ok(safe_eject_option!(captures.get(1)).as_str().to_string()) @@ -25,7 +25,7 @@ pub fn extract_label(data: &String) -> Result { /// /// # Returns /// [number1, number2] from `"(number1, number2)"` -pub fn extract_two_numbers(data: &String) -> Result<[f32; 2], SurrealError> { +pub fn extract_two_numbers(data: &String) -> Result<[f32; 2], NanoServiceError> { let re: Regex = safe_eject_internal!(Regex::new(r"[-+]?\d+(\.\d+)?")); let mut numbers = re.find_iter(data); let mut buffer: [f32; 2] = [0.0, 0.0]; diff --git a/modules/core/src/storage/header/normalisers/wrapper.rs b/modules/core/src/storage/header/normalisers/wrapper.rs index 92edc91..f9cb4ae 100644 --- a/modules/core/src/storage/header/normalisers/wrapper.rs +++ b/modules/core/src/storage/header/normalisers/wrapper.rs @@ -9,7 +9,7 @@ use super::utils::{extract_label, extract_two_numbers}; use super::traits::Normaliser; use crate::safe_eject_option; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// A wrapper for all different types of normalisers. @@ -56,7 +56,7 @@ impl NormaliserType { /// /// # Returns /// (type of normaliser, [normaliser parameters], column name) - pub fn unpack_normaliser_data(normaliser_data: &str) -> Result<(String, [f32; 2], String), SurrealError> { + pub fn unpack_normaliser_data(normaliser_data: &str) -> Result<(String, [f32; 2], String), NanoServiceError> { let mut normaliser_buffer = normaliser_data.split("=>"); let column_name = safe_eject_option!(normaliser_buffer.next()); @@ -74,7 +74,7 @@ impl NormaliserType { /// /// # Returns /// (normaliser, column name) - pub fn from_string(data: String) -> Result<(Self, String), SurrealError> { + pub fn from_string(data: String) -> Result<(Self, String), NanoServiceError> { let (label, numbers, column_name) = Self::unpack_normaliser_data(&data)?; let normaliser = match label.as_str() { "linear_scaling" => { @@ -98,9 +98,9 @@ impl NormaliserType { NormaliserType::ZScore(z_score::ZScore{mean, std_dev}) }, _ => { - let error = SurrealError::new( + let error = NanoServiceError::new( format!("Unknown normaliser type: {}", label).to_string(), - SurrealErrorStatus::Unknown + NanoServiceErrorStatus::Unknown ); return Err(error) } diff --git a/modules/core/src/storage/header/origin.rs b/modules/core/src/storage/header/origin.rs index 1f56f7d..c6a7130 100644 --- a/modules/core/src/storage/header/origin.rs +++ b/modules/core/src/storage/header/origin.rs @@ -1,6 +1,6 @@ //! Defines the origin of the model in the file. use serde::{Serialize, Deserialize}; -use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; use super::string_value::StringValue; @@ -40,12 +40,12 @@ impl OriginValue { /// /// # Returns /// A new `OriginValue`. - pub fn from_string(origin: String) -> Result { + pub fn from_string(origin: String) -> Result { match origin.to_lowercase().as_str() { LOCAL => Ok(OriginValue::Local(StringValue::from_string(origin))), SURREAL_DB => Ok(OriginValue::SurrealDb(StringValue::from_string(origin))), NONE => Ok(OriginValue::None(StringValue::from_string(origin))), - _ => Err(SurrealError::new(format!("invalid origin: {}", origin), SurrealErrorStatus::BadRequest)) + _ => Err(NanoServiceError::new(format!("invalid origin: {}", origin), NanoServiceErrorStatus::BadRequest)) } } @@ -100,7 +100,7 @@ impl Origin { /// Adds an origin to the origin struct. /// /// # Arguments - pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> { + pub fn add_origin(&mut self, origin: String) -> Result<(), NanoServiceError> { self.origin = OriginValue::from_string(origin)?; Ok(()) } @@ -123,7 +123,7 @@ impl Origin { /// /// # Returns /// A new origin. - pub fn from_string(origin: String) -> Result { + pub fn from_string(origin: String) -> Result { if origin == "".to_string() { return Ok(Origin::fresh()); } diff --git a/modules/core/src/storage/header/output.rs b/modules/core/src/storage/header/output.rs index e8aa6a8..0bce51d 100644 --- a/modules/core/src/storage/header/output.rs +++ b/modules/core/src/storage/header/output.rs @@ -1,13 +1,8 @@ //! Defines the struct housing data around the outputs of the model. use serde::{Serialize, Deserialize}; use super::normalisers::wrapper::NormaliserType; -use crate::{ - safe_eject_option, - errors::error::{ - SurrealError, - SurrealErrorStatus - } -}; +use crate::safe_eject_option; +use nanoservices_utils::errors::{NanoServiceError, NanoServiceErrorStatus}; /// Houses data around the outputs of the model. @@ -84,7 +79,7 @@ impl Output { /// /// # Returns /// * `Output` - The string as an instance of the Output struct. - pub fn from_string(data: String) -> Result { + pub fn from_string(data: String) -> Result { if data.contains("=>") == false { return Ok(Output::fresh()) } diff --git a/modules/core/src/storage/header/version.rs b/modules/core/src/storage/header/version.rs index 6a1a8c9..56c0777 100644 --- a/modules/core/src/storage/header/version.rs +++ b/modules/core/src/storage/header/version.rs @@ -1,12 +1,9 @@ //! Defines the process of managing the version of the `surml` file in the file. use serde::{Serialize, Deserialize}; -use crate::{ - safe_eject_option, +use crate::safe_eject_option; +use nanoservices_utils::{ safe_eject, - errors::error::{ - SurrealError, - SurrealErrorStatus - } + errors::{NanoServiceError, NanoServiceErrorStatus} }; @@ -56,7 +53,7 @@ impl Version { /// /// # Returns /// A new `Version` struct. - pub fn from_string(version: String) -> Result { + pub fn from_string(version: String) -> Result { if version == "".to_string() { return Ok(Version::fresh()) } @@ -66,9 +63,9 @@ impl Version { let three_str = safe_eject_option!(split.next()); Ok(Version { - one: safe_eject!(one_str.parse::(), SurrealErrorStatus::BadRequest), - two: safe_eject!(two_str.parse::(), SurrealErrorStatus::BadRequest), - three: safe_eject!(three_str.parse::(), SurrealErrorStatus::BadRequest), + one: safe_eject!(one_str.parse::(), NanoServiceErrorStatus::BadRequest)?, + two: safe_eject!(two_str.parse::(), NanoServiceErrorStatus::BadRequest)?, + three: safe_eject!(three_str.parse::(), NanoServiceErrorStatus::BadRequest)?, }) } diff --git a/modules/core/src/storage/stream_adapter.rs b/modules/core/src/storage/stream_adapter.rs index a5ad91a..ede0c82 100644 --- a/modules/core/src/storage/stream_adapter.rs +++ b/modules/core/src/storage/stream_adapter.rs @@ -7,11 +7,11 @@ use futures_core::stream::Stream; use futures_core::task::{Context, Poll}; use std::pin::Pin; use std::error::Error; -use crate::{ +use nanoservices_utils::{ safe_eject, - errors::error::{ - SurrealError, - SurrealErrorStatus + errors::{ + NanoServiceError, + NanoServiceErrorStatus } }; @@ -36,8 +36,8 @@ impl StreamAdapter { /// /// # Returns /// A new `StreamAdapter` struct. - pub fn new(chunk_size: usize, file_path: String) -> Result { - let file_pointer = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound); + pub fn new(chunk_size: usize, file_path: String) -> Result { + let file_pointer = safe_eject!(File::open(file_path), NanoServiceErrorStatus::NotFound)?; Ok(StreamAdapter { chunk_size, file_pointer diff --git a/modules/core/src/storage/surml_file.rs b/modules/core/src/storage/surml_file.rs index aa056ca..bbe8abd 100644 --- a/modules/core/src/storage/surml_file.rs +++ b/modules/core/src/storage/surml_file.rs @@ -1,14 +1,14 @@ //! Defines the saving and loading of the entire `surml` file. use std::fs::File; use std::io::{Read, Write}; +use super::header::Header; -use crate::{ - safe_eject_internal, +use crate::safe_eject_internal; +use nanoservices_utils::{ safe_eject, - storage::header::Header, - errors::error::{ - SurrealError, - SurrealErrorStatus + errors::{ + NanoServiceError, + NanoServiceErrorStatus } }; @@ -62,13 +62,13 @@ impl SurMlFile { /// /// # Returns /// A new `SurMlFile` struct. - pub fn from_bytes(bytes: Vec) -> Result { + pub fn from_bytes(bytes: Vec) -> Result { // check to see if there is enough bytes to read if bytes.len() < 4 { return Err( - SurrealError::new( + NanoServiceError::new( "Not enough bytes to read".to_string(), - SurrealErrorStatus::BadRequest + NanoServiceErrorStatus::BadRequest ) ); } @@ -83,9 +83,9 @@ impl SurMlFile { // check to see if there is enough bytes to read if bytes.len() < (4 + integer_value as usize) { return Err( - SurrealError::new( + NanoServiceError::new( "Not enough bytes to read for header, maybe the file format is not correct".to_string(), - SurrealErrorStatus::BadRequest + NanoServiceErrorStatus::BadRequest ) ); } @@ -112,23 +112,23 @@ impl SurMlFile { /// /// # Returns /// A new `SurMlFile` struct. - pub fn from_file(file_path: &str) -> Result { - let mut file = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound); + pub fn from_file(file_path: &str) -> Result { + let mut file = safe_eject!(File::open(file_path), NanoServiceErrorStatus::NotFound)?; // extract the first 4 bytes as an integer to get the length of the header let mut buffer = [0u8; 4]; - safe_eject!(file.read_exact(&mut buffer), SurrealErrorStatus::BadRequest); + safe_eject!(file.read_exact(&mut buffer), NanoServiceErrorStatus::BadRequest)?; let integer_value = u32::from_be_bytes(buffer); // Read the next integer_value bytes for the header let mut header_buffer = vec![0u8; integer_value as usize]; - safe_eject!(file.read_exact(&mut header_buffer), SurrealErrorStatus::BadRequest); + safe_eject!(file.read_exact(&mut header_buffer), NanoServiceErrorStatus::BadRequest)?; // Create a Vec to store the data let mut model_buffer = Vec::new(); // Read the rest of the file into the buffer - safe_eject!(file.take(usize::MAX as u64).read_to_end(&mut model_buffer), SurrealErrorStatus::BadRequest); + safe_eject!(file.take(usize::MAX as u64).read_to_end(&mut model_buffer), NanoServiceErrorStatus::BadRequest)?; // construct the header and C model from the bytes let header = Header::from_bytes(header_buffer)?; @@ -162,7 +162,7 @@ impl SurMlFile { /// /// # Returns /// An `io::Result` indicating whether the write was successful. - pub fn write(&self, file_path: &str) -> Result<(), SurrealError> { + pub fn write(&self, file_path: &str) -> Result<(), NanoServiceError> { let combined_vec = self.to_bytes(); // write the bytes to a file @@ -219,7 +219,7 @@ mod tests { match SurMlFile::from_bytes(bytes) { Ok(_) => assert!(false), Err(error) => { - assert_eq!(error.status, SurrealErrorStatus::BadRequest); + assert_eq!(error.status, NanoServiceErrorStatus::BadRequest); assert_eq!(error.to_string(), "Not enough bytes to read"); } }