From ade68d31a9f83669870a0412b9673fdeec626f85 Mon Sep 17 00:00:00 2001 From: Ivan Kalinin Date: Sat, 4 May 2024 19:50:23 +0200 Subject: [PATCH] feat(network): allow using raw sockets --- network/src/lib.rs | 4 +-- network/src/network/mod.rs | 67 +++++++++++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/network/src/lib.rs b/network/src/lib.rs index efd61227a..c43c6708c 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -13,8 +13,8 @@ pub use dht::{ }; pub use network::{ ActivePeers, Connection, KnownPeerHandle, KnownPeers, KnownPeersError, Network, NetworkBuilder, - NetworkConfig, Peer, PeerBannedError, QuicConfig, RecvStream, SendStream, WeakActivePeers, - WeakKnownPeerHandle, WeakNetwork, + NetworkConfig, Peer, PeerBannedError, QuicConfig, RecvStream, SendStream, ToSocket, + WeakActivePeers, WeakKnownPeerHandle, WeakNetwork, }; pub use types::{ service_datagram_fn, service_message_fn, service_query_fn, Address, BoxCloneService, diff --git a/network/src/network/mod.rs b/network/src/network/mod.rs index 3e7b0fba2..609630cfe 100644 --- a/network/src/network/mod.rs +++ b/network/src/network/mod.rs @@ -78,13 +78,11 @@ impl NetworkBuilder<(T1, ())> { } impl NetworkBuilder { - pub fn build(self, bind_address: T, service: S) -> Result + pub fn build(self, bind_address: T, service: S) -> Result where S: Send + Sync + Clone + 'static, S: Service, { - use socket2::{Domain, Protocol, Socket, Type}; - let config = self.optional_fields.config.unwrap_or_default(); let quic_config = config.quic.clone().unwrap_or_default(); let (service_name, private_key) = self.mandatory_fields; @@ -98,18 +96,7 @@ impl NetworkBuilder { .with_transport_config(quic_config.make_transport_config()) .build()?; - let socket = 'socket: { - let mut err = anyhow::anyhow!("no addresses to bind to"); - for addr in bind_address.to_socket_addrs()? { - let s = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?; - if let Err(e) = s.bind(&socket2::SockAddr::from(addr)) { - err = e.into(); - } else { - break 'socket s; - } - } - return Err(err); - }; + let socket = bind_address.to_socket().map(socket2::Socket::from)?; if let Some(send_buffer_size) = quic_config.socket_send_buffer_size { if let Err(e) = socket.set_send_buffer_size(send_buffer_size) { @@ -339,6 +326,56 @@ impl Drop for NetworkInner { } } +pub trait ToSocket { + fn to_socket(self) -> Result; +} + +impl ToSocket for std::net::UdpSocket { + fn to_socket(self) -> Result { + Ok(self) + } +} + +macro_rules! impl_to_socket_for_addr { + ($($ty:ty),*$(,)?) => {$( + impl ToSocket for $ty { + fn to_socket(self) -> Result { + bind_socket_to_addr(self) + } + } + )*}; +} + +impl_to_socket_for_addr! { + SocketAddr, + std::net::SocketAddrV4, + std::net::SocketAddrV6, + (std::net::IpAddr, u16), + (std::net::Ipv4Addr, u16), + (std::net::Ipv6Addr, u16), + (&str, u16), + (String, u16), + &str, + String, + &[SocketAddr], + Address, +} + +fn bind_socket_to_addr(bind_address: T) -> Result { + use socket2::{Domain, Protocol, Socket, Type}; + + let mut err = anyhow::anyhow!("no addresses to bind to"); + for addr in bind_address.to_socket_addrs()? { + let s = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?; + if let Err(e) = s.bind(&socket2::SockAddr::from(addr)) { + err = e.into(); + } else { + return Ok(s.into()); + } + } + return Err(err); +} + #[derive(thiserror::Error, Debug)] #[error("network has been shutdown")] struct NetworkShutdownError;