Skip to content

Commit 9ee3fa2

Browse files
authored
refactor(rust): Add WithRowIndexNode to new-streaming engine (#19037)
1 parent 4c6d501 commit 9ee3fa2

File tree

6 files changed

+141
-12
lines changed

6 files changed

+141
-12
lines changed

crates/polars-stream/src/nodes/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod reduce;
1111
pub mod select;
1212
pub mod simple_projection;
1313
pub mod streaming_slice;
14+
pub mod with_row_index;
1415
pub mod zip;
1516

1617
/// The imports you'll always need for implementing a ComputeNode.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use polars_core::prelude::*;
2+
use polars_core::utils::Container;
3+
use polars_utils::pl_str::PlSmallStr;
4+
5+
use super::compute_node_prelude::*;
6+
use crate::async_primitives::distributor_channel::distributor_channel;
7+
use crate::async_primitives::wait_group::WaitGroup;
8+
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;
9+
10+
pub struct WithRowIndexNode {
11+
name: PlSmallStr,
12+
offset: IdxSize,
13+
}
14+
15+
impl WithRowIndexNode {
16+
pub fn new(name: PlSmallStr, offset: Option<IdxSize>) -> Self {
17+
Self {
18+
name,
19+
offset: offset.unwrap_or(0),
20+
}
21+
}
22+
}
23+
24+
impl ComputeNode for WithRowIndexNode {
25+
fn name(&self) -> &str {
26+
"with_row_index"
27+
}
28+
29+
fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> {
30+
assert!(recv.len() == 1 && send.len() == 1);
31+
recv.swap_with_slice(send);
32+
Ok(())
33+
}
34+
35+
fn spawn<'env, 's>(
36+
&'env mut self,
37+
scope: &'s TaskScope<'s, 'env>,
38+
recv: &mut [Option<RecvPort<'_>>],
39+
send: &mut [Option<SendPort<'_>>],
40+
_state: &'s ExecutionState,
41+
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
42+
) {
43+
assert!(recv.len() == 1 && send.len() == 1);
44+
let mut receiver = recv[0].take().unwrap().serial();
45+
let senders = send[0].take().unwrap().parallel();
46+
47+
let (mut distributor, distr_receivers) =
48+
distributor_channel(senders.len(), DEFAULT_DISTRIBUTOR_BUFFER_SIZE);
49+
50+
let name = self.name.clone();
51+
52+
// To figure out the correct offsets we need to be serial.
53+
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
54+
while let Ok(morsel) = receiver.recv().await {
55+
let offset = self.offset;
56+
self.offset = self
57+
.offset
58+
.checked_add(morsel.df().len().try_into().unwrap())
59+
.unwrap();
60+
if distributor.send((morsel, offset)).await.is_err() {
61+
break;
62+
}
63+
}
64+
65+
Ok(())
66+
}));
67+
68+
// But adding the new row index column can be done in parallel.
69+
for (mut send, mut recv) in senders.into_iter().zip(distr_receivers) {
70+
let name = name.clone();
71+
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
72+
let wait_group = WaitGroup::default();
73+
while let Ok((morsel, offset)) = recv.recv().await {
74+
let mut morsel =
75+
morsel.try_map(|df| df.with_row_index(name.clone(), Some(offset)))?;
76+
morsel.set_consume_token(wait_group.token());
77+
if send.send(morsel).await.is_err() {
78+
break;
79+
}
80+
wait_group.wait().await;
81+
}
82+
83+
Ok(())
84+
}));
85+
}
86+
}
87+
}

crates/polars-stream/src/physical_plan/fmt.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ fn visualize_plan_rec(
5959
from_ref(input),
6060
)
6161
},
62+
PhysNodeKind::WithRowIndex {
63+
input,
64+
name,
65+
offset,
66+
} => (
67+
format!("with-row-index\\nname: {name}\\noffset: {offset:?}"),
68+
from_ref(input),
69+
),
6270
PhysNodeKind::InputIndependentSelect { selectors } => (
6371
format!(
6472
"input-independent-select\\n{}",

crates/polars-stream/src/physical_plan/lower_ir.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,32 @@ pub fn lower_ir(
254254
expr_cache,
255255
)?;
256256

257-
if function.is_streamable() {
258-
let map = Arc::new(move |df| function.evaluate(df));
259-
PhysNodeKind::Map {
260-
input: phys_input,
261-
map,
262-
}
263-
} else {
264-
let map = Arc::new(move |df| function.evaluate(df));
265-
PhysNodeKind::InMemoryMap {
257+
match function {
258+
FunctionIR::RowIndex {
259+
name,
260+
offset,
261+
schema: _,
262+
} => PhysNodeKind::WithRowIndex {
266263
input: phys_input,
267-
map,
268-
}
264+
name,
265+
offset,
266+
},
267+
268+
function if function.is_streamable() => {
269+
let map = Arc::new(move |df| function.evaluate(df));
270+
PhysNodeKind::Map {
271+
input: phys_input,
272+
map,
273+
}
274+
},
275+
276+
function => {
277+
let map = Arc::new(move |df| function.evaluate(df));
278+
PhysNodeKind::InMemoryMap {
279+
input: phys_input,
280+
map,
281+
}
282+
},
269283
}
270284
},
271285

crates/polars-stream/src/physical_plan/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::Arc;
22

33
use polars_core::frame::DataFrame;
4-
use polars_core::prelude::{InitHashMaps, PlHashMap, SortMultipleOptions};
4+
use polars_core::prelude::{IdxSize, InitHashMaps, PlHashMap, SortMultipleOptions};
55
use polars_core::schema::{Schema, SchemaRef};
66
use polars_error::PolarsResult;
77
use polars_plan::plans::hive::HivePartitions;
@@ -58,6 +58,12 @@ pub enum PhysNodeKind {
5858
extend_original: bool,
5959
},
6060

61+
WithRowIndex {
62+
input: PhysNodeKey,
63+
name: PlSmallStr,
64+
offset: Option<IdxSize>,
65+
},
66+
6167
InputIndependentSelect {
6268
selectors: Vec<ExprIR>,
6369
},
@@ -164,6 +170,7 @@ fn insert_multiplexers(
164170
| PhysNodeKind::FileScan { .. }
165171
| PhysNodeKind::InputIndependentSelect { .. } => {},
166172
PhysNodeKind::Select { input, .. }
173+
| PhysNodeKind::WithRowIndex { input, .. }
167174
| PhysNodeKind::Reduce { input, .. }
168175
| PhysNodeKind::StreamingSlice { input, .. }
169176
| PhysNodeKind::Filter { input, .. }

crates/polars-stream/src/physical_plan/to_graph.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ fn to_graph_rec<'a>(
131131
)
132132
},
133133

134+
WithRowIndex {
135+
input,
136+
name,
137+
offset,
138+
} => {
139+
let input_key = to_graph_rec(*input, ctx)?;
140+
ctx.graph.add_node(
141+
nodes::with_row_index::WithRowIndexNode::new(name.clone(), *offset),
142+
[input_key],
143+
)
144+
},
145+
134146
InputIndependentSelect { selectors } => {
135147
let phys_selectors = selectors
136148
.iter()

0 commit comments

Comments
 (0)