From 4f0e506a10896881d779c3c70d3d773adb6ff26e Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Thu, 28 Mar 2024 16:44:08 +0800 Subject: [PATCH 1/3] fix automatically drop IpstackTcpStream --- src/lib.rs | 18 +++++- src/stream/mod.rs | 1 + src/stream/tcp.rs | 140 +++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 144 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3c76e04..684c36e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,7 +17,10 @@ use log::{error, trace}; use crate::{ packet::IpStackPacketProtocol, - stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, + stream::{ + IpStackStream, IpStackTcpStream, IpStackTcpStreamInner, IpStackUdpStream, + IpStackUnknownTransport, + }, }; mod error; mod packet; @@ -88,6 +91,17 @@ impl IpStack { { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); + let (drop_sender, mut drop_receiver) = mpsc::unbounded_channel::(); + tokio::spawn(async move { + while let Some(mut inner) = drop_receiver.recv().await { + tokio::spawn(async move { + if let Err(e) = inner.shutdown().await { + trace!("fail to drop {e:?}"); + } + }); + } + }); + tokio::spawn(async move { let mut streams: HashMap> = HashMap::new(); let mut buffer = [0u8; u16::MAX as usize]; @@ -116,7 +130,7 @@ impl IpStack { Vacant(entry) => { match packet.transport_protocol(){ IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ + match IpStackTcpStream::new(drop_sender.clone(),packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ Ok(stream) => { entry.insert(stream.stream_sender()); accept_sender.send(IpStackStream::Tcp(stream))?; diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 4c9bd68..04d7087 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,6 +1,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; pub use self::tcp::IpStackTcpStream; +pub(crate) use self::tcp::IpStackTcpStreamInner; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 8c396ad..fbf7a19 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -50,7 +50,7 @@ impl Shutdown { } #[derive(Debug)] -pub struct IpStackTcpStream { +pub(crate) struct IpStackTcpStreamInner { src_addr: SocketAddr, dst_addr: SocketAddr, stream_sender: UnboundedSender, @@ -63,7 +63,7 @@ pub struct IpStackTcpStream { write_notify: Option, } -impl IpStackTcpStream { +impl IpStackTcpStreamInner { pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, @@ -71,10 +71,10 @@ impl IpStackTcpStream { pkt_sender: UnboundedSender, mtu: u16, tcp_timeout: Duration, - ) -> Result { + ) -> Result { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - let stream = IpStackTcpStream { + let stream = IpStackTcpStreamInner { src_addr, dst_addr, stream_sender, @@ -191,16 +191,16 @@ impl IpStackTcpStream { }) } - pub fn local_addr(&self) -> SocketAddr { - self.src_addr - } + // pub fn local_addr(&self) -> SocketAddr { + // self.src_addr + // } - pub fn peer_addr(&self) -> SocketAddr { - self.dst_addr - } + // pub fn peer_addr(&self) -> SocketAddr { + // self.dst_addr + // } } -impl AsyncRead for IpStackTcpStream { +impl AsyncRead for IpStackTcpStreamInner { fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -439,7 +439,7 @@ impl AsyncRead for IpStackTcpStream { } } -impl AsyncWrite for IpStackTcpStream { +impl AsyncWrite for IpStackTcpStreamInner { fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -528,10 +528,124 @@ impl AsyncWrite for IpStackTcpStream { } } -impl Drop for IpStackTcpStream { +impl Drop for IpStackTcpStreamInner { fn drop(&mut self) { if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { _ = self.packet_sender.send(p); } } } + +#[derive(Debug)] +pub struct IpStackTcpStream { + inner: Option, + drop_sender: UnboundedSender, + src_addr: SocketAddr, + dst_addr: SocketAddr, + stream_sender: UnboundedSender, +} + +impl IpStackTcpStream { + pub(crate) fn new( + drop_sender: UnboundedSender, + src_addr: SocketAddr, + dst_addr: SocketAddr, + tcp: TcpPacket, + pkt_sender: UnboundedSender, + mtu: u16, + tcp_timeout: Duration, + ) -> Result { + let stream = + IpStackTcpStreamInner::new(src_addr, dst_addr, tcp, pkt_sender, mtu, tcp_timeout)?; + Ok(IpStackTcpStream { + stream_sender: stream.stream_sender(), + inner: Some(stream), + drop_sender, + src_addr, + dst_addr, + }) + } + pub(crate) fn stream_sender(&self) -> UnboundedSender { + self.stream_sender.clone() + } + + pub fn local_addr(&self) -> SocketAddr { + self.src_addr + } + + pub fn peer_addr(&self) -> SocketAddr { + self.dst_addr + } +} + +impl AsyncRead for IpStackTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_read(cx, buf) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } +} + +impl AsyncWrite for IpStackTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_write(cx, buf) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_flush(cx) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = &mut self.inner { + Pin::new(inner).poll_shutdown(cx) + } else { + Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "", + ))) + } + } +} + +impl Drop for IpStackTcpStream { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + if let Err(e) = self.drop_sender.send(inner) { + trace!("fail to send IpStackTcpStreamInner to drop {:?}", e); + } + } + } +} From 1ebfc2143a9e608384e3a77977b0e7521d7e2e6e Mon Sep 17 00:00:00 2001 From: xmh0511 <970252187@qq.com> Date: Thu, 28 Mar 2024 17:46:33 +0800 Subject: [PATCH 2/3] IpStackTcpStreamInner is so larger and place it to heap --- src/lib.rs | 3 ++- src/stream/tcp.rs | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 684c36e..1e94f00 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,7 +91,8 @@ impl IpStack { { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - let (drop_sender, mut drop_receiver) = mpsc::unbounded_channel::(); + let (drop_sender, mut drop_receiver) = + mpsc::unbounded_channel::>(); tokio::spawn(async move { while let Some(mut inner) = drop_receiver.recv().await { tokio::spawn(async move { diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index fbf7a19..0b4c757 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -539,7 +539,7 @@ impl Drop for IpStackTcpStreamInner { #[derive(Debug)] pub struct IpStackTcpStream { inner: Option, - drop_sender: UnboundedSender, + drop_sender: UnboundedSender>, src_addr: SocketAddr, dst_addr: SocketAddr, stream_sender: UnboundedSender, @@ -547,7 +547,7 @@ pub struct IpStackTcpStream { impl IpStackTcpStream { pub(crate) fn new( - drop_sender: UnboundedSender, + drop_sender: UnboundedSender>, src_addr: SocketAddr, dst_addr: SocketAddr, tcp: TcpPacket, @@ -643,7 +643,7 @@ impl AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Some(inner) = self.inner.take() { - if let Err(e) = self.drop_sender.send(inner) { + if let Err(e) = self.drop_sender.send(Box::new(inner)) { trace!("fail to send IpStackTcpStreamInner to drop {:?}", e); } } From cee8e19d832367948d0a0986313d8ead59795e27 Mon Sep 17 00:00:00 2001 From: SajjadPourali Date: Fri, 29 Mar 2024 20:29:31 -0400 Subject: [PATCH 3/3] Code refactor --- src/lib.rs | 24 +----- src/stream/mod.rs | 4 +- src/stream/tcb.rs | 8 +- src/stream/tcp.rs | 154 +++----------------------------------- src/stream/tcp_wrapper.rs | 121 ++++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 166 deletions(-) create mode 100644 src/stream/tcp_wrapper.rs diff --git a/src/lib.rs b/src/lib.rs index 1e94f00..355a44b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,10 +17,7 @@ use log::{error, trace}; use crate::{ packet::IpStackPacketProtocol, - stream::{ - IpStackStream, IpStackTcpStream, IpStackTcpStreamInner, IpStackUdpStream, - IpStackUnknownTransport, - }, + stream::{IpStackStream, IpStackTcpStream, IpStackUdpStream, IpStackUnknownTransport}, }; mod error; mod packet; @@ -91,25 +88,12 @@ impl IpStack { { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - let (drop_sender, mut drop_receiver) = - mpsc::unbounded_channel::>(); - tokio::spawn(async move { - while let Some(mut inner) = drop_receiver.recv().await { - tokio::spawn(async move { - if let Err(e) = inner.shutdown().await { - trace!("fail to drop {e:?}"); - } - }); - } - }); - tokio::spawn(async move { let mut streams: HashMap> = HashMap::new(); let mut buffer = [0u8; u16::MAX as usize]; let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); loop { - // dbg!(streams.len()); select! { Ok(n) = device.read(&mut buffer) => { let offset = if config.packet_information && cfg!(unix) {4} else {0}; @@ -124,14 +108,14 @@ impl IpStack { match streams.entry(packet.network_tuple()){ Occupied(entry) =>{ - if let Err(_x) = entry.get().send(packet){ - trace!("Send packet error \"{}\"", _x); + if let Err(e) = entry.get().send(packet){ + trace!("Send packet error \"{}\"", e); } } Vacant(entry) => { match packet.transport_protocol(){ IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new(drop_sender.clone(),packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ + match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ Ok(stream) => { entry.insert(stream.stream_sender()); accept_sender.send(IpStackStream::Tcp(stream))?; diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 04d7087..42632f4 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,12 +1,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -pub use self::tcp::IpStackTcpStream; -pub(crate) use self::tcp::IpStackTcpStreamInner; +pub use self::tcp_wrapper::IpStackTcpStream; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; mod tcb; mod tcp; +mod tcp_wrapper; mod udp; mod unknown; diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 0e7383b..4e2e281 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -176,6 +176,7 @@ impl Tcb { } pub(super) fn change_last_ack(&mut self, ack: u32) { let distance = ack.wrapping_sub(self.last_ack); + self.last_ack = self.last_ack.wrapping_add(distance); if matches!(self.state, TcpState::Established) { if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { @@ -187,9 +188,12 @@ impl Tcb { self.inflight_packets.push(inflight_packet); } } + self.inflight_packets.retain(|p| { + let last_byte = p.seq.wrapping_add(p.payload.len() as u32); + last_byte.saturating_sub(self.last_ack) > 0 + && self.seq.saturating_sub(last_byte) > 0 + }); } - - self.last_ack = self.last_ack.wrapping_add(distance); } pub fn is_send_buffer_full(&self) -> bool { self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 0b4c757..c884cf1 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -20,7 +20,7 @@ use std::{ }; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + sync::mpsc::{UnboundedReceiver, UnboundedSender}, }; use log::{trace, warn}; @@ -50,10 +50,9 @@ impl Shutdown { } #[derive(Debug)] -pub(crate) struct IpStackTcpStreamInner { +pub(crate) struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, - stream_sender: UnboundedSender, stream_receiver: UnboundedReceiver, packet_sender: UnboundedSender, packet_to_send: Option, @@ -63,21 +62,19 @@ pub(crate) struct IpStackTcpStreamInner { write_notify: Option, } -impl IpStackTcpStreamInner { +impl IpStackTcpStream { pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, tcp: TcpPacket, pkt_sender: UnboundedSender, + stream_receiver: UnboundedReceiver, mtu: u16, tcp_timeout: Duration, - ) -> Result { - let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - - let stream = IpStackTcpStreamInner { + ) -> Result { + let stream = IpStackTcpStream { src_addr, dst_addr, - stream_sender, stream_receiver, packet_sender: pkt_sender.clone(), packet_to_send: None, @@ -94,10 +91,6 @@ impl IpStackTcpStreamInner { } } - pub(crate) fn stream_sender(&self) -> UnboundedSender { - self.stream_sender.clone() - } - fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), @@ -190,17 +183,9 @@ impl IpStackTcpStreamInner { payload, }) } - - // pub fn local_addr(&self) -> SocketAddr { - // self.src_addr - // } - - // pub fn peer_addr(&self) -> SocketAddr { - // self.dst_addr - // } } -impl AsyncRead for IpStackTcpStreamInner { +impl AsyncRead for IpStackTcpStream { fn poll_read( mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -263,6 +248,7 @@ impl AsyncRead for IpStackTcpStreamInner { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); + self.tcb.add_ack(1); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) @@ -410,22 +396,21 @@ impl AsyncRead for IpStackTcpStreamInner { } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.add_ack(1); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if t.flags() == (FIN | ACK) { - self.tcb.add_seq_one(); self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); - self.tcb.change_state(TcpState::FinWait2(false)); + self.tcb.change_state(TcpState::FinWait2(true)); continue; } } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { - self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); @@ -439,7 +424,7 @@ impl AsyncRead for IpStackTcpStreamInner { } } -impl AsyncWrite for IpStackTcpStreamInner { +impl AsyncWrite for IpStackTcpStream { fn poll_write( mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -468,7 +453,6 @@ impl AsyncWrite for IpStackTcpStreamInner { let seq = self.tcb.seq; let payload_len = packet.payload.len(); let payload = packet.payload.clone(); - self.packet_sender .send(packet) .or(Err(ErrorKind::UnexpectedEof))?; @@ -528,124 +512,10 @@ impl AsyncWrite for IpStackTcpStreamInner { } } -impl Drop for IpStackTcpStreamInner { +impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { _ = self.packet_sender.send(p); } } } - -#[derive(Debug)] -pub struct IpStackTcpStream { - inner: Option, - drop_sender: UnboundedSender>, - src_addr: SocketAddr, - dst_addr: SocketAddr, - stream_sender: UnboundedSender, -} - -impl IpStackTcpStream { - pub(crate) fn new( - drop_sender: UnboundedSender>, - src_addr: SocketAddr, - dst_addr: SocketAddr, - tcp: TcpPacket, - pkt_sender: UnboundedSender, - mtu: u16, - tcp_timeout: Duration, - ) -> Result { - let stream = - IpStackTcpStreamInner::new(src_addr, dst_addr, tcp, pkt_sender, mtu, tcp_timeout)?; - Ok(IpStackTcpStream { - stream_sender: stream.stream_sender(), - inner: Some(stream), - drop_sender, - src_addr, - dst_addr, - }) - } - pub(crate) fn stream_sender(&self) -> UnboundedSender { - self.stream_sender.clone() - } - - pub fn local_addr(&self) -> SocketAddr { - self.src_addr - } - - pub fn peer_addr(&self) -> SocketAddr { - self.dst_addr - } -} - -impl AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - if let Some(inner) = &mut self.inner { - Pin::new(inner).poll_read(cx, buf) - } else { - Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "", - ))) - } - } -} - -impl AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if let Some(inner) = &mut self.inner { - Pin::new(inner).poll_write(cx, buf) - } else { - Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "", - ))) - } - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - if let Some(inner) = &mut self.inner { - Pin::new(inner).poll_flush(cx) - } else { - Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "", - ))) - } - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - if let Some(inner) = &mut self.inner { - Pin::new(inner).poll_shutdown(cx) - } else { - Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::NotConnected, - "", - ))) - } - } -} - -impl Drop for IpStackTcpStream { - fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - if let Err(e) = self.drop_sender.send(Box::new(inner)) { - trace!("fail to send IpStackTcpStreamInner to drop {:?}", e); - } - } - } -} diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs new file mode 100644 index 0000000..e83241e --- /dev/null +++ b/src/stream/tcp_wrapper.rs @@ -0,0 +1,121 @@ +use std::{net::SocketAddr, pin::Pin, time::Duration}; + +use tokio::{ + io::AsyncWriteExt, + sync::mpsc::{self, UnboundedSender}, + time::timeout, +}; + +use crate::{ + packet::{NetworkPacket, TcpPacket}, + IpStackError, +}; + +use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; + +pub struct IpStackTcpStream { + inner: Option>, + peer_addr: SocketAddr, + local_addr: SocketAddr, + stream_sender: mpsc::UnboundedSender, +} + +impl IpStackTcpStream { + pub(crate) fn new( + local_addr: SocketAddr, + peer_addr: SocketAddr, + tcp: TcpPacket, + pkt_sender: UnboundedSender, + mtu: u16, + tcp_timeout: Duration, + ) -> Result { + let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); + IpStackTcpStreamInner::new( + local_addr, + peer_addr, + tcp, + pkt_sender, + stream_receiver, + mtu, + tcp_timeout, + ) + .map(Box::new) + .map(|inner| IpStackTcpStream { + inner: Some(inner), + peer_addr, + local_addr, + stream_sender, + }) + } + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + pub fn peer_addr(&self) -> SocketAddr { + self.peer_addr + } + pub fn stream_sender(&self) -> UnboundedSender { + self.stream_sender.clone() + } +} + +impl tokio::io::AsyncRead for IpStackTcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } +} + +impl tokio::io::AsyncWrite for IpStackTcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_flush(cx), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } +} + +impl Drop for IpStackTcpStream { + fn drop(&mut self) { + if let Some(mut inner) = self.inner.take() { + tokio::spawn(async move { + _ = timeout(Duration::from_secs(2), inner.shutdown()).await; + }); + } + } +}