diff --git a/crates/polars-stream/src/lib.rs b/crates/polars-stream/src/lib.rs index a31e20f22381..23f44fc7c306 100644 --- a/crates/polars-stream/src/lib.rs +++ b/crates/polars-stream/src/lib.rs @@ -16,3 +16,4 @@ mod utils; // TODO: experiment with these, and make them configurable through environment variables. const DEFAULT_LINEARIZER_BUFFER_SIZE: usize = 4; const DEFAULT_DISTRIBUTOR_BUFFER_SIZE: usize = 4; +const DEFAULT_ZIP_HEAD_BUFFER_SIZE: usize = 4; diff --git a/crates/polars-stream/src/nodes/zip.rs b/crates/polars-stream/src/nodes/zip.rs index 3a9290bde59d..cd72a3567442 100644 --- a/crates/polars-stream/src/nodes/zip.rs +++ b/crates/polars-stream/src/nodes/zip.rs @@ -5,9 +5,11 @@ use polars_core::functions::concat_df_horizontal; use polars_core::schema::Schema; use polars_core::series::Series; use polars_error::polars_ensure; +use polars_utils::itertools::Itertools; use super::compute_node_prelude::*; use crate::morsel::SourceToken; +use crate::DEFAULT_ZIP_HEAD_BUFFER_SIZE; /// The head of an input stream. #[derive(Debug)] @@ -211,7 +213,24 @@ impl ComputeNode for ZipNode { assert!(send.len() == 1); assert!(!recv.is_empty()); let mut sender = send[0].take().unwrap().serial(); - let mut receivers: Vec<_> = recv.iter_mut().map(|r| Some(r.take()?.serial())).collect(); + + let mut receivers = recv + .iter_mut() + .map(|r| { + // Add buffering to each receiver to reduce contention between input heads. + let mut serial_recv = r.take()?.serial(); + let (buf_send, buf_recv) = tokio::sync::mpsc::channel(DEFAULT_ZIP_HEAD_BUFFER_SIZE); + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + while let Ok(morsel) = serial_recv.recv().await { + if buf_send.send(morsel).await.is_err() { + break; + } + } + Ok(()) + })); + Some(buf_recv) + }) + .collect_vec(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { let mut out = Vec::new(); @@ -227,7 +246,7 @@ impl ComputeNode for ZipNode { for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() { if let Some(recv) = opt_recv { while !self.input_heads[recv_idx].ready_to_send() { - if let Ok(morsel) = recv.recv().await { + if let Some(morsel) = recv.recv().await { self.input_heads[recv_idx].add_morsel(morsel); } else { break; @@ -283,7 +302,7 @@ impl ComputeNode for ZipNode { for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() { if let Some(recv) = opt_recv { - while let Ok(mut morsel) = recv.recv().await { + while let Some(mut morsel) = recv.recv().await { morsel.source_token().stop(); drop(morsel.take_consume_token()); self.input_heads[recv_idx].add_morsel(morsel);