From 0870a5dbe935c98c38d35f78431e69caa3102591 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 11 Oct 2024 11:26:49 +0200 Subject: [PATCH] fix: Fix invalid list collection in expression engine (#19191) --- .../src/chunked_array/builder/list/mod.rs | 11 +++- .../src/chunked_array/from_iterator.rs | 48 +++++--------- .../src/chunked_array/from_iterator_par.rs | 32 ++------- .../chunked_array/object/extension/list.rs | 2 +- .../src/chunked_array/object/extension/mod.rs | 2 +- .../src/chunked_array/object/registry.rs | 21 +++++- crates/polars-expr/src/expressions/apply.rs | 66 +++++++++---------- crates/polars-expr/src/expressions/mod.rs | 7 -- .../tests/unit/operations/test_group_by.py | 7 ++ 9 files changed, 92 insertions(+), 104 deletions(-) diff --git a/crates/polars-core/src/chunked_array/builder/list/mod.rs b/crates/polars-core/src/chunked_array/builder/list/mod.rs index 645a2a168e90..a5110c8f2149 100644 --- a/crates/polars-core/src/chunked_array/builder/list/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/list/mod.rs @@ -19,6 +19,8 @@ pub use null::*; pub use primitive::*; use super::*; +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::get_object_builder; pub trait ListBuilderTrait { fn append_opt_series(&mut self, opt_s: Option<&Series>) -> PolarsResult<()> { @@ -115,7 +117,14 @@ pub fn get_list_builder( match &physical_type { #[cfg(feature = "object")] - DataType::Object(_, _) => polars_bail!(opq = list_builder, &physical_type), + DataType::Object(_, _) => { + let builder = get_object_builder(PlSmallStr::EMPTY, 0).get_list_builder( + name, + value_capacity, + list_capacity, + ); + Ok(Box::new(builder)) + }, #[cfg(feature = "dtype-struct")] DataType::Struct(_) => Ok(Box::new(AnonymousOwnedListBuilder::new( name, diff --git a/crates/polars-core/src/chunked_array/from_iterator.rs b/crates/polars-core/src/chunked_array/from_iterator.rs index 72f2bc8c60cb..36f776d001cf 100644 --- a/crates/polars-core/src/chunked_array/from_iterator.rs +++ b/crates/polars-core/src/chunked_array/from_iterator.rs @@ -199,42 +199,24 @@ impl FromIterator> for ListChunked { } builder.finish() } else { - match first_s.dtype() { - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let mut builder = - first_s.get_list_builder(PlSmallStr::EMPTY, capacity * 5, capacity); - for _ in 0..init_null_count { - builder.append_null(); - } - builder.append_series(first_s).unwrap(); + // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. + let mut builder = get_list_builder( + first_s.dtype(), + capacity * 5, + capacity, + PlSmallStr::EMPTY, + ) + .unwrap(); - for opt_s in it { - builder.append_opt_series(opt_s.as_ref()).unwrap(); - } - builder.finish() - }, - _ => { - // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. - let mut builder = get_list_builder( - first_s.dtype(), - capacity * 5, - capacity, - PlSmallStr::EMPTY, - ) - .unwrap(); - - for _ in 0..init_null_count { - builder.append_null(); - } - builder.append_series(first_s).unwrap(); + for _ in 0..init_null_count { + builder.append_null(); + } + builder.append_series(first_s).unwrap(); - for opt_s in it { - builder.append_opt_series(opt_s.as_ref()).unwrap(); - } - builder.finish() - }, + for opt_s in it { + builder.append_opt_series(opt_s.as_ref()).unwrap(); } + builder.finish() } }, } diff --git a/crates/polars-core/src/chunked_array/from_iterator_par.rs b/crates/polars-core/src/chunked_array/from_iterator_par.rs index 5c9abf4620af..f2ba901dd6ef 100644 --- a/crates/polars-core/src/chunked_array/from_iterator_par.rs +++ b/crates/polars-core/src/chunked_array/from_iterator_par.rs @@ -177,33 +177,13 @@ fn materialize_list( value_capacity: usize, list_capacity: usize, ) -> ListChunked { - match &dtype { - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let s = vectors - .iter() - .flatten() - .find_map(|opt_s| opt_s.as_ref()) - .unwrap(); - let mut builder = s.get_list_builder(name, value_capacity, list_capacity); - - for v in vectors { - for val in v { - builder.append_opt_series(val.as_ref()).unwrap(); - } - } - builder.finish() - }, - dtype => { - let mut builder = get_list_builder(dtype, value_capacity, list_capacity, name).unwrap(); - for v in vectors { - for val in v { - builder.append_opt_series(val.as_ref()).unwrap(); - } - } - builder.finish() - }, + let mut builder = get_list_builder(&dtype, value_capacity, list_capacity, name).unwrap(); + for v in vectors { + for val in v { + builder.append_opt_series(val.as_ref()).unwrap(); + } } + builder.finish() } impl FromParallelIterator> for ListChunked { diff --git a/crates/polars-core/src/chunked_array/object/extension/list.rs b/crates/polars-core/src/chunked_array/object/extension/list.rs index 1918039d647e..2d34315c378d 100644 --- a/crates/polars-core/src/chunked_array/object/extension/list.rs +++ b/crates/polars-core/src/chunked_array/object/extension/list.rs @@ -18,7 +18,7 @@ impl ObjectChunked { } } -struct ExtensionListBuilder { +pub(crate) struct ExtensionListBuilder { values_builder: ObjectChunkedBuilder, offsets: Vec, fast_explode: bool, diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 5a049da4a01f..89ccd65a7c1a 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -1,5 +1,5 @@ pub(crate) mod drop; -mod list; +pub(super) mod list; pub(crate) mod polars_extension; use std::mem; diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index e84c7ab69ba5..4bda1162bb94 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -13,7 +13,7 @@ use polars_utils::pl_str::PlSmallStr; use crate::chunked_array::object::builder::ObjectChunkedBuilder; use crate::datatypes::AnyValue; -use crate::prelude::PolarsObject; +use crate::prelude::{ListBuilderTrait, PolarsObject}; use crate::series::{IntoSeries, Series}; /// Takes a `name` and `capacity` and constructs a new builder. @@ -71,6 +71,13 @@ pub trait AnonymousObjectBuilder { /// Take the current state and materialize as a [`Series`] /// the builder should not be used after that. fn to_series(&mut self) -> Series; + + fn get_list_builder( + &self, + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box; } impl AnonymousObjectBuilder for ObjectChunkedBuilder { @@ -87,6 +94,18 @@ impl AnonymousObjectBuilder for ObjectChunkedBuilder { let builder = std::mem::take(self); builder.finish().into_series() } + fn get_list_builder( + &self, + name: PlSmallStr, + values_capacity: usize, + list_capacity: usize, + ) -> Box { + Box::new(super::extension::list::ExtensionListBuilder::::new( + name, + values_capacity, + list_capacity, + )) + } } pub fn register_object_builder( diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index a52cff4ca2f5..f9993473099a 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; +use polars_core::chunked_array::builder::get_list_builder; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "parquet")] @@ -265,46 +266,43 @@ impl ApplyExpr { // Length of the items to iterate over. let len = iters[0].size_hint().0; - if len == 0 { - drop(iters); - - // Take the first aggregation context that as that is the input series. - let mut ac = acs.swap_remove(0); - ac.with_update_groups(UpdateGroups::No); - - let agg_state = if self.function_returns_scalar { - AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype)) - } else { - match self.collect_groups { - ApplyOptions::ElementWise | ApplyOptions::ApplyList => ac - .agg_state() - .map(|_| Series::new_empty(field.name().clone(), &field.dtype)), - ApplyOptions::GroupWise => AggState::AggregatedList(Series::new_empty( - field.name().clone(), - &DataType::List(Box::new(field.dtype.clone())), - )), - } - }; - - ac.with_agg_state(agg_state); - return Ok(ac); - } - - let ca = (0..len) - .map(|_| { + let ca = if len == 0 { + let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name)?; + for _ in 0..len { container.clear(); for iter in &mut iters { match iter.next().unwrap() { - None => return Ok(None), + None => { + builder.append_null(); + }, Some(s) => container.push(s.deep_clone().into()), } } - self.function + let out = self + .function .call_udf(&mut container) - .map(|r| r.map(|c| c.as_materialized_series().clone())) - }) - .collect::>()? - .with_name(field.name.clone()); + .map(|r| r.map(|c| c.as_materialized_series().clone()))?; + + builder.append_opt_series(out.as_ref())? + } + builder.finish() + } else { + (0..len) + .map(|_| { + container.clear(); + for iter in &mut iters { + match iter.next().unwrap() { + None => return Ok(None), + Some(s) => container.push(s.deep_clone().into()), + } + } + self.function + .call_udf(&mut container) + .map(|r| r.map(|c| c.as_materialized_series().clone())) + }) + .collect::>()? + .with_name(field.name.clone()) + }; drop(iters); @@ -443,7 +441,7 @@ impl PhysicalExpr for ApplyExpr { self.expr.to_field(input_schema, Context::Default) } #[cfg(feature = "parquet")] - fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> { + fn as_stats_evaluator(&self) -> Option<&dyn StatsEvaluator> { let function = match &self.expr { Expr::Function { function, .. } => function, _ => return None, diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 8a74033953dc..15550c517fe7 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -72,13 +72,6 @@ impl AggState { AggState::NotAggregated(s) => AggState::NotAggregated(func(s)?), }) } - - fn map(&self, func: F) -> Self - where - F: FnOnce(&Series) -> Series, - { - self.try_map(|s| Ok(func(s))).unwrap() - } } // lazy update strategy diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index af9dc9a180e2..5ed57b374149 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -1146,3 +1146,10 @@ def test_positional_by_with_list_or_tuple_17540() -> None: pl.DataFrame({"a": [1, 2, 3]}).group_by(by=["a"]) with pytest.raises(TypeError, match="Hint: if you"): pl.LazyFrame({"a": [1, 2, 3]}).group_by(by=["a"]) + + +def test_group_by_agg_19173() -> None: + df = pl.DataFrame({"x": [1.0], "g": [0]}) + out = df.head(0).group_by("g").agg((pl.col.x - pl.col.x.sum() * pl.col.x) ** 2) + assert out.to_dict(as_series=False) == {"g": [], "x": []} + assert out.schema == pl.Schema([("g", pl.Int64), ("x", pl.List(pl.Float64))])