diff --git a/Cargo.lock b/Cargo.lock index 36665ca5945b..400a373f591d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -280,6 +280,12 @@ version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.3.0" @@ -3253,6 +3259,21 @@ dependencies = [ "sqlparser", ] +[[package]] +name = "polars-stream" +version = "0.40.0" +dependencies = [ + "atomic-waker", + "crossbeam-deque", + "crossbeam-utils", + "parking_lot", + "pin-project-lite", + "polars-utils", + "rand", + "slotmap", + "version_check", +] + [[package]] name = "polars-time" version = "0.40.0" @@ -4131,6 +4152,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "slotmap" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" +dependencies = [ + "version_check", +] + [[package]] name = "smallvec" version = "1.13.2" diff --git a/Cargo.toml b/Cargo.toml index 20f726e7b12e..68a6f195e001 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ arrow-data = { version = ">=41", default-features = false } arrow-schema = { version = ">=41", default-features = false } atoi = "2" atoi_simd = "0.15.5" +atomic-waker = "1" avro-schema = { version = "0.3" } base64 = "0.22.0" bitflags = "2" @@ -39,7 +40,9 @@ chrono = { version = "0.4.31", default-features = false, features = ["std"] } chrono-tz = "0.8.1" ciborium = "0.2" crossbeam-channel = "0.5.8" +crossbeam-deque = "0.8.5" crossbeam-queue = "0.3" +crossbeam-utils = "0.8.20" either = "1.11" ethnum = "1.3.2" fallible-streaming-iterator = "0.1.9" @@ -57,7 +60,9 @@ ndarray = { version = "0.15", default-features = false } num-traits = "0.2" object_store = { version = "0.9", default-features = false } once_cell = "1" +parking_lot = "0.12" percent-encoding = "2.3" +pin-project-lite = "0.2" pyo3 = "0.21" rand = "0.8" rand_distr = "0.4" @@ -71,6 +76,7 @@ serde = { version = "1.0.188", features = ["derive"] } serde_json = "1" simd-json = { version = "0.13", features = ["known-key"] } simdutf8 = "0.1.4" +slotmap = "1" smartstring = "1" sqlparser = "0.45" stacker = "0.1" diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml new file mode 100644 index 000000000000..e4bdff39db8d --- /dev/null +++ b/crates/polars-stream/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "polars-stream" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Private crate for the streaming execution engine for the Polars DataFrame library" + +[dependencies] +atomic-waker = { workspace = true } +crossbeam-deque = { workspace = true } +crossbeam-utils = { workspace = true } +parking_lot = { workspace = true } +pin-project-lite = { workspace = true } +polars-utils = { workspace = true } +rand = { workspace = true } +slotmap = { workspace = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +nightly = [] diff --git a/crates/polars-stream/LICENSE b/crates/polars-stream/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-stream/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-stream/README.md b/crates/polars-stream/README.md new file mode 100644 index 000000000000..c16aedf1901e --- /dev/null +++ b/crates/polars-stream/README.md @@ -0,0 +1,5 @@ +# polars-stream + +`polars-stream` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, containing a streaming execution engine. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-stream/build.rs b/crates/polars-stream/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-stream/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-stream/src/async_primitives/distributor_channel.rs b/crates/polars-stream/src/async_primitives/distributor_channel.rs new file mode 100644 index 000000000000..59fcbe88c6f1 --- /dev/null +++ b/crates/polars-stream/src/async_primitives/distributor_channel.rs @@ -0,0 +1,276 @@ +use std::cell::UnsafeCell; +use std::mem::MaybeUninit; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use crossbeam_utils::CachePadded; +use rand::prelude::*; + +use super::task_parker::TaskParker; + +/// Single-producer multi-consumer FIFO channel. +/// +/// Each [`Receiver`] has an internal buffer of `bufsize`. Thus it is possible +/// that when one [`Sender`] is exhausted some other receivers still have data +/// available. +/// +/// The FIFO order is only guaranteed per receiver. That is, each receiver is +/// guaranteed to see a subset of the data sent by the sender in the order the +/// sender sent it in, but not necessarily contiguously. +pub fn distributor_channel( + num_receivers: usize, + bufsize: usize, +) -> (Sender, Vec>) { + let capacity = bufsize.next_power_of_two(); + let receivers = (0..num_receivers) + .map(|_| { + CachePadded::new(ReceiverSlot { + closed: AtomicBool::new(false), + read_head: AtomicUsize::new(0), + parker: TaskParker::default(), + data: (0..capacity) + .map(|_| UnsafeCell::new(MaybeUninit::uninit())) + .collect(), + }) + }) + .collect(); + let inner = Arc::new(DistributorInner { + send_closed: AtomicBool::new(false), + send_parker: TaskParker::default(), + write_heads: (0..num_receivers).map(|_| AtomicUsize::new(0)).collect(), + receivers, + + bufsize, + mask: capacity - 1, + }); + + let receivers = (0..num_receivers) + .map(|index| Receiver { + inner: inner.clone(), + index, + }) + .collect(); + + let sender = Sender { + inner, + round_robin_idx: 0, + rng: SmallRng::from_rng(&mut rand::thread_rng()).unwrap(), + }; + + (sender, receivers) +} + +pub enum SendError { + Full(T), + Closed(T), +} + +pub enum RecvError { + Empty, + Closed, +} + +struct ReceiverSlot { + closed: AtomicBool, + read_head: AtomicUsize, + parker: TaskParker, + data: Box<[UnsafeCell>]>, +} + +struct DistributorInner { + send_closed: AtomicBool, + send_parker: TaskParker, + write_heads: Vec, + receivers: Vec>>, + + bufsize: usize, + mask: usize, +} + +impl DistributorInner { + fn reduce_index(&self, idx: usize) -> usize { + idx & self.mask + } +} + +pub struct Sender { + inner: Arc>, + round_robin_idx: usize, + rng: SmallRng, +} + +pub struct Receiver { + inner: Arc>, + index: usize, +} + +unsafe impl Send for Sender {} +unsafe impl Send for Receiver {} + +impl Sender { + pub async fn send(&mut self, mut value: T) -> Result<(), T> { + let num_receivers = self.inner.receivers.len(); + loop { + // Fast-path. + self.round_robin_idx += 1; + if self.round_robin_idx >= num_receivers { + self.round_robin_idx -= num_receivers; + } + + let mut hungriest_idx = self.round_robin_idx; + let mut shortest_len = self.upper_bound_len(self.round_robin_idx); + for _ in 0..4 { + let idx = ((self.rng.gen::() as u64 * num_receivers as u64) >> 32) as usize; + let len = self.upper_bound_len(idx); + if len < shortest_len { + shortest_len = len; + hungriest_idx = idx; + } + } + + match self.try_send(hungriest_idx, value) { + Ok(()) => return Ok(()), + Err(SendError::Full(v)) => value = v, + Err(SendError::Closed(v)) => value = v, + } + + // Do one proper search before parking. + let park = self.inner.send_parker.park(); + + // Try all receivers, starting at a random index. + let mut idx = ((self.rng.gen::() as u64 * num_receivers as u64) >> 32) as usize; + let mut all_closed = true; + for _ in 0..num_receivers { + match self.try_send(idx, value) { + Ok(()) => return Ok(()), + Err(SendError::Full(v)) => { + all_closed = false; + value = v; + }, + Err(SendError::Closed(v)) => value = v, + } + + idx += 1; + if idx >= num_receivers { + idx -= num_receivers; + } + } + + if all_closed { + return Err(value); + } + + park.await; + } + } + + fn upper_bound_len(&self, recv_idx: usize) -> usize { + let read_head = self.inner.receivers[recv_idx] + .read_head + .load(Ordering::SeqCst); + let write_head = self.inner.write_heads[recv_idx].load(Ordering::Relaxed); + write_head.wrapping_sub(read_head) + } + + fn try_send(&self, recv_idx: usize, value: T) -> Result<(), SendError> { + let read_head = self.inner.receivers[recv_idx] + .read_head + .load(Ordering::SeqCst); + let write_head = self.inner.write_heads[recv_idx].load(Ordering::Relaxed); + let len = write_head.wrapping_sub(read_head); + if len < self.inner.bufsize { + let idx = self.inner.reduce_index(write_head); + unsafe { + self.inner.receivers[recv_idx].data[idx] + .get() + .write(MaybeUninit::new(value)); + self.inner.write_heads[recv_idx] + .store(write_head.wrapping_add(1), Ordering::SeqCst); + } + self.inner.receivers[recv_idx].parker.unpark(); + Ok(()) + } else { + Err(SendError::Full(value)) + } + } +} + +impl Receiver { + pub async fn recv(&mut self) -> Result { + loop { + // Fast-path. + match self.try_recv() { + Ok(v) => return Ok(v), + Err(RecvError::Closed) => return Err(()), + Err(RecvError::Empty) => {}, + } + + // Try again, threatening to park if there's still nothing. + let park = self.inner.receivers[self.index].parker.park(); + match self.try_recv() { + Ok(v) => return Ok(v), + Err(RecvError::Closed) => return Err(()), + Err(RecvError::Empty) => {}, + } + park.await; + } + } + + fn try_recv(&self) -> Result { + let read_head = self.inner.receivers[self.index] + .read_head + .load(Ordering::Relaxed); + let write_head = self.inner.write_heads[self.index].load(Ordering::SeqCst); + if read_head != write_head { + let idx = self.inner.reduce_index(read_head); + let read; + unsafe { + let ptr = self.inner.receivers[self.index].data[idx].get(); + read = ptr.read().assume_init(); + self.inner.receivers[self.index] + .read_head + .store(read_head.wrapping_add(1), Ordering::SeqCst); + } + self.inner.send_parker.unpark(); + Ok(read) + } else if self.inner.send_closed.load(Ordering::SeqCst) { + Err(RecvError::Closed) + } else { + Err(RecvError::Empty) + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.inner.send_closed.store(true, Ordering::SeqCst); + for recv in &self.inner.receivers { + recv.parker.unpark(); + } + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + self.inner.receivers[self.index] + .closed + .store(true, Ordering::SeqCst); + self.inner.send_parker.unpark(); + } +} + +impl Drop for DistributorInner { + fn drop(&mut self) { + for r in 0..self.receivers.len() { + while self.receivers[r].read_head.load(Ordering::Relaxed) + != self.write_heads[r].load(Ordering::Relaxed) + { + let read_head = self.receivers[r].read_head.fetch_add(1, Ordering::Relaxed); + let idx = self.reduce_index(read_head); + unsafe { + (*self.receivers[r].data[idx].get()).assume_init_drop(); + } + } + } + } +} diff --git a/crates/polars-stream/src/async_primitives/mod.rs b/crates/polars-stream/src/async_primitives/mod.rs new file mode 100644 index 000000000000..a4074b7e77b0 --- /dev/null +++ b/crates/polars-stream/src/async_primitives/mod.rs @@ -0,0 +1,4 @@ +pub mod distributor_channel; +pub mod pipe; +pub mod task_parker; +pub mod wait_group; diff --git a/crates/polars-stream/src/async_primitives/pipe.rs b/crates/polars-stream/src/async_primitives/pipe.rs new file mode 100644 index 000000000000..52310c0c51ad --- /dev/null +++ b/crates/polars-stream/src/async_primitives/pipe.rs @@ -0,0 +1,278 @@ +use std::cell::UnsafeCell; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; + +use atomic_waker::AtomicWaker; +use pin_project_lite::pin_project; + +/// Single-producer, single-consumer capacity-one channel. +pub fn pipe() -> (Sender, Receiver) { + let pipe = Arc::new(Pipe::default()); + (Sender { pipe: pipe.clone() }, Receiver { pipe }) +} + +/* + For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive + access to value to the receiver), and a receiver may only unset the FULL_BIT + (giving exclusive access back to the sender). + + The exception is when the closed bit is set, at that point the unclosed + end has full exclusive access. +*/ + +const FULL_BIT: u8 = 0b1; +const CLOSED_BIT: u8 = 0b10; +const WAITING_BIT: u8 = 0b100; + +#[repr(align(64))] +struct Pipe { + send_waker: AtomicWaker, + recv_waker: AtomicWaker, + value: UnsafeCell>, + state: AtomicU8, +} + +impl Default for Pipe { + fn default() -> Self { + Self { + send_waker: AtomicWaker::new(), + recv_waker: AtomicWaker::new(), + value: UnsafeCell::new(MaybeUninit::uninit()), + state: AtomicU8::new(0), + } + } +} + +impl Drop for Pipe { + fn drop(&mut self) { + if self.state.load(Ordering::Acquire) & FULL_BIT == FULL_BIT { + unsafe { + self.value.get().drop_in_place(); + } + } + } +} + +pub enum SendError { + Full(T), + Closed(T), +} + +pub enum RecvError { + Empty, + Closed, +} + +// SAFETY: all the send methods may only be called from a single sender at a +// time, and similarly for all the recv methods from a single receiver. +impl Pipe { + unsafe fn poll_send(&self, value: &mut Option, waker: &Waker) -> Poll> { + if let Some(v) = value.take() { + let mut state = self.state.load(Ordering::Relaxed); + if state & FULL_BIT == FULL_BIT { + self.send_waker.register(waker); + let (Ok(s) | Err(s)) = self.state.compare_exchange( + state, + state | WAITING_BIT, + Ordering::Release, + Ordering::Relaxed, + ); + state = s; + } + + match self.try_send_impl(v, state) { + Ok(()) => {}, + Err(SendError::Closed(v)) => return Poll::Ready(Err(v)), + Err(SendError::Full(v)) => { + *value = Some(v); + return Poll::Pending; + }, + } + } + + Poll::Ready(Ok(())) + } + + unsafe fn try_send_impl(&self, value: T, state: u8) -> Result<(), SendError> { + if state & CLOSED_BIT == CLOSED_BIT { + return Err(SendError::Closed(value)); + } + if state & FULL_BIT == FULL_BIT { + return Err(SendError::Full(value)); + } + + unsafe { + self.value.get().write(MaybeUninit::new(value)); + let state = self.state.swap(FULL_BIT, Ordering::AcqRel); + if state & WAITING_BIT == WAITING_BIT { + self.recv_waker.wake(); + } + if state & CLOSED_BIT == CLOSED_BIT { + self.state.store(CLOSED_BIT, Ordering::Relaxed); + return Err(SendError::Closed(self.value.get().read().assume_init())); + } + } + + Ok(()) + } + + unsafe fn poll_recv(&self, waker: &Waker) -> Poll> { + let mut state = self.state.load(Ordering::Acquire); + if state & FULL_BIT == 0 { + self.recv_waker.register(waker); + let (Ok(s) | Err(s)) = self.state.compare_exchange( + state, + state | WAITING_BIT, + Ordering::Release, + Ordering::Acquire, + ); + state = s; + } + + match self.try_recv_impl(state) { + Ok(v) => Poll::Ready(Ok(v)), + Err(RecvError::Empty) => Poll::Pending, + Err(RecvError::Closed) => Poll::Ready(Err(())), + } + } + + unsafe fn try_recv_impl(&self, state: u8) -> Result { + if state & FULL_BIT == FULL_BIT { + unsafe { + let ret = self.value.get().read().assume_init(); + let state = self.state.swap(0, Ordering::Acquire); + if state & WAITING_BIT == WAITING_BIT { + self.send_waker.wake(); + } + if state & CLOSED_BIT == CLOSED_BIT { + self.state.store(CLOSED_BIT, Ordering::Relaxed); + } + return Ok(ret); + } + } + + // Check closed bit last so we do receive any last element sent before + // closing sender. + if state & CLOSED_BIT == CLOSED_BIT { + return Err(RecvError::Closed); + } + + Err(RecvError::Empty) + } + + unsafe fn try_send(&self, value: T) -> Result<(), SendError> { + self.try_send_impl(value, self.state.load(Ordering::Relaxed)) + } + + unsafe fn try_recv(&self) -> Result { + self.try_recv_impl(self.state.load(Ordering::Acquire)) + } + + /// # Safety + /// After calling close as a sender/receiver, you may not access + /// this pipe anymore as that end. + unsafe fn close(&self) { + self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed); + self.send_waker.wake(); + self.recv_waker.wake(); + } +} + +pub struct Sender { + pipe: Arc>, +} + +unsafe impl Send for Sender {} + +impl Drop for Sender { + fn drop(&mut self) { + unsafe { self.pipe.close() } + } +} + +pub struct Receiver { + pipe: Arc>, +} + +unsafe impl Send for Receiver {} + +impl Drop for Receiver { + fn drop(&mut self) { + unsafe { self.pipe.close() } + } +} + +pin_project! { + pub struct SendFuture<'a, T> { + pipe: &'a Pipe, + value: Option, + } +} + +unsafe impl<'a, T: Send> Send for SendFuture<'a, T> {} + +impl Sender { + /// Returns a future that when awaited will send the value to the [`Receiver`]. + /// Returns Err(value) if the pipe is closed. + #[must_use] + pub fn send(&mut self, value: T) -> SendFuture<'_, T> { + SendFuture { + pipe: &self.pipe, + value: Some(value), + } + } + + pub fn try_send(&mut self, value: T) -> Result<(), SendError> { + unsafe { self.pipe.try_send(value) } + } +} + +impl std::future::Future for SendFuture<'_, T> { + type Output = Result<(), T>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + assert!( + self.value.is_some(), + "re-poll after Poll::Ready in pipe SendFuture" + ); + unsafe { self.pipe.poll_send(self.project().value, cx.waker()) } + } +} + +pin_project! { + pub struct RecvFuture<'a, T> { + pipe: &'a Pipe, + done: bool, + } +} + +unsafe impl<'a, T: Send> Send for RecvFuture<'a, T> {} + +impl Receiver { + /// Returns a future that when awaited will return `Ok(value)` once the + /// value is received, or returns `Err(())` if the [`Sender`] was dropped + /// before sending a value. + #[must_use] + pub fn recv(&mut self) -> RecvFuture<'_, T> { + RecvFuture { + pipe: &self.pipe, + done: false, + } + } + + pub fn try_recv(&mut self) -> Result { + unsafe { self.pipe.try_recv() } + } +} + +impl std::future::Future for RecvFuture<'_, T> { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + assert!(!self.done, "re-poll after Poll::Ready in pipe SendFuture"); + unsafe { self.pipe.poll_recv(cx.waker()) } + } +} diff --git a/crates/polars-stream/src/async_primitives/task_parker.rs b/crates/polars-stream/src/async_primitives/task_parker.rs new file mode 100644 index 000000000000..9e48b79e468b --- /dev/null +++ b/crates/polars-stream/src/async_primitives/task_parker.rs @@ -0,0 +1,80 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::task::{Context, Poll, Waker}; + +use parking_lot::Mutex; + +#[derive(Default)] +pub struct TaskParker { + state: AtomicU8, + waker: Mutex>, +} + +impl TaskParker { + const RUNNING: u8 = 0; + const PREPARING_TO_PARK: u8 = 1; + const PARKED: u8 = 2; + + /// Returns a future that when awaited parks this task. + /// + /// Any notifications between calls to park and the await will cancel + /// the park attempt. + pub fn park(&self) -> TaskParkFuture<'_> { + self.state.store(Self::PREPARING_TO_PARK, Ordering::SeqCst); + TaskParkFuture { parker: self } + } + + /// Unparks the parked task, if it was parked. + pub fn unpark(&self) { + let state = self.state.load(Ordering::SeqCst); + if state != Self::RUNNING { + let old_state = self.state.swap(Self::RUNNING, Ordering::SeqCst); + if old_state == Self::PARKED { + if let Some(w) = self.waker.lock().take() { + w.wake(); + } + } + } + } +} + +pub struct TaskParkFuture<'a> { + parker: &'a TaskParker, +} + +impl<'a> Future for TaskParkFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut state = self.parker.state.load(Ordering::SeqCst); + loop { + match state { + TaskParker::RUNNING => return Poll::Ready(()), + + TaskParker::PARKED => { + // Refresh our waker. + match &mut *self.parker.waker.lock() { + Some(w) => w.clone_from(cx.waker()), + None => return Poll::Ready(()), // Apparently someone woke us up. + } + }, + TaskParker::PREPARING_TO_PARK => { + // Install waker first before publishing that we're parked + // to prevent missed notifications. + *self.parker.waker.lock() = Some(cx.waker().clone()); + match self.parker.state.compare_exchange_weak( + TaskParker::PREPARING_TO_PARK, + TaskParker::PARKED, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => return Poll::Pending, + Err(s) => state = s, + } + }, + _ => unreachable!(), + } + } + } +} diff --git a/crates/polars-stream/src/async_primitives/wait_group.rs b/crates/polars-stream/src/async_primitives/wait_group.rs new file mode 100644 index 000000000000..66a6e8c70170 --- /dev/null +++ b/crates/polars-stream/src/async_primitives/wait_group.rs @@ -0,0 +1,93 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; + +use parking_lot::Mutex; + +#[derive(Default)] +struct WaitGroupInner { + waker: Mutex>, + token_count: AtomicUsize, + is_waiting: AtomicBool, +} + +#[derive(Default)] +pub struct WaitGroup { + inner: Arc, +} + +impl WaitGroup { + /// Creates a token. + pub fn token(&self) -> WaitToken { + self.inner.token_count.fetch_add(1, Ordering::Relaxed); + WaitToken { + inner: Arc::clone(&self.inner), + } + } + + /// Waits until all created tokens are dropped. + /// + /// # Panics + /// Panics if there is more than one simultaneous waiter. + pub async fn wait(&self) { + let was_waiting = self.inner.is_waiting.swap(true, Ordering::Relaxed); + assert!(!was_waiting); + WaitGroupFuture { inner: &self.inner }.await + } +} + +struct WaitGroupFuture<'a> { + inner: &'a Arc, +} + +impl Future for WaitGroupFuture<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.token_count.load(Ordering::Acquire) == 0 { + return Poll::Ready(()); + } + + // Check count again while holding lock to prevent missed notifications. + let mut waker_lock = self.inner.waker.lock(); + if self.inner.token_count.load(Ordering::Acquire) == 0 { + return Poll::Ready(()); + } + + let waker = cx.waker().clone(); + *waker_lock = Some(waker); + Poll::Pending + } +} + +impl<'a> Drop for WaitGroupFuture<'a> { + fn drop(&mut self) { + self.inner.is_waiting.store(false, Ordering::Relaxed); + } +} + +pub struct WaitToken { + inner: Arc, +} + +impl Clone for WaitToken { + fn clone(&self) -> Self { + self.inner.token_count.fetch_add(1, Ordering::Relaxed); + Self { + inner: self.inner.clone(), + } + } +} + +impl Drop for WaitToken { + fn drop(&mut self) { + // Token count was 1, we must notify. + if self.inner.token_count.fetch_sub(1, Ordering::Release) == 1 { + if let Some(w) = self.inner.waker.lock().take() { + w.wake(); + } + } + } +} diff --git a/crates/polars-stream/src/executor/mod.rs b/crates/polars-stream/src/executor/mod.rs new file mode 100644 index 000000000000..7a2e08d96d17 --- /dev/null +++ b/crates/polars-stream/src/executor/mod.rs @@ -0,0 +1,345 @@ +mod park_group; +mod task; + +use std::cell::{Cell, UnsafeCell}; +use std::future::Future; +use std::marker::PhantomData; +use std::panic::AssertUnwindSafe; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, OnceLock, Weak}; + +use crossbeam_deque::{Injector, Steal, Stealer, Worker as WorkQueue}; +use crossbeam_utils::CachePadded; +use park_group::ParkGroup; +use parking_lot::Mutex; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +use slotmap::SlotMap; +use task::{CancelHandle, JoinHandle, Runnable}; + +static NUM_EXECUTOR_THREADS: AtomicUsize = AtomicUsize::new(0); +pub fn set_num_threads(t: usize) { + NUM_EXECUTOR_THREADS.store(t, Ordering::Relaxed); +} + +static GLOBAL_SCHEDULER: OnceLock = OnceLock::new(); + +thread_local!( + /// Used to store which executor thread this is. + static TLS_THREAD_ID: Cell = const { Cell::new(usize::MAX) }; +); + +slotmap::new_key_type! { + struct TaskKey; +} + +/// Metadata associated with a task to help schedule it and clean it up. +struct TaskMetadata { + priority: bool, + + task_key: TaskKey, + completed_tasks: Weak>>, +} + +impl Drop for TaskMetadata { + fn drop(&mut self) { + if let Some(completed_tasks) = self.completed_tasks.upgrade() { + completed_tasks.lock().push(self.task_key); + } + } +} + +/// A task ready to run. +type ReadyTask = Runnable; + +/// A per-thread task list. +struct ThreadLocalTaskList { + // May be used from any thread. + high_prio_tasks_stealer: Stealer, + + // SAFETY: these may only be used on the thread this task list belongs to. + high_prio_tasks: WorkQueue, + local_slot: UnsafeCell>, +} + +unsafe impl Sync for ThreadLocalTaskList {} + +struct Executor { + park_group: ParkGroup, + thread_task_lists: Vec>, + global_high_prio_task_queue: Injector, + global_low_prio_task_queue: Injector, +} + +impl Executor { + fn schedule_task(&self, task: ReadyTask) { + let thread = TLS_THREAD_ID.get(); + let priority = task.metadata().priority; + if let Some(ttl) = self.thread_task_lists.get(thread) { + // SAFETY: this slot may only be accessed from the local thread, which we are. + let slot = unsafe { &mut *ttl.local_slot.get() }; + + if priority { + // Insert new task into thread local slot, taking out the old task. + let Some(task) = slot.replace(task) else { + // We pushed a task into our local slot which was empty. Since + // we are already awake, no need to notify anyone. + return; + }; + + ttl.high_prio_tasks.push(task); + self.park_group.unpark_one(); + } else { + // Optimization: while this is a low priority task we have no + // high priority tasks on this thread so we'll execute this one. + if ttl.high_prio_tasks.is_empty() && slot.is_none() { + *slot = Some(task); + } else { + self.global_low_prio_task_queue.push(task); + self.park_group.unpark_one(); + } + } + } else { + // Scheduled from an unknown thread, add to global queue. + if priority { + self.global_high_prio_task_queue.push(task); + } else { + self.global_low_prio_task_queue.push(task); + } + self.park_group.unpark_one(); + } + } + + fn try_steal_task(&self, thread: usize, rng: &mut R) -> Option { + // Try to get a global task. + loop { + match self.global_high_prio_task_queue.steal() { + Steal::Empty => break, + Steal::Success(task) => return Some(task), + Steal::Retry => std::hint::spin_loop(), + } + } + + loop { + match self.global_low_prio_task_queue.steal() { + Steal::Empty => break, + Steal::Success(task) => return Some(task), + Steal::Retry => std::hint::spin_loop(), + } + } + + // Try to steal tasks. + let ttl = &self.thread_task_lists[thread]; + for _ in 0..4 { + let mut retry = true; + while retry { + retry = false; + + for idx in random_permutation(self.thread_task_lists.len() as u32, rng) { + let foreign_ttl = &self.thread_task_lists[idx as usize]; + match foreign_ttl + .high_prio_tasks_stealer + .steal_batch_and_pop(&ttl.high_prio_tasks) + { + Steal::Empty => {}, + Steal::Success(task) => return Some(task), + Steal::Retry => retry = true, + } + } + + std::hint::spin_loop() + } + } + + None + } + + fn runner(&self, thread: usize) { + TLS_THREAD_ID.set(thread); + + let mut rng = SmallRng::from_rng(&mut rand::thread_rng()).unwrap(); + let mut worker = self.park_group.new_worker(); + + loop { + let ttl = &self.thread_task_lists[thread]; + let task = (|| { + // Try to get a task from LIFO slot. + if let Some(task) = unsafe { (*ttl.local_slot.get()).take() } { + return Some(task); + } + + // Try to get a local high-priority task. + if let Some(task) = ttl.high_prio_tasks.pop() { + return Some(task); + } + + // Try to steal a task. + if let Some(task) = self.try_steal_task(thread, &mut rng) { + return Some(task); + } + + // Prepare to park, then try one more steal attempt. + let park = worker.prepare_park(); + if let Some(task) = self.try_steal_task(thread, &mut rng) { + return Some(task); + } + park.park(); + None + })(); + + if let Some(task) = task { + worker.recruit_next(); + task.run(); + } + } + } + + fn global() -> &'static Executor { + GLOBAL_SCHEDULER.get_or_init(|| { + let mut n_threads = NUM_EXECUTOR_THREADS.load(Ordering::Relaxed); + if n_threads == 0 { + n_threads = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4); + } + + let thread_task_lists = (0..n_threads) + .map(|t| { + std::thread::spawn(move || Self::global().runner(t)); + + let high_prio_tasks = WorkQueue::new_lifo(); + CachePadded::new(ThreadLocalTaskList { + high_prio_tasks_stealer: high_prio_tasks.stealer(), + high_prio_tasks, + local_slot: UnsafeCell::new(None), + }) + }) + .collect(); + Self { + park_group: ParkGroup::new(), + thread_task_lists, + global_high_prio_task_queue: Injector::new(), + global_low_prio_task_queue: Injector::new(), + } + }) + } +} + +pub struct TaskScope<'scope, 'env: 'scope> { + // Keep track of in-progress tasks so we can forcibly cancel them + // when the scope ends, to ensure the lifetimes are respected. + // Tasks add their own key to completed_tasks when done so we can + // reclaim the memory used by the cancel_handles. + cancel_handles: Mutex>, + completed_tasks: Arc>>, + + // Copied from std::thread::scope. Necessary to prevent unsoundness. + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, +} + +impl<'scope, 'env> TaskScope<'scope, 'env> { + // Not Drop because that extends lifetimes. + fn destroy(&self) { + // Make sure all tasks are cancelled. + for (_, t) in self.cancel_handles.lock().drain() { + t.cancel(); + } + } + + fn clear_completed_tasks(&self) { + let mut cancel_handles = self.cancel_handles.lock(); + for t in self.completed_tasks.lock().drain(..) { + cancel_handles.remove(t); + } + } + + pub fn spawn_task( + &self, + priority: bool, + fut: F, + ) -> JoinHandle + where + ::Output: Send + 'static, + { + self.clear_completed_tasks(); + + let mut runnable = None; + let mut join_handle = None; + self.cancel_handles.lock().insert_with_key(|task_key| { + let (run, jh) = unsafe { + // SAFETY: we make sure to cancel this task before 'scope ends. + let executor = Executor::global(); + let on_wake = move |task| executor.schedule_task(task); + task::spawn_with_lifetime( + fut, + on_wake, + TaskMetadata { + task_key, + priority, + completed_tasks: Arc::downgrade(&self.completed_tasks), + }, + ) + }; + let cancel_handle = jh.cancel_handle(); + runnable = Some(run); + join_handle = Some(jh); + cancel_handle + }); + runnable.unwrap().schedule(); + join_handle.unwrap() + } +} + +pub fn task_scope<'env, F, T>(f: F) -> T +where + F: for<'scope> FnOnce(&'scope TaskScope<'scope, 'env>) -> T, +{ + // By having this local variable inaccessible to anyone we guarantee + // that either abort is called killing the entire process, or that this + // executor is properly destroyed. + let scope = TaskScope { + cancel_handles: Mutex::default(), + completed_tasks: Arc::new(Mutex::default()), + scope: PhantomData, + env: PhantomData, + }; + + let result = std::panic::catch_unwind(AssertUnwindSafe(|| f(&scope))); + + // Make sure all tasks are properly destroyed. + scope.destroy(); + + match result { + Err(e) => std::panic::resume_unwind(e), + Ok(result) => result, + } +} + +fn random_permutation(len: u32, rng: &mut R) -> impl Iterator { + let modulus = len.next_power_of_two(); + let halfwidth = modulus.trailing_zeros() / 2; + let mask = modulus - 1; + let displace_zero = rng.gen::(); + let odd1 = rng.gen::() | 1; + let odd2 = rng.gen::() | 1; + let uniform_first = ((rng.gen::() as u64 * len as u64) >> 32) as u32; + + (0..modulus) + .map(move |mut i| { + // Invertible permutation on [0, modulus). + i = i.wrapping_add(displace_zero); + i = i.wrapping_mul(odd1); + i ^= (i & mask) >> halfwidth; + i = i.wrapping_mul(odd2); + i & mask + }) + .filter(move |i| *i < len) + .map(move |mut i| { + i += uniform_first; + if i >= len { + i -= len; + } + i + }) +} diff --git a/crates/polars-stream/src/executor/park_group.rs b/crates/polars-stream/src/executor/park_group.rs new file mode 100644 index 000000000000..d9da30ce7f3e --- /dev/null +++ b/crates/polars-stream/src/executor/park_group.rs @@ -0,0 +1,216 @@ +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::sync::Arc; + +use parking_lot::{Condvar, Mutex}; + +/// A group of workers that can park / unpark each other. +/// +/// There is at most one worker at a time which is considered a 'recruiter'. +/// A recruiter hasn't yet found work and will either park again or recruit the +/// next worker when it finds work. +/// +/// Calls to park/unpark participate in a global SeqCst order. +#[derive(Default)] +pub struct ParkGroup { + inner: Arc, +} + +#[derive(Default)] +struct ParkGroupInner { + // The condvar we park with. + condvar: Condvar, + + // Contains the number of notifications and whether or not the next unparked + // worker should become a recruiter. + notifications: Mutex<(u32, bool)>, + + // Bits 0..32: number of idle workers. + // Bit 32: set if there is an active recruiter. + // Bit 33: set if a worker is preparing to park. + // Bits 34..64: version that is incremented to cancel a park request. + state: AtomicU64, + + num_workers: AtomicU32, +} + +const IDLE_UNIT: u64 = 1; +const ACTIVE_RECRUITER_BIT: u64 = 1 << 32; +const PREPARING_TO_PARK_BIT: u64 = 1 << 33; +const VERSION_UNIT: u64 = 1 << 34; + +fn state_num_idle(state: u64) -> u32 { + state as u32 +} + +fn state_version(state: u64) -> u32 { + (state >> 34) as u32 +} + +pub struct ParkGroupWorker { + inner: Arc, + recruiter: bool, + version: u32, +} + +impl ParkGroup { + pub fn new() -> Self { + Self::default() + } + + /// Creates a new worker. + /// + /// # Panics + /// Panics if you try to create more than 2^32 - 1 workers. + pub fn new_worker(&self) -> ParkGroupWorker { + self.inner + .num_workers + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |w| w.checked_add(1)) + .expect("can't have more than 2^32 - 1 workers"); + + ParkGroupWorker { + version: 0, + inner: Arc::clone(&self.inner), + recruiter: false, + } + } + + /// Unparks an idle worker if there is no recruiter. + /// + /// Also cancels in-progress park attempts. + pub fn unpark_one(&self) { + self.inner.unpark_one(); + } +} + +impl ParkGroupWorker { + /// Prepares to park this worker. + pub fn prepare_park(&mut self) -> ParkAttempt<'_> { + let mut state = self.inner.state.load(Ordering::SeqCst); + self.version = state_version(state); + + // If the version changes or someone else has set the + // PREPARING_TO_PARK_BIT, stop trying to update the state. + while state & PREPARING_TO_PARK_BIT == 0 && state_version(state) == self.version { + // Notify that we're preparing to park, and while we're at it might as + // well try to become a recruiter to avoid expensive unparks. + let new_state = state | PREPARING_TO_PARK_BIT | ACTIVE_RECRUITER_BIT; + match self.inner.state.compare_exchange_weak( + state, + new_state, + Ordering::Relaxed, + Ordering::SeqCst, + ) { + Ok(s) => { + if s & ACTIVE_RECRUITER_BIT == 0 { + self.recruiter = true; + } + break; + }, + + Err(s) => state = s, + } + } + + ParkAttempt { worker: self } + } + + /// You should call this function after finding work to recruit the next + /// worker if this worker was a recruiter. + pub fn recruit_next(&mut self) { + if !self.recruiter { + return; + } + + // Recruit the next idle worker or mark that there is no recruiter anymore. + let mut recruit_next = false; + let _ = self + .inner + .state + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| { + debug_assert!(state & ACTIVE_RECRUITER_BIT != 0); + + recruit_next = state_num_idle(state) > 0; + let bit = if recruit_next { + IDLE_UNIT + } else { + ACTIVE_RECRUITER_BIT + }; + Some(state - bit) + }); + + if recruit_next { + self.inner.unpark_one_slow_as_recruiter(); + } + self.recruiter = false; + } +} + +pub struct ParkAttempt<'a> { + worker: &'a mut ParkGroupWorker, +} + +impl<'a> ParkAttempt<'a> { + /// Actually park this worker. + /// + /// If there were calls to unpark between calling prepare_park() and park(), + /// this park attempt is cancelled and immediately returns. + pub fn park(mut self) { + let state = &self.worker.inner.state; + let update = state.fetch_update(Ordering::Relaxed, Ordering::SeqCst, |state| { + if state_version(state) != self.worker.version { + // We got notified of new work, cancel park. + None + } else if self.worker.recruiter { + Some(state + IDLE_UNIT - ACTIVE_RECRUITER_BIT) + } else { + Some(state + IDLE_UNIT) + } + }); + + if update.is_ok() { + self.park_slow() + } + } + + #[cold] + fn park_slow(&mut self) { + let condvar = &self.worker.inner.condvar; + let mut notifications = self.worker.inner.notifications.lock(); + condvar.wait_while(&mut notifications, |n| n.0 == 0); + + // Possibly become a recruiter and consume the notification. + self.worker.recruiter = notifications.1; + notifications.0 -= 1; + notifications.1 = false; + } +} + +impl ParkGroupInner { + fn unpark_one(&self) { + let mut should_unpark = false; + let _ = self + .state + .fetch_update(Ordering::Release, Ordering::SeqCst, |state| { + should_unpark = state_num_idle(state) > 0 && state & ACTIVE_RECRUITER_BIT == 0; + if should_unpark { + Some(state - IDLE_UNIT + ACTIVE_RECRUITER_BIT) + } else if state & PREPARING_TO_PARK_BIT == PREPARING_TO_PARK_BIT { + Some(state.wrapping_add(VERSION_UNIT) & !PREPARING_TO_PARK_BIT) + } else { + None + } + }); + + if should_unpark { + self.unpark_one_slow_as_recruiter(); + } + } + + #[cold] + fn unpark_one_slow_as_recruiter(&self) { + let mut notifications = self.notifications.lock(); + notifications.0 += 1; + notifications.1 = true; + self.condvar.notify_one(); + } +} diff --git a/crates/polars-stream/src/executor/task.rs b/crates/polars-stream/src/executor/task.rs new file mode 100644 index 000000000000..b87b2a7b4be3 --- /dev/null +++ b/crates/polars-stream/src/executor/task.rs @@ -0,0 +1,404 @@ +use std::any::Any; +use std::future::Future; +use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll, Wake, Waker}; + +use parking_lot::Mutex; + +/// The state of the task. Can't be part of the TaskData enum as it needs to be +/// atomically updateable, even when we hold the lock on the data. +#[derive(Default)] +struct TaskState { + state: AtomicU8, +} + +impl TaskState { + /// Default state, not running, not scheduled. + const IDLE: u8 = 0; + + /// Task is scheduled, that is (task.schedule)(task) was called. + const SCHEDULED: u8 = 1; + + /// Task is currently running. + const RUNNING: u8 = 2; + + /// Task notified while running. + const NOTIFIED_WHILE_RUNNING: u8 = 3; + + /// Wake this task. Returns true if task.schedule should be called. + fn wake(&self) -> bool { + self.state + .fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state { + Self::SCHEDULED | Self::NOTIFIED_WHILE_RUNNING => None, + Self::RUNNING => Some(Self::NOTIFIED_WHILE_RUNNING), + Self::IDLE => Some(Self::SCHEDULED), + _ => unreachable!("invalid TaskState"), + }) + .map(|state| state == Self::IDLE) + .unwrap_or(false) + } + + /// Start running this task. + fn start_running(&self) { + assert_eq!(self.state.load(Ordering::Acquire), Self::SCHEDULED); + self.state.store(Self::RUNNING, Ordering::Relaxed); + } + + /// Done running this task. Returns true if task.schedule should be called. + fn reschedule_after_running(&self) -> bool { + self.state + .fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state { + Self::RUNNING => Some(Self::IDLE), + Self::NOTIFIED_WHILE_RUNNING => Some(Self::SCHEDULED), + _ => panic!("TaskState::reschedule_after_running() called on invalid state"), + }) + .map(|old_state| old_state == Self::NOTIFIED_WHILE_RUNNING) + .unwrap_or(false) + } +} + +enum TaskData { + Empty, + Polling(F, Waker), + Ready(F::Output), + Panic(Box), + Cancelled, + Joined, +} + +struct Task { + state: TaskState, + data: Mutex<(TaskData, Option)>, + schedule: S, + metadata: M, +} + +impl<'a, F, S, M> Task +where + F: Future + Send + 'a, + F::Output: Send + 'static, + S: Fn(Runnable) + Send + Sync + Copy + 'static, + M: Send + Sync + 'static, +{ + /// # Safety + /// It is the responsibility of the caller that before lifetime 'a ends the + /// task is either polled to completion or cancelled. + unsafe fn spawn(future: F, schedule: S, metadata: M) -> Arc { + let task = Arc::new(Self { + state: TaskState::default(), + data: Mutex::new((TaskData::Empty, None)), + schedule, + metadata, + }); + + let waker = unsafe { Waker::from_raw(std_shim::raw_waker(task.clone())) }; + task.data.try_lock().unwrap().0 = TaskData::Polling(future, waker); + task + } + + fn into_runnable(self: Arc) -> Runnable { + let arc: Arc + 'a> = self; + let arc: Arc> = unsafe { std::mem::transmute(arc) }; + Runnable(arc) + } + + fn into_join_handle(self: Arc) -> JoinHandle { + let arc: Arc + 'a> = self; + let arc: Arc> = unsafe { std::mem::transmute(arc) }; + JoinHandle(Some(arc)) + } + + fn into_cancel_handle(self: Arc) -> CancelHandle { + let arc: Arc = self; + let arc: Arc = unsafe { std::mem::transmute(arc) }; + CancelHandle(Arc::downgrade(&arc)) + } +} + +impl<'a, F, S, M> Wake for Task +where + F: Future + Send + 'a, + F::Output: Send + 'static, + S: Fn(Runnable) + Send + Sync + Copy + 'static, + M: Send + Sync + 'static, +{ + fn wake(self: Arc) { + if self.state.wake() { + let schedule = self.schedule; + (schedule)(self.into_runnable()); + } + } + + fn wake_by_ref(self: &Arc) { + self.clone().wake() + } +} + +pub trait DynTask: Send + Sync { + fn metadata(&self) -> &M; + fn run(self: Arc) -> bool; + fn schedule(self: Arc); +} + +impl<'a, F, S, M> DynTask for Task +where + F: Future + Send + 'a, + F::Output: Send + 'static, + S: Fn(Runnable) + Send + Sync + Copy + 'static, + M: Send + Sync + 'static, +{ + fn metadata(&self) -> &M { + &self.metadata + } + + fn run(self: Arc) -> bool { + let mut data = self.data.lock(); + + let poll_result = match &mut data.0 { + TaskData::Polling(future, waker) => { + self.state.start_running(); + // SAFETY: we always store a Task in an Arc and never move it. + let fut = unsafe { Pin::new_unchecked(future) }; + let mut ctx = Context::from_waker(waker); + catch_unwind(AssertUnwindSafe(|| fut.poll(&mut ctx))) + }, + TaskData::Cancelled => return true, + _ => unreachable!("invalid TaskData when polling"), + }; + + data.0 = match poll_result { + Err(error) => TaskData::Panic(error), + Ok(Poll::Ready(output)) => TaskData::Ready(output), + Ok(Poll::Pending) => { + drop(data); + if self.state.reschedule_after_running() { + let schedule = self.schedule; + (schedule)(self.into_runnable()); + } + return false; + }, + }; + + let join_waker = data.1.take(); + drop(data); + if let Some(w) = join_waker { + w.wake(); + } + true + } + + fn schedule(self: Arc) { + if self.state.wake() { + (self.schedule)(self.clone().into_runnable()); + } + } +} + +trait Joinable: Send + Sync { + fn cancel_handle(self: Arc) -> CancelHandle; + fn poll_join(&self, ctx: &mut Context<'_>) -> Poll; +} + +impl<'a, F, S, M> Joinable for Task +where + F: Future + Send + 'a, + F::Output: Send + 'static, + S: Fn(Runnable) + Send + Sync + Copy + 'static, + M: Send + Sync + 'static, +{ + fn cancel_handle(self: Arc) -> CancelHandle { + self.into_cancel_handle() + } + + fn poll_join(&self, cx: &mut Context<'_>) -> Poll { + let mut data = self.data.lock(); + if matches!(data.0, TaskData::Empty | TaskData::Polling(..)) { + data.1 = Some(cx.waker().clone()); + return Poll::Pending; + } + + match core::mem::replace(&mut data.0, TaskData::Joined) { + TaskData::Ready(output) => Poll::Ready(output), + TaskData::Panic(error) => resume_unwind(error), + TaskData::Cancelled => panic!("joined on cancelled task"), + _ => unreachable!("invalid TaskData when joining"), + } + } +} + +trait Cancellable: Send + Sync { + fn cancel(&self); +} + +impl<'a, F, S, M> Cancellable for Task +where + F: Future + Send + 'a, + F::Output: Send + 'static, + S: Send + Sync + 'static, + M: Send + Sync + 'static, +{ + fn cancel(&self) { + let mut data = self.data.lock(); + match data.0 { + // Already done. + TaskData::Panic(_) | TaskData::Joined => {}, + + // Still in-progress, cancel. + _ => { + data.0 = TaskData::Cancelled; + if let Some(join_waker) = data.1.take() { + join_waker.wake(); + } + }, + } + } +} + +pub struct Runnable(Arc>); + +impl Runnable { + /// Gives the metadata for this task. + pub fn metadata(&self) -> &M { + self.0.metadata() + } + + /// Runs a task, and returns true if the task is done. + pub fn run(self) -> bool { + self.0.run() + } + + /// Schedules this task. + pub fn schedule(self) { + self.0.schedule() + } +} + +pub struct JoinHandle(Option>>); +pub struct CancelHandle(Weak); + +impl JoinHandle { + pub fn cancel_handle(&self) -> CancelHandle { + let arc = self + .0 + .as_ref() + .expect("called cancel_handle on joined JoinHandle"); + Arc::clone(arc).cancel_handle() + } +} + +impl Future for JoinHandle { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let joinable = self.0.take().expect("JoinHandle polled after completion"); + + if let Poll::Ready(output) = joinable.poll_join(ctx) { + return Poll::Ready(output); + } + + self.0 = Some(joinable); + Poll::Pending + } +} + +impl CancelHandle { + pub fn cancel(self) { + if let Some(t) = self.0.upgrade() { + t.cancel(); + } + } +} + +#[allow(unused)] +pub fn spawn(future: F, schedule: S, metadata: M) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, + S: Fn(Runnable) + Send + Sync + Copy + 'static, + M: Send + Sync + 'static, +{ + let task = unsafe { Task::spawn(future, schedule, metadata) }; + JoinHandle(Some(task)) +} + +/// Takes a future and turns it into a runnable task with associated metadata. +/// +/// When the task is pending its waker will be set to call schedule +/// with the runnable. +pub unsafe fn spawn_with_lifetime<'a, F, S, M>( + future: F, + schedule: S, + metadata: M, +) -> (Runnable, JoinHandle) +where + F: Future + Send + 'a, + F::Output: Send + 'static, + S: Fn(Runnable) + Send + Sync + Copy + 'static, + M: Send + Sync + 'static, +{ + let task = Task::spawn(future, schedule, metadata); + (task.clone().into_runnable(), task.into_join_handle()) +} + +// Copied from the standard library, except without the 'static bound. +mod std_shim { + use std::mem::ManuallyDrop; + use std::sync::Arc; + use std::task::{RawWaker, RawWakerVTable, Wake}; + + #[inline(always)] + pub unsafe fn raw_waker<'a, W: Wake + Send + Sync + 'a>(waker: Arc) -> RawWaker { + // Increment the reference count of the arc to clone it. + // + // The #[inline(always)] is to ensure that raw_waker and clone_waker are + // always generated in the same code generation unit as one another, and + // therefore that the structurally identical const-promoted RawWakerVTable + // within both functions is deduplicated at LLVM IR code generation time. + // This allows optimizing Waker::will_wake to a single pointer comparison of + // the vtable pointers, rather than comparing all four function pointers + // within the vtables. + #[inline(always)] + unsafe fn clone_waker(waker: *const ()) -> RawWaker { + unsafe { Arc::increment_strong_count(waker as *const W) }; + RawWaker::new( + waker, + &RawWakerVTable::new( + clone_waker::, + wake::, + wake_by_ref::, + drop_waker::, + ), + ) + } + + // Wake by value, moving the Arc into the Wake::wake function + unsafe fn wake(waker: *const ()) { + let waker = unsafe { Arc::from_raw(waker as *const W) }; + ::wake(waker); + } + + // Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it + unsafe fn wake_by_ref(waker: *const ()) { + let waker = unsafe { ManuallyDrop::new(Arc::from_raw(waker as *const W)) }; + ::wake_by_ref(&waker); + } + + // Decrement the reference count of the Arc on drop + unsafe fn drop_waker(waker: *const ()) { + unsafe { Arc::decrement_strong_count(waker as *const W) }; + } + + RawWaker::new( + Arc::into_raw(waker) as *const (), + &RawWakerVTable::new( + clone_waker::, + wake::, + wake_by_ref::, + drop_waker::, + ), + ) + } +} diff --git a/crates/polars-stream/src/lib.rs b/crates/polars-stream/src/lib.rs new file mode 100644 index 000000000000..1ba4dceaf39a --- /dev/null +++ b/crates/polars-stream/src/lib.rs @@ -0,0 +1,20 @@ +#[allow(unused)] +mod async_primitives; +#[allow(unused)] +mod executor; + +pub async fn dummy() { + let num_threads = 8; + executor::set_num_threads(num_threads); + executor::task_scope(|s| { + s.spawn_task(false, async {}); + }); + + let (mut send, mut recv) = async_primitives::pipe::pipe::(); + send.send(42).await.ok(); + recv.recv().await.ok(); + let (mut send, mut recvs) = + async_primitives::distributor_channel::distributor_channel::(num_threads, 8); + send.send(42).await.ok(); + recvs[0].recv().await.ok(); +}