Skip to content

Commit

Permalink
Implement TryFrom<&[u8]> for Header
Browse files Browse the repository at this point in the history
  • Loading branch information
matei-radu committed Jul 18, 2024
1 parent 69bc257 commit 5db04d5
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lib/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dns_lib"
version = "0.15.0"
version = "0.16.0"
description = "An implementation of the DNS protocol from scratch based on the many DNS RFCs."

rust-version.workspace = true
Expand Down
53 changes: 53 additions & 0 deletions lib/src/message/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,56 @@ impl fmt::Display for RCodeTryFromError {
}

impl Error for RCodeTryFromError {}

#[derive(Debug, PartialEq)]
pub enum MalformedFlagsError {
OpCode(OpCodeTryFromError),
Z(ZTryFromError),
RCode(RCodeTryFromError),
}

impl fmt::Display for MalformedFlagsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MalformedFlagsError::OpCode(e) => e.fmt(f),
MalformedFlagsError::Z(e) => e.fmt(f),
MalformedFlagsError::RCode(e) => e.fmt(f),
}
}
}

impl Error for MalformedFlagsError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
MalformedFlagsError::OpCode(e) => Some(e),
MalformedFlagsError::Z(e) => Some(e),
MalformedFlagsError::RCode(e) => Some(e),
}
}
}

#[derive(Debug, PartialEq)]
pub enum HeaderTryFromError {
InsufficientHeaderBytes(usize),
MalformedFlags(MalformedFlagsError),
}

impl fmt::Display for HeaderTryFromError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
HeaderTryFromError::InsufficientHeaderBytes(len) => {
write!(f, "insufficient header bytes ({} found, 12 required)", len)
}
HeaderTryFromError::MalformedFlags(e) => e.fmt(f),
}
}
}

impl Error for HeaderTryFromError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
HeaderTryFromError::MalformedFlags(e) => Some(e),
_ => None,
}
}
}
139 changes: 132 additions & 7 deletions lib/src/message/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::message::error::{OpCodeTryFromError, RCodeTryFromError, ZTryFromError};
use crate::message::error::{
HeaderTryFromError, MalformedFlagsError, OpCodeTryFromError, RCodeTryFromError, ZTryFromError,
};

/// `Header` section of a DNS `Message`.
///
Expand All @@ -24,22 +26,101 @@ use crate::message::error::{OpCodeTryFromError, RCodeTryFromError, ZTryFromError
/// For more details, see [RFC 1035, Section 4.1.1].
///
/// [RFC 1035, Section 4.1.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
#[derive(Debug, PartialEq)]
pub struct Header {
pub id: u16,

pub qr: QR,
pub opcode: OpCode,
pub op_code: OpCode,
pub aa: bool,
pub tc: bool,
pub rd: bool,
pub ra: bool,
pub z: Z,
pub rcode: RCode,
pub r_code: RCode,

pub qd_count: u16,
pub an_count: u16,
pub ns_count: u16,
pub ar_count: u16,
}

impl TryFrom<&[u8]> for Header {
type Error = HeaderTryFromError;

/// Tries to convert a slice `&[u8]` into a DNS message `Header`.
///
/// A valid DNS message header requires at least 12 bytes. Trying to convert
/// a smaller slice will result in an error. Errors will also be triggered
/// if any header flag is found to use reserved values.
///
/// For more details, see [RFC 1035, Section 4.1.1].
///
/// ```text
/// 1 1 1 1 1 1
/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// | ID |
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// |QR| OPCODE |AA|TC|RD|RA| Z | RCODE |
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// | QDCOUNT |
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// | ANCOUNT |
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// | NSCOUNT |
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// | ARCOUNT |
/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
/// ```
///
/// [RFC 1035, Section 4.1.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() < 12 {
return Err(HeaderTryFromError::InsufficientHeaderBytes(value.len()));
}

let flags = u16::from_be_bytes([value[2], value[3]]);

let op_code = OpCode::try_from(flags)
.map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::OpCode(e)))?;
let z = Z::try_from(flags)
.map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::Z(e)))?;
let r_code = RCode::try_from(flags)
.map_err(|e| Self::Error::MalformedFlags(MalformedFlagsError::RCode(e)))?;

Ok(Header {
id: u16::from_be_bytes([value[0], value[1]]),
qr: QR::from(flags),
op_code,
aa: parse_aa_flag(flags),
tc: parse_tc_flag(flags),
rd: parse_rd_flag(flags),
ra: parse_ra_flag(flags),
z,
r_code,
qd_count: u16::from_be_bytes([value[4], value[5]]),
an_count: u16::from_be_bytes([value[6], value[7]]),
ns_count: u16::from_be_bytes([value[8], value[9]]),
ar_count: u16::from_be_bytes([value[10], value[11]]),
})
}
}

fn parse_aa_flag(value: u16) -> bool {
(value & 0b0_0000_1_0_0_0_000_0000) >> 10 == 1
}

fn parse_tc_flag(value: u16) -> bool {
(value & 0b0_0000_0_1_0_0_000_0000) >> 9 == 1
}

fn parse_rd_flag(value: u16) -> bool {
(value & 0b0_0000_0_0_1_0_000_0000) >> 8 == 1
}

pub qdcount: u16,
pub ancount: u16,
pub nscount: u16,
pub arcount: u16,
fn parse_ra_flag(value: u16) -> bool {
(value & 0b0_0000_0_0_0_1_000_0000) >> 7 == 1
}

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -308,4 +389,48 @@ mod tests {
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), error_msg);
}

#[rstest]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0000_0_0_0, 0b0_000_0000, 0, 1, 0, 0, 0, 0, 0, 0],
Header{ id: 255, qr: QR::Query, op_code: OpCode::Query, aa: false, tc: false, rd: false, ra: false, z: Z::AllZeros, r_code: RCode::NoError, qd_count: 1, an_count: 0, ns_count: 0, ar_count: 0 }
)]
#[case(
// ID , Flags , QD , AN , NS , AR
&[2, 255, 0b1_0010_0_1_0, 0b0_000_0000, 0, 2, 0, 0, 0, 0, 0, 1],
Header{ id: 767, qr: QR::Response, op_code: OpCode::Status, aa: false, tc: true, rd: false, ra: false, z: Z::AllZeros, r_code: RCode::NoError, qd_count: 2, an_count: 0, ns_count: 0, ar_count: 1 }
)]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 1, 0b1_0001_1_1_1, 0b1_000_0011, 0, 4, 0, 4, 0, 4, 0, 4],
Header{ id: 1, qr: QR::Response, op_code: OpCode::InverseQuery, aa: true, tc: true, rd: true, ra: true, z: Z::AllZeros, r_code: RCode::NameError, qd_count: 4, an_count: 4, ns_count: 4, ar_count: 4 }
)]
fn header_try_from_succeeds(#[case] input: &[u8], #[case] expected: Header) {
let result = Header::try_from(input);
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected);
}

#[rstest]
#[case(&[0x01, 0x02], HeaderTryFromError::InsufficientHeaderBytes(2))]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0111_0_0_0, 0b0_000_0000, 0, 1, 0, 0, 0, 0, 0, 0],
HeaderTryFromError::MalformedFlags(MalformedFlagsError::OpCode(OpCodeTryFromError(7)))
)]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0000_0_0_0, 0b0_010_0000, 0, 1, 0, 0, 0, 0, 0, 0],
HeaderTryFromError::MalformedFlags(MalformedFlagsError::Z(ZTryFromError))
)]
#[case(
// ID , Flags , QD , AN , NS , AR
&[0, 255, 0b0_0000_0_0_0, 0b0_000_1100, 0, 1, 0, 0, 0, 0, 0, 0],
HeaderTryFromError::MalformedFlags(MalformedFlagsError::RCode(RCodeTryFromError(12)))
)]
fn header_try_from_fails(#[case] input: &[u8], #[case] expected: HeaderTryFromError) {
let result = Header::try_from(input);
assert_eq!(result.unwrap_err(), expected);
}
}

0 comments on commit 5db04d5

Please sign in to comment.