From ba40e3d7ad4181086b51d51f7387713f25813624 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 15 Mar 2024 06:28:22 +0800 Subject: [PATCH] Fix timeout issues (#25) * Fix timeout issues * Another fixing * minor changes * Prevent creating stream when incorrect packet received + add connection check to poll_flush * minor changes * minor changes * minor changes * minor changes --------- Co-authored-by: SajjadPourali --- src/lib.rs | 14 ++++--- src/stream/mod.rs | 14 +++---- src/stream/tcp.rs | 104 ++++++++++++++++++++++------------------------ src/stream/udp.rs | 2 +- 4 files changed, 65 insertions(+), 69 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d43f99f..b014b43 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,7 +111,7 @@ impl IpStack { Occupied(entry) =>{ if let Err(_x) = entry.get().send(packet){ #[cfg(feature = "log")] - trace!("{}", _x); + trace!("Send packet error \"{}\"", _x); } } Vacant(entry) => { @@ -119,15 +119,17 @@ impl IpStack { IpStackPacketProtocol::Tcp(h) => { match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{ Ok(stream) => { - if stream.is_closed(){ - continue; - } entry.insert(stream.stream_sender()); accept_sender.send(IpStackStream::Tcp(stream))?; } - Err(_e) => { + Err(e) => { + if matches!(e,IpStackError::InvalidTcpPacket){ + #[cfg(feature = "log")] + trace!("Invalid TCP packet"); + continue; + } #[cfg(feature = "log")] - error!("{}", _e); + error!("IpStackTcpStream::new failed \"{}\"", e); } } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 9878f99..4c9bd68 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,4 +1,4 @@ -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; pub use self::tcp::IpStackTcpStream; pub use self::udp::IpStackUdpStream; @@ -22,11 +22,11 @@ impl IpStackStream { IpStackStream::Tcp(tcp) => tcp.local_addr(), IpStackStream::Udp(udp) => udp.local_addr(), IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0)) + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) } IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { - std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), - std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), + IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), + IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), }, } } @@ -35,11 +35,11 @@ impl IpStackStream { IpStackStream::Tcp(tcp) => tcp.peer_addr(), IpStackStream::Udp(udp) => udp.peer_addr(), IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0)) + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) } IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { - std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), - std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), + IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), + IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), }, } } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index f230cc5..d856404 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -11,7 +11,7 @@ use std::{ io::{Error, ErrorKind}, net::SocketAddr, pin::Pin, - task::Waker, + task::{Context, Poll, Waker}, time::Duration, }; use tokio::{ @@ -70,7 +70,7 @@ impl IpStackTcpStream { ) -> Result { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - let mut stream = IpStackTcpStream { + let stream = IpStackTcpStream { src_addr, dst_addr, stream_sender, @@ -84,12 +84,11 @@ impl IpStackTcpStream { }; if !tcp.inner().syn { let flags = tcp_flags::RST | tcp_flags::ACK; - pkt_sender - .send(stream.create_rev_packet(flags, TTL, None, Vec::new())?) - .map_err(|_| IpStackError::InvalidTcpPacket)?; - stream.tcb.change_state(TcpState::Closed); + _ = pkt_sender.send(stream.create_rev_packet(flags, TTL, None, Vec::new())?); + Err(IpStackError::InvalidTcpPacket) + } else { + Ok(stream) } - Ok(stream) } pub(crate) fn stream_sender(&self) -> UnboundedSender { @@ -174,12 +173,12 @@ impl IpStackTcpStream { etherparse::NetHeaders::Ipv4(ref ip_header, _) => { tcp_header.checksum = tcp_header .calc_checksum_ipv4(ip_header, &payload) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .or(Err(ErrorKind::InvalidInput))?; } etherparse::NetHeaders::Ipv6(ref ip_header, _) => { tcp_header.checksum = tcp_header .calc_checksum_ipv6(ip_header, &payload) - .map_err(|_e| Error::from(ErrorKind::InvalidInput))?; + .or(Err(ErrorKind::InvalidInput))?; } } Ok(NetworkPacket { @@ -196,17 +195,14 @@ impl IpStackTcpStream { pub fn peer_addr(&self) -> SocketAddr { self.dst_addr } - pub(crate) fn is_closed(&self) -> bool { - matches!(self.tcb.get_state(), TcpState::Closed) - } } impl AsyncRead for IpStackTcpStream { fn poll_read( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { + ) -> Poll> { loop { if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) && self.packet_to_send.is_none() @@ -215,21 +211,20 @@ impl AsyncRead for IpStackTcpStream { Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); - return std::task::Poll::Ready(Ok(())); + return Poll::Ready(Ok(())); } let min = cmp::min(self.tcb.get_available_read_buffer_size() as u16, u16::MAX); self.tcb.change_recv_window(min); - if matches!( - Pin::new(&mut self.tcb.timeout).poll(cx), - std::task::Poll::Ready(_) - ) { + if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { #[cfg(feature = "log")] trace!("timeout reached for {:?}", self.dst_addr); let flags = tcp_flags::RST | tcp_flags::ACK; self.packet_sender .send(self.create_rev_packet(flags, TTL, None, Vec::new())?) - .map_err(|_| ErrorKind::UnexpectedEof)?; - return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); + .or(Err(ErrorKind::UnexpectedEof))?; + self.tcb.change_state(TcpState::Closed); + self.shutdown.ready(); + return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } self.tcb.reset_timeout(); @@ -244,10 +239,10 @@ impl AsyncRead for IpStackTcpStream { if let Some(packet) = self.packet_to_send.take() { self.packet_sender .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + .or(Err(ErrorKind::UnexpectedEof))?; if matches!(self.tcb.get_state(), TcpState::Closed) { self.shutdown.ready(); - return std::task::Poll::Ready(Ok(())); + return Poll::Ready(Ok(())); } } if let Some(b) = self.tcb.get_unordered_packets() { @@ -255,8 +250,8 @@ impl AsyncRead for IpStackTcpStream { buf.put_slice(&b); self.packet_sender .send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; - return std::task::Poll::Ready(Ok(())); + .or(Err(ErrorKind::UnexpectedEof))?; + return Poll::Ready(Ok(())); } if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) { let flags = tcp_flags::FIN | tcp_flags::ACK; @@ -274,7 +269,7 @@ impl AsyncRead for IpStackTcpStream { continue; } match self.stream_receiver.poll_recv(cx) { - std::task::Poll::Ready(Some(p)) => { + Poll::Ready(Some(p)) => { let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else { unreachable!() }; @@ -282,9 +277,8 @@ impl AsyncRead for IpStackTcpStream { self.packet_to_send = Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); - return std::task::Poll::Ready(Err(Error::from( - ErrorKind::ConnectionReset, - ))); + self.shutdown.ready(); + return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); } if matches!( self.tcb.check_pkt_type(&t, &p.payload), @@ -325,11 +319,8 @@ impl AsyncRead for IpStackTcpStream { PacketStatus::RetransmissionRequest => { self.tcb.change_send_window(t.inner().window_size); self.tcb.retransmission = Some(t.inner().acknowledgment_number); - if matches!( - self.as_mut().poll_flush(cx), - std::task::Poll::Pending - ) { - return std::task::Poll::Pending; + if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { + return Poll::Pending; } continue; } @@ -364,7 +355,7 @@ impl AsyncRead for IpStackTcpStream { self.write_notify = None; }; continue; - // return std::task::Poll::Ready(Ok(())); + // return Poll::Ready(Ok(())); } PacketStatus::Ack => { self.tcb.change_last_ack(t.inner().acknowledgment_number); @@ -409,7 +400,7 @@ impl AsyncRead for IpStackTcpStream { // None, // Vec::new(), // )?); - // return std::task::Poll::Ready(Ok(())); + // return Poll::Ready(Ok(())); self.tcb .add_unordered_packet(t.inner().sequence_number, &p.payload); continue; @@ -436,9 +427,9 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(false)); } } - std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())), - std::task::Poll::Pending => { - return std::task::Poll::Pending; + Poll::Ready(None) => return Poll::Ready(Ok(())), + Poll::Pending => { + return Poll::Pending; } } } @@ -448,11 +439,11 @@ impl AsyncRead for IpStackTcpStream { impl AsyncWrite for IpStackTcpStream { fn poll_write( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, buf: &[u8], - ) -> std::task::Poll> { + ) -> Poll> { if !matches!(self.tcb.get_state(), TcpState::Established) { - return std::task::Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); + return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } self.tcb.reset_timeout(); @@ -460,13 +451,13 @@ impl AsyncWrite for IpStackTcpStream { || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); - return std::task::Poll::Pending; + return Poll::Pending; } if self.tcb.retransmission.is_some() { self.write_notify = Some(cx.waker().clone()); - if matches!(self.as_mut().poll_flush(cx), std::task::Poll::Pending) { - return std::task::Poll::Pending; + if matches!(self.as_mut().poll_flush(cx), Poll::Pending) { + return Poll::Pending; } } @@ -478,16 +469,19 @@ impl AsyncWrite for IpStackTcpStream { self.packet_sender .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + .or(Err(ErrorKind::UnexpectedEof))?; self.tcb.add_inflight_packet(seq, &payload); - std::task::Poll::Ready(Ok(payload_len)) + Poll::Ready(Ok(payload_len)) } fn poll_flush( mut self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + _cx: &mut Context<'_>, + ) -> Poll> { + if !matches!(self.tcb.get_state(), TcpState::Established) { + return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); + } if let Some(i) = self .tcb .retransmission @@ -499,7 +493,7 @@ impl AsyncWrite for IpStackTcpStream { self.packet_sender .send(packet) - .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; + .or(Err(ErrorKind::UnexpectedEof))?; self.tcb.retransmission = None; } else if let Some(_i) = self.tcb.retransmission { #[cfg(feature = "log")] @@ -515,18 +509,18 @@ impl AsyncWrite for IpStackTcpStream { } panic!("Please report these values at: https://github.com/narrowlink/ipstack/"); } - std::task::Poll::Ready(Ok(())) + Poll::Ready(Ok(())) } fn poll_shutdown( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + cx: &mut Context<'_>, + ) -> Poll> { match &self.shutdown { - Shutdown::Ready => std::task::Poll::Ready(Ok(())), + Shutdown::Ready => Poll::Ready(Ok(())), Shutdown::Pending(_) | Shutdown::None => { self.shutdown.pending(cx.waker().clone()); - std::task::Poll::Pending + Poll::Pending } } } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index 0eef9d4..fd75aad 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -160,7 +160,7 @@ impl AsyncWrite for IpStackUdpStream { let payload_len = packet.payload.len(); self.packet_sender .send(packet) - .map_err(|_| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; + .or(Err(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) }