From 6b23ebf28151027bbdbb34778c1e902ad955e22b Mon Sep 17 00:00:00 2001
From: SajjadPourali <sajjad@pourali.com>
Date: Mon, 4 Mar 2024 18:05:57 -0500
Subject: [PATCH] Fix fin race condition

---
 examples/tun_wintun.rs |  2 +-
 src/stream/tcp.rs      | 49 ++++++++++++++++++++++++------------------
 2 files changed, 29 insertions(+), 22 deletions(-)

diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs
index 397cc85..1df05c0 100644
--- a/examples/tun_wintun.rs
+++ b/examples/tun_wintun.rs
@@ -63,7 +63,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
                 };
                 println!("==== New TCP connection ====");
                 tokio::spawn(async move {
-                    let _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await;
+                    _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await;
                     println!("====== end tcp connection ======");
                 });
             }
diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs
index 54a0f00..f230cc5 100644
--- a/src/stream/tcp.rs
+++ b/src/stream/tcp.rs
@@ -208,7 +208,9 @@ impl AsyncRead for IpStackTcpStream {
         buf: &mut tokio::io::ReadBuf<'_>,
     ) -> std::task::Poll<std::io::Result<()>> {
         loop {
-            if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) {
+            if matches!(self.tcb.get_state(), TcpState::FinWait2(false))
+                && self.packet_to_send.is_none()
+            {
                 self.packet_to_send =
                     Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
                 self.tcb.change_state(TcpState::Closed);
@@ -256,7 +258,21 @@ impl AsyncRead for IpStackTcpStream {
                     .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
                 return std::task::Poll::Ready(Ok(()));
             }
+            if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
+                let flags = tcp_flags::FIN | tcp_flags::ACK;
+                self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
+                self.tcb.add_seq_one();
+                self.tcb.change_state(TcpState::FinWait2(true));
+                continue;
+            } else if matches!(self.shutdown, Shutdown::Pending(_))
+                && matches!(self.tcb.get_state(), TcpState::Established)
+            {
+                let flags = tcp_flags::FIN | tcp_flags::ACK;
+                self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
+                self.tcb.change_state(TcpState::FinWait1(false));
 
+                continue;
+            }
             match self.stream_receiver.poll_recv(cx) {
                 std::task::Poll::Ready(Some(p)) => {
                     let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else {
@@ -366,7 +382,6 @@ impl AsyncRead for IpStackTcpStream {
                             let flags = tcp_flags::ACK;
                             self.packet_to_send =
                                 Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
-                            // self.tcb.add_seq_one();
                             self.tcb.change_state(TcpState::FinWait1(true));
                             continue;
                         }
@@ -400,12 +415,18 @@ impl AsyncRead for IpStackTcpStream {
                             continue;
                         }
                     } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) {
-                        if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
+                        if t.flags() == tcp_flags::ACK {
+                            // panic!("ACK received in FinWait1");
+                            self.tcb.add_ack(1);
+                            self.tcb.change_state(TcpState::FinWait1(true));
+                            continue;
+                        } else if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
                             let flags = tcp_flags::ACK;
+                            self.tcb.add_seq_one();
+                            self.tcb.add_ack(1);
                             self.packet_to_send =
                                 Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
                             self.tcb.change_send_window(t.inner().window_size);
-                            // self.tcb.add_seq_one();
                             self.tcb.change_state(TcpState::FinWait2(false));
                             continue;
                         }
@@ -417,23 +438,6 @@ impl AsyncRead for IpStackTcpStream {
                 }
                 std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())),
                 std::task::Poll::Pending => {
-                    if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
-                        let flags = tcp_flags::FIN | tcp_flags::ACK;
-                        self.packet_to_send =
-                            Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
-
-                        self.tcb.change_state(TcpState::FinWait2(false));
-                        continue;
-                    } else if matches!(self.shutdown, Shutdown::Pending(_))
-                        && matches!(self.tcb.get_state(), TcpState::Established)
-                    {
-                        let flags = tcp_flags::FIN | tcp_flags::ACK;
-                        self.packet_to_send =
-                            Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
-                        self.tcb.change_state(TcpState::FinWait1(false));
-
-                        continue;
-                    }
                     return std::task::Poll::Pending;
                 }
             }
@@ -447,6 +451,9 @@ impl AsyncWrite for IpStackTcpStream {
         cx: &mut std::task::Context<'_>,
         buf: &[u8],
     ) -> std::task::Poll<Result<usize, std::io::Error>> {
+        if !matches!(self.tcb.get_state(), TcpState::Established) {
+            return std::task::Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
+        }
         self.tcb.reset_timeout();
 
         if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2