Skip to content

Commit

Permalink
Fix manual shutdown #27 #28 (#30)
Browse files Browse the repository at this point in the history
* Fix #27 #28

* Fix the dependency issue

* Fix the uninitialized buffer during shutdown

* Remove unnecessary async function

* Fix the block when invalid packet received

* Fix the block when invalid packet received

* Remove the test file
  • Loading branch information
SajjadPourali authored Mar 28, 2024
1 parent 20cb3e0 commit 437ef17
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 26 deletions.
14 changes: 6 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@ tokio = { version = "1.36", features = [
"time",
"io-util",
"macros",
"rt-multi-thread",
], default-features = false }
etherparse = { version = "0.14", default-features = false, features = ["std"] }
thiserror = { version = "1.0", default-features = false }
log = { version = "0.4", default-features = false}
rand = {version = "0.8.5", default-features = false, features = ["std","std_rng"] }
log = { version = "0.4", default-features = false }
rand = { version = "0.8.5", default-features = false, features = [
"std",
"std_rng",
] }

[dev-dependencies]
clap = { version = "4.5", features = ["derive"] }
env_logger = "0.11"
udp-stream = { version = "0.0", default-features = false }
tokio = { version = "1.36", features = [
"rt-multi-thread",
], default-features = false }


#tun2.rs example
Expand All @@ -52,6 +53,3 @@ incremental = false # Disable incremental compilation.
overflow-checks = false # Disable overflow checks.
strip = true # Automatically strip symbols from the binary.

[[example]]
name = "tun2"
required-features = ["log"]
2 changes: 1 addition & 1 deletion examples/tun_wintun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct Args {
server_addr: SocketAddr,
}

#[tokio::main(flavor = "current_thread")]
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl IpStack {
Vacant(entry) => {
match packet.transport_protocol(){
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{
match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout){
Ok(stream) => {
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Tcp(stream))?;
Expand Down
50 changes: 34 additions & 16 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ use std::{
cmp,
future::Future,
io::{Error, ErrorKind},
mem::MaybeUninit,
net::SocketAddr,
pin::Pin,
task::{Context, Poll, Waker},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
runtime::Handle,
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
task,
};

use log::{trace, warn};
Expand Down Expand Up @@ -63,7 +66,7 @@ pub struct IpStackTcpStream {
}

impl IpStackTcpStream {
pub(crate) async fn new(
pub(crate) fn new(
src_addr: SocketAddr,
dst_addr: SocketAddr,
tcp: TcpPacket,
Expand Down Expand Up @@ -244,8 +247,13 @@ impl AsyncRead for IpStackTcpStream {
self.shutdown.ready();
return Poll::Ready(Ok(()));
}
continue;
}
if let Some(b) = self.tcb.get_unordered_packets() {
if let Some(b) = self
.tcb
.get_unordered_packets()
.filter(|_| matches!(self.shutdown, Shutdown::None))
{
self.tcb.add_ack(b.len() as u32);
buf.put_slice(&b);
self.packet_sender
Expand All @@ -265,6 +273,7 @@ impl AsyncRead for IpStackTcpStream {
{
self.packet_to_send =
Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait1(false));
continue;
}
Expand Down Expand Up @@ -402,6 +411,7 @@ impl AsyncRead for IpStackTcpStream {
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if t.flags() == (FIN | ACK) {
Expand All @@ -417,16 +427,15 @@ impl AsyncRead for IpStackTcpStream {
if t.flags() == ACK {
self.tcb.change_state(TcpState::FinWait2(false));
} else if t.flags() == (FIN | ACK) {
self.tcb.add_ack(1);
self.packet_to_send =
Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::FinWait2(false));
}
}
}
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => {
return Poll::Pending;
}
Poll::Pending => return Poll::Pending,
}
}
}
Expand Down Expand Up @@ -509,21 +518,30 @@ impl AsyncWrite for IpStackTcpStream {
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
match &self.shutdown {
Shutdown::Ready => Poll::Ready(Ok(())),
Shutdown::Pending(_) => Poll::Pending,
Shutdown::None => {
self.shutdown.pending(cx.waker().clone());
Poll::Pending
}
if matches!(self.shutdown, Shutdown::Ready) {
return Poll::Ready(Ok(()));
} else if matches!(self.shutdown, Shutdown::None) {
self.shutdown.pending(cx.waker().clone());
}
self.poll_read(
cx,
&mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::<u8>::uninit()]),
)
}
}

impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) {
_ = self.packet_sender.send(p);
}
task::block_in_place(move || {
Handle::current().block_on(async move {
if matches!(self.tcb.get_state(), TcpState::Established) {
_ = self.shutdown().await;
}

if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) {
_ = self.packet_sender.send(p);
}
});
});
}
}

0 comments on commit 437ef17

Please sign in to comment.