diff --git a/crates/polars-plan/src/logical_plan/functions/mod.rs b/crates/polars-plan/src/logical_plan/functions/mod.rs index 13c8b4be3d6e..8f4866ddb73b 100644 --- a/crates/polars-plan/src/logical_plan/functions/mod.rs +++ b/crates/polars-plan/src/logical_plan/functions/mod.rs @@ -9,10 +9,12 @@ mod schema; use std::borrow::Cow; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use std::ops::Deref; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use polars_core::prelude::*; +use schema::CachedSchema; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; @@ -23,8 +25,6 @@ use crate::dsl::python_udf::PythonFunction; use crate::logical_plan::functions::merge_sorted::merge_sorted; use crate::prelude::*; -type CachedSchema = Arc>>; - #[derive(Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum FunctionNode { @@ -96,6 +96,7 @@ pub enum FunctionNode { RowIndex { name: Arc, // Might be cached. + #[cfg_attr(feature = "serde", serde(skip))] schema: CachedSchema, offset: Option, }, diff --git a/crates/polars-plan/src/logical_plan/functions/schema.rs b/crates/polars-plan/src/logical_plan/functions/schema.rs index a722e48887a4..9bcdcb0d8643 100644 --- a/crates/polars-plan/src/logical_plan/functions/schema.rs +++ b/crates/polars-plan/src/logical_plan/functions/schema.rs @@ -108,3 +108,29 @@ fn row_index_schema( *guard = Some(schema_ref.clone()); schema_ref } + +// We don't use an `Arc` because caches should live in different query plans. +// For that reason we have a specialized deep clone. +#[derive(Default)] +pub struct CachedSchema(Mutex>); + +impl AsRef>> for CachedSchema { + fn as_ref(&self) -> &Mutex> { + &self.0 + } +} + +impl Deref for CachedSchema { + type Target = Mutex>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Clone for CachedSchema { + fn clone(&self) -> Self { + let inner = self.0.lock().unwrap(); + Self(Mutex::new(inner.clone())) + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/functions/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/functions/mod.rs index f0b71e3ff42f..d7faf2526a9a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/functions/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/functions/mod.rs @@ -79,10 +79,6 @@ pub(super) fn process_functions( ) }, _ => { - let lp = IR::MapFunction { - input, - function: function.clone(), - }; if function.allow_projection_pd() && !acc_projections.is_empty() { let original_acc_projection_len = acc_projections.len(); @@ -109,6 +105,10 @@ pub(super) fn process_functions( // Remove the cached schema function.clear_cached_schema(); + let lp = IR::MapFunction { + input, + function: function.clone(), + }; if local_projections.is_empty() { Ok(lp) @@ -127,6 +127,10 @@ pub(super) fn process_functions( } } } else { + let lp = IR::MapFunction { + input, + function: function.clone(), + }; // restart projection pushdown proj_pd.no_pushdown_restart_opt( lp, diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 8035eaa02441..ae9dae78e6f2 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -408,3 +408,14 @@ def test_projection_drop_with_series_lit_14382() -> None: "b_name": ["b"], "b_old": [7], } + + +def test_cached_schema_15651() -> None: + q = pl.LazyFrame({"col1": [1], "col2": [2], "col3": [3]}) + q = q.with_row_index() + q = q.filter(~pl.col("col1").is_null()) + # create a subplan diverging from q + _ = q.select(pl.len()).collect(projection_pushdown=True) + + # ensure that q's "cached" columns are still correct + assert q.columns == q.collect().columns