diff --git a/src/error.rs b/src/error.rs index 6f402cd6..7b58be33 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,6 +17,10 @@ pub enum SnOsError { Runner(CairoRunError), #[error("SnOs Output Error: {0}")] Output(String), + #[error(transparent)] + IO(#[from] std::io::Error), + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), } #[derive(thiserror::Error, Clone, Debug)] diff --git a/src/hints/mod.rs b/src/hints/mod.rs index 68278f14..3309622f 100644 --- a/src/hints/mod.rs +++ b/src/hints/mod.rs @@ -107,9 +107,12 @@ pub fn starknet_os_input( ap_tracking: &ApTracking, _constants: &HashMap, ) -> Result<(), HintError> { - let input_path = exec_scopes.get::("input_path").unwrap_or(DEFAULT_INPUT_PATH.to_string()); + let input_path = + std::path::PathBuf::from(exec_scopes.get::("input_path").unwrap_or(DEFAULT_INPUT_PATH.to_string())); - let os_input = Box::new(StarknetOsInput::load(&input_path)); + let os_input = Box::new( + StarknetOsInput::load(&input_path).map_err(|e| HintError::CustomHint(e.to_string().into_boxed_str()))?, + ); exec_scopes.assign_or_update_variable("os_input", os_input); let initial_carried_outputs_ptr = get_ptr_from_var_name("initial_carried_outputs", vm, ids_data, ap_tracking)?; diff --git a/src/io/mod.rs b/src/io/mod.rs index 0f4503dc..608f9652 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -34,15 +34,17 @@ pub struct StarknetOsInput { } impl StarknetOsInput { - pub fn load(path: &str) -> Self { - let raw_input = fs::read_to_string(path::PathBuf::from(path)).unwrap(); - serde_json::from_str(&raw_input).unwrap() + pub fn load(path: &path::Path) -> Result { + let raw_input = fs::read_to_string(path)?; + let input = serde_json::from_str(&raw_input)?; + + Ok(input) } - pub fn dump(&self, path: &str) -> Result<(), SnOsError> { - fs::File::create(path) - .unwrap() - .write_all(&serde_json::to_vec(&self).unwrap()) - .map_err(|e| SnOsError::CatchAll(format!("{e}"))) + + pub fn dump(&self, path: &path::Path) -> Result<(), SnOsError> { + fs::File::create(path)?.write_all(&serde_json::to_vec(&self)?)?; + + Ok(()) } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index d2e0be00..283f56c3 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -32,7 +32,7 @@ pub fn load_and_write_input() { #[fixture] #[once] pub fn load_input(_load_and_write_input: ()) -> StarknetOsInput { - StarknetOsInput::load(DEFAULT_INPUT_PATH) + StarknetOsInput::load(std::path::Path::new(DEFAULT_INPUT_PATH)).unwrap() } #[fixture]