diff --git a/msim-tokio/src/sim/net.rs b/msim-tokio/src/sim/net.rs index 4022744..f28c727 100644 --- a/msim-tokio/src/sim/net.rs +++ b/msim-tokio/src/sim/net.rs @@ -3,7 +3,7 @@ use tracing::{debug, trace}; use std::{ future::Future, io, - net::SocketAddr, + net::SocketAddr as StdSocketAddr, os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, pin::Pin, sync::{ @@ -33,7 +33,7 @@ use crate::poller::Poller; pub struct TcpListener { fd: OwnedFd, ep: Arc, - poller: Poller>, + poller: Poller>, } impl std::fmt::Debug for TcpListener { @@ -68,22 +68,25 @@ impl TcpListener { })) } - async fn bind_addr(addr: SocketAddr) -> io::Result { + async fn bind_addr(addr: StdSocketAddr) -> io::Result { let tcp_sock = std::net::TcpListener::bind(addr)?; Self::from_std(tcp_sock) } /// poll_accept - pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_accept( + &self, + cx: &mut Context<'_>, + ) -> Poll> { self.poller .poll_with_fut(cx, || Self::poll_accept_internal(self.ep.clone())) } - pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { + pub async fn accept(&self) -> io::Result<(TcpStream, StdSocketAddr)> { Self::poll_accept_internal(self.ep.clone()).await } - async fn poll_accept_internal(ep: Arc) -> io::Result<(TcpStream, SocketAddr)> { + async fn poll_accept_internal(ep: Arc) -> io::Result<(TcpStream, StdSocketAddr)> { let (msg, from) = ep.recv_from_raw(0).await?; let remote_tcp_id = Message::new(msg).unwrap_tcp_id(); @@ -141,7 +144,7 @@ impl TcpListener { unsafe { Ok(std::net::TcpListener::from_raw_fd(fd.release())) } } - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { self.ep.local_addr() } @@ -262,7 +265,7 @@ struct TcpState { recv_seq: AtomicU32, local_tcp_id: u32, remote_tcp_id: u32, - remote_sock: SocketAddr, + remote_sock: StdSocketAddr, // not simulated, only present to return the correct value with getters/settters. nodelay: AtomicBool, @@ -273,7 +276,7 @@ impl TcpState { ep: Arc, local_tcp_id: u32, remote_tcp_id: u32, - remote_sock: SocketAddr, + remote_sock: StdSocketAddr, ) -> Self { Self { ep, @@ -393,16 +396,16 @@ impl TcpStream { Ok(Self::new(state)) } - pub fn peer_addr(&self) -> io::Result { + pub fn peer_addr(&self) -> io::Result { self.state.ep.peer_addr() } - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { self.state.ep.local_addr() } - pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { - split_owned(self) + pub fn into_split(self) -> (tcp::OwnedReadHalf, tcp::OwnedWriteHalf) { + tcp::split_owned(self) } fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { @@ -579,62 +582,66 @@ impl AsyncWrite for TcpStream { } } -pub struct OwnedWriteHalf { - inner: Arc, - // TODO: support this - _shutdown_on_drop: bool, -} +pub mod tcp { + use super::{io, Arc, AsyncRead, AsyncWrite, Context, Pin, Poll, ReadBuf, TcpStream}; -impl AsyncWrite for OwnedWriteHalf { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.inner.poll_write_priv(cx, buf) + pub struct OwnedWriteHalf { + pub(super) inner: Arc, + // TODO: support this + _shutdown_on_drop: bool, } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_flush_priv(cx) - } + impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.inner.poll_write_priv(cx, buf) + } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_shutdown_priv(cx) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush_priv(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_shutdown_priv(cx) + } } -} -pub struct OwnedReadHalf { - inner: Arc, -} + pub struct OwnedReadHalf { + pub(super) inner: Arc, + } -impl OwnedReadHalf { - pub async fn peek(&self, buf: &mut [u8]) -> io::Result { - self.inner.peek(buf).await + impl OwnedReadHalf { + pub async fn peek(&self, buf: &mut [u8]) -> io::Result { + self.inner.peek(buf).await + } } -} -impl AsyncRead for OwnedReadHalf { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - read: &mut ReadBuf<'_>, - ) -> Poll> { - self.inner - .poll_read_priv(false, cx, read) - .map(|r| r.map(|_| ())) + impl AsyncRead for OwnedReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + read: &mut ReadBuf<'_>, + ) -> Poll> { + self.inner + .poll_read_priv(false, cx, read) + .map(|r| r.map(|_| ())) + } } -} -fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) { - let arc = Arc::new(stream); - let read = OwnedReadHalf { - inner: Arc::clone(&arc), - }; - let write = OwnedWriteHalf { - inner: arc, - _shutdown_on_drop: true, - }; - (read, write) + pub(super) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) { + let arc = Arc::new(stream); + let read = OwnedReadHalf { + inner: Arc::clone(&arc), + }; + let write = OwnedWriteHalf { + inner: arc, + _shutdown_on_drop: true, + }; + (read, write) + } } pub struct TcpSocket { @@ -693,7 +700,7 @@ impl TcpSocket { todo!() } - pub fn local_addr(&self) -> io::Result { + pub fn local_addr(&self) -> io::Result { self.bind_addr .lock() .unwrap() @@ -706,13 +713,13 @@ impl TcpSocket { todo!() } - pub fn bind(&self, addr: SocketAddr) -> io::Result<()> { + pub fn bind(&self, addr: StdSocketAddr) -> io::Result<()> { let ep = Endpoint::bind_sync(addr)?; *self.bind_addr.lock().unwrap() = Some(ep.into()); Ok(()) } - pub async fn connect(self, addr: SocketAddr) -> io::Result { + pub async fn connect(self, addr: StdSocketAddr) -> io::Result { TcpStream::connect(addr).await } @@ -770,7 +777,7 @@ impl IntoRawFd for TcpStream { } } -pub async fn lookup_host(host: T) -> io::Result> +pub async fn lookup_host(host: T) -> io::Result> where T: ToSocketAddrs, { @@ -782,7 +789,10 @@ where #[cfg(test)] mod tests { - use super::{OwnedReadHalf, OwnedWriteHalf, TcpListener, TcpStream}; + use super::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpListener, TcpStream, + }; use bytes::{BufMut, BytesMut}; use futures::join; use msim::{