From 047aa6393e533082c1a0625891c9646debf3beac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 02:40:36 +0800 Subject: [PATCH 1/7] feat: add ProcessSocketNotifications support - add `iocp-psn` feature gate - a new Poller with ProcessSocketNotifications - wait for waitable objects with NtAssociateWaitCompletionPacket --- Cargo.toml | 4 + src/iocp/mod.rs | 1379 +-------------------------------- src/iocp/psn/mod.rs | 576 ++++++++++++++ src/iocp/psn/wait.rs | 93 +++ src/iocp/{ => wepoll}/afd.rs | 0 src/iocp/wepoll/mod.rs | 1375 ++++++++++++++++++++++++++++++++ src/iocp/{ => wepoll}/port.rs | 0 tests/multiple_pollers.rs | 3 + tests/other_modes.rs | 17 +- 9 files changed, 2075 insertions(+), 1372 deletions(-) create mode 100644 src/iocp/psn/mod.rs create mode 100644 src/iocp/psn/wait.rs rename src/iocp/{ => wepoll}/afd.rs (100%) create mode 100644 src/iocp/wepoll/mod.rs rename src/iocp/{ => wepoll}/port.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index e045958..a1d368a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,3 +61,7 @@ libc = "0.2" [target.'cfg(all(unix, not(target_os="vita")))'.dev-dependencies] signal-hook = "0.3.17" + +[features] +default = [] +iocp-psn = [] diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 36609f4..4bcfddb 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -1,1321 +1,16 @@ -//! Bindings to Windows I/O Completion Ports. -//! -//! I/O Completion Ports is a completion-based API rather than a polling-based API, like -//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API. -//! -//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be -//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not -//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is -//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready", -//! a completion packet is queued to an I/O completion port. -//! -//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API -//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL -//! operation is started and queued to the IOCP. To modify a currently registered device -//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted -//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP. -//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets -//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can -//! simply post a packet to the IOCP to wake it up. -//! -//! The main disadvantage of this strategy is that it relies on unstable Windows APIs. -//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable -//! AFD strategy, it is unlikely to be broken without plenty of advanced warning. -//! -//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar -//! AFD-based strategy for polling. - -mod afd; -mod port; - -use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock}; -use port::{IoCompletionPort, OverlappedEntry}; - -use windows_sys::Win32::Foundation::{ - BOOLEAN, ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED, -}; -use windows_sys::Win32::System::Threading::{ - RegisterWaitForSingleObject, UnregisterWait, INFINITE, WT_EXECUTELONGFUNCTION, - WT_EXECUTEONLYONCE, -}; - -use crate::{Event, PollMode}; - -use concurrent_queue::ConcurrentQueue; -use pin_project_lite::pin_project; - -use std::cell::UnsafeCell; -use std::collections::hash_map::{Entry, HashMap}; -use std::ffi::c_void; -use std::fmt; -use std::io; -use std::marker::PhantomPinned; -use std::mem::{forget, MaybeUninit}; -use std::os::windows::io::{ - AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket, -}; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; -use std::time::{Duration, Instant}; - -/// Macro to lock and ignore lock poisoning. -macro_rules! lock { - ($lock_result:expr) => {{ - $lock_result.unwrap_or_else(|e| e.into_inner()) - }}; -} - -/// Interface to I/O completion ports. -#[derive(Debug)] -pub(super) struct Poller { - /// The I/O completion port. - port: Arc>, - - /// List of currently active AFD instances. - /// - /// AFD acts as the actual source of the socket events. It's essentially running `WSAPoll` on - /// the sockets and then posting the events to the IOCP. - /// - /// AFD instances can be keyed to an unlimited number of sockets. However, each AFD instance - /// polls their sockets linearly. Therefore, it is best to limit the number of sockets each AFD - /// instance is responsible for. The limit of 32 is chosen because that's what `wepoll` uses. - /// - /// Weak references are kept here so that the AFD handle is automatically dropped when the last - /// associated socket is dropped. - afd: Mutex>>>, - - /// The state of the sources registered with this poller. - /// - /// Each source is keyed by its raw socket ID. - sources: RwLock>, - - /// The state of the waitable handles registered with this poller. - waitables: RwLock>, - - /// Sockets with pending updates. - /// - /// This list contains packets with sockets that need to have their AFD state adjusted by - /// calling the `update()` function on them. It's best to queue up packets as they need to - /// be updated and then run all of the updates before we start waiting on the IOCP, rather than - /// updating them as we come. If we're waiting on the IOCP updates should be run immediately. - pending_updates: ConcurrentQueue, - - /// Are we currently polling? - /// - /// This indicates whether or not we are blocking on the IOCP, and is used to determine - /// whether pending updates should be run immediately or queued. - polling: AtomicBool, - - /// The packet used to notify the poller. - /// - /// This is a special-case packet that is used to wake up the poller when it is waiting. - notifier: Packet, -} - -unsafe impl Send for Poller {} -unsafe impl Sync for Poller {} - -impl Poller { - /// Creates a new poller. - pub(super) fn new() -> io::Result { - // Make sure AFD is able to be used. - if let Err(e) = afd::NtdllImports::force_load() { - return Err(io::Error::new( - io::ErrorKind::Unsupported, - AfdError::new("failed to initialize unstable Windows functions", e), - )); - } - - // Create and destroy a single AFD to test if we support it. - Afd::::new().map_err(|e| { - io::Error::new( - io::ErrorKind::Unsupported, - AfdError::new("failed to initialize \\Device\\Afd", e), - ) - })?; - - let port = IoCompletionPort::new(0)?; - tracing::trace!(handle = ?port, "new"); - - Ok(Poller { - port: Arc::new(port), - afd: Mutex::new(vec![]), - sources: RwLock::new(HashMap::new()), - waitables: RwLock::new(HashMap::new()), - pending_updates: ConcurrentQueue::bounded(1024), - polling: AtomicBool::new(false), - notifier: Arc::pin( - PacketInner::Wakeup { - _pinned: PhantomPinned, - } - .into(), - ), - }) - } - - /// Whether this poller supports level-triggered events. - pub(super) fn supports_level(&self) -> bool { - true - } - - /// Whether this poller supports edge-triggered events. - pub(super) fn supports_edge(&self) -> bool { - false - } - - /// Add a new source to the poller. - /// - /// # Safety - /// - /// The socket must be a valid socket and must last until it is deleted. - pub(super) unsafe fn add( - &self, - socket: RawSocket, - interest: Event, - mode: PollMode, - ) -> io::Result<()> { - let span = tracing::trace_span!( - "add", - handle = ?self.port, - sock = ?socket, - ev = ?interest, - ); - let _enter = span.enter(); - - // We don't support edge-triggered events. - if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "edge-triggered events are not supported", - )); - } - - // Create a new packet. - let socket_state = { - // Create a new socket state and assign an AFD handle to it. - let state = SocketState { - socket, - base_socket: base_socket(socket)?, - interest, - interest_error: true, - afd: self.afd_handle()?, - mode, - waiting_on_delete: false, - status: SocketStatus::Idle, - }; - - // We wrap this socket state in a Packet so the IOCP can use it. - Arc::pin(IoStatusBlock::from(PacketInner::Socket { - packet: UnsafeCell::new(AfdPollInfo::default()), - socket: Mutex::new(state), - })) - }; - - // Keep track of the source in the poller. - { - let mut sources = lock!(self.sources.write()); - - match sources.entry(socket) { - Entry::Vacant(v) => { - v.insert(Pin::>::clone(&socket_state)); - } - - Entry::Occupied(_) => { - return Err(io::Error::from(io::ErrorKind::AlreadyExists)); - } - } - } - - // Update the packet. - self.update_packet(socket_state) - } - - /// Update a source in the poller. - pub(super) fn modify( - &self, - socket: BorrowedSocket<'_>, - interest: Event, - mode: PollMode, - ) -> io::Result<()> { - let span = tracing::trace_span!( - "modify", - handle = ?self.port, - sock = ?socket, - ev = ?interest, - ); - let _enter = span.enter(); - - // We don't support edge-triggered events. - if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "edge-triggered events are not supported", - )); - } - - // Get a reference to the source. - let source = { - let sources = lock!(self.sources.read()); - - sources - .get(&socket.as_raw_socket()) - .cloned() - .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? - }; - - // Set the new event. - if source.as_ref().set_events(interest, mode) { - // The packet needs to be updated. - self.update_packet(source)?; - } - - Ok(()) - } - - /// Delete a source from the poller. - pub(super) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> { - let span = tracing::trace_span!( - "remove", - handle = ?self.port, - sock = ?socket, - ); - let _enter = span.enter(); - - // Remove the source from our associative map. - let source = { - let mut sources = lock!(self.sources.write()); - - match sources.remove(&socket.as_raw_socket()) { - Some(s) => s, - None => { - // If the source has already been removed, then we can just return. - return Ok(()); - } - } - }; - - // Indicate to the source that it is being deleted. - // This cancels any ongoing AFD_IOCTL_POLL operations. - source.begin_delete() - } - - /// Add a new waitable to the poller. - pub(super) fn add_waitable( - &self, - handle: RawHandle, - interest: Event, - mode: PollMode, - ) -> io::Result<()> { - tracing::trace!( - "add_waitable: handle={:?}, waitable={:p}, ev={:?}", - self.port, - handle, - interest - ); - - // We don't support edge-triggered events. - if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "edge-triggered events are not supported", - )); - } - - // Create a new packet. - let handle_state = { - let state = WaitableState { - handle, - port: Arc::downgrade(&self.port), - interest, - mode, - status: WaitableStatus::Idle, - }; - - Arc::pin(IoStatusBlock::from(PacketInner::Waitable { - handle: Mutex::new(state), - })) - }; - - // Keep track of the source in the poller. - { - let mut sources = lock!(self.waitables.write()); - - match sources.entry(handle) { - Entry::Vacant(v) => { - v.insert(Pin::>::clone(&handle_state)); - } - - Entry::Occupied(_) => { - return Err(io::Error::from(io::ErrorKind::AlreadyExists)); - } - } - } - - // Update the packet. - self.update_packet(handle_state) - } - - /// Update a waitable in the poller. - pub(crate) fn modify_waitable( - &self, - waitable: RawHandle, - interest: Event, - mode: PollMode, - ) -> io::Result<()> { - tracing::trace!( - "modify_waitable: handle={:?}, waitable={:p}, ev={:?}", - self.port, - waitable, - interest - ); - - // We don't support edge-triggered events. - if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "edge-triggered events are not supported", - )); - } - - // Get a reference to the source. - let source = { - let sources = lock!(self.waitables.read()); - - sources - .get(&waitable) - .cloned() - .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? - }; - - // Set the new event. - if source.as_ref().set_events(interest, mode) { - self.update_packet(source)?; - } - - Ok(()) - } - - /// Delete a waitable from the poller. - pub(super) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> { - tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable); - - // Get a reference to the source. - let source = { - let mut sources = lock!(self.waitables.write()); - - match sources.remove(&waitable) { - Some(s) => s, - None => { - // If the source has already been removed, then we can just return. - return Ok(()); - } - } - }; - - // Indicate to the source that it is being deleted. - // This cancels any ongoing AFD_IOCTL_POLL operations. - source.begin_delete() - } - - /// Wait for events. - pub(super) fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { - let span = tracing::trace_span!( - "wait", - handle = ?self.port, - ?timeout, - ); - let _enter = span.enter(); - - // Make sure we have a consistent timeout. - let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout)); - let mut notified = false; - events.packets.clear(); - - loop { - let mut new_events = 0; - - // Indicate that we are now polling. - let was_polling = self.polling.swap(true, Ordering::SeqCst); - debug_assert!(!was_polling); - - // Even if we panic, we want to make sure we indicate that polling has stopped. - let guard = CallOnDrop(|| { - let was_polling = self.polling.swap(false, Ordering::SeqCst); - debug_assert!(was_polling); - }); - - // Process every entry in the queue before we start polling. - self.drain_update_queue(false)?; - - // Get the time to wait for. - let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now())); - - // Wait for I/O events. - let len = self.port.wait(&mut events.completions, timeout)?; - tracing::trace!( - handle = ?self.port, - res = ?len, - "new events"); - - // We are no longer polling. - drop(guard); - - // Process all of the events. - for entry in events.completions.drain(..) { - let packet = entry.into_packet(); - - // Feed the event into the packet. - match packet.feed_event(self)? { - FeedEventResult::NoEvent => {} - FeedEventResult::Event(event) => { - events.packets.push(event); - new_events += 1; - } - FeedEventResult::Notified => { - notified = true; - } - } - } - - // Break if there was a notification or at least one event, or if deadline is reached. - let timeout_is_empty = - timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0); - if notified || new_events > 0 || timeout_is_empty { - break; - } - - tracing::trace!("wait: no events found, re-entering polling loop"); - } - - Ok(()) - } - - /// Notify this poller. - pub(super) fn notify(&self) -> io::Result<()> { - // Push the notify packet into the IOCP. - self.port.post(0, 0, self.notifier.clone()) - } - - /// Push an IOCP packet into the queue. - pub(super) fn post(&self, packet: CompletionPacket) -> io::Result<()> { - self.port.post(0, 0, packet.0) - } - - /// Run an update on a packet. - fn update_packet(&self, mut packet: Packet) -> io::Result<()> { - loop { - // If we are currently polling, we need to update the packet immediately. - if self.polling.load(Ordering::Acquire) { - packet.update()?; - return Ok(()); - } - - // Try to queue the update. - match self.pending_updates.push(packet) { - Ok(()) => return Ok(()), - Err(p) => packet = p.into_inner(), - } - - // If we failed to queue the update, we need to drain the queue first. - self.drain_update_queue(true)?; - - // Loop back and try again. - } - } - - /// Drain the update queue. - fn drain_update_queue(&self, limit: bool) -> io::Result<()> { - // Determine how many packets to process. - let max = if limit { - // Only drain the queue's capacity, since this could in theory run forever. - self.pending_updates.capacity().unwrap() - } else { - // Less of a concern if we're draining the queue prior to a poll operation. - std::usize::MAX - }; - - self.pending_updates - .try_iter() - .take(max) - .try_for_each(|packet| packet.update()) - } - - /// Get a handle to the AFD reference. - /// - /// This finds an AFD handle with less than 32 associated sockets, or creates a new one if - /// one does not exist. - fn afd_handle(&self) -> io::Result>> { - const AFD_MAX_SIZE: usize = 32; - - // Crawl the list and see if there are any existing AFD instances that we can use. - // While we're here, remove any unused AFD pointers. - let mut afd_handles = lock!(self.afd.lock()); - let mut i = 0; - while i < afd_handles.len() { - // Get the reference count of the AFD instance. - let refcount = Weak::strong_count(&afd_handles[i]); - - match refcount { - 0 => { - // Prune the AFD pointer if it has no references. - afd_handles.swap_remove(i); - } - - refcount if refcount >= AFD_MAX_SIZE => { - // Skip this one, since it is already at the maximum size. - i += 1; - } - - _ => { - // We can use this AFD instance. - match afd_handles[i].upgrade() { - Some(afd) => return Ok(afd), - None => { - // The last socket dropped the AFD before we could acquire it. - // Prune the AFD pointer and continue. - afd_handles.swap_remove(i); - } - } - } - } - } - - // No available handles, create a new AFD instance. - let afd = Arc::new(Afd::new()?); - - // Register the AFD instance with the I/O completion port. - self.port.register(&*afd, true)?; - - // Insert a weak pointer to the AFD instance into the list for other sockets. - afd_handles.push(Arc::downgrade(&afd)); - - Ok(afd) - } -} - -impl AsRawHandle for Poller { - fn as_raw_handle(&self) -> RawHandle { - self.port.as_raw_handle() - } -} - -impl AsHandle for Poller { - fn as_handle(&self) -> BorrowedHandle<'_> { - unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) } - } -} - -/// The container for events. -pub(super) struct Events { - /// List of IOCP packets. - packets: Vec, - - /// Buffer for completion packets. - completions: Vec>, -} - -unsafe impl Send for Events {} - -impl Events { - /// Creates an empty list of events. - pub fn with_capacity(cap: usize) -> Events { - Events { - packets: Vec::with_capacity(cap), - completions: Vec::with_capacity(cap), - } - } - - /// Iterate over I/O events. - pub fn iter(&self) -> impl Iterator + '_ { - self.packets.iter().copied() - } - - /// Clear the list of events. - pub fn clear(&mut self) { - self.packets.clear(); - } - - /// The capacity of the list of events. - pub fn capacity(&self) -> usize { - self.packets.capacity() - } -} - -/// Extra information about an event. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct EventExtra { - /// Flags associated with this event. - flags: AfdPollMask, -} - -impl EventExtra { - /// Create a new, empty version of this struct. - #[inline] - pub const fn empty() -> EventExtra { - EventExtra { - flags: AfdPollMask::empty(), - } - } - - /// Is this a HUP event? - #[inline] - pub fn is_hup(&self) -> bool { - self.flags.intersects(AfdPollMask::ABORT) - } - - /// Is this a PRI event? - #[inline] - pub fn is_pri(&self) -> bool { - self.flags.intersects(AfdPollMask::RECEIVE_EXPEDITED) - } - - /// Set up a listener for HUP events. - #[inline] - pub fn set_hup(&mut self, active: bool) { - self.flags.set(AfdPollMask::ABORT, active); - } - - /// Set up a listener for PRI events. - #[inline] - pub fn set_pri(&mut self, active: bool) { - self.flags.set(AfdPollMask::RECEIVE_EXPEDITED, active); - } - - /// Check if TCP connect failed. Deprecated. - #[inline] - pub fn is_connect_failed(&self) -> Option { - Some(self.flags.intersects(AfdPollMask::CONNECT_FAIL)) - } - - /// Check if TCP connect failed. - #[inline] - pub fn is_err(&self) -> Option { - Some(self.flags.intersects(AfdPollMask::CONNECT_FAIL)) - } -} - -/// A packet used to wake up the poller with an event. -#[derive(Debug, Clone)] -pub struct CompletionPacket(Packet); - -impl CompletionPacket { - /// Create a new completion packet with a custom event. - pub fn new(event: Event) -> Self { - Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event }))) - } - - /// Get the event associated with this packet. - pub fn event(&self) -> &Event { - let data = self.0.as_ref().data().project_ref(); - - match data { - PacketInnerProj::Custom { event } => event, - _ => unreachable!(), - } - } -} - -/// The type of our completion packet. -/// -/// It needs to be pinned, since it contains data that is expected by IOCP not to be moved. -type Packet = Pin>; -type PacketUnwrapped = IoStatusBlock; - -pin_project! { - /// The inner type of the packet. - #[project_ref = PacketInnerProj] - #[project = PacketInnerProjMut] - enum PacketInner { - // A packet for a socket. - Socket { - // The AFD packet state. - #[pin] - packet: UnsafeCell, - - // The socket state. - socket: Mutex - }, - - /// A packet for a waitable handle. - Waitable { - handle: Mutex - }, - - /// A custom event sent by the user. - Custom { - event: Event, - }, - - // A packet used to wake up the poller. - Wakeup { #[pin] _pinned: PhantomPinned }, +cfg_if::cfg_if! { + if #[cfg(feature = "iocp-psn")] { + mod psn; + pub use psn::*; + } else { + mod wepoll; + pub use wepoll::*; } } -unsafe impl Send for PacketInner {} -unsafe impl Sync for PacketInner {} +use std::time::Duration; -impl fmt::Debug for PacketInner { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Wakeup { .. } => f.write_str("Wakeup { .. }"), - Self::Custom { event } => f.debug_struct("Custom").field("event", event).finish(), - Self::Socket { socket, .. } => f - .debug_struct("Socket") - .field("packet", &"..") - .field("socket", socket) - .finish(), - Self::Waitable { handle } => { - f.debug_struct("Waitable").field("handle", handle).finish() - } - } - } -} - -impl HasAfdInfo for PacketInner { - fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { - match self.project_ref() { - PacketInnerProj::Socket { packet, .. } => packet, - _ => unreachable!(), - } - } -} - -impl PacketUnwrapped { - /// Set the new events that this socket is waiting on. - /// - /// Returns `true` if we need to be updated. - fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool { - match self.data().project_ref() { - PacketInnerProj::Socket { socket, .. } => { - let mut socket = lock!(socket.lock()); - socket.interest = interest; - socket.mode = mode; - socket.interest_error = true; - - // If there was a change, indicate that we need an update. - match socket.status { - SocketStatus::Polling { flags } => { - let our_flags = event_to_afd_mask(socket.interest, socket.interest_error); - our_flags != flags - } - _ => true, - } - } - PacketInnerProj::Waitable { handle } => { - let mut handle = lock!(handle.lock()); - - // Set the new interest. - handle.interest = interest; - handle.mode = mode; - - // Update if there is no ongoing wait. - handle.status.is_idle() - } - _ => true, - } - } - - /// Update the socket and install the new status in AFD. - /// - /// This function does one of the following: - /// - /// - Nothing, if the packet is waiting on being dropped anyways. - /// - Cancels the ongoing poll, if we want to poll for different events than we are currently - /// polling for. - /// - Starts a new AFD_POLL operation, if we are not currently polling. - fn update(self: Pin>) -> io::Result<()> { - let mut socket = match self.as_ref().data().project_ref() { - PacketInnerProj::Socket { socket, .. } => lock!(socket.lock()), - PacketInnerProj::Waitable { handle } => { - let mut handle = lock!(handle.lock()); - - // If there is no interests, or if we have been cancelled, we don't need to update. - if !handle.interest.readable && !handle.interest.writable { - return Ok(()); - } - - // If we are idle, we need to update. - if !handle.status.is_idle() { - return Ok(()); - } - - // Start a new wait. - let packet = self.clone(); - let wait_handle = WaitHandle::new( - handle.handle, - move || { - let mut handle = match packet.as_ref().data().project_ref() { - PacketInnerProj::Waitable { handle } => lock!(handle.lock()), - _ => unreachable!(), - }; - - // Try to get the IOCP. - let iocp = match handle.port.upgrade() { - Some(iocp) => iocp, - None => return, - }; - - // Set us back into the idle state. - handle.status = WaitableStatus::Idle; - - // Push this packet. - drop(handle); - if let Err(e) = iocp.post(0, 0, packet) { - tracing::error!("failed to post completion packet: {}", e); - } - }, - None, - false, - )?; - - // Set the new status. - handle.status = WaitableStatus::Waiting(wait_handle); - - return Ok(()); - } - _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")), - }; - - // If we are waiting on a delete, just return, dropping the packet. - if socket.waiting_on_delete { - return Ok(()); - } - - // Check the current status. - match socket.status { - SocketStatus::Polling { flags } => { - // If we need to poll for events aside from what we are currently polling, we need - // to update the packet. Cancel the ongoing poll. - let our_flags = event_to_afd_mask(socket.interest, socket.interest_error); - if our_flags != flags { - return self.cancel(socket); - } - - // All events that we are currently waiting on are accounted for. - Ok(()) - } - - SocketStatus::Cancelled => { - // The ongoing operation was cancelled, and we're still waiting for it to return. - // For now, wait until the top-level loop calls feed_event(). - Ok(()) - } - - SocketStatus::Idle => { - // Start a new poll. - let mask = event_to_afd_mask(socket.interest, socket.interest_error); - let result = socket.afd.poll(self.clone(), socket.base_socket, mask); - - match result { - Ok(()) => {} - - Err(err) - if err.raw_os_error() == Some(ERROR_IO_PENDING as i32) - || err.kind() == io::ErrorKind::WouldBlock => - { - // The operation is pending. - } - - Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => { - // The socket was closed. We need to delete it. - // This should happen after we drop it here. - } - - Err(err) => return Err(err), - } - - // We are now polling for the current events. - socket.status = SocketStatus::Polling { flags: mask }; - - Ok(()) - } - } - } - - /// This socket state was notified; see if we need to update it. - /// - /// This indicates that this packet was indicated as "ready" by the IOCP and needs to be - /// processed. - fn feed_event(self: Pin>, poller: &Poller) -> io::Result { - let inner = self.as_ref().data().project_ref(); - - let (afd_info, socket) = match inner { - PacketInnerProj::Socket { packet, socket } => (packet, socket), - PacketInnerProj::Custom { event } => { - // This is a custom event. - return Ok(FeedEventResult::Event(*event)); - } - PacketInnerProj::Wakeup { .. } => { - // The poller was notified. - return Ok(FeedEventResult::Notified); - } - PacketInnerProj::Waitable { handle } => { - let mut handle = lock!(handle.lock()); - let event = handle.interest; - - // Clear the events if we are in one-shot mode. - if matches!(handle.mode, PollMode::Oneshot) { - handle.interest = Event::none(handle.interest.key); - } - - // Submit for an update. - drop(handle); - poller.update_packet(self)?; - - return Ok(FeedEventResult::Event(event)); - } - }; - - let mut socket_state = lock!(socket.lock()); - let mut event = Event::none(socket_state.interest.key); - - // Put ourselves into the idle state. - socket_state.status = SocketStatus::Idle; - - // If we are waiting to be deleted, just return and let the drop handler do their thing. - if socket_state.waiting_on_delete { - return Ok(FeedEventResult::NoEvent); - } - - unsafe { - // SAFETY: The packet is not in transit. - let iosb = &mut *self.as_ref().iosb().get(); - - // Check the status. - match iosb.Anonymous.Status { - STATUS_CANCELLED => { - // Poll request was cancelled. - } - - status if status < 0 => { - // There was an error, so we signal both ends. - event.readable = true; - event.writable = true; - } - - _ => { - // Check in on the AFD data. - let afd_data = &*afd_info.get(); - - // There was at least one event. - if afd_data.handle_count() >= 1 { - let events = afd_data.events(); - - // If we closed the socket, remove it from being polled. - if events.intersects(AfdPollMask::LOCAL_CLOSE) { - let source = lock!(poller.sources.write()) - .remove(&socket_state.socket) - .unwrap(); - return source.begin_delete().map(|()| FeedEventResult::NoEvent); - } - - // Report socket-related events. - let (readable, writable) = afd_mask_to_event(events); - event.readable = readable; - event.writable = writable; - event.extra.flags = events; - } - } - } - } - - // Filter out events that the user didn't ask for. - event.readable &= socket_state.interest.readable; - event.writable &= socket_state.interest.writable; - - // If this event doesn't have anything that interests us, don't return or - // update the oneshot state. - let return_value = if event.readable - || event.writable - || event - .extra - .flags - .intersects(socket_state.interest.extra.flags) - { - // If we are in oneshot mode, remove the interest. - if matches!(socket_state.mode, PollMode::Oneshot) { - socket_state.interest = Event::none(socket_state.interest.key); - socket_state.interest_error = false; - } - - FeedEventResult::Event(event) - } else { - FeedEventResult::NoEvent - }; - - // Put ourselves in the update queue. - drop(socket_state); - poller.update_packet(self)?; - - // Return the event. - Ok(return_value) - } - - /// Begin deleting this socket. - fn begin_delete(self: Pin>) -> io::Result<()> { - // If we aren't already being deleted, start deleting. - let mut socket = match self.as_ref().data().project_ref() { - PacketInnerProj::Socket { socket, .. } => lock!(socket.lock()), - PacketInnerProj::Waitable { handle } => { - let mut handle = lock!(handle.lock()); - - // Set the status to be cancelled. This drops the wait handle and prevents - // any further updates. - handle.status = WaitableStatus::Cancelled; - - return Ok(()); - } - _ => panic!("can't delete packet that doesn't belong to a socket"), - }; - if !socket.waiting_on_delete { - socket.waiting_on_delete = true; - - if matches!(socket.status, SocketStatus::Polling { .. }) { - // Cancel the ongoing poll. - self.cancel(socket)?; - } - } - - // Either drop it now or wait for it to be dropped later. - Ok(()) - } - - fn cancel(self: &Pin>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> { - assert!(matches!(socket.status, SocketStatus::Polling { .. })); - - // Send the cancel request. - unsafe { - socket.afd.cancel(self)?; - } - - // Move state to cancelled. - socket.status = SocketStatus::Cancelled; - - Ok(()) - } -} - -/// Per-socket state. -#[derive(Debug)] -struct SocketState { - /// The raw socket handle. - socket: RawSocket, - - /// The base socket handle. - base_socket: RawSocket, - - /// The event that this socket is currently waiting on. - interest: Event, - - /// Whether to listen for error events. - interest_error: bool, - - /// The current poll mode. - mode: PollMode, - - /// The AFD instance that this socket is registered with. - afd: Arc>, - - /// Whether this socket is waiting to be deleted. - waiting_on_delete: bool, - - /// The current status of the socket. - status: SocketStatus, -} - -/// The mode that a socket can be in. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum SocketStatus { - /// We are currently not polling. - Idle, - - /// We are currently polling these events. - Polling { - /// The flags we are currently polling for. - flags: AfdPollMask, - }, - - /// The last poll operation was cancelled, and we're waiting for it to - /// complete. - Cancelled, -} - -/// Per-waitable handle state. -#[derive(Debug)] -struct WaitableState { - /// The handle that this state is for. - handle: RawHandle, - - /// The IO completion port that this handle is registered with. - port: Weak>, - - /// The event that this handle will report. - interest: Event, - - /// The current poll mode. - mode: PollMode, - - /// The status of this waitable. - status: WaitableStatus, -} - -#[derive(Debug)] -enum WaitableStatus { - /// We are not polling. - Idle, - - /// We are waiting on this handle to become signaled. - Waiting(#[allow(dead_code)] WaitHandle), - - /// This handle has been cancelled. - Cancelled, -} - -impl WaitableStatus { - fn is_idle(&self) -> bool { - matches!(self, WaitableStatus::Idle) - } -} - -/// The result of calling `feed_event`. -#[derive(Debug)] -enum FeedEventResult { - /// No event was yielded. - NoEvent, - - /// An event was yielded. - Event(Event), - - /// The poller has been notified. - Notified, -} - -/// A handle for an ongoing wait operation. -#[derive(Debug)] -struct WaitHandle(RawHandle); - -impl Drop for WaitHandle { - fn drop(&mut self) { - unsafe { - UnregisterWait(self.0 as _); - } - } -} - -impl WaitHandle { - /// Wait for a waitable handle to become signaled. - fn new( - handle: RawHandle, - callback: F, - timeout: Option, - long_wait: bool, - ) -> io::Result - where - F: FnOnce() + Send + Sync + 'static, - { - // Make sure a panic in the callback doesn't propagate to the OS. - struct AbortOnDrop; - - impl Drop for AbortOnDrop { - fn drop(&mut self) { - std::process::abort(); - } - } - - unsafe extern "system" fn wait_callback( - context: *mut c_void, - _timer_fired: BOOLEAN, - ) { - let _guard = AbortOnDrop; - let callback = Box::from_raw(context as *mut F); - callback(); - - // We executed without panicking, so don't abort. - forget(_guard); - } - - let mut wait_handle = MaybeUninit::::uninit(); - - let mut flags = WT_EXECUTEONLYONCE; - if long_wait { - flags |= WT_EXECUTELONGFUNCTION; - } - - let res = unsafe { - RegisterWaitForSingleObject( - wait_handle.as_mut_ptr().cast::<_>(), - handle as _, - Some(wait_callback::), - Box::into_raw(Box::new(callback)) as _, - timeout.map_or(INFINITE, dur2timeout), - flags, - ) - }; - - if res == 0 { - return Err(io::Error::last_os_error()); - } - - let wait_handle = unsafe { wait_handle.assume_init() }; - Ok(Self(wait_handle)) - } -} - -/// Translate an event to the mask expected by AFD. -#[inline] -fn event_to_afd_mask(event: Event, error: bool) -> afd::AfdPollMask { - event_properties_to_afd_mask(event.readable, event.writable, error) | event.extra.flags -} - -/// Translate an event to the mask expected by AFD. -#[inline] -fn event_properties_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask { - use afd::AfdPollMask as AfdPoll; - - let mut mask = AfdPoll::empty(); - - if error || readable || writable { - mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL; - } - - if readable { - mask |= - AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED; - } - - if writable { - mask |= AfdPoll::SEND; - } - - mask -} - -/// Convert the mask reported by AFD to an event. -#[inline] -fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) { - use afd::AfdPollMask as AfdPoll; - - let mut readable = false; - let mut writable = false; - - if mask.intersects( - AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED, - ) { - readable = true; - } - - if mask.intersects(AfdPoll::SEND) { - writable = true; - } - - if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) { - readable = true; - writable = true; - } - - (readable, writable) -} +use windows_sys::Win32::System::Threading::INFINITE; // Implementation taken from https://github.com/rust-lang/rust/blob/db5476571d9b27c862b95c1e64764b0ac8980e23/src/libstd/sys/windows/mod.rs fn dur2timeout(dur: Duration) -> u32 { @@ -1339,59 +34,3 @@ fn dur2timeout(dur: Duration) -> u32 { .and_then(|x| u32::try_from(x).ok()) .unwrap_or(INFINITE) } - -/// An error type that wraps around failing to open AFD. -struct AfdError { - /// String description of what happened. - description: &'static str, - - /// The underlying system error. - system: io::Error, -} - -impl AfdError { - #[inline] - fn new(description: &'static str, system: io::Error) -> Self { - Self { - description, - system, - } - } -} - -impl fmt::Debug for AfdError { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("AfdError") - .field("description", &self.description) - .field("system", &self.system) - .field("note", &"probably caused by old Windows or Wine") - .finish() - } -} - -impl fmt::Display for AfdError { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}: {}\nThis error is usually caused by running on old Windows or Wine", - self.description, &self.system - ) - } -} - -impl std::error::Error for AfdError { - #[inline] - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.system) - } -} - -struct CallOnDrop(F); - -impl Drop for CallOnDrop { - fn drop(&mut self) { - (self.0)(); - } -} diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs new file mode 100644 index 0000000..892c446 --- /dev/null +++ b/src/iocp/psn/mod.rs @@ -0,0 +1,576 @@ +mod wait; + +use std::collections::HashMap; +use std::io; +use std::os::windows::io::{ + AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, FromRawHandle, OwnedHandle, + RawHandle, RawSocket, +}; +use std::ptr::null_mut; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +use wait::WaitCompletionPacket; +use windows_sys::Win32::Foundation::{ERROR_SUCCESS, INVALID_HANDLE_VALUE, WAIT_TIMEOUT}; +use windows_sys::Win32::Networking::WinSock::{ + ProcessSocketNotifications, SOCK_NOTIFY_EVENT_ERR, SOCK_NOTIFY_EVENT_HANGUP, + SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_EVENT_REMOVE, SOCK_NOTIFY_OP_DISABLE, + SOCK_NOTIFY_OP_ENABLE, SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, + SOCK_NOTIFY_REGISTER_EVENT_IN, SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, + SOCK_NOTIFY_REGISTRATION, SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, + SOCK_NOTIFY_TRIGGER_ONESHOT, SOCK_NOTIFY_TRIGGER_PERSISTENT, +}; +use windows_sys::Win32::System::Threading::INFINITE; +use windows_sys::Win32::System::IO::{ + CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED_ENTRY, +}; + +use super::dur2timeout; +use crate::{Event, PollMode, NOTIFY_KEY}; + +/// Interface to kqueue. +#[derive(Debug)] +pub struct Poller { + /// The I/O completion port. + port: Arc, + sources: RwLock>, +} + +#[derive(Debug)] +pub(crate) enum SourceAttr { + Socket { + key: usize, + }, + Waitable { + key: usize, + packet: wait::WaitCompletionPacket, + }, +} + +impl Poller { + /// Creates a new poller. + pub fn new() -> io::Result { + let handle = unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 0) }; + if handle == 0 { + return Err(io::Error::last_os_error()); + } + + tracing::trace!(port = ?handle, "new"); + let port = Arc::new(unsafe { OwnedHandle::from_raw_handle(handle as _) }); + Ok(Poller { + port, + sources: RwLock::default(), + }) + } + + /// Whether this poller supports level-triggered events. + pub fn supports_level(&self) -> bool { + true + } + + /// Whether this poller supports edge-triggered events. + pub fn supports_edge(&self) -> bool { + true + } + + /// Adds a new socket. + /// + /// # Safety + /// + /// The socket must be valid and it must last until it is deleted. + pub unsafe fn add(&self, socket: RawSocket, interest: Event, mode: PollMode) -> io::Result<()> { + let span = tracing::trace_span!( + "add", + handle = ?self.port, + sock = ?socket, + ev = ?interest, + ); + let _enter = span.enter(); + + self.add_source( + socket as _, + SourceAttr::Socket { key: interest.key }, + |_| Ok(()), + )?; + + let info = create_registration(socket, interest, mode, true); + self.update_source(info) + } + + /// Modifies an existing socket. + pub fn modify( + &self, + socket: BorrowedSocket<'_>, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + let span = tracing::trace_span!( + "modify", + handle = ?self.port, + sock = ?socket, + ev = ?interest, + ); + let _enter = span.enter(); + + let socket = socket.as_raw_socket(); + + self.has_socket(socket as _)?; + + let info = create_registration(socket, interest, mode, true); + unsafe { self.update_source(info) } + } + + /// Deletes a socket. + pub fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> { + let span = tracing::trace_span!( + "delete", + handle = ?self.port, + sock = ?socket + ); + let _enter = span.enter(); + + let socket = socket.as_raw_socket(); + + if let SourceAttr::Socket { key } = self.remove_source(socket as _)? { + let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); + unsafe { self.update_source(info) } + } else { + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + + pub(crate) fn add_waitable( + &self, + handle: RawHandle, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + tracing::trace!( + "add_waitable: handle={:?}, waitable={:p}, ev={:?}", + self.port, + handle, + interest + ); + + if !matches!(mode, PollMode::Oneshot) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "only support oneshot events", + )); + } + + let key = interest.key; + + let packet = wait::WaitCompletionPacket::new()?; + self.add_source( + handle as _, + SourceAttr::Waitable { key, packet }, + |source| { + if let SourceAttr::Waitable { key, packet } = source { + packet.associate( + self.port.as_raw_handle(), + handle, + *key, + interest_to_events(&interest) as _, + ) + } else { + unreachable!() + } + }, + ) + } + + pub(crate) fn modify_waitable( + &self, + waitable: RawHandle, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + tracing::trace!( + "modify_waitable: handle={:?}, waitable={:p}, ev={:?}", + self.port, + waitable, + interest + ); + + if !matches!(mode, PollMode::Oneshot) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "only support oneshot events", + )); + } + + self.has_waitable(waitable as _, |key, packet| { + let cancelled = packet.cancel()?; + if !cancelled { + // The packet could not be reused, create a new one. + *packet = WaitCompletionPacket::new()?; + } + packet.associate( + self.port.as_raw_handle(), + waitable, + key, + interest_to_events(&interest) as _, + ) + }) + } + + pub(crate) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> { + tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable); + + if let SourceAttr::Waitable { mut packet, .. } = self.remove_source(waitable as _)? { + packet.cancel()?; + Ok(()) + } else { + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + + /// Add a source to the sources set. + #[inline] + pub(crate) fn add_source( + &self, + handle: usize, + source: SourceAttr, + handler: impl FnOnce(&mut SourceAttr) -> io::Result<()>, + ) -> io::Result<()> { + let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); + if sources.contains_key(&handle) { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + let source = sources.entry(handle).or_insert(source); + handler(source) + } + + /// Tell if a socket is currently inside the set. + #[inline] + pub(crate) fn has_socket(&self, handle: usize) -> io::Result { + if let Some(SourceAttr::Socket { key }) = self + .sources + .read() + .unwrap_or_else(|e| e.into_inner()) + .get(&handle) + { + Ok(*key) + } else { + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + + /// Tell if a waitable is currently inside the set. + #[inline] + pub(crate) fn has_waitable( + &self, + handle: usize, + handler: impl FnOnce(usize, &mut WaitCompletionPacket) -> io::Result<()>, + ) -> io::Result<()> { + if let Some(SourceAttr::Waitable { key, packet }) = self + .sources + .write() + .unwrap_or_else(|e| e.into_inner()) + .get_mut(&handle) + { + handler(*key, packet) + } else { + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + + /// Remove a source from the sources set. + #[inline] + pub(crate) fn remove_source(&self, handle: usize) -> io::Result { + self.sources + .write() + .unwrap_or_else(|e| e.into_inner()) + .remove(&handle) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) + } + + unsafe fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> { + let res = unsafe { + ProcessSocketNotifications( + self.port.as_raw_handle() as _, + 1, + &mut reg, + 0, + 0, + null_mut(), + null_mut(), + ) + }; + if res == ERROR_SUCCESS { + if reg.registrationResult == ERROR_SUCCESS { + Ok(()) + } else { + Err(io::Error::from_raw_os_error(reg.registrationResult as _)) + } + } else { + Err(io::Error::from_raw_os_error(res as _)) + } + } + + /// Waits for I/O events with an optional timeout. + pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { + let span = tracing::trace_span!( + "wait", + handle = ?self.port, + ?timeout, + ); + let _enter = span.enter(); + + let timeout = timeout.map_or(INFINITE, dur2timeout); + let spare_entries = events.list.spare_capacity_mut(); + let mut received = 0; + let res = unsafe { + ProcessSocketNotifications( + self.port.as_raw_handle() as _, + 0, + null_mut(), + timeout, + spare_entries.len() as _, + spare_entries.as_mut_ptr().cast(), + &mut received, + ) + }; + + if res == ERROR_SUCCESS { + tracing::trace!( + handle = ?self.port, + received, + "new events", + ); + unsafe { events.list.set_len(events.list.len() + received as usize) }; + Ok(()) + } else if res == WAIT_TIMEOUT { + Ok(()) + } else { + Err(io::Error::from_raw_os_error(res as _)) + } + } + + /// Sends a notification to wake up the current or next `wait()` call. + pub fn notify(&self) -> io::Result<()> { + self.post(CompletionPacket::new(Event::none(NOTIFY_KEY))) + } + + pub fn post(&self, packet: CompletionPacket) -> io::Result<()> { + let span = tracing::trace_span!( + "post", + handle = ?self.port, + key = ?packet.0.key, + ); + let _enter = span.enter(); + + let event = packet.event(); + let res = unsafe { + PostQueuedCompletionStatus( + self.port.as_raw_handle() as _, + interest_to_events(event), + event.key, + null_mut(), + ) + }; + if res == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } +} + +impl AsRawHandle for Poller { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +impl AsHandle for Poller { + fn as_handle(&self) -> BorrowedHandle<'_> { + self.port.as_handle() + } +} + +/// A list of reported I/O events. +pub struct Events { + list: Vec, +} + +unsafe impl Send for Events {} + +impl Events { + /// Creates an empty list. + pub fn with_capacity(cap: usize) -> Events { + Events { + list: Vec::with_capacity(cap), + } + } + + /// Iterates over I/O events. + pub fn iter(&self) -> impl Iterator + '_ { + self.list.iter().filter_map(|ev| { + let key = ev.lpCompletionKey; + // The post CompletionPacket. + if key == NOTIFY_KEY { + return None; + } + let events = ev.dwNumberOfBytesTransferred; + // Just ignore the remove event. + if events == SOCK_NOTIFY_EVENT_REMOVE { + return None; + } + Some(Event { + key: ev.lpCompletionKey, + readable: (events & SOCK_NOTIFY_EVENT_IN) != 0, + writable: (events & SOCK_NOTIFY_EVENT_OUT) != 0, + extra: EventExtra { + hup: (events & SOCK_NOTIFY_EVENT_HANGUP) != 0, + err: (events & SOCK_NOTIFY_EVENT_ERR) != 0, + }, + }) + }) + } + + /// Clears the list. + pub fn clear(&mut self) { + self.list.clear(); + } + + /// Get the capacity of the list. + pub fn capacity(&self) -> usize { + self.list.capacity() + } +} + +/// Extra information associated with an event. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct EventExtra { + hup: bool, + err: bool, +} + +impl EventExtra { + /// Create a new, empty version of this struct. + #[inline] + pub const fn empty() -> EventExtra { + EventExtra { + hup: false, + err: false, + } + } + + /// Set the interrupt flag. + #[inline] + pub fn set_hup(&mut self, value: bool) { + self.hup = value; + } + + /// Set the priority flag. + #[inline] + pub fn set_pri(&mut self, _value: bool) { + // No-op. + } + + /// Is the interrupt flag set? + #[inline] + pub fn is_hup(&self) -> bool { + self.hup + } + + /// Is the priority flag set? + #[inline] + pub fn is_pri(&self) -> bool { + false + } + + #[inline] + pub fn is_connect_failed(&self) -> Option { + None + } + + #[inline] + pub fn is_err(&self) -> Option { + Some(self.err) + } +} + +/// A packet used to wake up the poller with an event. +#[derive(Debug, Clone)] +pub struct CompletionPacket(Event); + +impl CompletionPacket { + /// Create a new completion packet with a custom event. + pub fn new(event: Event) -> Self { + Self(event) + } + + /// Get the event associated with this packet. + pub fn event(&self) -> &Event { + &self.0 + } +} + +pub(crate) fn interest_to_filter(interest: &Event) -> u16 { + let mut filter = SOCK_NOTIFY_REGISTER_EVENT_NONE; + if interest.readable { + filter |= SOCK_NOTIFY_REGISTER_EVENT_IN; + } + if interest.writable { + filter |= SOCK_NOTIFY_REGISTER_EVENT_OUT; + } + if interest.extra.hup { + filter |= SOCK_NOTIFY_REGISTER_EVENT_HANGUP; + } + filter as _ +} + +pub(crate) fn interest_to_events(interest: &Event) -> u32 { + let mut events = 0; + if interest.readable { + events |= SOCK_NOTIFY_EVENT_IN; + } + if interest.writable { + events |= SOCK_NOTIFY_EVENT_OUT; + } + if interest.extra.hup { + events |= SOCK_NOTIFY_EVENT_HANGUP; + } + if interest.extra.err { + events |= SOCK_NOTIFY_EVENT_ERR; + } + events +} + +pub(crate) fn mode_to_flags(mode: PollMode) -> u8 { + let flags = match mode { + PollMode::Oneshot => SOCK_NOTIFY_TRIGGER_ONESHOT | SOCK_NOTIFY_TRIGGER_LEVEL, + PollMode::Level => SOCK_NOTIFY_TRIGGER_PERSISTENT | SOCK_NOTIFY_TRIGGER_LEVEL, + PollMode::Edge => SOCK_NOTIFY_TRIGGER_PERSISTENT | SOCK_NOTIFY_TRIGGER_EDGE, + PollMode::EdgeOneshot => SOCK_NOTIFY_TRIGGER_ONESHOT | SOCK_NOTIFY_TRIGGER_EDGE, + }; + flags as u8 +} + +pub(crate) fn create_registration( + socket: RawSocket, + interest: Event, + mode: PollMode, + enable: bool, +) -> SOCK_NOTIFY_REGISTRATION { + let filter = interest_to_filter(&interest); + SOCK_NOTIFY_REGISTRATION { + socket: socket as _, + completionKey: interest.key as _, + eventFilter: filter, + operation: if enable { + if filter == SOCK_NOTIFY_REGISTER_EVENT_NONE as _ { + SOCK_NOTIFY_OP_DISABLE as _ + } else { + SOCK_NOTIFY_OP_ENABLE as _ + } + } else { + SOCK_NOTIFY_OP_REMOVE as _ + }, + triggerFlags: mode_to_flags(mode), + registrationResult: 0, + } +} diff --git a/src/iocp/psn/wait.rs b/src/iocp/psn/wait.rs new file mode 100644 index 0000000..30a3749 --- /dev/null +++ b/src/iocp/psn/wait.rs @@ -0,0 +1,93 @@ +use std::ffi::c_void; +use std::io; +use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}; +use std::ptr::null_mut; + +use windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES; +use windows_sys::Win32::Foundation::{ + RtlNtStatusToDosError, BOOLEAN, GENERIC_READ, GENERIC_WRITE, HANDLE, NTSTATUS, + STATUS_CANCELLED, STATUS_PENDING, STATUS_SUCCESS, +}; + +extern "system" { + fn NtCreateWaitCompletionPacket( + WaitCompletionPacketHandle: *mut HANDLE, + DesiredAccess: u32, + ObjectAttributes: *mut OBJECT_ATTRIBUTES, + ) -> NTSTATUS; + + fn NtAssociateWaitCompletionPacket( + WaitCompletionPacketHandle: HANDLE, + IoCompletionHandle: HANDLE, + TargetObjectHandle: HANDLE, + KeyContext: *mut c_void, + ApcContext: *mut c_void, + IoStatus: NTSTATUS, + IoStatusInformation: usize, + AlreadySignaled: *mut BOOLEAN, + ) -> NTSTATUS; + + fn NtCancelWaitCompletionPacket( + WaitCompletionPacketHandle: HANDLE, + RemoveSignaledPacket: BOOLEAN, + ) -> NTSTATUS; +} + +#[derive(Debug)] +pub struct WaitCompletionPacket { + handle: OwnedHandle, +} + +fn check_status(status: NTSTATUS) -> io::Result<()> { + if status == STATUS_SUCCESS { + Ok(()) + } else { + Err(io::Error::from_raw_os_error(unsafe { + RtlNtStatusToDosError(status) as _ + })) + } +} + +impl WaitCompletionPacket { + pub fn new() -> io::Result { + let mut handle = 0; + check_status(unsafe { + NtCreateWaitCompletionPacket(&mut handle, GENERIC_READ | GENERIC_WRITE, null_mut()) + })?; + let handle = unsafe { OwnedHandle::from_raw_handle(handle as _) }; + Ok(Self { handle }) + } + + pub fn associate( + &mut self, + port: RawHandle, + event: RawHandle, + key: usize, + info: usize, + ) -> io::Result<()> { + check_status(unsafe { + NtAssociateWaitCompletionPacket( + self.handle.as_raw_handle() as _, + port as _, + event as _, + key as _, + null_mut(), + STATUS_SUCCESS, + info, + null_mut(), + ) + })?; + Ok(()) + } + + pub fn cancel(&mut self) -> io::Result { + let status = unsafe { NtCancelWaitCompletionPacket(self.handle.as_raw_handle() as _, 0) }; + match status { + STATUS_SUCCESS | STATUS_CANCELLED => Ok(true), + STATUS_PENDING => Ok(false), + _ => Err(io::Error::from_raw_os_error(unsafe { + RtlNtStatusToDosError(status) as _ + })), + } + } +} diff --git a/src/iocp/afd.rs b/src/iocp/wepoll/afd.rs similarity index 100% rename from src/iocp/afd.rs rename to src/iocp/wepoll/afd.rs diff --git a/src/iocp/wepoll/mod.rs b/src/iocp/wepoll/mod.rs new file mode 100644 index 0000000..38550f1 --- /dev/null +++ b/src/iocp/wepoll/mod.rs @@ -0,0 +1,1375 @@ +//! Bindings to Windows I/O Completion Ports. +//! +//! I/O Completion Ports is a completion-based API rather than a polling-based API, like +//! epoll or kqueue. Therefore, we have to adapt the IOCP API to the crate's API. +//! +//! WinSock is powered by the Auxillary Function Driver (AFD) subsystem, which can be +//! accessed directly by using unstable `ntdll` functions. AFD exposes features that are not +//! available through the normal WinSock interface, such as IOCTL_AFD_POLL. This function is +//! similar to the exposed `WSAPoll` method. However, once the targeted socket is "ready", +//! a completion packet is queued to an I/O completion port. +//! +//! We take advantage of IOCTL_AFD_POLL to "translate" this crate's polling-based API +//! to the one Windows expects. When a device is added to the `Poller`, an IOCTL_AFD_POLL +//! operation is started and queued to the IOCP. To modify a currently registered device +//! (e.g. with `modify()` or `delete()`), the ongoing POLL is cancelled and then restarted +//! with new parameters. Whn the POLL eventually completes, the packet is posted to the IOCP. +//! From here it's a simple matter of using `GetQueuedCompletionStatusEx` to read the packets +//! from the IOCP and react accordingly. Notifying the poller is trivial, because we can +//! simply post a packet to the IOCP to wake it up. +//! +//! The main disadvantage of this strategy is that it relies on unstable Windows APIs. +//! However, as `libuv` (the backing I/O library for Node.JS) relies on the same unstable +//! AFD strategy, it is unlikely to be broken without plenty of advanced warning. +//! +//! Previously, this crate used the `wepoll` library for polling. `wepoll` uses a similar +//! AFD-based strategy for polling. + +mod afd; +mod port; + +use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock}; +use port::{IoCompletionPort, OverlappedEntry}; + +use windows_sys::Win32::Foundation::{ + BOOLEAN, ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED, +}; +use windows_sys::Win32::System::Threading::{ + RegisterWaitForSingleObject, UnregisterWait, INFINITE, WT_EXECUTELONGFUNCTION, + WT_EXECUTEONLYONCE, +}; + +use super::dur2timeout; +use crate::{Event, PollMode}; + +use concurrent_queue::ConcurrentQueue; +use pin_project_lite::pin_project; + +use std::cell::UnsafeCell; +use std::collections::hash_map::{Entry, HashMap}; +use std::ffi::c_void; +use std::fmt; +use std::io; +use std::marker::PhantomPinned; +use std::mem::{forget, MaybeUninit}; +use std::os::windows::io::{ + AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket, +}; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; +use std::time::{Duration, Instant}; + +/// Macro to lock and ignore lock poisoning. +macro_rules! lock { + ($lock_result:expr) => {{ + $lock_result.unwrap_or_else(|e| e.into_inner()) + }}; +} + +/// Interface to I/O completion ports. +#[derive(Debug)] +pub(crate) struct Poller { + /// The I/O completion port. + port: Arc>, + + /// List of currently active AFD instances. + /// + /// AFD acts as the actual source of the socket events. It's essentially running `WSAPoll` on + /// the sockets and then posting the events to the IOCP. + /// + /// AFD instances can be keyed to an unlimited number of sockets. However, each AFD instance + /// polls their sockets linearly. Therefore, it is best to limit the number of sockets each AFD + /// instance is responsible for. The limit of 32 is chosen because that's what `wepoll` uses. + /// + /// Weak references are kept here so that the AFD handle is automatically dropped when the last + /// associated socket is dropped. + afd: Mutex>>>, + + /// The state of the sources registered with this poller. + /// + /// Each source is keyed by its raw socket ID. + sources: RwLock>, + + /// The state of the waitable handles registered with this poller. + waitables: RwLock>, + + /// Sockets with pending updates. + /// + /// This list contains packets with sockets that need to have their AFD state adjusted by + /// calling the `update()` function on them. It's best to queue up packets as they need to + /// be updated and then run all of the updates before we start waiting on the IOCP, rather than + /// updating them as we come. If we're waiting on the IOCP updates should be run immediately. + pending_updates: ConcurrentQueue, + + /// Are we currently polling? + /// + /// This indicates whether or not we are blocking on the IOCP, and is used to determine + /// whether pending updates should be run immediately or queued. + polling: AtomicBool, + + /// The packet used to notify the poller. + /// + /// This is a special-case packet that is used to wake up the poller when it is waiting. + notifier: Packet, +} + +unsafe impl Send for Poller {} +unsafe impl Sync for Poller {} + +impl Poller { + /// Creates a new poller. + pub(crate) fn new() -> io::Result { + // Make sure AFD is able to be used. + if let Err(e) = afd::NtdllImports::force_load() { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + AfdError::new("failed to initialize unstable Windows functions", e), + )); + } + + // Create and destroy a single AFD to test if we support it. + Afd::::new().map_err(|e| { + io::Error::new( + io::ErrorKind::Unsupported, + AfdError::new("failed to initialize \\Device\\Afd", e), + ) + })?; + + let port = IoCompletionPort::new(0)?; + tracing::trace!(handle = ?port, "new"); + + Ok(Poller { + port: Arc::new(port), + afd: Mutex::new(vec![]), + sources: RwLock::new(HashMap::new()), + waitables: RwLock::new(HashMap::new()), + pending_updates: ConcurrentQueue::bounded(1024), + polling: AtomicBool::new(false), + notifier: Arc::pin( + PacketInner::Wakeup { + _pinned: PhantomPinned, + } + .into(), + ), + }) + } + + /// Whether this poller supports level-triggered events. + pub(crate) fn supports_level(&self) -> bool { + true + } + + /// Whether this poller supports edge-triggered events. + pub(crate) fn supports_edge(&self) -> bool { + false + } + + /// Add a new source to the poller. + /// + /// # Safety + /// + /// The socket must be a valid socket and must last until it is deleted. + pub(crate) unsafe fn add( + &self, + socket: RawSocket, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + let span = tracing::trace_span!( + "add", + handle = ?self.port, + sock = ?socket, + ev = ?interest, + ); + let _enter = span.enter(); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Create a new packet. + let socket_state = { + // Create a new socket state and assign an AFD handle to it. + let state = SocketState { + socket, + base_socket: base_socket(socket)?, + interest, + interest_error: true, + afd: self.afd_handle()?, + mode, + waiting_on_delete: false, + status: SocketStatus::Idle, + }; + + // We wrap this socket state in a Packet so the IOCP can use it. + Arc::pin(IoStatusBlock::from(PacketInner::Socket { + packet: UnsafeCell::new(AfdPollInfo::default()), + socket: Mutex::new(state), + })) + }; + + // Keep track of the source in the poller. + { + let mut sources = lock!(self.sources.write()); + + match sources.entry(socket) { + Entry::Vacant(v) => { + v.insert(Pin::>::clone(&socket_state)); + } + + Entry::Occupied(_) => { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + } + } + + // Update the packet. + self.update_packet(socket_state) + } + + /// Update a source in the poller. + pub(crate) fn modify( + &self, + socket: BorrowedSocket<'_>, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + let span = tracing::trace_span!( + "modify", + handle = ?self.port, + sock = ?socket, + ev = ?interest, + ); + let _enter = span.enter(); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Get a reference to the source. + let source = { + let sources = lock!(self.sources.read()); + + sources + .get(&socket.as_raw_socket()) + .cloned() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? + }; + + // Set the new event. + if source.as_ref().set_events(interest, mode) { + // The packet needs to be updated. + self.update_packet(source)?; + } + + Ok(()) + } + + /// Delete a source from the poller. + pub(crate) fn delete(&self, socket: BorrowedSocket<'_>) -> io::Result<()> { + let span = tracing::trace_span!( + "remove", + handle = ?self.port, + sock = ?socket, + ); + let _enter = span.enter(); + + // Remove the source from our associative map. + let source = { + let mut sources = lock!(self.sources.write()); + + match sources.remove(&socket.as_raw_socket()) { + Some(s) => s, + None => { + // If the source has already been removed, then we can just return. + return Ok(()); + } + } + }; + + // Indicate to the source that it is being deleted. + // This cancels any ongoing AFD_IOCTL_POLL operations. + source.begin_delete() + } + + /// Add a new waitable to the poller. + pub(crate) fn add_waitable( + &self, + handle: RawHandle, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + tracing::trace!( + "add_waitable: handle={:?}, waitable={:p}, ev={:?}", + self.port, + handle, + interest + ); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Create a new packet. + let handle_state = { + let state = WaitableState { + handle, + port: Arc::downgrade(&self.port), + interest, + mode, + status: WaitableStatus::Idle, + }; + + Arc::pin(IoStatusBlock::from(PacketInner::Waitable { + handle: Mutex::new(state), + })) + }; + + // Keep track of the source in the poller. + { + let mut sources = lock!(self.waitables.write()); + + match sources.entry(handle) { + Entry::Vacant(v) => { + v.insert(Pin::>::clone(&handle_state)); + } + + Entry::Occupied(_) => { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + } + } + + // Update the packet. + self.update_packet(handle_state) + } + + /// Update a waitable in the poller. + pub(crate) fn modify_waitable( + &self, + waitable: RawHandle, + interest: Event, + mode: PollMode, + ) -> io::Result<()> { + tracing::trace!( + "modify_waitable: handle={:?}, waitable={:p}, ev={:?}", + self.port, + waitable, + interest + ); + + // We don't support edge-triggered events. + if matches!(mode, PollMode::Edge | PollMode::EdgeOneshot) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "edge-triggered events are not supported", + )); + } + + // Get a reference to the source. + let source = { + let sources = lock!(self.waitables.read()); + + sources + .get(&waitable) + .cloned() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? + }; + + // Set the new event. + if source.as_ref().set_events(interest, mode) { + self.update_packet(source)?; + } + + Ok(()) + } + + /// Delete a waitable from the poller. + pub(crate) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> { + tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable); + + // Get a reference to the source. + let source = { + let mut sources = lock!(self.waitables.write()); + + match sources.remove(&waitable) { + Some(s) => s, + None => { + // If the source has already been removed, then we can just return. + return Ok(()); + } + } + }; + + // Indicate to the source that it is being deleted. + // This cancels any ongoing AFD_IOCTL_POLL operations. + source.begin_delete() + } + + /// Wait for events. + pub(crate) fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { + let span = tracing::trace_span!( + "wait", + handle = ?self.port, + ?timeout, + ); + let _enter = span.enter(); + + // Make sure we have a consistent timeout. + let deadline = timeout.and_then(|timeout| Instant::now().checked_add(timeout)); + let mut notified = false; + events.packets.clear(); + + loop { + let mut new_events = 0; + + // Indicate that we are now polling. + let was_polling = self.polling.swap(true, Ordering::SeqCst); + debug_assert!(!was_polling); + + // Even if we panic, we want to make sure we indicate that polling has stopped. + let guard = CallOnDrop(|| { + let was_polling = self.polling.swap(false, Ordering::SeqCst); + debug_assert!(was_polling); + }); + + // Process every entry in the queue before we start polling. + self.drain_update_queue(false)?; + + // Get the time to wait for. + let timeout = deadline.map(|t| t.saturating_duration_since(Instant::now())); + + // Wait for I/O events. + let len = self.port.wait(&mut events.completions, timeout)?; + tracing::trace!( + handle = ?self.port, + res = ?len, + "new events"); + + // We are no longer polling. + drop(guard); + + // Process all of the events. + for entry in events.completions.drain(..) { + let packet = entry.into_packet(); + + // Feed the event into the packet. + match packet.feed_event(self)? { + FeedEventResult::NoEvent => {} + FeedEventResult::Event(event) => { + events.packets.push(event); + new_events += 1; + } + FeedEventResult::Notified => { + notified = true; + } + } + } + + // Break if there was a notification or at least one event, or if deadline is reached. + let timeout_is_empty = + timeout.map_or(false, |t| t.as_secs() == 0 && t.subsec_nanos() == 0); + if notified || new_events > 0 || timeout_is_empty { + break; + } + + tracing::trace!("wait: no events found, re-entering polling loop"); + } + + Ok(()) + } + + /// Notify this poller. + pub(crate) fn notify(&self) -> io::Result<()> { + // Push the notify packet into the IOCP. + self.port.post(0, 0, self.notifier.clone()) + } + + /// Push an IOCP packet into the queue. + pub(crate) fn post(&self, packet: CompletionPacket) -> io::Result<()> { + self.port.post(0, 0, packet.0) + } + + /// Run an update on a packet. + fn update_packet(&self, mut packet: Packet) -> io::Result<()> { + loop { + // If we are currently polling, we need to update the packet immediately. + if self.polling.load(Ordering::Acquire) { + packet.update()?; + return Ok(()); + } + + // Try to queue the update. + match self.pending_updates.push(packet) { + Ok(()) => return Ok(()), + Err(p) => packet = p.into_inner(), + } + + // If we failed to queue the update, we need to drain the queue first. + self.drain_update_queue(true)?; + + // Loop back and try again. + } + } + + /// Drain the update queue. + fn drain_update_queue(&self, limit: bool) -> io::Result<()> { + // Determine how many packets to process. + let max = if limit { + // Only drain the queue's capacity, since this could in theory run forever. + self.pending_updates.capacity().unwrap() + } else { + // Less of a concern if we're draining the queue prior to a poll operation. + std::usize::MAX + }; + + self.pending_updates + .try_iter() + .take(max) + .try_for_each(|packet| packet.update()) + } + + /// Get a handle to the AFD reference. + /// + /// This finds an AFD handle with less than 32 associated sockets, or creates a new one if + /// one does not exist. + fn afd_handle(&self) -> io::Result>> { + const AFD_MAX_SIZE: usize = 32; + + // Crawl the list and see if there are any existing AFD instances that we can use. + // While we're here, remove any unused AFD pointers. + let mut afd_handles = lock!(self.afd.lock()); + let mut i = 0; + while i < afd_handles.len() { + // Get the reference count of the AFD instance. + let refcount = Weak::strong_count(&afd_handles[i]); + + match refcount { + 0 => { + // Prune the AFD pointer if it has no references. + afd_handles.swap_remove(i); + } + + refcount if refcount >= AFD_MAX_SIZE => { + // Skip this one, since it is already at the maximum size. + i += 1; + } + + _ => { + // We can use this AFD instance. + match afd_handles[i].upgrade() { + Some(afd) => return Ok(afd), + None => { + // The last socket dropped the AFD before we could acquire it. + // Prune the AFD pointer and continue. + afd_handles.swap_remove(i); + } + } + } + } + } + + // No available handles, create a new AFD instance. + let afd = Arc::new(Afd::new()?); + + // Register the AFD instance with the I/O completion port. + self.port.register(&*afd, true)?; + + // Insert a weak pointer to the AFD instance into the list for other sockets. + afd_handles.push(Arc::downgrade(&afd)); + + Ok(afd) + } +} + +impl AsRawHandle for Poller { + fn as_raw_handle(&self) -> RawHandle { + self.port.as_raw_handle() + } +} + +impl AsHandle for Poller { + fn as_handle(&self) -> BorrowedHandle<'_> { + unsafe { BorrowedHandle::borrow_raw(self.as_raw_handle()) } + } +} + +/// The container for events. +pub(crate) struct Events { + /// List of IOCP packets. + packets: Vec, + + /// Buffer for completion packets. + completions: Vec>, +} + +unsafe impl Send for Events {} + +impl Events { + /// Creates an empty list of events. + pub fn with_capacity(cap: usize) -> Events { + Events { + packets: Vec::with_capacity(cap), + completions: Vec::with_capacity(cap), + } + } + + /// Iterate over I/O events. + pub fn iter(&self) -> impl Iterator + '_ { + self.packets.iter().copied() + } + + /// Clear the list of events. + pub fn clear(&mut self) { + self.packets.clear(); + } + + /// The capacity of the list of events. + pub fn capacity(&self) -> usize { + self.packets.capacity() + } +} + +/// Extra information about an event. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct EventExtra { + /// Flags associated with this event. + flags: AfdPollMask, +} + +impl EventExtra { + /// Create a new, empty version of this struct. + #[inline] + pub const fn empty() -> EventExtra { + EventExtra { + flags: AfdPollMask::empty(), + } + } + + /// Is this a HUP event? + #[inline] + pub fn is_hup(&self) -> bool { + self.flags.intersects(AfdPollMask::ABORT) + } + + /// Is this a PRI event? + #[inline] + pub fn is_pri(&self) -> bool { + self.flags.intersects(AfdPollMask::RECEIVE_EXPEDITED) + } + + /// Set up a listener for HUP events. + #[inline] + pub fn set_hup(&mut self, active: bool) { + self.flags.set(AfdPollMask::ABORT, active); + } + + /// Set up a listener for PRI events. + #[inline] + pub fn set_pri(&mut self, active: bool) { + self.flags.set(AfdPollMask::RECEIVE_EXPEDITED, active); + } + + /// Check if TCP connect failed. Deprecated. + #[inline] + pub fn is_connect_failed(&self) -> Option { + Some(self.flags.intersects(AfdPollMask::CONNECT_FAIL)) + } + + /// Check if TCP connect failed. + #[inline] + pub fn is_err(&self) -> Option { + Some(self.flags.intersects(AfdPollMask::CONNECT_FAIL)) + } +} + +/// A packet used to wake up the poller with an event. +#[derive(Debug, Clone)] +pub struct CompletionPacket(Packet); + +impl CompletionPacket { + /// Create a new completion packet with a custom event. + pub fn new(event: Event) -> Self { + Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event }))) + } + + /// Get the event associated with this packet. + pub fn event(&self) -> &Event { + let data = self.0.as_ref().data().project_ref(); + + match data { + PacketInnerProj::Custom { event } => event, + _ => unreachable!(), + } + } +} + +/// The type of our completion packet. +/// +/// It needs to be pinned, since it contains data that is expected by IOCP not to be moved. +type Packet = Pin>; +type PacketUnwrapped = IoStatusBlock; + +pin_project! { + /// The inner type of the packet. + #[project_ref = PacketInnerProj] + #[project = PacketInnerProjMut] + enum PacketInner { + // A packet for a socket. + Socket { + // The AFD packet state. + #[pin] + packet: UnsafeCell, + + // The socket state. + socket: Mutex + }, + + /// A packet for a waitable handle. + Waitable { + handle: Mutex + }, + + /// A custom event sent by the user. + Custom { + event: Event, + }, + + // A packet used to wake up the poller. + Wakeup { #[pin] _pinned: PhantomPinned }, + } +} + +unsafe impl Send for PacketInner {} +unsafe impl Sync for PacketInner {} + +impl fmt::Debug for PacketInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Wakeup { .. } => f.write_str("Wakeup { .. }"), + Self::Custom { event } => f.debug_struct("Custom").field("event", event).finish(), + Self::Socket { socket, .. } => f + .debug_struct("Socket") + .field("packet", &"..") + .field("socket", socket) + .finish(), + Self::Waitable { handle } => { + f.debug_struct("Waitable").field("handle", handle).finish() + } + } + } +} + +impl HasAfdInfo for PacketInner { + fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell> { + match self.project_ref() { + PacketInnerProj::Socket { packet, .. } => packet, + _ => unreachable!(), + } + } +} + +impl PacketUnwrapped { + /// Set the new events that this socket is waiting on. + /// + /// Returns `true` if we need to be updated. + fn set_events(self: Pin<&Self>, interest: Event, mode: PollMode) -> bool { + match self.data().project_ref() { + PacketInnerProj::Socket { socket, .. } => { + let mut socket = lock!(socket.lock()); + socket.interest = interest; + socket.mode = mode; + socket.interest_error = true; + + // If there was a change, indicate that we need an update. + match socket.status { + SocketStatus::Polling { flags } => { + let our_flags = event_to_afd_mask(socket.interest, socket.interest_error); + our_flags != flags + } + _ => true, + } + } + PacketInnerProj::Waitable { handle } => { + let mut handle = lock!(handle.lock()); + + // Set the new interest. + handle.interest = interest; + handle.mode = mode; + + // Update if there is no ongoing wait. + handle.status.is_idle() + } + _ => true, + } + } + + /// Update the socket and install the new status in AFD. + /// + /// This function does one of the following: + /// + /// - Nothing, if the packet is waiting on being dropped anyways. + /// - Cancels the ongoing poll, if we want to poll for different events than we are currently + /// polling for. + /// - Starts a new AFD_POLL operation, if we are not currently polling. + fn update(self: Pin>) -> io::Result<()> { + let mut socket = match self.as_ref().data().project_ref() { + PacketInnerProj::Socket { socket, .. } => lock!(socket.lock()), + PacketInnerProj::Waitable { handle } => { + let mut handle = lock!(handle.lock()); + + // If there is no interests, or if we have been cancelled, we don't need to update. + if !handle.interest.readable && !handle.interest.writable { + return Ok(()); + } + + // If we are idle, we need to update. + if !handle.status.is_idle() { + return Ok(()); + } + + // Start a new wait. + let packet = self.clone(); + let wait_handle = WaitHandle::new( + handle.handle, + move || { + let mut handle = match packet.as_ref().data().project_ref() { + PacketInnerProj::Waitable { handle } => lock!(handle.lock()), + _ => unreachable!(), + }; + + // Try to get the IOCP. + let iocp = match handle.port.upgrade() { + Some(iocp) => iocp, + None => return, + }; + + // Set us back into the idle state. + handle.status = WaitableStatus::Idle; + + // Push this packet. + drop(handle); + if let Err(e) = iocp.post(0, 0, packet) { + tracing::error!("failed to post completion packet: {}", e); + } + }, + None, + false, + )?; + + // Set the new status. + handle.status = WaitableStatus::Waiting(wait_handle); + + return Ok(()); + } + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")), + }; + + // If we are waiting on a delete, just return, dropping the packet. + if socket.waiting_on_delete { + return Ok(()); + } + + // Check the current status. + match socket.status { + SocketStatus::Polling { flags } => { + // If we need to poll for events aside from what we are currently polling, we need + // to update the packet. Cancel the ongoing poll. + let our_flags = event_to_afd_mask(socket.interest, socket.interest_error); + if our_flags != flags { + return self.cancel(socket); + } + + // All events that we are currently waiting on are accounted for. + Ok(()) + } + + SocketStatus::Cancelled => { + // The ongoing operation was cancelled, and we're still waiting for it to return. + // For now, wait until the top-level loop calls feed_event(). + Ok(()) + } + + SocketStatus::Idle => { + // Start a new poll. + let mask = event_to_afd_mask(socket.interest, socket.interest_error); + let result = socket.afd.poll(self.clone(), socket.base_socket, mask); + + match result { + Ok(()) => {} + + Err(err) + if err.raw_os_error() == Some(ERROR_IO_PENDING as i32) + || err.kind() == io::ErrorKind::WouldBlock => + { + // The operation is pending. + } + + Err(err) if err.raw_os_error() == Some(ERROR_INVALID_HANDLE as i32) => { + // The socket was closed. We need to delete it. + // This should happen after we drop it here. + } + + Err(err) => return Err(err), + } + + // We are now polling for the current events. + socket.status = SocketStatus::Polling { flags: mask }; + + Ok(()) + } + } + } + + /// This socket state was notified; see if we need to update it. + /// + /// This indicates that this packet was indicated as "ready" by the IOCP and needs to be + /// processed. + fn feed_event(self: Pin>, poller: &Poller) -> io::Result { + let inner = self.as_ref().data().project_ref(); + + let (afd_info, socket) = match inner { + PacketInnerProj::Socket { packet, socket } => (packet, socket), + PacketInnerProj::Custom { event } => { + // This is a custom event. + return Ok(FeedEventResult::Event(*event)); + } + PacketInnerProj::Wakeup { .. } => { + // The poller was notified. + return Ok(FeedEventResult::Notified); + } + PacketInnerProj::Waitable { handle } => { + let mut handle = lock!(handle.lock()); + let event = handle.interest; + + // Clear the events if we are in one-shot mode. + if matches!(handle.mode, PollMode::Oneshot) { + handle.interest = Event::none(handle.interest.key); + } + + // Submit for an update. + drop(handle); + poller.update_packet(self)?; + + return Ok(FeedEventResult::Event(event)); + } + }; + + let mut socket_state = lock!(socket.lock()); + let mut event = Event::none(socket_state.interest.key); + + // Put ourselves into the idle state. + socket_state.status = SocketStatus::Idle; + + // If we are waiting to be deleted, just return and let the drop handler do their thing. + if socket_state.waiting_on_delete { + return Ok(FeedEventResult::NoEvent); + } + + unsafe { + // SAFETY: The packet is not in transit. + let iosb = &mut *self.as_ref().iosb().get(); + + // Check the status. + match iosb.Anonymous.Status { + STATUS_CANCELLED => { + // Poll request was cancelled. + } + + status if status < 0 => { + // There was an error, so we signal both ends. + event.readable = true; + event.writable = true; + } + + _ => { + // Check in on the AFD data. + let afd_data = &*afd_info.get(); + + // There was at least one event. + if afd_data.handle_count() >= 1 { + let events = afd_data.events(); + + // If we closed the socket, remove it from being polled. + if events.intersects(AfdPollMask::LOCAL_CLOSE) { + let source = lock!(poller.sources.write()) + .remove(&socket_state.socket) + .unwrap(); + return source.begin_delete().map(|()| FeedEventResult::NoEvent); + } + + // Report socket-related events. + let (readable, writable) = afd_mask_to_event(events); + event.readable = readable; + event.writable = writable; + event.extra.flags = events; + } + } + } + } + + // Filter out events that the user didn't ask for. + event.readable &= socket_state.interest.readable; + event.writable &= socket_state.interest.writable; + + // If this event doesn't have anything that interests us, don't return or + // update the oneshot state. + let return_value = if event.readable + || event.writable + || event + .extra + .flags + .intersects(socket_state.interest.extra.flags) + { + // If we are in oneshot mode, remove the interest. + if matches!(socket_state.mode, PollMode::Oneshot) { + socket_state.interest = Event::none(socket_state.interest.key); + socket_state.interest_error = false; + } + + FeedEventResult::Event(event) + } else { + FeedEventResult::NoEvent + }; + + // Put ourselves in the update queue. + drop(socket_state); + poller.update_packet(self)?; + + // Return the event. + Ok(return_value) + } + + /// Begin deleting this socket. + fn begin_delete(self: Pin>) -> io::Result<()> { + // If we aren't already being deleted, start deleting. + let mut socket = match self.as_ref().data().project_ref() { + PacketInnerProj::Socket { socket, .. } => lock!(socket.lock()), + PacketInnerProj::Waitable { handle } => { + let mut handle = lock!(handle.lock()); + + // Set the status to be cancelled. This drops the wait handle and prevents + // any further updates. + handle.status = WaitableStatus::Cancelled; + + return Ok(()); + } + _ => panic!("can't delete packet that doesn't belong to a socket"), + }; + if !socket.waiting_on_delete { + socket.waiting_on_delete = true; + + if matches!(socket.status, SocketStatus::Polling { .. }) { + // Cancel the ongoing poll. + self.cancel(socket)?; + } + } + + // Either drop it now or wait for it to be dropped later. + Ok(()) + } + + fn cancel(self: &Pin>, mut socket: MutexGuard<'_, SocketState>) -> io::Result<()> { + assert!(matches!(socket.status, SocketStatus::Polling { .. })); + + // Send the cancel request. + unsafe { + socket.afd.cancel(self)?; + } + + // Move state to cancelled. + socket.status = SocketStatus::Cancelled; + + Ok(()) + } +} + +/// Per-socket state. +#[derive(Debug)] +struct SocketState { + /// The raw socket handle. + socket: RawSocket, + + /// The base socket handle. + base_socket: RawSocket, + + /// The event that this socket is currently waiting on. + interest: Event, + + /// Whether to listen for error events. + interest_error: bool, + + /// The current poll mode. + mode: PollMode, + + /// The AFD instance that this socket is registered with. + afd: Arc>, + + /// Whether this socket is waiting to be deleted. + waiting_on_delete: bool, + + /// The current status of the socket. + status: SocketStatus, +} + +/// The mode that a socket can be in. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum SocketStatus { + /// We are currently not polling. + Idle, + + /// We are currently polling these events. + Polling { + /// The flags we are currently polling for. + flags: AfdPollMask, + }, + + /// The last poll operation was cancelled, and we're waiting for it to + /// complete. + Cancelled, +} + +/// Per-waitable handle state. +#[derive(Debug)] +struct WaitableState { + /// The handle that this state is for. + handle: RawHandle, + + /// The IO completion port that this handle is registered with. + port: Weak>, + + /// The event that this handle will report. + interest: Event, + + /// The current poll mode. + mode: PollMode, + + /// The status of this waitable. + status: WaitableStatus, +} + +#[derive(Debug)] +enum WaitableStatus { + /// We are not polling. + Idle, + + /// We are waiting on this handle to become signaled. + Waiting(#[allow(dead_code)] WaitHandle), + + /// This handle has been cancelled. + Cancelled, +} + +impl WaitableStatus { + fn is_idle(&self) -> bool { + matches!(self, WaitableStatus::Idle) + } +} + +/// The result of calling `feed_event`. +#[derive(Debug)] +enum FeedEventResult { + /// No event was yielded. + NoEvent, + + /// An event was yielded. + Event(Event), + + /// The poller has been notified. + Notified, +} + +/// A handle for an ongoing wait operation. +#[derive(Debug)] +struct WaitHandle(RawHandle); + +impl Drop for WaitHandle { + fn drop(&mut self) { + unsafe { + UnregisterWait(self.0 as _); + } + } +} + +impl WaitHandle { + /// Wait for a waitable handle to become signaled. + fn new( + handle: RawHandle, + callback: F, + timeout: Option, + long_wait: bool, + ) -> io::Result + where + F: FnOnce() + Send + Sync + 'static, + { + // Make sure a panic in the callback doesn't propagate to the OS. + struct AbortOnDrop; + + impl Drop for AbortOnDrop { + fn drop(&mut self) { + std::process::abort(); + } + } + + unsafe extern "system" fn wait_callback( + context: *mut c_void, + _timer_fired: BOOLEAN, + ) { + let _guard = AbortOnDrop; + let callback = Box::from_raw(context as *mut F); + callback(); + + // We executed without panicking, so don't abort. + forget(_guard); + } + + let mut wait_handle = MaybeUninit::::uninit(); + + let mut flags = WT_EXECUTEONLYONCE; + if long_wait { + flags |= WT_EXECUTELONGFUNCTION; + } + + let res = unsafe { + RegisterWaitForSingleObject( + wait_handle.as_mut_ptr().cast::<_>(), + handle as _, + Some(wait_callback::), + Box::into_raw(Box::new(callback)) as _, + timeout.map_or(INFINITE, dur2timeout), + flags, + ) + }; + + if res == 0 { + return Err(io::Error::last_os_error()); + } + + let wait_handle = unsafe { wait_handle.assume_init() }; + Ok(Self(wait_handle)) + } +} + +/// Translate an event to the mask expected by AFD. +#[inline] +fn event_to_afd_mask(event: Event, error: bool) -> afd::AfdPollMask { + event_properties_to_afd_mask(event.readable, event.writable, error) | event.extra.flags +} + +/// Translate an event to the mask expected by AFD. +#[inline] +fn event_properties_to_afd_mask(readable: bool, writable: bool, error: bool) -> afd::AfdPollMask { + use afd::AfdPollMask as AfdPoll; + + let mut mask = AfdPoll::empty(); + + if error || readable || writable { + mask |= AfdPoll::ABORT | AfdPoll::CONNECT_FAIL; + } + + if readable { + mask |= + AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED; + } + + if writable { + mask |= AfdPoll::SEND; + } + + mask +} + +/// Convert the mask reported by AFD to an event. +#[inline] +fn afd_mask_to_event(mask: afd::AfdPollMask) -> (bool, bool) { + use afd::AfdPollMask as AfdPoll; + + let mut readable = false; + let mut writable = false; + + if mask.intersects( + AfdPoll::RECEIVE | AfdPoll::ACCEPT | AfdPoll::DISCONNECT | AfdPoll::RECEIVE_EXPEDITED, + ) { + readable = true; + } + + if mask.intersects(AfdPoll::SEND) { + writable = true; + } + + if mask.intersects(AfdPoll::ABORT | AfdPoll::CONNECT_FAIL) { + readable = true; + writable = true; + } + + (readable, writable) +} + +/// An error type that wraps around failing to open AFD. +struct AfdError { + /// String description of what happened. + description: &'static str, + + /// The underlying system error. + system: io::Error, +} + +impl AfdError { + #[inline] + fn new(description: &'static str, system: io::Error) -> Self { + Self { + description, + system, + } + } +} + +impl fmt::Debug for AfdError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AfdError") + .field("description", &self.description) + .field("system", &self.system) + .field("note", &"probably caused by old Windows or Wine") + .finish() + } +} + +impl fmt::Display for AfdError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}: {}\nThis error is usually caused by running on old Windows or Wine", + self.description, &self.system + ) + } +} + +impl std::error::Error for AfdError { + #[inline] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.system) + } +} + +struct CallOnDrop(F); + +impl Drop for CallOnDrop { + fn drop(&mut self) { + (self.0)(); + } +} diff --git a/src/iocp/port.rs b/src/iocp/wepoll/port.rs similarity index 100% rename from src/iocp/port.rs rename to src/iocp/wepoll/port.rs diff --git a/tests/multiple_pollers.rs b/tests/multiple_pollers.rs index 18f0efd..1f97212 100644 --- a/tests/multiple_pollers.rs +++ b/tests/multiple_pollers.rs @@ -1,5 +1,8 @@ //! Test registering one source into multiple pollers. +// On Windows, a socket handle can be registered to only one IOCP at a time. +#![cfg(not(all(windows, feature = "iocp-psn")))] + use polling::{Event, Events, PollMode, Poller}; use std::io::{self, prelude::*}; diff --git a/tests/other_modes.rs b/tests/other_modes.rs index 407e42b..dfb2c1f 100644 --- a/tests/other_modes.rs +++ b/tests/other_modes.rs @@ -120,7 +120,8 @@ fn edge_triggered() { target_os = "freebsd", target_os = "netbsd", target_os = "openbsd", - target_os = "dragonfly" + target_os = "dragonfly", + all(target_os = "windows", feature = "iocp-psn") ), not(polling_test_poll_backend) ))] { @@ -155,6 +156,9 @@ fn edge_triggered() { .unwrap(); assert!(events.is_empty()); + // On Windows, the buffer should be cleared to trigger the edge. + reader.read_exact(&mut [0; 2]).unwrap(); + // If we write more data, a notification should be delivered. writer.write_all(&data).unwrap(); events.clear(); @@ -216,7 +220,8 @@ fn edge_oneshot_triggered() { target_os = "freebsd", target_os = "netbsd", target_os = "openbsd", - target_os = "dragonfly" + target_os = "dragonfly", + all(target_os = "windows", feature = "iocp-psn") ), not(polling_test_poll_backend) ))] { @@ -251,6 +256,9 @@ fn edge_oneshot_triggered() { .unwrap(); assert!(events.is_empty()); + // On Windows, the buffer should be cleared to trigger the edge. + reader.read_exact(&mut [0; 2]).unwrap(); + // If we modify to re-enable the notification, it should be delivered. poller .modify_with_mode( @@ -259,6 +267,11 @@ fn edge_oneshot_triggered() { PollMode::EdgeOneshot, ) .unwrap(); + // On Windows, the notification won't be queued up. + // The condition must change while the registration is enabled. + #[cfg(windows)] + writer.write_all(&data).unwrap(); + events.clear(); poller .wait(&mut events, Some(Duration::from_secs(0))) From 5dacb1af3c28aef847bd0046a6ebc7aa8528a1c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 19:08:27 +0800 Subject: [PATCH 2/7] fix: use map in iter --- src/iocp/psn/mod.rs | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs index 892c446..2f41137 100644 --- a/src/iocp/psn/mod.rs +++ b/src/iocp/psn/mod.rs @@ -14,11 +14,11 @@ use wait::WaitCompletionPacket; use windows_sys::Win32::Foundation::{ERROR_SUCCESS, INVALID_HANDLE_VALUE, WAIT_TIMEOUT}; use windows_sys::Win32::Networking::WinSock::{ ProcessSocketNotifications, SOCK_NOTIFY_EVENT_ERR, SOCK_NOTIFY_EVENT_HANGUP, - SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_EVENT_REMOVE, SOCK_NOTIFY_OP_DISABLE, - SOCK_NOTIFY_OP_ENABLE, SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, - SOCK_NOTIFY_REGISTER_EVENT_IN, SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, - SOCK_NOTIFY_REGISTRATION, SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, - SOCK_NOTIFY_TRIGGER_ONESHOT, SOCK_NOTIFY_TRIGGER_PERSISTENT, + SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_OP_DISABLE, SOCK_NOTIFY_OP_ENABLE, + SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, SOCK_NOTIFY_REGISTER_EVENT_IN, + SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, SOCK_NOTIFY_REGISTRATION, + SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, SOCK_NOTIFY_TRIGGER_ONESHOT, + SOCK_NOTIFY_TRIGGER_PERSISTENT, }; use windows_sys::Win32::System::Threading::INFINITE; use windows_sys::Win32::System::IO::{ @@ -407,18 +407,9 @@ impl Events { /// Iterates over I/O events. pub fn iter(&self) -> impl Iterator + '_ { - self.list.iter().filter_map(|ev| { - let key = ev.lpCompletionKey; - // The post CompletionPacket. - if key == NOTIFY_KEY { - return None; - } + self.list.iter().map(|ev| { let events = ev.dwNumberOfBytesTransferred; - // Just ignore the remove event. - if events == SOCK_NOTIFY_EVENT_REMOVE { - return None; - } - Some(Event { + Event { key: ev.lpCompletionKey, readable: (events & SOCK_NOTIFY_EVENT_IN) != 0, writable: (events & SOCK_NOTIFY_EVENT_OUT) != 0, @@ -426,7 +417,7 @@ impl Events { hup: (events & SOCK_NOTIFY_EVENT_HANGUP) != 0, err: (events & SOCK_NOTIFY_EVENT_ERR) != 0, }, - }) + } }) } From 6104c0d932a8639b0782af195d2cadd83d495246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 19:28:57 +0800 Subject: [PATCH 3/7] doc: add comments for psn backend --- src/iocp/mod.rs | 5 +++++ src/iocp/psn/mod.rs | 31 ++++++++++++++++++++++++++++--- src/iocp/psn/wait.rs | 10 ++++++++++ 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index 4bcfddb..3f995b6 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -1,3 +1,8 @@ +//! Bindings to Windows I/O Completion Ports. +//! +//! There are two implementations. The `wepoll` one is classical and uses AFD subsystem. +//! The `psn` one uses `ProcessSocketNotifications` and need Windows 10 21H1+. + cfg_if::cfg_if! { if #[cfg(feature = "iocp-psn")] { mod psn; diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs index 2f41137..023ec76 100644 --- a/src/iocp/psn/mod.rs +++ b/src/iocp/psn/mod.rs @@ -1,3 +1,19 @@ +//! Bindings to Windows IOCP with `ProcessSocketNotifications` and +//! `NtAssociateWaitCompletionPacket` support. +//! +//! `ProcessSocketNotifications` is a new Windows API after 21H1. It is much like kqueue, +//! and support edge triggers. The implementation is easier to be adapted to the crate's API. +//! However, there are some behaviors different from other platforms: +//! - The `psn` poller distingushes "disabled" state and "removed" state. When the registration +//! disabled, the notifications won't be queued to the poller. +//! - The edge trigger only triggers condition changes after it is enabled. You cannot expect +//! an event coming if you change the condition before registering the notification. +//! - A socket can be registered to only one IOCP at a time. +//! +//! `NtAssociateWaitCompletionPacket` is an undocumented API and it's the back of thread pool +//! APIs like `RegisterWaitForSingleObject`. We use it to avoid starting thread pools. It only +//! supports `Oneshot` mode. + mod wait; use std::collections::HashMap; @@ -33,14 +49,18 @@ use crate::{Event, PollMode, NOTIFY_KEY}; pub struct Poller { /// The I/O completion port. port: Arc, + /// Attribute map. sources: RwLock>, } +/// Attributes of added sources. #[derive(Debug)] pub(crate) enum SourceAttr { - Socket { - key: usize, - }, + /// A socket with key. + Socket { key: usize }, + /// A waitable object with key and [`WaitCompletionPacket`]. + /// + /// [`WaitCompletionPacket`]: wait::WaitCompletionPacket Waitable { key: usize, packet: wait::WaitCompletionPacket, @@ -139,6 +159,7 @@ impl Poller { } } + /// Add a new waitable to the poller. pub(crate) fn add_waitable( &self, handle: RawHandle, @@ -180,6 +201,7 @@ impl Poller { ) } + /// Update a waitable in the poller. pub(crate) fn modify_waitable( &self, waitable: RawHandle, @@ -215,6 +237,7 @@ impl Poller { }) } + /// Delete a waitable from the poller. pub(crate) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> { tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable); @@ -286,6 +309,7 @@ impl Poller { .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) } + /// Add or modify the registration. unsafe fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> { let res = unsafe { ProcessSocketNotifications( @@ -353,6 +377,7 @@ impl Poller { self.post(CompletionPacket::new(Event::none(NOTIFY_KEY))) } + /// Push an IOCP packet into the queue. pub fn post(&self, packet: CompletionPacket) -> io::Result<()> { let span = tracing::trace_span!( "post", diff --git a/src/iocp/psn/wait.rs b/src/iocp/psn/wait.rs index 30a3749..fa98d73 100644 --- a/src/iocp/psn/wait.rs +++ b/src/iocp/psn/wait.rs @@ -1,3 +1,5 @@ +//! Safe wrapper around `NtAssociateWaitCompletionPacket` API series. + use std::ffi::c_void; use std::io; use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}; @@ -9,6 +11,7 @@ use windows_sys::Win32::Foundation::{ STATUS_CANCELLED, STATUS_PENDING, STATUS_SUCCESS, }; +#[link(name = "ntdll")] extern "system" { fn NtCreateWaitCompletionPacket( WaitCompletionPacketHandle: *mut HANDLE, @@ -33,6 +36,7 @@ extern "system" { ) -> NTSTATUS; } +/// Wrapper of NT WaitCompletionPacket. #[derive(Debug)] pub struct WaitCompletionPacket { handle: OwnedHandle, @@ -58,6 +62,8 @@ impl WaitCompletionPacket { Ok(Self { handle }) } + /// Associate waitable object to IOCP. The parameter `info` is the + /// field `dwNumberOfBytesTransferred` in `OVERLAPPED_ENTRY` pub fn associate( &mut self, port: RawHandle, @@ -80,6 +86,10 @@ impl WaitCompletionPacket { Ok(()) } + /// Cancels the completion packet. The return value means: + /// - `Ok(true)`: cancellation is successful. + /// - `Ok(false)`: cancellation failed, the packet is still in use. + /// - `Err(e)`: other errors. pub fn cancel(&mut self) -> io::Result { let status = unsafe { NtCancelWaitCompletionPacket(self.handle.as_raw_handle() as _, 0) }; match status { From 7bfcc53a3e1d8c6f74ee735c7fc1412c8dea75ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 19:46:57 +0800 Subject: [PATCH 4/7] feat: split sources and waitables --- src/iocp/psn/mod.rs | 195 +++++++++++++++++--------------------------- 1 file changed, 74 insertions(+), 121 deletions(-) diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs index 023ec76..f0adbf6 100644 --- a/src/iocp/psn/mod.rs +++ b/src/iocp/psn/mod.rs @@ -23,7 +23,7 @@ use std::os::windows::io::{ RawHandle, RawSocket, }; use std::ptr::null_mut; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::time::Duration; use wait::WaitCompletionPacket; @@ -44,27 +44,35 @@ use windows_sys::Win32::System::IO::{ use super::dur2timeout; use crate::{Event, PollMode, NOTIFY_KEY}; +/// Macro to lock and ignore lock poisoning. +macro_rules! lock { + ($lock_result:expr) => {{ + $lock_result.unwrap_or_else(|e| e.into_inner()) + }}; +} + /// Interface to kqueue. #[derive(Debug)] pub struct Poller { /// The I/O completion port. port: Arc, - /// Attribute map. - sources: RwLock>, + + /// The state of the sources registered with this poller. + /// + /// Each source is keyed by its raw socket ID. + sources: RwLock>, + + /// The state of the waitable handles registered with this poller. + waitables: Mutex>, } -/// Attributes of added sources. +/// A waitable object with key and [`WaitCompletionPacket`]. +/// +/// [`WaitCompletionPacket`]: wait::WaitCompletionPacket #[derive(Debug)] -pub(crate) enum SourceAttr { - /// A socket with key. - Socket { key: usize }, - /// A waitable object with key and [`WaitCompletionPacket`]. - /// - /// [`WaitCompletionPacket`]: wait::WaitCompletionPacket - Waitable { - key: usize, - packet: wait::WaitCompletionPacket, - }, +struct WaitableAttr { + key: usize, + packet: wait::WaitCompletionPacket, } impl Poller { @@ -80,6 +88,7 @@ impl Poller { Ok(Poller { port, sources: RwLock::default(), + waitables: Mutex::default(), }) } @@ -107,11 +116,11 @@ impl Poller { ); let _enter = span.enter(); - self.add_source( - socket as _, - SourceAttr::Socket { key: interest.key }, - |_| Ok(()), - )?; + let mut sources = lock!(self.sources.write()); + if sources.contains_key(&socket) { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + sources.insert(socket, interest.key); let info = create_registration(socket, interest, mode, true); self.update_source(info) @@ -134,7 +143,9 @@ impl Poller { let socket = socket.as_raw_socket(); - self.has_socket(socket as _)?; + lock!(self.sources.read()) + .get(&socket) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; let info = create_registration(socket, interest, mode, true); unsafe { self.update_source(info) } @@ -151,12 +162,11 @@ impl Poller { let socket = socket.as_raw_socket(); - if let SourceAttr::Socket { key } = self.remove_source(socket as _)? { - let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); - unsafe { self.update_source(info) } - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } + let key = lock!(self.sources.write()) + .remove(&socket) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); + unsafe { self.update_source(info) } } /// Add a new waitable to the poller. @@ -182,23 +192,20 @@ impl Poller { let key = interest.key; - let packet = wait::WaitCompletionPacket::new()?; - self.add_source( - handle as _, - SourceAttr::Waitable { key, packet }, - |source| { - if let SourceAttr::Waitable { key, packet } = source { - packet.associate( - self.port.as_raw_handle(), - handle, - *key, - interest_to_events(&interest) as _, - ) - } else { - unreachable!() - } - }, - ) + let mut waitables = lock!(self.waitables.lock()); + if waitables.contains_key(&handle) { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + + let mut packet = wait::WaitCompletionPacket::new()?; + packet.associate( + self.port.as_raw_handle(), + handle, + key, + interest_to_events(&interest) as _, + )?; + waitables.insert(handle, WaitableAttr { key, packet }); + Ok(()) } /// Update a waitable in the poller. @@ -222,91 +229,34 @@ impl Poller { )); } - self.has_waitable(waitable as _, |key, packet| { - let cancelled = packet.cancel()?; - if !cancelled { - // The packet could not be reused, create a new one. - *packet = WaitCompletionPacket::new()?; - } - packet.associate( - self.port.as_raw_handle(), - waitable, - key, - interest_to_events(&interest) as _, - ) - }) + let mut waitables = lock!(self.waitables.lock()); + let WaitableAttr { key, packet } = waitables + .get_mut(&waitable) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + + let cancelled = packet.cancel()?; + if !cancelled { + // The packet could not be reused, create a new one. + *packet = WaitCompletionPacket::new()?; + } + packet.associate( + self.port.as_raw_handle(), + waitable, + *key, + interest_to_events(&interest) as _, + ) } /// Delete a waitable from the poller. pub(crate) fn remove_waitable(&self, waitable: RawHandle) -> io::Result<()> { tracing::trace!("remove: handle={:?}, waitable={:p}", self.port, waitable); - if let SourceAttr::Waitable { mut packet, .. } = self.remove_source(waitable as _)? { - packet.cancel()?; - Ok(()) - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } - } + let WaitableAttr { mut packet, .. } = lock!(self.waitables.lock()) + .remove(&waitable) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; - /// Add a source to the sources set. - #[inline] - pub(crate) fn add_source( - &self, - handle: usize, - source: SourceAttr, - handler: impl FnOnce(&mut SourceAttr) -> io::Result<()>, - ) -> io::Result<()> { - let mut sources = self.sources.write().unwrap_or_else(|e| e.into_inner()); - if sources.contains_key(&handle) { - return Err(io::Error::from(io::ErrorKind::AlreadyExists)); - } - let source = sources.entry(handle).or_insert(source); - handler(source) - } - - /// Tell if a socket is currently inside the set. - #[inline] - pub(crate) fn has_socket(&self, handle: usize) -> io::Result { - if let Some(SourceAttr::Socket { key }) = self - .sources - .read() - .unwrap_or_else(|e| e.into_inner()) - .get(&handle) - { - Ok(*key) - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } - } - - /// Tell if a waitable is currently inside the set. - #[inline] - pub(crate) fn has_waitable( - &self, - handle: usize, - handler: impl FnOnce(usize, &mut WaitCompletionPacket) -> io::Result<()>, - ) -> io::Result<()> { - if let Some(SourceAttr::Waitable { key, packet }) = self - .sources - .write() - .unwrap_or_else(|e| e.into_inner()) - .get_mut(&handle) - { - handler(*key, packet) - } else { - Err(io::Error::from(io::ErrorKind::NotFound)) - } - } - - /// Remove a source from the sources set. - #[inline] - pub(crate) fn remove_source(&self, handle: usize) -> io::Result { - self.sources - .write() - .unwrap_or_else(|e| e.into_inner()) - .remove(&handle) - .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) + packet.cancel()?; + Ok(()) } /// Add or modify the registration. @@ -415,6 +365,9 @@ impl AsHandle for Poller { } } +unsafe impl Send for Poller {} +unsafe impl Sync for Poller {} + /// A list of reported I/O events. pub struct Events { list: Vec, From 34f7cac5ca4fbad28b8a6bfe663ea8d212aac4ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 19:58:18 +0800 Subject: [PATCH 5/7] ci: add iocp-psn test --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a61bec8..879f0d2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,6 +61,8 @@ jobs: # the backend that uses pipes, and is not a public API. RUSTFLAGS: ${{ env.RUSTFLAGS }} --cfg polling_test_epoll_pipe if: startsWith(matrix.os, 'ubuntu') + - run: cargo test --features iocp-psn + if: startsWith(matrix.os, 'windows') - run: cargo hack build --feature-powerset --no-dev-deps # TODO: broken due to https://github.com/rust-lang/rust/pull/119026. # - name: Check selected Tier 3 targets From 16b6fa6bddfc41eaf55a68f732bd22cd4cf40b52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 20:03:17 +0800 Subject: [PATCH 6/7] test: fix edge-oneshot trigger on other platforms --- tests/other_modes.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/other_modes.rs b/tests/other_modes.rs index dfb2c1f..ecb1c48 100644 --- a/tests/other_modes.rs +++ b/tests/other_modes.rs @@ -256,9 +256,6 @@ fn edge_oneshot_triggered() { .unwrap(); assert!(events.is_empty()); - // On Windows, the buffer should be cleared to trigger the edge. - reader.read_exact(&mut [0; 2]).unwrap(); - // If we modify to re-enable the notification, it should be delivered. poller .modify_with_mode( @@ -267,10 +264,15 @@ fn edge_oneshot_triggered() { PollMode::EdgeOneshot, ) .unwrap(); - // On Windows, the notification won't be queued up. - // The condition must change while the registration is enabled. + #[cfg(windows)] - writer.write_all(&data).unwrap(); + { + // On Windows, the buffer should be cleared to trigger the edge. + reader.read_exact(&mut [0; 2]).unwrap(); + // On Windows, the notification won't be queued up. + // The condition must change while the registration is enabled. + writer.write_all(&data).unwrap(); + } events.clear(); poller From e2d17ed82dce1b6b580472c88509f126969121e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=87=E9=80=B8?= Date: Tue, 11 Jun 2024 22:59:58 +0800 Subject: [PATCH 7/7] fix: handle REMOVE event --- src/iocp/psn/mod.rs | 131 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 115 insertions(+), 16 deletions(-) diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs index f0adbf6..0deb382 100644 --- a/src/iocp/psn/mod.rs +++ b/src/iocp/psn/mod.rs @@ -18,6 +18,7 @@ mod wait; use std::collections::HashMap; use std::io; +use std::mem::MaybeUninit; use std::os::windows::io::{ AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, FromRawHandle, OwnedHandle, RawHandle, RawSocket, @@ -30,15 +31,15 @@ use wait::WaitCompletionPacket; use windows_sys::Win32::Foundation::{ERROR_SUCCESS, INVALID_HANDLE_VALUE, WAIT_TIMEOUT}; use windows_sys::Win32::Networking::WinSock::{ ProcessSocketNotifications, SOCK_NOTIFY_EVENT_ERR, SOCK_NOTIFY_EVENT_HANGUP, - SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_OP_DISABLE, SOCK_NOTIFY_OP_ENABLE, - SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, SOCK_NOTIFY_REGISTER_EVENT_IN, - SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, SOCK_NOTIFY_REGISTRATION, - SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, SOCK_NOTIFY_TRIGGER_ONESHOT, - SOCK_NOTIFY_TRIGGER_PERSISTENT, + SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_EVENT_REMOVE, SOCK_NOTIFY_OP_DISABLE, + SOCK_NOTIFY_OP_ENABLE, SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, + SOCK_NOTIFY_REGISTER_EVENT_IN, SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, + SOCK_NOTIFY_REGISTRATION, SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, + SOCK_NOTIFY_TRIGGER_ONESHOT, SOCK_NOTIFY_TRIGGER_PERSISTENT, }; use windows_sys::Win32::System::Threading::INFINITE; use windows_sys::Win32::System::IO::{ - CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED_ENTRY, + CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED, OVERLAPPED_ENTRY, }; use super::dur2timeout; @@ -143,12 +144,18 @@ impl Poller { let socket = socket.as_raw_socket(); - lock!(self.sources.read()) + let sources = lock!(self.sources.read()); + let oldkey = sources .get(&socket) .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + if oldkey != &interest.key { + // To change the key, remove the old registration and wait for REMOVE event. + let info = create_registration(socket, Event::none(*oldkey), PollMode::Oneshot, false); + self.update_and_wait_for_remove(info, *oldkey)?; + } let info = create_registration(socket, interest, mode, true); - unsafe { self.update_source(info) } + self.update_source(info) } /// Deletes a socket. @@ -166,7 +173,7 @@ impl Poller { .remove(&socket) .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); - unsafe { self.update_source(info) } + self.update_and_wait_for_remove(info, key) } /// Add a new waitable to the poller. @@ -260,7 +267,7 @@ impl Poller { } /// Add or modify the registration. - unsafe fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> { + fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> { let res = unsafe { ProcessSocketNotifications( self.port.as_raw_handle() as _, @@ -283,6 +290,94 @@ impl Poller { } } + /// Attempt to remove a registration, and wait for the `SOCK_NOTIFY_EVENT_REMOVE` event. + fn update_and_wait_for_remove( + &self, + mut reg: SOCK_NOTIFY_REGISTRATION, + key: usize, + ) -> io::Result<()> { + debug_assert_eq!(reg.operation, SOCK_NOTIFY_OP_REMOVE as _); + let mut received = 0; + let mut entry: MaybeUninit = MaybeUninit::uninit(); + + let repost = |entry: OVERLAPPED_ENTRY| { + self.post_raw( + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + ) + }; + + // Update the registration and wait for the event in the same time. + // However, the returned completion entry may not be the wanted REMOVE event. + let res = unsafe { + ProcessSocketNotifications( + self.port.as_raw_handle() as _, + 1, + &mut reg, + 0, + 1, + entry.as_mut_ptr().cast(), + &mut received, + ) + }; + match res { + ERROR_SUCCESS | WAIT_TIMEOUT => { + if reg.registrationResult != ERROR_SUCCESS { + // If the registration is not successful, the received entry should be reposted. + if received == 1 { + repost(unsafe { entry.assume_init() })?; + } + return Err(io::Error::from_raw_os_error(reg.registrationResult as _)); + } + } + _ => return Err(io::Error::from_raw_os_error(res as _)), + } + if received == 1 { + // The registration is successful, and check the received entry. + let entry = unsafe { entry.assume_init() }; + if entry.lpCompletionKey == key { + // If the entry is current key but not the remove event, just ignore it. + if (entry.dwNumberOfBytesTransferred & SOCK_NOTIFY_EVENT_REMOVE) != 0 { + return Ok(()); + } + } else { + repost(entry)?; + } + } + + // No wanted event, start a loop to wait for it. + // TODO: any better solutions? + loop { + let res = unsafe { + ProcessSocketNotifications( + self.port.as_raw_handle() as _, + 0, + null_mut(), + 0, + 1, + entry.as_mut_ptr().cast(), + &mut received, + ) + }; + match res { + ERROR_SUCCESS => { + debug_assert_eq!(received, 1); + let entry = unsafe { entry.assume_init() }; + if entry.lpCompletionKey == key { + if (entry.dwNumberOfBytesTransferred & SOCK_NOTIFY_EVENT_REMOVE) != 0 { + return Ok(()); + } + } else { + repost(entry)?; + } + } + WAIT_TIMEOUT => {} + _ => return Err(io::Error::from_raw_os_error(res as _)), + } + } + } + /// Waits for I/O events with an optional timeout. pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { let span = tracing::trace_span!( @@ -337,13 +432,17 @@ impl Poller { let _enter = span.enter(); let event = packet.event(); + self.post_raw(interest_to_events(event), event.key, null_mut()) + } + + fn post_raw( + &self, + transferred: u32, + key: usize, + overlapped: *mut OVERLAPPED, + ) -> io::Result<()> { let res = unsafe { - PostQueuedCompletionStatus( - self.port.as_raw_handle() as _, - interest_to_events(event), - event.key, - null_mut(), - ) + PostQueuedCompletionStatus(self.port.as_raw_handle() as _, transferred, key, overlapped) }; if res == 0 { Err(io::Error::last_os_error())