diff --git a/Cargo.lock b/Cargo.lock index 7fddf68..8c960db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,11 +18,12 @@ checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" name = "nftables" version = "0.2.3" dependencies = [ - "serde", - "serde_json", - "serde_path_to_error", - "strum", - "strum_macros", + "serde", + "serde_json", + "serde_path_to_error", + "strum", + "strum_macros", + "thiserror", ] [[package]] @@ -132,9 +133,29 @@ version = "2.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "15e3fc8c0c74267e2df136e5e5fb656a464158aa57624053375eb9c8c6e25ae2" dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "611040a08a0439f8248d1990b111c95baa9c704c805fa1f62104b39655fd7f90" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.25", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1ef4560..0c1530e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,10 +15,11 @@ exclude = [ ] [dependencies] -serde = {version = "1.0.137", features = ["derive"]} -serde_json = {version = "1.0.81"} +serde = { version = "1.0.137", features = ["derive"] } +serde_json = { version = "1.0.81" } serde_path_to_error = "0.1" strum = "0.24" strum_macros = "0.24" +thiserror = "1.0" [build-dependencies] diff --git a/src/helper.rs b/src/helper.rs index 4885cba..bf19fe6 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,21 +1,48 @@ +use std::string::FromUtf8Error; use std::{ io::{self, Write}, process::{Command, Stdio}, }; +use thiserror::Error; + use crate::schema::Nftables; const NFT_EXECUTABLE: &str = "nft"; // search in PATH -pub fn get_current_ruleset(program: Option<&str>, args: Option>) -> Nftables { - let output = get_current_ruleset_raw(program, args); - let nftables: Nftables = serde_json::from_str(&output).unwrap(); - nftables +#[derive(Error, Debug)] +pub enum NftablesError { + #[error("unable to execute {program}: {inner}")] + NftExecution { program: String, inner: io::Error }, + #[error("{program}'s output contained invalid utf8: {inner}")] + NftOutputEncoding { + program: String, + inner: FromUtf8Error, + }, + #[error("got invalid json: {0}")] + NftInvalidJson(serde_json::Error), + #[error("{program} did not return successfully while {hint}")] + NftFailed { + program: String, + hint: String, + stdout: String, + stderr: String, + }, } -pub fn get_current_ruleset_raw(program: Option<&str>, args: Option>) -> String { - let nft_executable: &str = program.unwrap_or(NFT_EXECUTABLE); - let mut nft_cmd = get_command(Some(nft_executable)); +pub fn get_current_ruleset( + program: Option<&str>, + args: Option>, +) -> Result { + let output = get_current_ruleset_raw(program, args)?; + serde_json::from_str(&output).map_err(NftablesError::NftInvalidJson) +} + +pub fn get_current_ruleset_raw( + program: Option<&str>, + args: Option>, +) -> Result { + let mut nft_cmd = get_command(program); let default_args = ["-j", "list", "ruleset"]; let args: Vec<&str> = match args { Some(mut args) => { @@ -24,21 +51,34 @@ pub fn get_current_ruleset_raw(program: Option<&str>, args: Option>) - } None => default_args.to_vec(), }; - let output = nft_cmd + let process_result = nft_cmd .args(args) .output() - .expect("nft command failed to start"); - if !output.status.success() { - panic!("nft failed to show the current ruleset"); + .map_err(|e| NftablesError::NftExecution { + inner: e, + program: format!("{}", nft_cmd.get_program().to_str().unwrap()), + })?; + + let stdout = read_output(&nft_cmd, process_result.stdout)?; + + if !process_result.status.success() { + let stderr = read_output(&nft_cmd, process_result.stderr)?; + + return Err(NftablesError::NftFailed { + program: format!("{}", nft_cmd.get_program().to_str().unwrap()), + hint: "getting the current ruleset".to_string(), + stdout, + stderr, + }); } - String::from_utf8(output.stdout).expect("failed to decode nft output as utf8") + Ok(stdout) } pub fn apply_ruleset( nftables: &Nftables, program: Option<&str>, args: Option>, -) -> io::Result<()> { +) -> Result<(), NftablesError> { let nftables = serde_json::to_string(nftables).expect("failed to serialize Nftables struct"); apply_ruleset_raw(nftables, program, args) } @@ -47,9 +87,8 @@ pub fn apply_ruleset_raw( payload: String, program: Option<&str>, args: Option>, -) -> io::Result<()> { - let nft_executable: &str = program.unwrap_or(NFT_EXECUTABLE); - let mut nft_cmd = get_command(Some(nft_executable)); +) -> Result<(), NftablesError> { + let mut nft_cmd = get_command(program); let default_args = ["-j", "-f", "-"]; let args: Vec<&str> = match args { Some(mut args) => { @@ -62,19 +101,39 @@ pub fn apply_ruleset_raw( .args(args) .stdin(Stdio::piped()) .stdout(Stdio::piped()) - .spawn()?; + .spawn() + .map_err(|e| NftablesError::NftExecution { + program: format!("{}", nft_cmd.get_program().to_str().unwrap()), + inner: e, + })?; let mut stdin = process.stdin.take().unwrap(); - stdin.write_all(payload.as_bytes())?; + stdin + .write_all(payload.as_bytes()) + .map_err(|e| NftablesError::NftExecution { + program: format!("{}", nft_cmd.get_program().to_str().unwrap()), + inner: e, + })?; drop(stdin); let result = process.wait_with_output(); match result { - Ok(output) => { - assert!(output.status.success()); - Ok(()) + Ok(output) if output.status.success() => Ok(()), + Ok(process_result) => { + let stdout = read_output(&nft_cmd, process_result.stdout)?; + let stderr = read_output(&nft_cmd, process_result.stderr)?; + + Err(NftablesError::NftFailed { + program: format!("{}", nft_cmd.get_program().to_str().unwrap()), + hint: "applying ruleset".to_string(), + stdout, + stderr, + }) } - Err(err) => Err(err), + Err(e) => Err(NftablesError::NftExecution { + program: format!("{}", nft_cmd.get_program().to_str().unwrap()), + inner: e, + }), } } @@ -82,3 +141,10 @@ fn get_command(program: Option<&str>) -> Command { let nft_executable: &str = program.unwrap_or(NFT_EXECUTABLE); Command::new(nft_executable) } + +fn read_output(cmd: &Command, bytes: Vec) -> Result { + String::from_utf8(bytes).map_err(|e| NftablesError::NftOutputEncoding { + inner: e, + program: format!("{}", cmd.get_program().to_str().unwrap()), + }) +} diff --git a/tests/helper_tests.rs b/tests/helper_tests.rs index d334af9..586a3eb 100644 --- a/tests/helper_tests.rs +++ b/tests/helper_tests.rs @@ -1,10 +1,10 @@ -use nftables::{batch::Batch, helper, schema, types, expr}; +use nftables::{batch::Batch, expr, helper, schema, types}; #[test] #[ignore] /// Reads current ruleset from nftables and reads it to `Nftables` Rust struct. fn test_list_ruleset() { - helper::get_current_ruleset(None, None); + helper::get_current_ruleset(None, None).unwrap(); } #[test] @@ -42,7 +42,10 @@ fn example_ruleset() -> schema::Nftables { family: types::NfFamily::IP, table: table_name, name: set_name, - elem: vec![expr::Expression::String("127.0.0.1".to_string()), expr::Expression::String("127.0.0.2".to_string())], + elem: vec![ + expr::Expression::String("127.0.0.1".to_string()), + expr::Expression::String("127.0.0.2".to_string()), + ], })); batch.delete(schema::NfListObject::Table(schema::Table::new( types::NfFamily::IP,