From 74593acec7ff06f3c43ef34ec36701f72523aee9 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Fri, 30 Aug 2024 20:54:40 +0800 Subject: [PATCH] async implement --- Cargo.toml | 12 +- README.md | 19 ++- examples/udp-echo-async.rs | 282 +++++++++++++++++++++++++++++++++++++ src/adapter.rs | 14 +- src/async_session.rs | 116 +++++++++++++++ src/handle.rs | 62 ++++++++ src/lib.rs | 7 + src/session.rs | 29 ++-- src/util.rs | 9 -- 9 files changed, 509 insertions(+), 41 deletions(-) create mode 100644 examples/udp-echo-async.rs create mode 100644 src/async_session.rs create mode 100644 src/handle.rs diff --git a/Cargo.toml b/Cargo.toml index 47b8e87..7b63dec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,12 +25,16 @@ targets = [ [features] default = [] -# default = ["verify_binary_signature", "panic_on_unsent_packets"] +# default = ["verify_binary_signature", "panic_on_unsent_packets", "async"] +async = ["async-task", "blocking", "futures"] panic_on_unsent_packets = [] verify_binary_signature = [] [dependencies] +async-task = { version = "4", optional = true } +blocking = { version = "1", optional = true } c2rust-bitfields = "0.18" +futures = { version = "0.3", optional = true } libloading = "0.8" log = "0.4" thiserror = "1" @@ -59,3 +63,9 @@ packet = "0.1" pcap-file = "2" serde_json = "1" subprocess = "0.2" +tokio = { version = "1", features = ["full"] } + +[[example]] +name = "udp-echo-async" +path = "examples/udp-echo-async.rs" +required-features = ["async"] diff --git a/README.md b/README.md index 38e53aa..6d0da17 100644 --- a/README.md +++ b/README.md @@ -78,10 +78,19 @@ wintun's internal ring buffer. - `verify_binary_signature`: Verifies the signature of the wintun dll file before loading it. -## TODO: -- Add async support -Requires hooking into a windows specific reactor and registering read interest on wintun's read -handle. Asyncify other slow operations via tokio::spawn_blocking. As always, PR's are welcome! - +- `async`: Enables async support for the library. + Just add `async` feature to your `Cargo.toml`: + ```toml + [dependencies] + wintun-bindings = { version = "0.1", features = ["async"] } + ``` + And simply transform your `Session` into an `AsyncSession`: + ```rust + // ... + let session = Arc::new(adapter.start_session(MAX_RING_CAPACITY)?); + let mut reader_session: AsyncSession = session.clone().try_into()?; + let mut writer_session: AsyncSession = session.clone().try_into()?; + // ... + ``` License: MIT diff --git a/examples/udp-echo-async.rs b/examples/udp-echo-async.rs new file mode 100644 index 0000000..52288ac --- /dev/null +++ b/examples/udp-echo-async.rs @@ -0,0 +1,282 @@ +//! This example demonstrates how to use Wintun to create a simple UDP echo server. +//! +//! You can see packets being received by wintun by runnig: `nc -u 10.28.13.100 4321` +//! and sending lines of text. + +use futures::{AsyncReadExt, AsyncWriteExt}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use tokio::sync::mpsc::channel; +use windows_sys::Win32::{ + Foundation::FALSE, + Security::Cryptography::{CryptAcquireContextW, CryptGenRandom, CryptReleaseContext, PROV_RSA_FULL}, +}; +use wintun_bindings::{ + get_active_network_interface_gateways, get_running_driver_version, get_wintun_bin_pattern_path, load_from_path, + run_command, Adapter, AsyncSession, BoxError, Error, MAX_RING_CAPACITY, +}; + +#[derive(Debug)] +struct NaiveUdpPacket { + src_addr: SocketAddr, + dst_addr: SocketAddr, + data: Vec, +} + +impl NaiveUdpPacket { + fn new(src_addr: SocketAddr, dst_addr: SocketAddr, data: &[u8]) -> Self { + Self { + src_addr, + dst_addr, + data: data.to_vec(), + } + } +} + +impl std::fmt::Display for NaiveUdpPacket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "src=\"{}\", dst=\"{}\", data length {}", + self.src_addr, + self.dst_addr, + self.data.len() + ) + } +} + +#[tokio::main] +async fn main() -> Result<(), BoxError> { + dotenvy::dotenv().ok(); + env_logger::init(); + // Loading wintun + let dll_path = get_wintun_bin_pattern_path()?; + let wintun = unsafe { load_from_path(dll_path)? }; + + let version = get_running_driver_version(&wintun); + println!("Wintun version: {:?}", version); + + let adapter_name = "Demo"; + let guid = 2131231231231231231_u128; + + // Open or create a new adapter + let adapter = match Adapter::open(&wintun, adapter_name) { + Ok(a) => a, + Err(_) => Adapter::create(&wintun, adapter_name, "MyTunnelType", Some(guid))?, + }; + + let version = get_running_driver_version(&wintun)?; + println!("Wintun version: {}", version); + + // set metric command: `netsh interface ipv4 set interface adapter_name metric=255` + let args = &["interface", "ipv4", "set", "interface", adapter_name, "metric=255"]; + run_command("netsh", args)?; + println!("netsh {}", args.join(" ")); + + // Execute the network card initialization command, setting virtual network card information + // ip = 10.28.13.2 mask = 255.255.255.0 gateway = 10.28.13.1 + // command: `netsh interface ipv4 set address adapter_name static 10.28.13.2/24 gateway=10.28.13.1` + let args = &[ + "interface", + "ipv4", + "set", + "address", + adapter_name, + "static", + "10.28.13.2/24", + "gateway=10.28.13.1", + ]; + run_command("netsh", args)?; + println!("netsh {}", args.join(" ")); + + let dns = "8.8.8.8".parse::().unwrap(); + let dns2 = "8.8.4.4".parse::().unwrap(); + adapter.set_dns_servers(&[dns, dns2])?; + + let v = adapter.get_addresses()?; + for addr in &v { + let mask = adapter.get_netmask_of_address(addr)?; + println!("address {} netmask: {}", addr, mask); + } + + let gateways = adapter.get_gateways()?; + println!("adapter gateways: {gateways:?}"); + + // adapter.set_name("MyNewName")?; + // println!("adapter name: {}", adapter.get_name()?); + + // adapter.set_address("10.28.13.2".parse()?)?; + + println!("adapter mtu: {}", adapter.get_mtu()?); + + println!( + "active adapter gateways: {:?}", + get_active_network_interface_gateways()? + ); + + let session = Arc::new(adapter.start_session(MAX_RING_CAPACITY)?); + + let mut reader_session: AsyncSession = session.clone().try_into()?; + let mut writer_session: AsyncSession = session.clone().try_into()?; + + let (tx, mut rx) = channel::(1000); + + // Global flag to stop the session + static RUNNING: AtomicBool = AtomicBool::new(true); + + let reader = tokio::task::spawn(async move { + let block = async { + while RUNNING.load(Ordering::Relaxed) { + let mut bytes = [0u8; 1500]; + + // recieved IP packet + let len = reader_session.read(&mut bytes).await?; + if len == 0 { + break; + } + + let udp_packet = extract_udp_packet(&bytes[..len]); + if let Err(err) = udp_packet { + println!("{}", err); + continue; + } + + // swap src and dst + let mut udp_packet = udp_packet?; + let src_addr = udp_packet.src_addr; + let dst_addr = udp_packet.dst_addr; + udp_packet.src_addr = dst_addr; + udp_packet.dst_addr = src_addr; + + // send to writer + tx.send(udp_packet).await?; + } + Ok::<(), BoxError>(()) + }; + if let Err(err) = block.await { + println!("Reader {}", err); + } + }); + + let writer = tokio::task::spawn(async move { + let block = async { + while RUNNING.load(Ordering::Relaxed) { + let resp = rx.recv().await.ok_or("Channel closed")?; + + let src_addr = match resp.src_addr.ip() { + IpAddr::V4(addr) => addr, + IpAddr::V6(_) => return Err("IPv6 addresses are not supported".into()), + }; + + let dst_addr = match resp.dst_addr.ip() { + IpAddr::V4(addr) => addr, + IpAddr::V6(_) => return Err("IPv6 addresses are not supported".into()), + }; + + let v = generate_random_bytes(2)?; + let id = u16::from_ne_bytes([v[0], v[1]]); + + // build response IP packet + use packet::Builder; + let ip_packet = packet::ip::v4::Builder::default() + .id(id)? + .ttl(64)? + .source(src_addr)? + .destination(dst_addr)? + .udp()? + .source(resp.src_addr.port())? + .destination(resp.dst_addr.port())? + .payload(&resp.data)? + .build()?; + + // // The following code will be better than above, the `ipv4_udp_build` function link is + // // + // // https://github.com/pysrc/study-udp/blob/59d7ba210a022d207c60ad5370de37110fefaefb/src/protocol.rs#L157-L252 + // // + // let mut ip_packet = vec![0u8; 28 + resp.data.len()]; + // protocol::ipv4_udp_build( + // &mut ip_packet, + // &src_addr.octets(), + // resp.src_addr.port(), + // &dst_addr.octets(), + // resp.dst_addr.port(), + // &resp.data, + // ); + + writer_session.write_all(&ip_packet).await?; + } + Ok::<(), BoxError>(()) + }; + if let Err(err) = block.await { + println!("Writer {}", err); + } + }); + + println!("Press enter to stop session"); + + let mut line = String::new(); + let _ = std::io::stdin().read_line(&mut line); + println!("Shutting down session"); + RUNNING.store(false, Ordering::Relaxed); + session.shutdown()?; + let _ = reader.await; + let _ = writer.await; + Ok(()) +} + +fn extract_udp_packet(packet: &[u8]) -> Result { + use packet::{ip, udp, AsPacket, Packet}; + let packet: ip::Packet<_> = packet.as_packet().map_err(|err| format!("{}", err))?; + let info: String; + match packet { + ip::Packet::V4(a) => { + let src_addr = a.source(); + let dst_addr = a.destination(); + let protocol = a.protocol(); + let payload = a.payload(); + match protocol { + ip::Protocol::Udp => { + let udp = udp::Packet::new(payload).map_err(|err| format!("{}", err))?; + let src_port = udp.source(); + let dst_port = udp.destination(); + let src_addr = SocketAddr::new(src_addr.into(), src_port); + let dst_addr = SocketAddr::new(dst_addr.into(), dst_port); + let data = udp.payload(); + let udp_packet = NaiveUdpPacket::new(src_addr, dst_addr, data); + log::trace!("{protocol:?} {}", udp_packet); + return Ok(udp_packet); + } + _ => { + info = format!("{:?} src={}, dst={}", protocol, src_addr, dst_addr); + } + } + } + ip::Packet::V6(a) => { + info = format!("{:?}", a); + } + } + Err(info.into()) +} + +fn generate_random_bytes(len: usize) -> std::io::Result> { + let mut buf = vec![0u8; len]; + unsafe { + let mut h_prov = 0_usize; + let null = std::ptr::null_mut(); + if FALSE == CryptAcquireContextW(&mut h_prov, null, null, PROV_RSA_FULL, 0) { + return Err(std::io::Error::last_os_error()); + } + if FALSE == CryptGenRandom(h_prov, buf.len() as _, buf.as_mut_ptr()) { + return Err(std::io::Error::last_os_error()); + } + if FALSE == CryptReleaseContext(h_prov, 0) { + return Err(std::io::Error::last_os_error()); + } + }; + Ok(buf) +} diff --git a/src/adapter.rs b/src/adapter.rs index de9f8b2..4cb83fe 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -5,8 +5,9 @@ /// wintun functionality use crate::{ error::{Error, OutOfRangeData}, + handle::{SafeEvent, UnsafeHandle}, session, - util::{self, UnsafeHandle}, + util::{self}, wintun_raw, Wintun, }; use std::{ @@ -20,11 +21,7 @@ use std::{ }; use windows_sys::{ core::GUID, - Win32::{ - Foundation::FALSE, - NetworkManagement::{IpHelper::ConvertLengthToIpv4Mask, Ndis::NET_LUID_LH}, - System::Threading::CreateEventA, - }, + Win32::NetworkManagement::{IpHelper::ConvertLengthToIpv4Mask, Ndis::NET_LUID_LH}, }; /// Wrapper around a @@ -157,11 +154,12 @@ impl Adapter { if result.is_null() { return Err("WintunStartSession failed".into()); } - let shutdown_event = unsafe { CreateEventA(std::ptr::null_mut(), FALSE, FALSE, std::ptr::null_mut()) }; + // Manual reset, because we use this event once and it must fire on all threads + let shutdown_event = SafeEvent::new(true, false)?; Ok(session::Session { session: UnsafeHandle(result), read_event: OnceLock::new(), - shutdown_event: UnsafeHandle(shutdown_event), + shutdown_event: Arc::new(shutdown_event), adapter: self.clone(), }) } diff --git a/src/async_session.rs b/src/async_session.rs new file mode 100644 index 0000000..12792f9 --- /dev/null +++ b/src/async_session.rs @@ -0,0 +1,116 @@ +use crate::{handle::UnsafeHandle, Error}; +use futures::{AsyncRead, AsyncWrite}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use windows_sys::Win32::{ + Foundation::{FALSE, HANDLE, WAIT_ABANDONED_0, WAIT_EVENT, WAIT_OBJECT_0}, + System::Threading::{WaitForMultipleObjects, INFINITE}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum WaitingStopReason { + Shutdown, + Ready, +} + +#[derive(Debug)] +enum ReadState { + Waiting(Option>), + Idle, + Closed, +} + +pub struct AsyncSession { + session: Arc, + read_state: ReadState, +} + +impl TryFrom> for AsyncSession { + type Error = Error; + + fn try_from(session: Arc) -> Result { + Ok(Self { + session, + read_state: ReadState::Idle, + }) + } +} + +impl Drop for AsyncSession { + fn drop(&mut self) { + self.session.shutdown().ok(); + } +} + +impl AsyncSession { + fn wait_for_read(read_event: UnsafeHandle, shutdown_event: UnsafeHandle) -> WaitingStopReason { + const WAIT_OBJECT_1: WAIT_EVENT = WAIT_OBJECT_0 + 1; + const WAIT_ABANDONED_1: WAIT_EVENT = WAIT_ABANDONED_0 + 1; + let handles = [shutdown_event.0, read_event.0]; + match unsafe { WaitForMultipleObjects(handles.len() as u32, &handles as _, FALSE, INFINITE) } { + WAIT_OBJECT_0 | WAIT_ABANDONED_0 => WaitingStopReason::Shutdown, + WAIT_OBJECT_1 => WaitingStopReason::Ready, + WAIT_ABANDONED_1 => panic!("Read event deleted unexpectedly"), + e => panic!("WaitForMultipleObjects returned unexpected value {:?}", e), + } + } +} + +impl AsyncRead for AsyncSession { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + loop { + match &mut self.read_state { + ReadState::Idle => match self.session.try_receive() { + Ok(Some(packet)) => { + let size = packet.bytes.len().min(buf.len()); + buf[..size].copy_from_slice(&packet.bytes[..size]); + return Poll::Ready(Ok(size)); + } + Ok(None) => { + let read_event = self + .session + .get_read_wait_event() + .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))?; + let shutdown_event = self.session.shutdown_event.get_handle(); + self.read_state = ReadState::Waiting(Some(blocking::unblock(move || { + Self::wait_for_read(read_event, shutdown_event) + }))); + } + Err(err) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err))), + }, + ReadState::Waiting(task) => { + let mut task = task.take().unwrap(); + self.read_state = match Pin::new(&mut task).poll(cx) { + Poll::Ready(WaitingStopReason::Shutdown) => ReadState::Closed, + Poll::Ready(WaitingStopReason::Ready) => ReadState::Idle, + Poll::Pending => ReadState::Waiting(Some(task)), + }; + if let ReadState::Waiting(_) = self.read_state { + return Poll::Pending; + } + } + ReadState::Closed => return Poll::Ready(Ok(0)), + } + } + } +} + +impl AsyncWrite for AsyncSession { + fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let packet = self.session.allocate_send_packet(buf.len() as _)?; + packet.bytes.copy_from_slice(buf); + self.session.send_packet(packet); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.session.shutdown()?; + Poll::Ready(Ok(())) + } +} diff --git a/src/handle.rs b/src/handle.rs new file mode 100644 index 0000000..cd4daf1 --- /dev/null +++ b/src/handle.rs @@ -0,0 +1,62 @@ +use windows_sys::Win32::{ + Foundation::{CloseHandle, FALSE, HANDLE}, + System::Threading::{CreateEventW, SetEvent}, +}; + +use crate::{util::get_last_error, Error}; + +/// A wrapper struct that allows a type to be Send and Sync +#[derive(Copy, Clone, Debug)] +pub struct UnsafeHandle(pub T); + +/// We never read from the pointer. It only serves as a handle we pass to the kernel or C code that +/// doesn't have the same mutable aliasing restrictions we have in Rust +unsafe impl Send for UnsafeHandle {} +unsafe impl Sync for UnsafeHandle {} + +#[derive(Debug)] +pub(crate) struct SafeEvent(pub UnsafeHandle); + +impl From> for SafeEvent { + fn from(handle: UnsafeHandle) -> Self { + Self(handle) + } +} + +impl SafeEvent { + pub(crate) fn new(manual_reset: bool, initial_state: bool) -> Result { + let null = std::ptr::null(); + let handle = unsafe { CreateEventW(null, manual_reset as _, initial_state as _, std::ptr::null()) }; + if handle.is_null() { + return Err(get_last_error()?.into()); + } + Ok(Self(UnsafeHandle(handle))) + } + + pub(crate) fn set_event(&self) -> Result<(), Error> { + if unsafe { SetEvent(self.0 .0) } == FALSE { + return Err(get_last_error()?.into()); + } + Ok(()) + } + + pub(crate) fn close_handle(&self) -> Result<(), Error> { + if !self.0 .0.is_null() && unsafe { CloseHandle(self.0 .0) } == FALSE { + return Err(get_last_error()?.into()); + } + Ok(()) + } + + #[allow(dead_code)] + pub(crate) fn get_handle(&self) -> UnsafeHandle { + self.0 + } +} + +impl Drop for SafeEvent { + fn drop(&mut self) { + if let Err(e) = self.close_handle() { + log::trace!("Failed to close event handle: {}", e); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 1836d51..f24c364 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,8 +76,11 @@ //! mod adapter; +#[cfg(feature = "async")] +mod async_session; mod error; mod ffi; +mod handle; mod log; mod packet; mod session; @@ -92,9 +95,13 @@ pub(crate) const WINTUN_PROVIDER: &str = "WireGuard LLC"; #[allow(dead_code, unused_variables, deref_nullptr, clippy::all)] mod wintun_raw; +#[cfg(feature = "async")] +pub use crate::async_session::AsyncSession; + pub use crate::{ adapter::Adapter, error::{BoxError, Error, OutOfRangeData, Result}, + handle::UnsafeHandle, log::{default_logger, reset_logger, set_logger}, packet::Packet, session::Session, diff --git a/src/session.rs b/src/session.rs index 694c7fb..b202813 100644 --- a/src/session.rs +++ b/src/session.rs @@ -1,14 +1,11 @@ use crate::{ - packet, - util::{self, UnsafeHandle}, - wintun_raw, Adapter, Error, Wintun, + handle::{SafeEvent, UnsafeHandle}, + packet, util, wintun_raw, Adapter, Error, Wintun, }; use std::{ptr, slice, sync::Arc, sync::OnceLock}; use windows_sys::Win32::{ - Foundation::{ - CloseHandle, GetLastError, ERROR_NO_MORE_ITEMS, FALSE, HANDLE, WAIT_EVENT, WAIT_FAILED, WAIT_OBJECT_0, - }, - System::Threading::{SetEvent, WaitForMultipleObjects, INFINITE}, + Foundation::{GetLastError, ERROR_NO_MORE_ITEMS, FALSE, HANDLE, WAIT_EVENT, WAIT_FAILED, WAIT_OBJECT_0}, + System::Threading::{WaitForMultipleObjects, INFINITE}, }; /// Wrapper around a @@ -22,7 +19,7 @@ pub struct Session { /// Windows event handle that is signaled when [`Session::shutdown`] is called force blocking /// readers to exit - pub(crate) shutdown_event: UnsafeHandle, + pub(crate) shutdown_event: Arc, /// The adapter that owns this session pub(crate) adapter: Arc, @@ -98,12 +95,11 @@ impl Session { /// # Safety /// Returns the low level read event handle that is signaled when more data becomes available /// to read - pub unsafe fn get_read_wait_event(&self) -> Result { + pub fn get_read_wait_event(&self) -> Result, Error> { let wintun = self.get_wintun(); - Ok(self + Ok(*self .read_event - .get_or_init(|| UnsafeHandle(wintun.WintunGetReadWaitEvent(self.session.0) as _)) - .0) + .get_or_init(|| UnsafeHandle(unsafe { wintun.WintunGetReadWaitEvent(self.session.0) as _ }))) } /// Blocks until a packet is available, returning the next packet in the receive queue once this happens. @@ -124,7 +120,7 @@ impl Session { } } //Wait on both the read handle and the shutdown handle so that we stop when requested - let handles = [unsafe { self.get_read_wait_event()? }, self.shutdown_event.0]; + let handles = [self.get_read_wait_event()?.0, self.shutdown_event.0 .0]; let result = unsafe { //SAFETY: We abide by the requirements of WaitForMultipleObjects, handles is a //pointer to valid, aligned, stack memory @@ -151,17 +147,14 @@ impl Session { /// Cancels any active calls to [`Session::receive_blocking`] making them instantly return Err(_) so that session can be shutdown cleanly pub fn shutdown(&self) -> Result<(), Error> { - if FALSE == unsafe { SetEvent(self.shutdown_event.0) } { - return Err(util::get_last_error()?.into()); - } + self.shutdown_event.set_event()?; Ok(()) } } impl Drop for Session { fn drop(&mut self) { - if FALSE == unsafe { CloseHandle(self.shutdown_event.0) } { - let err = util::get_last_error(); + if let Err(err) = self.shutdown_event.close_handle() { log::error!("Failed to close handle of shutdown event: {:?}", err); } diff --git a/src/util.rs b/src/util.rs index 3aad16c..f4f7801 100644 --- a/src/util.rs +++ b/src/util.rs @@ -75,15 +75,6 @@ pub(crate) unsafe fn win_pwstr_to_string(pwstr: ::windows_sys::core::PWSTR) -> R .map_err(|e| format!("Invalid UTF-8 sequence: {:?}", e).into()) } -/// A wrapper struct that allows a type to be Send and Sync -#[derive(Copy, Clone, Debug)] -pub(crate) struct UnsafeHandle(pub T); - -/// We never read from the pointer. It only serves as a handle we pass to the kernel or C code that -/// doesn't have the same mutable aliasing restrictions we have in Rust -unsafe impl Send for UnsafeHandle {} -unsafe impl Sync for UnsafeHandle {} - pub(crate) fn guid_to_win_style_string(guid: &GUID) -> Result { let mut buffer = [0u16; 40]; unsafe { StringFromGUID2(guid, &mut buffer as *mut u16, buffer.len() as i32) };