Skip to content

Commit

Permalink
Move Owned{Read,Write}Half to the correct place in the tokio module s…
Browse files Browse the repository at this point in the history
…tructure
  • Loading branch information
mystenmark committed Nov 1, 2023
1 parent 03b22b1 commit 1bdccd9
Showing 1 changed file with 73 additions and 63 deletions.
136 changes: 73 additions & 63 deletions msim-tokio/src/sim/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use tracing::{debug, trace};
use std::{
future::Future,
io,
net::SocketAddr,
net::SocketAddr as StdSocketAddr,
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
pin::Pin,
sync::{
Expand Down Expand Up @@ -33,7 +33,7 @@ use crate::poller::Poller;
pub struct TcpListener {
fd: OwnedFd,
ep: Arc<Endpoint>,
poller: Poller<io::Result<(TcpStream, SocketAddr)>>,
poller: Poller<io::Result<(TcpStream, StdSocketAddr)>>,
}

impl std::fmt::Debug for TcpListener {
Expand Down Expand Up @@ -68,22 +68,25 @@ impl TcpListener {
}))
}

async fn bind_addr(addr: SocketAddr) -> io::Result<Self> {
async fn bind_addr(addr: StdSocketAddr) -> io::Result<Self> {
let tcp_sock = std::net::TcpListener::bind(addr)?;
Self::from_std(tcp_sock)
}

/// poll_accept
pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<(TcpStream, SocketAddr)>> {
pub fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<io::Result<(TcpStream, StdSocketAddr)>> {
self.poller
.poll_with_fut(cx, || Self::poll_accept_internal(self.ep.clone()))
}

pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
pub async fn accept(&self) -> io::Result<(TcpStream, StdSocketAddr)> {
Self::poll_accept_internal(self.ep.clone()).await
}

async fn poll_accept_internal(ep: Arc<Endpoint>) -> io::Result<(TcpStream, SocketAddr)> {
async fn poll_accept_internal(ep: Arc<Endpoint>) -> io::Result<(TcpStream, StdSocketAddr)> {
let (msg, from) = ep.recv_from_raw(0).await?;

let remote_tcp_id = Message::new(msg).unwrap_tcp_id();
Expand Down Expand Up @@ -141,7 +144,7 @@ impl TcpListener {
unsafe { Ok(std::net::TcpListener::from_raw_fd(fd.release())) }
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> io::Result<StdSocketAddr> {
self.ep.local_addr()
}

Expand Down Expand Up @@ -262,7 +265,7 @@ struct TcpState {
recv_seq: AtomicU32,
local_tcp_id: u32,
remote_tcp_id: u32,
remote_sock: SocketAddr,
remote_sock: StdSocketAddr,

// not simulated, only present to return the correct value with getters/settters.
nodelay: AtomicBool,
Expand All @@ -273,7 +276,7 @@ impl TcpState {
ep: Arc<Endpoint>,
local_tcp_id: u32,
remote_tcp_id: u32,
remote_sock: SocketAddr,
remote_sock: StdSocketAddr,
) -> Self {
Self {
ep,
Expand Down Expand Up @@ -393,16 +396,16 @@ impl TcpStream {
Ok(Self::new(state))
}

pub fn peer_addr(&self) -> io::Result<SocketAddr> {
pub fn peer_addr(&self) -> io::Result<StdSocketAddr> {
self.state.ep.peer_addr()
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> io::Result<StdSocketAddr> {
self.state.ep.local_addr()
}

pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
split_owned(self)
pub fn into_split(self) -> (tcp::OwnedReadHalf, tcp::OwnedWriteHalf) {
tcp::split_owned(self)
}

fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
Expand Down Expand Up @@ -579,62 +582,66 @@ impl AsyncWrite for TcpStream {
}
}

pub struct OwnedWriteHalf {
inner: Arc<TcpStream>,
// TODO: support this
_shutdown_on_drop: bool,
}
pub mod tcp {
use super::{io, Arc, AsyncRead, AsyncWrite, Context, Pin, Poll, ReadBuf, TcpStream};

impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.inner.poll_write_priv(cx, buf)
pub struct OwnedWriteHalf {
pub(super) inner: Arc<TcpStream>,
// TODO: support this
_shutdown_on_drop: bool,
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.poll_flush_priv(cx)
}
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.inner.poll_write_priv(cx, buf)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.poll_shutdown_priv(cx)
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.poll_flush_priv(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.inner.poll_shutdown_priv(cx)
}
}
}

pub struct OwnedReadHalf {
inner: Arc<TcpStream>,
}
pub struct OwnedReadHalf {
pub(super) inner: Arc<TcpStream>,
}

impl OwnedReadHalf {
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf).await
impl OwnedReadHalf {
pub async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf).await
}
}
}

impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
read: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.inner
.poll_read_priv(false, cx, read)
.map(|r| r.map(|_| ()))
impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
read: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.inner
.poll_read_priv(false, cx, read)
.map(|r| r.map(|_| ()))
}
}
}

fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
let arc = Arc::new(stream);
let read = OwnedReadHalf {
inner: Arc::clone(&arc),
};
let write = OwnedWriteHalf {
inner: arc,
_shutdown_on_drop: true,
};
(read, write)
pub(super) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
let arc = Arc::new(stream);
let read = OwnedReadHalf {
inner: Arc::clone(&arc),
};
let write = OwnedWriteHalf {
inner: arc,
_shutdown_on_drop: true,
};
(read, write)
}
}

pub struct TcpSocket {
Expand Down Expand Up @@ -693,7 +700,7 @@ impl TcpSocket {
todo!()
}

pub fn local_addr(&self) -> io::Result<SocketAddr> {
pub fn local_addr(&self) -> io::Result<StdSocketAddr> {
self.bind_addr
.lock()
.unwrap()
Expand All @@ -706,13 +713,13 @@ impl TcpSocket {
todo!()
}

pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
pub fn bind(&self, addr: StdSocketAddr) -> io::Result<()> {
let ep = Endpoint::bind_sync(addr)?;
*self.bind_addr.lock().unwrap() = Some(ep.into());
Ok(())
}

pub async fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
pub async fn connect(self, addr: StdSocketAddr) -> io::Result<TcpStream> {
TcpStream::connect(addr).await
}

Expand Down Expand Up @@ -770,7 +777,7 @@ impl IntoRawFd for TcpStream {
}
}

pub async fn lookup_host<T>(host: T) -> io::Result<impl Iterator<Item = SocketAddr>>
pub async fn lookup_host<T>(host: T) -> io::Result<impl Iterator<Item = StdSocketAddr>>
where
T: ToSocketAddrs,
{
Expand All @@ -782,7 +789,10 @@ where
#[cfg(test)]
mod tests {

use super::{OwnedReadHalf, OwnedWriteHalf, TcpListener, TcpStream};
use super::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpListener, TcpStream,
};
use bytes::{BufMut, BytesMut};
use futures::join;
use msim::{
Expand Down

0 comments on commit 1bdccd9

Please sign in to comment.