diff --git a/Cargo.lock b/Cargo.lock index ba38b97102e..0f5d06c87ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -805,14 +805,14 @@ dependencies = [ [[package]] name = "enum-as-inner" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116" +checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" dependencies = [ "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.29", ] [[package]] @@ -1192,6 +1192,51 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +[[package]] +name = "hickory-proto" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091a6fbccf4860009355e3efc52ff4acf37a63489aad7435372d44ceeb6fbbcf" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.4.0", + "ipnet", + "once_cell", + "rand", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "hickory-resolver" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35b8f021164e6a984c9030023544c57789c51760065cd510572fedcfb04164e8" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "lru-cache", + "once_cell", + "parking_lot", + "rand", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "hmac" version = "0.12.1" @@ -2383,7 +2428,6 @@ dependencies = [ "tokio-socks", "tokio-util", "tower-service", - "trust-dns-resolver", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -3268,51 +3312,6 @@ dependencies = [ "tracing-log", ] -[[package]] -name = "trust-dns-proto" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26" -dependencies = [ - "async-trait", - "cfg-if", - "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", - "idna 0.2.3", - "ipnet", - "lazy_static", - "rand", - "smallvec", - "thiserror", - "tinyvec", - "tokio", - "tracing", - "url", -] - -[[package]] -name = "trust-dns-resolver" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe" -dependencies = [ - "cfg-if", - "futures-util", - "ipconfig", - "lazy_static", - "lru-cache", - "parking_lot", - "resolv-conf", - "smallvec", - "thiserror", - "tokio", - "tracing", - "trust-dns-proto", -] - [[package]] name = "try-lock" version = "0.2.4" @@ -3464,7 +3463,9 @@ dependencies = [ "futures", "governor", "handlebars", + "hickory-resolver", "html5gum", + "hyper", "job_scheduler_ng", "jsonwebtoken", "lettre", diff --git a/Cargo.toml b/Cargo.toml index 309b379a541..3b6515334dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,7 +68,7 @@ dashmap = "5.5.3" # Async futures futures = "0.3.28" -tokio = { version = "1.32.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal"] } +tokio = { version = "1.32.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] } # A generic serialization/deserialization framework serde = { version = "1.0.188", features = ["derive"] } @@ -124,7 +124,9 @@ email_address = "0.2.4" handlebars = { version = "4.3.7", features = ["dir_source"] } # HTTP client (Used for favicons, version check, DUO and HIBP API) -reqwest = { version = "0.11.20", features = ["stream", "json", "deflate", "gzip", "brotli", "socks", "cookies", "trust-dns", "native-tls-alpn"] } +reqwest = { version = "0.11.20", features = ["stream", "json", "deflate", "gzip", "brotli", "socks", "cookies", "native-tls-alpn"] } +hyper = { version = "0.14.27", default-features = false } +hickory-resolver = "0.24.0" # Favicon extraction libraries html5gum = "0.5.7" diff --git a/src/api/icons.rs b/src/api/icons.rs index f47357bb310..ccd4ce451a2 100644 --- a/src/api/icons.rs +++ b/src/api/icons.rs @@ -1,6 +1,6 @@ use std::{ net::IpAddr, - sync::Arc, + sync::{Arc, Mutex}, time::{Duration, SystemTime}, }; @@ -16,14 +16,13 @@ use rocket::{http::ContentType, response::Redirect, Route}; use tokio::{ fs::{create_dir_all, remove_file, symlink_metadata, File}, io::{AsyncReadExt, AsyncWriteExt}, - net::lookup_host, }; use html5gum::{Emitter, HtmlString, InfallibleTokenizer, Readable, StringReader, Tokenizer}; use crate::{ error::Error, - util::{get_reqwest_client_builder, Cached}, + util::{get_reqwest_client_builder, Cached, CustomDnsResolver, CustomResolverError}, CONFIG, }; @@ -49,48 +48,35 @@ static CLIENT: Lazy = Lazy::new(|| { let icon_download_timeout = Duration::from_secs(CONFIG.icon_download_timeout()); let pool_idle_timeout = Duration::from_secs(10); // Reuse the client between requests - let client = get_reqwest_client_builder() + get_reqwest_client_builder() .cookie_provider(Arc::clone(&cookie_store)) .timeout(icon_download_timeout) .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections .pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds - .trust_dns(true) - .default_headers(default_headers.clone()); - - match client.build() { - Ok(client) => client, - Err(e) => { - error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'"); - get_reqwest_client_builder() - .cookie_provider(cookie_store) - .timeout(icon_download_timeout) - .pool_max_idle_per_host(5) // Configure the Hyper Pool to only have max 5 idle connections - .pool_idle_timeout(pool_idle_timeout) // Configure the Hyper Pool to timeout after 10 seconds - .trust_dns(false) - .default_headers(default_headers) - .build() - .expect("Failed to build client") - } - } + .dns_resolver(CustomDnsResolver::instance()) + .default_headers(default_headers.clone()) + .build() + .expect("Failed to build client") }); // Build Regex only once since this takes a lot of time. static ICON_SIZE_REGEX: Lazy = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap()); -// Special HashMap which holds the user defined Regex to speedup matching the regex. -static ICON_BLACKLIST_REGEX: Lazy> = Lazy::new(dashmap::DashMap::new); +// Compiled domain blacklist +static COMPILED_BLACKLIST: Mutex> = Mutex::new(None); -async fn icon_redirect(domain: &str, template: &str) -> Option { +#[get("//icon.png")] +fn icon_external(domain: &str) -> Option { if !is_valid_domain(domain) { warn!("Invalid domain: {}", domain); return None; } - if check_domain_blacklist_reason(domain).await.is_some() { + if is_domain_blacklisted(domain) { return None; } - let url = template.replace("{}", domain); + let url = CONFIG._icon_service_url().replace("{}", domain); match CONFIG.icon_redirect_code() { 301 => Some(Redirect::moved(url)), // legacy permanent redirect 302 => Some(Redirect::found(url)), // legacy temporary redirect @@ -103,11 +89,6 @@ async fn icon_redirect(domain: &str, template: &str) -> Option { } } -#[get("//icon.png")] -async fn icon_external(domain: &str) -> Option { - icon_redirect(domain, &CONFIG._icon_service_url()).await -} - #[get("//icon.png")] async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec)> { const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png"); @@ -166,153 +147,30 @@ fn is_valid_domain(domain: &str) -> bool { true } -/// TODO: This is extracted from IpAddr::is_global, which is unstable: -/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global -/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged -#[allow(clippy::nonminimal_bool)] -#[cfg(not(feature = "unstable"))] -fn is_global(ip: IpAddr) -> bool { - match ip { - IpAddr::V4(ip) => { - // check if this address is 192.0.0.9 or 192.0.0.10. These addresses are the only two - // globally routable addresses in the 192.0.0.0/24 range. - if u32::from(ip) == 0xc0000009 || u32::from(ip) == 0xc000000a { - return true; - } - !ip.is_private() - && !ip.is_loopback() - && !ip.is_link_local() - && !ip.is_broadcast() - && !ip.is_documentation() - && !(ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) - && !(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0) - && !(ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) - && !(ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) - // Make sure the address is not in 0.0.0.0/8 - && ip.octets()[0] != 0 - } - IpAddr::V6(ip) => { - if ip.is_multicast() && ip.segments()[0] & 0x000f == 14 { - true - } else { - !ip.is_multicast() - && !ip.is_loopback() - && !((ip.segments()[0] & 0xffc0) == 0xfe80) - && !((ip.segments()[0] & 0xfe00) == 0xfc00) - && !ip.is_unspecified() - && !((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) - } - } - } -} - -#[cfg(feature = "unstable")] -fn is_global(ip: IpAddr) -> bool { - ip.is_global() -} +pub fn is_domain_blacklisted(domain: &str) -> bool { + let Some(config_blacklist) = CONFIG.icon_blacklist_regex() else { + return false; + }; -/// These are some tests to check that the implementations match -/// The IPv4 can be all checked in 5 mins or so and they are correct as of nightly 2020-07-11 -/// The IPV6 can't be checked in a reasonable time, so we check about ten billion random ones, so far correct -/// Note that the is_global implementation is subject to change as new IP RFCs are created -/// -/// To run while showing progress output: -/// cargo test --features sqlite,unstable -- --nocapture --ignored -#[cfg(test)] -#[cfg(feature = "unstable")] -mod tests { - use super::*; - - #[test] - #[ignore] - fn test_ipv4_global() { - for a in 0..u8::MAX { - println!("Iter: {}/255", a); - for b in 0..u8::MAX { - for c in 0..u8::MAX { - for d in 0..u8::MAX { - let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d)); - assert_eq!(ip.is_global(), is_global(ip)) - } - } - } - } - } + let mut guard = COMPILED_BLACKLIST.lock().unwrap(); - #[test] - #[ignore] - fn test_ipv6_global() { - use ring::rand::{SecureRandom, SystemRandom}; - let mut v = [0u8; 16]; - let rand = SystemRandom::new(); - for i in 0..1_000 { - println!("Iter: {}/1_000", i); - for _ in 0..10_000_000 { - rand.fill(&mut v).expect("Error generating random values"); - let ip = IpAddr::V6(std::net::Ipv6Addr::new( - (v[14] as u16) << 8 | v[15] as u16, - (v[12] as u16) << 8 | v[13] as u16, - (v[10] as u16) << 8 | v[11] as u16, - (v[8] as u16) << 8 | v[9] as u16, - (v[6] as u16) << 8 | v[7] as u16, - (v[4] as u16) << 8 | v[5] as u16, - (v[2] as u16) << 8 | v[3] as u16, - (v[0] as u16) << 8 | v[1] as u16, - )); - assert_eq!(ip.is_global(), is_global(ip)) - } + // If the stored regex is up to date, use it + if let Some((value, regex)) = &*guard { + if value == &config_blacklist && regex.is_match(domain) { + return true; } } -} - -#[derive(Clone)] -enum DomainBlacklistReason { - Regex, - IP, -} - -use cached::proc_macro::cached; -#[cached(key = "String", convert = r#"{ domain.to_string() }"#, size = 16, time = 60)] -async fn check_domain_blacklist_reason(domain: &str) -> Option { - // First check the blacklist regex if there is a match. - // This prevents the blocked domain(s) from being leaked via a DNS lookup. - if let Some(blacklist) = CONFIG.icon_blacklist_regex() { - // Use the pre-generate Regex stored in a Lazy HashMap if there's one, else generate it. - let is_match = if let Some(regex) = ICON_BLACKLIST_REGEX.get(&blacklist) { - regex.is_match(domain) - } else { - // Clear the current list if the previous key doesn't exists. - // To prevent growing of the HashMap after someone has changed it via the admin interface. - if ICON_BLACKLIST_REGEX.len() >= 1 { - ICON_BLACKLIST_REGEX.clear(); - } - - // Generate the regex to store in too the Lazy Static HashMap. - let blacklist_regex = Regex::new(&blacklist).unwrap(); - let is_match = blacklist_regex.is_match(domain); - ICON_BLACKLIST_REGEX.insert(blacklist.clone(), blacklist_regex); - - is_match - }; - if is_match { - debug!("Blacklisted domain: {} matched ICON_BLACKLIST_REGEX", domain); - return Some(DomainBlacklistReason::Regex); - } - } + // If we don't have a regex stored, or it's not up to date, recreate it + let regex = Regex::new(&config_blacklist).unwrap(); + let is_match = regex.is_match(domain); + *guard = Some((config_blacklist, regex)); - if CONFIG.icon_blacklist_non_global_ips() { - if let Ok(s) = lookup_host((domain, 0)).await { - for addr in s { - if !is_global(addr.ip()) { - debug!("IP {} for domain '{}' is not a global IP!", addr.ip(), domain); - return Some(DomainBlacklistReason::IP); - } - } - } + if is_match { + return true; } - None + false } async fn get_icon(domain: &str) -> Option<(Vec, String)> { @@ -342,6 +200,13 @@ async fn get_icon(domain: &str) -> Option<(Vec, String)> { Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string())) } Err(e) => { + // If this error comes from the resolver, this means this is a blacklisted domain + // or non global IP, don't save the miss file in this case to avoid leaking it + if let Some(error) = CustomResolverError::downcast_ref(&e) { + warn!("{error}"); + return None; + } + warn!("Unable to download icon: {:?}", e); let miss_indicator = path + ".miss"; save_icon(&miss_indicator, &[]).await; @@ -573,21 +438,12 @@ async fn get_page(url: &str) -> Result { } async fn get_page_with_referer(url: &str, referer: &str) -> Result { - match check_domain_blacklist_reason(url::Url::parse(url).unwrap().host_str().unwrap_or_default()).await { - Some(DomainBlacklistReason::Regex) => warn!("Favicon '{}' is from a blacklisted domain!", url), - Some(DomainBlacklistReason::IP) => warn!("Favicon '{}' is hosted on a non-global IP!", url), - None => (), - } - let mut client = CLIENT.get(url); if !referer.is_empty() { client = client.header("Referer", referer) } - match client.send().await { - Ok(c) => c.error_for_status().map_err(Into::into), - Err(e) => err_silent!(format!("{e}")), - } + Ok(client.send().await?.error_for_status()?) } /// Returns a Integer with the priority of the type of the icon which to prefer. @@ -670,12 +526,6 @@ fn parse_sizes(sizes: &str) -> (u16, u16) { } async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> { - match check_domain_blacklist_reason(domain).await { - Some(DomainBlacklistReason::Regex) => err_silent!("Domain is blacklisted", domain), - Some(DomainBlacklistReason::IP) => err_silent!("Host resolves to a non-global IP", domain), - None => (), - } - let icon_result = get_icon_url(domain).await?; let mut buffer = Bytes::new(); @@ -711,22 +561,19 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> { _ => debug!("Extracted icon from data:image uri is invalid"), }; } else { - match get_page_with_referer(&icon.href, &icon_result.referer).await { - Ok(res) => { - buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net) - - // Check if the icon type is allowed, else try an icon from the list. - icon_type = get_icon_type(&buffer); - if icon_type.is_none() { - buffer.clear(); - debug!("Icon from {}, is not a valid image type", icon.href); - continue; - } - info!("Downloaded icon from {}", icon.href); - break; - } - Err(e) => debug!("{:?}", e), - }; + let res = get_page_with_referer(&icon.href, &icon_result.referer).await?; + + buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net) + + // Check if the icon type is allowed, else try an icon from the list. + icon_type = get_icon_type(&buffer); + if icon_type.is_none() { + buffer.clear(); + debug!("Icon from {}, is not a valid image type", icon.href); + continue; + } + info!("Downloaded icon from {}", icon.href); + break; } } diff --git a/src/api/mod.rs b/src/api/mod.rs index fd181fda504..5bc375edd84 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -20,7 +20,7 @@ pub use crate::api::{ core::two_factor::send_incomplete_2fa_notifications, core::{emergency_notification_reminder_job, emergency_request_timeout_job}, core::{event_cleanup_job, events_routes as core_events_routes}, - icons::routes as icons_routes, + icons::{is_domain_blacklisted, routes as icons_routes}, identity::routes as identity_routes, notifications::routes as notifications_routes, notifications::{start_notification_server, AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS}, diff --git a/src/util.rs b/src/util.rs index 52f371e4558..9414a72a765 100644 --- a/src/util.rs +++ b/src/util.rs @@ -6,6 +6,7 @@ use std::{ ops::Deref, }; +use once_cell::sync::Lazy; use rocket::{ fairing::{Fairing, Info, Kind}, http::{ContentType, Header, HeaderMap, Method, Status}, @@ -680,14 +681,9 @@ where use reqwest::{header, Client, ClientBuilder}; -pub fn get_reqwest_client() -> Client { - match get_reqwest_client_builder().build() { - Ok(client) => client, - Err(e) => { - error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'"); - get_reqwest_client_builder().trust_dns(false).build().expect("Failed to build client") - } - } +pub fn get_reqwest_client() -> &'static Client { + static INSTANCE: Lazy = Lazy::new(|| get_reqwest_client_builder().build().expect("Failed to build client")); + &INSTANCE } pub fn get_reqwest_client_builder() -> ClientBuilder { @@ -738,3 +734,256 @@ pub fn convert_json_key_lcase_first(src_json: Value) -> Value { value => value, } } + +mod dns_resolver { + use std::{ + fmt, + net::{IpAddr, SocketAddr}, + sync::Arc, + }; + + use hickory_resolver::{system_conf::read_system_conf, TokioAsyncResolver}; + use hyper::client::connect::dns::Name; + use once_cell::sync::Lazy; + use reqwest::dns::{Resolve, Resolving}; + + use crate::{util::is_global, CONFIG}; + + #[derive(Debug, Clone)] + pub enum CustomResolverError { + Blacklist { + domain: String, + }, + NonGlobalIp { + domain: String, + ip: IpAddr, + }, + } + + impl CustomResolverError { + pub fn downcast_ref(e: &dyn std::error::Error) -> Option<&Self> { + let mut source = e.source(); + + while let Some(err) = source { + source = err.source(); + if let Some(err) = err.downcast_ref::() { + return Some(err); + } + } + None + } + } + + impl fmt::Display for CustomResolverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Blacklist { + domain, + } => write!(f, "Blacklisted domain: {domain} matched ICON_BLACKLIST_REGEX"), + Self::NonGlobalIp { + domain, + ip, + } => write!(f, "IP {ip} for domain '{domain}' is not a global IP!"), + } + } + } + + impl std::error::Error for CustomResolverError {} + + #[derive(Debug, Clone)] + pub enum CustomDnsResolver { + Default(), + Hickory(Arc), + } + type BoxError = Box; + + impl CustomDnsResolver { + pub fn instance() -> Arc { + static INSTANCE: Lazy> = Lazy::new(CustomDnsResolver::new); + Arc::clone(&*INSTANCE) + } + + fn new() -> Arc { + match read_system_conf() { + Ok((config, opts)) => { + let resolver = TokioAsyncResolver::tokio(config.clone(), opts.clone()); + Arc::new(Self::Hickory(Arc::new(resolver))) + } + Err(e) => { + warn!("Error creating Hickory resolver, falling back to default: {e:?}"); + Arc::new(Self::Default()) + } + } + } + + // Note that we get an iterator of addresses, but we only grab the first one for convenience + async fn resolve_domain(&self, name: &str) -> Result, BoxError> { + match self { + Self::Default() => { + let mut lookup = tokio::net::lookup_host(name).await?; + Ok(lookup.next()) + } + Self::Hickory(resolver) => { + let lookup = resolver.lookup_ip(name).await?; + Ok(lookup.into_iter().next().map(|a| SocketAddr::new(a, 0))) + } + } + } + } + + fn pre_resolve(name: &str) -> Result<(), CustomResolverError> { + if crate::api::is_domain_blacklisted(name) { + return Err(CustomResolverError::Blacklist { + domain: name.to_string(), + }); + } + + Ok(()) + } + + fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomResolverError> { + if CONFIG.icon_blacklist_non_global_ips() && !is_global(ip) { + let e = CustomResolverError::NonGlobalIp { + domain: name.to_string(), + ip, + }; + warn!("{e}"); + Err(e) + } else { + Ok(()) + } + } + + impl Resolve for CustomDnsResolver { + fn resolve(&self, name: Name) -> Resolving { + let this = self.clone(); + Box::pin(async move { + let name = name.as_str(); + + pre_resolve(name)?; + let result = this.resolve_domain(name).await?; + if let Some(addr) = &result { + let ip = addr.ip(); + post_resolve(name, ip)?; + } + + Ok::(Box::new(result.into_iter())) + }) + } + } +} + +pub use dns_resolver::{CustomDnsResolver, CustomResolverError}; + +/// TODO: This is extracted from IpAddr::is_global, which is unstable: +/// https://doc.rust-lang.org/nightly/std/net/enum.IpAddr.html#method.is_global +/// Remove once https://github.com/rust-lang/rust/issues/27709 is merged +#[allow(clippy::nonminimal_bool)] +#[cfg(any(not(feature = "unstable"), test))] +pub fn is_global_hardcoded(ip: std::net::IpAddr) -> bool { + match ip { + std::net::IpAddr::V4(ip) => { + !(ip.octets()[0] == 0 // "This network" + || ip.is_private() + || (ip.octets()[0] == 100 && (ip.octets()[1] & 0b1100_0000 == 0b0100_0000)) //ip.is_shared() + || ip.is_loopback() + || ip.is_link_local() + // addresses reserved for future protocols (`192.0.0.0/24`) + ||(ip.octets()[0] == 192 && ip.octets()[1] == 0 && ip.octets()[2] == 0) + || ip.is_documentation() + || (ip.octets()[0] == 198 && (ip.octets()[1] & 0xfe) == 18) // ip.is_benchmarking() + || (ip.octets()[0] & 240 == 240 && !ip.is_broadcast()) //ip.is_reserved() + || ip.is_broadcast()) + } + std::net::IpAddr::V6(ip) => { + !(ip.is_unspecified() + || ip.is_loopback() + // IPv4-mapped Address (`::ffff:0:0/96`) + || matches!(ip.segments(), [0, 0, 0, 0, 0, 0xffff, _, _]) + // IPv4-IPv6 Translat. (`64:ff9b:1::/48`) + || matches!(ip.segments(), [0x64, 0xff9b, 1, _, _, _, _, _]) + // Discard-Only Address Block (`100::/64`) + || matches!(ip.segments(), [0x100, 0, 0, 0, _, _, _, _]) + // IETF Protocol Assignments (`2001::/23`) + || (matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if b < 0x200) + && !( + // Port Control Protocol Anycast (`2001:1::1`) + u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0001 + // Traversal Using Relays around NAT Anycast (`2001:1::2`) + || u128::from_be_bytes(ip.octets()) == 0x2001_0001_0000_0000_0000_0000_0000_0002 + // AMT (`2001:3::/32`) + || matches!(ip.segments(), [0x2001, 3, _, _, _, _, _, _]) + // AS112-v6 (`2001:4:112::/48`) + || matches!(ip.segments(), [0x2001, 4, 0x112, _, _, _, _, _]) + // ORCHIDv2 (`2001:20::/28`) + || matches!(ip.segments(), [0x2001, b, _, _, _, _, _, _] if (0x20..=0x2F).contains(&b)) + )) + || ((ip.segments()[0] == 0x2001) && (ip.segments()[1] == 0xdb8)) // ip.is_documentation() + || ((ip.segments()[0] & 0xfe00) == 0xfc00) //ip.is_unique_local() + || ((ip.segments()[0] & 0xffc0) == 0xfe80)) //ip.is_unicast_link_local() + } + } +} + +#[cfg(not(feature = "unstable"))] +pub use is_global_hardcoded as is_global; + +#[cfg(feature = "unstable")] +#[inline(always)] +pub fn is_global(ip: std::net::IpAddr) -> bool { + ip.is_global() +} + +/// These are some tests to check that the implementations match +/// The IPv4 can be all checked in 30 seconds or so and they are correct as of nightly 2023-07-17 +/// The IPV6 can't be checked in a reasonable time, so we check over a hundred billion random ones, so far correct +/// Note that the is_global implementation is subject to change as new IP RFCs are created +/// +/// To run while showing progress output: +/// cargo +nightly test --release --features sqlite,unstable -- --nocapture --ignored +#[cfg(test)] +#[cfg(feature = "unstable")] +mod tests { + use super::*; + use std::net::IpAddr; + + #[test] + #[ignore] + fn test_ipv4_global() { + for a in 0..u8::MAX { + println!("Iter: {}/255", a); + for b in 0..u8::MAX { + for c in 0..u8::MAX { + for d in 0..u8::MAX { + let ip = IpAddr::V4(std::net::Ipv4Addr::new(a, b, c, d)); + assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {}", ip) + } + } + } + } + } + + #[test] + #[ignore] + fn test_ipv6_global() { + use rand::Rng; + + std::thread::scope(|s| { + for t in 0..16 { + let handle = s.spawn(move || { + let mut v = [0u8; 16]; + let mut rng = rand::thread_rng(); + + for i in 0..20 { + println!("Thread {t} Iter: {i}/50"); + for _ in 0..500_000_000 { + rng.fill(&mut v); + let ip = IpAddr::V6(std::net::Ipv6Addr::from(v)); + assert_eq!(ip.is_global(), is_global_hardcoded(ip), "IP mismatch: {ip}"); + } + } + }); + } + }); + } +}