From d7b4f7232aeb17c28785f2c052d32b203605a3c1 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 30 May 2024 15:35:16 +0100 Subject: [PATCH] fix(rust): get_dtype handles input node schema and CSE (#16582) --- py-polars/src/lazyframe/visit.rs | 9 ++++++--- py-polars/src/lazyframe/visitor/nodes.rs | 15 +++++++++++---- py-polars/src/lib.rs | 1 + 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/py-polars/src/lazyframe/visit.rs b/py-polars/src/lazyframe/visit.rs index d62f1388978a..f7f8d2ceee34 100644 --- a/py-polars/src/lazyframe/visit.rs +++ b/py-polars/src/lazyframe/visit.rs @@ -125,16 +125,17 @@ impl NodeTraverser { let lp_arena = self.lp_arena.lock().unwrap(); let ir_node = lp_arena.get(self.root); let expr_arena = self.expr_arena.lock().unwrap(); - let schema = { + let schema = if let Some(schema) = ir_node.input_schema(&lp_arena) { // TODO: This is a hack for CSE expressions when // determining the dtype. It should be removed once // to_field, or its moral equivalent can handle this in a // proper way. The schema needs to include the dtype of // CSE expressions for to_field to work with expressions // that reference them, but is not part of the public - // schema of the node. - let schema = ir_node.schema(&lp_arena); + // schema of the input. match ir_node { + // Both select and hstack must augment with any CSE + // expressions. IR::Select { expr, .. } | IR::HStack { exprs: expr, .. } => { let cse_exprs = expr.cse_exprs(); if cse_exprs.is_empty() { @@ -153,6 +154,8 @@ impl NodeTraverser { }, _ => schema, } + } else { + raise_err!("Not able to compute input schema", ComputeError) }; let field = expr_arena .get(expr_node) diff --git a/py-polars/src/lazyframe/visitor/nodes.rs b/py-polars/src/lazyframe/visitor/nodes.rs index 88f7143969b7..15e6e32dce2e 100644 --- a/py-polars/src/lazyframe/visitor/nodes.rs +++ b/py-polars/src/lazyframe/visitor/nodes.rs @@ -202,6 +202,15 @@ pub struct HStack { options: (), // ProjectionOptions, } +#[pyclass] +/// Like Select, but all operations produce a single row. +pub struct Reduce { + #[pyo3(get)] + input: usize, + #[pyo3(get)] + exprs: Vec, +} + #[pyclass] /// Remove duplicates from the table pub struct Distinct { @@ -436,11 +445,9 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { input, exprs, schema: _, - } => Select { + } => Reduce { input: input.0, - expr: exprs.iter().map(|e| e.into()).collect(), - cse_expr: vec![], - options: (), + exprs: exprs.iter().map(|e| e.into()).collect(), } .into_py(py), IR::Distinct { input, options } => Distinct { diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index bf41d6846aad..4cce971059ef 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -75,6 +75,7 @@ fn _ir_nodes(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); + m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap();