Skip to content

Commit

Permalink
Separate ports by protocol (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
mystenmark authored May 17, 2024
1 parent 6f88ec8 commit 077b735
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 73 deletions.
1 change: 1 addition & 0 deletions msim-tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ real_tokio = { git = "https://github.com/mystenmark/tokio-madsim-fork.git", rev
bytes = { version = "1.1" }
futures = { version = "0.3.0", features = ["async-await"] }
mio = { version = "0.8.1" }
libc = "0.2"
4 changes: 2 additions & 2 deletions msim-tokio/src/sim/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ impl TcpStream {
}

async fn connect_addr(addr: impl ToSocketAddrs) -> io::Result<TcpStream> {
let ep = Arc::new(Endpoint::connect(addr).await?);
let ep = Arc::new(Endpoint::connect(libc::SOCK_STREAM, addr).await?);
trace!("connect {:?}", ep.local_addr());

let remote_sock = ep.peer_addr()?;
Expand Down Expand Up @@ -714,7 +714,7 @@ impl TcpSocket {
}

pub fn bind(&self, addr: StdSocketAddr) -> io::Result<()> {
let ep = Endpoint::bind_sync(addr)?;
let ep = Endpoint::bind_sync(libc::SOCK_STREAM, addr)?;
*self.bind_addr.lock().unwrap() = Some(ep.into());
Ok(())
}
Expand Down
126 changes: 74 additions & 52 deletions msim/src/sim/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ unsafe fn accept_impl(
) -> libc::c_int {
let result = HostNetworkState::with_socket(
sock_fd,
|socket| -> Result<SocketAddr, (libc::c_int, libc::c_int)> {
|socket| -> Result<(SocketAddr, libc::c_int), (libc::c_int, libc::c_int)> {
let node = plugin::node();
let net = plugin::simulator::<NetSim>();
let network = net.network.lock().unwrap();
Expand All @@ -343,7 +343,8 @@ unsafe fn accept_impl(
// We can't simulate blocking accept in a single-threaded simulator, so if there is no
// connection waiting for us, just bail.
network
.accept_connect(node, endpoint.addr)
.accept_connect(socket.ty, node, endpoint.addr)
.map(|addr| (addr, socket.ty))
.ok_or((-1, libc::ECONNABORTED))
},
)
Expand All @@ -352,18 +353,18 @@ unsafe fn accept_impl(
Result::Err((-1, libc::ENOTSOCK))
});

let remote_addr = match result {
let (remote_addr, proto) = match result {
Err((ret, err)) => {
trace!("error status: {} {}", ret, err);
set_errno(err);
return ret;
}
Ok(addr) => addr,
Ok(res) => res,
};

write_socket_addr(address, address_len, remote_addr);

let endpoint = Endpoint::connect_sync(remote_addr)
let endpoint = Endpoint::connect_sync(proto, remote_addr)
.expect("connection failure should already have been detected");

let fd = alloc_fd();
Expand Down Expand Up @@ -396,7 +397,7 @@ define_sys_interceptor!(

HostNetworkState::with_socket(sock_fd, |socket| {
assert!(socket.endpoint.is_none(), "socket already bound");
match Endpoint::bind_sync(socket_addr) {
match Endpoint::bind_sync(socket.ty, socket_addr) {
Ok(ep) => {
socket.endpoint = Some(Arc::new(ep));
0
Expand Down Expand Up @@ -438,7 +439,7 @@ define_sys_interceptor!(
return Err((-1, libc::EISCONN));
}

let ep = Endpoint::connect_sync(sock_addr).map_err(|e| match e.kind() {
let ep = Endpoint::connect_sync(socket.ty, sock_addr).map_err(|e| match e.kind() {
io::ErrorKind::AddrInUse => (-1, libc::EADDRINUSE),
io::ErrorKind::AddrNotAvailable => (-1, libc::EADDRNOTAVAIL),
_ => {
Expand All @@ -453,7 +454,7 @@ define_sys_interceptor!(
// the other end goes away).
let net = plugin::simulator::<NetSim>();
let network = net.network.lock().unwrap();
if !network.signal_connect(ep.addr, sock_addr) {
if !network.signal_connect(socket.ty, ep.addr, sock_addr) {
return Err((-1, libc::ECONNREFUSED));
}

Expand Down Expand Up @@ -544,8 +545,7 @@ define_sys_interceptor!(
match (level, name) {
// called by anemo::Network::start (via socket2)
// skip returning any value here since Sui only uses it to log an error anyway
(libc::SOL_SOCKET, libc::SO_RCVBUF) |
(libc::SOL_SOCKET, libc::SO_SNDBUF) => 0,
(libc::SOL_SOCKET, libc::SO_RCVBUF) | (libc::SOL_SOCKET, libc::SO_SNDBUF) => 0,

_ => {
warn!("unhandled getsockopt {} {}", level, name);
Expand Down Expand Up @@ -1015,6 +1015,7 @@ pub struct Endpoint {
net: Arc<NetSim>,
node: NodeId,
addr: SocketAddr,
proto: libc::c_int,
peer: Option<SocketAddr>,
live_tcp_ids: Mutex<HashSet<u32>>,
}
Expand All @@ -1030,16 +1031,17 @@ impl std::fmt::Debug for Endpoint {
}

impl Endpoint {
/// Bind synchronously (for UDP)
pub fn bind_sync(addr: impl ToSocketAddrs) -> io::Result<Self> {
/// Bind synchronously
pub fn bind_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
let net = plugin::simulator::<NetSim>();
let node = plugin::node();
let addr = addr.to_socket_addrs()?.next().unwrap();
let addr = net.network.lock().unwrap().bind(node, addr)?;
let addr = net.network.lock().unwrap().bind(node, proto, addr)?;
let ep = Endpoint {
net,
node,
addr,
proto,
peer: None,
live_tcp_ids: Default::default(),
};
Expand All @@ -1063,30 +1065,31 @@ impl Endpoint {
}

/// Creates a [`Endpoint`] from the given address.
pub async fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
pub async fn bind(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
let net = plugin::simulator::<NetSim>();
let node = plugin::node();
let addr = addr.to_socket_addrs()?.next().unwrap();
net.rand_delay().await;
let addr = net.network.lock().unwrap().bind(node, addr)?;
let addr = net.network.lock().unwrap().bind(node, proto, addr)?;
Ok(Endpoint {
net,
node,
addr,
proto,
peer: None,
live_tcp_ids: Default::default(),
})
}

/// Connects this [`Endpoint`] to a remote address.
pub async fn connect(addr: impl ToSocketAddrs) -> io::Result<Self> {
pub async fn connect(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
let net = plugin::simulator::<NetSim>();
net.rand_delay().await;
Self::connect_sync(addr)
Self::connect_sync(proto, addr)
}

/// For libc::connect()
pub fn connect_sync(addr: impl ToSocketAddrs) -> io::Result<Self> {
pub fn connect_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
let net = plugin::simulator::<NetSim>();
let node = plugin::node();
let peer = addr.to_socket_addrs()?.next().unwrap();
Expand All @@ -1095,11 +1098,12 @@ impl Endpoint {
} else {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
};
let addr = net.network.lock().unwrap().bind(node, addr)?;
let addr = net.network.lock().unwrap().bind(node, proto, addr)?;
Ok(Endpoint {
net,
node,
addr,
proto,
peer: Some(peer),
live_tcp_ids: Default::default(),
})
Expand Down Expand Up @@ -1128,7 +1132,7 @@ impl Endpoint {
.network
.lock()
.unwrap()
.deregister_tcp_id(self.node, remote_sock, id);
.deregister_tcp_id(self.node, self.proto, remote_sock, id);
}

/// Returns the local socket address.
Expand Down Expand Up @@ -1234,7 +1238,7 @@ impl Endpoint {
.network
.lock()
.unwrap()
.send(plugin::node(), self.addr, dst, tag, data)
.send(plugin::node(), self.proto, self.addr, dst, tag, data)
}

/// Receives a raw message.
Expand All @@ -1244,12 +1248,12 @@ impl Endpoint {
#[cfg_attr(docsrs, doc(cfg(msim)))]
pub async fn recv_from_raw(&self, tag: u64) -> io::Result<(Payload, SocketAddr)> {
trace!("awaiting recv: {} tag={:x}", self.addr, tag);
let recver = self
.net
.network
.lock()
.unwrap()
.recv(plugin::node(), self.addr, tag);
let recver =
self.net
.network
.lock()
.unwrap()
.recv(plugin::node(), self.proto, self.addr, tag);
let msg = recver
.await
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "network is down"))?;
Expand All @@ -1266,7 +1270,7 @@ impl Endpoint {
.network
.lock()
.unwrap()
.recv_sync(plugin::node(), self.addr, tag)
.recv_sync(plugin::node(), self.proto, self.addr, tag)
.ok_or_else(|| io::Error::new(io::ErrorKind::WouldBlock, "recv call would blck"))?;

trace!(
Expand Down Expand Up @@ -1320,12 +1324,13 @@ impl Endpoint {
/// Check if there is a message waiting that can be received without blocking.
/// If not, schedule a wakeup using the context.
pub fn recv_ready(&self, cx: Option<&mut Context<'_>>, tag: u64) -> io::Result<bool> {
Ok(self
.net
.network
.lock()
.unwrap()
.recv_ready(cx, plugin::node(), self.addr, tag))
Ok(self.net.network.lock().unwrap().recv_ready(
cx,
plugin::node(),
self.proto,
self.addr,
tag,
))
}
}

Expand All @@ -1338,7 +1343,7 @@ impl Drop for Endpoint {

// avoid panic on panicking
if let Ok(mut network) = self.net.network.lock() {
network.close(self.node, self.addr);
network.close(self.proto, self.node, self.addr);
}
}
}
Expand Down Expand Up @@ -1372,7 +1377,7 @@ mod tests {

let barrier_ = barrier.clone();
node1.spawn(async move {
let net = Endpoint::bind(addr1).await.unwrap();
let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
barrier_.wait().await;

net.send_to(addr2, 1, payload!(vec![1])).await.unwrap();
Expand All @@ -1382,7 +1387,7 @@ mod tests {
});

let f = node2.spawn(async move {
let net = Endpoint::bind(addr2).await.unwrap();
let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap();
barrier.wait().await;

let mut buf = vec![0; 0x10];
Expand Down Expand Up @@ -1411,14 +1416,14 @@ mod tests {

let barrier_ = barrier.clone();
node1.spawn(async move {
let net = Endpoint::bind(addr1).await.unwrap();
let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
barrier_.wait().await;

net.send_to(addr2, 1, payload!(vec![1])).await.unwrap();
});

let f = node2.spawn(async move {
let net = Endpoint::bind(addr2).await.unwrap();
let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap();
let mut buf = vec![0; 0x10];
timeout(Duration::from_secs(1), net.recv_from(1, &mut buf))
.await
Expand All @@ -1443,7 +1448,7 @@ mod tests {
let node1 = runtime.create_node().ip(addr1.ip()).build();

let f = node1.spawn(async move {
let net = Endpoint::bind(addr1).await.unwrap();
let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
let err = net.recv_from(1, &mut []).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
// FIXME: should still error
Expand All @@ -1466,36 +1471,47 @@ mod tests {

let f = node.spawn(async move {
// unspecified
let ep = Endpoint::bind("0.0.0.0:0").await.unwrap();
let ep = Endpoint::bind(libc::SOCK_STREAM, "0.0.0.0:0")
.await
.unwrap();
let addr = ep.local_addr().unwrap();
assert_eq!(addr.ip(), ip);
assert_ne!(addr.port(), 0);

// unspecified v6
let ep = Endpoint::bind(":::0").await.unwrap();
let ep = Endpoint::bind(libc::SOCK_STREAM, ":::0").await.unwrap();
let addr = ep.local_addr().unwrap();
assert_eq!(addr.ip(), ip);
assert_ne!(addr.port(), 0);

// localhost
let ep = Endpoint::bind("127.0.0.1:0").await.unwrap();
let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:0")
.await
.unwrap();
let addr = ep.local_addr().unwrap();
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert_ne!(addr.port(), 0);

// localhost v6
let ep = Endpoint::bind("::1:0").await.unwrap();
let ep = Endpoint::bind(libc::SOCK_STREAM, "::1:0").await.unwrap();
let addr = ep.local_addr().unwrap();
assert_eq!(addr.ip().to_string(), "::1");
assert_ne!(addr.port(), 0);

// wrong IP
let err = Endpoint::bind("10.0.0.2:0").await.err().unwrap();
let err = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.2:0")
.await
.err()
.unwrap();
assert_eq!(err.kind(), std::io::ErrorKind::AddrNotAvailable);

// drop and reuse port
let _ = Endpoint::bind("10.0.0.1:100").await.unwrap();
let _ = Endpoint::bind("10.0.0.1:100").await.unwrap();
let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100")
.await
.unwrap();
let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100")
.await
.unwrap();
});
runtime.block_on(f).unwrap();
}
Expand All @@ -1512,8 +1528,12 @@ mod tests {

let barrier_ = barrier.clone();
let f1 = node1.spawn(async move {
let ep1 = Endpoint::bind("127.0.0.1:1").await.unwrap();
let ep2 = Endpoint::bind("10.0.0.1:2").await.unwrap();
let ep1 = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1")
.await
.unwrap();
let ep2 = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:2")
.await
.unwrap();
barrier_.wait().await;

// FIXME: ep1 should not receive messages from other node
Expand All @@ -1525,7 +1545,9 @@ mod tests {
ep2.recv_from(1, &mut []).await.unwrap();
});
let f2 = node2.spawn(async move {
let ep = Endpoint::bind("127.0.0.1:1").await.unwrap();
let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1")
.await
.unwrap();
barrier.wait().await;

ep.send_to("10.0.0.1:1", 1, payload!(vec![1]))
Expand All @@ -1550,7 +1572,7 @@ mod tests {

let barrier_ = barrier.clone();
node1.spawn(async move {
let ep = Endpoint::bind(addr1).await.unwrap();
let ep = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
assert_eq!(ep.local_addr().unwrap(), addr1);
barrier_.wait().await;

Expand All @@ -1565,7 +1587,7 @@ mod tests {

let f = node2.spawn(async move {
barrier.wait().await;
let ep = Endpoint::connect(addr1).await.unwrap();
let ep = Endpoint::connect(libc::SOCK_STREAM, addr1).await.unwrap();
assert_eq!(ep.peer_addr().unwrap(), addr1);

ep.send(1, payload!(b"ping".to_vec())).await.unwrap();
Expand Down
Loading

0 comments on commit 077b735

Please sign in to comment.