diff --git a/Cargo.toml b/Cargo.toml index edbd11c..64d9971 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,8 +34,8 @@ udp-stream = { version = "0.0", default-features = false } # Benchmarks criterion = { version = "0.5" } -[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] tun = { version = "0.7.13", features = ["async"], default-features = false } + [target.'cfg(target_os = "windows")'.dev-dependencies] wintun = { version = "0.5", default-features = false } diff --git a/README.md b/README.md index 9517ed6..5ed93f8 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ An asynchronous lightweight userspace implementation of TCP/IP stack for Tun dev Unstable, under development. [![Crates.io](https://img.shields.io/crates/v/ipstack.svg)](https://crates.io/crates/ipstack) -![ipstack](https://docs.rs/ipstack/badge.svg) +[![ipstack](https://docs.rs/ipstack/badge.svg)](https://docs.rs/ipstack) [![Documentation](https://img.shields.io/badge/docs-release-brightgreen.svg?style=flat)](https://docs.rs/ipstack) [![Download](https://img.shields.io/crates/d/ipstack.svg)](https://crates.io/crates/ipstack) [![License](https://img.shields.io/crates/l/ipstack.svg?style=flat)](https://github.com/narrowlink/ipstack/blob/main/LICENSE) @@ -86,4 +86,4 @@ async fn main() { } ``` -We also suggest that you take a look at the complete [examples](examples). +We also suggest that you take a look at the complete [examples](./examples). diff --git a/examples/tun2.rs b/examples/tun2.rs index 3c17db3..535e02b 100644 --- a/examples/tun2.rs +++ b/examples/tun2.rs @@ -28,7 +28,7 @@ //! use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::Icmpv4Header; use ipstack::{stream::IpStackStream, IpNumber}; use std::net::{Ipv4Addr, SocketAddr}; use tokio::net::TcpStream; @@ -154,12 +154,8 @@ async fn main() -> Result<(), Box> { let n = number; if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { log::info!("#{n} ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index e55c8d2..ceaa96f 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -1,7 +1,7 @@ use std::net::{Ipv4Addr, SocketAddr}; use clap::Parser; -use etherparse::{IcmpEchoHeader, Icmpv4Header}; +use etherparse::Icmpv4Header; use ipstack::{stream::IpStackStream, IpNumber}; use tokio::net::TcpStream; use udp_stream::UdpStream; @@ -46,10 +46,7 @@ async fn main() -> Result<(), Box> { let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&config)?); #[cfg(target_os = "windows")] - let mut ip_stack = ipstack::IpStack::new( - ipstack_config, - wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0)), - ); + let mut ip_stack = ipstack::IpStack::new(ipstack_config, wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0))); let server_addr = args.server_addr; @@ -86,12 +83,8 @@ async fn main() -> Result<(), Box> { IpStackStream::UnknownTransport(u) => { if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP { let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?; - if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type { + if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type { println!("ICMPv4 echo"); - let echo = IcmpEchoHeader { - id: req.id, - seq: req.seq, - }; let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo)); resp.update_checksum(req_payload); let mut payload = resp.to_bytes().to_vec(); @@ -178,17 +171,11 @@ mod wintun { std::task::Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..8449be0 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 140 diff --git a/src/error.rs b/src/error.rs index 360badd..319a439 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,8 +13,8 @@ pub enum IpStackError { #[error("ValueTooBigError {0}")] ValueTooBigErrorUsize(#[from] etherparse::err::ValueTooBigError), - #[error("Invalid Tcp packet")] - InvalidTcpPacket, + #[error("Invalid Tcp packet {0}")] + InvalidTcpPacket(crate::packet::TcpHeaderWrapper), #[error("IO error: {0}")] IoError(#[from] std::io::Error), diff --git a/src/lib.rs b/src/lib.rs index 93ab830..2108e68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,8 @@ mod packet; pub mod stream; pub use self::error::{IpStackError, Result}; -pub use etherparse::IpNumber; +pub use self::packet::TcpHeaderWrapper; +pub use ::etherparse::IpNumber; const DROP_TTL: u8 = 0; @@ -93,35 +94,27 @@ pub struct IpStack { } impl IpStack { - pub fn new(config: IpStackConfig, device: D) -> IpStack + pub fn new(config: IpStackConfig, device: Device) -> IpStack where - D: AsyncRead + AsyncWrite + Unpin + Send + 'static, + Device: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - let handle = run(config, device, accept_sender); - IpStack { accept_receiver, - handle, + handle: run(config, device, accept_sender), } } pub async fn accept(&mut self) -> Result { - self.accept_receiver - .recv() - .await - .ok_or(IpStackError::AcceptError) + self.accept_receiver.recv().await.ok_or(IpStackError::AcceptError) } } -fn run( +fn run( config: IpStackConfig, - mut device: D, + mut device: Device, accept_sender: UnboundedSender, -) -> JoinHandle> -where - D: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ +) -> JoinHandle> { let mut sessions: SessionCollection = AHashMap::new(); let pi = config.packet_information; let offset = if pi && cfg!(unix) { 4 } else { 0 }; @@ -167,56 +160,43 @@ fn process_device_read( }; if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - return Some(IpStackStream::UnknownTransport( - IpStackUnknownTransport::new( - packet.src_addr().ip(), - packet.dst_addr().ip(), - packet.payload, - &packet.ip, - config.mtu, - pkt_sender, - ), - )); + return Some(IpStackStream::UnknownTransport(IpStackUnknownTransport::new( + packet.src_addr().ip(), + packet.dst_addr().ip(), + packet.payload, + &packet.ip, + config.mtu, + pkt_sender, + ))); } match sessions.entry(packet.network_tuple()) { Occupied(mut entry) => { if let Err(e) = entry.get().send(packet) { - trace!("New stream because: {}", e); - create_stream(e.0, config, pkt_sender).map(|s| { - entry.insert(s.0); - s.1 + log::debug!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e); + create_stream(e.0, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { + entry.insert(packet_sender); + ip_stack_stream }) } else { None } } - Vacant(entry) => create_stream(packet, config, pkt_sender).map(|s| { - entry.insert(s.0); - s.1 + Vacant(entry) => create_stream(packet, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| { + entry.insert(packet_sender); + ip_stack_stream }), } } -fn create_stream( - packet: NetworkPacket, - config: &IpStackConfig, - pkt_sender: PacketSender, -) -> Option<(PacketSender, IpStackStream)> { +fn create_stream(packet: NetworkPacket, config: &IpStackConfig, pkt_sender: PacketSender) -> Option<(PacketSender, IpStackStream)> { match packet.transport_protocol() { IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new( - packet.src_addr(), - packet.dst_addr(), - h, - pkt_sender, - config.mtu, - config.tcp_timeout, - ) { + match IpStackTcpStream::new(packet.src_addr(), packet.dst_addr(), h, pkt_sender, config.mtu, config.tcp_timeout) { Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))), Err(e) => { - if matches!(e, IpStackError::InvalidTcpPacket) { - trace!("Invalid TCP packet"); + if matches!(e, IpStackError::InvalidTcpPacket(_)) { + log::debug!("{e}"); } else { error!("IpStackTcpStream::new failed \"{}\"", e); } @@ -251,7 +231,9 @@ where D: AsyncWrite + Unpin + 'static, { if packet.ttl() == 0 { - sessions.remove(&packet.reverse_network_tuple()); + let network_tuple = packet.reverse_network_tuple(); + sessions.remove(&network_tuple); + log::trace!("session removed: {}", network_tuple); return Ok(()); } #[allow(unused_mut)] diff --git a/src/packet.rs b/src/packet.rs index 540d4ea..6e6fa60 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -8,6 +8,14 @@ pub struct NetworkTuple { pub dst: SocketAddr, pub tcp: bool, } + +impl std::fmt::Display for NetworkTuple { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let tcp = if self.tcp { "TCP" } else { "UDP" }; + write!(f, "{} {} -> {}", tcp, self.src, self.dst) + } +} + pub mod tcp_flags { pub const CWR: u8 = 0b10000000; pub const ECE: u8 = 0b01000000; @@ -53,32 +61,18 @@ impl NetworkPacket { let ip = p.net.ok_or(IpStackError::InvalidPacket)?; let (ip, ip_payload) = match ip { - NetSlice::Ipv4(ip) => ( - IpHeader::Ipv4(ip.header().to_header()), - ip.payload().payload, - ), - NetSlice::Ipv6(ip) => ( - IpHeader::Ipv6(ip.header().to_header()), - ip.payload().payload, - ), + NetSlice::Ipv4(ip) => (IpHeader::Ipv4(ip.header().to_header()), ip.payload().payload), + NetSlice::Ipv6(ip) => (IpHeader::Ipv6(ip.header().to_header()), ip.payload().payload), NetSlice::Arp(_) => return Err(IpStackError::UnsupportedTransportProtocol), }; let (transport, payload) = match p.transport { - Some(etherparse::TransportSlice::Tcp(h)) => { - (TransportHeader::Tcp(h.to_header()), h.payload()) - } - Some(etherparse::TransportSlice::Udp(u)) => { - (TransportHeader::Udp(u.to_header()), u.payload()) - } + Some(etherparse::TransportSlice::Tcp(h)) => (TransportHeader::Tcp(h.to_header()), h.payload()), + Some(etherparse::TransportSlice::Udp(u)) => (TransportHeader::Udp(u.to_header()), u.payload()), _ => (TransportHeader::Unknown, ip_payload), }; let payload = payload.to_vec(); - Ok(NetworkPacket { - ip, - transport, - payload, - }) + Ok(NetworkPacket { ip, transport, payload }) } pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol { match self.transport { @@ -146,10 +140,49 @@ impl NetworkPacket { } #[derive(Debug, Clone)] -pub(super) struct TcpHeaderWrapper { +pub struct TcpHeaderWrapper { header: TcpHeader, } +impl std::fmt::Display for TcpHeaderWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut flags = String::new(); + if self.header.cwr { + flags.push_str("CWR "); + } + if self.header.ece { + flags.push_str("ECE "); + } + if self.header.urg { + flags.push_str("URG "); + } + if self.header.ack { + flags.push_str("ACK "); + } + if self.header.psh { + flags.push_str("PSH "); + } + if self.header.rst { + flags.push_str("RST "); + } + if self.header.syn { + flags.push_str("SYN "); + } + if self.header.fin { + flags.push_str("FIN "); + } + write!( + f, + "TcpHeader {{ src_port: {}, dst_port: {}, seq: {}, ack: {}, flags: {} }}", + self.header.source_port, + self.header.destination_port, + self.header.sequence_number, + self.header.acknowledgment_number, + flags.trim() + ) + } +} + impl TcpHeaderWrapper { pub fn inner(&self) -> &TcpHeader { &self.header @@ -188,9 +221,7 @@ impl TcpHeaderWrapper { impl From<&TcpHeader> for TcpHeaderWrapper { fn from(header: &TcpHeader) -> Self { - TcpHeaderWrapper { - header: header.clone(), - } + TcpHeaderWrapper { header: header.clone() } } } diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 42632f4..944e98a 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -22,9 +22,7 @@ impl IpStackStream { match self { IpStackStream::Tcp(tcp) => tcp.local_addr(), IpStackStream::Udp(udp) => udp.local_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) - } + IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), @@ -35,9 +33,7 @@ impl IpStackStream { match self { IpStackStream::Tcp(tcp) => tcp.peer_addr(), IpStackStream::Udp(udp) => udp.peer_addr(), - IpStackStream::UnknownNetwork(_) => { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) - } + IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)), IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() { IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)), IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)), diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index ec0a671..9aad3fe 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -71,15 +71,10 @@ impl Tcb { if seq < self.ack { return; } - self.unordered_packets - .insert(seq, UnorderedPacket::new(buf)); + self.unordered_packets.insert(seq, UnorderedPacket::new(buf)); } pub(super) fn get_available_read_buffer_size(&self) -> usize { - READ_BUFFER_SIZE.saturating_sub( - self.unordered_packets - .iter() - .fold(0, |acc, (_, p)| acc + p.payload.len()), - ) + READ_BUFFER_SIZE.saturating_sub(self.unordered_packets.iter().fold(0, |acc, (_, p)| acc + p.payload.len())) } pub(super) fn get_unordered_packets(&mut self) -> Option> { // dbg!(self.ack); @@ -110,8 +105,7 @@ impl Tcb { self.state } pub(super) fn change_send_window(&mut self, window: u16) { - let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) - / (self.avg_send_window.1 + 1); + let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) / (self.avg_send_window.1 + 1); self.avg_send_window.0 = avg_send_window; self.avg_send_window.1 += 1; self.send_window = window; @@ -147,8 +141,7 @@ impl Tcb { let current_ack_distance = self.seq.wrapping_sub(self.last_ack); if received_ack_distance > current_ack_distance - || (tcp_header.acknowledgment_number != self.seq - && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) + || (tcp_header.acknowledgment_number != self.seq && self.seq.saturating_sub(tcp_header.acknowledgment_number) == 0) { PacketStatus::Invalid } else if self.last_ack == tcp_header.acknowledgment_number { diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 77e2547..bc5b7e8 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -83,7 +83,7 @@ impl IpStackTcpStream { warn!("Error sending RST/ACK packet: {:?}", err); } } - Err(IpStackError::InvalidTcpPacket) + Err(IpStackError::InvalidTcpPacket(tcp.clone())) } fn calculate_payload_len(&self, ip_header_size: u16, tcp_header_size: u16) -> u16 { @@ -93,13 +93,7 @@ impl IpStackTcpStream { ) } - fn create_rev_packet( - &self, - flags: u8, - ttl: u8, - seq: impl Into>, - mut payload: Vec, - ) -> Result { + fn create_rev_packet(&self, flags: u8, ttl: u8, seq: impl Into>, mut payload: Vec) -> Result { let mut tcp_header = etherparse::TcpHeader::new( self.dst_addr.port(), self.src_addr.port(), @@ -116,12 +110,8 @@ impl IpStackTcpStream { let ip_header = match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; - let payload_len = self.calculate_payload_len( - ip_h.header_len() as u16, - tcp_header.header_len() as u16, - ); + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::TCP, dst.octets(), src.octets()).map_err(IpStackError::from)?; + let payload_len = self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); ip_h.set_payload_len(payload.len() + tcp_header.header_len()) .map_err(IpStackError::from)?; @@ -138,10 +128,7 @@ impl IpStackTcpStream { source: dst.octets(), destination: src.octets(), }; - let payload_len = self.calculate_payload_len( - ip_h.header_len() as u16, - tcp_header.header_len() as u16, - ); + let payload_len = self.calculate_payload_len(ip_h.header_len() as u16, tcp_header.header_len() as u16); payload.truncate(payload_len as usize); let len = payload.len() + tcp_header.header_len(); ip_h.set_payload_length(len).map_err(IpStackError::from)?; @@ -172,11 +159,7 @@ impl IpStackTcpStream { } impl AsyncRead for IpStackTcpStream { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { + fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { if self.tcb.retransmission.is_some() { self.write_notify = Some(cx.waker().clone()); @@ -186,9 +169,7 @@ impl AsyncRead for IpStackTcpStream { } if let Some(packet) = self.packet_to_send.take() { - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; + self.packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; } if self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); @@ -196,8 +177,7 @@ impl AsyncRead for IpStackTcpStream { } if self.tcb.get_state() == TcpState::FinWait2(false) { - self.packet_to_send = - Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionAborted))); @@ -218,18 +198,13 @@ impl AsyncRead for IpStackTcpStream { self.tcb.reset_timeout(); if self.tcb.get_state() == TcpState::SynReceived(false) { - self.packet_to_send = - Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::SynReceived(true)); continue; } - if let Some(b) = self - .tcb - .get_unordered_packets() - .filter(|_| matches!(self.shutdown, Shutdown::None)) - { + if let Some(b) = self.tcb.get_unordered_packets().filter(|_| matches!(self.shutdown, Shutdown::None)) { self.tcb.add_ack(b.len() as u32); buf.put_slice(&b); self.packet_sender @@ -238,8 +213,7 @@ impl AsyncRead for IpStackTcpStream { return Poll::Ready(Ok(())); } if self.tcb.get_state() == TcpState::FinWait1(true) { - self.packet_to_send = - Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.add_ack(1); self.tcb.change_state(TcpState::FinWait2(true)); @@ -248,8 +222,7 @@ impl AsyncRead for IpStackTcpStream { && self.tcb.get_state() == TcpState::Established && self.tcb.get_last_ack() == self.tcb.get_seq() { - self.packet_to_send = - Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait1(false)); continue; @@ -260,8 +233,7 @@ impl AsyncRead for IpStackTcpStream { unreachable!() }; if t.flags() & RST != 0 { - self.packet_to_send = - Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset))); @@ -291,8 +263,7 @@ impl AsyncRead for IpStackTcpStream { PacketStatus::KeepAlive => { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_send_window(t.inner().window_size); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); continue; } PacketStatus::RetransmissionRequest => { @@ -316,8 +287,7 @@ impl AsyncRead for IpStackTcpStream { // } self.tcb.change_last_ack(t.inner().acknowledgment_number); - self.tcb - .add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.add_unordered_packet(t.inner().sequence_number, p.payload); self.tcb.change_send_window(t.inner().window_size); if let Some(ref n) = self.write_notify { @@ -339,30 +309,23 @@ impl AsyncRead for IpStackTcpStream { } if t.flags() == (FIN | ACK) { self.tcb.add_ack(1); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait1(true)); continue; } if t.flags() == (PSH | ACK) { - if !matches!( - self.tcb.check_pkt_type(&t, &p.payload), - PacketStatus::NewPacket - ) { + if !matches!(self.tcb.check_pkt_type(&t, &p.payload), PacketStatus::NewPacket) { continue; } self.tcb.change_last_ack(t.inner().acknowledgment_number); - if p.payload.is_empty() - || self.tcb.get_ack() != t.inner().sequence_number - { + if p.payload.is_empty() || self.tcb.get_ack() != t.inner().sequence_number { continue; } self.tcb.change_send_window(t.inner().window_size); - self.tcb - .add_unordered_packet(t.inner().sequence_number, p.payload); + self.tcb.add_unordered_packet(t.inner().sequence_number, p.payload); continue; } } else if self.tcb.get_state() == TcpState::FinWait1(false) { @@ -373,8 +336,7 @@ impl AsyncRead for IpStackTcpStream { continue; } else if t.flags() == (FIN | ACK) { self.tcb.add_ack(1); - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::FinWait2(true)); continue; @@ -383,8 +345,7 @@ impl AsyncRead for IpStackTcpStream { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { - self.packet_to_send = - Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); + self.packet_to_send = Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?); self.tcb.change_state(TcpState::FinWait2(false)); } } @@ -397,19 +358,13 @@ impl AsyncRead for IpStackTcpStream { } impl AsyncWrite for IpStackTcpStream { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } self.tcb.reset_timeout(); - if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 - || self.tcb.is_send_buffer_full() - { + if (self.tcb.get_send_window() as u64) < self.tcb.get_avg_send_window() / 2 || self.tcb.is_send_buffer_full() { self.write_notify = Some(cx.waker().clone()); return Poll::Pending; } @@ -425,30 +380,22 @@ impl AsyncWrite for IpStackTcpStream { let seq = self.tcb.get_seq(); let payload_len = packet.payload.len(); let payload = packet.payload.clone(); - self.packet_sender - .send(packet) - .or(Err(ErrorKind::UnexpectedEof))?; + self.packet_sender.send(packet).or(Err(ErrorKind::UnexpectedEof))?; self.tcb.add_inflight_packet(seq, payload); Poll::Ready(Ok(payload_len)) } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(mut self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } if let Some(s) = self.tcb.retransmission.take() { if let Some(packet) = self.tcb.inflight_packets.iter().find(|p| p.seq == s) { - let rev_packet = - self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; + let rev_packet = self.create_rev_packet(PSH | ACK, TTL, packet.seq, packet.payload.clone())?; - self.packet_sender - .send(rev_packet) - .or(Err(ErrorKind::UnexpectedEof))?; + self.packet_sender.send(rev_packet).or(Err(ErrorKind::UnexpectedEof))?; } else { error!("Packet {} not found in inflight_packets", s); error!("seq: {}", self.tcb.get_seq()); @@ -465,19 +412,13 @@ impl AsyncWrite for IpStackTcpStream { Poll::Ready(Ok(())) } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if matches!(self.shutdown, Shutdown::Ready) { return Poll::Ready(Ok(())); } else if matches!(self.shutdown, Shutdown::None) { self.shutdown.pending(cx.waker().clone()); } - self.poll_read( - cx, - &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()]), - ) + self.poll_read(cx, &mut tokio::io::ReadBuf::uninit(&mut [MaybeUninit::::uninit()])) } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index e6653b9..a3695f5 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -23,20 +23,13 @@ impl IpStackTcpStream { tcp_timeout: Duration, ) -> Result { let (stream_sender, stream_receiver) = mpsc::unbounded_channel::(); - IpStackTcpStreamInner::new( - local_addr, - peer_addr, - tcp, - pkt_sender, - stream_receiver, - mtu, - tcp_timeout, - ) - .map(|inner| IpStackTcpStream { - inner: Some(Box::new(inner)), - peer_addr, - local_addr, - stream_sender, + IpStackTcpStreamInner::new(local_addr, peer_addr, tcp, pkt_sender, stream_receiver, mtu, tcp_timeout).map(|inner| { + IpStackTcpStream { + inner: Some(Box::new(inner)), + peer_addr, + local_addr, + stream_sender, + } }) } pub fn local_addr(&self) -> SocketAddr { @@ -58,9 +51,7 @@ impl tokio::io::AsyncRead for IpStackTcpStream { ) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } } @@ -73,31 +64,19 @@ impl tokio::io::AsyncWrite for IpStackTcpStream { ) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_flush(cx), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { match self.inner.as_mut() { Some(mut inner) => Pin::new(&mut inner).poll_shutdown(cx), - None => { - std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))) - } + None => std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected))), } } } @@ -105,10 +84,13 @@ impl tokio::io::AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Some(mut inner) = self.inner.take() { + let local_addr = self.local_addr(); + let peer_addr = self.peer_addr(); tokio::spawn(async move { if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await { log::warn!("Error while dropping IpStackTcpStream: {:?}", err); } + log::trace!("TCP Stream closed: {} -> {}", local_addr, peer_addr); }); } } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index ad8086c..4a8bce1 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -55,19 +55,12 @@ impl IpStackUdpStream { const UHS: usize = 8; // udp header size is 8 match (self.dst_addr.ip(), self.src_addr.ip()) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()) - .map_err(IpStackError::from)?; + let mut ip_h = Ipv4Header::new(0, ttl, IpNumber::UDP, dst.octets(), src.octets()).map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub((ip_h.header_len() + UHS) as u16); payload.truncate(line_buffer as usize); - ip_h.set_payload_len(payload.len() + UHS) + ip_h.set_payload_len(payload.len() + UHS).map_err(IpStackError::from)?; + let udp_header = UdpHeader::with_ipv4_checksum(self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload) .map_err(IpStackError::from)?; - let udp_header = UdpHeader::with_ipv4_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv4(ip_h), transport: TransportHeader::Udp(udp_header), @@ -89,13 +82,8 @@ impl IpStackUdpStream { payload.truncate(line_buffer as usize); ip_h.payload_length = (payload.len() + UHS) as u16; - let udp_header = UdpHeader::with_ipv6_checksum( - self.dst_addr.port(), - self.src_addr.port(), - &ip_h, - &payload, - ) - .map_err(IpStackError::from)?; + let udp_header = UdpHeader::with_ipv6_checksum(self.dst_addr.port(), self.src_addr.port(), &ip_h, &payload) + .map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Udp(udp_header), @@ -148,31 +136,19 @@ impl AsyncRead for IpStackUdpStream { } impl AsyncWrite for IpStackUdpStream { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { + fn poll_write(mut self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>, buf: &[u8]) -> std::task::Poll> { self.reset_timeout(); let packet = self.create_rev_packet(TTL, buf.to_vec())?; let payload_len = packet.payload.len(); - self.pkt_sender - .send(packet) - .or(Err(std::io::ErrorKind::UnexpectedEof))?; + self.pkt_sender.send(packet).or(Err(std::io::ErrorKind::UnexpectedEof))?; std::task::Poll::Ready(Ok(payload_len)) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { std::task::Poll::Ready(Ok(())) } } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 838d93f..173dfce 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,9 +1,9 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - PacketSender, TTL, + IpStackError, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header}; -use std::{io::Error, mem, net::IpAddr}; +use std::net::IpAddr; pub struct IpStackUnknownTransport { src_addr: IpAddr, @@ -15,14 +15,7 @@ pub struct IpStackUnknownTransport { } impl IpStackUnknownTransport { - pub(crate) fn new( - src_addr: IpAddr, - dst_addr: IpAddr, - payload: Vec, - ip: &IpHeader, - mtu: u16, - packet_sender: PacketSender, - ) -> Self { + pub(crate) fn new(src_addr: IpAddr, dst_addr: IpAddr, payload: Vec, ip: &IpHeader, mtu: u16, packet_sender: PacketSender) -> Self { let protocol = match ip { IpHeader::Ipv4(ip) => ip.protocol, IpHeader::Ipv6(ip) => ip.next_header, @@ -48,32 +41,30 @@ impl IpStackUnknownTransport { pub fn ip_protocol(&self) -> IpNumber { self.protocol } - pub fn send(&self, mut payload: Vec) -> Result<(), Error> { + pub fn send(&self, mut payload: Vec) -> std::io::Result<()> { loop { let packet = self.create_rev_packet(&mut payload)?; self.packet_sender .send(packet) - .map_err(|_| Error::new(std::io::ErrorKind::Other, "send error"))?; + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("send error: {}", e)))?; if payload.is_empty() { return Ok(()); } } } - pub fn create_rev_packet(&self, payload: &mut Vec) -> Result { + pub fn create_rev_packet(&self, payload: &mut Vec) -> std::io::Result { match (self.dst_addr, self.src_addr) { (std::net::IpAddr::V4(dst), std::net::IpAddr::V4(src)) => { - let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()) - .map_err(crate::IpStackError::from)?; + let mut ip_h = Ipv4Header::new(0, TTL, self.protocol, dst.octets(), src.octets()).map_err(IpStackError::from)?; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); let p = if payload.len() > line_buffer as usize { payload.drain(0..line_buffer as usize).collect::>() } else { - mem::take(payload) + std::mem::take(payload) }; - ip_h.set_payload_len(p.len()) - .map_err(crate::IpStackError::from)?; + ip_h.set_payload_len(p.len()).map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv4(ip_h), transport: TransportHeader::Unknown, @@ -91,13 +82,12 @@ impl IpStackUnknownTransport { destination: src.octets(), }; let line_buffer = self.mtu.saturating_sub(ip_h.header_len() as u16); - payload.truncate(line_buffer as usize); - ip_h.payload_length = payload.len() as u16; let p = if payload.len() > line_buffer as usize { payload.drain(0..line_buffer as usize).collect::>() } else { - mem::take(payload) + std::mem::take(payload) }; + ip_h.set_payload_length(p.len()).map_err(IpStackError::from)?; Ok(NetworkPacket { ip: IpHeader::Ipv6(ip_h), transport: TransportHeader::Unknown,