From 5db04d5332f935825aa3df07e9dba4b6e48c9246 Mon Sep 17 00:00:00 2001 From: Matei Radu Date: Thu, 18 Jul 2024 22:23:57 +0200 Subject: [PATCH] Implement TryFrom<&[u8]> for Header --- Cargo.lock | 2 +- lib/Cargo.toml | 2 +- lib/src/message/error.rs | 53 +++++++++++++++ lib/src/message/header.rs | 139 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 187 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a0100c1..e6a7049 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,7 +25,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "dns_lib" -version = "0.15.0" +version = "0.16.0" dependencies = [ "rstest", ] diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 51d8bcc..f62574e 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dns_lib" -version = "0.15.0" +version = "0.16.0" description = "An implementation of the DNS protocol from scratch based on the many DNS RFCs." rust-version.workspace = true diff --git a/lib/src/message/error.rs b/lib/src/message/error.rs index 9650143..557afe4 100644 --- a/lib/src/message/error.rs +++ b/lib/src/message/error.rs @@ -47,3 +47,56 @@ impl fmt::Display for RCodeTryFromError { } impl Error for RCodeTryFromError {} + +#[derive(Debug, PartialEq)] +pub enum MalformedFlagsError { + OpCode(OpCodeTryFromError), + Z(ZTryFromError), + RCode(RCodeTryFromError), +} + +impl fmt::Display for MalformedFlagsError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + MalformedFlagsError::OpCode(e) => e.fmt(f), + MalformedFlagsError::Z(e) => e.fmt(f), + MalformedFlagsError::RCode(e) => e.fmt(f), + } + } +} + +impl Error for MalformedFlagsError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + MalformedFlagsError::OpCode(e) => Some(e), + MalformedFlagsError::Z(e) => Some(e), + MalformedFlagsError::RCode(e) => Some(e), + } + } +} + +#[derive(Debug, PartialEq)] +pub enum HeaderTryFromError { + InsufficientHeaderBytes(usize), + MalformedFlags(MalformedFlagsError), +} + +impl fmt::Display for HeaderTryFromError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + HeaderTryFromError::InsufficientHeaderBytes(len) => { + write!(f, "insufficient header bytes ({} found, 12 required)", len) + } + HeaderTryFromError::MalformedFlags(e) => e.fmt(f), + } + } +} + +impl Error for HeaderTryFromError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + HeaderTryFromError::MalformedFlags(e) => Some(e), + _ => None, + } + } +} diff --git a/lib/src/message/header.rs b/lib/src/message/header.rs index f6cd5c2..f43c3de 100644 --- a/lib/src/message/header.rs +++ b/lib/src/message/header.rs @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::message::error::{OpCodeTryFromError, RCodeTryFromError, ZTryFromError}; +use crate::message::error::{ + HeaderTryFromError, MalformedFlagsError, OpCodeTryFromError, RCodeTryFromError, ZTryFromError, +}; /// `Header` section of a DNS `Message`. /// @@ -24,22 +26,101 @@ use crate::message::error::{OpCodeTryFromError, RCodeTryFromError, ZTryFromError /// For more details, see [RFC 1035, Section 4.1.1]. /// /// [RFC 1035, Section 4.1.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 +#[derive(Debug, PartialEq)] pub struct Header { pub id: u16, pub qr: QR, - pub opcode: OpCode, + pub op_code: OpCode, pub aa: bool, pub tc: bool, pub rd: bool, pub ra: bool, pub z: Z, - pub rcode: RCode, + pub r_code: RCode, + + pub qd_count: u16, + pub an_count: u16, + pub ns_count: u16, + pub ar_count: u16, +} + +impl TryFrom<&[u8]> for Header { + type Error = HeaderTryFromError; + + /// Tries to convert a slice `&[u8]` into a DNS message `Header`. + /// + /// A valid DNS message header requires at least 12 bytes. Trying to convert + /// a smaller slice will result in an error. Errors will also be triggered + /// if any header flag is found to use reserved values. + /// + /// For more details, see [RFC 1035, Section 4.1.1]. + /// + /// ```text + /// 1 1 1 1 1 1 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// | ID | + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// |QR| OPCODE |AA|TC|RD|RA| Z | RCODE | + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// | QDCOUNT | + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// | ANCOUNT | + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// | NSCOUNT | + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// | ARCOUNT | + /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + /// ``` + /// + /// [RFC 1035, Section 4.1.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 + fn try_from(value: &[u8]) -> Result { + if value.len() < 12 { + return Err(HeaderTryFromError::InsufficientHeaderBytes(value.len())); + } + + let flags = u16::from_be_bytes([value[2], value[3]]); + + let op_code = OpCode::try_from(flags) + .map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::OpCode(e)))?; + let z = Z::try_from(flags) + .map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::Z(e)))?; + let r_code = RCode::try_from(flags) + .map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::RCode(e)))?; + + Ok(Header { + id: u16::from_be_bytes([value[0], value[1]]), + qr: QR::from(flags), + op_code, + aa: parse_aa_flag(flags), + tc: parse_tc_flag(flags), + rd: parse_rd_flag(flags), + ra: parse_ra_flag(flags), + z, + r_code, + qd_count: u16::from_be_bytes([value[4], value[5]]), + an_count: u16::from_be_bytes([value[6], value[7]]), + ns_count: u16::from_be_bytes([value[8], value[9]]), + ar_count: u16::from_be_bytes([value[10], value[11]]), + }) + } +} + +fn parse_aa_flag(value: u16) -> bool { + (value & 0b0_0000_1_0_0_0_000_0000) >> 10 == 1 +} + +fn parse_tc_flag(value: u16) -> bool { + (value & 0b0_0000_0_1_0_0_000_0000) >> 9 == 1 +} + +fn parse_rd_flag(value: u16) -> bool { + (value & 0b0_0000_0_0_1_0_000_0000) >> 8 == 1 +} - pub qdcount: u16, - pub ancount: u16, - pub nscount: u16, - pub arcount: u16, +fn parse_ra_flag(value: u16) -> bool { + (value & 0b0_0000_0_0_0_1_000_0000) >> 7 == 1 } #[derive(Debug, PartialEq)] @@ -308,4 +389,48 @@ mod tests { assert!(result.is_err()); assert_eq!(result.unwrap_err().to_string(), error_msg); } + + #[rstest] + #[case( + // ID , Flags , QD , AN , NS , AR + &[0, 255, 0b0_0000_0_0_0, 0b0_000_0000, 0, 1, 0, 0, 0, 0, 0, 0], + Header{ id: 255, qr: QR::Query, op_code: OpCode::Query, aa: false, tc: false, rd: false, ra: false, z: Z::AllZeros, r_code: RCode::NoError, qd_count: 1, an_count: 0, ns_count: 0, ar_count: 0 } + )] + #[case( + // ID , Flags , QD , AN , NS , AR + &[2, 255, 0b1_0010_0_1_0, 0b0_000_0000, 0, 2, 0, 0, 0, 0, 0, 1], + Header{ id: 767, qr: QR::Response, op_code: OpCode::Status, aa: false, tc: true, rd: false, ra: false, z: Z::AllZeros, r_code: RCode::NoError, qd_count: 2, an_count: 0, ns_count: 0, ar_count: 1 } + )] + #[case( + // ID , Flags , QD , AN , NS , AR + &[0, 1, 0b1_0001_1_1_1, 0b1_000_0011, 0, 4, 0, 4, 0, 4, 0, 4], + Header{ id: 1, qr: QR::Response, op_code: OpCode::InverseQuery, aa: true, tc: true, rd: true, ra: true, z: Z::AllZeros, r_code: RCode::NameError, qd_count: 4, an_count: 4, ns_count: 4, ar_count: 4 } + )] + fn header_try_from_succeeds(#[case] input: &[u8], #[case] expected: Header) { + let result = Header::try_from(input); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected); + } + + #[rstest] + #[case(&[0x01, 0x02], HeaderTryFromError::InsufficientHeaderBytes(2))] + #[case( + // ID , Flags , QD , AN , NS , AR + &[0, 255, 0b0_0111_0_0_0, 0b0_000_0000, 0, 1, 0, 0, 0, 0, 0, 0], + HeaderTryFromError::MalformedFlags(MalformedFlagsError::OpCode(OpCodeTryFromError(7))) + )] + #[case( + // ID , Flags , QD , AN , NS , AR + &[0, 255, 0b0_0000_0_0_0, 0b0_010_0000, 0, 1, 0, 0, 0, 0, 0, 0], + HeaderTryFromError::MalformedFlags(MalformedFlagsError::Z(ZTryFromError)) + )] + #[case( + // ID , Flags , QD , AN , NS , AR + &[0, 255, 0b0_0000_0_0_0, 0b0_000_1100, 0, 1, 0, 0, 0, 0, 0, 0], + HeaderTryFromError::MalformedFlags(MalformedFlagsError::RCode(RCodeTryFromError(12))) + )] + fn header_try_from_fails(#[case] input: &[u8], #[case] expected: HeaderTryFromError) { + let result = Header::try_from(input); + assert_eq!(result.unwrap_err(), expected); + } }