Skip to content

Commit

Permalink
fix(rust): get_dtype handles input node schema and CSE (#16582)
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored May 30, 2024
1 parent 040da95 commit d7b4f72
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
9 changes: 6 additions & 3 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,17 @@ impl NodeTraverser {
let lp_arena = self.lp_arena.lock().unwrap();
let ir_node = lp_arena.get(self.root);
let expr_arena = self.expr_arena.lock().unwrap();
let schema = {
let schema = if let Some(schema) = ir_node.input_schema(&lp_arena) {
// TODO: This is a hack for CSE expressions when
// determining the dtype. It should be removed once
// to_field, or its moral equivalent can handle this in a
// proper way. The schema needs to include the dtype of
// CSE expressions for to_field to work with expressions
// that reference them, but is not part of the public
// schema of the node.
let schema = ir_node.schema(&lp_arena);
// schema of the input.
match ir_node {
// Both select and hstack must augment with any CSE
// expressions.
IR::Select { expr, .. } | IR::HStack { exprs: expr, .. } => {
let cse_exprs = expr.cse_exprs();
if cse_exprs.is_empty() {
Expand All @@ -153,6 +154,8 @@ impl NodeTraverser {
},
_ => schema,
}
} else {
raise_err!("Not able to compute input schema", ComputeError)
};
let field = expr_arena
.get(expr_node)
Expand Down
15 changes: 11 additions & 4 deletions py-polars/src/lazyframe/visitor/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ pub struct HStack {
options: (), // ProjectionOptions,
}

#[pyclass]
/// Like Select, but all operations produce a single row.
pub struct Reduce {
#[pyo3(get)]
input: usize,
#[pyo3(get)]
exprs: Vec<PyExprIR>,
}

#[pyclass]
/// Remove duplicates from the table
pub struct Distinct {
Expand Down Expand Up @@ -436,11 +445,9 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult<PyObject> {
input,
exprs,
schema: _,
} => Select {
} => Reduce {
input: input.0,
expr: exprs.iter().map(|e| e.into()).collect(),
cse_expr: vec![],
options: (),
exprs: exprs.iter().map(|e| e.into()).collect(),
}
.into_py(py),
IR::Distinct { input, options } => Distinct {
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ fn _ir_nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<GroupBy>().unwrap();
m.add_class::<Join>().unwrap();
m.add_class::<HStack>().unwrap();
m.add_class::<Reduce>().unwrap();
m.add_class::<Distinct>().unwrap();
m.add_class::<MapFunction>().unwrap();
m.add_class::<Union>().unwrap();
Expand Down

0 comments on commit d7b4f72

Please sign in to comment.