Skip to content

Commit

Permalink
Refactor HeaderTryFromError
Browse files Browse the repository at this point in the history
I was not happy with my original approach to error wrapping, so I opted
for some semplifications using the `From` trait.
  • Loading branch information
matei-radu committed Jul 21, 2024
1 parent 0fb56e5 commit 2578458
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 45 deletions.
60 changes: 36 additions & 24 deletions lib/src/message/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,29 @@ impl fmt::Display for RCodeTryFromError {
impl Error for RCodeTryFromError {}

#[derive(Debug, PartialEq)]
pub enum MalformedFlagsError {
OpCode(OpCodeTryFromError),
Z(ZTryFromError),
RCode(RCodeTryFromError),
pub enum HeaderTryFromError {
InsufficientHeaderBytes(usize),
OpCodeTryFromError(OpCodeTryFromError),
ZTryFromError(ZTryFromError),
RCodeTryFromError(RCodeTryFromError),
}

impl fmt::Display for MalformedFlagsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MalformedFlagsError::OpCode(e) => e.fmt(f),
MalformedFlagsError::Z(e) => e.fmt(f),
MalformedFlagsError::RCode(e) => e.fmt(f),
}
impl From<OpCodeTryFromError> for HeaderTryFromError {
fn from(error: OpCodeTryFromError) -> HeaderTryFromError {
HeaderTryFromError::OpCodeTryFromError(error)
}
}

impl Error for MalformedFlagsError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
MalformedFlagsError::OpCode(e) => Some(e),
MalformedFlagsError::Z(e) => Some(e),
MalformedFlagsError::RCode(e) => Some(e),
}
impl From<ZTryFromError> for HeaderTryFromError {
fn from(error: ZTryFromError) -> HeaderTryFromError {
HeaderTryFromError::ZTryFromError(error)
}
}

#[derive(Debug, PartialEq)]
pub enum HeaderTryFromError {
InsufficientHeaderBytes(usize),
MalformedFlags(MalformedFlagsError),
impl From<RCodeTryFromError> for HeaderTryFromError {
fn from(error: RCodeTryFromError) -> HeaderTryFromError {
HeaderTryFromError::RCodeTryFromError(error)
}
}

impl fmt::Display for HeaderTryFromError {
Expand All @@ -87,16 +80,35 @@ impl fmt::Display for HeaderTryFromError {
HeaderTryFromError::InsufficientHeaderBytes(len) => {
write!(f, "insufficient header bytes ({} found, 12 required)", len)
}
HeaderTryFromError::MalformedFlags(e) => e.fmt(f),
HeaderTryFromError::OpCodeTryFromError(e) => e.fmt(f),
HeaderTryFromError::ZTryFromError(e) => e.fmt(f),
HeaderTryFromError::RCodeTryFromError(e) => e.fmt(f),
}
}
}

impl Error for HeaderTryFromError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
HeaderTryFromError::MalformedFlags(e) => Some(e),
HeaderTryFromError::OpCodeTryFromError(e) => Some(e),
HeaderTryFromError::ZTryFromError(e) => Some(e),
HeaderTryFromError::RCodeTryFromError(e) => Some(e),
_ => None,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;

#[rstest]
#[case(HeaderTryFromError::InsufficientHeaderBytes(3), "insufficient header bytes (3 found, 12 required)".to_string())]
#[case(OpCodeTryFromError(14).into(), "OPCODE '14' is not supported".to_string())]
#[case(ZTryFromError.into(), "all Z bits most be zero".to_string())]
#[case(RCodeTryFromError(7).into(), "RCODE '7' is not supported".to_string())]
fn header_try_from_error_display(#[case] err: HeaderTryFromError, #[case] msg: String) {
assert_eq!(err.to_string(), msg);
}
}
39 changes: 18 additions & 21 deletions lib/src/message/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use crate::message::error::{
HeaderTryFromError, MalformedFlagsError, OpCodeTryFromError, RCodeTryFromError, ZTryFromError,
HeaderTryFromError, OpCodeTryFromError, RCodeTryFromError, ZTryFromError,
};

/// `Header` section of a DNS `Message`.
Expand Down Expand Up @@ -82,12 +82,9 @@ impl TryFrom<&[u8]> for Header {

let flags = u16::from_be_bytes([value[2], value[3]]);

let op_code = OpCode::try_from(flags)
.map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::OpCode(e)))?;
let z = Z::try_from(flags)
.map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::Z(e)))?;
let r_code = RCode::try_from(flags)
.map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::RCode(e)))?;
let op_code = OpCode::try_from(flags).map_err(|e| Self::Error::from(e))?;
let z = Z::try_from(flags).map_err(|e| Self::Error::from(e))?;
let r_code = RCode::try_from(flags).map_err(|e| Self::Error::from(e))?;

Ok(Header {
id: u16::from_be_bytes([value[0], value[1]]),
Expand Down Expand Up @@ -340,13 +337,13 @@ mod tests {
}

#[rstest]
#[case(0b0_0011_0_0_0_0_000_0000, "OPCODE '3' is not supported".to_string())]
#[case(0b0_1101_0_0_0_0_000_0000, "OPCODE '13' is not supported".to_string())]
#[case(0b0_1111_0_0_0_0_000_0000, "OPCODE '15' is not supported".to_string())]
fn op_code_try_from_u16_fails(#[case] input: u16, #[case] error_msg: String) {
#[case(0b0_0011_0_0_0_0_000_0000, OpCodeTryFromError(3))]
#[case(0b0_1101_0_0_0_0_000_0000, OpCodeTryFromError(13))]
#[case(0b0_1111_0_0_0_0_000_0000, OpCodeTryFromError(15))]
fn op_code_try_from_u16_fails(#[case] input: u16, #[case] err: OpCodeTryFromError) {
let result = OpCode::try_from(input);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), error_msg);
assert_eq!(result.unwrap_err(), err);
}

#[rstest]
Expand All @@ -364,7 +361,7 @@ mod tests {
fn z_try_from_u16_fails(#[case] input: u16) {
let result = Z::try_from(input);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), "all Z bits most be zero");
assert_eq!(result.unwrap_err(), ZTryFromError);
}

#[rstest]
Expand All @@ -381,13 +378,13 @@ mod tests {
}

#[rstest]
#[case(0b0_0000_0_0_0_0_000_0110, "RCODE '6' is not supported".to_string())]
#[case(0b0_0000_0_0_0_0_000_1101, "RCODE '13' is not supported".to_string())]
#[case(0b0_0000_0_0_0_0_000_1111, "RCODE '15' is not supported".to_string())]
fn r_code_try_from_u16_fails(#[case] input: u16, #[case] error_msg: String) {
#[case(0b0_0000_0_0_0_0_000_0110, RCodeTryFromError(6))]
#[case(0b0_0000_0_0_0_0_000_1101, RCodeTryFromError(13))]
#[case(0b0_0000_0_0_0_0_000_1111, RCodeTryFromError(15))]
fn r_code_try_from_u16_fails(#[case] input: u16, #[case] err: RCodeTryFromError) {
let result = RCode::try_from(input);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), error_msg);
assert_eq!(result.unwrap_err(), err);
}

#[rstest]
Expand Down Expand Up @@ -417,17 +414,17 @@ mod tests {
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0111_0_0_0, 0b0_000_0000, 0, 1, 0, 0, 0, 0, 0, 0],
HeaderTryFromError::MalformedFlags(MalformedFlagsError::OpCode(OpCodeTryFromError(7)))
OpCodeTryFromError(7).into()
)]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0000_0_0_0, 0b0_010_0000, 0, 1, 0, 0, 0, 0, 0, 0],
HeaderTryFromError::MalformedFlags(MalformedFlagsError::Z(ZTryFromError))
ZTryFromError.into()
)]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0000_0_0_0, 0b0_000_1100, 0, 1, 0, 0, 0, 0, 0, 0],
HeaderTryFromError::MalformedFlags(MalformedFlagsError::RCode(RCodeTryFromError(12)))
RCodeTryFromError(12).into()
)]
fn header_try_from_fails(#[case] input: &[u8], #[case] expected: HeaderTryFromError) {
let result = Header::try_from(input);
Expand Down

0 comments on commit 2578458

Please sign in to comment.