diff --git a/src/cli/error.rs b/src/cli/error.rs index 1b09b1c756..646575e05f 100644 --- a/src/cli/error.rs +++ b/src/cli/error.rs @@ -6,7 +6,7 @@ use thiserror::Error; use crate::valid::ValidationError; -#[derive(Debug, Error, Setters)] +#[derive(Debug, Error, Setters, PartialEq, Clone)] pub struct CLIError { is_root: bool, #[setters(skip)] @@ -156,6 +156,31 @@ impl From for CLIError { } } +impl From for CLIError { + fn from(error: anyhow::Error) -> Self { + // Convert other errors to CLIError + let cli_error = match error.downcast::() { + Ok(cli_error) => cli_error, + Err(error) => { + // Convert other errors to CLIError + let cli_error = match error.downcast::>() { + Ok(validation_error) => CLIError::from(validation_error), + Err(error) => { + let sources = error + .source() + .map(|error| vec![CLIError::new(error.to_string().as_str())]) + .unwrap_or_default(); + + CLIError::new(&error.to_string()).caused_by(sources) + } + }; + cli_error + } + }; + cli_error + } +} + impl From for CLIError { fn from(error: std::io::Error) -> Self { let cli_error = CLIError::new("IO Error"); @@ -353,4 +378,40 @@ mod tests { assert_eq!(error.to_string(), expected); } + + #[test] + fn test_cli_error_identity() { + let cli_error = CLIError::new("Server could not be started") + .description("The port is already in use".to_string()) + .trace(vec!["@server".into(), "port".into()]); + let anyhow_error: anyhow::Error = cli_error.clone().into(); + + let actual = CLIError::from(anyhow_error); + let expected = cli_error; + + assert_eq!(actual, expected); + } + + #[test] + fn test_validation_error_identity() { + let validation_error = ValidationError::from( + Cause::new("Test Error".to_string()).trace(vec!["Query".to_string()]), + ); + let anyhow_error: anyhow::Error = validation_error.clone().into(); + + let actual = CLIError::from(anyhow_error); + let expected = CLIError::from(validation_error); + + assert_eq!(actual, expected); + } + + #[test] + fn test_generic_error() { + let anyhow_error = anyhow::anyhow!("Some error msg"); + + let actual: CLIError = CLIError::from(anyhow_error); + let expected = CLIError::new("Some error msg"); + + assert_eq!(actual, expected); + } } diff --git a/src/main.rs b/src/main.rs index 07bc492a0e..bd2aa9d7ae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,6 @@ use std::cell::Cell; use mimalloc::MiMalloc; use tailcall::cli::CLIError; use tailcall::tracing::default_tailcall_tracing; -use tailcall::valid::ValidationError; use tracing::subscriber::DefaultGuard; #[global_allocator] @@ -47,24 +46,7 @@ fn main() -> anyhow::Result<()> { Ok(_) => {} Err(error) => { // Ensure all errors are converted to CLIErrors before being printed. - let cli_error = match error.downcast::() { - Ok(cli_error) => cli_error, - Err(error) => { - // Convert other errors to CLIError - let cli_error = match error.downcast::>() { - Ok(validation_error) => CLIError::from(validation_error), - Err(error) => { - let sources = error - .source() - .map(|error| vec![CLIError::new(error.to_string().as_str())]) - .unwrap_or_default(); - - CLIError::new(&error.to_string()).caused_by(sources) - } - }; - cli_error - } - }; + let cli_error: CLIError = error.into(); tracing::error!("{}", cli_error.color(true)); std::process::exit(exitcode::CONFIG); }