From 2578458a7cdb80c102300dfebb3960091c1573ad Mon Sep 17 00:00:00 2001 From: Matei Radu Date: Sun, 21 Jul 2024 17:54:09 +0200 Subject: [PATCH] Refactor HeaderTryFromError I was not happy with my original approach to error wrapping, so I opted for some semplifications using the `From` trait. --- lib/src/message/error.rs | 60 +++++++++++++++++++++++---------------- lib/src/message/header.rs | 39 ++++++++++++------------- 2 files changed, 54 insertions(+), 45 deletions(-) diff --git a/lib/src/message/error.rs b/lib/src/message/error.rs index 557afe4..165a56e 100644 --- a/lib/src/message/error.rs +++ b/lib/src/message/error.rs @@ -49,36 +49,29 @@ impl fmt::Display for RCodeTryFromError { impl Error for RCodeTryFromError {} #[derive(Debug, PartialEq)] -pub enum MalformedFlagsError { - OpCode(OpCodeTryFromError), - Z(ZTryFromError), - RCode(RCodeTryFromError), +pub enum HeaderTryFromError { + InsufficientHeaderBytes(usize), + OpCodeTryFromError(OpCodeTryFromError), + ZTryFromError(ZTryFromError), + RCodeTryFromError(RCodeTryFromError), } -impl fmt::Display for MalformedFlagsError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - MalformedFlagsError::OpCode(e) => e.fmt(f), - MalformedFlagsError::Z(e) => e.fmt(f), - MalformedFlagsError::RCode(e) => e.fmt(f), - } +impl From for HeaderTryFromError { + fn from(error: OpCodeTryFromError) -> HeaderTryFromError { + HeaderTryFromError::OpCodeTryFromError(error) } } -impl Error for MalformedFlagsError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - MalformedFlagsError::OpCode(e) => Some(e), - MalformedFlagsError::Z(e) => Some(e), - MalformedFlagsError::RCode(e) => Some(e), - } +impl From for HeaderTryFromError { + fn from(error: ZTryFromError) -> HeaderTryFromError { + HeaderTryFromError::ZTryFromError(error) } } -#[derive(Debug, PartialEq)] -pub enum HeaderTryFromError { - InsufficientHeaderBytes(usize), - MalformedFlags(MalformedFlagsError), +impl From for HeaderTryFromError { + fn from(error: RCodeTryFromError) -> HeaderTryFromError { + HeaderTryFromError::RCodeTryFromError(error) + } } impl fmt::Display for HeaderTryFromError { @@ -87,7 +80,9 @@ impl fmt::Display for HeaderTryFromError { HeaderTryFromError::InsufficientHeaderBytes(len) => { write!(f, "insufficient header bytes ({} found, 12 required)", len) } - HeaderTryFromError::MalformedFlags(e) => e.fmt(f), + HeaderTryFromError::OpCodeTryFromError(e) => e.fmt(f), + HeaderTryFromError::ZTryFromError(e) => e.fmt(f), + HeaderTryFromError::RCodeTryFromError(e) => e.fmt(f), } } } @@ -95,8 +90,25 @@ impl fmt::Display for HeaderTryFromError { impl Error for HeaderTryFromError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - HeaderTryFromError::MalformedFlags(e) => Some(e), + HeaderTryFromError::OpCodeTryFromError(e) => Some(e), + HeaderTryFromError::ZTryFromError(e) => Some(e), + HeaderTryFromError::RCodeTryFromError(e) => Some(e), _ => None, } } } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case(HeaderTryFromError::InsufficientHeaderBytes(3), "insufficient header bytes (3 found, 12 required)".to_string())] + #[case(OpCodeTryFromError(14).into(), "OPCODE '14' is not supported".to_string())] + #[case(ZTryFromError.into(), "all Z bits most be zero".to_string())] + #[case(RCodeTryFromError(7).into(), "RCODE '7' is not supported".to_string())] + fn header_try_from_error_display(#[case] err: HeaderTryFromError, #[case] msg: String) { + assert_eq!(err.to_string(), msg); + } +} diff --git a/lib/src/message/header.rs b/lib/src/message/header.rs index f43c3de..79373ad 100644 --- a/lib/src/message/header.rs +++ b/lib/src/message/header.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::message::error::{ - HeaderTryFromError, MalformedFlagsError, OpCodeTryFromError, RCodeTryFromError, ZTryFromError, + HeaderTryFromError, OpCodeTryFromError, RCodeTryFromError, ZTryFromError, }; /// `Header` section of a DNS `Message`. @@ -82,12 +82,9 @@ impl TryFrom<&[u8]> for Header { let flags = u16::from_be_bytes([value[2], value[3]]); - let op_code = OpCode::try_from(flags) - .map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::OpCode(e)))?; - let z = Z::try_from(flags) - .map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::Z(e)))?; - let r_code = RCode::try_from(flags) - .map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::RCode(e)))?; + let op_code = OpCode::try_from(flags).map_err(|e| Self::Error::from(e))?; + let z = Z::try_from(flags).map_err(|e| Self::Error::from(e))?; + let r_code = RCode::try_from(flags).map_err(|e| Self::Error::from(e))?; Ok(Header { id: u16::from_be_bytes([value[0], value[1]]), @@ -340,13 +337,13 @@ mod tests { } #[rstest] - #[case(0b0_0011_0_0_0_0_000_0000, "OPCODE '3' is not supported".to_string())] - #[case(0b0_1101_0_0_0_0_000_0000, "OPCODE '13' is not supported".to_string())] - #[case(0b0_1111_0_0_0_0_000_0000, "OPCODE '15' is not supported".to_string())] - fn op_code_try_from_u16_fails(#[case] input: u16, #[case] error_msg: String) { + #[case(0b0_0011_0_0_0_0_000_0000, OpCodeTryFromError(3))] + #[case(0b0_1101_0_0_0_0_000_0000, OpCodeTryFromError(13))] + #[case(0b0_1111_0_0_0_0_000_0000, OpCodeTryFromError(15))] + fn op_code_try_from_u16_fails(#[case] input: u16, #[case] err: OpCodeTryFromError) { let result = OpCode::try_from(input); assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), error_msg); + assert_eq!(result.unwrap_err(), err); } #[rstest] @@ -364,7 +361,7 @@ mod tests { fn z_try_from_u16_fails(#[case] input: u16) { let result = Z::try_from(input); assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), "all Z bits most be zero"); + assert_eq!(result.unwrap_err(), ZTryFromError); } #[rstest] @@ -381,13 +378,13 @@ mod tests { } #[rstest] - #[case(0b0_0000_0_0_0_0_000_0110, "RCODE '6' is not supported".to_string())] - #[case(0b0_0000_0_0_0_0_000_1101, "RCODE '13' is not supported".to_string())] - #[case(0b0_0000_0_0_0_0_000_1111, "RCODE '15' is not supported".to_string())] - fn r_code_try_from_u16_fails(#[case] input: u16, #[case] error_msg: String) { + #[case(0b0_0000_0_0_0_0_000_0110, RCodeTryFromError(6))] + #[case(0b0_0000_0_0_0_0_000_1101, RCodeTryFromError(13))] + #[case(0b0_0000_0_0_0_0_000_1111, RCodeTryFromError(15))] + fn r_code_try_from_u16_fails(#[case] input: u16, #[case] err: RCodeTryFromError) { let result = RCode::try_from(input); assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), error_msg); + assert_eq!(result.unwrap_err(), err); } #[rstest] @@ -417,17 +414,17 @@ mod tests { #[case( // ID , Flags , QD , AN , NS , AR &[0, 255, 0b0_0111_0_0_0, 0b0_000_0000, 0, 1, 0, 0, 0, 0, 0, 0], - HeaderTryFromError::MalformedFlags(MalformedFlagsError::OpCode(OpCodeTryFromError(7))) + OpCodeTryFromError(7).into() )] #[case( // ID , Flags , QD , AN , NS , AR &[0, 255, 0b0_0000_0_0_0, 0b0_010_0000, 0, 1, 0, 0, 0, 0, 0, 0], - HeaderTryFromError::MalformedFlags(MalformedFlagsError::Z(ZTryFromError)) + ZTryFromError.into() )] #[case( // ID , Flags , QD , AN , NS , AR &[0, 255, 0b0_0000_0_0_0, 0b0_000_1100, 0, 1, 0, 0, 0, 0, 0, 0], - HeaderTryFromError::MalformedFlags(MalformedFlagsError::RCode(RCodeTryFromError(12))) + RCodeTryFromError(12).into() )] fn header_try_from_fails(#[case] input: &[u8], #[case] expected: HeaderTryFromError) { let result = Header::try_from(input);