Skip to content

Commit

Permalink
refactor(rust): Fix input independence tests in new-streaming engine (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 16, 2024
1 parent a53bf03 commit f7b6d86
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 39 deletions.
23 changes: 14 additions & 9 deletions crates/polars-stream/src/nodes/in_memory_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ impl ComputeNode for InMemorySourceNode {
assert!(recv.is_empty());
assert!(send.len() == 1);

let exhausted = self
.source
.as_ref()
.map(|s| {
self.seq.load(Ordering::Relaxed) * self.morsel_size as u64 >= s.height() as u64
})
.unwrap_or(true);

// As a temporary hack for some nodes (like the FunctionIR::FastCount)
// node that rely on an empty input, always ensure we send at least one
// morsel.
// TODO: remove this hack.
let exhausted = if let Some(src) = &self.source {
let seq = self.seq.load(Ordering::Relaxed);
seq > 0 && seq * self.morsel_size as u64 >= src.height() as u64
} else {
true
};
if send[0] == PortState::Done || exhausted {
send[0] = PortState::Done;
self.source = None;
Expand Down Expand Up @@ -77,7 +79,10 @@ impl ComputeNode for InMemorySourceNode {
let seq = slf.seq.fetch_add(1, Ordering::Relaxed);
let offset = (seq as usize * slf.morsel_size) as i64;
let df = source.slice(offset, slf.morsel_size);
if df.is_empty() {

// TODO: remove this 'always sent at least one morsel'
// condition, see update_state.
if df.is_empty() && seq > 0 {
break;
}

Expand Down
64 changes: 64 additions & 0 deletions crates/polars-stream/src/nodes/input_independent_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use polars_core::prelude::IntoColumn;

use super::compute_node_prelude::*;
use crate::expression::StreamExpr;
use crate::morsel::SourceToken;

pub struct InputIndependentSelectNode {
selectors: Vec<StreamExpr>,
done: bool,
}

impl InputIndependentSelectNode {
pub fn new(selectors: Vec<StreamExpr>) -> Self {
Self {
selectors,
done: false,
}
}
}

impl ComputeNode for InputIndependentSelectNode {
fn name(&self) -> &str {
"input_independent_select"
}

fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> {
assert!(recv.is_empty() && send.len() == 1);
send[0] = if send[0] == PortState::Done || self.done {
PortState::Done
} else {
PortState::Ready
};
Ok(())
}

fn spawn<'env, 's>(
&'env mut self,
scope: &'s TaskScope<'s, 'env>,
recv: &mut [Option<RecvPort<'_>>],
send: &mut [Option<SendPort<'_>>],
state: &'s ExecutionState,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) {
assert!(recv.is_empty() && send.len() == 1);
let mut sender = send[0].take().unwrap().serial();

join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
let empty_df = DataFrame::empty();
let mut selected = Vec::new();
for selector in self.selectors.iter() {
let s = selector.evaluate(&empty_df, state).await?;
selected.push(s.into_column());
}

let ret = DataFrame::new_with_broadcast(selected)?;
let seq = MorselSeq::default();
let source_token = SourceToken::new();
let morsel = Morsel::new(ret, seq, source_token);
sender.send(morsel).await.ok();
self.done = true;
Ok(())
}));
}
}
1 change: 1 addition & 0 deletions crates/polars-stream/src/nodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod filter;
pub mod in_memory_map;
pub mod in_memory_sink;
pub mod in_memory_source;
pub mod input_independent_select;
pub mod map;
pub mod multiplexer;
pub mod ordered_union;
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-stream/src/physical_plan/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ fn visualize_plan_rec(
from_ref(input),
)
},
PhysNodeKind::InputIndependentSelect { selectors } => (
format!(
"input-independent-select\\n{}",
fmt_exprs(selectors, expr_arena)
),
&[][..],
),
PhysNodeKind::Reduce { input, exprs } => (
format!("reduce\\n{}", fmt_exprs(exprs, expr_arena)),
from_ref(input),
Expand Down
53 changes: 24 additions & 29 deletions crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,26 +228,12 @@ fn build_input_independent_node_with_ctx(
exprs: &[ExprIR],
ctx: &mut LowerExprContext,
) -> PolarsResult<PhysNodeKey> {
let expr_depth_limit = get_expr_depth_limit()?;
let mut state = ExpressionConversionState::new(false, expr_depth_limit);
let empty = DataFrame::empty();
let execution_state = ExecutionState::new();
let columns = exprs
.iter()
.map(|expr| {
let phys_expr =
create_physical_expr(expr, Context::Default, ctx.expr_arena, None, &mut state)?;

phys_expr
.evaluate(&empty, &execution_state)
.map(Column::from)
})
.try_collect_vec()?;

let df = Arc::new(DataFrame::new_with_broadcast(columns)?);
let output_schema = compute_output_schema(&Schema::default(), exprs, ctx.expr_arena)?;
Ok(ctx.phys_sm.insert(PhysNode::new(
Arc::new(df.schema()),
PhysNodeKind::InMemorySource { df },
output_schema,
PhysNodeKind::InputIndependentSelect {
selectors: exprs.to_vec(),
},
)))
}

Expand Down Expand Up @@ -653,29 +639,38 @@ fn lower_exprs_with_ctx(
Ok((zip_node, transformed_exprs))
}

/// Computes the schema that selecting the given expressions on the input node
/// Computes the schema that selecting the given expressions on the input schema
/// would result in.
fn schema_for_select(
input: PhysNodeKey,
fn compute_output_schema(
input_schema: &Schema,
exprs: &[ExprIR],
ctx: &mut LowerExprContext,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<Arc<Schema>> {
let input_schema = &ctx.phys_sm[input].output_schema;
let output_schema: Schema = exprs
.iter()
.map(|e| {
let name = e.output_name().clone();
let dtype = ctx.expr_arena.get(e.node()).to_dtype(
input_schema,
Context::Default,
ctx.expr_arena,
)?;
let dtype =
expr_arena
.get(e.node())
.to_dtype(input_schema, Context::Default, expr_arena)?;
PolarsResult::Ok(Field::new(name, dtype))
})
.try_collect()?;
Ok(Arc::new(output_schema))
}

/// Computes the schema that selecting the given expressions on the input node
/// would result in.
fn schema_for_select(
input: PhysNodeKey,
exprs: &[ExprIR],
ctx: &mut LowerExprContext,
) -> PolarsResult<Arc<Schema>> {
let input_schema = &ctx.phys_sm[input].output_schema;
compute_output_schema(input_schema, exprs, ctx.expr_arena)
}

fn build_select_node_with_ctx(
input: PhysNodeKey,
exprs: &[ExprIR],
Expand Down
8 changes: 7 additions & 1 deletion crates/polars-stream/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ pub enum PhysNodeKind {
extend_original: bool,
},

InputIndependentSelect {
selectors: Vec<ExprIR>,
},

Reduce {
input: PhysNodeKey,
exprs: Vec<ExprIR>,
Expand Down Expand Up @@ -156,7 +160,9 @@ fn insert_multiplexers(

if !seen_before {
match &phys_sm[node].kind {
PhysNodeKind::InMemorySource { .. } | PhysNodeKind::FileScan { .. } => {},
PhysNodeKind::InMemorySource { .. }
| PhysNodeKind::FileScan { .. }
| PhysNodeKind::InputIndependentSelect { .. } => {},
PhysNodeKind::Select { input, .. }
| PhysNodeKind::Reduce { input, .. }
| PhysNodeKind::StreamingSlice { input, .. }
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-stream/src/physical_plan/to_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ fn to_graph_rec<'a>(
[input_key],
)
},

InputIndependentSelect { selectors } => {
let phys_selectors = selectors
.iter()
.map(|selector| create_stream_expr(selector, ctx))
.collect::<PolarsResult<_>>()?;
ctx.graph.add_node(
nodes::input_independent_select::InputIndependentSelectNode::new(phys_selectors),
[],
)
},

Reduce { input, exprs } => {
let input_key = to_graph_rec(*input, ctx)?;
let input_schema = &ctx.phys_sm[*input].output_schema;
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/streaming/test_streaming_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_streaming_empty_parquet_16523(tmp_path: Path) -> None:
assert q.join(q2, on="a").collect(streaming=True).shape == (0, 1)


@pytest.mark.may_fail_auto_streaming
@pytest.mark.parametrize(
"method",
["parquet", "csv"],
Expand Down

0 comments on commit f7b6d86

Please sign in to comment.