Skip to content

Commit

Permalink
refactor(rust): Add WithRowIndexNode to new-streaming engine (#19037)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Oct 1, 2024
1 parent 4c6d501 commit 9ee3fa2
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 12 deletions.
1 change: 1 addition & 0 deletions crates/polars-stream/src/nodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod reduce;
pub mod select;
pub mod simple_projection;
pub mod streaming_slice;
pub mod with_row_index;
pub mod zip;

/// The imports you'll always need for implementing a ComputeNode.
Expand Down
87 changes: 87 additions & 0 deletions crates/polars-stream/src/nodes/with_row_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use polars_core::prelude::*;
use polars_core::utils::Container;
use polars_utils::pl_str::PlSmallStr;

use super::compute_node_prelude::*;
use crate::async_primitives::distributor_channel::distributor_channel;
use crate::async_primitives::wait_group::WaitGroup;
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;

pub struct WithRowIndexNode {
name: PlSmallStr,
offset: IdxSize,
}

impl WithRowIndexNode {
pub fn new(name: PlSmallStr, offset: Option<IdxSize>) -> Self {
Self {
name,
offset: offset.unwrap_or(0),
}
}
}

impl ComputeNode for WithRowIndexNode {
fn name(&self) -> &str {
"with_row_index"
}

fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> {
assert!(recv.len() == 1 && send.len() == 1);
recv.swap_with_slice(send);
Ok(())
}

fn spawn<'env, 's>(
&'env mut self,
scope: &'s TaskScope<'s, 'env>,
recv: &mut [Option<RecvPort<'_>>],
send: &mut [Option<SendPort<'_>>],
_state: &'s ExecutionState,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) {
assert!(recv.len() == 1 && send.len() == 1);
let mut receiver = recv[0].take().unwrap().serial();
let senders = send[0].take().unwrap().parallel();

let (mut distributor, distr_receivers) =
distributor_channel(senders.len(), DEFAULT_DISTRIBUTOR_BUFFER_SIZE);

let name = self.name.clone();

// To figure out the correct offsets we need to be serial.
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
while let Ok(morsel) = receiver.recv().await {
let offset = self.offset;
self.offset = self
.offset
.checked_add(morsel.df().len().try_into().unwrap())
.unwrap();
if distributor.send((morsel, offset)).await.is_err() {
break;
}
}

Ok(())
}));

// But adding the new row index column can be done in parallel.
for (mut send, mut recv) in senders.into_iter().zip(distr_receivers) {
let name = name.clone();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
let wait_group = WaitGroup::default();
while let Ok((morsel, offset)) = recv.recv().await {
let mut morsel =
morsel.try_map(|df| df.with_row_index(name.clone(), Some(offset)))?;
morsel.set_consume_token(wait_group.token());
if send.send(morsel).await.is_err() {
break;
}
wait_group.wait().await;
}

Ok(())
}));
}
}
}
8 changes: 8 additions & 0 deletions crates/polars-stream/src/physical_plan/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ fn visualize_plan_rec(
from_ref(input),
)
},
PhysNodeKind::WithRowIndex {
input,
name,
offset,
} => (
format!("with-row-index\\nname: {name}\\noffset: {offset:?}"),
from_ref(input),
),
PhysNodeKind::InputIndependentSelect { selectors } => (
format!(
"input-independent-select\\n{}",
Expand Down
36 changes: 25 additions & 11 deletions crates/polars-stream/src/physical_plan/lower_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,32 @@ pub fn lower_ir(
expr_cache,
)?;

if function.is_streamable() {
let map = Arc::new(move |df| function.evaluate(df));
PhysNodeKind::Map {
input: phys_input,
map,
}
} else {
let map = Arc::new(move |df| function.evaluate(df));
PhysNodeKind::InMemoryMap {
match function {
FunctionIR::RowIndex {
name,
offset,
schema: _,
} => PhysNodeKind::WithRowIndex {
input: phys_input,
map,
}
name,
offset,
},

function if function.is_streamable() => {
let map = Arc::new(move |df| function.evaluate(df));
PhysNodeKind::Map {
input: phys_input,
map,
}
},

function => {
let map = Arc::new(move |df| function.evaluate(df));
PhysNodeKind::InMemoryMap {
input: phys_input,
map,
}
},
}
},

Expand Down
9 changes: 8 additions & 1 deletion crates/polars-stream/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use polars_core::frame::DataFrame;
use polars_core::prelude::{InitHashMaps, PlHashMap, SortMultipleOptions};
use polars_core::prelude::{IdxSize, InitHashMaps, PlHashMap, SortMultipleOptions};
use polars_core::schema::{Schema, SchemaRef};
use polars_error::PolarsResult;
use polars_plan::plans::hive::HivePartitions;
Expand Down Expand Up @@ -58,6 +58,12 @@ pub enum PhysNodeKind {
extend_original: bool,
},

WithRowIndex {
input: PhysNodeKey,
name: PlSmallStr,
offset: Option<IdxSize>,
},

InputIndependentSelect {
selectors: Vec<ExprIR>,
},
Expand Down Expand Up @@ -164,6 +170,7 @@ fn insert_multiplexers(
| PhysNodeKind::FileScan { .. }
| PhysNodeKind::InputIndependentSelect { .. } => {},
PhysNodeKind::Select { input, .. }
| PhysNodeKind::WithRowIndex { input, .. }
| PhysNodeKind::Reduce { input, .. }
| PhysNodeKind::StreamingSlice { input, .. }
| PhysNodeKind::Filter { input, .. }
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-stream/src/physical_plan/to_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ fn to_graph_rec<'a>(
)
},

WithRowIndex {
input,
name,
offset,
} => {
let input_key = to_graph_rec(*input, ctx)?;
ctx.graph.add_node(
nodes::with_row_index::WithRowIndexNode::new(name.clone(), *offset),
[input_key],
)
},

InputIndependentSelect { selectors } => {
let phys_selectors = selectors
.iter()
Expand Down

0 comments on commit 9ee3fa2

Please sign in to comment.