Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine code #59

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ udp-stream = { version = "0.0", default-features = false }
# Benchmarks
criterion = { version = "0.5" }

[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
tun = { version = "0.7.13", features = ["async"], default-features = false }

[target.'cfg(target_os = "windows")'.dev-dependencies]
wintun = { version = "0.5", default-features = false }

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ An asynchronous lightweight userspace implementation of TCP/IP stack for Tun dev
Unstable, under development.

[![Crates.io](https://img.shields.io/crates/v/ipstack.svg)](https://crates.io/crates/ipstack)
![ipstack](https://docs.rs/ipstack/badge.svg)
[![ipstack](https://docs.rs/ipstack/badge.svg)](https://docs.rs/ipstack)
[![Documentation](https://img.shields.io/badge/docs-release-brightgreen.svg?style=flat)](https://docs.rs/ipstack)
[![Download](https://img.shields.io/crates/d/ipstack.svg)](https://crates.io/crates/ipstack)
[![License](https://img.shields.io/crates/l/ipstack.svg?style=flat)](https://github.com/narrowlink/ipstack/blob/main/LICENSE)
Expand Down Expand Up @@ -86,4 +86,4 @@ async fn main() {
}
```

We also suggest that you take a look at the complete [examples](examples).
We also suggest that you take a look at the complete [examples](./examples).
8 changes: 2 additions & 6 deletions examples/tun2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
//!
use clap::Parser;
use etherparse::{IcmpEchoHeader, Icmpv4Header};
use etherparse::Icmpv4Header;
use ipstack::{stream::IpStackStream, IpNumber};
use std::net::{Ipv4Addr, SocketAddr};
use tokio::net::TcpStream;
Expand Down Expand Up @@ -154,12 +154,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let n = number;
if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP {
let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?;
if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type {
if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type {
log::info!("#{n} ICMPv4 echo");
let echo = IcmpEchoHeader {
id: req.id,
seq: req.seq,
};
let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo));
resp.update_checksum(req_payload);
let mut payload = resp.to_bytes().to_vec();
Expand Down
23 changes: 5 additions & 18 deletions examples/tun_wintun.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::net::{Ipv4Addr, SocketAddr};

use clap::Parser;
use etherparse::{IcmpEchoHeader, Icmpv4Header};
use etherparse::Icmpv4Header;
use ipstack::{stream::IpStackStream, IpNumber};
use tokio::net::TcpStream;
use udp_stream::UdpStream;
Expand Down Expand Up @@ -46,10 +46,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut ip_stack = ipstack::IpStack::new(ipstack_config, tun::create_as_async(&config)?);

#[cfg(target_os = "windows")]
let mut ip_stack = ipstack::IpStack::new(
ipstack_config,
wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0)),
);
let mut ip_stack = ipstack::IpStack::new(ipstack_config, wintun::WinTunDevice::new(ipv4, Ipv4Addr::new(255, 255, 255, 0)));

let server_addr = args.server_addr;

Expand Down Expand Up @@ -86,12 +83,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
IpStackStream::UnknownTransport(u) => {
if u.src_addr().is_ipv4() && u.ip_protocol() == IpNumber::ICMP {
let (icmp_header, req_payload) = Icmpv4Header::from_slice(u.payload())?;
if let etherparse::Icmpv4Type::EchoRequest(req) = icmp_header.icmp_type {
if let etherparse::Icmpv4Type::EchoRequest(echo) = icmp_header.icmp_type {
println!("ICMPv4 echo");
let echo = IcmpEchoHeader {
id: req.id,
seq: req.seq,
};
let mut resp = Icmpv4Header::new(etherparse::Icmpv4Type::EchoReply(echo));
resp.update_checksum(req_payload);
let mut payload = resp.to_bytes().to_vec();
Expand Down Expand Up @@ -178,17 +171,11 @@ mod wintun {
std::task::Poll::Ready(Ok(buf.len()))
}

fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
fn poll_flush(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}

fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
fn poll_shutdown(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
Expand Down
1 change: 1 addition & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
max_width = 140
4 changes: 2 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub enum IpStackError {
#[error("ValueTooBigError<usize> {0}")]
ValueTooBigErrorUsize(#[from] etherparse::err::ValueTooBigError<usize>),

#[error("Invalid Tcp packet")]
InvalidTcpPacket,
#[error("Invalid Tcp packet {0}")]
InvalidTcpPacket(crate::packet::TcpHeaderWrapper),

#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
Expand Down
80 changes: 31 additions & 49 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ mod packet;
pub mod stream;

pub use self::error::{IpStackError, Result};
pub use etherparse::IpNumber;
pub use self::packet::TcpHeaderWrapper;
pub use ::etherparse::IpNumber;

const DROP_TTL: u8 = 0;

Expand Down Expand Up @@ -93,35 +94,27 @@ pub struct IpStack {
}

impl IpStack {
pub fn new<D>(config: IpStackConfig, device: D) -> IpStack
pub fn new<Device>(config: IpStackConfig, device: Device) -> IpStack
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
Device: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (accept_sender, accept_receiver) = mpsc::unbounded_channel::<IpStackStream>();
let handle = run(config, device, accept_sender);

IpStack {
accept_receiver,
handle,
handle: run(config, device, accept_sender),
}
}

pub async fn accept(&mut self) -> Result<IpStackStream, IpStackError> {
self.accept_receiver
.recv()
.await
.ok_or(IpStackError::AcceptError)
self.accept_receiver.recv().await.ok_or(IpStackError::AcceptError)
}
}

fn run<D>(
fn run<Device: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: IpStackConfig,
mut device: D,
mut device: Device,
accept_sender: UnboundedSender<IpStackStream>,
) -> JoinHandle<Result<()>>
where
D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
) -> JoinHandle<Result<()>> {
let mut sessions: SessionCollection = AHashMap::new();
let pi = config.packet_information;
let offset = if pi && cfg!(unix) { 4 } else { 0 };
Expand Down Expand Up @@ -167,56 +160,43 @@ fn process_device_read(
};

if let IpStackPacketProtocol::Unknown = packet.transport_protocol() {
return Some(IpStackStream::UnknownTransport(
IpStackUnknownTransport::new(
packet.src_addr().ip(),
packet.dst_addr().ip(),
packet.payload,
&packet.ip,
config.mtu,
pkt_sender,
),
));
return Some(IpStackStream::UnknownTransport(IpStackUnknownTransport::new(
packet.src_addr().ip(),
packet.dst_addr().ip(),
packet.payload,
&packet.ip,
config.mtu,
pkt_sender,
)));
}

match sessions.entry(packet.network_tuple()) {
Occupied(mut entry) => {
if let Err(e) = entry.get().send(packet) {
trace!("New stream because: {}", e);
create_stream(e.0, config, pkt_sender).map(|s| {
entry.insert(s.0);
s.1
log::debug!("New stream \"{}\" because: \"{}\"", e.0.network_tuple(), e);
create_stream(e.0, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| {
entry.insert(packet_sender);
ip_stack_stream
})
} else {
None
}
}
Vacant(entry) => create_stream(packet, config, pkt_sender).map(|s| {
entry.insert(s.0);
s.1
Vacant(entry) => create_stream(packet, config, pkt_sender).map(|(packet_sender, ip_stack_stream)| {
entry.insert(packet_sender);
ip_stack_stream
}),
}
}

fn create_stream(
packet: NetworkPacket,
config: &IpStackConfig,
pkt_sender: PacketSender,
) -> Option<(PacketSender, IpStackStream)> {
fn create_stream(packet: NetworkPacket, config: &IpStackConfig, pkt_sender: PacketSender) -> Option<(PacketSender, IpStackStream)> {
match packet.transport_protocol() {
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(
packet.src_addr(),
packet.dst_addr(),
h,
pkt_sender,
config.mtu,
config.tcp_timeout,
) {
match IpStackTcpStream::new(packet.src_addr(), packet.dst_addr(), h, pkt_sender, config.mtu, config.tcp_timeout) {
Ok(stream) => Some((stream.stream_sender(), IpStackStream::Tcp(stream))),
Err(e) => {
if matches!(e, IpStackError::InvalidTcpPacket) {
trace!("Invalid TCP packet");
if matches!(e, IpStackError::InvalidTcpPacket(_)) {
log::debug!("{e}");
} else {
error!("IpStackTcpStream::new failed \"{}\"", e);
}
Expand Down Expand Up @@ -251,7 +231,9 @@ where
D: AsyncWrite + Unpin + 'static,
{
if packet.ttl() == 0 {
sessions.remove(&packet.reverse_network_tuple());
let network_tuple = packet.reverse_network_tuple();
sessions.remove(&network_tuple);
log::trace!("session removed: {}", network_tuple);
return Ok(());
}
#[allow(unused_mut)]
Expand Down
77 changes: 54 additions & 23 deletions src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ pub struct NetworkTuple {
pub dst: SocketAddr,
pub tcp: bool,
}

impl std::fmt::Display for NetworkTuple {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tcp = if self.tcp { "TCP" } else { "UDP" };
write!(f, "{} {} -> {}", tcp, self.src, self.dst)
}
}

pub mod tcp_flags {
pub const CWR: u8 = 0b10000000;
pub const ECE: u8 = 0b01000000;
Expand Down Expand Up @@ -53,32 +61,18 @@ impl NetworkPacket {
let ip = p.net.ok_or(IpStackError::InvalidPacket)?;

let (ip, ip_payload) = match ip {
NetSlice::Ipv4(ip) => (
IpHeader::Ipv4(ip.header().to_header()),
ip.payload().payload,
),
NetSlice::Ipv6(ip) => (
IpHeader::Ipv6(ip.header().to_header()),
ip.payload().payload,
),
NetSlice::Ipv4(ip) => (IpHeader::Ipv4(ip.header().to_header()), ip.payload().payload),
NetSlice::Ipv6(ip) => (IpHeader::Ipv6(ip.header().to_header()), ip.payload().payload),
NetSlice::Arp(_) => return Err(IpStackError::UnsupportedTransportProtocol),
};
let (transport, payload) = match p.transport {
Some(etherparse::TransportSlice::Tcp(h)) => {
(TransportHeader::Tcp(h.to_header()), h.payload())
}
Some(etherparse::TransportSlice::Udp(u)) => {
(TransportHeader::Udp(u.to_header()), u.payload())
}
Some(etherparse::TransportSlice::Tcp(h)) => (TransportHeader::Tcp(h.to_header()), h.payload()),
Some(etherparse::TransportSlice::Udp(u)) => (TransportHeader::Udp(u.to_header()), u.payload()),
_ => (TransportHeader::Unknown, ip_payload),
};
let payload = payload.to_vec();

Ok(NetworkPacket {
ip,
transport,
payload,
})
Ok(NetworkPacket { ip, transport, payload })
}
pub(crate) fn transport_protocol(&self) -> IpStackPacketProtocol {
match self.transport {
Expand Down Expand Up @@ -146,10 +140,49 @@ impl NetworkPacket {
}

#[derive(Debug, Clone)]
pub(super) struct TcpHeaderWrapper {
pub struct TcpHeaderWrapper {
header: TcpHeader,
}

impl std::fmt::Display for TcpHeaderWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut flags = String::new();
if self.header.cwr {
flags.push_str("CWR ");
}
if self.header.ece {
flags.push_str("ECE ");
}
if self.header.urg {
flags.push_str("URG ");
}
if self.header.ack {
flags.push_str("ACK ");
}
if self.header.psh {
flags.push_str("PSH ");
}
if self.header.rst {
flags.push_str("RST ");
}
if self.header.syn {
flags.push_str("SYN ");
}
if self.header.fin {
flags.push_str("FIN ");
}
write!(
f,
"TcpHeader {{ src_port: {}, dst_port: {}, seq: {}, ack: {}, flags: {} }}",
self.header.source_port,
self.header.destination_port,
self.header.sequence_number,
self.header.acknowledgment_number,
flags.trim()
)
}
}

impl TcpHeaderWrapper {
pub fn inner(&self) -> &TcpHeader {
&self.header
Expand Down Expand Up @@ -188,9 +221,7 @@ impl TcpHeaderWrapper {

impl From<&TcpHeader> for TcpHeaderWrapper {
fn from(header: &TcpHeader) -> Self {
TcpHeaderWrapper {
header: header.clone(),
}
TcpHeaderWrapper { header: header.clone() }
}
}

Expand Down
8 changes: 2 additions & 6 deletions src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ impl IpStackStream {
match self {
IpStackStream::Tcp(tcp) => tcp.local_addr(),
IpStackStream::Udp(udp) => udp.local_addr(),
IpStackStream::UnknownNetwork(_) => {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0))
}
IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() {
IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
Expand All @@ -35,9 +33,7 @@ impl IpStackStream {
match self {
IpStackStream::Tcp(tcp) => tcp.peer_addr(),
IpStackStream::Udp(udp) => udp.peer_addr(),
IpStackStream::UnknownNetwork(_) => {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0))
}
IpStackStream::UnknownNetwork(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() {
IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
Expand Down
Loading
Loading