diff --git a/py-polars/src/lazyframe/visit.rs b/py-polars/src/lazyframe/visit.rs index 1cbf02322161..d62f1388978a 100644 --- a/py-polars/src/lazyframe/visit.rs +++ b/py-polars/src/lazyframe/visit.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::sync::Mutex; use polars_plan::logical_plan::{to_aexpr, Context, IR}; @@ -122,8 +123,37 @@ impl NodeTraverser { fn get_dtype(&self, expr_node: usize, py: Python<'_>) -> PyResult { let expr_node = Node(expr_node); let lp_arena = self.lp_arena.lock().unwrap(); - let schema = lp_arena.get(self.root).schema(&lp_arena); + let ir_node = lp_arena.get(self.root); let expr_arena = self.expr_arena.lock().unwrap(); + let schema = { + // TODO: This is a hack for CSE expressions when + // determining the dtype. It should be removed once + // to_field, or its moral equivalent can handle this in a + // proper way. The schema needs to include the dtype of + // CSE expressions for to_field to work with expressions + // that reference them, but is not part of the public + // schema of the node. + let schema = ir_node.schema(&lp_arena); + match ir_node { + IR::Select { expr, .. } | IR::HStack { exprs: expr, .. } => { + let cse_exprs = expr.cse_exprs(); + if cse_exprs.is_empty() { + schema + } else { + let mut new_schema: Schema = (**schema).clone(); + for e in cse_exprs { + let field = expr_arena + .get(e.node()) + .to_field(&schema, Context::Default, &expr_arena) + .map_err(PyPolarsErr::from)?; + new_schema.with_column(e.output_name().into(), field.dtype); + } + Cow::Owned(Arc::new(new_schema)) + } + }, + _ => schema, + } + }; let field = expr_arena .get(expr_node) .to_field(&schema, Context::Default, &expr_arena)