diff --git a/Cargo.toml b/Cargo.toml index 1f46bc6..a270033 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,19 +16,20 @@ tokio = { version = "1.36", features = [ "time", "io-util", "macros", + "rt-multi-thread", ], default-features = false } etherparse = { version = "0.14", default-features = false, features = ["std"] } thiserror = { version = "1.0", default-features = false } -log = { version = "0.4", default-features = false} -rand = {version = "0.8.5", default-features = false, features = ["std","std_rng"] } +log = { version = "0.4", default-features = false } +rand = { version = "0.8.5", default-features = false, features = [ + "std", + "std_rng", +] } [dev-dependencies] clap = { version = "4.5", features = ["derive"] } env_logger = "0.11" udp-stream = { version = "0.0", default-features = false } -tokio = { version = "1.36", features = [ - "rt-multi-thread", -], default-features = false } #tun2.rs example @@ -52,6 +53,3 @@ incremental = false # Disable incremental compilation. overflow-checks = false # Disable overflow checks. strip = true # Automatically strip symbols from the binary. -[[example]] -name = "tun2" -required-features = ["log"] diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index 9984805..b6e9211 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -17,7 +17,7 @@ struct Args { server_addr: SocketAddr, } -#[tokio::main(flavor = "current_thread")] +#[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { let args = Args::parse(); diff --git a/src/lib.rs b/src/lib.rs index e212e39..3c76e04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,7 +116,7 @@ impl IpStack { Vacant(entry) => { match packet.transport_protocol(){ IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{ + match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){ Ok(stream) => { entry.insert(stream.stream_sender()); accept_sender.send(IpStackStream::Tcp(stream))?; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index d928d05..ae6f014 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -12,14 +12,17 @@ use std::{ cmp, future::Future, io::{Error, ErrorKind}, + mem::MaybeUninit, net::SocketAddr, pin::Pin, task::{Context, Poll, Waker}, time::Duration, }; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, + runtime::Handle, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + task, }; use log::{trace, warn}; @@ -63,7 +66,7 @@ pub struct IpStackTcpStream { } impl IpStackTcpStream { - pub(crate) async fn new( + pub(crate) fn new( src_addr: SocketAddr, dst_addr: SocketAddr, tcp: TcpPacket, @@ -244,8 +247,13 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Ok(())); } + continue; } - if let Some(b) = self.tcb.get_unordered_packets() { + if let Some(b) = self + .tcb + .get_unordered_packets() + .filter(|_| matches!(self.shutdown, Shutdown::None)) + { self.tcb.add_ack(b.len() as u32); buf.put_slice(&b); self.packet_sender @@ -265,6 +273,7 @@ impl AsyncRead for IpStackTcpStream { { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait1(false)); continue; } @@ -402,6 +411,7 @@ impl AsyncRead for IpStackTcpStream { } } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { if t.flags() == ACK { + self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if t.flags() == (FIN | ACK) { @@ -417,6 +427,7 @@ impl AsyncRead for IpStackTcpStream { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { + self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); @@ -424,9 +435,7 @@ impl AsyncRead for IpStackTcpStream { } } Poll::Ready(None) => return Poll::Ready(Ok(())), - Poll::Pending => { - return Poll::Pending; - } + Poll::Pending => return Poll::Pending, } } } @@ -509,21 +518,30 @@ impl AsyncWrite for IpStackTcpStream { mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::io::Result<()>> { - match &self.shutdown { - Shutdown::Ready => Poll::Ready(Ok(())), - Shutdown::Pending(_) => Poll::Pending, - Shutdown::None => { - self.shutdown.pending(cx.waker().clone()); - Poll::Pending - } + if matches!(self.shutdown, Shutdown::Ready) { + return Poll::Ready(Ok(())); + } else if matches!(self.shutdown, Shutdown::None) { + self.shutdown.pending(cx.waker().clone()); } + self.poll_read( + cx, + &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::<u8>::uninit()]), + ) } } impl Drop for IpStackTcpStream { fn drop(&mut self) { - if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - _ = self.packet_sender.send(p); - } + task::block_in_place(move || { + Handle::current().block_on(async move { + if matches!(self.tcb.get_state(), TcpState::Established) { + _ = self.shutdown().await; + } + + if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { + _ = self.packet_sender.send(p); + } + }); + }); } }