diff --git a/iroh-net/src/magic_endpoint.rs b/iroh-net/src/magic_endpoint.rs index 19837cd4e3..c9e4f69c54 100644 --- a/iroh-net/src/magic_endpoint.rs +++ b/iroh-net/src/magic_endpoint.rs @@ -15,7 +15,7 @@ use crate::{ discovery::{Discovery, DiscoveryTask}, dns::{default_resolver, DnsResolver}, key::{PublicKey, SecretKey}, - magicsock::{self, MagicSock}, + magicsock::{self, ConnectionTypeStream, MagicSock}, relay::{RelayMap, RelayMode, RelayUrl}, tls, NodeId, }; @@ -402,6 +402,16 @@ impl MagicEndpoint { self.connect(addr, alpn).await } + /// Returns a stream that reports changes in the [`crate::magicsock::ConnectionType`] + /// for the given `node_id`. + /// + /// # Errors + /// + /// Will error if we do not have any address information for the given `node_id` + pub fn conn_type_stream(&self, node_id: &PublicKey) -> Result { + self.msock.conn_type_stream(node_id) + } + /// Connect to a remote endpoint. /// /// A [`NodeAddr`] is required. It must contain the [`NodeId`] to dial and may also contain a @@ -630,7 +640,7 @@ mod tests { use rand_core::SeedableRng; use tracing::{error_span, info, info_span, Instrument}; - use crate::test_utils::run_relay_server; + use crate::{magicsock::ConnectionType, test_utils::run_relay_server}; use super::*; @@ -971,4 +981,100 @@ mod tests { p1_connect.await.unwrap(); p2_connect.await.unwrap(); } + + #[tokio::test] + async fn magic_endpoint_conn_type_stream() { + let _logging_guard = iroh_test::logging::setup(); + let (relay_map, relay_url, _relay_guard) = run_relay_server().await.unwrap(); + let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); + let ep1_secret_key = SecretKey::generate_with_rng(&mut rng); + let ep2_secret_key = SecretKey::generate_with_rng(&mut rng); + let ep1 = MagicEndpoint::builder() + .secret_key(ep1_secret_key) + .insecure_skip_relay_cert_verify(true) + .alpns(vec![TEST_ALPN.to_vec()]) + .relay_mode(RelayMode::Custom(relay_map.clone())) + .bind(0) + .await + .unwrap(); + let ep2 = MagicEndpoint::builder() + .secret_key(ep2_secret_key) + .insecure_skip_relay_cert_verify(true) + .alpns(vec![TEST_ALPN.to_vec()]) + .relay_mode(RelayMode::Custom(relay_map)) + .bind(0) + .await + .unwrap(); + + async fn handle_direct_conn(ep: MagicEndpoint, node_id: PublicKey) -> Result<()> { + let node_addr = NodeAddr::new(node_id); + ep.add_node_addr(node_addr)?; + let stream = ep.conn_type_stream(&node_id)?; + async fn get_direct_event( + src: &PublicKey, + dst: &PublicKey, + mut stream: ConnectionTypeStream, + ) -> Result<()> { + let src = src.fmt_short(); + let dst = dst.fmt_short(); + while let Some(conn_type) = stream.next().await { + tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type); + if matches!(conn_type, ConnectionType::Direct(_)) { + return Ok(()); + } + } + anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`"); + } + tokio::time::timeout( + Duration::from_secs(15), + get_direct_event(&ep.node_id(), &node_id, stream), + ) + .await??; + Ok(()) + } + + let ep1_nodeid = ep1.node_id(); + let ep2_nodeid = ep2.node_id(); + + let ep1_nodeaddr = ep1.my_addr().await.unwrap(); + tracing::info!( + "node id 1 {ep1_nodeid}, relay URL {:?}", + ep1_nodeaddr.relay_url() + ); + tracing::info!("node id 2 {ep2_nodeid}"); + + let res_ep1 = tokio::spawn(handle_direct_conn(ep1.clone(), ep2_nodeid)); + + let ep1_abort_handle = res_ep1.abort_handle(); + let _ep1_guard = CallOnDrop::new(move || { + ep1_abort_handle.abort(); + }); + + let res_ep2 = tokio::spawn(handle_direct_conn(ep2.clone(), ep1_nodeid)); + let ep2_abort_handle = res_ep2.abort_handle(); + let _ep2_guard = CallOnDrop::new(move || { + ep2_abort_handle.abort(); + }); + async fn accept(ep: MagicEndpoint) -> (PublicKey, String, quinn::Connection) { + let incoming = ep.accept().await.unwrap(); + accept_conn(incoming).await.unwrap() + } + + // create a node addr with no direct connections + let ep1_nodeaddr = NodeAddr::from_parts(ep1_nodeid, Some(relay_url), vec![]); + + let accept_res = tokio::spawn(accept(ep1.clone())); + let accept_abort_handle = accept_res.abort_handle(); + let _accept_guard = CallOnDrop::new(move || { + accept_abort_handle.abort(); + }); + + let _conn_2 = ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap(); + + let (got_id, _, _conn) = accept_res.await.unwrap(); + assert_eq!(ep2_nodeid, got_id); + + res_ep1.await.unwrap().unwrap(); + res_ep2.await.unwrap().unwrap(); + } } diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index b243cc1a5a..a35e179df9 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -80,7 +80,9 @@ mod udp_conn; pub use crate::net::UdpSocket; pub use self::metrics::Metrics; -pub use self::node_map::{ConnectionType, ControlMsg, DirectAddrInfo, EndpointInfo}; +pub use self::node_map::{ + ConnectionType, ConnectionTypeStream, ControlMsg, DirectAddrInfo, EndpointInfo, +}; pub use self::timer::Timer; /// How long we consider a STUN-derived endpoint valid for. UDP NAT mappings typically @@ -1349,6 +1351,23 @@ impl MagicSock { } } + /// Returns a stream that reports the [`ConnectionType`] we have to the + /// given `node_id`. + /// + /// The `NodeMap` continuously monitors the `node_id`'s endpoint for + /// [`ConnectionType`] changes, and sends the latest [`ConnectionType`] + /// on the stream. + /// + /// The current [`ConnectionType`] will the the initial entry on the stream. + /// + /// # Errors + /// + /// Will return an error if there is no address information known about the + /// given `node_id`. + pub fn conn_type_stream(&self, node_id: &PublicKey) -> Result { + self.inner.node_map.conn_type_stream(node_id) + } + /// Get the cached version of the Ipv4 and Ipv6 addrs of the current connection. pub fn local_addr(&self) -> Result<(SocketAddr, Option)> { Ok(self.inner.local_addr()) diff --git a/iroh-net/src/magicsock/node_map.rs b/iroh-net/src/magicsock/node_map.rs index 90c214b67e..99fdf52e2c 100644 --- a/iroh-net/src/magicsock/node_map.rs +++ b/iroh-net/src/magicsock/node_map.rs @@ -3,10 +3,13 @@ use std::{ hash::Hash, net::{IpAddr, SocketAddr}, path::Path, + pin::Pin, + task::{Context, Poll}, time::Instant, }; -use anyhow::{ensure, Context}; +use anyhow::{ensure, Context as _}; +use futures::Stream; use iroh_metrics::inc; use parking_lot::Mutex; use stun_rs::TransactionId; @@ -209,6 +212,19 @@ impl NodeMap { self.inner.lock().endpoint_infos(now) } + /// Returns a stream of [`ConnectionType`]. + /// + /// Sends the current [`ConnectionType`] whenever any changes to the + /// connection type for `public_key` has occured. + /// + /// # Errors + /// + /// Will return an error if there is not an entry in the [`NodeMap`] for + /// the `public_key` + pub fn conn_type_stream(&self, public_key: &PublicKey) -> anyhow::Result { + self.inner.lock().conn_type_stream(public_key) + } + /// Get the [`EndpointInfo`]s for each endpoint pub fn endpoint_info(&self, public_key: &PublicKey) -> Option { self.inner.lock().endpoint_info(public_key) @@ -389,6 +405,25 @@ impl NodeMapInner { .map(|ep| ep.info(Instant::now())) } + /// Returns a stream of [`ConnectionType`]. + /// + /// Sends the current [`ConnectionType`] whenever any changes to the + /// connection type for `public_key` has occured. + /// + /// # Errors + /// + /// Will return an error if there is not an entry in the [`NodeMap`] for + /// the `public_key` + fn conn_type_stream(&self, public_key: &PublicKey) -> anyhow::Result { + match self.get(EndpointId::NodeKey(public_key)) { + Some(ep) => Ok(ConnectionTypeStream { + initial: Some(ep.conn_type.get()), + inner: ep.conn_type.watch().into_stream(), + }), + None => anyhow::bail!("No endpoint for {public_key:?} found"), + } + } + fn handle_pong(&mut self, sender: PublicKey, src: &DiscoMessageSource, pong: Pong) { if let Some(ep) = self.get_mut(EndpointId::NodeKey(&sender)).as_mut() { let insert = ep.handle_pong(&pong, src.into()); @@ -536,6 +571,25 @@ impl NodeMapInner { } } +/// Stream returning `ConnectionTypes` +#[derive(Debug)] +pub struct ConnectionTypeStream { + initial: Option, + inner: watchable::WatcherStream, +} + +impl Stream for ConnectionTypeStream { + type Item = ConnectionType; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + if let Some(initial_conn_type) = this.initial.take() { + return Poll::Ready(Some(initial_conn_type)); + } + Pin::new(&mut this.inner).poll_next(cx) + } +} + /// An (Ip, Port) pair. /// /// NOTE: storing an [`IpPort`] is safer than storing a [`SocketAddr`] because for IPv6 socket diff --git a/iroh-net/src/magicsock/node_map/endpoint.rs b/iroh-net/src/magicsock/node_map/endpoint.rs index 7c9038b615..391352bf8a 100644 --- a/iroh-net/src/magicsock/node_map/endpoint.rs +++ b/iroh-net/src/magicsock/node_map/endpoint.rs @@ -10,6 +10,7 @@ use rand::seq::IteratorRandom; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tracing::{debug, info, instrument, trace, warn}; +use watchable::Watchable; use crate::{ disco::{self, SendAddr}, @@ -141,6 +142,8 @@ pub(super) struct Endpoint { /// the [`Endpoint::stayin_alive`] function is called, which will trigger new /// call-me-maybe messages as backup. last_call_me_maybe: Option, + /// The type of connection we have to the node, either direct, relay, mixed, or none. + pub conn_type: Watchable, } #[derive(Debug)] @@ -171,6 +174,7 @@ impl Endpoint { direct_addr_state: BTreeMap::new(), last_used: options.active.then(Instant::now), last_call_me_maybe: None, + conn_type: Watchable::new(ConnectionType::None), } } @@ -188,24 +192,30 @@ impl Endpoint { /// Returns info about this endpoint pub(super) fn info(&self, now: Instant) -> EndpointInfo { - use best_addr::State::*; - // Report our active connection. This replicates the logic of [`Endpoint::addr_for_send`] - // without choosing a random candidate address if no best_addr is set. - let (conn_type, latency) = match (self.best_addr.state(now), self.relay_url.as_ref()) { - (Valid(addr), _) | (Outdated(addr), None) => { - (ConnectionType::Direct(addr.addr), Some(addr.latency)) - } - (Outdated(addr), Some((url, relay_state))) => { - let latency = relay_state - .latency() - .map(|l| l.min(addr.latency)) - .unwrap_or(addr.latency); - (ConnectionType::Mixed(addr.addr, url.clone()), Some(latency)) - } - (Empty, Some((url, relay_state))) => { - (ConnectionType::Relay(url.clone()), relay_state.latency()) + let conn_type = self.conn_type.get(); + let latency = match conn_type { + ConnectionType::Direct(addr) => self + .direct_addr_state + .get(&addr.into()) + .and_then(|state| state.latency()), + ConnectionType::Relay(ref url) => self + .relay_url + .as_ref() + .filter(|(relay_url, _)| relay_url == url) + .and_then(|(_, state)| state.latency()), + ConnectionType::Mixed(addr, ref url) => { + let addr_latency = self + .direct_addr_state + .get(&addr.into()) + .and_then(|state| state.latency()); + let relay_latency = self + .relay_url + .as_ref() + .filter(|(relay_url, _)| relay_url == url) + .and_then(|(_, state)| state.latency()); + addr_latency.min(relay_latency) } - (Empty, None) => (ConnectionType::None, None), + ConnectionType::None => None, }; let addrs = self .direct_addr_state @@ -252,7 +262,7 @@ impl Endpoint { // Update our best addr from candidate addresses (only if it is empty and if we have // recent pongs). self.assign_best_addr_from_candidates_if_empty(); - match self.best_addr.state(*now) { + let (best_addr, relay_url) = match self.best_addr.state(*now) { best_addr::State::Valid(best_addr) => { // If we have a valid address we use it. trace!(addr = %best_addr.addr, latency = ?best_addr.latency, @@ -283,7 +293,24 @@ impl Endpoint { trace!(udp_addr = ?addr, "best_addr is unset, use candidate addr and relay"); (addr, self.relay_url()) } + }; + match (best_addr, relay_url.clone()) { + (Some(best_addr), Some(relay_url)) => { + let _ = self + .conn_type + .update(ConnectionType::Mixed(best_addr, relay_url)); + } + (Some(best_addr), None) => { + let _ = self.conn_type.update(ConnectionType::Direct(best_addr)); + } + (None, Some(relay_url)) => { + let _ = self.conn_type.update(ConnectionType::Relay(relay_url)); + } + (None, None) => { + let _ = self.conn_type.update(ConnectionType::None); + } } + (best_addr, relay_url) } /// Fixup best_addr from candidates. @@ -329,7 +356,7 @@ impl Endpoint { best_addr::Source::BestCandidate, pong.pong_at, self.relay_url.is_some(), - ); + ) } } } @@ -1438,6 +1465,7 @@ mod tests { sent_pings: HashMap::new(), last_used: Some(now), last_call_me_maybe: None, + conn_type: Watchable::new(ConnectionType::Direct(ip_port.into())), }, ip_port.into(), ) @@ -1463,6 +1491,7 @@ mod tests { sent_pings: HashMap::new(), last_used: Some(now), last_call_me_maybe: None, + conn_type: Watchable::new(ConnectionType::Relay(send_addr.clone())), } }; @@ -1482,6 +1511,7 @@ mod tests { sent_pings: HashMap::new(), last_used: Some(now), last_call_me_maybe: None, + conn_type: Watchable::new(ConnectionType::Relay(send_addr.clone())), } }; @@ -1522,6 +1552,10 @@ mod tests { sent_pings: HashMap::new(), last_used: Some(now), last_call_me_maybe: None, + conn_type: Watchable::new(ConnectionType::Mixed( + socket_addr, + send_addr.clone(), + )), }, socket_addr, )