diff --git a/.gitignore b/.gitignore index a147089..0d41899 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target/ -env/ \ No newline at end of file +env/ +*.lock \ No newline at end of file diff --git a/workflow-with-state-management-poc/runtime/Cargo.toml b/workflow-with-state-management-poc/runtime/Cargo.toml new file mode 100644 index 0000000..3c8896e --- /dev/null +++ b/workflow-with-state-management-poc/runtime/Cargo.toml @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +[package] +name = "test_util" +version = "0.1.0" +edition = "2021" + +[dependencies] +libc = "0.2.49" +serde = { version = "1.0.81", features = ["derive"] } +serde_json = "1.0.81" +serde_derive = "1.0.81" +anyhow = "1.0.56" +wasmtime = "0.36.0" +wasmtime-wasi = "0.36.0" +wasi-common = "0.36.0" +bytes = "1" +futures = "0.3" +http = "0.2" +reqwest = { version = "0.11", default-features = true, features = [ + "json", + "blocking", +] } +thiserror = "1.0" +tokio = { version = "1.4.0", features = ["full"] } +tracing = { version = "0.1", features = ["log"] } +url = "2.2.1" +openssl = { version = "0.10", features = ["vendored"] } +openwhisk-client-rust = { git = "https://github.com/HugoByte/openwhisk-client-rust.git", branch = "master" } +wiremock = "0.5.17" +async-std = { version = "1.12.0", features = ["attributes"] } +dyn-clone = "1.0.7" +cached = { version = "0.49.2", features = [ "redis_store"]} +sha256 = "1.5.0" \ No newline at end of file diff --git a/workflow-with-state-management-poc/runtime/src/helper.rs b/workflow-with-state-management-poc/runtime/src/helper.rs new file mode 100644 index 0000000..873d0f5 --- /dev/null +++ b/workflow-with-state-management-poc/runtime/src/helper.rs @@ -0,0 +1,184 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use wiremock::{ + matchers::{method, path}, + Mock, MockServer, ResponseTemplate, +}; + +async fn create_server(add: &str) -> MockServer { + let listener = std::net::TcpListener::bind(add).unwrap(); + MockServer::builder().listener(listener).start().await +} + +pub async fn post(address: &str) -> MockServer { + let server = create_server(address).await; + + let mut r = HashMap::new(); + r.insert( + "maruthi".to_string(), + vec!["800".to_string(), "alto".to_string()], + ); + + let res = Cartype { + car_company_list: r, + }; + + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/cartype")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + let res = ModelAvail { + models: vec!["800".to_string(), "alto".to_string()], + }; + + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/modelavail")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + let mut r = HashMap::new(); + r.insert("800".to_string(), 1200000); + r.insert("alto".to_string(), 1800000); + + // let asd = serde_json::json!({ + // "result" : r + // }); + + let res = ModelPrice { + model_price_list: r, + }; + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/modelsprice")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + let res = Purchase { + message: String::from("Thank you for the purchase"), + }; + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/purchase")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + let res = EmplyeeIds { + ids: vec![1, 2, 3, 4, 5], + }; + + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/employee_ids")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + let res = GetSalary { salary: 10000000 }; + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/getsalaries")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + let res = GetAddress { + address: "HugoByte".to_string(), + }; + + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/getaddress")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + let res = vec!["Salary creditted for emp id 1 from Hugobyte "]; + + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/salary")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + let res = serde_json::json!({ + "company" : "Hugobyte-Ai-Labs", + "company_reg_id": "851-Hugobyte-Ai-Labs", + }); + + Mock::given(method("POST")) + .and(path("/api/v1/namespaces/guest/actions/get_company_name")) + .respond_with( + ResponseTemplate::new(200) + .insert_header("Content-Type", "application/json") + .set_body_json(res), + ) + .mount(&server) + .await; + + server +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Cartype { + car_company_list: HashMap>, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ModelAvail { + models: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ModelPrice { + model_price_list: HashMap, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Purchase { + message: String, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct EmplyeeIds { + ids: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct GetSalary { + salary: i32, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct GetAddress { + address: String, +} diff --git a/workflow-with-state-management-poc/runtime/src/lib.rs b/workflow-with-state-management-poc/runtime/src/lib.rs new file mode 100644 index 0000000..851be41 --- /dev/null +++ b/workflow-with-state-management-poc/runtime/src/lib.rs @@ -0,0 +1,243 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +mod types; +#[cfg(test)] +mod wasi_http; + +pub use types::*; +pub mod helper; +pub use helper::*; + +#[cfg(test)] +mod tests { + use super::*; + use cached::IOCached; + use cached::RedisCache; + use sha256::digest; + use std::collections::HashMap; + use std::convert::TryInto; + use std::{ + fs, + sync::{Arc, Mutex}, + }; + use wasi_common::WasiCtx; + use wasi_http::HttpCtx; + use wasmtime::Linker; + use wasmtime::*; + use wasmtime_wasi::sync::WasiCtxBuilder; + + #[allow(dead_code)] + + fn run_workflow(data: Value, path: String) -> (Output, Vec) { + let redis_cache: RedisCache)> = + RedisCache::new("workflow".to_string(), 30) + .set_connection_string("redis://127.0.0.1:6379") + .set_refresh(true) + // .set_namespace("workflows") + .build() + .unwrap(); + + let key = digest(format!("{:?}{:?}", data, path)); + let output = redis_cache.cache_get(&key).unwrap(); + if output.is_some() { + println!("workflow cache hit!"); + return output.unwrap(); + } + + let wasm_file = fs::read(path).unwrap(); + let input: MainInput = serde_json::from_value(data).unwrap(); + let engine = Engine::default(); + let mut linker = Linker::new(&engine); + + let output: Arc> = Arc::new(Mutex::new(Output { + result: serde_json::json!({}), + })); + + let output_ = output.clone(); + let buf = serde_json::to_vec(&input).expect("should serialize"); + let mem_size: i32 = buf.len() as i32; + + linker + .func_wrap("host", "get_input_size", move || -> i32 { mem_size }) + .expect("should define the function"); + + linker + .func_wrap( + "host", + "set_output", + move |mut caller: Caller<'_, WasiCtx>, ptr: i32, capacity: i32| { + let output = output_.clone(); + let mem = match caller.get_export("memory") { + Some(Extern::Memory(mem)) => mem, + _ => return Err(Trap::new("failed to find host memory")), + }; + let offset = ptr as u32 as usize; + let mut buffer: Vec = vec![0; capacity as usize]; + match mem.read(&caller, offset, &mut buffer) { + Ok(()) => match serde_json::from_slice::(&buffer) { + Ok(serialized_output) => { + let mut output = output.lock().unwrap(); + *output = serialized_output; + Ok(()) + } + Err(err) => { + let msg = format!("failed to serialize host memory: {}", err); + Err(Trap::new(msg)) + } + }, + _ => Err(Trap::new("failed to read host memory")), + } + }, + ) + .expect("should define the function"); + + let output_2: Arc>> = Arc::new(Mutex::new(Vec::new())); + + let output_ = output_2.clone(); + + linker + .func_wrap( + "host", + "set_state", + move |mut caller: Caller<'_, WasiCtx>, ptr: i32, capacity: i32| { + let output_2 = output_.clone(); + let mem = match caller.get_export("memory") { + Some(Extern::Memory(mem)) => mem, + _ => return Err(Trap::new("failed to find host memory")), + }; + let offset = ptr as u32 as usize; + let mut buffer: Vec = vec![0; capacity as usize]; + match mem.read(&caller, offset, &mut buffer) { + Ok(()) => match serde_json::from_slice::(&buffer) { + Ok(serialized_output) => { + let mut output_2 = output_2.lock().unwrap(); + output_2.push(serialized_output); + Ok(()) + } + Err(err) => { + let msg = format!("failed to serialize host memory: {}", err); + Err(Trap::new(msg)) + } + }, + _ => Err(Trap::new("failed to read host memory")), + } + }, + ) + .expect("should define the function"); + + wasmtime_wasi::add_to_linker(&mut linker, |s| s).unwrap(); + let wasi = WasiCtxBuilder::new() + .inherit_stdio() + .inherit_args() + .unwrap() + .build(); + let mut store = Store::new(&engine, wasi); + let module = Module::from_binary(&engine, &wasm_file).unwrap(); + let max_concurrent_requests = Some(42); + + let http = HttpCtx::new(input.allowed_hosts, max_concurrent_requests).unwrap(); + http.add_to_linker(&mut linker).unwrap(); + + let linking = linker.instantiate(&mut store, &module).unwrap(); + + let malloc = linking + .get_typed_func::<(i32, i32), i32, _>(&mut store, "memory_alloc") + .unwrap(); + let data = serde_json::to_vec(&input.data).unwrap(); + let data_ptr = malloc.call(&mut store, (data.len() as i32, 2)).unwrap(); + + let memory = linking.get_memory(&mut store, "memory").unwrap(); + memory.write(&mut store, data_ptr as usize, &data).unwrap(); + let len: i32 = data.len().try_into().unwrap(); + let run = linking + .get_typed_func::<(i32, i32), (), _>(&mut store, "_start") + .unwrap(); + + let _result_from_wasm = run.call(&mut store, (data_ptr, len)); + let malloc = linking + .get_typed_func::<(i32, i32, i32), (), _>(&mut store, "free_memory") + .unwrap(); + malloc + .call(&mut store, (data_ptr, data.len() as i32, 2)) + .unwrap(); + + let state_output = output_2.lock().unwrap().clone(); + let res = output.lock().unwrap().clone(); + + if !res.is_err() { + // caching the results + println!("workflow cache set!"); + redis_cache + .cache_set(key, (res.clone(), state_output.clone())) + .unwrap(); + } else { + println!("workflow cache miss!"); + } + + (res, state_output) + } + + #[async_std::test] + async fn test_employee_salary_with_concat_operator() { + let path = std::env::var("WORKFLOW_WASM").unwrap_or(format!( + "../state-managed-workflow/target/wasm32-wasi/release/boilerplate.wasm" // "/Users/ajaykumar/Downloads/Github/Hugobyte/testing/runtime_dev/sled_db_examples/boilerplate-copy.wasm" + )); + let server = post("127.0.0.1:1234").await; + let input = serde_json::json!({ + "allowed_hosts": [ + server.uri() + ], + "data": { + "role":"Software Developer", + } + }); + + let (result, state_data) = run_workflow(input, path); + + // println!("State_data => {:#?}", state_data); + + let mut outputs: HashMap = HashMap::new(); + + for sd in state_data { + if sd.is_success() { + outputs.insert(sd.get_action_name(), sd.get_output()); + } + } + + println!("Outputs => {:#?}", outputs); + println!("Result => {:#?}", result); + + assert!(result + .result + .to_string() + .contains("Salary creditted for emp id 1 from Hugobyte")) + } + + #[async_std::test] + async fn test_car_market_place() { + let path = std::env::var("WORKFLOW_WASM").unwrap_or(format!( + "../state-managed-workflow/target/wasm32-wasi/release/boilerplate.wasm" + )); + + let server = post("127.0.0.1:8080").await; + let input = serde_json::json!({ + "allowed_hosts": [ + server.uri() + ], + "data": { + "car_type":"hatchback", + "company_name":"maruthi", + "model_name":"alto", + "price":1200000 + } + }); + let (result, _state_data) = run_workflow(input, path); + + println!("State_data => {:#?}", _state_data); + + assert!(result + .result + .to_string() + .contains("Thank you for the purchase")) + } +} diff --git a/workflow-with-state-management-poc/runtime/src/types.rs b/workflow-with-state-management-poc/runtime/src/types.rs new file mode 100644 index 0000000..65adab7 --- /dev/null +++ b/workflow-with-state-management-poc/runtime/src/types.rs @@ -0,0 +1,62 @@ +use super::*; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub enum ExecutionState { + // Init, + Running, + Aborted, + Success, +} + +impl Default for ExecutionState { + fn default() -> Self { + ExecutionState::Running + } +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +#[serde(rename= "StateManager")] +pub struct StateData { + // execution_state: ExecutionState, // to represent the task life cycle + action_name: String, // task name + task_index: isize, // n'th task out of m tasks + execution_state: ExecutionState, + output: Option, + error: Option, // to define the error kind +} + +impl StateData{ + pub fn is_success(&self) -> bool { + self.execution_state == ExecutionState::Success && self.task_index > -1 + } + + pub fn get_output(&self) -> Value { + self.output.clone().into() + } + + pub fn get_action_name(&self) -> String { + self.action_name.clone() + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Output { + pub result: Value, +} + +impl Output{ + pub fn is_err(&self) -> bool{ + self.result.get("Err").is_some() + } +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct Resultss { + pub result: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MainInput { + pub allowed_hosts: Option>, + pub data: Value, +} \ No newline at end of file diff --git a/workflow-with-state-management-poc/runtime/src/wasi_http.rs b/workflow-with-state-management-poc/runtime/src/wasi_http.rs new file mode 100644 index 0000000..2dd963b --- /dev/null +++ b/workflow-with-state-management-poc/runtime/src/wasi_http.rs @@ -0,0 +1,727 @@ +use anyhow::Error; +use bytes::Bytes; +use futures::executor::block_on; +use http::{header::HeaderName, HeaderMap, HeaderValue}; +use reqwest::{Client, Method}; +use std::{ + collections::HashMap, + str::FromStr, + sync::{Arc, PoisonError, RwLock}, +}; +use tokio::runtime::Handle; +use url::Url; +use wasmtime::*; + +const MEMORY: &str = "memory"; + +pub type WasiHttpHandle = u32; + +/// Response body for HTTP requests, consumed by guest modules. +struct Body { + bytes: Bytes, + pos: usize, +} + +/// An HTTP response abstraction that is persisted across multiple +/// host calls. +struct Response { + headers: HeaderMap, + body: Body, +} + +/// Host state for the responses of the instance. +#[derive(Default)] +struct State { + responses: HashMap, + current_handle: WasiHttpHandle, +} + +#[derive(Debug, thiserror::Error)] +enum HttpError { + #[error("Invalid handle: [{0}]")] + InvalidHandle(WasiHttpHandle), + #[error("Memory not found")] + MemoryNotFound, + #[error("Memory access error")] + MemoryAccessError(#[from] wasmtime::MemoryAccessError), + #[error("Buffer too small")] + BufferTooSmall, + #[error("Header not found")] + HeaderNotFound, + #[error("UTF-8 error")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Destination not allowed")] + DestinationNotAllowed(String), + #[error("Invalid method")] + InvalidMethod, + #[error("Invalid encoding")] + InvalidEncoding, + #[error("Invalid URL")] + InvalidUrl, + #[error("HTTP error")] + RequestError(#[from] reqwest::Error), + #[error("Runtime error")] + RuntimeError, + #[error("Too many sessions")] + TooManySessions, +} + +impl From for u32 { + fn from(e: HttpError) -> u32 { + match e { + HttpError::InvalidHandle(_) => 1, + HttpError::MemoryNotFound => 2, + HttpError::MemoryAccessError(_) => 3, + HttpError::BufferTooSmall => 4, + HttpError::HeaderNotFound => 5, + HttpError::Utf8Error(_) => 6, + HttpError::DestinationNotAllowed(_) => 7, + HttpError::InvalidMethod => 8, + HttpError::InvalidEncoding => 9, + HttpError::InvalidUrl => 10, + HttpError::RequestError(_) => 11, + HttpError::RuntimeError => 12, + HttpError::TooManySessions => 13, + } + } +} + +impl From>> for HttpError { + fn from(_: PoisonError>) -> Self { + HttpError::RuntimeError + } +} + +impl From>> for HttpError { + fn from(_: PoisonError>) -> Self { + HttpError::RuntimeError + } +} + +impl From> for HttpError { + fn from(_: PoisonError<&mut State>) -> Self { + HttpError::RuntimeError + } +} + +struct HostCalls; + +impl HostCalls { + /// Remove the current handle from the state. + /// Depending on the implementation, guest modules might + /// have to manually call `close`. + // TODO (@radu-matei) + // Fix the clippy warning. + #[allow(clippy::unnecessary_wraps)] + fn close(st: Arc>, handle: WasiHttpHandle) -> Result<(), HttpError> { + let mut st = st.write()?; + st.responses.remove(&handle); + Ok(()) + } + + /// Read `buf_len` bytes from the response of `handle` and + /// write them into `buf_ptr`. + fn body_read( + st: Arc>, + memory: Memory, + mut store: impl AsContextMut, + handle: WasiHttpHandle, + buf_ptr: u32, + buf_len: u32, + buf_read_ptr: u32, + ) -> Result<(), HttpError> { + let mut st = st.write()?; + + let body = &mut st.responses.get_mut(&handle).unwrap().body; + let mut context = store.as_context_mut(); + + // Write at most either the remaining of the response body, or the entire + // length requested by the guest. + let available = std::cmp::min(buf_len as _, body.bytes.len() - body.pos); + memory.write( + &mut context, + buf_ptr as _, + &body.bytes[body.pos..body.pos + available], + )?; + body.pos += available; + // Write the number of bytes written back to the guest. + memory.write( + &mut context, + buf_read_ptr as _, + &(available as u32).to_le_bytes(), + )?; + Ok(()) + } + + /// Get a response header value given a key. + #[allow(clippy::too_many_arguments)] + fn header_get( + st: Arc>, + memory: Memory, + mut store: impl AsContextMut, + handle: WasiHttpHandle, + name_ptr: u32, + name_len: u32, + value_ptr: u32, + value_len: u32, + value_written_ptr: u32, + ) -> Result<(), HttpError> { + let st = st.read()?; + + // Get the current response headers. + let headers = &st + .responses + .get(&handle) + .ok_or(HttpError::InvalidHandle(handle))? + .headers; + + let mut store = store.as_context_mut(); + + // Read the header key from the module's memory. + let key = string_from_memory(&memory, &mut store, name_ptr, name_len)?.to_ascii_lowercase(); + // Attempt to get the corresponding value from the resposne headers. + let value = headers.get(key).ok_or(HttpError::HeaderNotFound)?; + if value.len() > value_len as _ { + return Err(HttpError::BufferTooSmall); + } + // Write the header value and its length. + memory.write(&mut store, value_ptr as _, value.as_bytes())?; + memory.write( + &mut store, + value_written_ptr as _, + &(value.len() as u32).to_le_bytes(), + )?; + Ok(()) + } + + fn headers_get_all( + st: Arc>, + memory: Memory, + mut store: impl AsContextMut, + handle: WasiHttpHandle, + buf_ptr: u32, + buf_len: u32, + buf_written_ptr: u32, + ) -> Result<(), HttpError> { + let st = st.read()?; + + let headers = &st + .responses + .get(&handle) + .ok_or(HttpError::InvalidHandle(handle))? + .headers; + + let headers = match header_map_to_string(headers) { + Ok(res) => res, + Err(_) => return Err(HttpError::RuntimeError), + }; + + if headers.len() > buf_len as _ { + return Err(HttpError::BufferTooSmall); + } + + let mut store = store.as_context_mut(); + + memory.write(&mut store, buf_ptr as _, headers.as_bytes())?; + memory.write( + &mut store, + buf_written_ptr as _, + &(headers.len() as u32).to_le_bytes(), + )?; + Ok(()) + } + + /// Execute a request for a guest module, given + /// the request data. + #[allow(clippy::too_many_arguments)] + fn req( + st: Arc>, + allowed_hosts: Option<&[String]>, + max_concurrent_requests: Option, + memory: Memory, + mut store: impl AsContextMut, + url_ptr: u32, + url_len: u32, + method_ptr: u32, + method_len: u32, + req_headers_ptr: u32, + req_headers_len: u32, + req_body_ptr: u32, + req_body_len: u32, + status_code_ptr: u32, + res_handle_ptr: u32, + ) -> Result<(), HttpError> { + let span = tracing::trace_span!("req"); + let _enter = span.enter(); + + let mut st = st.write()?; + if let Some(max) = max_concurrent_requests { + if st.responses.len() > (max - 1) as usize { + return Err(HttpError::TooManySessions); + } + }; + + let mut store = store.as_context_mut(); + + // Read the request parts from the module's linear memory and check early if + // the guest is allowed to make a request to the given URL. + let url = string_from_memory(&memory, &mut store, url_ptr, url_len)?; + if !is_allowed(url.as_str(), allowed_hosts)? { + return Err(HttpError::DestinationNotAllowed(url)); + } + + let method = Method::from_str( + string_from_memory(&memory, &mut store, method_ptr, method_len)?.as_str(), + ) + .map_err(|_| HttpError::InvalidMethod)?; + let req_body = slice_from_memory(&memory, &mut store, req_body_ptr, req_body_len)?; + let headers = string_to_header_map( + string_from_memory(&memory, &mut store, req_headers_ptr, req_headers_len)?.as_str(), + ) + .map_err(|_| HttpError::InvalidEncoding)?; + + // Send the request. + let (status, resp_headers, resp_body) = + request(url.as_str(), headers, method, req_body.as_slice())?; + tracing::debug!( + status, + ?resp_headers, + body_len = resp_body.as_ref().len(), + "got HTTP response, writing back to memory" + ); + + // Write the status code to the guest. + memory.write(&mut store, status_code_ptr as _, &status.to_le_bytes())?; + + // Construct the response, add it to the current state, and write + // the handle to the guest. + let response = Response { + headers: resp_headers, + body: Body { + bytes: resp_body, + pos: 0, + }, + }; + + let initial_handle = st.current_handle; + while st.responses.get(&st.current_handle).is_some() { + st.current_handle += 1; + if st.current_handle == initial_handle { + return Err(HttpError::TooManySessions); + } + } + let handle = st.current_handle; + st.responses.insert(handle, response); + memory.write(&mut store, res_handle_ptr as _, &handle.to_le_bytes())?; + + Ok(()) + } +} + +/// Experimental HTTP extension object for Wasmtime. +pub struct HttpCtx { + state: Arc>, + allowed_hosts: Arc>>, + max_concurrent_requests: Option, +} + +impl HttpCtx { + pub const MODULE: &'static str = "wasi_experimental_http"; + pub fn new( + allowed_hosts: Option>, + max_concurrent_requests: Option, + ) -> Result { + let state = Arc::new(RwLock::new(State::default())); + let allowed_hosts = Arc::new(allowed_hosts); + Ok(HttpCtx { + state, + allowed_hosts, + max_concurrent_requests, + }) + } + + pub fn add_to_linker(&self, linker: &mut Linker) -> Result<(), Error> { + let st = self.state.clone(); + linker.func_wrap( + Self::MODULE, + "close", + move |handle: WasiHttpHandle| -> u32 { + match HostCalls::close(st.clone(), handle) { + Ok(()) => 0, + Err(e) => e.into(), + } + }, + )?; + + let st = self.state.clone(); + linker.func_wrap( + Self::MODULE, + "body_read", + move |mut caller: Caller<'_, T>, + handle: WasiHttpHandle, + buf_ptr: u32, + buf_len: u32, + buf_read_ptr: u32| + -> u32 { + let memory = match memory_get(&mut caller) { + Ok(m) => m, + Err(e) => return e.into(), + }; + + let ctx = caller.as_context_mut(); + + match HostCalls::body_read( + st.clone(), + memory, + ctx, + handle, + buf_ptr, + buf_len, + buf_read_ptr, + ) { + Ok(()) => 0, + Err(e) => e.into(), + } + }, + )?; + + let st = self.state.clone(); + linker.func_wrap( + Self::MODULE, + "header_get", + move |mut caller: Caller<'_, T>, + handle: WasiHttpHandle, + name_ptr: u32, + name_len: u32, + value_ptr: u32, + value_len: u32, + value_written_ptr: u32| + -> u32 { + let memory = match memory_get(&mut caller) { + Ok(m) => m, + Err(e) => return e.into(), + }; + + let ctx = caller.as_context_mut(); + + match HostCalls::header_get( + st.clone(), + memory, + ctx, + handle, + name_ptr, + name_len, + value_ptr, + value_len, + value_written_ptr, + ) { + Ok(()) => 0, + Err(e) => e.into(), + } + }, + )?; + + let st = self.state.clone(); + linker.func_wrap( + Self::MODULE, + "headers_get_all", + move |mut caller: Caller<'_, T>, + handle: WasiHttpHandle, + buf_ptr: u32, + buf_len: u32, + buf_read_ptr: u32| + -> u32 { + let memory = match memory_get(&mut caller) { + Ok(m) => m, + Err(e) => return e.into(), + }; + + let ctx = caller.as_context_mut(); + + match HostCalls::headers_get_all( + st.clone(), + memory, + ctx, + handle, + buf_ptr, + buf_len, + buf_read_ptr, + ) { + Ok(()) => 0, + Err(e) => e.into(), + } + }, + )?; + + let st = self.state.clone(); + let allowed_hosts = self.allowed_hosts.clone(); + let max_concurrent_requests = self.max_concurrent_requests; + linker.func_wrap( + Self::MODULE, + "req", + move |mut caller: Caller<'_, T>, + url_ptr: u32, + url_len: u32, + method_ptr: u32, + method_len: u32, + req_headers_ptr: u32, + req_headers_len: u32, + req_body_ptr: u32, + req_body_len: u32, + status_code_ptr: u32, + res_handle_ptr: u32| + -> u32 { + let memory = match memory_get(&mut caller) { + Ok(m) => m, + Err(e) => return e.into(), + }; + + let ctx = caller.as_context_mut(); + + match HostCalls::req( + st.clone(), + allowed_hosts.as_deref(), + max_concurrent_requests, + memory, + ctx, + url_ptr, + url_len, + method_ptr, + method_len, + req_headers_ptr, + req_headers_len, + req_body_ptr, + req_body_len, + status_code_ptr, + res_handle_ptr, + ) { + Ok(()) => 0, + Err(e) => e.into(), + } + }, + )?; + + Ok(()) + } +} + +#[tracing::instrument] +fn request( + url: &str, + headers: HeaderMap, + method: Method, + body: &[u8], +) -> Result<(u16, HeaderMap, Bytes), HttpError> { + tracing::debug!( + %url, + ?headers, + ?method, + body_len = body.len(), + "performing request" + ); + let url: Url = url.parse().map_err(|_| HttpError::InvalidUrl)?; + let body = body.to_vec(); + match Handle::try_current() { + Ok(r) => { + // If running in a Tokio runtime, spawn a new blocking executor + // that will send the HTTP request, and block on its execution. + // This attempts to avoid any deadlocks from other operations + // already executing on the same executor (compared with just + // blocking on the current one). + // + // This should only be a temporary workaround, until we take + // advantage of async functions in Wasmtime. + tracing::trace!("tokio runtime available, spawning request on tokio thread"); + block_on(r.spawn_blocking(move || { + let mut client = Client::builder(); + let headr = headers.get("Upgrade-Insecure-Requests").unwrap(); + + if headr.to_str().unwrap() == "1" { + client = client.danger_accept_invalid_certs(true); + } else { + client = client.danger_accept_invalid_certs(false); + } + + let res = block_on( + client + .build() + .unwrap() + .request(method, url) + .headers(headers) + .body(body) + .send(), + )?; + Ok(( + res.status().as_u16(), + res.headers().clone(), + block_on(res.bytes())?, + )) + })) + .map_err(|_| HttpError::RuntimeError)? + } + Err(_) => { + tracing::trace!("no tokio runtime available, using blocking request"); + let mut client = reqwest::blocking::Client::builder(); + let headr = headers.get("Upgrade-Insecure-Requests").unwrap(); + if headr.to_str().unwrap() == "1" { + client = client.danger_accept_invalid_certs(true); + } else { + client = client.danger_accept_invalid_certs(false); + } + + let res = client + .build() + .unwrap() + .request(method, url) + .headers(headers) + .body(body) + .send()?; + return Ok((res.status().as_u16(), res.headers().clone(), res.bytes()?)); + } + } +} + +/// Get the exported memory block called `memory`. +/// This will return an `HttpError::MemoryNotFound` if the module does +/// not export a memory block. +fn memory_get(caller: &mut Caller<'_, T>) -> Result { + if let Some(Extern::Memory(mem)) = caller.get_export(MEMORY) { + Ok(mem) + } else { + Err(HttpError::MemoryNotFound) + } +} + +/// Get a slice of length `len` from `memory`, starting at `offset`. +/// This will return an `HttpError::BufferTooSmall` if the size of the +/// requested slice is larger than the memory size. +fn slice_from_memory( + memory: &Memory, + mut ctx: impl AsContextMut, + offset: u32, + len: u32, +) -> Result, HttpError> { + let required_memory_size = offset.checked_add(len).ok_or(HttpError::BufferTooSmall)? as usize; + + if required_memory_size > memory.data_size(&mut ctx) { + return Err(HttpError::BufferTooSmall); + } + + let mut buf = vec![0u8; len as usize]; + memory.read(&mut ctx, offset as usize, buf.as_mut_slice())?; + Ok(buf) +} + +/// Read a string of byte length `len` from `memory`, starting at `offset`. +fn string_from_memory( + memory: &Memory, + ctx: impl AsContextMut, + offset: u32, + len: u32, +) -> Result { + let slice = slice_from_memory(memory, ctx, offset, len)?; + Ok(std::str::from_utf8(&slice)?.to_string()) +} + +/// Check if guest module is allowed to send request to URL, based on the list of +/// allowed hosts defined by the runtime. +/// If `None` is passed, the guest module is not allowed to send the request. +fn is_allowed(url: &str, allowed_hosts: Option<&[String]>) -> Result { + let url_host = Url::parse(url) + .map_err(|_| HttpError::InvalidUrl)? + .host_str() + .ok_or(HttpError::InvalidUrl)? + .to_owned(); + match allowed_hosts { + Some(domains) => { + let allowed: Result, _> = domains.iter().map(|d| Url::parse(d)).collect(); + let allowed = allowed.map_err(|_| HttpError::InvalidUrl)?; + Ok(allowed + .iter() + .map(|u| u.host_str().unwrap()) + .any(|x| x == url_host.as_str())) + } + None => Ok(false), + } +} + +// The following two functions are copied from the `wasi_experimental_http` +// crate, because the Windows linker apparently cannot handle unresolved +// symbols from a crate, even when the caller does not actually use any of the +// external symbols. +// +// https://github.com/rust-lang/rust/issues/86125 + +/// Decode a header map from a string. +fn string_to_header_map(s: &str) -> Result { + let mut headers = HeaderMap::new(); + for entry in s.lines() { + let mut parts = entry.splitn(2, ':'); + #[allow(clippy::or_fun_call)] + let k = parts.next().ok_or(anyhow::format_err!( + "Invalid serialized header: [{}]", + entry + ))?; + let v = parts.next().unwrap(); + headers.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?); + } + Ok(headers) +} + +/// Encode a header map as a string. +fn header_map_to_string(hm: &HeaderMap) -> Result { + let mut res = String::new(); + for (name, value) in hm + .iter() + .map(|(name, value)| (name.as_str(), std::str::from_utf8(value.as_bytes()))) + { + let value = value?; + anyhow::ensure!( + !name + .chars() + .any(|x| x.is_control() || "(),/:;<=>?@[\\]{}".contains(x)), + "Invalid header name" + ); + anyhow::ensure!( + !value.chars().any(|x| x.is_control()), + "Invalid header value" + ); + res.push_str(&format!("{}:{}\n", name, value)); + } + Ok(res) +} + +#[test] +#[allow(clippy::bool_assert_comparison)] +fn test_allowed_domains() { + let allowed_domains = vec![ + "https://api.brigade.sh".to_string(), + "https://example.com".to_string(), + "http://192.168.0.1".to_string(), + ]; + + assert_eq!( + true, + is_allowed( + "https://api.brigade.sh/healthz", + Some(allowed_domains.as_ref()) + ) + .unwrap() + ); + assert_eq!( + true, + is_allowed( + "https://example.com/some/path/with/more/paths", + Some(allowed_domains.as_ref()) + ) + .unwrap() + ); + assert_eq!( + true, + is_allowed("http://192.168.0.1/login", Some(allowed_domains.as_ref())).unwrap() + ); + assert_eq!( + false, + is_allowed("https://test.brigade.sh", Some(allowed_domains.as_ref())).unwrap() + ); +} diff --git a/workflow-with-state-management-poc/state-managed-workflow/Cargo.toml b/workflow-with-state-management-poc/state-managed-workflow/Cargo.toml new file mode 100644 index 0000000..a861b7a --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/Cargo.toml @@ -0,0 +1,37 @@ + +[package] +name = "boilerplate" +version = "0.0.1" +edition = "2018" + + +[lib] +crate-type = ["cdylib"] + +[profile.release] +lto = true +codegen-units = 1 +overflow-checks = true +# Tell `rustc` to optimize for small code size. +opt-level = "z" +debug = false + +[workspace] + +[dependencies] +derive-enum-from-into = "0.1.1" +serde_derive = "1.0.192" +paste = "1.0.7" +dyn-clone = "1.0.7" +workflow_macro = "0.0.3" +openwhisk-rust = "0.1.2" +serde_json = { version = "1.0", features = ["raw_value"] } +serde = { version = "1.0.192", features = ["derive"] } +codec = { package = "parity-scale-codec", features = [ + "derive", +], version = "3.1.5" } + + +openwhisk_macro = "0.1.6" +# cached = { version = "0.49.2", features = [ "redis_store"]} +# sha256 = "1.5.0" diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/common.rs b/workflow-with-state-management-poc/state-managed-workflow/src/common.rs new file mode 100644 index 0000000..3beda5b --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/common.rs @@ -0,0 +1,285 @@ +#![allow(unused_imports)] +use super::*; +use alloc::task; +use paste::paste; +#[derive(Debug)] +pub struct WorkflowGraph { + edges: Vec<(usize, usize)>, + nodes: Vec>, + pub workflow_id: String, + pub state_manger: StateManager, +} + +impl WorkflowGraph { + pub fn new(size: usize, workflow_id: &str) -> Self { + WorkflowGraph { + nodes: Vec::with_capacity(size), + edges: Vec::new(), + workflow_id: workflow_id.to_string(), + state_manger: StateManager::init(), + } + } +} + +impl WorkflowGraph { + pub fn node_count(&self) -> usize { + self.nodes.len() + } + + pub fn add_node(&mut self, task: Box) -> usize { + let len = self.nodes.len(); + self.nodes.push(task); + len + } + + pub fn add_edge(&mut self, parent: usize, child: usize) { + self.edges.push((parent, child)); + } + + pub fn add_edges(&mut self, edges: &[(usize, usize)]) { + edges + .iter() + .for_each(|(source, destination)| self.add_edge(*source, *destination)); + } + + pub fn get_task(&self, index: usize) -> &Box { + self.nodes.get(index).unwrap() + } + + pub fn get_task_as_mut(&mut self, index: usize) -> &mut Box { + self.nodes.get_mut(index).unwrap() + } + + pub fn node_indices(&self) -> Vec { + (0..self.node_count()).collect::>() + } + + // pub fn init(&mut self) -> Result<&mut Self, String> { + // match self.get_task_as_mut(0).execute() { + // Ok(()) => Ok(self), + // Err(err) => Err(err), + // } + // } + + // pub fn term(&mut self, task_index: Option) -> Result { + + // match task_index { + // Some(index) => { + // let mut list = Vec::new(); + // let edges_list = self.edges.clone(); + // edges_list.iter().for_each(|(source, destination)| { + // if destination == &index { + // list.push(source) + // } + // }); + // match list.len() { + // 0 => { + // let current_task = self.get_task_as_mut(index); + // match current_task.execute() { + // Ok(()) => Ok(current_task.get_task_output()), + // Err(err) => Err(err), + // } + // } + // 1 => { + // let previous_task_output = self.get_task(*list[0]).get_task_output(); + // let current_task = self.get_task_as_mut(index); + // current_task.set_output_to_task(previous_task_output); + // match current_task.execute() { + // Ok(()) => Ok(current_task.get_task_output()), + // Err(err) => Err(err), + // } + // } + // _ => { + // let res: Vec = list + // .iter() + // .map(|index| { + // let previous_task = self.get_task(**index); + // let previous_task_output = previous_task.get_task_output(); + // previous_task_output + // }) + // .collect(); + + // let s: Value = res.into(); + // let current_task = self.get_task_as_mut(index); + // current_task.set_output_to_task(s); + + // match current_task.execute() { + // Ok(()) => Ok(current_task.get_task_output()), + // Err(err) => Err(err), + // } + // } + // } + // } + // None => { + // let len = self.node_count(); + // Ok(self.get_task(len - 1).get_task_output()) + // } + // } + // } + + pub fn pipe(&mut self, task_index: usize) -> Result<&mut Self, String> { + let len = self.nodes.len() - 1; + + // let redis_cache: RedisCache = RedisCache::new(self.workflow_id.clone(), 60*60) + // .set_connection_string("redis://127.0.0.1:6379") + // .set_refresh(true) + // .set_namespace("workflows") + // .build() + // .unwrap(); + + let task = self.get_task(task_index); + // let key = digest(serde_json::to_string(&task.get_json_string()).unwrap()); + let action_name = task.get_action_name(); + + self.state_manger.update_running(&action_name, task_index as isize); + + // let output = redis_cache.cache_get(&key).unwrap(); + // if output.is_some() { + // println!("cache hit for task {}", action_name); + // let output = output.unwrap(); + // let task = self.get_task_as_mut(task_index); + // task.set_output_to_task(output); + // } + + let result = { + + let mut list = Vec::new(); + let edges_list = self.edges.clone(); + edges_list.iter().for_each(|(source, destination)| { + if destination == &task_index { + list.push(source) + } + }); + + match list.len() { + 0 => { + + let task = self.get_task_as_mut(task_index); + match task.execute() { + Ok(()) => Ok(task.get_task_output()), + Err(err) => Err(err), + } + }, + 1 => { + let previous_task_output = self.get_task(*list[0]).get_task_output(); + let current_task = self.get_task_as_mut(task_index); + current_task.set_output_to_task(previous_task_output); + match current_task.execute() { + Ok(()) => Ok(current_task.get_task_output()), + Err(err) => Err(err), + } + } + _ => { + let mut res: Vec = list + .iter() + .map(|index| { + let previous_task = self.get_task(**index); + let previous_task_output = previous_task.get_task_output(); + previous_task_output + }) + .collect(); + + let s: Value = res.into(); + let current_task = self.get_task_as_mut(task_index); + current_task.set_output_to_task(s); + + match current_task.execute() { + Ok(()) => Ok(current_task.get_task_output()), + Err(err) => Err(err), + } + } + } + }; + + // let result = if let 3 = task_index { + // Err("error in task".to_string()) + // }else{ + // result + // }; + + match result { + Ok(output) => { + self.state_manger.update_success(output); + Ok(self) + } + Err(err) => { + self.state_manger.update_err(&err); + Err(err) + } + } + } +} + +#[macro_export] +macro_rules! impl_execute_trait { + ($ ($struct : ty), *) => { + + paste!{$( + impl Execute for $struct { + fn execute(&mut self) -> Result<(),String>{ + self.run() + } + + fn get_task_output(&self) -> Value { + self.output().clone().into() + } + + fn set_output_to_task(&mut self, input: Value) { + self.setter(input) + } + + fn get_action_name(&self) -> String{ + self.action_name.clone() + } + + fn get_json_string(&self) -> String{ + serde_json::to_string(&self).unwrap() + } + + } + )*} + }; +} + +#[allow(dead_code, unused)] +pub fn join_hashmap( + first: HashMap, + second: HashMap, +) -> HashMap { + let mut data: HashMap = HashMap::new(); + for (key, value) in first { + for (s_key, s_value) in &second { + if key.clone() == *s_key { + data.insert(key.clone(), (value.clone(), s_value.clone())); + } + } + } + data +} + +#[no_mangle] +pub unsafe extern "C" fn free_memory(ptr: *mut u8, size: u32, alignment: u32) { + let layout = Layout::from_size_align_unchecked(size as usize, alignment as usize); + alloc::alloc::dealloc(ptr, layout); +} + +#[link(wasm_import_module = "host")] +extern "C" { + pub fn set_output(ptr: i32, size: i32); +} + +#[link(wasm_import_module = "host")] +extern "C" { + pub fn set_state(ptr: i32, size: i32); +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Output { + pub result: Value, +} + +#[no_mangle] +pub unsafe extern "C" fn memory_alloc(size: u32, alignment: u32) -> *mut u8 { + let layout = Layout::from_size_align_unchecked(size as usize, alignment as usize); + alloc::alloc::alloc(layout) +} diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/lib.rs b/workflow-with-state-management-poc/state-managed-workflow/src/lib.rs new file mode 100644 index 0000000..e26d240 --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/lib.rs @@ -0,0 +1,65 @@ +#![allow(unused_imports)] +#![allow(unused_macros)] +#![allow(unused_variables)] +#![allow(dead_code)] +#![allow(forgetting_copy_types)] +#![allow(unused_mut)] +#![allow(unused_must_use)] + +mod common; +mod macros; +mod state_manager; +mod traits; +mod types; + +use state_manager::*; + +use common::*; +use derive_enum_from_into::{EnumFrom, EnumTryInto}; +use dyn_clone::{clone_trait_object, DynClone}; +use macros::*; +use openwhisk_rust::*; +use paste::*; +use serde::{Deserialize, Serialize}; +use serde_json::to_value; +use serde_json::Value; +use std::collections::HashMap; +use std::convert::TryInto; +use std::fmt::Debug; +use traits::*; +use types::*; +use workflow_macro::Flow; +extern crate alloc; +use codec::{Decode, Encode}; +use core::alloc::Layout; +// use cached::RedisCache; +// use cached::IOCached; +// use sha256::digest; + +#[no_mangle] +pub fn _start(ptr: *mut u8, length: i32) { + let result: Value; + unsafe { + let mut vect = Vec::new(); + for i in 1..=length { + if let Some(val_back) = ptr.as_ref() { + vect.push(val_back.clone()); + } + *ptr = *ptr.add(i as usize); + } + result = serde_json::from_slice(&vect).unwrap(); + } + + let res = main(result); + let output = Output { + result: serde_json::to_value(res).unwrap(), + }; + let serialized = serde_json::to_vec(&output).unwrap(); + let size = serialized.len() as i32; + let ptr = serialized.as_ptr(); + std::mem::forget(ptr); + + unsafe { + set_output(ptr as i32, size); + } +} diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/macros.rs b/workflow-with-state-management-poc/state-managed-workflow/src/macros.rs new file mode 100644 index 0000000..c18db7e --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/macros.rs @@ -0,0 +1,175 @@ +use super::*; + +#[macro_export] +macro_rules! make_input_struct { + ( + $x:ident, + [$( + $(#[$default_derive:stmt])? + $visibility:vis $element:ident : $ty:ty),*], + [$($der:ident),*] +) => { + #[derive($($der),*)] + pub struct $x { + $( + $(#[serde(default=$default_derive)])? + $visibility $element: $ty + ),* + } + } +} + +#[macro_export] +macro_rules! make_main_struct { + ( + $name:ident, + $input:ty, + [$($der:ident),*], + [$($key:ident : $val:expr),*], + $output_field: ident +) => { + #[derive($($der),*)] + $( + #[$key = $val] + )* + pub struct $name { + action_name: String, + pub input: $input, + pub output: Value, + pub mapout: Value + } + impl $name{ + pub fn output(&self) -> Value { + self.$output_field.clone() + } + } + } +} + +#[macro_export] +macro_rules! impl_new { + ( + $name:ident, + $input:ident, + [] + ) => { + impl $name{ + pub fn new(action_name:String) -> Self{ + Self{ + action_name, + input: $input{ + ..Default::default() + }, + ..Default::default() + } + } + } + }; + ( + $name:ident, + $input:ident, + [$($element:ident : $ty:ty),*] + ) => { + impl $name{ + pub fn new($( $element: $ty),*, action_name:String) -> Self{ + Self{ + action_name, + input: $input{ + $($element),*, + ..Default::default() + }, + ..Default::default() + } + } + } + } +} + +#[macro_export] +macro_rules! impl_setter { + ( + $name:ty, + [$($element:ident : $key:expr),*] + ) => { + impl $name{ + pub fn setter(&mut self, value: Value) { + $( + let val = value.get($key).unwrap(); + self.input.$element = serde_json::from_value(val.clone()).unwrap(); + )* + } + } + } +} + +#[macro_export] +macro_rules! impl_map_setter { + ( + $name:ty, + $element:ident : $key:expr, + $typ_name : ty, + $out:expr + ) => { + impl $name { + pub fn setter(&mut self, val: Value) { + let value = val.get($key).unwrap(); + let value = serde_json::from_value::>(value.clone()).unwrap(); + let mut map: HashMap<_, _> = value + .iter() + .map(|x| { + self.input.$element = x.to_owned() as $typ_name; + self.run(); + (x.to_owned(), self.output.get($out).unwrap().to_owned()) + }) + .collect(); + self.mapout = to_value(map).unwrap(); + } + } + }; +} + +#[macro_export] +macro_rules! impl_concat_setter { + ( + $name:ty, + $input:ident + ) => { + impl $name { + pub fn setter(&mut self, val: Value) { + let val: Vec = serde_json::from_value(val).unwrap(); + let res = join_hashmap( + serde_json::from_value(val[0].to_owned()).unwrap(), + serde_json::from_value(val[1].to_owned()).unwrap(), + ); + self.input.$input = res; + } + } + }; +} + +#[allow(unused)] +#[macro_export] +macro_rules! impl_combine_setter { + ( + $name:ty, + [$( + $(($value_input:ident))? + $([$index:expr])? + $element:ident : $key:expr),*] + ) => { + impl $name{ + pub fn setter(&mut self, value: Value) { + + let value: Vec = serde_json::from_value(value).unwrap(); + $( + if stringify!($($value_input)*).is_empty(){ + let val = value[$($index)*].get($key).unwrap(); + self.input.$element = serde_json::from_value(val.clone()).unwrap(); + }else{ + self.input.$element = serde_json::from_value(value[$($index)*].to_owned()).unwrap(); + } + )* + } + } + } +} diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/state_manager.rs b/workflow-with-state-management-poc/state-managed-workflow/src/state_manager.rs new file mode 100644 index 0000000..18cdf26 --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/state_manager.rs @@ -0,0 +1,100 @@ +use super::*; +use crate::WorkflowGraph; +use core::default; + +#[derive(Debug, Serialize, Deserialize, Clone)] +enum ExecutionState { + // Init, + Running, + Aborted, + Success, +} + +impl Default for ExecutionState { + fn default() -> Self { + ExecutionState::Running + } +} + +#[derive(Default, Debug)] +pub struct StateManager { + // execution_state: ExecutionState, // to represent the task life cycle + action_name: String, // task name + task_index: isize, // n'th task out of m tasks + execution_state: ExecutionState, + output: Option, + error: Option, // to define the error kind +} + +impl StateManager { + fn update_state_data(&self) { + let state_data: serde_json::Value = serde_json::json!( + { + "action_name": self.action_name, + "task_index": self.task_index, + "execution_state": self.execution_state, + "output": self.output, + "error": self.error + } + ); + + let serialized = serde_json::to_vec(&state_data).unwrap(); + let size = serialized.len() as i32; + let ptr = serialized.as_ptr(); + + std::mem::forget(ptr); + + unsafe { + super::set_state(ptr as i32, size); + } + } + + pub fn init() -> Self { + let state_data = StateManager { + action_name: "Initializing Workflow".to_string(), + execution_state: ExecutionState::Running, + output: None, + task_index: -1, + error: None, + }; + + state_data.update_state_data(); + state_data + } + + pub fn update_workflow_initialized(&mut self) { + self.execution_state = ExecutionState::Success; + self.task_index = -1; + self.error = None; + self.update_state_data(); + } + + // pub fn update(&mut self, action_name: &str, task_index: isize, execution_state: &str, error: Option) { + // self.action_name = action_name.to_string(); + // self.task_index = task_index; + // self.execution_state = execution_state.to_string(); + // self.error = error; + // self.update_state_data(); + // } + + pub fn update_running(&mut self, action_name: &str, task_index: isize) { + self.action_name = action_name.to_string(); + self.task_index = task_index; + self.execution_state = ExecutionState::Running; + self.output = None; + self.update_state_data(); + } + + pub fn update_success(&mut self, output: Value) { + self.output = Some(output); + self.execution_state = ExecutionState::Success; + self.update_state_data(); + } + + pub fn update_err(&mut self, error: &str) { + self.execution_state = ExecutionState::Aborted; + self.error = Some(error.to_string()); + self.update_state_data(); + } + +} diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/traits.rs b/workflow-with-state-management-poc/state-managed-workflow/src/traits.rs new file mode 100644 index 0000000..0af49dc --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/traits.rs @@ -0,0 +1,11 @@ +use super::*; + +pub trait Execute: Debug + DynClone { + fn execute(&mut self) -> Result<(), String>; + fn get_task_output(&self) -> Value; + fn set_output_to_task(&mut self, inp: Value); + fn get_action_name(&self) -> String; + fn get_json_string(&self) -> String; +} + +clone_trait_object!(Execute); diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/types-cmp.rs b/workflow-with-state-management-poc/state-managed-workflow/src/types-cmp.rs new file mode 100644 index 0000000..59e5c82 --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/types-cmp.rs @@ -0,0 +1,122 @@ +use super::*; +use openwhisk_macro::*; +use openwhisk_rust::*; + +make_input_struct!( +ModelavailInput, +[car_company_list:HashMap>,company_name:String], +[Debug, Clone, Default, Serialize, Deserialize] +); +make_input_struct!( +PurchaseInput, +[model_price_list:HashMap,model_name:String,price:i32], +[Debug, Clone, Default, Serialize, Deserialize] +); +make_input_struct!( +CartypeInput, +[car_type:String], +[Debug, Clone, Default, Serialize, Deserialize] +); +make_input_struct!( +ModelspriceInput, +[models:Vec], +[Debug, Clone, Default, Serialize, Deserialize] +); + +make_main_struct!( + Modelavail, + ModelavailInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP",ApiHost:"http://127.0.0.1:8080",Insecure:"true",Namespace:"guest"], + output +); +impl_new!( + Modelavail, + ModelavailInput, + [company_name:String] +); + +make_main_struct!( + Purchase, + PurchaseInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [Namespace:"guest",AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP",Insecure:"true",ApiHost:"http://127.0.0.1:8080"], + output +); +impl_new!( + Purchase, + PurchaseInput, + [model_name:String,price:i32] +); + +make_main_struct!( + Cartype, + CartypeInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP",Namespace:"guest",Insecure:"true",ApiHost:"http://127.0.0.1:8080"], + output +); +impl_new!( + Cartype, + CartypeInput, + [car_type:String] +); + +make_main_struct!( + Modelsprice, + ModelspriceInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP",ApiHost:"http://127.0.0.1:8080",Insecure:"true",Namespace:"guest"], + output +); +impl_new!(Modelsprice, ModelspriceInput, []); + +impl_setter!(Modelavail, [car_company_list:"car_company_list"]); +impl_setter!(Purchase, [model_price_list:"model_price_list"]); +impl_setter!(Cartype, []); +impl_setter!(Modelsprice, [models:"models"]); + +pub fn car_type_fn() -> String { + "tesla".to_string() +} + +make_input_struct!( +Input, +[company_name:String,model_name:String,price:i32,#["car_type_fn"] car_type:String], +[Debug, Clone, Default, Serialize, Deserialize] +); +impl_execute_trait!(Modelavail, Purchase, Cartype, Modelsprice); +#[allow(dead_code, unused)] +pub fn main(args: Value) -> Result { + const LIMIT: usize = 4; + let mut workflow = WorkflowGraph::new(LIMIT); + let input: Input = serde_json::from_value(args).map_err(|e| e.to_string())?; + + let modelavail = Modelavail::new(input.company_name, "modelavail".to_string()); + let purchase = Purchase::new(input.model_name, input.price, "purchase".to_string()); + let cartype = Cartype::new(input.car_type, "cartype".to_string()); + let modelsprice = Modelsprice::new("modelsprice".to_string()); + + let cartype_index = workflow.add_node(Box::new(cartype)); + let modelavail_index = workflow.add_node(Box::new(modelavail)); + let modelsprice_index = workflow.add_node(Box::new(modelsprice)); + let purchase_index = workflow.add_node(Box::new(purchase)); + + workflow.add_edges(&[ + (cartype_index, modelavail_index), + (modelavail_index, modelsprice_index), + (modelsprice_index, purchase_index), + ]); + let result = workflow + .pipe(cartype_index)? + .pipe(modelavail_index)? + .pipe(modelsprice_index)? + .pipe(purchase_index)?; + + + let len = workflow.node_count(); + let output = workflow.get_task(len - 1).get_task_output(); + + let result = serde_json::to_value(output).unwrap(); + Ok(result) +} diff --git a/workflow-with-state-management-poc/state-managed-workflow/src/types.rs b/workflow-with-state-management-poc/state-managed-workflow/src/types.rs new file mode 100644 index 0000000..50f7f08 --- /dev/null +++ b/workflow-with-state-management-poc/state-managed-workflow/src/types.rs @@ -0,0 +1,117 @@ +use std::borrow::Borrow; + +use super::*; +use openwhisk_macro::*; +use openwhisk_rust::*; + +make_input_struct!( +EmployeeIdsInput, +[role:String], +[Debug, Clone, Default, Serialize, Deserialize] +); +make_input_struct!( +GetsalariesInput, +[id:i32], +[Debug, Clone, Default, Serialize, Deserialize] +); +make_input_struct!( +SalaryInput, +[details:HashMap], +[Debug, Clone, Default, Serialize, Deserialize] +); +make_input_struct!( +GetaddressInput, +[id:i32], +[Debug, Clone, Default, Serialize, Deserialize] +); + +make_main_struct!( + EmployeeIds, + EmployeeIdsInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP",Insecure:"true",Namespace:"guest",ApiHost:"http://127.0.0.1:1234"], + output +); +impl_new!( + EmployeeIds, + EmployeeIdsInput, + [role:String] +); + +make_main_struct!( + Getsalaries, + GetsalariesInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [Insecure:"true",AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP",Namespace:"guest",ApiHost:"http://127.0.0.1:1234"], + mapout +); +impl_new!(Getsalaries, GetsalariesInput, []); + +make_main_struct!( + Salary, + SalaryInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [ApiHost:"http://127.0.0.1:1234",Namespace:"guest",Insecure:"true",AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP"], + output +); +impl_new!(Salary, SalaryInput, []); + +make_main_struct!( + Getaddress, + GetaddressInput, + [Debug, Clone, Default, Serialize, Deserialize, OpenWhisk], + [Namespace:"guest",ApiHost:"http://127.0.0.1:1234",Insecure:"true",AuthKey:"23bc46b1-71f6-4ed5-8c54-816aa4f8c502:123zO3xZCLrMN6v2BKK1dXYFpXlPkccOFqm12CdAsMgRU4VrNZ9lyGVCGuMDGIwP"], + mapout +); +impl_new!(Getaddress, GetaddressInput, []); + +impl_setter!(EmployeeIds, []); +impl_map_setter!(Getsalaries, id:"ids", i32, "salary"); +impl_concat_setter!(Salary, details); +impl_map_setter!(Getaddress, id:"ids", i32, "address"); + +make_input_struct!( +Input, +[role:String], +[Debug, Clone, Default, Serialize, Deserialize] +); +impl_execute_trait!(EmployeeIds, Getsalaries, Salary, Getaddress); + +#[allow(dead_code, unused)] +pub fn main(args: Value) -> Result { + const LIMIT: usize = 4; + let mut workflow = WorkflowGraph::new(LIMIT, "employee_salary_id"); + workflow.state_manger.update_workflow_initialized(); + + let input: Input = serde_json::from_value(args).map_err(|e| e.to_string())?; + + let employee_ids = EmployeeIds::new(input.role, "employee_ids".to_string()); + let getsalaries = Getsalaries::new("getsalaries".to_string()); + let salary = Salary::new("salary".to_string()); + let getaddress = Getaddress::new("getaddress".to_string()); + + let employee_ids_index = workflow.add_node(Box::new(employee_ids)); + let getsalaries_index = workflow.add_node(Box::new(getsalaries)); + let getaddress_index = workflow.add_node(Box::new(getaddress)); + let salary_index = workflow.add_node(Box::new(salary)); + + workflow.add_edges(&[ + (employee_ids_index, getsalaries_index), + (employee_ids_index, getaddress_index), + (getsalaries_index, salary_index), + (getaddress_index, salary_index), + ]); + + workflow + .pipe(employee_ids_index)? + .pipe(getsalaries_index)? + .pipe(getaddress_index)? + .pipe(salary_index)?; // salary is depending and term does not handle multiple deps + + + let len = workflow.node_count(); + let output = workflow.get_task(len - 1).get_task_output(); + // simply returns the output + let result = serde_json::to_value(output).unwrap(); + Ok(result) +}