diff --git a/Cargo.toml b/Cargo.toml index 0c0a3ad..80f2e54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ readme = "README.md" [dependencies] log = "0.4" -regex = "1.1" nix = { version = "0.25.0", default-features = false, features = ["event", "fs", "process"] } libc = "0.2" serde = { version = "1.0", features = ["derive"], optional = true } diff --git a/src/hugetlb.rs b/src/hugetlb.rs index d25d9de..311dca1 100644 --- a/src/hugetlb.rs +++ b/src/hugetlb.rs @@ -181,7 +181,6 @@ impl HugeTlbController { } pub const HUGEPAGESIZE_DIR: &str = "/sys/kernel/mm/hugepages"; -use regex::Regex; use std::collections::HashMap; use std::fs; @@ -263,37 +262,46 @@ pub fn get_decimal_abbrs() -> Vec { } fn parse_size(s: &str, m: &HashMap) -> Result { - let re = Regex::new(r"(?P\d+)(?P[kKmMgGtTpP]?)[bB]?$"); + // Remove leading/trailing whitespace. + let s = s.trim(); - if re.is_err() { + // Remove an optional trailing 'b' or 'B' + let s = if let Some(stripped) = s.strip_suffix('b').or_else(|| s.strip_suffix('B')) { + stripped + } else { + s + }; + + // Ensure that the string is not empty after stripping. + if s.is_empty() { return Err(Error::new(InvalidBytesSize)); } - let caps = re.unwrap().captures(s).unwrap(); - let num = caps.name("num"); - let size: u128 = if let Some(num) = num { - let n = num.as_str().trim().parse::(); - if n.is_err() { - return Err(Error::new(InvalidBytesSize)); - } - n.unwrap() - } else { + // The last character should be the multiplier letter. + let last_char = s.chars().last().unwrap(); + if !"kKmMgGtTpP".contains(last_char) { return Err(Error::new(InvalidBytesSize)); - }; + } - let q = caps.name("mul"); - let mul: u128 = if let Some(q) = q { - let t = m.get(q.as_str()); - if let Some(t) = t { - *t - } else { - return Err(Error::new(InvalidBytesSize)); - } - } else { + // The numeric part is everything before the multiplier letter. + let num_part = &s[..s.len() - last_char.len_utf8()]; + if num_part.trim().is_empty() { return Err(Error::new(InvalidBytesSize)); - }; + } + + // Parse the numeric part into a u128. + let number: u128 = num_part + .trim() + .parse() + .map_err(|_| Error::new(InvalidBytesSize))?; - Ok(size * mul) + // Look up the multiplier in the provided HashMap. + let multiplier_key = last_char.to_string(); + let multiplier = m + .get(&multiplier_key) + .ok_or_else(|| Error::new(InvalidBytesSize))?; + + Ok(number * multiplier) } fn custom_size(mut size: f64, base: f64, m: &[String]) -> String { @@ -305,3 +313,60 @@ fn custom_size(mut size: f64, base: f64, m: &[String]) -> String { format!("{}{}", size, m[i].as_str()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_binary_size_valid() { + let m = get_binary_size_map(); + // Valid inputs must include a multiplier letter. + assert_eq!(parse_size("1k", &m).unwrap(), KiB); + assert_eq!(parse_size("2m", &m).unwrap(), 2 * MiB); + assert_eq!(parse_size("3g", &m).unwrap(), 3 * GiB); + assert_eq!(parse_size("4t", &m).unwrap(), 4 * TiB); + assert_eq!(parse_size("5p", &m).unwrap(), 5 * PiB); + } + + #[test] + fn test_decimal_size_valid() { + let m = get_decimal_size_map(); + assert_eq!(parse_size("1k", &m).unwrap(), KB); + assert_eq!(parse_size("2m", &m).unwrap(), 2 * MB); + assert_eq!(parse_size("3g", &m).unwrap(), 3 * GB); + assert_eq!(parse_size("4t", &m).unwrap(), 4 * TB); + assert_eq!(parse_size("5p", &m).unwrap(), 5 * PB); + } + + #[test] + fn test_trailing_b_suffix() { + let m = get_binary_size_map(); + // Trailing 'b' or 'B' should be accepted. + assert_eq!(parse_size("1kb", &m).unwrap(), KiB); + assert_eq!(parse_size("2mB", &m).unwrap(), 2 * MiB); + } + + #[test] + fn test_invalid_inputs() { + let m = get_binary_size_map(); + // Missing multiplier letter results in error. + assert!(parse_size("1", &m).is_err()); + // Invalid multiplier letter. + assert!(parse_size("10x", &m).is_err()); + // Non-numeric input. + assert!(parse_size("abc", &m).is_err()); + // Only multiplier letter with no number. + assert!(parse_size("k", &m).is_err()); + // Number with an invalid trailing character. + assert!(parse_size("123z", &m).is_err()); + } + + #[test] + fn test_uppercase_multiplier_fails() { + let m = get_binary_size_map(); + // Although the regex matches uppercase letters, the provided map only contains lowercase keys. + // Therefore, "1K" does not match any key and should produce an error. + assert!(parse_size("1K", &m).is_err()); + } +}