diff --git a/aws_secretsmanager_agent/src/config.rs b/aws_secretsmanager_agent/src/config.rs index f071afb..71e7dbe 100644 --- a/aws_secretsmanager_agent/src/config.rs +++ b/aws_secretsmanager_agent/src/config.rs @@ -1,17 +1,21 @@ use crate::constants::EMPTY_ENV_LIST_MSG; use crate::constants::{BAD_MAX_CONN_MSG, BAD_PREFIX_MSG, EMPTY_SSRF_LIST_MSG}; use crate::constants::{DEFAULT_MAX_CONNECTIONS, GENERIC_CONFIG_ERR_MSG}; -use crate::constants::{INVALID_CACHE_SIZE_ERR_MSG, INVALID_HTTP_PORT_ERR_MSG}; +use crate::constants::{ + INVALID_CACHE_SIZE_ERR_MSG, INVALID_HTTP_ADDRESS_ERR_MSG, INVALID_HTTP_PORT_ERR_MSG, +}; use crate::constants::{INVALID_LOG_LEVEL_ERR_MSG, INVALID_TTL_SECONDS_ERR_MSG}; use config::Config as ConfigLib; use config::File; use serde_derive::Deserialize; +use std::net::Ipv4Addr; use std::num::NonZeroUsize; use std::ops::Range; use std::str::FromStr; use std::time::Duration; const DEFAULT_LOG_LEVEL: &str = "info"; +const DEFAULT_HTTP_ADDRESS: &str = "127.0.0.1"; const DEFAULT_HTTP_PORT: &str = "2773"; const DEFAULT_TTL_SECONDS: &str = "300"; const DEFAULT_CACHE_SIZE: &str = "1000"; @@ -32,6 +36,7 @@ const DEFAULT_REGION: Option = None; #[serde(deny_unknown_fields)] // We want to error out when file has misspelled or unknown configurations. struct ConfigFile { log_level: String, + http_address: String, http_port: String, ttl_seconds: String, cache_size: String, @@ -75,6 +80,8 @@ pub struct Config { /// The level of logging the agent provides ie. debug, info, warn, error or none. log_level: LogLevel, + http_address: [u8; 4], + /// The port for the local HTTP server. http_port: u16, @@ -130,6 +137,7 @@ impl Config { // Setting default configurations let mut config = ConfigLib::builder() .set_default("log_level", DEFAULT_LOG_LEVEL)? + .set_default("http_address", DEFAULT_HTTP_ADDRESS)? .set_default("http_port", DEFAULT_HTTP_PORT)? .set_default("ttl_seconds", DEFAULT_TTL_SECONDS)? .set_default("cache_size", DEFAULT_CACHE_SIZE)? @@ -164,6 +172,15 @@ impl Config { self.log_level } + /// The address for the local HTTP server to listen for incoming requests. + /// + /// # Returns + /// + /// * `address` - The address. Defaults to 127.0.0.1. + pub fn http_address(&self) -> [u8; 4] { + self.http_address + } + /// The port for the local HTTP server to listen for incomming requests. /// /// # Returns @@ -263,6 +280,7 @@ impl Config { let config = Config { // Configurations that are allowed to be overridden. log_level: LogLevel::from_str(config_file.log_level.as_str())?, + http_address: parse_address(&config_file.http_address, INVALID_HTTP_ADDRESS_ERR_MSG)?, http_port: parse_num::( &config_file.http_port, INVALID_HTTP_PORT_ERR_MSG, @@ -312,6 +330,32 @@ impl Config { } } +/// Private helper to convert a string to an array of u8 values for an IP address, returning a custom error on failure. +/// +/// # Arguments +/// +/// * `str_val` - The string to convert. +/// * `msg` - The custom error message. +/// +/// # Returns +/// +/// * `Ok(arr)` - When the string can be parsed into an IP address successfully. +/// * `Err(Error)` - The custom error message on failure. +/// +/// # Example +/// +/// ``` +/// assert_eq!(parse_address(&String::from("127.0.0.1"), "Unable to parse IP")); +/// ``` +#[doc(hidden)] +fn parse_address(str_val: &str, msg: &str) -> Result<[u8; 4], Box> { + let addr = str_val.parse::(); + match addr { + Ok(val) => Ok(val.octets()), + Err(_) => Err(msg)?, + } +} + /// Private helper to convert a string to number and perform range checks, returning a custom error on failure. /// /// # Arguments @@ -369,6 +413,7 @@ mod tests { fn get_default_config_file() -> ConfigFile { ConfigFile { log_level: String::from(DEFAULT_LOG_LEVEL), + http_address: String::from(DEFAULT_HTTP_ADDRESS), http_port: String::from(DEFAULT_HTTP_PORT), ttl_seconds: String::from(DEFAULT_TTL_SECONDS), cache_size: String::from(DEFAULT_CACHE_SIZE), @@ -386,6 +431,7 @@ mod tests { fn test_default_config() { let config = Config::default(); assert_eq!(config.clone().log_level(), LogLevel::Info); + assert_eq!(config.clone().http_address(), [127, 0, 0, 1]); assert_eq!(config.clone().http_port(), 2773); assert_eq!(config.clone().ttl(), Duration::from_secs(300)); assert_eq!( @@ -411,6 +457,7 @@ mod tests { fn test_config_overrides() { let config = Config::new(Some("tests/resources/configs/config_file_valid.toml")).unwrap(); assert_eq!(config.clone().log_level(), LogLevel::Debug); + assert_eq!(config.clone().http_address(), [0, 0, 0, 0]); assert_eq!(config.clone().http_port(), 65535); assert_eq!(config.clone().ttl(), Duration::from_secs(300)); assert_eq!( @@ -433,15 +480,25 @@ mod tests { /// Tests that an Err is returned when an invalid value is provided in one of the configurations. #[test] - fn test_config_overrides_invalid_value() { + fn test_config_overrides_invalid_log_level() { match Config::new(Some( - "tests/resources/configs/config_file_with_invalid_config.toml", + "tests/resources/configs/config_file_with_invalid_log_level.toml", )) { Ok(_) => panic!(), Err(e) => assert_eq!(e.to_string(), INVALID_LOG_LEVEL_ERR_MSG), }; } + #[test] + fn test_config_overrides_invalid_http_addr() { + match Config::new(Some( + "tests/resources/configs/config_file_with_invalid_http_address.toml", + )) { + Ok(_) => panic!(), + Err(e) => assert_eq!(e.to_string(), INVALID_HTTP_ADDRESS_ERR_MSG), + }; + } + /// Tests that an valid log level values don't return an Err. #[test] fn test_validate_config_valid_log_level_values() { @@ -487,6 +544,18 @@ mod tests { }; } + #[test] + fn test_validate_config_invalid_http_address() { + let invalid_config = ConfigFile { + http_address: String::from("invalid"), + ..get_default_config_file() + }; + match Config::build(invalid_config) { + Ok(_) => panic!(), + Err(e) => assert_eq!(e.to_string(), INVALID_HTTP_ADDRESS_ERR_MSG), + }; + } + /// Tests that an invalid http port value returns an Err #[test] fn test_validate_config_http_port_invalid_values() { diff --git a/aws_secretsmanager_agent/src/constants.rs b/aws_secretsmanager_agent/src/constants.rs index 9d4edde..54a603f 100644 --- a/aws_secretsmanager_agent/src/constants.rs +++ b/aws_secretsmanager_agent/src/constants.rs @@ -1,5 +1,6 @@ /// User visible error messages. pub const INVALID_LOG_LEVEL_ERR_MSG: &str = "The log level specified in the configuration file isn't valid. The log level must be DEBUG, INFO, WARN, ERROR, or NONE."; +pub const INVALID_HTTP_ADDRESS_ERR_MSG: &str = "The HTTP address specified in the configuration file isn't valid. The HTTP address must be an IP address like 127.0.0.1."; pub const INVALID_HTTP_PORT_ERR_MSG: &str = "The HTTP port specified in the configuration file isn't valid. The HTTP port must be in the range 1024 to 65535."; pub const INVALID_TTL_SECONDS_ERR_MSG: &str = "The TTL in seconds specified in the configuration file isn't valid. The TTL in seconds must be in the range 1 to 3600."; pub const INVALID_CACHE_SIZE_ERR_MSG: &str = "The cache size specified in the configuration file isn't valid. The cache size must be in the range 1 to 1000."; diff --git a/aws_secretsmanager_agent/src/main.rs b/aws_secretsmanager_agent/src/main.rs index e822dea..bbf3fb5 100644 --- a/aws_secretsmanager_agent/src/main.rs +++ b/aws_secretsmanager_agent/src/main.rs @@ -167,7 +167,7 @@ async fn init(args: impl IntoIterator) -> (Config, TcpListener) { } // Bind the listener to the specified port - let addr: SocketAddr = ([127, 0, 0, 1], config.http_port()).into(); + let addr: SocketAddr = (config.http_address(), config.http_port()).into(); let listener: TcpListener = match TcpListener::bind(addr).await { Ok(x) => x, Err(err) => { diff --git a/aws_secretsmanager_agent/tests/resources/configs/config_file_valid.toml b/aws_secretsmanager_agent/tests/resources/configs/config_file_valid.toml index d15186c..f36ed3a 100644 --- a/aws_secretsmanager_agent/tests/resources/configs/config_file_valid.toml +++ b/aws_secretsmanager_agent/tests/resources/configs/config_file_valid.toml @@ -1,5 +1,6 @@ # checking that all caps for log level is accpeted. log_level = "DEBUG" +http_address = "0.0.0.0" http_port = "65535" ssrf_headers = ["X-Aws-Parameters-Secrets-Token"] ssrf_env_variables = ["MY_TOKEN"] diff --git a/aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_http_address.toml b/aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_http_address.toml new file mode 100644 index 0000000..003bbee --- /dev/null +++ b/aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_http_address.toml @@ -0,0 +1,2 @@ +http_address = "invalid" +http_port = "9999" diff --git a/aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_config.toml b/aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_log_level.toml similarity index 100% rename from aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_config.toml rename to aws_secretsmanager_agent/tests/resources/configs/config_file_with_invalid_log_level.toml