diff --git a/cli/src/node/config.rs b/cli/src/node/config.rs index 02f55336e..946096be0 100644 --- a/cli/src/node/config.rs +++ b/cli/src/node/config.rs @@ -1,4 +1,4 @@ -use std::net::Ipv4Addr; +use std::net::{IpAddr, Ipv4Addr}; use std::path::Path; use anyhow::Result; @@ -32,12 +32,12 @@ pub struct NodeConfig { /// Public IP address of the node. /// /// Default: resolved automatically. - pub public_ip: Option, + pub public_ip: Option, /// Ip address to listen on. /// /// Default: 0.0.0.0 - pub local_ip: Ipv4Addr, + pub local_ip: IpAddr, /// Default: 30000. pub port: u16, @@ -63,7 +63,7 @@ impl Default for NodeConfig { fn default() -> Self { Self { public_ip: None, - local_ip: Ipv4Addr::UNSPECIFIED, + local_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), port: 30000, network: NetworkConfig::default(), dht: DhtConfig::default(), diff --git a/cli/src/node/mod.rs b/cli/src/node/mod.rs index f1636f72e..640e9d666 100644 --- a/cli/src/node/mod.rs +++ b/cli/src/node/mod.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -218,11 +218,11 @@ fn init_logger(logger_config: Option) -> Result<()> { Ok(()) } -async fn resolve_public_ip(ip: Option) -> Result { +async fn resolve_public_ip(ip: Option) -> Result { match ip { Some(address) => Ok(address), None => match public_ip::addr_v4().await { - Some(address) => Ok(address), + Some(address) => Ok(IpAddr::V4(address)), None => anyhow::bail!("failed to resolve public IP address"), }, } diff --git a/consensus/examples/consensus_node.rs b/consensus/examples/consensus_node.rs index e9c56c21d..c2f658953 100644 --- a/consensus/examples/consensus_node.rs +++ b/consensus/examples/consensus_node.rs @@ -17,7 +17,7 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::{fmt, EnvFilter, Layer}; use tycho_consensus::test_utils::drain_anchors; use tycho_consensus::{Engine, InputBufferStub}; -use tycho_network::{DhtConfig, NetworkConfig, PeerId, PeerInfo}; +use tycho_network::{Address, DhtConfig, NetworkConfig, PeerId, PeerInfo}; use tycho_util::time::now_sec; #[tokio::main] @@ -179,8 +179,9 @@ impl CmdGenKey { /// generate a dht node info #[derive(Parser)] struct CmdGenDht { - /// local node address - addr: SocketAddr, + /// a list of node addresses + #[clap(required = true)] + addr: Vec
, /// node secret key #[clap(long)] @@ -195,8 +196,7 @@ impl CmdGenDht { fn run(self) -> Result<()> { let secret_key = parse_key(&self.key)?; let key_pair = KeyPair::from(&secret_key); - let entry = - tycho_consensus::test_utils::make_peer_info(&key_pair, self.addr.into(), self.ttl); + let entry = tycho_consensus::test_utils::make_peer_info(&key_pair, self.addr, self.ttl); let output = if std::io::stdin().is_terminal() { serde_json::to_string_pretty(&entry) } else { diff --git a/consensus/src/test_utils.rs b/consensus/src/test_utils.rs index 485574359..eb268f8b8 100644 --- a/consensus/src/test_utils.rs +++ b/consensus/src/test_utils.rs @@ -32,13 +32,13 @@ pub fn genesis() -> Arc { }) } -pub fn make_peer_info(keypair: &KeyPair, address: Address, ttl: Option) -> PeerInfo { +pub fn make_peer_info(keypair: &KeyPair, address_list: Vec
, ttl: Option) -> PeerInfo { let peer_id = PeerId::from(keypair.public_key); let now = now_sec(); let mut peer_info = PeerInfo { id: peer_id, - address_list: vec![address.clone()].into_boxed_slice(), + address_list: address_list.into_boxed_slice(), created_at: now, expires_at: ttl.unwrap_or(u32::MAX), signature: Box::new([0; 64]), @@ -136,7 +136,9 @@ mod tests { let peer_info = keys .iter() .zip(addresses.iter()) - .map(|((_, key_pair), addr)| Arc::new(make_peer_info(key_pair, addr.clone(), None))) + .map(|((_, key_pair), addr)| { + Arc::new(make_peer_info(key_pair, vec![addr.clone()], None)) + }) .collect::>(); let mut handles = vec![]; diff --git a/network/examples/network_node.rs b/network/examples/network_node.rs index ccd14fca3..633e0ff64 100644 --- a/network/examples/network_node.rs +++ b/network/examples/network_node.rs @@ -147,8 +147,9 @@ impl CmdGenKey { /// generate a dht node info #[derive(Parser)] struct CmdGenDht { - /// local node address - addr: SocketAddr, + /// a list of node addresses + #[clap(required = true)] + addr: Vec
, /// node secret key #[clap(long)] @@ -161,7 +162,7 @@ struct CmdGenDht { impl CmdGenDht { fn run(self) -> Result<()> { - let entry = Node::make_peer_info(parse_key(&self.key)?, self.addr.into(), self.ttl); + let entry = Node::make_peer_info(parse_key(&self.key)?, self.addr, self.ttl); let output = if std::io::stdin().is_terminal() { serde_json::to_string_pretty(&entry) } else { @@ -235,14 +236,18 @@ impl Node { Ok(Self { network, dht }) } - fn make_peer_info(key: ed25519::SecretKey, address: Address, ttl: Option) -> PeerInfo { + fn make_peer_info( + key: ed25519::SecretKey, + address_list: Vec
, + ttl: Option, + ) -> PeerInfo { let keypair = ed25519::KeyPair::from(&key); let peer_id = PeerId::from(keypair.public_key); let now = now_sec(); let mut node_info = PeerInfo { id: peer_id, - address_list: vec![address].into_boxed_slice(), + address_list: address_list.into_boxed_slice(), created_at: now, expires_at: ttl.unwrap_or(u32::MAX), signature: Box::new([0; 64]), diff --git a/network/src/network/connection_manager.rs b/network/src/network/connection_manager.rs index 2ab01095f..5c0db5886 100644 --- a/network/src/network/connection_manager.rs +++ b/network/src/network/connection_manager.rs @@ -505,13 +505,15 @@ impl ConnectionManager { fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) { async fn dial_peer_task( seqno: u32, - connecting: Result, + endpoint: Arc, address: Address, peer_id: PeerId, config: Arc, ) -> ConnectingOutput { let fut = async { - let connection = ConnectionClosedOnDrop::new(connecting?.await?); + let address = address.resolve().await?; + let connecting = endpoint.connect_with_expected_id(&address, &peer_id)?; + let connection = ConnectionClosedOnDrop::new(connecting.await?); handshake(&connection).await?; Ok(connection) }; @@ -580,14 +582,10 @@ impl ConnectionManager { }; if let Some(entry) = entry { - let target_address = address.clone(); - let connecting = self - .endpoint - .connect_with_expected_id(address.clone(), peer_id); entry.abort_handle = Some(self.pending_connections.spawn(dial_peer_task( entry.last_seqno, - connecting, - target_address, + self.endpoint.clone(), + address.clone(), *peer_id, self.config.clone(), ))); diff --git a/network/src/network/endpoint.rs b/network/src/network/endpoint.rs index bf6ce84e4..cd83433e3 100644 --- a/network/src/network/endpoint.rs +++ b/network/src/network/endpoint.rs @@ -9,7 +9,7 @@ use anyhow::Result; use crate::network::config::EndpointConfig; use crate::network::connection::{parse_peer_identity, Connection}; -use crate::types::{Address, Direction, PeerId}; +use crate::types::{Direction, PeerId}; pub(crate) struct Endpoint { inner: quinn::Endpoint, @@ -75,7 +75,7 @@ impl Endpoint { /// Connect to a remote endpoint expecting it to have the provided peer id. pub fn connect_with_expected_id( &self, - address: Address, + address: &SocketAddr, peer_id: &PeerId, ) -> Result { let config = self.config.make_client_config_for_peer_id(peer_id)?; @@ -86,12 +86,10 @@ impl Endpoint { fn connect_with_client_config( &self, config: quinn::ClientConfig, - address: Address, + address: &SocketAddr, ) -> Result { - let address = address.resolve()?; - self.inner - .connect_with(config, address, &self.config.service_name) + .connect_with(config, *address, &self.config.service_name) .map_err(Into::into) .map(Connecting::new_outbound) } diff --git a/network/src/proto.tl b/network/src/proto.tl index e04b43437..6362324c0 100644 --- a/network/src/proto.tl +++ b/network/src/proto.tl @@ -13,6 +13,7 @@ transport.peerId key:int256 = transport.PeerId; transport.address.ipv4 ip:int port:int = transport.Address; transport.address.ipv6 ip:int128 port:int = transport.Address; +transport.address.dns hostname:bytes post:int = transport.Address; // DHT //////////////////////////////////////////////////////////////////////////////// @@ -209,4 +210,4 @@ overlay.exchangeRandomPublicEntries * * @param overlay_id overlay id */ -overlay.prefix overlay_id:int256 = True; \ No newline at end of file +overlay.prefix overlay_id:int256 = True; diff --git a/network/src/types/address.rs b/network/src/types/address.rs index f772054c0..49d487563 100644 --- a/network/src/types/address.rs +++ b/network/src/types/address.rs @@ -1,34 +1,103 @@ -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{IpAddr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::str::FromStr; +use std::sync::Arc; use serde::{Deserialize, Serialize}; use tl_proto::{TlRead, TlWrite}; +use tycho_util::serde_helpers::StrVisitor; -#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -#[serde(transparent)] -pub struct Address(#[serde(with = "tycho_util::serde_helpers::socket_addr")] SocketAddr); +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum Address { + Ip(SocketAddr), + Dns { hostname: Arc, port: u16 }, +} impl Address { - pub fn resolve(&self) -> std::io::Result { - std::net::ToSocketAddrs::to_socket_addrs(&self).and_then(|mut iter| { - iter.next().ok_or_else(|| { - std::io::Error::new(std::io::ErrorKind::NotFound, "unable to resolve host") + pub async fn resolve(&self) -> std::io::Result { + match self { + Self::Ip(addr) => Ok(*addr), + Self::Dns { hostname, port } => { + let mut iter = tokio::net::lookup_host((hostname.as_ref(), *port)).await?; + iter.next().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::NotFound, "unable to resolve host") + }) + } + } + } +} + +impl Serialize for Address { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + enum Address<'a> { + Ip(&'a SocketAddr), + Dns { hostname: &'a str, port: u16 }, + } + + if serializer.is_human_readable() { + serializer.collect_str(self) + } else { + match self { + Self::Ip(addr) => Address::Ip(addr), + Self::Dns { hostname, port } => Address::Dns { + hostname: hostname.as_ref(), + port: *port, + }, + } + .serialize(serializer) + } + } +} + +impl<'de> Deserialize<'de> for Address { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + enum Address { + Ip(SocketAddr), + Dns { hostname: String, port: u16 }, + } + + if deserializer.is_human_readable() { + deserializer.deserialize_str(StrVisitor::new()) + } else { + let addr = Address::deserialize(deserializer)?; + Ok(match addr { + Address::Ip(addr) => Self::Ip(addr), + Address::Dns { hostname, port } => Self::Dns { + hostname: hostname.into(), + port, + }, }) - }) + } } } impl std::fmt::Display for Address { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(&self.0, f) + match self { + Self::Ip(addr) => std::fmt::Display::fmt(addr, f), + Self::Dns { hostname, port } => write!(f, "{}:{port}", hostname.as_ref()), + } } } impl std::net::ToSocketAddrs for Address { - type Iter = ::Iter; + type Iter = std::option::IntoIter; fn to_socket_addrs(&self) -> std::io::Result { - self.0.to_socket_addrs() + match self { + Self::Ip(addr) => addr.to_socket_addrs(), + Self::Dns { hostname, port } => { + let resolved = (hostname.as_ref(), *port).to_socket_addrs()?; + Ok(resolved.into_iter().next().into_iter()) + } + } } } @@ -36,27 +105,38 @@ impl TlWrite for Address { type Repr = tl_proto::Boxed; fn max_size_hint(&self) -> usize { - 4 + match &self.0 { - SocketAddr::V4(_) => 4 + 4, - SocketAddr::V6(_) => 16 + 4, - } + let len = match self { + Self::Ip(SocketAddr::V4(_)) => 4, + Self::Ip(SocketAddr::V6(_)) => 16, + Self::Dns { hostname: host, .. } => host.as_bytes().max_size_hint(), + }; + // Constructor + len + port + 4 + len + 4 } fn write_to

(&self, packet: &mut P) where P: tl_proto::TlPacket, { - match &self.0 { - SocketAddr::V4(addr) => { + match self { + Self::Ip(SocketAddr::V4(addr)) => { packet.write_u32(ADDRESS_V4_TL_ID); packet.write_u32(u32::from(*addr.ip())); packet.write_u32(addr.port() as u32); } - SocketAddr::V6(addr) => { + Self::Ip(SocketAddr::V6(addr)) => { packet.write_u32(ADDRESS_V6_TL_ID); packet.write_raw_slice(&addr.ip().octets()); packet.write_u32(addr.port() as u32); } + Self::Dns { + hostname: host, + port, + } => { + packet.write_u32(ADDRESS_DNS_TL_ID); + host.as_bytes().write_to(packet); + packet.write_u32(*port as u32); + } }; } } @@ -67,58 +147,77 @@ impl<'a> TlRead<'a> for Address { fn read_from(packet: &'a [u8], offset: &mut usize) -> tl_proto::TlResult { use tl_proto::TlError; - Ok(Address(match u32::read_from(packet, offset)? { + Ok(match u32::read_from(packet, offset)? { ADDRESS_V4_TL_ID => { let ip = u32::read_from(packet, offset)?; let Ok(port) = u32::read_from(packet, offset)?.try_into() else { return Err(TlError::InvalidData); }; - SocketAddr::V4(SocketAddrV4::new(ip.into(), port)) + Self::Ip(SocketAddr::V4(SocketAddrV4::new(ip.into(), port))) } ADDRESS_V6_TL_ID => { let octets = <[u8; 16]>::read_from(packet, offset)?; let Ok(port) = u32::read_from(packet, offset)?.try_into() else { return Err(TlError::InvalidData); }; - SocketAddr::V6(SocketAddrV6::new(octets.into(), port, 0, 0)) + Self::Ip(SocketAddr::V6(SocketAddrV6::new(octets.into(), port, 0, 0))) + } + ADDRESS_DNS_TL_ID => { + let hostname = <&[u8]>::read_from(packet, offset)?; + let Some(hostname) = validate_hostname(hostname) else { + return Err(TlError::InvalidData); + }; + + let Ok(port) = u32::read_from(packet, offset)?.try_into() else { + return Err(TlError::InvalidData); + }; + + if hostname.parse::().is_ok() { + return Err(TlError::InvalidData); + } + + Self::Dns { + hostname: hostname.into(), + port, + } } _ => return Err(TlError::UnknownConstructor), - })) + }) } } impl From for Address { #[inline] fn from(value: SocketAddr) -> Self { - Self(value) + Self::Ip(value) } } impl From for Address { #[inline] fn from(value: SocketAddrV4) -> Self { - Self(SocketAddr::V4(value)) + Self::Ip(SocketAddr::V4(value)) } } impl From for Address { #[inline] fn from(value: SocketAddrV6) -> Self { - Self(SocketAddr::V6(value)) + Self::Ip(SocketAddr::V6(value)) } } impl From<(std::net::Ipv4Addr, u16)> for Address { #[inline] fn from((ip, port): (std::net::Ipv4Addr, u16)) -> Self { - Self(SocketAddr::V4(SocketAddrV4::new(ip, port))) + Self::Ip(SocketAddr::V4(SocketAddrV4::new(ip, port))) } } impl From<(std::net::Ipv6Addr, u16)> for Address { #[inline] fn from((ip, port): (std::net::Ipv6Addr, u16)) -> Self { - Self(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))) + Self::Ip(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))) } } @@ -127,23 +226,103 @@ impl FromStr for Address { #[inline] fn from_str(s: &str) -> Result { - SocketAddr::from_str(s).map(Self) + match SocketAddr::from_str(s) { + Ok(addr) => Ok(Self::Ip(addr)), + Err(e) => { + 'host: { + let Some((hostname, port)) = s.split_once(':') else { + break 'host; + }; + + let Ok(port) = port.parse::() else { + break 'host; + }; + + let Some(hostname) = validate_hostname(hostname.as_bytes()) else { + break 'host; + }; + + return Ok(Self::Dns { + hostname: hostname.into(), + port, + }); + } + + Err(e) + } + } } } +/// Validates a hostname according to [IETF RFC 1123](https://tools.ietf.org/html/rfc1123). +/// +/// A hostname is valid if the following conditions are true: +/// +/// - It does not start or end with `-` or `.`. +/// - It does not contain any characters outside of the alphanumeric range, except for `-` and `.`. +/// - It is not empty. +/// - It is 253 or fewer characters. +/// - Its labels (characters separated by `.`) are not empty. +/// - Its labels are 63 or fewer characters. +/// - Its labels do not start or end with '-' or '.'. +fn validate_hostname(hostname: &[u8]) -> Option<&str> { + if hostname.is_empty() || hostname.len() > 253 { + return None; + } + + let mut label_length = 0; + let mut previous_char = b'.'; // assume the previous character is a dot + + for &byte in hostname { + match byte { + b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' => { + label_length += 1; + } + b'-' => { + if label_length == 0 { + return None; // invalid label + } + label_length += 1; + } + b'.' => { + if label_length == 0 || previous_char == b'-' { + return None; // invalid label + } + label_length = 0; // reset label length after each dot + } + _ => return None, + } + + if label_length > 63 { + return None; // invalid label + } + + previous_char = byte; + } + + if label_length == 0 || previous_char == b'-' { + return None; + } + + // SAFETY: `hostname` is guaranteed to contain only valid UTF-8 characters. + Some(unsafe { std::str::from_utf8_unchecked(hostname) }) +} + const ADDRESS_V4_TL_ID: u32 = tl_proto::id!("transport.address.ipv4", scheme = "proto.tl"); const ADDRESS_V6_TL_ID: u32 = tl_proto::id!("transport.address.ipv6", scheme = "proto.tl"); +const ADDRESS_DNS_TL_ID: u32 = tl_proto::id!("transport.address.dns", scheme = "proto.tl"); #[cfg(test)] mod tests { use super::*; + const SOME_ADDR_V4: &str = "101.102.103.104:12345"; + const SOME_ADDR_V6: &str = "[2345:0425:2CA1:0:0:0567:5673:23b5]:12345"; + const SOME_ADDR_DNS: &str = "node-1.example.com:12345"; + #[test] fn serde() { - const SOME_ADDR_V4: &str = "101.102.103.104:12345"; - const SOME_ADDR_V6: &str = "[2345:0425:2CA1:0:0:0567:5673:23b5]:12345"; - - for addr in [SOME_ADDR_V4, SOME_ADDR_V6] { + for addr in [SOME_ADDR_V4, SOME_ADDR_V6, SOME_ADDR_DNS] { let from_json: Address = serde_json::from_str(&format!("\"{addr}\"")).unwrap(); let from_str = Address::from_str(addr).unwrap(); assert_eq!(from_json, from_str); @@ -153,4 +332,95 @@ mod tests { assert_eq!(from_json, from_str); } } + + #[test] + fn tl() { + // Valid + let addrs = [ + Address::Ip(SocketAddr::from_str(SOME_ADDR_V4).unwrap()), + Address::Ip(SocketAddr::from_str(SOME_ADDR_V6).unwrap()), + Address::Dns { + hostname: "node-1.example.com".into(), + port: 12345, + }, + ]; + + for addr in addrs { + let bytes = tl_proto::serialize(&addr); + let parsed = tl_proto::deserialize::

(&bytes).unwrap(); + assert_eq!(addr, parsed); + } + + // Invalid + let addrs = [ + Address::Dns { + hostname: "test.com:12345".into(), + port: 12345, + }, + Address::Dns { + hostname: "".into(), + port: 12345, + }, + Address::Dns { + hostname: "...".into(), + port: 12345, + }, + Address::Dns { + hostname: "127.0.0.1".into(), + port: 12345, + }, + Address::Dns { + hostname: SOME_ADDR_V6.into(), + port: 12345, + }, + ]; + + for addr in addrs { + assert!(matches!( + tl_proto::deserialize::
(&tl_proto::serialize(addr)), + Err(tl_proto::TlError::InvalidData) + )); + } + } + + #[test] + fn valid_hostnames() { + for hostname in &[ + "VaLiD-HoStNaMe", + "50-name", + "235235", + "example.com", + "VaLid.HoStNaMe", + "123.456", + ] { + assert!( + validate_hostname(hostname.as_bytes()).is_some(), + "{} is not valid", + hostname + ); + } + } + + #[test] + fn invalid_hostnames() { + for hostname in &[ + "-invalid-name", + "also-invalid-", + "asdf@fasd", + "@asdfl", + "asd f@", + ".invalid", + "invalid.name.", + "foo.label-is-way-to-longgggggggggggggggggggggggggggggggggggggggggggg.org", + "invalid.-starting.char", + "invalid.ending-.char", + "empty..label", + ] { + assert!( + validate_hostname(hostname.as_bytes()).is_none(), + "{} should not be valid", + hostname + ); + } + } } diff --git a/util/src/serde_helpers.rs b/util/src/serde_helpers.rs index d47af4752..56458c507 100644 --- a/util/src/serde_helpers.rs +++ b/util/src/serde_helpers.rs @@ -322,21 +322,3 @@ impl<'de, const M: usize> Visitor<'de> for BytesVisitor { array_from_iterator(SeqIter::new(seq), &self) } } - -struct HexVisitor; - -impl<'de> Visitor<'de> for HexVisitor { - type Value = Vec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("hex-encoded byte array") - } - - fn visit_str(self, value: &str) -> Result { - hex::decode(value).map_err(|_e| E::invalid_type(serde::de::Unexpected::Str(value), &self)) - } - - fn visit_bytes(self, value: &[u8]) -> Result { - Ok(value.to_vec()) - } -}