Skip to content

Commit

Permalink
feat(rust): handle CSE dtypes in NodeTraverser.get_dtype (#16552)
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored May 29, 2024
1 parent 243b61e commit 10ea42b
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::sync::Mutex;

use polars_plan::logical_plan::{to_aexpr, Context, IR};
Expand Down Expand Up @@ -122,8 +123,37 @@ impl NodeTraverser {
fn get_dtype(&self, expr_node: usize, py: Python<'_>) -> PyResult<PyObject> {
let expr_node = Node(expr_node);
let lp_arena = self.lp_arena.lock().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena);
let ir_node = lp_arena.get(self.root);
let expr_arena = self.expr_arena.lock().unwrap();
let schema = {
// TODO: This is a hack for CSE expressions when
// determining the dtype. It should be removed once
// to_field, or its moral equivalent can handle this in a
// proper way. The schema needs to include the dtype of
// CSE expressions for to_field to work with expressions
// that reference them, but is not part of the public
// schema of the node.
let schema = ir_node.schema(&lp_arena);
match ir_node {
IR::Select { expr, .. } | IR::HStack { exprs: expr, .. } => {
let cse_exprs = expr.cse_exprs();
if cse_exprs.is_empty() {
schema
} else {
let mut new_schema: Schema = (**schema).clone();
for e in cse_exprs {
let field = expr_arena
.get(e.node())
.to_field(&schema, Context::Default, &expr_arena)
.map_err(PyPolarsErr::from)?;
new_schema.with_column(e.output_name().into(), field.dtype);
}
Cow::Owned(Arc::new(new_schema))
}
},
_ => schema,
}
};
let field = expr_arena
.get(expr_node)
.to_field(&schema, Context::Default, &expr_arena)
Expand Down

0 comments on commit 10ea42b

Please sign in to comment.