diff --git a/src/lib.rs b/src/lib.rs index 3c76e04..355a44b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,7 +94,6 @@ impl IpStack { let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); loop { - // dbg!(streams.len()); select! { Ok(n) = device.read(&mut buffer) => { let offset = if config.packet_information && cfg!(unix) {4} else {0}; @@ -109,8 +108,8 @@ impl IpStack { match streams.entry(packet.network_tuple()){ Occupied(entry) =>{ - if let Err(_x) = entry.get().send(packet){ - trace!("Send packet error \"{}\"", _x); + if let Err(e) = entry.get().send(packet){ + trace!("Send packet error \"{}\"", e); } } Vacant(entry) => { diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 4c9bd68..42632f4 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -1,11 +1,12 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -pub use self::tcp::IpStackTcpStream; +pub use self::tcp_wrapper::IpStackTcpStream; pub use self::udp::IpStackUdpStream; pub use self::unknown::IpStackUnknownTransport; mod tcb; mod tcp; +mod tcp_wrapper; mod udp; mod unknown; diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 0e7383b..4e2e281 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -176,6 +176,7 @@ impl Tcb { } pub(super) fn change_last_ack(&mut self, ack: u32) { let distance = ack.wrapping_sub(self.last_ack); + self.last_ack = self.last_ack.wrapping_add(distance); if matches!(self.state, TcpState::Established) { if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { @@ -187,9 +188,12 @@ impl Tcb { self.inflight_packets.push(inflight_packet); } } + self.inflight_packets.retain(|p| { + let last_byte = p.seq.wrapping_add(p.payload.len() as u32); + last_byte.saturating_sub(self.last_ack) > 0 + && self.seq.saturating_sub(last_byte) > 0 + }); } - - self.last_ack = self.last_ack.wrapping_add(distance); } pub fn is_send_buffer_full(&self) -> bool { self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 8c396ad..c884cf1 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -20,7 +20,7 @@ use std::{ }; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + sync::mpsc::{UnboundedReceiver, UnboundedSender}, }; use log::{trace, warn}; @@ -50,10 +50,9 @@ impl Shutdown { } #[derive(Debug)] -pub struct IpStackTcpStream { +pub(crate) struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, - stream_sender: UnboundedSender, stream_receiver: UnboundedReceiver, packet_sender: UnboundedSender, packet_to_send: Option, @@ -69,15 +68,13 @@ impl IpStackTcpStream { dst_addr: SocketAddr, tcp: TcpPacket, pkt_sender: UnboundedSender, + stream_receiver: UnboundedReceiver, mtu: u16, tcp_timeout: Duration, ) -> Result { - let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - let stream = IpStackTcpStream { src_addr, dst_addr, - stream_sender, stream_receiver, packet_sender: pkt_sender.clone(), packet_to_send: None, @@ -94,10 +91,6 @@ impl IpStackTcpStream { } } - pub(crate) fn stream_sender(&self) -> UnboundedSender { - self.stream_sender.clone() - } - fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { cmp::min( self.tcb.get_send_window(), @@ -190,14 +183,6 @@ impl IpStackTcpStream { payload, }) } - - pub fn local_addr(&self) -> SocketAddr { - self.src_addr - } - - pub fn peer_addr(&self) -> SocketAddr { - self.dst_addr - } } impl AsyncRead for IpStackTcpStream { @@ -263,6 +248,7 @@ impl AsyncRead for IpStackTcpStream { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); + self.tcb.add_ack(1); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) @@ -410,22 +396,21 @@ impl AsyncRead for IpStackTcpStream { } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); + self.tcb.add_ack(1); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if t.flags() == (FIN | ACK) { - self.tcb.add_seq_one(); self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); - self.tcb.change_state(TcpState::FinWait2(false)); + self.tcb.change_state(TcpState::FinWait2(true)); continue; } } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { - self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); @@ -468,7 +453,6 @@ impl AsyncWrite for IpStackTcpStream { let seq = self.tcb.seq; let payload_len = packet.payload.len(); let payload = packet.payload.clone(); - self.packet_sender .send(packet) .or(Err(ErrorKind::UnexpectedEof))?; diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs new file mode 100644 index 0000000..e83241e --- /dev/null +++ b/src/stream/tcp_wrapper.rs @@ -0,0 +1,121 @@ +use std::{net::SocketAddr, pin::Pin, time::Duration}; + +use tokio::{ + io::AsyncWriteExt, + sync::mpsc::{self, UnboundedSender}, + time::timeout, +}; + +use crate::{ + packet::{NetworkPacket, TcpPacket}, + IpStackError, +}; + +use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; + +pub struct IpStackTcpStream { + inner: Option>, + peer_addr: SocketAddr, + local_addr: SocketAddr, + stream_sender: mpsc::UnboundedSender, +} + +impl IpStackTcpStream { + pub(crate) fn new( + local_addr: SocketAddr, + peer_addr: SocketAddr, + tcp: TcpPacket, + pkt_sender: UnboundedSender, + mtu: u16, + tcp_timeout: Duration, + ) -> Result { + let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); + IpStackTcpStreamInner::new( + local_addr, + peer_addr, + tcp, + pkt_sender, + stream_receiver, + mtu, + tcp_timeout, + ) + .map(Box::new) + .map(|inner| IpStackTcpStream { + inner: Some(inner), + peer_addr, + local_addr, + stream_sender, + }) + } + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + pub fn peer_addr(&self) -> SocketAddr { + self.peer_addr + } + pub fn stream_sender(&self) -> UnboundedSender { + self.stream_sender.clone() + } +} + +impl tokio::io::AsyncRead for IpStackTcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } +} + +impl tokio::io::AsyncWrite for IpStackTcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_flush(cx), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.as_mut() { + Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx), + None => { + std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) + } + } + } +} + +impl Drop for IpStackTcpStream { + fn drop(&mut self) { + if let Some(mut inner) = self.inner.take() { + tokio::spawn(async move { + _ = timeout(Duration::from_secs(2), inner.shutdown()).await; + }); + } + } +}