Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix automatically drop IpstackTcpStream #32

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ impl IpStack {

let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
loop {
// dbg!(streams.len());
select! {
Ok(n) = device.read(&mut buffer) => {
let offset = if config.packet_information && cfg!(unix) {4} else {0};
Expand All @@ -109,8 +108,8 @@ impl IpStack {

match streams.entry(packet.network_tuple()){
Occupied(entry) =>{
if let Err(_x) = entry.get().send(packet){
trace!("Send packet error \"{}\"", _x);
if let Err(e) = entry.get().send(packet){
trace!("Send packet error \"{}\"", e);
}
}
Vacant(entry) => {
Expand Down
3 changes: 2 additions & 1 deletion src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

pub use self::tcp::IpStackTcpStream;
pub use self::tcp_wrapper::IpStackTcpStream;
pub use self::udp::IpStackUdpStream;
pub use self::unknown::IpStackUnknownTransport;

mod tcb;
mod tcp;
mod tcp_wrapper;
mod udp;
mod unknown;

Expand Down
8 changes: 6 additions & 2 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ impl Tcb {
}
pub(super) fn change_last_ack(&mut self, ack: u32) {
let distance = ack.wrapping_sub(self.last_ack);
self.last_ack = self.last_ack.wrapping_add(distance);

if matches!(self.state, TcpState::Established) {
if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) {
Expand All @@ -187,9 +188,12 @@ impl Tcb {
self.inflight_packets.push(inflight_packet);
}
}
self.inflight_packets.retain(|p| {
let last_byte = p.seq.wrapping_add(p.payload.len() as u32);
last_byte.saturating_sub(self.last_ack) > 0
&& self.seq.saturating_sub(last_byte) > 0
});
}

self.last_ack = self.last_ack.wrapping_add(distance);
}
pub fn is_send_buffer_full(&self) -> bool {
self.seq.wrapping_sub(self.last_ack) >= MAX_UNACK
Expand Down
28 changes: 6 additions & 22 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
};

use log::{trace, warn};
Expand Down Expand Up @@ -50,10 +50,9 @@ impl Shutdown {
}

#[derive(Debug)]
pub struct IpStackTcpStream {
pub(crate) struct IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
packet_sender: UnboundedSender<NetworkPacket>,
packet_to_send: Option<NetworkPacket>,
Expand All @@ -69,15 +68,13 @@ impl IpStackTcpStream {
dst_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

let stream = IpStackTcpStream {
src_addr,
dst_addr,
stream_sender,
stream_receiver,
packet_sender: pkt_sender.clone(),
packet_to_send: None,
Expand All @@ -94,10 +91,6 @@ impl IpStackTcpStream {
}
}

pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}

fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 {
cmp::min(
self.tcb.get_send_window(),
Expand Down Expand Up @@ -190,14 +183,6 @@ impl IpStackTcpStream {
payload,
})
}

pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}

pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
}

impl AsyncRead for IpStackTcpStream {
Expand Down Expand Up @@ -263,6 +248,7 @@ impl AsyncRead for IpStackTcpStream {
self.packet_to_send =
Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.add_ack(1);
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if matches!(self.shutdown, Shutdown::Pending(_))
Expand Down Expand Up @@ -410,22 +396,21 @@ impl AsyncRead for IpStackTcpStream {
} else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.add_ack(1);
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if t.flags() == (FIN | ACK) {
self.tcb.add_seq_one();
self.tcb.add_ack(1);
self.packet_to_send =
Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.change_state(TcpState::FinWait2(false));
self.tcb.change_state(TcpState::FinWait2(true));
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) {
if t.flags() == ACK {
self.tcb.change_state(TcpState::FinWait2(false));
} else if t.flags() == (FIN | ACK) {
self.tcb.add_ack(1);
self.packet_to_send =
Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::FinWait2(false));
Expand Down Expand Up @@ -468,7 +453,6 @@ impl AsyncWrite for IpStackTcpStream {
let seq = self.tcb.seq;
let payload_len = packet.payload.len();
let payload = packet.payload.clone();

self.packet_sender
.send(packet)
.or(Err(ErrorKind::UnexpectedEof))?;
Expand Down
121 changes: 121 additions & 0 deletions src/stream/tcp_wrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::{net::SocketAddr, pin::Pin, time::Duration};

use tokio::{
io::AsyncWriteExt,
sync::mpsc::{self, UnboundedSender},
time::timeout,
};

use crate::{
packet::{NetworkPacket, TcpPacket},
IpStackError,
};

use super::tcp::IpStackTcpStream as IpStackTcpStreamInner;

pub struct IpStackTcpStream {
inner: Option<Box<IpStackTcpStreamInner>>,
peer_addr: SocketAddr,
local_addr: SocketAddr,
stream_sender: mpsc::UnboundedSender<NetworkPacket>,
}

impl IpStackTcpStream {
pub(crate) fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
IpStackTcpStreamInner::new(
local_addr,
peer_addr,
tcp,
pkt_sender,
stream_receiver,
mtu,
tcp_timeout,
)
.map(Box::new)
.map(|inner| IpStackTcpStream {
inner: Some(inner),
peer_addr,
local_addr,
stream_sender,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
pub fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}
}

impl tokio::io::AsyncRead for IpStackTcpStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
}

impl tokio::io::AsyncWrite for IpStackTcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_flush(cx),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
}

impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Some(mut inner) = self.inner.take() {
tokio::spawn(async move {
_ = timeout(Duration::from_secs(2), inner.shutdown()).await;
});
}
}
}
Loading