Skip to content

Configurable HTTP Address #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions aws_secretsmanager_agent/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::constants::EMPTY_ENV_LIST_MSG;
use crate::constants::{BAD_MAX_CONN_MSG, BAD_PREFIX_MSG, EMPTY_SSRF_LIST_MSG};
use crate::constants::{DEFAULT_MAX_CONNECTIONS, GENERIC_CONFIG_ERR_MSG};
use crate::constants::{INVALID_CACHE_SIZE_ERR_MSG, INVALID_HTTP_PORT_ERR_MSG};
use crate::constants::{
INVALID_CACHE_SIZE_ERR_MSG, INVALID_HTTP_ADDRESS_ERR_MSG, INVALID_HTTP_PORT_ERR_MSG,
};
use crate::constants::{INVALID_LOG_LEVEL_ERR_MSG, INVALID_TTL_SECONDS_ERR_MSG};
use config::Config as ConfigLib;
use config::File;
use serde_derive::Deserialize;
use std::net::Ipv4Addr;
use std::num::NonZeroUsize;
use std::ops::Range;
use std::str::FromStr;
use std::time::Duration;

const DEFAULT_LOG_LEVEL: &str = "info";
const DEFAULT_HTTP_ADDRESS: &str = "127.0.0.1";
const DEFAULT_HTTP_PORT: &str = "2773";
const DEFAULT_TTL_SECONDS: &str = "300";
const DEFAULT_CACHE_SIZE: &str = "1000";
Expand All @@ -32,6 +36,7 @@ const DEFAULT_REGION: Option<String> = None;
#[serde(deny_unknown_fields)] // We want to error out when file has misspelled or unknown configurations.
struct ConfigFile {
log_level: String,
http_address: String,
http_port: String,
ttl_seconds: String,
cache_size: String,
Expand Down Expand Up @@ -75,6 +80,8 @@ pub struct Config {
/// The level of logging the agent provides ie. debug, info, warn, error or none.
log_level: LogLevel,

http_address: [u8; 4],

/// The port for the local HTTP server.
http_port: u16,

Expand Down Expand Up @@ -130,6 +137,7 @@ impl Config {
// Setting default configurations
let mut config = ConfigLib::builder()
.set_default("log_level", DEFAULT_LOG_LEVEL)?
.set_default("http_address", DEFAULT_HTTP_ADDRESS)?
.set_default("http_port", DEFAULT_HTTP_PORT)?
.set_default("ttl_seconds", DEFAULT_TTL_SECONDS)?
.set_default("cache_size", DEFAULT_CACHE_SIZE)?
Expand Down Expand Up @@ -164,6 +172,15 @@ impl Config {
self.log_level
}

/// The address for the local HTTP server to listen for incoming requests.
///
/// # Returns
///
/// * `address` - The address. Defaults to 127.0.0.1.
pub fn http_address(&self) -> [u8; 4] {
self.http_address
}

/// The port for the local HTTP server to listen for incomming requests.
///
/// # Returns
Expand Down Expand Up @@ -263,6 +280,7 @@ impl Config {
let config = Config {
// Configurations that are allowed to be overridden.
log_level: LogLevel::from_str(config_file.log_level.as_str())?,
http_address: parse_address(&config_file.http_address, INVALID_HTTP_ADDRESS_ERR_MSG)?,
http_port: parse_num::<u16>(
&config_file.http_port,
INVALID_HTTP_PORT_ERR_MSG,
Expand Down Expand Up @@ -312,6 +330,32 @@ impl Config {
}
}

/// Private helper to convert a string to an array of u8 values for an IP address, returning a custom error on failure.
///
/// # Arguments
///
/// * `str_val` - The string to convert.
/// * `msg` - The custom error message.
///
/// # Returns
///
/// * `Ok(arr)` - When the string can be parsed into an IP address successfully.
/// * `Err(Error)` - The custom error message on failure.
///
/// # Example
///
/// ```
/// assert_eq!(parse_address(&String::from("127.0.0.1"), "Unable to parse IP"));
/// ```
#[doc(hidden)]
fn parse_address(str_val: &str, msg: &str) -> Result<[u8; 4], Box<dyn std::error::Error>> {
let addr = str_val.parse::<Ipv4Addr>();
match addr {
Ok(val) => Ok(val.octets()),
Err(_) => Err(msg)?,
}
}

/// Private helper to convert a string to number and perform range checks, returning a custom error on failure.
///
/// # Arguments
Expand Down Expand Up @@ -369,6 +413,7 @@ mod tests {
fn get_default_config_file() -> ConfigFile {
ConfigFile {
log_level: String::from(DEFAULT_LOG_LEVEL),
http_address: String::from(DEFAULT_HTTP_ADDRESS),
http_port: String::from(DEFAULT_HTTP_PORT),
ttl_seconds: String::from(DEFAULT_TTL_SECONDS),
cache_size: String::from(DEFAULT_CACHE_SIZE),
Expand All @@ -386,6 +431,7 @@ mod tests {
fn test_default_config() {
let config = Config::default();
assert_eq!(config.clone().log_level(), LogLevel::Info);
assert_eq!(config.clone().http_address(), [127, 0, 0, 1]);
assert_eq!(config.clone().http_port(), 2773);
assert_eq!(config.clone().ttl(), Duration::from_secs(300));
assert_eq!(
Expand All @@ -411,6 +457,7 @@ mod tests {
fn test_config_overrides() {
let config = Config::new(Some("tests/resources/configs/config_file_valid.toml")).unwrap();
assert_eq!(config.clone().log_level(), LogLevel::Debug);
assert_eq!(config.clone().http_address(), [0, 0, 0, 0]);
assert_eq!(config.clone().http_port(), 65535);
assert_eq!(config.clone().ttl(), Duration::from_secs(300));
assert_eq!(
Expand All @@ -433,15 +480,25 @@ mod tests {

/// Tests that an Err is returned when an invalid value is provided in one of the configurations.
#[test]
fn test_config_overrides_invalid_value() {
fn test_config_overrides_invalid_log_level() {
match Config::new(Some(
"tests/resources/configs/config_file_with_invalid_config.toml",
"tests/resources/configs/config_file_with_invalid_log_level.toml",
)) {
Ok(_) => panic!(),
Err(e) => assert_eq!(e.to_string(), INVALID_LOG_LEVEL_ERR_MSG),
};
}

#[test]
fn test_config_overrides_invalid_http_addr() {
match Config::new(Some(
"tests/resources/configs/config_file_with_invalid_http_address.toml",
)) {
Ok(_) => panic!(),
Err(e) => assert_eq!(e.to_string(), INVALID_HTTP_ADDRESS_ERR_MSG),
};
}

/// Tests that an valid log level values don't return an Err.
#[test]
fn test_validate_config_valid_log_level_values() {
Expand Down Expand Up @@ -487,6 +544,18 @@ mod tests {
};
}

#[test]
fn test_validate_config_invalid_http_address() {
let invalid_config = ConfigFile {
http_address: String::from("invalid"),
..get_default_config_file()
};
match Config::build(invalid_config) {
Ok(_) => panic!(),
Err(e) => assert_eq!(e.to_string(), INVALID_HTTP_ADDRESS_ERR_MSG),
};
}

/// Tests that an invalid http port value returns an Err
#[test]
fn test_validate_config_http_port_invalid_values() {
Expand Down
1 change: 1 addition & 0 deletions aws_secretsmanager_agent/src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/// User visible error messages.
pub const INVALID_LOG_LEVEL_ERR_MSG: &str = "The log level specified in the configuration file isn't valid. The log level must be DEBUG, INFO, WARN, ERROR, or NONE.";
pub const INVALID_HTTP_ADDRESS_ERR_MSG: &str = "The HTTP address specified in the configuration file isn't valid. The HTTP address must be an IP address like 127.0.0.1.";
pub const INVALID_HTTP_PORT_ERR_MSG: &str = "The HTTP port specified in the configuration file isn't valid. The HTTP port must be in the range 1024 to 65535.";
pub const INVALID_TTL_SECONDS_ERR_MSG: &str = "The TTL in seconds specified in the configuration file isn't valid. The TTL in seconds must be in the range 1 to 3600.";
pub const INVALID_CACHE_SIZE_ERR_MSG: &str = "The cache size specified in the configuration file isn't valid. The cache size must be in the range 1 to 1000.";
Expand Down
2 changes: 1 addition & 1 deletion aws_secretsmanager_agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async fn init(args: impl IntoIterator<Item = String>) -> (Config, TcpListener) {
}

// Bind the listener to the specified port
let addr: SocketAddr = ([127, 0, 0, 1], config.http_port()).into();
let addr: SocketAddr = (config.http_address(), config.http_port()).into();
let listener: TcpListener = match TcpListener::bind(addr).await {
Ok(x) => x,
Err(err) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# checking that all caps for log level is accpeted.
log_level = "DEBUG"
http_address = "0.0.0.0"
http_port = "65535"
ssrf_headers = ["X-Aws-Parameters-Secrets-Token"]
ssrf_env_variables = ["MY_TOKEN"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
http_address = "invalid"
http_port = "9999"
Loading