diff --git a/Cargo.lock b/Cargo.lock index 1d68ef20b..4e4b04905 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,37 +99,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" -[[package]] -name = "argh" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219" -dependencies = [ - "argh_derive", - "argh_shared", -] - -[[package]] -name = "argh_derive" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a" -dependencies = [ - "argh_shared", - "proc-macro2", - "quote", - "syn 2.0.57", -] - -[[package]] -name = "argh_shared" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531" -dependencies = [ - "serde", -] - [[package]] name = "asn1-rs" version = "0.5.2" @@ -177,7 +146,7 @@ checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -246,7 +215,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -272,9 +241,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytecount" @@ -345,16 +314,16 @@ dependencies = [ [[package]] name = "castaway" version = "0.2.3" -source = "git+https://github.com/sagebind/castaway.git#564b11fb3394802b895f44fe42a7bba7b17df69b" +source = "git+https://github.com/sagebind/castaway.git#7e15c4627055c582d45c30a75ac275e20fa69a69" dependencies = [ "rustversion", ] [[package]] name = "cc" -version = "1.0.90" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" +checksum = "2678b2e3449475e95b0aa6f9b506a28e61b3dc8996592b983695e8ebb58a8b41" dependencies = [ "jobserver", "libc", @@ -417,7 +386,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -455,9 +424,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.0.1" +version = "3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86ec7a15cbe22e59248fc7eadb1907dab5ba09372595da4d73dd805ed4417dfe" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" dependencies = [ "crc-catalog", ] @@ -545,7 +514,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -618,7 +587,7 @@ checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -705,7 +674,7 @@ checksum = "323d8b61c76be2c16eb2d72d007f1542fdeb3760fdf2e2cae219fc0da3db0c09" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -753,7 +722,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -795,9 +764,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", "libc", @@ -1155,11 +1124,11 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "pem" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" dependencies = [ - "base64 0.21.7", + "base64 0.22.0", "serde", ] @@ -1194,7 +1163,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -1261,7 +1230,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" dependencies = [ "proc-macro2", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -1607,9 +1576,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" [[package]] name = "ryu" @@ -1668,7 +1637,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -1796,9 +1765,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "strsim" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "subtle" @@ -1819,9 +1788,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.57" +version = "2.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11a6ae1e52eb25aab8f3fb9fca13be982a373b8f1157ca14b897a825ba4a2d35" +checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" dependencies = [ "proc-macro2", "quote", @@ -1842,9 +1811,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.30.7" +version = "0.30.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c385888ef380a852a16209afc8cfad22795dd8873d69c9a14d2e2088f118d18" +checksum = "e9a84fe4cfc513b41cb2596b624e561ec9e7e1c4b46328e496ed56a53514ef2a" dependencies = [ "cfg-if", "core-foundation-sys", @@ -1890,7 +1859,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -1982,7 +1951,7 @@ dependencies = [ "proc-macro2", "quote", "rustc-hash", - "syn 2.0.57", + "syn 2.0.58", "tl-scheme", ] @@ -2026,7 +1995,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -2074,7 +2043,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] @@ -2226,10 +2195,10 @@ dependencies = [ "ahash", "anyhow", "arc-swap", - "argh", "base64 0.21.7", "bytes", "castaway", + "clap", "dashmap", "ed25519", "everscale-crypto", @@ -2446,7 +2415,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", "wasm-bindgen-shared", ] @@ -2468,7 +2437,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2726,7 +2695,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.58", ] [[package]] diff --git a/network/Cargo.toml b/network/Cargo.toml index e72ae9221..2aefe041e 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -46,7 +46,7 @@ x509-parser = "0.15" tycho-util = { path = "../util", version = "=0.0.1" } [dev-dependencies] -argh = "0.1" +clap = { version = "4.5.3", features = ["derive"] } serde_json = "1.0" tokio = { version = "1", features = ["rt-multi-thread"] } tracing-appender = "0.2.3" diff --git a/network/examples/network_node.rs b/network/examples/network_node.rs index d668f2fad..4e4eda0af 100644 --- a/network/examples/network_node.rs +++ b/network/examples/network_node.rs @@ -8,7 +8,7 @@ use std::net::SocketAddr; use std::sync::Arc; use anyhow::Result; -use argh::FromArgs; +use clap::{Parser, Subcommand}; use everscale_crypto::ed25519; use serde::{Deserialize, Serialize}; use tracing_subscriber::layer::SubscriberExt; @@ -20,18 +20,17 @@ use tycho_util::time::now_sec; #[tokio::main] async fn main() -> Result<()> { - let app: App = argh::from_env(); - app.run().await + Cli::parse().run().await } /// Tycho network node. -#[derive(FromArgs)] -struct App { - #[argh(subcommand)] +#[derive(Parser)] +struct Cli { + #[clap(subcommand)] cmd: Cmd, } -impl App { +impl Cli { async fn run(self) -> Result<()> { let enable_persistent_logs = std::env::var("TYCHO_PERSISTENT_LOGS").is_ok(); @@ -67,8 +66,7 @@ impl App { } } -#[derive(FromArgs)] -#[argh(subcommand)] +#[derive(Subcommand)] enum Cmd { Run(CmdRun), GenKey(CmdGenKey), @@ -76,23 +74,21 @@ enum Cmd { } /// run a node -#[derive(FromArgs)] -#[argh(subcommand, name = "run")] +#[derive(Parser)] struct CmdRun { /// local node address - #[argh(positional)] addr: SocketAddr, /// node secret key - #[argh(option)] + #[clap(long)] key: String, /// path to the node config - #[argh(option)] + #[clap(long)] config: Option, /// path to the global config - #[argh(option)] + #[clap(long)] global_config: String, } @@ -125,8 +121,7 @@ impl CmdRun { } /// generate a key -#[derive(FromArgs)] -#[argh(subcommand, name = "genkey")] +#[derive(Parser)] struct CmdGenKey {} impl CmdGenKey { @@ -150,19 +145,17 @@ impl CmdGenKey { } /// generate a dht node info -#[derive(FromArgs)] -#[argh(subcommand, name = "gendht")] +#[derive(Parser)] struct CmdGenDht { /// local node address - #[argh(positional)] addr: SocketAddr, /// node secret key - #[argh(option)] + #[clap(long)] key: String, /// time to live in seconds (default: unlimited) - #[argh(option)] + #[clap(long)] ttl: Option, } diff --git a/network/src/dht/background_tasks.rs b/network/src/dht/background_tasks.rs new file mode 100644 index 000000000..816079aaf --- /dev/null +++ b/network/src/dht/background_tasks.rs @@ -0,0 +1,216 @@ +use std::collections::hash_map; +use std::sync::Arc; + +use anyhow::Result; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; +use tokio::sync::Semaphore; +use tokio::task::JoinHandle; +use tycho_util::time::{now_sec, shifted_interval}; + +use crate::dht::{random_key_at_distance, DhtInner, DhtQueryMode, Query}; +use crate::network::{Network, WeakNetwork}; +use crate::proto::dht::{PeerValueKeyName, ValueRef}; +use crate::types::PeerInfo; + +impl DhtInner { + pub(crate) fn start_background_tasks(self: &Arc, network: WeakNetwork) { + enum Action { + RefreshLocalPeerInfo, + AnnounceLocalPeerInfo, + RefreshRoutingTable, + AddPeer(Arc), + } + + let mut refresh_peer_info_interval = + tokio::time::interval(self.config.local_info_refresh_period); + let mut announce_peer_info_interval = shifted_interval( + self.config.local_info_announce_period, + self.config.local_info_announce_period_max_jitter, + ); + let mut refresh_routing_table_interval = shifted_interval( + self.config.routing_table_refresh_period, + self.config.routing_table_refresh_period_max_jitter, + ); + + let mut announced_peers = self.announced_peers.subscribe(); + + let this = Arc::downgrade(self); + tokio::spawn(async move { + tracing::debug!("background DHT loop started"); + + let mut prev_refresh_routing_table_fut = None::>; + loop { + let action = tokio::select! { + _ = refresh_peer_info_interval.tick() => Action::RefreshLocalPeerInfo, + _ = announce_peer_info_interval.tick() => Action::AnnounceLocalPeerInfo, + _ = refresh_routing_table_interval.tick() => Action::RefreshRoutingTable, + peer = announced_peers.recv() => match peer { + Ok(peer) => Action::AddPeer(peer), + Err(_) => continue, + } + }; + + let (Some(this), Some(network)) = (this.upgrade(), network.upgrade()) else { + break; + }; + + match action { + Action::RefreshLocalPeerInfo => { + this.refresh_local_peer_info(&network); + } + Action::AnnounceLocalPeerInfo => { + // Peer info is always refreshed before announcing + refresh_peer_info_interval.reset(); + + if let Err(e) = this.announce_local_peer_info(&network).await { + tracing::error!("failed to announce local DHT node info: {e}"); + } + } + Action::RefreshRoutingTable => { + if let Some(fut) = prev_refresh_routing_table_fut.take() { + if let Err(e) = fut.await { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + } + } + + prev_refresh_routing_table_fut = Some(tokio::spawn(async move { + this.refresh_routing_table(&network).await; + })); + } + Action::AddPeer(peer_info) => { + let peer_id = peer_info.id; + let added = this.add_peer_info(&network, peer_info); + tracing::debug!( + local_id = %this.local_id, + %peer_id, + ?added, + "received peer info", + ); + + if let Err(e) = added { + tracing::error!("failed to add peer to the routing table: {e}"); + } + } + } + } + tracing::debug!("background DHT loop finished"); + }); + } + + fn refresh_local_peer_info(&self, network: &Network) { + let peer_info = self.make_local_peer_info(network, now_sec()); + *self.local_peer_info.lock().unwrap() = Some(peer_info); + } + + #[tracing::instrument(level = "debug", skip_all, fields(local_id = %self.local_id))] + async fn announce_local_peer_info(&self, network: &Network) -> Result<()> { + let now = now_sec(); + let data = { + let peer_info = self.make_local_peer_info(network, now); + let data = tl_proto::serialize(&peer_info); + *self.local_peer_info.lock().unwrap() = Some(peer_info); + data + }; + + let mut value = self.make_unsigned_peer_value( + PeerValueKeyName::NodeInfo, + &data, + now + self.config.max_peer_info_ttl.as_secs() as u32, + ); + let signature = network.sign_tl(&value); + value.signature = &signature; + + self.store_value(network, &ValueRef::Peer(value), true) + .await + } + + #[tracing::instrument(level = "debug", skip_all, fields(local_id = %self.local_id))] + async fn refresh_routing_table(&self, network: &Network) { + const PARALLEL_QUERIES: usize = 3; + const MAX_BUCKETS: usize = 15; + const QUERY_DEPTH: usize = 3; + + // Prepare futures for each bucket + let semaphore = Semaphore::new(PARALLEL_QUERIES); + let mut futures = FuturesUnordered::new(); + { + let rng = &mut rand::thread_rng(); + + let mut routing_table = self.routing_table.lock().unwrap(); + + // Filter out expired nodes + let now = now_sec(); + for (_, bucket) in routing_table.buckets.iter_mut() { + bucket.retain_nodes(|node| !node.is_expired(now, &self.config.max_peer_info_ttl)); + } + + // Iterate over the first non-empty buckets (at most `MAX_BUCKETS`) + for (&distance, _) in routing_table + .buckets + .iter() + .filter(|(&distance, bucket)| distance > 0 && !bucket.is_empty()) + .take(MAX_BUCKETS) + { + // Query the K closest nodes for a random ID at the specified distance from the local ID. + let random_id = random_key_at_distance(&routing_table.local_id, distance, rng); + let query = Query::new( + network.clone(), + &routing_table, + random_id.as_bytes(), + self.config.max_k, + DhtQueryMode::Closest, + ); + + futures.push(async { + let _permit = semaphore.acquire().await.unwrap(); + query.find_peers(Some(QUERY_DEPTH)).await + }); + } + } + + // Receive initial set of peers + let Some(mut peers) = futures.next().await else { + tracing::debug!("no new peers found"); + return; + }; + + // Merge new peers into the result set + while let Some(new_peers) = futures.next().await { + for (peer_id, peer) in new_peers { + match peers.entry(peer_id) { + // Just insert the peer if it's new + hash_map::Entry::Vacant(entry) => { + entry.insert(peer); + } + // Replace the peer if it's newer (by creation time) + hash_map::Entry::Occupied(mut entry) => { + if entry.get().created_at < peer.created_at { + entry.insert(peer); + } + } + } + } + } + + let mut routing_table = self.routing_table.lock().unwrap(); + let mut count = 0usize; + for peer in peers.into_values() { + if peer.id == self.local_id { + continue; + } + + let is_new = routing_table.add( + peer.clone(), + self.config.max_k, + &self.config.max_peer_info_ttl, + |peer_info| network.known_peers().insert(peer_info, false).ok(), + ); + count += is_new as usize; + } + + tracing::debug!(count, "found new peers"); + } +} diff --git a/network/src/dht/config.rs b/network/src/dht/config.rs index 0552bb5ee..cb5bf6cd6 100644 --- a/network/src/dht/config.rs +++ b/network/src/dht/config.rs @@ -51,7 +51,7 @@ pub struct DhtConfig { /// /// Default: 1 minute. #[serde(with = "serde_helpers::humantime")] - pub max_local_info_announce_period_jitter: Duration, + pub local_info_announce_period_max_jitter: Duration, /// A period of updating and populating the routing table. /// @@ -63,7 +63,7 @@ pub struct DhtConfig { /// /// Default: 1 minutes. #[serde(with = "serde_helpers::humantime")] - pub max_routing_table_refresh_period_jitter: Duration, + pub routing_table_refresh_period_max_jitter: Duration, /// The capacity of the announced peers channel. /// @@ -81,9 +81,9 @@ impl Default for DhtConfig { storage_item_time_to_idle: None, local_info_refresh_period: Duration::from_secs(60), local_info_announce_period: Duration::from_secs(600), - max_local_info_announce_period_jitter: Duration::from_secs(60), + local_info_announce_period_max_jitter: Duration::from_secs(60), routing_table_refresh_period: Duration::from_secs(600), - max_routing_table_refresh_period_jitter: Duration::from_secs(60), + routing_table_refresh_period_max_jitter: Duration::from_secs(60), announced_peers_channel_capacity: 10, } } diff --git a/network/src/dht/mod.rs b/network/src/dht/mod.rs index eccc346e1..23d2163a2 100644 --- a/network/src/dht/mod.rs +++ b/network/src/dht/mod.rs @@ -1,21 +1,17 @@ -use std::collections::hash_map; use std::sync::{Arc, Mutex}; use anyhow::Result; use bytes::{Buf, Bytes}; -use futures_util::stream::FuturesUnordered; -use futures_util::StreamExt; use rand::RngCore; use tl_proto::TlRead; -use tokio::sync::{broadcast, Semaphore}; -use tokio::task::JoinHandle; +use tokio::sync::broadcast; use tycho_util::realloc_box_enum; -use tycho_util::time::{now_sec, shifted_interval}; +use tycho_util::time::now_sec; use self::query::{Query, QueryCache, StoreValue}; use self::routing::HandlesRoutingTable; use self::storage::Storage; -use crate::network::{Network, WeakNetwork}; +use crate::network::Network; use crate::proto::dht::{ rpc, NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef, PeerValueRef, Value, ValueRef, ValueResponseRaw, @@ -25,8 +21,10 @@ use crate::util::{NetworkExt, Routable}; pub use self::config::DhtConfig; pub use self::peer_resolver::{PeerResolver, PeerResolverBuilder, PeerResolverHandle}; -pub use self::storage::{OverlayValueMerger, StorageError}; +pub use self::query::DhtQueryMode; +pub use self::storage::{DhtValueMerger, DhtValueSource, StorageError}; +mod background_tasks; mod config; mod peer_resolver; mod query; @@ -71,6 +69,13 @@ impl DhtClient { idx: 0, } } + + /// Find a value by its key hash. + /// + /// This is quite a low-level method, so it is recommended to use [`DhtClient::entry`]. + pub async fn find_value(&self, key_hash: &[u8; 32], mode: DhtQueryMode) -> Option> { + self.inner.find_value(&self.network, key_hash, mode).await + } } #[derive(Clone, Copy)] @@ -97,12 +102,16 @@ impl<'a> DhtQueryBuilder<'a> { peer_id, }); - match self.inner.find_value(self.network, &key_hash).await { + match self + .inner + .find_value(self.network, &key_hash, DhtQueryMode::Closest) + .await + { Some(value) => match value.as_ref() { Value::Peer(value) => { tl_proto::deserialize(&value.data).map_err(FindValueError::InvalidData) } - Value::Overlay(_) => Err(FindValueError::InvalidData( + Value::Merged(_) => Err(FindValueError::InvalidData( tl_proto::TlError::UnknownConstructor, )), }, @@ -119,11 +128,15 @@ impl<'a> DhtQueryBuilder<'a> { peer_id, }); - match self.inner.find_value(self.network, &key_hash).await { + match self + .inner + .find_value(self.network, &key_hash, DhtQueryMode::Closest) + .await + { Some(value) => { realloc_box_enum!(value, { Value::Peer(value) => Box::new(value) => Ok(value), - Value::Overlay(_) => Err(FindValueError::InvalidData( + Value::Merged(_) => Err(FindValueError::InvalidData( tl_proto::TlError::UnknownConstructor, )), }) @@ -174,22 +187,25 @@ impl DhtQueryWithDataBuilder<'_> { let dht = self.inner.inner; let network = self.inner.network; - let mut value = PeerValueRef { - key: PeerValueKeyRef { - name: self.inner.name, - peer_id: &dht.local_id, - }, - data: &self.data, - expires_at: self.at.unwrap_or_else(now_sec) + self.ttl, - signature: &[0; 64], - }; + let mut value = self.make_unsigned_value_ref(); let signature = network.sign_tl(&value); value.signature = &signature; - dht.store_value(network, ValueRef::Peer(value), self.with_peer_info) + dht.store_value(network, &ValueRef::Peer(value), self.with_peer_info) .await } + pub fn store_locally(&self) -> Result { + let dht = self.inner.inner; + let network = self.inner.network; + + let mut value = self.make_unsigned_value_ref(); + let signature = network.sign_tl(&value); + value.signature = &signature; + + dht.store_value_locally(&ValueRef::Peer(value)) + } + pub fn into_signed_value(self) -> PeerValue { let dht = self.inner.inner; let network = self.inner.network; @@ -206,6 +222,18 @@ impl DhtQueryWithDataBuilder<'_> { *value.signature = network.sign_tl(&value); value } + + fn make_unsigned_value_ref(&self) -> PeerValueRef<'_> { + PeerValueRef { + key: PeerValueKeyRef { + name: self.inner.name, + peer_id: &self.inner.inner.local_id, + }, + data: &self.data, + expires_at: self.at.unwrap_or_else(now_sec) + self.ttl, + signature: &[0; 64], + } + } } impl<'a> std::ops::Deref for DhtQueryWithDataBuilder<'a> { @@ -238,7 +266,6 @@ impl DhtServiceBackgroundTasks { pub struct DhtServiceBuilder { local_id: PeerId, config: Option, - overlay_merger: Option>, } impl DhtServiceBuilder { @@ -247,11 +274,6 @@ impl DhtServiceBuilder { self } - pub fn with_overlay_value_merger(mut self, merger: Arc) -> Self { - self.overlay_merger = Some(merger); - self - } - pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) { let config = self.config.unwrap_or_default(); @@ -264,10 +286,6 @@ impl DhtServiceBuilder { builder = builder.with_max_idle(time_to_idle); } - if let Some(ref merger) = self.overlay_merger { - builder = builder.with_overlay_value_merger(merger); - } - builder.build() }; @@ -306,7 +324,6 @@ impl DhtService { DhtServiceBuilder { local_id, config: None, - overlay_merger: None, } } @@ -324,6 +341,22 @@ impl DhtService { pub fn has_peer(&self, peer_id: &PeerId) -> bool { self.0.routing_table.lock().unwrap().contains(peer_id) } + + pub fn store_value_locally(&self, value: &ValueRef<'_>) -> Result { + self.0.store_value_locally(value) + } + + pub fn insert_merger( + &self, + group_id: &[u8; 32], + merger: Arc, + ) -> Option> { + self.0.storage.insert_merger(group_id, merger) + } + + pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option> { + self.0.storage.remove_merger(group_id) + } } impl Service for DhtService { @@ -438,205 +471,12 @@ struct DhtInner { } impl DhtInner { - fn start_background_tasks(self: &Arc, network: WeakNetwork) { - enum Action { - RefreshLocalPeerInfo, - AnnounceLocalPeerInfo, - RefreshRoutingTable, - AddPeer(Arc), - } - - let mut refresh_peer_info_interval = - tokio::time::interval(self.config.local_info_refresh_period); - let mut announce_peer_info_interval = shifted_interval( - self.config.local_info_announce_period, - self.config.max_local_info_announce_period_jitter, - ); - let mut refresh_routing_table_interval = shifted_interval( - self.config.routing_table_refresh_period, - self.config.max_routing_table_refresh_period_jitter, - ); - - let mut announced_peers = self.announced_peers.subscribe(); - - let this = Arc::downgrade(self); - tokio::spawn(async move { - tracing::debug!("background DHT loop started"); - - let mut prev_refresh_routing_table_fut = None::>; - loop { - let action = tokio::select! { - _ = refresh_peer_info_interval.tick() => Action::RefreshLocalPeerInfo, - _ = announce_peer_info_interval.tick() => Action::AnnounceLocalPeerInfo, - _ = refresh_routing_table_interval.tick() => Action::RefreshRoutingTable, - peer = announced_peers.recv() => match peer { - Ok(peer) => Action::AddPeer(peer), - Err(_) => continue, - } - }; - - let (Some(this), Some(network)) = (this.upgrade(), network.upgrade()) else { - break; - }; - - match action { - Action::RefreshLocalPeerInfo => { - this.refresh_local_peer_info(&network); - } - Action::AnnounceLocalPeerInfo => { - // Peer info is always refreshed before announcing - refresh_peer_info_interval.reset(); - - if let Err(e) = this.announce_local_peer_info(&network).await { - tracing::error!("failed to announce local DHT node info: {e}"); - } - } - Action::RefreshRoutingTable => { - if let Some(fut) = prev_refresh_routing_table_fut.take() { - if let Err(e) = fut.await { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } - } - } - - prev_refresh_routing_table_fut = Some(tokio::spawn(async move { - this.refresh_routing_table(&network).await; - })); - } - Action::AddPeer(peer_info) => { - let peer_id = peer_info.id; - let added = this.add_peer_info(&network, peer_info); - tracing::debug!( - local_id = %this.local_id, - %peer_id, - ?added, - "received peer info", - ); - - if let Err(e) = added { - tracing::error!("failed to add peer to the routing table: {e}"); - } - } - } - } - tracing::debug!("background DHT loop finished"); - }); - } - - fn refresh_local_peer_info(&self, network: &Network) { - let peer_info = self.make_local_peer_info(network, now_sec()); - *self.local_peer_info.lock().unwrap() = Some(peer_info); - } - - #[tracing::instrument(level = "debug", skip_all, fields(local_id = %self.local_id))] - async fn announce_local_peer_info(&self, network: &Network) -> Result<()> { - let now = now_sec(); - let data = { - let peer_info = self.make_local_peer_info(network, now); - let data = tl_proto::serialize(&peer_info); - *self.local_peer_info.lock().unwrap() = Some(peer_info); - data - }; - - let mut value = self.make_unsigned_peer_value( - PeerValueKeyName::NodeInfo, - &data, - now + self.config.max_peer_info_ttl.as_secs() as u32, - ); - let signature = network.sign_tl(&value); - value.signature = &signature; - - self.store_value(network, ValueRef::Peer(value), true).await - } - - #[tracing::instrument(level = "debug", skip_all, fields(local_id = %self.local_id))] - async fn refresh_routing_table(&self, network: &Network) { - const PARALLEL_QUERIES: usize = 3; - const MAX_BUCKETS: usize = 15; - const QUERY_DEPTH: usize = 3; - - // Prepare futures for each bucket - let semaphore = Semaphore::new(PARALLEL_QUERIES); - let mut futures = FuturesUnordered::new(); - { - let rng = &mut rand::thread_rng(); - - let mut routing_table = self.routing_table.lock().unwrap(); - - // Filter out expired nodes - let now = now_sec(); - for (_, bucket) in routing_table.buckets.iter_mut() { - bucket.retain_nodes(|node| !node.is_expired(now, &self.config.max_peer_info_ttl)); - } - - // Iterate over the first non-empty buckets (at most `MAX_BUCKETS`) - for (&distance, _) in routing_table - .buckets - .iter() - .filter(|(&distance, bucket)| distance > 0 && !bucket.is_empty()) - .take(MAX_BUCKETS) - { - // Query the K closest nodes for a random ID at the specified distance from the local ID. - let random_id = random_key_at_distance(&routing_table.local_id, distance, rng); - let query = Query::new( - network.clone(), - &routing_table, - random_id.as_bytes(), - self.config.max_k, - ); - - futures.push(async { - let _permit = semaphore.acquire().await.unwrap(); - query.find_peers(Some(QUERY_DEPTH)).await - }); - } - } - - // Receive initial set of peers - let Some(mut peers) = futures.next().await else { - tracing::debug!("no new peers found"); - return; - }; - - // Merge new peers into the result set - while let Some(new_peers) = futures.next().await { - for (peer_id, peer) in new_peers { - match peers.entry(peer_id) { - // Just insert the peer if it's new - hash_map::Entry::Vacant(entry) => { - entry.insert(peer); - } - // Replace the peer if it's newer (by creation time) - hash_map::Entry::Occupied(mut entry) => { - if entry.get().created_at < peer.created_at { - entry.insert(peer); - } - } - } - } - } - - let mut routing_table = self.routing_table.lock().unwrap(); - let mut count = 0usize; - for peer in peers.into_values() { - if peer.id == self.local_id { - continue; - } - - let is_new = routing_table.add( - peer.clone(), - self.config.max_k, - &self.config.max_peer_info_ttl, - |peer_info| network.known_peers().insert(peer_info, false).ok(), - ); - count += is_new as usize; - } - - tracing::debug!(count, "found new peers"); - } - - async fn find_value(&self, network: &Network, key_hash: &[u8; 32]) -> Option> { + async fn find_value( + &self, + network: &Network, + key_hash: &[u8; 32], + mode: DhtQueryMode, + ) -> Option> { self.find_value_queries .run(key_hash, || { let query = Query::new( @@ -644,6 +484,7 @@ impl DhtInner { &self.routing_table.lock().unwrap(), key_hash, self.config.max_k, + mode, ); // NOTE: expression is intentionally split to drop the routing table guard @@ -655,10 +496,10 @@ impl DhtInner { async fn store_value( &self, network: &Network, - value: ValueRef<'_>, + value: &ValueRef<'_>, with_peer_info: bool, ) -> Result<()> { - self.storage.insert(&value)?; + self.storage.insert(DhtValueSource::Local, value)?; let local_peer_info = if with_peer_info { let mut node_info = self.local_peer_info.lock().unwrap(); @@ -684,6 +525,10 @@ impl DhtInner { Ok(()) } + fn store_value_locally(&self, value: &ValueRef<'_>) -> Result { + self.storage.insert(DhtValueSource::Local, value) + } + fn add_peer_info(&self, network: &Network, peer_info: Arc) -> Result { anyhow::ensure!(peer_info.is_valid(now_sec()), "invalid peer info"); @@ -735,7 +580,7 @@ impl DhtInner { peer_info.id == req.metadata.peer_id, "suggested peer ID does not belong to the sender" ); - self.announced_peers.send(peer_info).ok(); + self.announced_peers.send(Arc::new(peer_info)).ok(); body = &body[offset..]; anyhow::ensure!(body.len() >= 4, tl_proto::TlError::UnexpectedEof); @@ -748,7 +593,7 @@ impl DhtInner { } fn handle_store(&self, req: &rpc::StoreRef<'_>) -> Result { - self.storage.insert(&req.value) + self.storage.insert(DhtValueSource::Remote, &req.value) } fn handle_find_node(&self, req: &rpc::FindNode) -> NodeResponse { diff --git a/network/src/dht/query.rs b/network/src/dht/query.rs index dd6ed2cf9..19fe5188e 100644 --- a/network/src/dht/query.rs +++ b/network/src/dht/query.rs @@ -108,6 +108,13 @@ impl Default for QueryCache { type WeakSpawnedFut = WeakShared>; +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum DhtQueryMode { + #[default] + Closest, + Random, +} + pub struct Query { network: Network, candidates: SimpleRoutingTable, @@ -120,9 +127,20 @@ impl Query { routing_table: &HandlesRoutingTable, target_id: &[u8; 32], max_k: usize, + mode: DhtQueryMode, ) -> Self { let mut candidates = SimpleRoutingTable::new(PeerId(*target_id)); - routing_table.visit_closest(target_id, max_k, |node| { + + let random_id; + let target_id_for_full = match mode { + DhtQueryMode::Closest => target_id, + DhtQueryMode::Random => { + random_id = rand::random(); + &random_id + } + }; + + routing_table.visit_closest(target_id_for_full, max_k, |node| { candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some); }); @@ -384,20 +402,21 @@ impl StoreValue<()> { pub fn new( network: Network, routing_table: &HandlesRoutingTable, - value: ValueRef<'_>, + value: &ValueRef<'_>, max_k: usize, local_peer_info: Option<&PeerInfo>, ) -> StoreValue, Option>)> + Send> { - let key_hash = match &value { + let key_hash = match value { ValueRef::Peer(value) => tl_proto::hash(&value.key), - ValueRef::Overlay(value) => tl_proto::hash(&value.key), + ValueRef::Merged(value) => tl_proto::hash(&value.key), }; let request_body = Bytes::from(match local_peer_info { - Some(peer_info) => { - tl_proto::serialize((rpc::WithPeerInfoRef { peer_info }, rpc::StoreRef { value })) - } - None => tl_proto::serialize(rpc::StoreRef { value }), + Some(peer_info) => tl_proto::serialize(( + rpc::WithPeerInfo::wrap(peer_info), + rpc::StoreRef::wrap(value), + )), + None => tl_proto::serialize(rpc::StoreRef::wrap(value)), }); let semaphore = Arc::new(Semaphore::new(10)); diff --git a/network/src/dht/storage.rs b/network/src/dht/storage.rs index 76041b389..28e76e211 100644 --- a/network/src/dht/storage.rs +++ b/network/src/dht/storage.rs @@ -1,5 +1,5 @@ use std::cell::RefCell; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use std::time::Duration; use anyhow::Result; @@ -8,29 +8,37 @@ use moka::sync::{Cache, CacheBuilder}; use moka::Expiry; use tl_proto::TlWrite; use tycho_util::time::now_sec; +use tycho_util::FastDashMap; -use crate::proto::dht::{OverlayValue, OverlayValueRef, PeerValueRef, ValueRef}; +use crate::proto::dht::{MergedValue, MergedValueRef, PeerValueRef, ValueRef}; type DhtCache = Cache; type DhtCacheBuilder = CacheBuilder>; -pub trait OverlayValueMerger: Send + Sync + 'static { - fn check_value(&self, new: &OverlayValueRef<'_>) -> Result<(), StorageError>; - fn merge_value(&self, new: &OverlayValueRef<'_>, stored: &mut OverlayValue) -> bool; +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub enum DhtValueSource { + Local, + Remote, } -impl OverlayValueMerger for () { - fn check_value(&self, _new: &OverlayValueRef<'_>) -> Result<(), StorageError> { - Err(StorageError::InvalidKey) - } - fn merge_value(&self, _new: &OverlayValueRef<'_>, _stored: &mut OverlayValue) -> bool { - false - } +pub trait DhtValueMerger: Send + Sync + 'static { + fn check_value( + &self, + source: DhtValueSource, + new: &MergedValueRef<'_>, + ) -> Result<(), StorageError>; + + fn merge_value( + &self, + source: DhtValueSource, + new: &MergedValueRef<'_>, + stored: &mut MergedValue, + ) -> bool; } pub(crate) struct StorageBuilder { cache_builder: DhtCacheBuilder, - overlay_value_merger: Weak, + value_mergers: FastDashMap<[u8; 32], Arc>, max_ttl: Duration, } @@ -38,7 +46,7 @@ impl Default for StorageBuilder { fn default() -> Self { Self { cache_builder: Default::default(), - overlay_value_merger: Weak::<()>::new(), + value_mergers: Default::default(), max_ttl: Duration::from_secs(3600), } } @@ -59,13 +67,18 @@ impl StorageBuilder { .weigher(weigher) .expire_after(ValueExpiry) .build_with_hasher(ahash::RandomState::default()), - overlay_value_merger: self.overlay_value_merger, + value_mergers: self.value_mergers, max_ttl_sec: self.max_ttl.as_secs().try_into().unwrap_or(u32::MAX), } } - pub fn with_overlay_value_merger(mut self, merger: &Arc) -> Self { - self.overlay_value_merger = Arc::downgrade(merger); + #[allow(unused)] + pub fn with_value_merger( + self, + group_id: &[u8; 32], + value_merger: Arc, + ) -> Self { + self.value_mergers.insert(*group_id, value_merger); self } @@ -87,7 +100,7 @@ impl StorageBuilder { pub(crate) struct Storage { cache: DhtCache, - overlay_value_merger: Weak, + value_mergers: FastDashMap<[u8; 32], Arc>, max_ttl_sec: u32, } @@ -96,12 +109,30 @@ impl Storage { StorageBuilder::default() } + pub fn insert_merger( + &self, + group_id: &[u8; 32], + merger: Arc, + ) -> Option> { + self.value_mergers.insert(*group_id, merger) + } + + pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option> { + self.value_mergers + .remove(group_id) + .map(|(_, merger)| merger) + } + pub fn get(&self, key: &[u8; 32]) -> Option { let stored_value = self.cache.get(key)?; (stored_value.expires_at > now_sec()).then_some(stored_value.data) } - pub fn insert(&self, value: &ValueRef<'_>) -> Result { + pub fn insert( + &self, + source: DhtValueSource, + value: &ValueRef<'_>, + ) -> Result { match value.expires_at().checked_sub(now_sec()) { Some(0) | None => return Err(StorageError::ValueExpired), Some(remaining_ttl) if remaining_ttl > self.max_ttl_sec => { @@ -112,7 +143,7 @@ impl Storage { match value { ValueRef::Peer(value) => self.insert_signed_value(value), - ValueRef::Overlay(value) => self.insert_overlay_value(value), + ValueRef::Merged(value) => self.insert_merged_value(source, value), } } @@ -138,19 +169,24 @@ impl Storage { .is_fresh()) } - fn insert_overlay_value(&self, value: &OverlayValueRef<'_>) -> Result { - let Some(merger) = self.overlay_value_merger.upgrade() else { - return Ok(false); + fn insert_merged_value( + &self, + source: DhtValueSource, + value: &MergedValueRef<'_>, + ) -> Result { + let merger = match self.value_mergers.get(value.key.group_id) { + Some(merger) => merger.clone(), + None => return Ok(false), }; - merger.check_value(value)?; + merger.check_value(source, value)?; - enum OverlayValueCow<'a, 'b> { - Borrowed(&'a OverlayValueRef<'b>), - Owned(OverlayValue), + enum MergedValueCow<'a, 'b> { + Borrowed(&'a MergedValueRef<'b>), + Owned(MergedValue), } - impl OverlayValueCow<'_, '_> { + impl MergedValueCow<'_, '_> { fn make_stored_value(&self) -> StoredValue { match self { Self::Borrowed(value) => StoredValue::new(*value, value.expires_at), @@ -159,7 +195,7 @@ impl Storage { } } - let new_value = RefCell::new(OverlayValueCow::Borrowed(value)); + let new_value = RefCell::new(MergedValueCow::Borrowed(value)); Ok(self .cache @@ -170,13 +206,13 @@ impl Storage { value.make_stored_value() }, |prev| { - let Ok(mut prev) = tl_proto::deserialize::(&prev.data) else { + let Ok(mut prev) = tl_proto::deserialize::(&prev.data) else { // Invalid values are always replaced with new values return true; }; - if merger.merge_value(value, &mut prev) { - *new_value.borrow_mut() = OverlayValueCow::Owned(prev); + if merger.merge_value(source, value, &mut prev) { + *new_value.borrow_mut() = MergedValueCow::Owned(prev); true } else { false @@ -250,4 +286,6 @@ pub enum StorageError { InvalidSignature, #[error("value too big")] ValueTooBig, + #[error("invalid source")] + InvalidSource, } diff --git a/network/src/lib.rs b/network/src/lib.rs index f701fc434..c276c25cf 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -6,9 +6,9 @@ pub use self::overlay::{ }; pub use self::util::{check_peer_signature, NetworkExt, Routable, Router, RouterBuilder}; pub use dht::{ - xor_distance, DhtClient, DhtConfig, DhtQueryBuilder, DhtQueryWithDataBuilder, DhtService, - DhtServiceBackgroundTasks, DhtServiceBuilder, FindValueError, OverlayValueMerger, PeerResolver, - PeerResolverBuilder, PeerResolverHandle, StorageError, + xor_distance, DhtClient, DhtConfig, DhtQueryBuilder, DhtQueryMode, DhtQueryWithDataBuilder, + DhtService, DhtServiceBackgroundTasks, DhtServiceBuilder, DhtValueMerger, DhtValueSource, + FindValueError, PeerResolver, PeerResolverBuilder, PeerResolverHandle, StorageError, }; pub use network::{ ActivePeers, Connection, KnownPeerHandle, KnownPeers, KnownPeersError, Network, NetworkBuilder, diff --git a/network/src/overlay/background_tasks.rs b/network/src/overlay/background_tasks.rs new file mode 100644 index 000000000..58bdae552 --- /dev/null +++ b/network/src/overlay/background_tasks.rs @@ -0,0 +1,421 @@ +use std::sync::Arc; + +use anyhow::Result; +use rand::Rng; +use tycho_util::time::{now_sec, shifted_interval}; + +use crate::dht::{DhtClient, DhtQueryMode, DhtService}; +use crate::network::{KnownPeerHandle, Network, WeakNetwork}; +use crate::overlay::tasks_stream::TasksStream; +use crate::overlay::{OverlayId, OverlayServiceInner, PublicEntry, PublicOverlayEntries}; +use crate::proto::dht::{MergedValueKeyName, MergedValueKeyRef, Value}; +use crate::proto::overlay::{rpc, PublicEntriesResponse, PublicEntryToSign}; +use crate::types::Request; +use crate::util::NetworkExt; + +impl OverlayServiceInner { + pub(crate) fn start_background_tasks( + self: &Arc, + network: WeakNetwork, + dht_service: Option, + ) { + enum Action<'a> { + UpdatePublicOverlaysList(&'a mut PublicOverlaysState), + ExchangePublicOverlayEntries { + overlay_id: OverlayId, + tasks: &'a mut TasksStream, + }, + DiscoverPublicOverlayEntries { + overlay_id: OverlayId, + tasks: &'a mut TasksStream, + }, + StorePublicEntries { + overlay_id: OverlayId, + tasks: &'a mut TasksStream, + }, + } + + struct PublicOverlaysState { + exchange: TasksStream, + discover: TasksStream, + store: TasksStream, + } + + let public_overlays_notify = self.public_overlays_changed.clone(); + + let this = Arc::downgrade(self); + tokio::spawn(async move { + tracing::debug!("background overlay loop started"); + + let mut public_overlays_changed = Box::pin(public_overlays_notify.notified()); + let mut public_overlays_state = None::; + + loop { + let action = match &mut public_overlays_state { + // Initial update for public overlays list + None => Action::UpdatePublicOverlaysList(public_overlays_state.insert( + PublicOverlaysState { + exchange: TasksStream::new("exchange public overlay peers"), + discover: TasksStream::new("discover public overlay entries in DHT"), + store: TasksStream::new("store public overlay entries in DHT"), + }, + )), + // Default actions + Some(public_overlays_state) => { + tokio::select! { + _ = &mut public_overlays_changed => { + public_overlays_changed = Box::pin(public_overlays_notify.notified()); + Action::UpdatePublicOverlaysList(public_overlays_state) + }, + overlay_id = public_overlays_state.exchange.next() => match overlay_id { + Some(id) => Action::ExchangePublicOverlayEntries { + overlay_id: id, + tasks: &mut public_overlays_state.exchange, + }, + None => continue, + }, + overlay_id = public_overlays_state.discover.next() => match overlay_id { + Some(id) => Action::DiscoverPublicOverlayEntries { + overlay_id: id, + tasks: &mut public_overlays_state.discover, + }, + None => continue, + }, + overlay_id = public_overlays_state.store.next() => match overlay_id { + Some(id) => Action::StorePublicEntries { + overlay_id: id, + tasks: &mut public_overlays_state.store, + }, + None => continue, + }, + } + } + }; + + let (Some(this), Some(network)) = (this.upgrade(), network.upgrade()) else { + break; + }; + + match action { + Action::UpdatePublicOverlaysList(PublicOverlaysState { + exchange, + discover, + store, + }) => { + let iter = this.public_overlays.iter().map(|item| *item.key()); + exchange.rebuild(iter.clone(), |_| { + shifted_interval( + this.config.public_overlay_peer_exchange_period, + this.config.public_overlay_peer_exchange_max_jitter, + ) + }); + discover.rebuild(iter.clone(), |_| { + shifted_interval( + this.config.public_overlay_peer_discovery_period, + this.config.public_overlay_peer_discovery_max_jitter, + ) + }); + store.rebuild_ext( + iter, + |overlay_id| { + // Insert merger for new overlays + if let Some(dht) = &dht_service { + dht.insert_merger( + overlay_id.as_bytes(), + this.public_entries_merger.clone(), + ); + } + + shifted_interval( + this.config.public_overlay_peer_store_period, + this.config.public_overlay_peer_store_max_jitter, + ) + }, + |overlay_id| { + // Remove merger for removed overlays + if let Some(dht) = &dht_service { + dht.remove_merger(overlay_id.as_bytes()); + } + }, + ); + } + Action::ExchangePublicOverlayEntries { overlay_id, tasks } => { + tasks.spawn(&overlay_id, move || async move { + this.exchange_public_entries(&network, &overlay_id).await + }); + } + Action::DiscoverPublicOverlayEntries { overlay_id, tasks } => { + let Some(dht_service) = dht_service.clone() else { + continue; + }; + + tasks.spawn(&overlay_id, move || async move { + this.discover_public_entries( + &dht_service.make_client(&network), + &overlay_id, + ) + .await + }); + } + Action::StorePublicEntries { overlay_id, tasks } => { + let Some(dht_service) = dht_service.clone() else { + continue; + }; + + tasks.spawn(&overlay_id, move || async move { + this.store_public_entries( + &dht_service.make_client(&network), + &overlay_id, + ) + .await + }); + } + } + } + + tracing::debug!("background overlay loop stopped"); + }); + } + + #[tracing::instrument( + level = "debug", + skip_all, + fields(local_id = %self.local_id, overlay_id = %overlay_id), + )] + async fn exchange_public_entries( + &self, + network: &Network, + overlay_id: &OverlayId, + ) -> Result<()> { + let overlay = if let Some(overlay) = self.public_overlays.get(overlay_id) { + overlay.value().clone() + } else { + tracing::warn!("overlay not found"); + return Ok(()); + }; + + overlay.remove_invalid_entries(now_sec()); + + let n = std::cmp::max(self.config.exchange_public_entries_batch, 1); + let mut entries = Vec::with_capacity(n); + + // Always include us in the response + entries.push(Arc::new(self.make_local_public_overlay_entry( + network, + overlay_id, + now_sec(), + ))); + + // Choose a random target to send the request and additional random entries + let target_peer_handle; + let target_peer_id; + { + let rng = &mut rand::thread_rng(); + + let all_entries = overlay.read_entries(); + + match choose_random_resolved_peer(&all_entries, rng) { + Some(handle) => { + target_peer_handle = handle; + target_peer_id = target_peer_handle.load_peer_info().id; + } + None => { + tracing::warn!("no resolved peers in the overlay to exchange entries with"); + return Ok(()); + } + } + + // Add additional random entries to the response. + // NOTE: `n` instead of `n - 1` because we might ignore the target peer + entries.extend( + all_entries + .choose_multiple(rng, n) + .filter(|&item| (item.entry.peer_id != target_peer_id)) + .map(|item| item.entry.clone()) + .take(n - 1), + ); + }; + + // Send request + let response = network + .query( + &target_peer_id, + Request::from_tl(rpc::ExchangeRandomPublicEntries { + overlay_id: overlay_id.to_bytes(), + entries, + }), + ) + .await? + .parse_tl::()?; + + // NOTE: Ensure that resolved peer handle is alive for enough time + drop(target_peer_handle); + + // Populate the overlay with the response + match response { + PublicEntriesResponse::PublicEntries(entries) => { + tracing::debug!( + peer_id = %target_peer_id, + count = entries.len(), + "received public entries" + ); + overlay.add_untrusted_entries(&entries, now_sec()); + } + PublicEntriesResponse::OverlayNotFound => { + tracing::debug!( + peer_id = %target_peer_id, + "peer does not have the overlay", + ); + } + } + + // Done + Ok(()) + } + + #[tracing::instrument( + level = "debug", + skip_all, + fields(local_id = %self.local_id, overlay_id = %overlay_id), + )] + async fn discover_public_entries( + &self, + dht_client: &DhtClient, + overlay_id: &OverlayId, + ) -> Result<()> { + let overlay = if let Some(overlay) = self.public_overlays.get(overlay_id) { + overlay.value().clone() + } else { + tracing::warn!(%overlay_id, "overlay not found"); + return Ok(()); + }; + + let key_hash = tl_proto::hash(MergedValueKeyRef { + name: MergedValueKeyName::PublicOverlayEntries, + group_id: overlay_id.as_bytes(), + }); + + let entries = match dht_client.find_value(&key_hash, DhtQueryMode::Random).await { + Some(value) => match &*value { + Value::Merged(value) => { + tl_proto::deserialize::>>(&value.data)? + } + Value::Peer(_) => { + tracing::warn!("expected a `Value::Merged`, but got a `Value::Peer`"); + return Ok(()); + } + }, + None => { + tracing::debug!("no public entries found in the DHT"); + return Ok(()); + } + }; + + overlay.add_untrusted_entries(&entries, now_sec()); + + tracing::debug!(count = entries.len(), "discovered public entries"); + Ok(()) + } + + #[tracing::instrument( + level = "debug", + skip_all, + fields(local_id = %self.local_id, overlay_id = %overlay_id), + )] + async fn store_public_entries( + &self, + dht_client: &DhtClient, + overlay_id: &OverlayId, + ) -> Result<()> { + use crate::proto::dht; + + const DEFAULT_TTL: u32 = 3600; // 1 hour + + let overlay = if let Some(overlay) = self.public_overlays.get(overlay_id) { + overlay.value().clone() + } else { + tracing::warn!(%overlay_id, "overlay not found"); + return Ok(()); + }; + + let now = now_sec(); + let mut n = std::cmp::max(self.config.public_overlay_peer_store_max_entries, 1); + + let data = { + let rng = &mut rand::thread_rng(); + + let mut entries = Vec::>::with_capacity(n); + + // Always include us in the list + entries.push(Arc::new(self.make_local_public_overlay_entry( + dht_client.network(), + overlay_id, + now, + ))); + + // Fill with random entries + entries.extend( + overlay + .read_entries() + .choose_multiple(rng, n - 1) + .map(|item| item.entry.clone()), + ); + + n = entries.len(); + + // Serialize entries + tl_proto::serialize(&entries) + }; + + // Store entries in the DHT + let value = dht::ValueRef::Merged(dht::MergedValueRef { + key: dht::MergedValueKeyRef { + name: dht::MergedValueKeyName::PublicOverlayEntries, + group_id: overlay_id.as_bytes(), + }, + data: &data, + expires_at: now + DEFAULT_TTL, + }); + + // TODO: Store the value on other nodes as well? + dht_client.service().store_value_locally(&value)?; + + tracing::debug!(count = n, "stored public entries in the DHT",); + Ok(()) + } + + fn make_local_public_overlay_entry( + &self, + network: &Network, + overlay_id: &OverlayId, + now: u32, + ) -> PublicEntry { + let signature = Box::new(network.sign_tl(PublicEntryToSign { + overlay_id: overlay_id.as_bytes(), + peer_id: &self.local_id, + created_at: now, + })); + PublicEntry { + peer_id: self.local_id, + created_at: now, + signature, + } + } +} + +fn choose_random_resolved_peer( + entries: &PublicOverlayEntries, + rng: &mut R, +) -> Option +where + R: Rng + ?Sized, +{ + entries + .choose_all(rng) + .find(|item| item.resolver_handle.is_resolved()) + .map(|item| { + item.resolver_handle + .load_handle() + .expect("invalid resolved flag state") + }) +} diff --git a/network/src/overlay/config.rs b/network/src/overlay/config.rs index f1094ee32..ee9fc6496 100644 --- a/network/src/overlay/config.rs +++ b/network/src/overlay/config.rs @@ -6,6 +6,23 @@ use tycho_util::serde_helpers; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct OverlayConfig { + /// A period of storing public overlay entries in local DHT. + /// + /// Default: 3 minutes. + #[serde(with = "serde_helpers::humantime")] + pub public_overlay_peer_store_period: Duration, + + /// A maximum value of a random jitter for the entries store period. + /// + /// Default: 30 seconds. + #[serde(with = "serde_helpers::humantime")] + pub public_overlay_peer_store_max_jitter: Duration, + + /// A maximum number of public overlay entries to store. + /// + /// Default: 20. + pub public_overlay_peer_store_max_entries: usize, + /// A period of exchanging public overlay peers. /// /// Default: 3 minutes. @@ -15,8 +32,21 @@ pub struct OverlayConfig { /// A maximum value of a random jitter for the peer exchange period. /// /// Default: 30 seconds. + #[serde(with = "serde_helpers::humantime")] pub public_overlay_peer_exchange_max_jitter: Duration, + /// A period of discovering public overlay peers. + /// + /// Default: 3 minutes. + #[serde(with = "serde_helpers::humantime")] + pub public_overlay_peer_discovery_period: Duration, + + /// A maximum value of a random jitter for the peer discovery period. + /// + /// Default: 30 seconds. + #[serde(with = "serde_helpers::humantime")] + pub public_overlay_peer_discovery_max_jitter: Duration, + /// Number of peers to send during entries exchange request. /// /// Default: 20. @@ -26,8 +56,13 @@ pub struct OverlayConfig { impl Default for OverlayConfig { fn default() -> Self { Self { + public_overlay_peer_store_period: Duration::from_secs(3 * 60), + public_overlay_peer_store_max_jitter: Duration::from_secs(30), + public_overlay_peer_store_max_entries: 20, public_overlay_peer_exchange_period: Duration::from_secs(3 * 60), public_overlay_peer_exchange_max_jitter: Duration::from_secs(30), + public_overlay_peer_discovery_period: Duration::from_secs(3 * 60), + public_overlay_peer_discovery_max_jitter: Duration::from_secs(30), exchange_public_entries_batch: 20, } } diff --git a/network/src/overlay/entries_merger.rs b/network/src/overlay/entries_merger.rs new file mode 100644 index 000000000..ae353732e --- /dev/null +++ b/network/src/overlay/entries_merger.rs @@ -0,0 +1,35 @@ +use crate::dht::{DhtValueMerger, DhtValueSource, StorageError}; +use crate::proto::dht::{MergedValue, MergedValueRef}; + +/// Allows only local values to be stored. +/// Always overwrites the stored value with the new value. +#[derive(Debug, Default, Clone, Copy)] +pub struct PublicOverlayEntriesMerger; + +impl DhtValueMerger for PublicOverlayEntriesMerger { + fn check_value( + &self, + source: DhtValueSource, + _: &MergedValueRef<'_>, + ) -> Result<(), StorageError> { + if source != DhtValueSource::Local { + return Err(StorageError::InvalidSource); + } + + Ok(()) + } + + fn merge_value( + &self, + source: DhtValueSource, + new: &MergedValueRef<'_>, + stored: &mut MergedValue, + ) -> bool { + if source != DhtValueSource::Local { + return false; + } + + *stored = new.as_owned(); + true + } +} diff --git a/network/src/overlay/mod.rs b/network/src/overlay/mod.rs index 78f0ccb5d..c8abd000d 100644 --- a/network/src/overlay/mod.rs +++ b/network/src/overlay/mod.rs @@ -1,24 +1,18 @@ -use std::collections::hash_map; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll, Waker}; -use anyhow::Result; use bytes::Buf; -use futures_util::{Stream, StreamExt}; use tl_proto::{TlError, TlRead}; use tokio::sync::Notify; -use tokio::task::{AbortHandle, JoinSet}; use tycho_util::futures::BoxFutureOrNoop; -use tycho_util::time::{now_sec, shifted_interval}; -use tycho_util::{FastDashMap, FastHashMap, FastHashSet}; +use tycho_util::time::now_sec; +use tycho_util::{FastDashMap, FastHashSet}; +use self::entries_merger::PublicOverlayEntriesMerger; use crate::dht::DhtService; -use crate::network::{Network, WeakNetwork}; -use crate::proto::overlay::{rpc, PublicEntriesResponse, PublicEntry, PublicEntryToSign}; -use crate::types::{PeerId, Request, Response, Service, ServiceRequest}; -use crate::util::{NetworkExt, Routable}; +use crate::network::Network; +use crate::proto::overlay::{rpc, PublicEntriesResponse, PublicEntry}; +use crate::types::{PeerId, Response, Service, ServiceRequest}; +use crate::util::Routable; pub use self::config::OverlayConfig; pub use self::overlay_id::OverlayId; @@ -30,10 +24,13 @@ pub use self::public_overlay::{ PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries, PublicOverlayEntriesReadGuard, }; +mod background_tasks; mod config; +mod entries_merger; mod overlay_id; mod private_overlay; mod public_overlay; +mod tasks_stream; pub struct OverlayServiceBackgroundTasks { inner: Arc, @@ -74,6 +71,7 @@ impl OverlayServiceBuilder { public_overlays: Default::default(), public_overlays_changed: Arc::new(Notify::new()), private_overlays_changed: Arc::new(Notify::new()), + public_entries_merger: Arc::new(PublicOverlayEntriesMerger), }); let background_tasks = OverlayServiceBackgroundTasks { @@ -253,187 +251,10 @@ struct OverlayServiceInner { private_overlays: FastDashMap, public_overlays_changed: Arc, private_overlays_changed: Arc, + public_entries_merger: Arc, } impl OverlayServiceInner { - fn start_background_tasks(self: &Arc, network: WeakNetwork, _dht: Option) { - // TODO: Store public overlay entries in the DHT. - - enum Action<'a> { - UpdatePublicOverlaysList(&'a mut PublicOverlaysState), - ExchangePublicOverlayEntries { - overlay_id: OverlayId, - exchange: &'a mut OverlayTaskSet, - }, - } - - struct PublicOverlaysState { - exchange: OverlayTaskSet, - } - - let public_overlays_notify = self.public_overlays_changed.clone(); - - let this = Arc::downgrade(self); - tokio::spawn(async move { - tracing::debug!("background overlay loop started"); - - let mut public_overlays_changed = Box::pin(public_overlays_notify.notified()); - - let mut public_overlays_state = None::; - - loop { - let action = match &mut public_overlays_state { - // Initial update for public overlays list - None => Action::UpdatePublicOverlaysList(public_overlays_state.insert( - PublicOverlaysState { - exchange: OverlayTaskSet::new("exchange public overlay peers"), - }, - )), - // Default actions - Some(public_overlays_state) => { - tokio::select! { - _ = &mut public_overlays_changed => { - public_overlays_changed = Box::pin(public_overlays_notify.notified()); - Action::UpdatePublicOverlaysList(public_overlays_state) - }, - overlay_id = public_overlays_state.exchange.next() => match overlay_id { - Some(id) => Action::ExchangePublicOverlayEntries { - overlay_id: id, - exchange: &mut public_overlays_state.exchange, - }, - None => continue, - }, - } - } - }; - - let (Some(this), Some(network)) = (this.upgrade(), network.upgrade()) else { - break; - }; - - match action { - Action::UpdatePublicOverlaysList(PublicOverlaysState { exchange }) => { - let iter = this.public_overlays.iter().map(|item| *item.key()); - exchange.rebuild(iter.clone(), |_| { - shifted_interval( - this.config.public_overlay_peer_exchange_period, - this.config.public_overlay_peer_exchange_max_jitter, - ) - }); - } - Action::ExchangePublicOverlayEntries { - exchange: exchange_state, - overlay_id, - } => { - exchange_state.spawn(&overlay_id, move || async move { - this.exchange_public_entries(&network, &overlay_id).await - }); - } - } - } - - tracing::debug!("background overlay loop stopped"); - }); - } - - #[tracing::instrument( - level = "debug", - skip_all, - fields(local_id = %self.local_id, overlay_id = %overlay_id), - )] - async fn exchange_public_entries( - &self, - network: &Network, - overlay_id: &OverlayId, - ) -> Result<()> { - let overlay = if let Some(overlay) = self.public_overlays.get(overlay_id) { - overlay.value().clone() - } else { - tracing::debug!(%overlay_id, "overlay not found"); - return Ok(()); - }; - - overlay.remove_invalid_entries(now_sec()); - - let n = std::cmp::max(self.config.exchange_public_entries_batch, 1); - let mut entries = Vec::with_capacity(n); - - // Always include us in the response - entries.push(Arc::new(self.make_local_public_overlay_entry( - network, - overlay_id, - now_sec(), - ))); - - // Choose a random target to send the request and additional random entries - let peer_id = { - let rng = &mut rand::thread_rng(); - - let all_entries = overlay.read_entries(); - let mut iter = all_entries.choose_multiple(rng, n); - - // TODO: search for target in known peers. This is a stub which will not work. - let peer_id = match iter.next() { - Some(item) => item.entry.peer_id, - None => anyhow::bail!("empty overlay, no peers to exchange entries with"), - }; - - // Add additional random entries to the response - entries.extend(iter.map(|item| item.entry.clone())); - - // Use this peer id for the request - peer_id - }; - - // Send request - let response = network - .query( - &peer_id, - Request::from_tl(rpc::ExchangeRandomPublicEntries { - overlay_id: overlay_id.to_bytes(), - entries, - }), - ) - .await? - .parse_tl::()?; - - // Populate the overlay with the response - match response { - PublicEntriesResponse::PublicEntries(entries) => { - tracing::debug!( - %peer_id, - count = entries.len(), - "received public entries" - ); - overlay.add_untrusted_entries(&entries, now_sec()); - } - PublicEntriesResponse::OverlayNotFound => { - tracing::debug!(%peer_id, "overlay not found"); - } - } - - // Done - Ok(()) - } - - fn make_local_public_overlay_entry( - &self, - network: &Network, - overlay_id: &OverlayId, - now: u32, - ) -> PublicEntry { - let signature = Box::new(network.sign_tl(PublicEntryToSign { - overlay_id: overlay_id.as_bytes(), - peer_id: &self.local_id, - created_at: now, - })); - PublicEntry { - peer_id: self.local_id, - created_at: now, - signature, - } - } - fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool { use dashmap::mapref::entry::Entry; @@ -523,168 +344,3 @@ impl OverlayServiceInner { PublicEntriesResponse::PublicEntries(entries) } } - -struct OverlayTaskSet { - name: &'static str, - stream: OverlayActionsStream, - handles: FastHashMap, - join_set: JoinSet, -} - -impl OverlayTaskSet { - fn new(name: &'static str) -> Self { - Self { - name, - stream: Default::default(), - handles: Default::default(), - join_set: Default::default(), - } - } - - async fn next(&mut self) -> Option { - use futures_util::future::{select, Either}; - - loop { - // Wait until the next interval or completed task - let res = { - let next = std::pin::pin!(self.stream.next()); - let joined = std::pin::pin!(self.join_set.join_next()); - match select(next, joined).await { - // Handle interval events first - Either::Left((id, _)) => return id, - // Handled task completion otherwise - Either::Right((joined, fut)) => match joined { - Some(res) => res, - None => return fut.await, - }, - } - }; - - // If some task was joined - match res { - // Task was completed successfully - Ok(overlay_id) => { - return if matches!(self.handles.remove(&overlay_id), Some((_, true))) { - // Reset interval and execute task immediately - self.stream.reset_interval(&overlay_id); - Some(overlay_id) - } else { - None - }; - } - // Propagate task panic - Err(e) if e.is_panic() => { - tracing::error!(task = self.name, "task panicked"); - std::panic::resume_unwind(e.into_panic()); - } - // Task cancelled, loop once more with the next task - Err(_) => continue, - } - } - } - - fn rebuild(&mut self, iter: I, f: F) - where - I: Iterator, - for<'a> F: FnMut(&'a OverlayId) -> tokio::time::Interval, - { - self.stream.rebuild(iter, f, |overlay_id| { - if let Some((handle, _)) = self.handles.remove(overlay_id) { - tracing::debug!(task = self.name, %overlay_id, "task cancelled"); - handle.abort(); - } - }); - } - - fn spawn(&mut self, overlay_id: &OverlayId, f: F) - where - F: FnOnce() -> Fut, - Fut: Future> + Send + 'static, - { - match self.handles.entry(*overlay_id) { - hash_map::Entry::Vacant(entry) => { - let fut = { - let fut = f(); - let task = self.name; - let overlay_id = *overlay_id; - async move { - if let Err(e) = fut.await { - tracing::error!(task, %overlay_id, "task failed: {e:?}"); - } - overlay_id - } - }; - entry.insert((self.join_set.spawn(fut), false)); - } - hash_map::Entry::Occupied(mut entry) => { - tracing::warn!( - task = self.name, - %overlay_id, - "task is running longer than expected", - ); - entry.get_mut().1 = true; - } - } - } -} - -#[derive(Default)] -struct OverlayActionsStream { - intervals: Vec<(tokio::time::Interval, OverlayId)>, - waker: Option, -} - -impl OverlayActionsStream { - fn reset_interval(&mut self, overlay_id: &OverlayId) { - if let Some((interval, _)) = self.intervals.iter_mut().find(|(_, id)| id == overlay_id) { - interval.reset(); - } - } - - fn rebuild, A, R>( - &mut self, - iter: I, - mut on_add: A, - mut on_remove: R, - ) where - for<'a> A: FnMut(&'a OverlayId) -> tokio::time::Interval, - for<'a> R: FnMut(&'a OverlayId), - { - let mut new_overlays = iter.collect::>(); - self.intervals.retain(|(_, id)| { - let retain = new_overlays.remove(id); - if !retain { - on_remove(id); - } - retain - }); - - for id in new_overlays { - self.intervals.push((on_add(&id), id)); - } - - if let Some(waker) = &self.waker { - waker.wake_by_ref(); - } - } -} - -impl Stream for OverlayActionsStream { - type Item = OverlayId; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Always register the waker to resume the stream even if there were - // changes in the intervals. - if !matches!(&self.waker, Some(waker) if cx.waker().will_wake(waker)) { - self.waker = Some(cx.waker().clone()); - } - - for (interval, data) in self.intervals.iter_mut() { - if interval.poll_tick(cx).is_ready() { - return Poll::Ready(Some(*data)); - } - } - - Poll::Pending - } -} diff --git a/network/src/overlay/public_overlay.rs b/network/src/overlay/public_overlay.rs index 5a7de4221..4c9f85880 100644 --- a/network/src/overlay/public_overlay.rs +++ b/network/src/overlay/public_overlay.rs @@ -74,6 +74,12 @@ impl PublicOverlayBuilder { overlay_id: self.overlay_id.as_bytes(), }); + let entries = PublicOverlayEntries { + peer_id_to_index: Default::default(), + data: Default::default(), + peer_resolver: self.peer_resolver, + }; + let entry_ttl_sec = self.entry_ttl.as_secs().try_into().unwrap_or(u32::MAX); PublicOverlay { @@ -81,7 +87,7 @@ impl PublicOverlayBuilder { overlay_id: self.overlay_id, min_capacity: self.min_capacity, entry_ttl_sec, - entries: RwLock::new(Default::default()), + entries: RwLock::new(entries), entry_count: AtomicUsize::new(0), banned_peer_ids: self.banned_peer_ids, service: service.boxed(), @@ -307,7 +313,6 @@ struct Inner { request_prefix: Box<[u8]>, } -#[derive(Default)] pub struct PublicOverlayEntries { peer_id_to_index: FastHashMap, data: Vec, @@ -315,6 +320,28 @@ pub struct PublicOverlayEntries { } impl PublicOverlayEntries { + /// Returns `true` if the set contains no elements. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Returns the number of elements in the set, also referred to as its 'length'. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns true if the set contains the specified peer id. + pub fn contains(&self, peer_id: &PeerId) -> bool { + self.peer_id_to_index.contains_key(peer_id) + } + + /// Returns an iterator over the entries. + /// + /// The order is not random, but is not defined. + pub fn iter(&self) -> std::slice::Iter<'_, PublicOverlayEntryData> { + self.data.iter() + } + /// Returns a reference to one random element of the slice, /// or `None` if the slice is empty. pub fn choose(&self, rng: &mut R) -> Option<&PublicOverlayEntryData> @@ -337,6 +364,18 @@ impl PublicOverlayEntries { self.data.choose_multiple(rng, n) } + /// Chooses all entries from the set, without repetition, + /// and in random order. + pub fn choose_all( + &self, + rng: &mut R, + ) -> rand::seq::SliceChooseIter<'_, [PublicOverlayEntryData], PublicOverlayEntryData> + where + R: Rng + ?Sized, + { + self.data.choose_multiple(rng, self.data.len()) + } + fn insert(&mut self, item: &PublicEntry) -> bool { match self.peer_id_to_index.entry(item.peer_id) { // No entry for the peer_id, insert a new one diff --git a/network/src/overlay/tasks_stream.rs b/network/src/overlay/tasks_stream.rs new file mode 100644 index 000000000..fdf546304 --- /dev/null +++ b/network/src/overlay/tasks_stream.rs @@ -0,0 +1,186 @@ +use std::collections::hash_map; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; + +use anyhow::Result; +use futures_util::{Future, Stream, StreamExt}; +use tokio::task::{AbortHandle, JoinSet}; +use tycho_util::{FastHashMap, FastHashSet}; + +use crate::overlay::OverlayId; + +pub(crate) struct TasksStream { + name: &'static str, + stream: IdsStream, + handles: FastHashMap, + join_set: JoinSet, +} + +impl TasksStream { + pub fn new(name: &'static str) -> Self { + Self { + name, + stream: Default::default(), + handles: Default::default(), + join_set: Default::default(), + } + } + + pub async fn next(&mut self) -> Option { + use futures_util::future::{select, Either}; + + loop { + // Wait until the next interval or completed task + let res = { + let next = std::pin::pin!(self.stream.next()); + let joined = std::pin::pin!(self.join_set.join_next()); + match select(next, joined).await { + // Handle interval events first + Either::Left((id, _)) => return id, + // Handled task completion otherwise + Either::Right((joined, fut)) => match joined { + Some(res) => res, + None => return fut.await, + }, + } + }; + + // If some task was joined + match res { + // Task was completed successfully + Ok(overlay_id) => { + return if matches!(self.handles.remove(&overlay_id), Some((_, true))) { + // Reset interval and execute task immediately + self.stream.reset_interval(&overlay_id); + Some(overlay_id) + } else { + None + }; + } + // Propagate task panic + Err(e) if e.is_panic() => { + tracing::error!(task = self.name, "task panicked"); + std::panic::resume_unwind(e.into_panic()); + } + // Task cancelled, loop once more with the next task + Err(_) => continue, + } + } + } + + pub fn rebuild(&mut self, iter: I, f: F) + where + I: Iterator, + for<'a> F: FnMut(&'a OverlayId) -> tokio::time::Interval, + { + self.rebuild_ext(iter, f, |_| {}) + } + + pub fn rebuild_ext(&mut self, iter: I, on_add: F, mut on_remove: R) + where + I: Iterator, + for<'a> F: FnMut(&'a OverlayId) -> tokio::time::Interval, + for<'a> R: FnMut(&'a OverlayId), + { + self.stream.rebuild(iter, on_add, |overlay_id| { + on_remove(overlay_id); + + if let Some((handle, _)) = self.handles.remove(overlay_id) { + tracing::debug!(task = self.name, %overlay_id, "task cancelled"); + handle.abort(); + } + }); + } + + pub fn spawn(&mut self, overlay_id: &OverlayId, f: F) + where + F: FnOnce() -> Fut, + Fut: Future> + Send + 'static, + { + match self.handles.entry(*overlay_id) { + hash_map::Entry::Vacant(entry) => { + let fut = { + let fut = f(); + let task = self.name; + let overlay_id = *overlay_id; + async move { + if let Err(e) = fut.await { + tracing::error!(task, %overlay_id, "task failed: {e:?}"); + } + overlay_id + } + }; + entry.insert((self.join_set.spawn(fut), false)); + } + hash_map::Entry::Occupied(mut entry) => { + tracing::warn!( + task = self.name, + %overlay_id, + "task is running longer than expected", + ); + entry.get_mut().1 = true; + } + } + } +} + +#[derive(Default)] +struct IdsStream { + intervals: Vec<(tokio::time::Interval, OverlayId)>, + waker: Option, +} + +impl IdsStream { + fn reset_interval(&mut self, overlay_id: &OverlayId) { + if let Some((interval, _)) = self.intervals.iter_mut().find(|(_, id)| id == overlay_id) { + interval.reset(); + } + } + + fn rebuild, A, R>( + &mut self, + iter: I, + mut on_add: A, + mut on_remove: R, + ) where + for<'a> A: FnMut(&'a OverlayId) -> tokio::time::Interval, + for<'a> R: FnMut(&'a OverlayId), + { + let mut new_overlays = iter.collect::>(); + self.intervals.retain(|(_, id)| { + let retain = new_overlays.remove(id); + if !retain { + on_remove(id); + } + retain + }); + + for id in new_overlays { + self.intervals.push((on_add(&id), id)); + } + + if let Some(waker) = &self.waker { + waker.wake_by_ref(); + } + } +} + +impl Stream for IdsStream { + type Item = OverlayId; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Always register the waker to resume the stream even if there were + // changes in the intervals. + if !matches!(&self.waker, Some(waker) if cx.waker().will_wake(waker)) { + self.waker = Some(cx.waker().clone()); + } + + for (interval, data) in self.intervals.iter_mut() { + if interval.poll_tick(cx).is_ready() { + return Poll::Ready(Some(*data)); + } + } + + Poll::Pending + } +} diff --git a/network/src/proto.tl b/network/src/proto.tl index 60ad4f88e..1866c6f01 100644 --- a/network/src/proto.tl +++ b/network/src/proto.tl @@ -52,22 +52,22 @@ dht.peerValueKey = dht.Key; /** -* Key for the overlay-managed value +* Key for the group-managed value * * @param name key name enum -* @param overlay_id overlay id +* @param group_id group id */ -dht.overlayValueKey - name:dht.OverlayValueKeyName - overlay_id:int256 +dht.mergedValueKey + name:dht.MergedValueKeyName + group_id:int256 = dht.Key; // Peer value key names { dht.peerValueKeyName.nodeInfo = dht.PeerValueKeyName; // } -// Overlay value key names { -dht.overlayValueKeyName.peersList = dht.OverlayValueKeyName; +// Merged value key names { +dht.mergedValueKeyName.publicOverlayEntries = dht.MergedValueKeyName; // } /** @@ -80,13 +80,13 @@ dht.overlayValueKeyName.peersList = dht.OverlayValueKeyName; dht.peerValue key:dht.peerValueKey data:bytes expires_at:int signature:bytes = dht.Value; /** -* An overlay-managed value +* An group-managed value * -* @param key overlay key +* @param key key info * @param value any data * @param expires_at unix timestamp up to which this value is valid */ -dht.overlayValue key:dht.overlayValueKey data:bytes expires_at:int = dht.Value; +dht.mergedValue key:dht.mergedValueKey data:bytes expires_at:int = dht.Value; /** diff --git a/network/src/proto/dht.rs b/network/src/proto/dht.rs index 312379c90..0ef19a3c1 100644 --- a/network/src/proto/dht.rs +++ b/network/src/proto/dht.rs @@ -15,9 +15,9 @@ pub enum PeerValueKeyName { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, TlRead, TlWrite)] #[tl(boxed, scheme = "proto.tl")] -pub enum OverlayValueKeyName { - #[tl(id = "dht.overlayValueKeyName.peersList")] - PeersList, +pub enum MergedValueKeyName { + #[tl(id = "dht.mergedValueKeyName.publicOverlayEntries")] + PublicOverlayEntries, } /// Key for values that can only be updated by the owner. @@ -53,35 +53,35 @@ impl PeerValueKeyRef<'_> { } } -/// Key for overlay-managed values. +/// Key for group-managed values. /// -/// See [`OverlayValueKeyRef`] for the non-owned version of the struct. +/// See [`MergedValueKeyRef`] for the non-owned version of the struct. #[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)] -#[tl(boxed, id = "dht.overlayValueKey", scheme = "proto.tl")] -pub struct OverlayValueKey { +#[tl(boxed, id = "dht.mergedValueKey", scheme = "proto.tl")] +pub struct MergedValueKey { /// Key name. - pub name: OverlayValueKeyName, - /// Overlay id. - pub overlay_id: [u8; 32], + pub name: MergedValueKeyName, + /// Group id. + pub group_id: [u8; 32], } -/// Key for overlay-managed values. +/// Key for group-managed values. /// -/// See [`OverlayValueKey`] for the owned version of the struct. +/// See [`MergedValueKey`] for the owned version of the struct. #[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)] -#[tl(boxed, id = "dht.overlayValueKey", scheme = "proto.tl")] -pub struct OverlayValueKeyRef<'tl> { +#[tl(boxed, id = "dht.mergedValueKey", scheme = "proto.tl")] +pub struct MergedValueKeyRef<'tl> { /// Key name. - pub name: OverlayValueKeyName, - /// Overlay id. - pub overlay_id: &'tl [u8; 32], + pub name: MergedValueKeyName, + /// Group id. + pub group_id: &'tl [u8; 32], } -impl OverlayValueKeyRef<'_> { - pub fn as_owned(&self) -> OverlayValueKey { - OverlayValueKey { +impl MergedValueKeyRef<'_> { + pub fn as_owned(&self) -> MergedValueKey { + MergedValueKey { name: self.name, - overlay_id: *self.overlay_id, + group_id: *self.group_id, } } } @@ -131,37 +131,37 @@ impl PeerValueRef<'_> { } } -/// Overlay-managed value. +/// Group-managed value. /// -/// See [`OverlayValueRef`] for the non-owned version of the struct. +/// See [`MergedValueRef`] for the non-owned version of the struct. #[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)] -#[tl(boxed, id = "dht.overlayValue", scheme = "proto.tl")] -pub struct OverlayValue { - /// Overlay key. - pub key: OverlayValueKey, +#[tl(boxed, id = "dht.mergedValue", scheme = "proto.tl")] +pub struct MergedValue { + /// Key info. + pub key: MergedValueKey, /// Any data. pub data: Box<[u8]>, /// Unix timestamp up to which this value is valid. pub expires_at: u32, } -/// Overlay-managed value. +/// Group-managed value. /// -/// See [`OverlayValue`] for the owned version of the struct. +/// See [`MergedValue`] for the owned version of the struct. #[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)] -#[tl(boxed, id = "dht.overlayValue", scheme = "proto.tl")] -pub struct OverlayValueRef<'tl> { - /// Overlay key. - pub key: OverlayValueKeyRef<'tl>, +#[tl(boxed, id = "dht.mergedValue", scheme = "proto.tl")] +pub struct MergedValueRef<'tl> { + /// Key info. + pub key: MergedValueKeyRef<'tl>, /// Any data. pub data: &'tl [u8], /// Unix timestamp up to which this value is valid. pub expires_at: u32, } -impl OverlayValueRef<'_> { - pub fn as_owned(&self) -> OverlayValue { - OverlayValue { +impl MergedValueRef<'_> { + pub fn as_owned(&self) -> MergedValue { + MergedValue { key: self.key.as_owned(), data: Box::from(self.data), expires_at: self.expires_at, @@ -176,8 +176,8 @@ impl OverlayValueRef<'_> { pub enum Value { /// Value with a known owner. Peer(PeerValue), - /// Overlay-managed value. - Overlay(OverlayValue), + /// Group-managed value. + Merged(MergedValue), } impl Value { @@ -188,7 +188,7 @@ impl Value { && key_hash == &tl_proto::hash(&value.key) && check_peer_signature(&value.key.peer_id, &value.signature, value) } - Self::Overlay(value) => { + Self::Merged(value) => { value.expires_at >= at && key_hash == &tl_proto::hash(&value.key) } } @@ -197,7 +197,7 @@ impl Value { pub const fn expires_at(&self) -> u32 { match self { Self::Peer(value) => value.expires_at, - Self::Overlay(value) => value.expires_at, + Self::Merged(value) => value.expires_at, } } } @@ -208,7 +208,7 @@ impl TlWrite for Value { fn max_size_hint(&self) -> usize { match self { Self::Peer(value) => value.max_size_hint(), - Self::Overlay(value) => value.max_size_hint(), + Self::Merged(value) => value.max_size_hint(), } } @@ -218,7 +218,7 @@ impl TlWrite for Value { { match self { Self::Peer(value) => value.write_to(packet), - Self::Overlay(value) => value.write_to(packet), + Self::Merged(value) => value.write_to(packet), } } } @@ -231,7 +231,7 @@ impl<'a> TlRead<'a> for Value { *offset -= 4; match id { PeerValue::TL_ID => PeerValue::read_from(packet, offset).map(Self::Peer), - OverlayValue::TL_ID => OverlayValue::read_from(packet, offset).map(Self::Overlay), + MergedValue::TL_ID => MergedValue::read_from(packet, offset).map(Self::Merged), _ => Err(tl_proto::TlError::UnknownConstructor), } } @@ -244,8 +244,8 @@ impl<'a> TlRead<'a> for Value { pub enum ValueRef<'tl> { /// Value with a known owner. Peer(PeerValueRef<'tl>), - /// Overlay-managed value. - Overlay(OverlayValueRef<'tl>), + /// Group-managed value. + Merged(MergedValueRef<'tl>), } impl ValueRef<'_> { @@ -256,7 +256,7 @@ impl ValueRef<'_> { && key_hash == &tl_proto::hash(&value.key) && check_peer_signature(value.key.peer_id, value.signature, value) } - Self::Overlay(value) => { + Self::Merged(value) => { value.expires_at >= at && key_hash == &tl_proto::hash(&value.key) } } @@ -265,7 +265,7 @@ impl ValueRef<'_> { pub const fn expires_at(&self) -> u32 { match self { Self::Peer(value) => value.expires_at, - Self::Overlay(value) => value.expires_at, + Self::Merged(value) => value.expires_at, } } } @@ -276,7 +276,7 @@ impl TlWrite for ValueRef<'_> { fn max_size_hint(&self) -> usize { match self { Self::Peer(value) => value.max_size_hint(), - Self::Overlay(value) => value.max_size_hint(), + Self::Merged(value) => value.max_size_hint(), } } @@ -286,7 +286,7 @@ impl TlWrite for ValueRef<'_> { { match self { Self::Peer(value) => value.write_to(packet), - Self::Overlay(value) => value.write_to(packet), + Self::Merged(value) => value.write_to(packet), } } } @@ -299,7 +299,7 @@ impl<'a> TlRead<'a> for ValueRef<'a> { *offset -= 4; match id { PeerValue::TL_ID => PeerValueRef::read_from(packet, offset).map(Self::Peer), - OverlayValue::TL_ID => OverlayValueRef::read_from(packet, offset).map(Self::Overlay), + MergedValue::TL_ID => MergedValueRef::read_from(packet, offset).map(Self::Merged), _ => Err(tl_proto::TlError::UnknownConstructor), } } @@ -378,22 +378,23 @@ pub mod rpc { /// Query wrapper with an announced peer info. #[derive(Debug, Clone, TlRead, TlWrite)] #[tl(boxed, id = "dht.withPeerInfo", scheme = "proto.tl")] + #[repr(transparent)] pub struct WithPeerInfo { /// A signed info of the sender. - pub peer_info: Arc, + pub peer_info: PeerInfo, } - /// Query wrapper with an announced peer info. - #[derive(Debug, Clone, TlWrite)] - #[tl(boxed, id = "dht.withPeerInfo", scheme = "proto.tl")] - pub struct WithPeerInfoRef<'tl> { - /// A signed info of the sender. - pub peer_info: &'tl PeerInfo, + impl WithPeerInfo { + pub fn wrap(value: &'_ PeerInfo) -> &'_ Self { + // SAFETY: `rpc::WithPeerInfo` has the same memory layout as `PeerInfo`. + unsafe { &*(value as *const PeerInfo).cast() } + } } /// Suggest a node to store that value. #[derive(Debug, Clone, TlRead, TlWrite)] #[tl(boxed, id = "dht.store", scheme = "proto.tl")] + #[repr(transparent)] pub struct Store { /// A value to store. pub value: Value, @@ -402,12 +403,20 @@ pub mod rpc { /// Suggest a node to store that value. #[derive(Debug, Clone, TlRead, TlWrite)] #[tl(boxed, id = "dht.store", scheme = "proto.tl")] + #[repr(transparent)] pub struct StoreRef<'tl> { /// A value to store. pub value: ValueRef<'tl>, } - /// Search for `k` closest nodes. + impl<'tl> StoreRef<'tl> { + pub fn wrap<'a>(value: &'a ValueRef<'tl>) -> &'a Self { + // SAFETY: `rpc::StoreRef` has the same memory layout as `ValueRef`. + unsafe { &*(value as *const ValueRef<'tl>).cast() } + } + } + + /// Search for `k` the closest nodes. /// /// See [`NodeResponse`]. #[derive(Debug, Clone, TlRead, TlWrite)] @@ -419,7 +428,7 @@ pub mod rpc { pub k: u32, } - /// Search for a value if stored or `k` closest nodes. + /// Search for a value if stored or `k` the closest nodes. /// /// See [`ValueResponse`]. #[derive(Debug, Clone, TlRead, TlWrite)] diff --git a/network/tests/common/mod.rs b/network/tests/common/mod.rs new file mode 100644 index 000000000..7d4d85d84 --- /dev/null +++ b/network/tests/common/mod.rs @@ -0,0 +1,136 @@ +use std::net::Ipv4Addr; +use std::time::Duration; + +use everscale_crypto::ed25519; +use tl_proto::{TlRead, TlWrite}; +use tycho_network::{ + DhtConfig, DhtService, Network, OverlayConfig, OverlayService, PeerResolver, Response, Router, + Service, ServiceRequest, +}; + +pub fn init_logger() { + tracing_subscriber::fmt::try_init().ok(); + tracing::info!("bootstrap_nodes_accessible"); + + std::panic::set_hook(Box::new(|info| { + use std::io::Write; + + tracing::error!("{}", info); + std::io::stderr().flush().ok(); + std::io::stdout().flush().ok(); + std::process::exit(1); + })); +} + +pub struct NodeBase { + pub network: Network, + pub dht_service: DhtService, + pub overlay_service: OverlayService, + pub peer_resolver: PeerResolver, +} + +impl NodeBase { + pub fn with_random_key() -> Self { + let key = ed25519::SecretKey::generate(&mut rand::thread_rng()); + let local_id = ed25519::PublicKey::from(&key).into(); + + let (dht_tasks, dht_service) = DhtService::builder(local_id) + .with_config(make_fast_dht_config()) + .build(); + + let (overlay_tasks, overlay_service) = OverlayService::builder(local_id) + .with_config(make_fast_overlay_config()) + .with_dht_service(dht_service.clone()) + .build(); + + let router = Router::builder() + .route(dht_service.clone()) + .route(overlay_service.clone()) + .build(); + + let network = Network::builder() + .with_private_key(key.to_bytes()) + .with_service_name("test-service") + .build((Ipv4Addr::LOCALHOST, 0), router) + .unwrap(); + + dht_tasks.spawn(&network); + overlay_tasks.spawn(&network); + + let peer_resolver = dht_service.make_peer_resolver().build(&network); + + Self { + network, + dht_service, + overlay_service, + peer_resolver, + } + } +} + +pub fn make_fast_dht_config() -> DhtConfig { + DhtConfig { + local_info_announce_period: Duration::from_secs(1), + local_info_announce_period_max_jitter: Duration::from_secs(1), + routing_table_refresh_period: Duration::from_secs(1), + routing_table_refresh_period_max_jitter: Duration::from_secs(1), + ..Default::default() + } +} + +pub fn make_fast_overlay_config() -> OverlayConfig { + OverlayConfig { + public_overlay_peer_store_period: Duration::from_secs(1), + public_overlay_peer_store_max_jitter: Duration::from_secs(1), + public_overlay_peer_exchange_period: Duration::from_secs(1), + public_overlay_peer_exchange_max_jitter: Duration::from_secs(1), + public_overlay_peer_discovery_period: Duration::from_secs(1), + public_overlay_peer_discovery_max_jitter: Duration::from_secs(1), + ..Default::default() + } +} + +pub struct PingPongService; + +impl Service for PingPongService { + type QueryResponse = Response; + type OnQueryFuture = futures_util::future::Ready>; + type OnMessageFuture = futures_util::future::Ready<()>; + type OnDatagramFuture = futures_util::future::Ready<()>; + + fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture { + futures_util::future::ready(match req.parse_tl() { + Ok(Ping { value }) => Some(Response::from_tl(Pong { value })), + Err(e) => { + tracing::error!( + peer_id = %req.metadata.peer_id, + addr = %req.metadata.remote_address, + "invalid request: {e:?}", + ); + None + } + }) + } + + #[inline] + fn on_message(&self, _req: ServiceRequest) -> Self::OnMessageFuture { + futures_util::future::ready(()) + } + + #[inline] + fn on_datagram(&self, _req: ServiceRequest) -> Self::OnDatagramFuture { + futures_util::future::ready(()) + } +} + +#[derive(Debug, Copy, Clone, TlRead, TlWrite)] +#[tl(boxed, id = 0x11223344)] +pub struct Ping { + pub value: u64, +} + +#[derive(Debug, Copy, Clone, TlRead, TlWrite)] +#[tl(boxed, id = 0x55667788)] +pub struct Pong { + pub value: u64, +} diff --git a/network/tests/dht.rs b/network/tests/dht.rs index 7b7303ac1..e720db019 100644 --- a/network/tests/dht.rs +++ b/network/tests/dht.rs @@ -30,9 +30,9 @@ impl Node { .with_config(DhtConfig { max_k: 20, routing_table_refresh_period: Duration::from_secs(1), - max_routing_table_refresh_period_jitter: Duration::from_secs(1), + routing_table_refresh_period_max_jitter: Duration::from_secs(1), local_info_announce_period: Duration::from_secs(1), - max_local_info_announce_period_jitter: Duration::from_secs(1), + local_info_announce_period_max_jitter: Duration::from_secs(1), ..Default::default() }) .build(); diff --git a/network/tests/overlay.rs b/network/tests/overlay.rs deleted file mode 100644 index 0c6e79494..000000000 --- a/network/tests/overlay.rs +++ /dev/null @@ -1,208 +0,0 @@ -//! Run tests with this env: -//! ```text -//! RUST_LOG=info,tycho_network=trace -//! ``` - -use std::net::Ipv4Addr; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use everscale_crypto::ed25519; -use futures_util::stream::FuturesUnordered; -use futures_util::StreamExt; -use tl_proto::{TlRead, TlWrite}; -use tycho_network::{ - DhtClient, DhtConfig, DhtService, Network, OverlayId, OverlayService, PeerId, PrivateOverlay, - Request, Response, Router, Service, ServiceRequest, -}; - -struct Node { - network: Network, - private_overlay: PrivateOverlay, - dht_client: DhtClient, -} - -impl Node { - fn with_random_key() -> Self { - let key = ed25519::SecretKey::generate(&mut rand::thread_rng()); - let local_id = ed25519::PublicKey::from(&key).into(); - - let (dht_tasks, dht_service) = DhtService::builder(local_id) - .with_config(DhtConfig { - local_info_announce_period: Duration::from_secs(1), - max_local_info_announce_period_jitter: Duration::from_secs(1), - routing_table_refresh_period: Duration::from_secs(1), - max_routing_table_refresh_period_jitter: Duration::from_secs(1), - ..Default::default() - }) - .build(); - - let (overlay_tasks, overlay_service) = OverlayService::builder(local_id) - .with_dht_service(dht_service.clone()) - .build(); - - let router = Router::builder() - .route(dht_service.clone()) - .route(overlay_service.clone()) - .build(); - - let network = Network::builder() - .with_private_key(key.to_bytes()) - .with_service_name("test-service") - .build((Ipv4Addr::LOCALHOST, 0), router) - .unwrap(); - - dht_tasks.spawn(&network); - overlay_tasks.spawn(&network); - - let dht_client = dht_service.make_client(&network); - let peer_resolver = dht_service.make_peer_resolver().build(&network); - - let private_overlay = PrivateOverlay::builder(PRIVATE_OVERLAY_ID) - .with_peer_resolver(peer_resolver) - .build(PingPongService); - overlay_service.add_private_overlay(&private_overlay); - - Self { - network, - dht_client, - private_overlay, - } - } - - async fn private_overlay_query(&self, peer_id: &PeerId, req: Q) -> Result - where - Q: tl_proto::TlWrite, - for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>, - { - self.private_overlay - .query(&self.network, peer_id, Request::from_tl(req)) - .await? - .parse_tl::() - .map_err(Into::into) - } -} - -fn make_network(node_count: usize) -> Vec { - let nodes = (0..node_count) - .map(|_| Node::with_random_key()) - .collect::>(); - - let common_peer_info = nodes.first().unwrap().network.sign_peer_info(0, u32::MAX); - - for node in &nodes { - node.dht_client - .add_peer(Arc::new(common_peer_info.clone())) - .unwrap(); - - let mut private_overlay_entries = node.private_overlay.write_entries(); - - for peer_id in nodes.iter().map(|node| node.network.peer_id()) { - if peer_id == node.network.peer_id() { - continue; - } - private_overlay_entries.insert(peer_id); - } - } - - nodes -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn private_overlays_accessible() -> Result<()> { - tracing_subscriber::fmt::try_init().ok(); - tracing::info!("bootstrap_nodes_accessible"); - - std::panic::set_hook(Box::new(|info| { - use std::io::Write; - - tracing::error!("{}", info); - std::io::stderr().flush().ok(); - std::io::stdout().flush().ok(); - std::process::exit(1); - })); - - let nodes = make_network(20); - - for node in &nodes { - let resolved = FuturesUnordered::new(); - for entry in node.private_overlay.read_entries().iter() { - let handle = entry.resolver_handle.clone(); - resolved.push(async move { handle.wait_resolved().await }); - } - - // Ensure all entries are resolved. - resolved.collect::>().await; - tracing::info!( - peer_id = %node.network.peer_id(), - "all entries resolved", - ); - } - - for i in 0..nodes.len() { - for j in 0..nodes.len() { - if i == j { - continue; - } - - let left = &nodes[i]; - let right = &nodes[j]; - - let value = (i * 1000 + j) as u64; - let Pong { value: received } = left - .private_overlay_query(right.network.peer_id(), Ping { value }) - .await?; - assert_eq!(received, value); - } - } - - Ok(()) -} - -struct PingPongService; - -impl Service for PingPongService { - type QueryResponse = Response; - type OnQueryFuture = futures_util::future::Ready>; - type OnMessageFuture = futures_util::future::Ready<()>; - type OnDatagramFuture = futures_util::future::Ready<()>; - - fn on_query(&self, req: ServiceRequest) -> Self::OnQueryFuture { - futures_util::future::ready(match req.parse_tl() { - Ok(Ping { value }) => Some(Response::from_tl(Pong { value })), - Err(e) => { - tracing::error!( - peer_id = %req.metadata.peer_id, - addr = %req.metadata.remote_address, - "invalid request: {e:?}", - ); - None - } - }) - } - - #[inline] - fn on_message(&self, _req: ServiceRequest) -> Self::OnMessageFuture { - futures_util::future::ready(()) - } - - #[inline] - fn on_datagram(&self, _req: ServiceRequest) -> Self::OnDatagramFuture { - futures_util::future::ready(()) - } -} - -#[derive(Debug, Copy, Clone, TlRead, TlWrite)] -#[tl(boxed, id = 0x11223344)] -struct Ping { - value: u64, -} - -#[derive(Debug, Copy, Clone, TlRead, TlWrite)] -#[tl(boxed, id = 0x55667788)] -struct Pong { - value: u64, -} - -static PRIVATE_OVERLAY_ID: OverlayId = OverlayId([0; 32]); diff --git a/network/tests/private_overlay.rs b/network/tests/private_overlay.rs new file mode 100644 index 000000000..8433288f0 --- /dev/null +++ b/network/tests/private_overlay.rs @@ -0,0 +1,126 @@ +//! Run tests with this env: +//! ```text +//! RUST_LOG=info,tycho_network=trace +//! ``` + +use std::sync::Arc; + +use anyhow::Result; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; +use tycho_network::{DhtClient, Network, OverlayId, PeerId, PrivateOverlay, Request}; + +use self::common::{init_logger, NodeBase, Ping, PingPongService, Pong}; + +mod common; + +struct Node { + network: Network, + private_overlay: PrivateOverlay, + dht_client: DhtClient, +} + +impl Node { + fn with_random_key() -> Self { + let NodeBase { + network, + dht_service, + overlay_service, + peer_resolver, + } = NodeBase::with_random_key(); + + let private_overlay = PrivateOverlay::builder(PRIVATE_OVERLAY_ID) + .with_peer_resolver(peer_resolver) + .build(PingPongService); + overlay_service.add_private_overlay(&private_overlay); + + let dht_client = dht_service.make_client(&network); + + Self { + network, + dht_client, + private_overlay, + } + } + + async fn private_overlay_query(&self, peer_id: &PeerId, req: Q) -> Result + where + Q: tl_proto::TlWrite, + for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>, + { + self.private_overlay + .query(&self.network, peer_id, Request::from_tl(req)) + .await? + .parse_tl::() + .map_err(Into::into) + } +} + +fn make_network(node_count: usize) -> Vec { + let nodes = (0..node_count) + .map(|_| Node::with_random_key()) + .collect::>(); + + let common_peer_info = nodes.first().unwrap().network.sign_peer_info(0, u32::MAX); + + for node in &nodes { + node.dht_client + .add_peer(Arc::new(common_peer_info.clone())) + .unwrap(); + + let mut private_overlay_entries = node.private_overlay.write_entries(); + + for peer_id in nodes.iter().map(|node| node.network.peer_id()) { + if peer_id == node.network.peer_id() { + continue; + } + private_overlay_entries.insert(peer_id); + } + } + + nodes +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn private_overlays_accessible() -> Result<()> { + init_logger(); + tracing::info!("private_overlays_accessible"); + + let nodes = make_network(20); + + for node in &nodes { + let resolved = FuturesUnordered::new(); + for entry in node.private_overlay.read_entries().iter() { + let handle = entry.resolver_handle.clone(); + resolved.push(async move { handle.wait_resolved().await }); + } + + // Ensure all entries are resolved. + resolved.collect::>().await; + tracing::info!( + peer_id = %node.network.peer_id(), + "all entries resolved", + ); + } + + for i in 0..nodes.len() { + for j in 0..nodes.len() { + if i == j { + continue; + } + + let left = &nodes[i]; + let right = &nodes[j]; + + let value = (i * 1000 + j) as u64; + let Pong { value: received } = left + .private_overlay_query(right.network.peer_id(), Ping { value }) + .await?; + assert_eq!(received, value); + } + } + + Ok(()) +} + +static PRIVATE_OVERLAY_ID: OverlayId = OverlayId([0; 32]); diff --git a/network/tests/public_overlay.rs b/network/tests/public_overlay.rs new file mode 100644 index 000000000..69995dd4c --- /dev/null +++ b/network/tests/public_overlay.rs @@ -0,0 +1,167 @@ +//! Run tests with this env: +//! ```text +//! RUST_LOG=info,tycho_network=trace +//! ``` + +use std::collections::BTreeMap; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; +use tycho_network::{DhtClient, Network, OverlayId, PeerId, PublicOverlay, Request}; + +use self::common::{init_logger, NodeBase, Ping, PingPongService, Pong}; + +mod common; + +struct Node { + network: Network, + public_overlay: PublicOverlay, + dht_client: DhtClient, +} + +impl Node { + fn with_random_key() -> Self { + let NodeBase { + network, + dht_service, + overlay_service, + peer_resolver, + } = NodeBase::with_random_key(); + + let public_overlay = PublicOverlay::builder(PUBLIC_OVERLAY_ID) + .with_peer_resolver(peer_resolver) + .build(PingPongService); + overlay_service.add_public_overlay(&public_overlay); + + let dht_client = dht_service.make_client(&network); + + Self { + network, + public_overlay, + dht_client, + } + } + + async fn public_overlay_query(&self, peer_id: &PeerId, req: Q) -> Result + where + Q: tl_proto::TlWrite, + for<'a> A: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>, + { + self.public_overlay + .query(&self.network, peer_id, Request::from_tl(req)) + .await? + .parse_tl::() + .map_err(Into::into) + } +} + +fn make_network(node_count: usize) -> Vec { + let nodes = (0..node_count) + .map(|_| Node::with_random_key()) + .collect::>(); + + let common_peer_info = nodes.first().unwrap().network.sign_peer_info(0, u32::MAX); + + for node in &nodes { + node.dht_client + .add_peer(Arc::new(common_peer_info.clone())) + .unwrap(); + } + + nodes +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn public_overlays_accessible() -> Result<()> { + init_logger(); + tracing::info!("public_overlays_accessible"); + + #[derive(Debug, Default)] + struct PeerState { + knows_about: usize, + known_by: usize, + } + + let nodes = make_network(20); + + tracing::info!("discovering nodes"); + loop { + tokio::time::sleep(Duration::from_secs(1)).await; + + let mut peer_states = BTreeMap::<&PeerId, PeerState>::new(); + + for (i, left) in nodes.iter().enumerate() { + for (j, right) in nodes.iter().enumerate() { + if i == j { + continue; + } + + let left_id = left.network.peer_id(); + let right_id = right.network.peer_id(); + + if left.public_overlay.read_entries().contains(right_id) { + peer_states.entry(left_id).or_default().knows_about += 1; + peer_states.entry(right_id).or_default().known_by += 1; + } + } + } + + tracing::info!("{peer_states:#?}"); + + let total_filled = peer_states + .values() + .filter(|state| state.knows_about == nodes.len() - 1) + .count(); + + tracing::info!( + "peers with filled overlay: {} / {}", + total_filled, + nodes.len() + ); + if total_filled == nodes.len() { + break; + } + } + + tracing::info!("resolving entries..."); + for node in &nodes { + let resolved = FuturesUnordered::new(); + for entry in node.public_overlay.read_entries().iter() { + let handle = entry.resolver_handle.clone(); + resolved.push(async move { handle.wait_resolved().await }); + } + + // Ensure all entries are resolved. + resolved.collect::>().await; + tracing::info!( + peer_id = %node.network.peer_id(), + "all entries resolved", + ); + } + + tracing::info!("checking connectivity..."); + for i in 0..nodes.len() { + for j in 0..nodes.len() { + if i == j { + continue; + } + + let left = &nodes[i]; + let right = &nodes[j]; + + let value = (i * 1000 + j) as u64; + let Pong { value: received } = left + .public_overlay_query(right.network.peer_id(), Ping { value }) + .await?; + assert_eq!(received, value); + } + } + + tracing::info!("done!"); + Ok(()) +} + +static PUBLIC_OVERLAY_ID: OverlayId = OverlayId([1; 32]); diff --git a/simulator/src/node.rs b/simulator/src/node.rs index 494019ad3..2c970dea9 100644 --- a/simulator/src/node.rs +++ b/simulator/src/node.rs @@ -24,7 +24,7 @@ impl Node { .arg("--example") .arg("network-node") .arg("--") - .arg("gendht") + .arg("gen-dht") .arg(format!("{ip}:{port}")) .arg("--key") .arg(&private_key)