From 077b735b484cf33e79f9d621db1d0c3a5827b81e Mon Sep 17 00:00:00 2001 From: Mark Logan <103447440+mystenmark@users.noreply.github.com> Date: Fri, 17 May 2024 08:12:49 -0700 Subject: [PATCH] Separate ports by protocol (#46) --- msim-tokio/Cargo.toml | 1 + msim-tokio/src/sim/net.rs | 4 +- msim/src/sim/net/mod.rs | 126 +++++++++++++++++++++--------------- msim/src/sim/net/network.rs | 82 +++++++++++++++++------ 4 files changed, 140 insertions(+), 73 deletions(-) diff --git a/msim-tokio/Cargo.toml b/msim-tokio/Cargo.toml index d70cff7..346cf09 100644 --- a/msim-tokio/Cargo.toml +++ b/msim-tokio/Cargo.toml @@ -65,3 +65,4 @@ real_tokio = { git = "https://github.com/mystenmark/tokio-madsim-fork.git", rev bytes = { version = "1.1" } futures = { version = "0.3.0", features = ["async-await"] } mio = { version = "0.8.1" } +libc = "0.2" diff --git a/msim-tokio/src/sim/net.rs b/msim-tokio/src/sim/net.rs index f28c727..4f921e8 100644 --- a/msim-tokio/src/sim/net.rs +++ b/msim-tokio/src/sim/net.rs @@ -350,7 +350,7 @@ impl TcpStream { } async fn connect_addr(addr: impl ToSocketAddrs) -> io::Result { - let ep = Arc::new(Endpoint::connect(addr).await?); + let ep = Arc::new(Endpoint::connect(libc::SOCK_STREAM, addr).await?); trace!("connect {:?}", ep.local_addr()); let remote_sock = ep.peer_addr()?; @@ -714,7 +714,7 @@ impl TcpSocket { } pub fn bind(&self, addr: StdSocketAddr) -> io::Result<()> { - let ep = Endpoint::bind_sync(addr)?; + let ep = Endpoint::bind_sync(libc::SOCK_STREAM, addr)?; *self.bind_addr.lock().unwrap() = Some(ep.into()); Ok(()) } diff --git a/msim/src/sim/net/mod.rs b/msim/src/sim/net/mod.rs index f9cfda9..afa567e 100644 --- a/msim/src/sim/net/mod.rs +++ b/msim/src/sim/net/mod.rs @@ -325,7 +325,7 @@ unsafe fn accept_impl( ) -> libc::c_int { let result = HostNetworkState::with_socket( sock_fd, - |socket| -> Result { + |socket| -> Result<(SocketAddr, libc::c_int), (libc::c_int, libc::c_int)> { let node = plugin::node(); let net = plugin::simulator::(); let network = net.network.lock().unwrap(); @@ -343,7 +343,8 @@ unsafe fn accept_impl( // We can't simulate blocking accept in a single-threaded simulator, so if there is no // connection waiting for us, just bail. network - .accept_connect(node, endpoint.addr) + .accept_connect(socket.ty, node, endpoint.addr) + .map(|addr| (addr, socket.ty)) .ok_or((-1, libc::ECONNABORTED)) }, ) @@ -352,18 +353,18 @@ unsafe fn accept_impl( Result::Err((-1, libc::ENOTSOCK)) }); - let remote_addr = match result { + let (remote_addr, proto) = match result { Err((ret, err)) => { trace!("error status: {} {}", ret, err); set_errno(err); return ret; } - Ok(addr) => addr, + Ok(res) => res, }; write_socket_addr(address, address_len, remote_addr); - let endpoint = Endpoint::connect_sync(remote_addr) + let endpoint = Endpoint::connect_sync(proto, remote_addr) .expect("connection failure should already have been detected"); let fd = alloc_fd(); @@ -396,7 +397,7 @@ define_sys_interceptor!( HostNetworkState::with_socket(sock_fd, |socket| { assert!(socket.endpoint.is_none(), "socket already bound"); - match Endpoint::bind_sync(socket_addr) { + match Endpoint::bind_sync(socket.ty, socket_addr) { Ok(ep) => { socket.endpoint = Some(Arc::new(ep)); 0 @@ -438,7 +439,7 @@ define_sys_interceptor!( return Err((-1, libc::EISCONN)); } - let ep = Endpoint::connect_sync(sock_addr).map_err(|e| match e.kind() { + let ep = Endpoint::connect_sync(socket.ty, sock_addr).map_err(|e| match e.kind() { io::ErrorKind::AddrInUse => (-1, libc::EADDRINUSE), io::ErrorKind::AddrNotAvailable => (-1, libc::EADDRNOTAVAIL), _ => { @@ -453,7 +454,7 @@ define_sys_interceptor!( // the other end goes away). let net = plugin::simulator::(); let network = net.network.lock().unwrap(); - if !network.signal_connect(ep.addr, sock_addr) { + if !network.signal_connect(socket.ty, ep.addr, sock_addr) { return Err((-1, libc::ECONNREFUSED)); } @@ -544,8 +545,7 @@ define_sys_interceptor!( match (level, name) { // called by anemo::Network::start (via socket2) // skip returning any value here since Sui only uses it to log an error anyway - (libc::SOL_SOCKET, libc::SO_RCVBUF) | - (libc::SOL_SOCKET, libc::SO_SNDBUF) => 0, + (libc::SOL_SOCKET, libc::SO_RCVBUF) | (libc::SOL_SOCKET, libc::SO_SNDBUF) => 0, _ => { warn!("unhandled getsockopt {} {}", level, name); @@ -1015,6 +1015,7 @@ pub struct Endpoint { net: Arc, node: NodeId, addr: SocketAddr, + proto: libc::c_int, peer: Option, live_tcp_ids: Mutex>, } @@ -1030,16 +1031,17 @@ impl std::fmt::Debug for Endpoint { } impl Endpoint { - /// Bind synchronously (for UDP) - pub fn bind_sync(addr: impl ToSocketAddrs) -> io::Result { + /// Bind synchronously + pub fn bind_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); let node = plugin::node(); let addr = addr.to_socket_addrs()?.next().unwrap(); - let addr = net.network.lock().unwrap().bind(node, addr)?; + let addr = net.network.lock().unwrap().bind(node, proto, addr)?; let ep = Endpoint { net, node, addr, + proto, peer: None, live_tcp_ids: Default::default(), }; @@ -1063,30 +1065,31 @@ impl Endpoint { } /// Creates a [`Endpoint`] from the given address. - pub async fn bind(addr: impl ToSocketAddrs) -> io::Result { + pub async fn bind(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); let node = plugin::node(); let addr = addr.to_socket_addrs()?.next().unwrap(); net.rand_delay().await; - let addr = net.network.lock().unwrap().bind(node, addr)?; + let addr = net.network.lock().unwrap().bind(node, proto, addr)?; Ok(Endpoint { net, node, addr, + proto, peer: None, live_tcp_ids: Default::default(), }) } /// Connects this [`Endpoint`] to a remote address. - pub async fn connect(addr: impl ToSocketAddrs) -> io::Result { + pub async fn connect(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); net.rand_delay().await; - Self::connect_sync(addr) + Self::connect_sync(proto, addr) } /// For libc::connect() - pub fn connect_sync(addr: impl ToSocketAddrs) -> io::Result { + pub fn connect_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); let node = plugin::node(); let peer = addr.to_socket_addrs()?.next().unwrap(); @@ -1095,11 +1098,12 @@ impl Endpoint { } else { SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) }; - let addr = net.network.lock().unwrap().bind(node, addr)?; + let addr = net.network.lock().unwrap().bind(node, proto, addr)?; Ok(Endpoint { net, node, addr, + proto, peer: Some(peer), live_tcp_ids: Default::default(), }) @@ -1128,7 +1132,7 @@ impl Endpoint { .network .lock() .unwrap() - .deregister_tcp_id(self.node, remote_sock, id); + .deregister_tcp_id(self.node, self.proto, remote_sock, id); } /// Returns the local socket address. @@ -1234,7 +1238,7 @@ impl Endpoint { .network .lock() .unwrap() - .send(plugin::node(), self.addr, dst, tag, data) + .send(plugin::node(), self.proto, self.addr, dst, tag, data) } /// Receives a raw message. @@ -1244,12 +1248,12 @@ impl Endpoint { #[cfg_attr(docsrs, doc(cfg(msim)))] pub async fn recv_from_raw(&self, tag: u64) -> io::Result<(Payload, SocketAddr)> { trace!("awaiting recv: {} tag={:x}", self.addr, tag); - let recver = self - .net - .network - .lock() - .unwrap() - .recv(plugin::node(), self.addr, tag); + let recver = + self.net + .network + .lock() + .unwrap() + .recv(plugin::node(), self.proto, self.addr, tag); let msg = recver .await .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "network is down"))?; @@ -1266,7 +1270,7 @@ impl Endpoint { .network .lock() .unwrap() - .recv_sync(plugin::node(), self.addr, tag) + .recv_sync(plugin::node(), self.proto, self.addr, tag) .ok_or_else(|| io::Error::new(io::ErrorKind::WouldBlock, "recv call would blck"))?; trace!( @@ -1320,12 +1324,13 @@ impl Endpoint { /// Check if there is a message waiting that can be received without blocking. /// If not, schedule a wakeup using the context. pub fn recv_ready(&self, cx: Option<&mut Context<'_>>, tag: u64) -> io::Result { - Ok(self - .net - .network - .lock() - .unwrap() - .recv_ready(cx, plugin::node(), self.addr, tag)) + Ok(self.net.network.lock().unwrap().recv_ready( + cx, + plugin::node(), + self.proto, + self.addr, + tag, + )) } } @@ -1338,7 +1343,7 @@ impl Drop for Endpoint { // avoid panic on panicking if let Ok(mut network) = self.net.network.lock() { - network.close(self.node, self.addr); + network.close(self.proto, self.node, self.addr); } } } @@ -1372,7 +1377,7 @@ mod tests { let barrier_ = barrier.clone(); node1.spawn(async move { - let net = Endpoint::bind(addr1).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); barrier_.wait().await; net.send_to(addr2, 1, payload!(vec![1])).await.unwrap(); @@ -1382,7 +1387,7 @@ mod tests { }); let f = node2.spawn(async move { - let net = Endpoint::bind(addr2).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap(); barrier.wait().await; let mut buf = vec![0; 0x10]; @@ -1411,14 +1416,14 @@ mod tests { let barrier_ = barrier.clone(); node1.spawn(async move { - let net = Endpoint::bind(addr1).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); barrier_.wait().await; net.send_to(addr2, 1, payload!(vec![1])).await.unwrap(); }); let f = node2.spawn(async move { - let net = Endpoint::bind(addr2).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap(); let mut buf = vec![0; 0x10]; timeout(Duration::from_secs(1), net.recv_from(1, &mut buf)) .await @@ -1443,7 +1448,7 @@ mod tests { let node1 = runtime.create_node().ip(addr1.ip()).build(); let f = node1.spawn(async move { - let net = Endpoint::bind(addr1).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); let err = net.recv_from(1, &mut []).await.unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe); // FIXME: should still error @@ -1466,36 +1471,47 @@ mod tests { let f = node.spawn(async move { // unspecified - let ep = Endpoint::bind("0.0.0.0:0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "0.0.0.0:0") + .await + .unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip(), ip); assert_ne!(addr.port(), 0); // unspecified v6 - let ep = Endpoint::bind(":::0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, ":::0").await.unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip(), ip); assert_ne!(addr.port(), 0); // localhost - let ep = Endpoint::bind("127.0.0.1:0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:0") + .await + .unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip().to_string(), "127.0.0.1"); assert_ne!(addr.port(), 0); // localhost v6 - let ep = Endpoint::bind("::1:0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "::1:0").await.unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip().to_string(), "::1"); assert_ne!(addr.port(), 0); // wrong IP - let err = Endpoint::bind("10.0.0.2:0").await.err().unwrap(); + let err = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.2:0") + .await + .err() + .unwrap(); assert_eq!(err.kind(), std::io::ErrorKind::AddrNotAvailable); // drop and reuse port - let _ = Endpoint::bind("10.0.0.1:100").await.unwrap(); - let _ = Endpoint::bind("10.0.0.1:100").await.unwrap(); + let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100") + .await + .unwrap(); + let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100") + .await + .unwrap(); }); runtime.block_on(f).unwrap(); } @@ -1512,8 +1528,12 @@ mod tests { let barrier_ = barrier.clone(); let f1 = node1.spawn(async move { - let ep1 = Endpoint::bind("127.0.0.1:1").await.unwrap(); - let ep2 = Endpoint::bind("10.0.0.1:2").await.unwrap(); + let ep1 = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1") + .await + .unwrap(); + let ep2 = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:2") + .await + .unwrap(); barrier_.wait().await; // FIXME: ep1 should not receive messages from other node @@ -1525,7 +1545,9 @@ mod tests { ep2.recv_from(1, &mut []).await.unwrap(); }); let f2 = node2.spawn(async move { - let ep = Endpoint::bind("127.0.0.1:1").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1") + .await + .unwrap(); barrier.wait().await; ep.send_to("10.0.0.1:1", 1, payload!(vec![1])) @@ -1550,7 +1572,7 @@ mod tests { let barrier_ = barrier.clone(); node1.spawn(async move { - let ep = Endpoint::bind(addr1).await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); assert_eq!(ep.local_addr().unwrap(), addr1); barrier_.wait().await; @@ -1565,7 +1587,7 @@ mod tests { let f = node2.spawn(async move { barrier.wait().await; - let ep = Endpoint::connect(addr1).await.unwrap(); + let ep = Endpoint::connect(libc::SOCK_STREAM, addr1).await.unwrap(); assert_eq!(ep.peer_addr().unwrap(), addr1); ep.send(1, payload!(b"ping".to_vec())).await.unwrap(); diff --git a/msim/src/sim/net/network.rs b/msim/src/sim/net/network.rs index ca0f68b..3a3ac02 100644 --- a/msim/src/sim/net/network.rs +++ b/msim/src/sim/net/network.rs @@ -33,7 +33,7 @@ struct Node { /// NOTE: now a node can have at most one IP address. ip: Option, /// Sockets in the node. - sockets: HashMap>>, + sockets: HashMap>>, /// live tcp connections. live_tcp_ids: HashSet, @@ -67,6 +67,17 @@ pub struct Stat { pub msg_count: u64, } +#[derive(Debug, Hash, Eq, PartialEq)] +struct SocketKey(u16, libc::c_int); + +fn proto_str(proto: libc::c_int) -> &'static str { + match proto { + libc::SOCK_STREAM => "tcp", + libc::SOCK_DGRAM => "udp", + _ => panic!("unsupported socket type {}", proto), + } +} + impl Network { pub fn new(rand: GlobalRng, time: TimeHandle, config: NetworkConfig) -> Self { Self { @@ -179,8 +190,13 @@ impl Network { self.clogged_link.remove(&(src, dst)); } - pub fn bind(&mut self, node_id: NodeId, mut addr: SocketAddr) -> io::Result { - debug!("binding: {addr} -> {node_id}"); + pub fn bind( + &mut self, + node_id: NodeId, + proto: libc::c_int, + mut addr: SocketAddr, + ) -> io::Result { + debug!("binding ({}): {addr} -> {node_id}", proto_str(proto)); let node = self.nodes.get_mut(&node_id).expect("node not found"); // resolve IP if unspecified if addr.ip().is_unspecified() { @@ -200,7 +216,7 @@ impl Network { if addr.port() == 0 { let next_ephemeral_port = node.next_ephemeral_port; let port = (next_ephemeral_port..=u16::MAX) - .find(|port| !node.sockets.contains_key(port)) + .find(|port| !node.sockets.contains_key(&SocketKey(*port, proto))) .ok_or_else(|| { warn!("ephemeral ports exhausted"); io::Error::new(io::ErrorKind::AddrInUse, "no available ephemeral port") @@ -210,7 +226,7 @@ impl Network { addr.set_port(port); } // insert socket - match node.sockets.entry(addr.port()) { + match node.sockets.entry(SocketKey(addr.port(), proto)) { Entry::Occupied(_) => { warn!("bind() error: address already in use: {addr:?}"); return Err(io::Error::new( @@ -239,7 +255,13 @@ impl Network { ); } - pub fn deregister_tcp_id(&mut self, node: NodeId, remote_addr: &SocketAddr, tcp_id: u32) { + pub fn deregister_tcp_id( + &mut self, + node: NodeId, + proto: libc::c_int, + remote_addr: &SocketAddr, + tcp_id: u32, + ) { trace!("deregistering tcp id {} for node {}", tcp_id, node); // node may have been deleted @@ -262,7 +284,7 @@ impl Network { if let Some(socket) = self .nodes .get_mut(node_id) - .map(|node| node.sockets.get(&remote_addr.port())) + .map(|node| node.sockets.get(&SocketKey(remote_addr.port(), proto))) .tap_none(|| debug!("No node found for {node_id}")) .flatten() { @@ -280,14 +302,14 @@ impl Network { } } - pub fn signal_connect(&self, src: SocketAddr, dst: SocketAddr) -> bool { + pub fn signal_connect(&self, proto: libc::c_int, src: SocketAddr, dst: SocketAddr) -> bool { let node = self.get_node_for_addr(&dst.ip()); if node.is_none() { return false; } let node = node.unwrap(); - let dst_socket = self.nodes[&node].sockets.get(&dst.port()); + let dst_socket = self.nodes[&node].sockets.get(&SocketKey(dst.port(), proto)); if let Some(dst_socket) = dst_socket { dst_socket.lock().unwrap().signal_connect(src); @@ -297,22 +319,31 @@ impl Network { } } - pub fn accept_connect(&self, node: NodeId, listening: SocketAddr) -> Option { - let socket = self.nodes[&node].sockets.get(&listening.port()).unwrap(); + pub fn accept_connect( + &self, + proto: libc::c_int, + node: NodeId, + listening: SocketAddr, + ) -> Option { + let socket = self.nodes[&node] + .sockets + .get(&SocketKey(listening.port(), proto)) + .unwrap(); socket.lock().unwrap().accept_connect() } - pub fn close(&mut self, node_id: NodeId, addr: SocketAddr) { + pub fn close(&mut self, proto: libc::c_int, node_id: NodeId, addr: SocketAddr) { if let Some(node) = self.nodes.get_mut(&node_id) { debug!("close: {node_id} {addr}"); // TODO: simulate TIME_WAIT? - node.sockets.remove(&addr.port()); + node.sockets.remove(&SocketKey(addr.port(), proto)); } } pub fn send( &mut self, node_id: NodeId, + proto: libc::c_int, src: SocketAddr, dst: SocketAddr, tag: u64, @@ -383,7 +414,7 @@ impl Network { } } - let mailbox = match node.sockets.get(&dst.port()) { + let mailbox = match node.sockets.get(&SocketKey(dst.port(), proto)) { Some(mailbox) => Arc::downgrade(mailbox), None => { debug!("destination port not available: {dst}"); @@ -420,15 +451,27 @@ impl Network { Ok(()) } - pub fn recv(&mut self, node: NodeId, dst: SocketAddr, tag: u64) -> oneshot::Receiver { - self.nodes[&node].sockets[&dst.port()] + pub fn recv( + &mut self, + node: NodeId, + proto: libc::c_int, + dst: SocketAddr, + tag: u64, + ) -> oneshot::Receiver { + self.nodes[&node].sockets[&SocketKey(dst.port(), proto)] .lock() .unwrap() .recv(tag) } - pub fn recv_sync(&mut self, node: NodeId, dst: SocketAddr, tag: u64) -> Option { - self.nodes[&node].sockets[&dst.port()] + pub fn recv_sync( + &mut self, + node: NodeId, + proto: libc::c_int, + dst: SocketAddr, + tag: u64, + ) -> Option { + self.nodes[&node].sockets[&SocketKey(dst.port(), proto)] .lock() .unwrap() .recv_sync(tag) @@ -438,10 +481,11 @@ impl Network { &self, cx: Option<&mut Context<'_>>, node: NodeId, + proto: libc::c_int, dst: SocketAddr, tag: u64, ) -> bool { - self.nodes[&node].sockets[&dst.port()] + self.nodes[&node].sockets[&SocketKey(dst.port(), proto)] .lock() .unwrap() .recv_ready(cx, tag)