Skip to content

Commit

Permalink
fix automatically drop IpstackTcpStream (#32)
Browse files Browse the repository at this point in the history
* fix automatically drop IpstackTcpStream

* IpStackTcpStreamInner is so larger and place it to heap

* Code refactor

---------

Co-authored-by: SajjadPourali
  • Loading branch information
xmh0511 authored Mar 30, 2024
1 parent 8a30d36 commit 8876511
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 22 deletions.
5 changes: 2 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ impl IpStack {

let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
loop {
// dbg!(streams.len());
select! {
Ok(n) = device.read(&mut buffer) => {
let offset = if config.packet_information && cfg!(unix) {4} else {0};
Expand All @@ -109,8 +108,8 @@ impl IpStack {

match streams.entry(packet.network_tuple()){
Occupied(entry) =>{
if let Err(_x) = entry.get().send(packet){
trace!("Send packet error \"{}\"", _x);
if let Err(e) = entry.get().send(packet){
trace!("Send packet error \"{}\"", e);
}
}
Vacant(entry) => {
Expand Down
3 changes: 2 additions & 1 deletion src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

pub use self::tcp::IpStackTcpStream;
pub use self::tcp_wrapper::IpStackTcpStream;
pub use self::udp::IpStackUdpStream;
pub use self::unknown::IpStackUnknownTransport;

mod tcb;
mod tcp;
mod tcp_wrapper;
mod udp;
mod unknown;

Expand Down
21 changes: 3 additions & 18 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
};

use log::{trace, warn};
Expand Down Expand Up @@ -50,10 +50,9 @@ impl Shutdown {
}

#[derive(Debug)]
pub struct IpStackTcpStream {
pub(crate) struct IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
stream_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
packet_sender: UnboundedSender<NetworkPacket>,
packet_to_send: Option<NetworkPacket>,
Expand All @@ -69,15 +68,13 @@ impl IpStackTcpStream {
dst_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
stream_receiver: UnboundedReceiver<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

let stream = IpStackTcpStream {
src_addr,
dst_addr,
stream_sender,
stream_receiver,
packet_sender: pkt_sender.clone(),
packet_to_send: None,
Expand All @@ -94,10 +91,6 @@ impl IpStackTcpStream {
}
}

pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}

fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 {
cmp::min(
self.tcb.get_send_window(),
Expand Down Expand Up @@ -190,14 +183,6 @@ impl IpStackTcpStream {
payload,
})
}

pub fn local_addr(&self) -> SocketAddr {
self.src_addr
}

pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
}

impl AsyncRead for IpStackTcpStream {
Expand Down
121 changes: 121 additions & 0 deletions src/stream/tcp_wrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::{net::SocketAddr, pin::Pin, time::Duration};

use tokio::{
io::AsyncWriteExt,
sync::mpsc::{self, UnboundedSender},
time::timeout,
};

use crate::{
packet::{NetworkPacket, TcpPacket},
IpStackError,
};

use super::tcp::IpStackTcpStream as IpStackTcpStreamInner;

pub struct IpStackTcpStream {
inner: Option<Box<IpStackTcpStreamInner>>,
peer_addr: SocketAddr,
local_addr: SocketAddr,
stream_sender: mpsc::UnboundedSender<NetworkPacket>,
}

impl IpStackTcpStream {
pub(crate) fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
tcp: TcpPacket,
pkt_sender: UnboundedSender<NetworkPacket>,
mtu: u16,
tcp_timeout: Duration,
) -> Result<IpStackTcpStream, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();
IpStackTcpStreamInner::new(
local_addr,
peer_addr,
tcp,
pkt_sender,
stream_receiver,
mtu,
tcp_timeout,
)
.map(Box::new)
.map(|inner| IpStackTcpStream {
inner: Some(inner),
peer_addr,
local_addr,
stream_sender,
})
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
pub fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
self.stream_sender.clone()
}
}

impl tokio::io::AsyncRead for IpStackTcpStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
}

impl tokio::io::AsyncWrite for IpStackTcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_flush(cx),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.inner.as_mut() {
Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx),
None => {
std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
}
}
}
}

impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Some(mut inner) = self.inner.take() {
tokio::spawn(async move {
_ = timeout(Duration::from_secs(2), inner.shutdown()).await;
});
}
}
}

0 comments on commit 8876511

Please sign in to comment.