diff --git a/Cargo.lock b/Cargo.lock index a30aadf..f5ade68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -167,6 +167,15 @@ dependencies = [ "log", ] +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "heck" version = "0.3.3" @@ -191,6 +200,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -218,6 +237,12 @@ version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "predicates" version = "3.1.0" @@ -423,12 +448,42 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-segmentation" version = "1.11.0" @@ -441,6 +496,17 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "utf8parse" version = "0.2.1" @@ -474,6 +540,7 @@ dependencies = [ "log", "structopt", "thiserror", + "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4abd8e4..a112e41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ log = { version = "0.4.20", features = ["kv_unstable"] } structopt = "0.3" env_logger = "0.11.3" thiserror = "1.0.59" +url = "2.5.0" [dev-dependencies] assert_cmd = "2.0.12" diff --git a/examples/main.rs b/examples/main.rs index d2c7607..875e6fe 100644 --- a/examples/main.rs +++ b/examples/main.rs @@ -1,7 +1,8 @@ use wait_for_rs::WaitService; fn main() { - let urls = ["google.com:443".to_string(), "github.com:443".to_string()].to_vec(); + // let urls = ["google.com:443".to_string(), "github.com:443".to_string()].to_vec(); + let urls = ["google.com:443".to_string()].to_vec(); let timeout = 10; let wait_service = WaitService::new(urls, timeout).unwrap(); diff --git a/src/errors.rs b/src/errors.rs index fe76c33..256b583 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,3 +1,5 @@ +use std::net::AddrParseError; + use thiserror::Error; /// Custom error type for the WaitService. @@ -11,6 +13,14 @@ pub enum WaitServiceError { #[error("Urls are empty")] UrlsEmpty, + /// Url parse error. + #[error("{}", _0)] + Url(#[from] url::ParseError), + + /// Address parse error. + #[error("{}", _0)] + Address(#[from] AddrParseError), + /// IO error. #[error("{}", _0)] Io(#[from] std::io::Error), diff --git a/src/wait.rs b/src/wait.rs index fb0189f..7a500e1 100644 --- a/src/wait.rs +++ b/src/wait.rs @@ -7,11 +7,16 @@ use std::{ thread::{sleep, spawn}, time::{Duration, Instant}, }; +use url::Url; /// The default interval in milliseconds between pings for the same url const DEFAULT_INTERVAL: u64 = 500; /// The default connection timeout in seconds const DEAFULT_CONNECTION_TIMEOUT: u64 = 1; +/// The default port for https to use when port is not provided +const DEFAULT_HTTPS_PORT: u16 = 443; +/// The default port for http to use when port is not provided +const DEFAULT_HTTP_PORT: u16 = 80; pub struct WaitService { addresses: HashMap, @@ -64,12 +69,72 @@ impl WaitService { } fn resolve_address(url: &str) -> Result { - match url.to_socket_addrs() { - Ok(mut addr) => { - return Ok(addr.next().unwrap()); + // Try parsing the input as a URL + if let Ok(parsed_url) = Url::parse(url) { + if parsed_url.has_host() { + // Check if the URL includes a port + let port = match parsed_url.port() { + Some(port) => port, + None => get_default_port(&parsed_url)?, + }; + + let host = parsed_url + .host_str() + .ok_or(WaitServiceError::UrlNotParsed)?; + + // Construct the address + let addr = socket_addr_from_host_and_port(&host, port)?; + + return Ok(addr); + } else { + // If there is no host then it is only domain + let addr = socket_addr_from_domain(url)?; + return Ok(addr); } - Err(e) => return Err(WaitServiceError::Io(e)), } + let addr = socket_addr_from_tcp(url)?; + + return Ok(addr); +} + +/// Default port based on scheme +fn get_default_port(url: &Url) -> Result { + match url.scheme() { + "https" => Ok(DEFAULT_HTTPS_PORT), + "http" => Ok(DEFAULT_HTTP_PORT), + _ => return Err(WaitServiceError::UrlNotParsed), + } +} + +/// Construct a `SocketAddr` from a host and a port +fn socket_addr_from_host_and_port(host: &str, port: u16) -> Result { + let addr = (host, port) + .to_socket_addrs() + .map_err(|e| WaitServiceError::Io(e))? + .next() + .ok_or(WaitServiceError::UrlNotParsed)?; + + Ok(addr) +} + +/// Construct a `SocketAddr` from a url +fn socket_addr_from_domain(url: &str) -> Result { + let addr = url + .to_socket_addrs() + .map_err(|e| WaitServiceError::Io(e))? + .next() + .ok_or(WaitServiceError::UrlNotParsed)?; + + Ok(addr) +} + +/// Construct a `SocketAddr` from a tcp address +fn socket_addr_from_tcp(address: &str) -> Result { + let addr = address + .parse::() + .map_err(|e| WaitServiceError::Address(e))?; + + Ok(addr) } fn wait_for_tcp_socket( diff --git a/tests/wait.rs b/tests/wait.rs index ab6a11a..0cbe0ab 100644 --- a/tests/wait.rs +++ b/tests/wait.rs @@ -96,3 +96,16 @@ fn multiple_addresses_with_timeout_works() { drop(server1); drop(server2); } + +#[test] +fn one_address_one_https_works() { + let server1 = utils::TestServer::new(4000, Duration::from_millis(10)); + + Command::cargo_bin("wait-for-rs") + .unwrap() + .args(&["--urls", "127.0.0.1:4000", "https://google.com"]) + .assert() + .success(); + + drop(server1); +}