diff --git a/crates/polars-stream/src/nodes/io_sources/parquet/init.rs b/crates/polars-stream/src/nodes/io_sources/parquet/init.rs index c683aecaac98..7e1effd3431a 100644 --- a/crates/polars-stream/src/nodes/io_sources/parquet/init.rs +++ b/crates/polars-stream/src/nodes/io_sources/parquet/init.rs @@ -1,9 +1,6 @@ -use std::collections::VecDeque; use std::future::Future; use std::sync::Arc; -use futures::stream::FuturesUnordered; -use futures::StreamExt; use polars_core::frame::DataFrame; use polars_core::prelude::PlIndexSet; use polars_error::PolarsResult; @@ -13,12 +10,11 @@ use polars_io::prelude::_internal::PrefilterMaskSetting; use super::row_group_data_fetch::RowGroupDataFetcher; use super::row_group_decode::RowGroupDecoder; use super::{AsyncTaskData, ParquetSourceNode}; -use crate::async_executor; -use crate::async_primitives::connector::connector; -use crate::async_primitives::wait_group::IndexedWaitGroup; +use crate::async_primitives::distributor_channel::distributor_channel; use crate::morsel::get_ideal_morsel_size; use crate::nodes::{MorselSeq, TaskPriority}; -use crate::utils::task_handles_ext; +use crate::utils::task_handles_ext::{self, AbortOnDropHandle}; +use crate::{async_executor, DEFAULT_DISTRIBUTOR_BUFFER_SIZE}; impl ParquetSourceNode { /// # Panics @@ -72,9 +68,8 @@ impl ParquetSourceNode { let use_statistics = self.options.use_statistics; - let (mut raw_morsel_senders, raw_morsel_receivers): (Vec<_>, Vec<_>) = - (0..self.config.num_pipelines).map(|_| connector()).unzip(); - + let (mut raw_morsel_sender, raw_morsel_receivers) = + distributor_channel(self.config.num_pipelines, DEFAULT_DISTRIBUTOR_BUFFER_SIZE); if let Some((_, 0)) = self.file_options.slice { return ( raw_morsel_receivers, @@ -84,10 +79,9 @@ impl ParquetSourceNode { let reader_schema = self.schema.clone().unwrap(); - let (normalized_slice_oneshot_rx, metadata_rx, metadata_task_handle) = + let (normalized_slice_oneshot_rx, metadata_rx, metadata_task) = self.init_metadata_fetcher(); - let num_pipelines = self.config.num_pipelines; let row_group_prefetch_size = self.config.row_group_prefetch_size; let projection = self.file_options.with_columns.clone(); assert_eq!(self.physical_predicate.is_some(), self.predicate.is_some()); @@ -121,14 +115,15 @@ impl ParquetSourceNode { eprintln!("[ParquetSource]: ideal_morsel_size: {}", ideal_morsel_size); } - // Distributes morsels across pipelines. This does not perform any CPU or I/O bound work - - // it is purely a dispatch loop. - let raw_morsel_distributor_task_handle = io_runtime.spawn(async move { + // Prefetch loop (spawns prefetches on the tokio scheduler). + let (prefetch_send, mut prefetch_recv) = + tokio::sync::mpsc::channel(row_group_prefetch_size); + let prefetch_task = AbortOnDropHandle(io_runtime.spawn(async move { let slice_range = { let Ok(slice) = normalized_slice_oneshot_rx.await else { // If we are here then the producer probably errored. drop(row_group_data_fetcher); - return metadata_task_handle.await.unwrap(); + return PolarsResult::Ok(()); }; slice.map(|(offset, len)| offset..offset + len) @@ -136,107 +131,62 @@ impl ParquetSourceNode { row_group_data_fetcher.slice_range = slice_range; - // Ensure proper backpressure by only polling the buffered iterator when a wait group - // is free. - let mut wait_groups = (0..num_pipelines) - .map(|index| IndexedWaitGroup::new(index).wait()) - .collect::>(); - - let mut df_stream = row_group_data_fetcher - .into_stream() - .map(|x| async { - match x { - Ok(handle) => handle.await.unwrap(), - Err(e) => Err(e), - } - }) - .buffered(row_group_prefetch_size) - .map(|x| async { - let row_group_decoder = row_group_decoder.clone(); - - match x { - Ok(row_group_data) => { - async_executor::spawn(TaskPriority::Low, async move { - row_group_decoder.row_group_data_to_df(row_group_data).await - }) - .await - }, - Err(e) => Err(e), - } - }) - .buffered( - // Because we are using an ordered buffer, we may suffer from head-of-line blocking, - // so we add a small amount of buffer. - num_pipelines + 4, - ); - - let morsel_seq_ref = &mut MorselSeq::default(); - let mut dfs = VecDeque::with_capacity(1); - - 'main: loop { - let Some(mut indexed_wait_group) = wait_groups.next().await else { + loop { + let Some(prefetch) = row_group_data_fetcher.next().await else { break; }; + if prefetch_send.send(prefetch?).await.is_err() { + break; + } + } + PolarsResult::Ok(()) + })); + + // Decode loop (spawns decodes on the computational executor). + let (decode_send, mut decode_recv) = tokio::sync::mpsc::channel(self.config.num_pipelines); + let decode_task = AbortOnDropHandle(io_runtime.spawn(async move { + while let Some(prefetch) = prefetch_recv.recv().await { + let row_group_data = prefetch.await.unwrap()?; + let row_group_decoder = row_group_decoder.clone(); + let decode_fut = async_executor::spawn(TaskPriority::High, async move { + row_group_decoder.row_group_data_to_df(row_group_data).await + }); + if decode_send.send(decode_fut).await.is_err() { + break; + } + } + PolarsResult::Ok(()) + })); - while dfs.is_empty() { - let Some(v) = df_stream.next().await else { - break 'main; - }; - - let df = v?; - - if df.is_empty() { - continue; - } - - let (iter, n) = split_to_morsels(&df, ideal_morsel_size); - - dfs.reserve(n); - dfs.extend(iter); + // Distributes morsels across pipelines. This does not perform any CPU or I/O bound work - + // it is purely a dispatch loop. + let distribute_task = AbortOnDropHandle(io_runtime.spawn(async move { + let mut morsel_seq = MorselSeq::default(); + while let Some(decode_fut) = decode_recv.recv().await { + let df = decode_fut.await?; + if df.is_empty() { + continue; } - let mut df = dfs.pop_front().unwrap(); - let morsel_seq = *morsel_seq_ref; - *morsel_seq_ref = morsel_seq.successor(); - - loop { - use crate::async_primitives::connector::SendError; - - let channel_index = indexed_wait_group.index(); - let wait_token = indexed_wait_group.token(); - - match raw_morsel_senders[channel_index].try_send((df, morsel_seq, wait_token)) { - Ok(_) => { - wait_groups.push(indexed_wait_group.wait()); - break; - }, - Err(SendError::Closed(v)) => { - // The channel assigned to this wait group has been closed, so we will not - // add it back to the list of wait groups, and we will try to send this - // across another channel. - df = v.0 - }, - Err(SendError::Full(_)) => unreachable!(), + for df in split_to_morsels(&df, ideal_morsel_size) { + if raw_morsel_sender.send((df, morsel_seq)).await.is_err() { + return Ok(()); } - - let Some(v) = wait_groups.next().await else { - // All channels have closed - break 'main; - }; - - indexed_wait_group = v; + morsel_seq = morsel_seq.successor(); } } - - // Join on the producer handle to catch errors/panics. - drop(df_stream); - metadata_task_handle.await.unwrap() + PolarsResult::Ok(()) + })); + + let join_task = io_runtime.spawn(async move { + metadata_task.await.unwrap()?; + prefetch_task.await.unwrap()?; + decode_task.await.unwrap()?; + distribute_task.await.unwrap()?; + Ok(()) }); - let raw_morsel_distributor_task_handle = - task_handles_ext::AbortOnDropHandle(raw_morsel_distributor_task_handle); - - (raw_morsel_receivers, raw_morsel_distributor_task_handle) + (raw_morsel_receivers, AbortOnDropHandle(join_task)) } /// Creates a `RowGroupDecoder` that turns `RowGroupData` into DataFrames. @@ -387,7 +337,7 @@ fn filtered_range(exclude: &[usize], len: usize) -> Vec { fn split_to_morsels( df: &DataFrame, ideal_morsel_size: usize, -) -> (impl Iterator + '_, usize) { +) -> impl Iterator + '_ { let n_morsels = if df.height() > 3 * ideal_morsel_size / 2 { // num_rows > (1.5 * ideal_morsel_size) (df.height() / ideal_morsel_size).max(2) @@ -397,12 +347,9 @@ fn split_to_morsels( let rows_per_morsel = 1 + df.height() / n_morsels; - ( - (0..i64::try_from(df.height()).unwrap()) - .step_by(rows_per_morsel) - .map(move |offset| df.slice(offset, rows_per_morsel)), - n_morsels, - ) + (0..i64::try_from(df.height()).unwrap()) + .step_by(rows_per_morsel) + .map(move |offset| df.slice(offset, rows_per_morsel)) } mod tests { diff --git a/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs b/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs index 098a0fd90b04..d4dd5cade584 100644 --- a/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs +++ b/crates/polars-stream/src/nodes/io_sources/parquet/mod.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use mem_prefetch_funcs::get_memory_prefetch_func; use polars_core::config; -use polars_core::frame::DataFrame; use polars_core::prelude::ArrowSchema; use polars_error::PolarsResult; use polars_expr::prelude::{phys_expr_to_io_expr, PhysicalExpr}; @@ -17,10 +16,10 @@ use polars_plan::prelude::FileScanOptions; use polars_utils::index::AtomicIdxSize; use polars_utils::pl_str::PlSmallStr; -use crate::async_primitives::wait_group::WaitToken; +use crate::async_primitives::wait_group::WaitGroup; use crate::morsel::SourceToken; use crate::nodes::compute_node_prelude::*; -use crate::nodes::{MorselSeq, TaskPriority}; +use crate::nodes::TaskPriority; use crate::utils::task_handles_ext; mod init; @@ -31,7 +30,7 @@ mod row_group_data_fetch; mod row_group_decode; type AsyncTaskData = ( - Vec>, + Vec>, task_handles_ext::AbortOnDropHandle>, ); @@ -242,19 +241,20 @@ impl ComputeNode for ParquetSourceNode { let is_finished = is_finished.clone(); let source_token = source_token.clone(); scope.spawn_task(TaskPriority::Low, async move { + let wait_group = WaitGroup::default(); loop { - let Ok((df, morsel_seq, wait_token)) = raw_morsel_rx.recv().await else { + let Ok((df, seq)) = raw_morsel_rx.recv().await else { is_finished.store(true, Ordering::Relaxed); break; }; - let mut morsel = Morsel::new(df, morsel_seq, source_token.clone()); - morsel.set_consume_token(wait_token); - + let mut morsel = Morsel::new(df, seq, source_token.clone()); + morsel.set_consume_token(wait_group.token()); if morsel_tx.send(morsel).await.is_err() { break; } + wait_group.wait().await; if source_token.stop_requested() { break; } diff --git a/crates/polars-stream/src/nodes/io_sources/parquet/row_group_data_fetch.rs b/crates/polars-stream/src/nodes/io_sources/parquet/row_group_data_fetch.rs index bf2e7e60ea6e..d7d004a33b7e 100644 --- a/crates/polars-stream/src/nodes/io_sources/parquet/row_group_data_fetch.rs +++ b/crates/polars-stream/src/nodes/io_sources/parquet/row_group_data_fetch.rs @@ -1,4 +1,3 @@ -use std::future::Future; use std::sync::Arc; use polars_core::prelude::{ArrowSchema, PlHashMap}; @@ -55,10 +54,6 @@ pub(super) struct RowGroupDataFetcher { } impl RowGroupDataFetcher { - pub(super) fn into_stream(self) -> RowGroupDataStream { - RowGroupDataStream::new(self) - } - pub(super) async fn init_next_file_state(&mut self) -> bool { let Ok((path_index, row_offset, byte_source, metadata)) = self.metadata_rx.recv().await else { @@ -278,66 +273,6 @@ impl FetchedBytes { } } -#[rustfmt::skip] -type RowGroupDataStreamFut = std::pin::Pin , - Option < - PolarsResult < - task_handles_ext::AbortOnDropHandle < - PolarsResult < - RowGroupData > > > > - ) - > + Send ->>; - -pub(super) struct RowGroupDataStream { - current_future: RowGroupDataStreamFut, -} - -impl RowGroupDataStream { - fn new(row_group_data_fetcher: RowGroupDataFetcher) -> Self { - // [`RowGroupDataFetcher`] is a big struct, so we Box it once here to avoid boxing it on - // every `next()` call. - let current_future = Self::call_next_owned(Box::new(row_group_data_fetcher)); - Self { current_future } - } - - fn call_next_owned( - mut row_group_data_fetcher: Box, - ) -> RowGroupDataStreamFut { - Box::pin(async move { - let out = row_group_data_fetcher.next().await; - (row_group_data_fetcher, out) - }) - } -} - -impl futures::stream::Stream for RowGroupDataStream { - type Item = PolarsResult>>; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - use std::pin::Pin; - use std::task::Poll; - - match Pin::new(&mut self.current_future.as_mut()).poll(cx) { - Poll::Ready((row_group_data_fetcher, out)) => { - if out.is_some() { - self.current_future = Self::call_next_owned(row_group_data_fetcher); - } - - Poll::Ready(out) - }, - Poll::Pending => Poll::Pending, - } - } -} - fn get_row_group_byte_ranges_for_projection<'a>( row_group_metadata: &'a RowGroupMetadata, columns: &'a [PlSmallStr],