From 4cfe56f2198b7c5607baf2bf9a88ba1ec3b73de5 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 23 Oct 2024 18:47:25 +0200 Subject: [PATCH] refactor(rust): Add streaming groupby for reductions (#19291) --- Cargo.lock | 2 + .../polars-arrow/src/compute/aggregate/mod.rs | 2 +- .../chunked_array/logical/categorical/mod.rs | 18 ++ .../polars-core/src/chunked_array/ops/mod.rs | 1 + .../src/chunked_array/ops/row_encode.rs | 220 ++++++++++++++++ .../chunked_array/ops/sort/arg_bottom_k.rs | 1 + .../ops/sort/arg_sort_multiple.rs | 179 +------------ .../src/chunked_array/ops/sort/mod.rs | 43 +-- .../src/chunked_array/struct_/mod.rs | 2 +- crates/polars-core/src/datatypes/dtype.rs | 25 ++ .../src/frame/group_by/into_groups.rs | 2 +- crates/polars-core/src/frame/group_by/mod.rs | 2 +- crates/polars-core/src/frame/mod.rs | 41 ++- crates/polars-core/src/schema.rs | 9 + crates/polars-core/src/series/mod.rs | 79 +++++- crates/polars-expr/Cargo.toml | 2 + crates/polars-expr/src/groups/mod.rs | 68 +++++ crates/polars-expr/src/groups/row_encoded.rs | 183 +++++++++++++ crates/polars-expr/src/lib.rs | 1 + crates/polars-expr/src/reduce/convert.rs | 3 +- crates/polars-expr/src/reduce/len.rs | 2 +- crates/polars-expr/src/reduce/mean.rs | 2 +- crates/polars-expr/src/reduce/min_max.rs | 3 +- crates/polars-expr/src/reduce/mod.rs | 5 +- crates/polars-expr/src/reduce/sum.rs | 2 +- crates/polars-lazy/Cargo.toml | 1 + crates/polars-ops/src/frame/join/mod.rs | 4 +- crates/polars-ops/src/frame/mod.rs | 5 +- crates/polars-ops/src/series/ops/various.rs | 4 +- .../src/executors/sinks/sort/sink_multiple.rs | 2 +- crates/polars-plan/src/plans/aexpr/schema.rs | 8 +- .../polars-plan/src/plans/aexpr/traverse.rs | 2 +- crates/polars-row/src/decode.rs | 3 + crates/polars-row/src/encode.rs | 2 + crates/polars-stream/Cargo.toml | 1 + crates/polars-stream/src/nodes/group_by.rs | 221 ++++++++++++++++ crates/polars-stream/src/nodes/mod.rs | 1 + crates/polars-stream/src/physical_plan/fmt.rs | 11 + .../src/physical_plan/lower_expr.rs | 2 +- .../src/physical_plan/lower_ir.rs | 247 ++++++++++-------- crates/polars-stream/src/physical_plan/mod.rs | 9 +- .../src/physical_plan/to_graph.rs | 44 +++- .../tests/unit/streaming/test_streaming.py | 2 + 43 files changed, 1088 insertions(+), 378 deletions(-) create mode 100644 crates/polars-core/src/chunked_array/ops/row_encode.rs create mode 100644 crates/polars-expr/src/groups/mod.rs create mode 100644 crates/polars-expr/src/groups/row_encoded.rs create mode 100644 crates/polars-stream/src/nodes/group_by.rs diff --git a/Cargo.lock b/Cargo.lock index c601e355be99..dbd7af500fcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2878,6 +2878,7 @@ version = "0.43.1" dependencies = [ "ahash", "bitflags", + "hashbrown 0.15.0", "num-traits", "once_cell", "polars-arrow", @@ -2887,6 +2888,7 @@ dependencies = [ "polars-json", "polars-ops", "polars-plan", + "polars-row", "polars-time", "polars-utils", "rayon", diff --git a/crates/polars-arrow/src/compute/aggregate/mod.rs b/crates/polars-arrow/src/compute/aggregate/mod.rs index 9528f833a67e..481194c1551c 100644 --- a/crates/polars-arrow/src/compute/aggregate/mod.rs +++ b/crates/polars-arrow/src/compute/aggregate/mod.rs @@ -1,4 +1,4 @@ -//! Contains different aggregation functions +/// ! Contains different aggregation functions #[cfg(feature = "compute_aggregate")] mod sum; #[cfg(feature = "compute_aggregate")] diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index a59ff68e40d9..d740455777c4 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -203,6 +203,24 @@ impl CategoricalChunked { } } + /// Create a [`CategoricalChunked`] from a physical array and dtype. + /// + /// # Safety + /// It's not checked that the indices are in-bounds or that the dtype is + /// correct. + pub unsafe fn from_cats_and_dtype_unchecked(idx: UInt32Chunked, dtype: DataType) -> Self { + debug_assert!(matches!( + dtype, + DataType::Enum { .. } | DataType::Categorical { .. } + )); + let mut logical = Logical::::new_logical::(idx); + logical.2 = Some(dtype); + Self { + physical: logical, + bit_settings: Default::default(), + } + } + /// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`. /// /// # Safety diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index a3e7f04cc9e1..c0daaa72bdf6 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -34,6 +34,7 @@ pub(crate) mod nulls; mod reverse; #[cfg(feature = "rolling_window")] pub(crate) mod rolling_window; +pub mod row_encode; pub mod search_sorted; mod set; mod shift; diff --git a/crates/polars-core/src/chunked_array/ops/row_encode.rs b/crates/polars-core/src/chunked_array/ops/row_encode.rs new file mode 100644 index 000000000000..5ac627327389 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/row_encode.rs @@ -0,0 +1,220 @@ +use arrow::compute::utils::combine_validities_and_many; +use polars_row::{convert_columns, EncodingField, RowsEncoded}; +use rayon::prelude::*; + +use crate::prelude::*; +use crate::utils::_split_offsets; +use crate::POOL; + +pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + #[cfg(feature = "dtype-categorical")] + Categorical(_, _) | Enum(_, _) => s.rechunk(), + Binary | Boolean => s.clone(), + BinaryOffset => s.clone(), + String => s.str().unwrap().as_binary().into_series(), + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = s.struct_().unwrap(); + let new_fields = ca + .fields_as_series() + .iter() + .map(convert_series_for_row_encoding) + .collect::>>()?; + let mut out = + StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?; + out.zip_outer_validity(ca); + out.into_series() + }, + // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => s.clone(), + List(inner) if !inner.is_nested() => s.clone(), + Null => s.clone(), + _ => { + let phys = s.to_physical_repr().into_owned(); + polars_ensure!( + phys.dtype().is_numeric(), + InvalidOperation: "cannot sort column of dtype `{}`", s.dtype() + ); + phys + }, + }; + Ok(out) +} + +pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { + let by = convert_series_for_row_encoding(by)?; + let by = by.rechunk(); + + let out = match by.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + let ca = by.categorical().unwrap(); + if ca.uses_lexical_ordering() { + by.to_arrow(0, CompatLevel::newest()) + } else { + ca.physical().chunks[0].clone() + } + }, + // Take physical + _ => by.chunks()[0].clone(), + }; + Ok(out) +} + +pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + Ok(rows.into_array()) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +// Almost the same but broadcast nulls to the row-encoded array. +pub fn encode_rows_vertical_par_unordered_broadcast_nulls( + by: &[Series], +) -> PolarsResult { + let n_threads = POOL.current_num_threads(); + let len = by[0].len(); + let splits = _split_offsets(len, n_threads); + + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded_unordered(&sliced)?; + + let validities = sliced + .iter() + .flat_map(|s| { + let s = s.rechunk(); + #[allow(clippy::unnecessary_to_owned)] + s.chunks() + .to_vec() + .into_iter() + .map(|arr| arr.validity().cloned()) + }) + .collect::>(); + + let validity = combine_validities_and_many(&validities); + Ok(rows.into_array().with_validity_typed(validity)) + }); + let chunks = POOL.install(|| chunks.collect::>>()); + + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) +} + +pub fn encode_rows_unordered(by: &[Series]) -> PolarsResult { + let rows = _get_rows_encoded_unordered(by)?; + Ok(BinaryOffsetChunked::with_chunk( + PlSmallStr::EMPTY, + rows.into_array(), + )) +} + +pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + for by in by { + let arr = _get_rows_encoded_compat_array(by)?; + let field = EncodingField::new_unsorted(); + match arr.dtype() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + for arr in arr.values() { + cols.push(arr.clone() as ArrayRef); + fields.push(field) + } + }, + _ => { + cols.push(arr); + fields.push(field) + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + +pub fn _get_rows_encoded( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + debug_assert_eq!(by.len(), descending.len()); + debug_assert_eq!(by.len(), nulls_last.len()); + + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + + for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { + let by = by.as_materialized_series(); + let arr = _get_rows_encoded_compat_array(by)?; + let sort_field = EncodingField { + descending: *desc, + nulls_last: *null_last, + no_order: false, + }; + match arr.dtype() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + let arr = arr.propagate_nulls(); + for value_arr in arr.values() { + cols.push(value_arr.clone() as ArrayRef); + fields.push(sort_field); + } + }, + _ => { + cols.push(arr); + fields.push(sort_field); + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + +pub fn _get_rows_encoded_ca( + name: PlSmallStr, + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult { + _get_rows_encoded(by, descending, nulls_last) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} + +pub fn _get_rows_encoded_arr( + by: &[Column], + descending: &[bool], + nulls_last: &[bool], +) -> PolarsResult> { + _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) +} + +pub fn _get_rows_encoded_ca_unordered( + name: PlSmallStr, + by: &[Series], +) -> PolarsResult { + _get_rows_encoded_unordered(by) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs index cad95d6b1d10..7f257f23f59e 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -1,6 +1,7 @@ use polars_utils::itertools::Itertools; use super::*; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; #[derive(Eq)] struct CompareRow<'a> { diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 5653039ff02e..2291cc2306e1 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,10 +1,8 @@ -use arrow::compute::utils::combine_validities_and_many; use compare_inner::NullOrderCmp; -use polars_row::{convert_columns, EncodingField, RowsEncoded}; use polars_utils::itertools::Itertools; use super::*; -use crate::utils::_split_offsets; +use crate::chunked_array::ops::row_encode::_get_rows_encoded; pub(crate) fn args_validate( ca: &ChunkedArray, @@ -86,181 +84,6 @@ pub(crate) fn arg_sort_multiple_impl( Ok(ca.into_inner()) } -pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { - let by = convert_sort_column_multi_sort(by)?; - let by = by.rechunk(); - - let out = match by.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - let ca = by.categorical().unwrap(); - if ca.uses_lexical_ordering() { - by.to_arrow(0, CompatLevel::newest()) - } else { - ca.physical().chunks[0].clone() - } - }, - // Take physical - _ => by.chunks()[0].clone(), - }; - Ok(out) -} - -pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult { - let n_threads = POOL.current_num_threads(); - let len = by[0].len(); - let splits = _split_offsets(len, n_threads); - - let chunks = splits.into_par_iter().map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded_unordered(&sliced)?; - Ok(rows.into_array()) - }); - let chunks = POOL.install(|| chunks.collect::>>()); - - Ok(BinaryOffsetChunked::from_chunk_iter( - PlSmallStr::EMPTY, - chunks?, - )) -} - -// Almost the same but broadcast nulls to the row-encoded array. -pub fn encode_rows_vertical_par_unordered_broadcast_nulls( - by: &[Series], -) -> PolarsResult { - let n_threads = POOL.current_num_threads(); - let len = by[0].len(); - let splits = _split_offsets(len, n_threads); - - let chunks = splits.into_par_iter().map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded_unordered(&sliced)?; - - let validities = sliced - .iter() - .flat_map(|s| { - let s = s.rechunk(); - #[allow(clippy::unnecessary_to_owned)] - s.chunks() - .to_vec() - .into_iter() - .map(|arr| arr.validity().cloned()) - }) - .collect::>(); - - let validity = combine_validities_and_many(&validities); - Ok(rows.into_array().with_validity_typed(validity)) - }); - let chunks = POOL.install(|| chunks.collect::>>()); - - Ok(BinaryOffsetChunked::from_chunk_iter( - PlSmallStr::EMPTY, - chunks?, - )) -} - -pub(crate) fn encode_rows_unordered(by: &[Series]) -> PolarsResult { - let rows = _get_rows_encoded_unordered(by)?; - Ok(BinaryOffsetChunked::with_chunk( - PlSmallStr::EMPTY, - rows.into_array(), - )) -} - -pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { - let mut cols = Vec::with_capacity(by.len()); - let mut fields = Vec::with_capacity(by.len()); - for by in by { - let arr = _get_rows_encoded_compat_array(by)?; - let field = EncodingField::new_unsorted(); - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for arr in arr.values() { - cols.push(arr.clone() as ArrayRef); - fields.push(field) - } - }, - _ => { - cols.push(arr); - fields.push(field) - }, - } - } - Ok(convert_columns(&cols, &fields)) -} - -pub fn _get_rows_encoded( - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult { - debug_assert_eq!(by.len(), descending.len()); - debug_assert_eq!(by.len(), nulls_last.len()); - - let mut cols = Vec::with_capacity(by.len()); - let mut fields = Vec::with_capacity(by.len()); - - for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { - let by = by.as_materialized_series(); - let arr = _get_rows_encoded_compat_array(by)?; - let sort_field = EncodingField { - descending: *desc, - nulls_last: *null_last, - no_order: false, - }; - match arr.dtype() { - // Flatten the struct fields. - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - let arr = arr.propagate_nulls(); - for value_arr in arr.values() { - cols.push(value_arr.clone() as ArrayRef); - fields.push(sort_field); - } - }, - _ => { - cols.push(arr); - fields.push(sort_field); - }, - } - } - Ok(convert_columns(&cols, &fields)) -} - -pub fn _get_rows_encoded_ca( - name: PlSmallStr, - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult { - _get_rows_encoded(by, descending, nulls_last) - .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) -} - -pub fn _get_rows_encoded_arr( - by: &[Column], - descending: &[bool], - nulls_last: &[bool], -) -> PolarsResult> { - _get_rows_encoded(by, descending, nulls_last).map(|rows| rows.into_array()) -} - -pub fn _get_rows_encoded_ca_unordered( - name: PlSmallStr, - by: &[Series], -) -> PolarsResult { - _get_rows_encoded_unordered(by) - .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) -} - pub(crate) fn argsort_multiple_row_fmt( by: &[Column], mut descending: Vec, diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index d68583f1dc5c..727f2ace15a8 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -18,6 +18,9 @@ use compare_inner::NonNull; use rayon::prelude::*; pub use slice::*; +use crate::chunked_array::ops::row_encode::{ + _get_rows_encoded_ca, convert_series_for_row_encoding, +}; use crate::prelude::compare_inner::TotalOrdInner; use crate::prelude::sort::arg_sort_multiple::*; use crate::prelude::*; @@ -708,44 +711,6 @@ impl ChunkSort for BooleanChunked { } } -pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult { - use DataType::*; - let out = match s.dtype() { - #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => s.rechunk(), - Binary | Boolean => s.clone(), - BinaryOffset => s.clone(), - String => s.str().unwrap().as_binary().into_series(), - #[cfg(feature = "dtype-struct")] - Struct(_) => { - let ca = s.struct_().unwrap(); - let new_fields = ca - .fields_as_series() - .iter() - .map(convert_sort_column_multi_sort) - .collect::>>()?; - let mut out = - StructChunked::from_series(ca.name().clone(), ca.len(), new_fields.iter())?; - out.zip_outer_validity(ca); - out.into_series() - }, - // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here - #[cfg(feature = "dtype-decimal")] - Decimal(_, _) => s.clone(), - List(inner) if !inner.is_nested() => s.clone(), - Null => s.clone(), - _ => { - let phys = s.to_physical_repr().into_owned(); - polars_ensure!( - phys.dtype().is_numeric(), - InvalidOperation: "cannot sort column of dtype `{}`", s.dtype() - ); - phys - }, - }; - Ok(out) -} - pub fn _broadcast_bools(n_cols: usize, values: &mut Vec) { if n_cols > values.len() && values.len() == 1 { while n_cols != values.len() { @@ -763,7 +728,7 @@ pub(crate) fn prepare_arg_sort( let mut columns = columns .iter() .map(Column::as_materialized_series) - .map(convert_sort_column_multi_sort) + .map(convert_series_for_row_encoding) .map(|s| s.map(Column::from)) .collect::>>()?; diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs index 05ac32424d5e..625da8881117 100644 --- a/crates/polars-core/src/chunked_array/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -10,8 +10,8 @@ use polars_utils::aliases::PlHashMap; use polars_utils::itertools::Itertools; use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::chunked_array::ChunkedArray; -use crate::prelude::sort::arg_sort_multiple::{_get_rows_encoded_arr, _get_rows_encoded_ca}; use crate::prelude::*; use crate::series::Series; use crate::utils::Container; diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 02cf360f3bf5..650fb8eab05d 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; #[cfg(feature = "dtype-array")] use polars_utils::format_tuple; +use polars_utils::itertools::Itertools; use super::*; #[cfg(feature = "object")] @@ -191,6 +192,30 @@ impl DataType { } } + /// Materialize this datatype if it is unknown. All other datatypes + /// are left unchanged. + pub fn materialize_unknown(&self) -> PolarsResult { + match self { + DataType::Unknown(u) => u + .materialize() + .ok_or_else(|| polars_err!(SchemaMismatch: "failed to materialize unknown type")), + DataType::List(inner) => Ok(DataType::List(Box::new(inner.materialize_unknown()?))), + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => Ok(DataType::Struct( + fields + .iter() + .map(|f| { + PolarsResult::Ok(Field::new( + f.name().clone(), + f.dtype().materialize_unknown()?, + )) + }) + .try_collect_vec()?, + )), + _ => Ok(self.clone()), + } + } + #[cfg(feature = "dtype-array")] /// Get the full shape of a multidimensional array. pub fn get_shape(&self) -> Option> { diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index bdaa439a1232..12b4b27de7e2 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -3,8 +3,8 @@ use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; use crate::chunked_array::cast::CastOptions; +use crate::chunked_array::ops::row_encode::_get_rows_encoded_ca_unordered; use crate::config::verbose; -use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered; use crate::series::BitRepr; use crate::utils::flatten::flatten_par; diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 31936a3a5906..9dee1e1f411a 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -21,7 +21,7 @@ mod proxy; pub use into_groups::*; pub use proxy::*; -use crate::prelude::sort::arg_sort_multiple::{ +use crate::chunked_array::ops::row_encode::{ encode_rows_unordered, encode_rows_vertical_par_unordered, }; diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 5941970d14cf..8ce1525b2ed2 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -309,16 +309,6 @@ impl DataFrame { /// Converts a sequence of columns into a DataFrame, broadcasting length-1 /// columns to match the other columns. pub fn new_with_broadcast(columns: Vec) -> PolarsResult { - ensure_names_unique(&columns, |s| s.name().as_str())?; - unsafe { Self::new_with_broadcast_no_checks(columns) } - } - - /// Converts a sequence of columns into a DataFrame, broadcasting length-1 - /// columns to match the other columns. - /// - /// # Safety - /// Does not check that the column names are unique (which they must be). - pub unsafe fn new_with_broadcast_no_checks(mut columns: Vec) -> PolarsResult { // The length of the longest non-unit length column determines the // broadcast length. If all columns are unit-length the broadcast length // is one. @@ -328,17 +318,42 @@ impl DataFrame { .filter(|l| *l != 1) .max() .unwrap_or(1); + Self::new_with_broadcast_len(columns, broadcast_len) + } + + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to broadcast_len. + pub fn new_with_broadcast_len( + columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { + ensure_names_unique(&columns, |s| s.name().as_str())?; + unsafe { Self::new_with_broadcast_no_namecheck(columns, broadcast_len) } + } + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + /// + /// # Safety + /// Does not check that the column names are unique (which they must be). + pub unsafe fn new_with_broadcast_no_namecheck( + mut columns: Vec, + broadcast_len: usize, + ) -> PolarsResult { for col in &mut columns { // Length not equal to the broadcast len, needs broadcast or is an error. let len = col.len(); if len != broadcast_len { if len != 1 { let name = col.name().to_owned(); - let longest_column = columns.iter().max_by_key(|c| c.len()).unwrap().name(); + let extra_info = + if let Some(c) = columns.iter().find(|c| c.len() == broadcast_len) { + format!(" (matching column '{}')", c.name()) + } else { + String::new() + }; polars_bail!( - ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", - name, len, longest_column, broadcast_len + ShapeMismatch: "could not create a new DataFrame: series {name:?} has length {len} while trying to broadcast to length {broadcast_len}{extra_info}", ); } *col = col.new_from_index(0, broadcast_len); diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index d100cf91172f..38fe0377f9c4 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -20,6 +20,8 @@ pub trait SchemaExt { fn iter_fields(&self) -> impl ExactSizeIterator + '_; fn to_supertype(&mut self, other: &Schema) -> PolarsResult; + + fn materialize_unknown_dtypes(&self) -> PolarsResult; } impl SchemaExt for Schema { @@ -88,6 +90,13 @@ impl SchemaExt for Schema { } Ok(changed) } + + /// Materialize all unknown dtypes in this schema. + fn materialize_unknown_dtypes(&self) -> PolarsResult { + self.iter() + .map(|(name, dtype)| Ok((name.clone(), dtype.materialize_unknown()?))) + .collect() + } } pub trait SchemaNamesAndDtypes { diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 227e53ec56d7..9508a72cc1d7 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -23,6 +23,7 @@ use arrow::offset::Offsets; pub use from::*; pub use iterator::{SeriesIter, SeriesPhysIter}; use num_traits::NumCast; +use polars_utils::itertools::Itertools; pub use series_trait::{IsSorted, *}; use crate::chunked_array::cast::CastOptions; @@ -590,15 +591,16 @@ impl Series { lhs.zip_with_same_type(mask, rhs.as_ref()) } - /// Cast a datelike Series to their physical representation. - /// Primitives remain unchanged + /// Converts a Series to their physical representation, if they have one, + /// otherwise the series is left unchanged. /// /// * Date -> Int32 - /// * Datetime-> Int64 + /// * Datetime -> Int64 + /// * Duration -> Int64 /// * Time -> Int64 /// * Categorical -> UInt32 /// * List(inner) -> List(physical of inner) - /// + /// * Struct -> Struct with physical repr of each struct column pub fn to_physical_repr(&self) -> Cow { use DataType::*; match self.dtype() { @@ -639,6 +641,75 @@ impl Series { } } + /// Attempts to convert a Series to dtype, only allowing conversions from + /// physical to logical dtypes--the inverse of to_physical_repr(). + /// + /// # Safety + /// When converting from UInt32 to Categorical it is not checked that the + /// values are in-bound for the categorical mapping. + pub unsafe fn to_logical_repr_unchecked(&self, dtype: &DataType) -> PolarsResult { + use DataType::*; + + let err = || { + Err( + polars_err!(ComputeError: "can't cast from {} to {} in to_logical_repr_unchecked", self.dtype(), dtype), + ) + }; + + match dtype { + dt if self.dtype() == dt => Ok(self.clone()), + #[cfg(feature = "dtype-date")] + Date => Ok(self.i32()?.clone().into_date().into_series()), + #[cfg(feature = "dtype-datetime")] + Datetime(u, z) => Ok(self + .i64()? + .clone() + .into_datetime(*u, z.clone()) + .into_series()), + #[cfg(feature = "dtype-duration")] + Duration(u) => Ok(self.i64()?.clone().into_duration(*u).into_series()), + #[cfg(feature = "dtype-time")] + Time => Ok(self.i64()?.clone().into_time().into_series()), + #[cfg(feature = "dtype-categorical")] + Categorical { .. } | Enum { .. } => { + Ok(CategoricalChunked::from_cats_and_dtype_unchecked( + self.u32()?.clone(), + dtype.clone(), + ) + .into_series()) + }, + List(inner) => { + if let List(self_inner) = self.dtype() { + if inner.to_physical() == **self_inner { + return self.cast(dtype); + } + } + err() + }, + #[cfg(feature = "dtype-struct")] + Struct(target_fields) => { + let ca = self.struct_().unwrap(); + if ca.struct_fields().len() != target_fields.len() { + return err(); + } + let fields = ca + .fields_as_series() + .iter() + .zip(target_fields) + .map(|(s, tf)| s.to_logical_repr_unchecked(tf.dtype())) + .try_collect_vec()?; + let mut result = + StructChunked::from_series(self.name().clone(), ca.len(), fields.iter())?; + if ca.null_count() > 0 { + result.zip_outer_validity(ca); + } + Ok(result.into_series()) + }, + + _ => err(), + } + } + /// Take by index if ChunkedArray contains a single chunk. /// /// # Safety diff --git a/crates/polars-expr/Cargo.toml b/crates/polars-expr/Cargo.toml index d53585b17d43..0911445617aa 100644 --- a/crates/polars-expr/Cargo.toml +++ b/crates/polars-expr/Cargo.toml @@ -12,6 +12,7 @@ description = "Physical expression implementation of the Polars project." ahash = { workspace = true } arrow = { workspace = true } bitflags = { workspace = true } +hashbrown = { workspace = true } num-traits = { workspace = true } once_cell = { workspace = true } polars-compute = { workspace = true } @@ -20,6 +21,7 @@ polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } polars-ops = { workspace = true, features = ["chunked_ids"] } polars-plan = { workspace = true } +polars-row = { workspace = true } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } rayon = { workspace = true } diff --git a/crates/polars-expr/src/groups/mod.rs b/crates/polars-expr/src/groups/mod.rs new file mode 100644 index 000000000000..5eb32b34a052 --- /dev/null +++ b/crates/polars-expr/src/groups/mod.rs @@ -0,0 +1,68 @@ +use std::any::Any; +use std::path::Path; + +use polars_core::prelude::*; +use polars_utils::aliases::PlRandomState; +use polars_utils::IdxSize; + +mod row_encoded; + +/// A Grouper maps keys to groups, such that duplicate keys map to the same group. +pub trait Grouper: Any + Send { + /// Creates a new empty Grouper similar to this one. + fn new_empty(&self) -> Box; + + /// Returns the number of groups in this Grouper. + fn num_groups(&self) -> IdxSize; + + /// Inserts the given keys into this Grouper, mutating groups_idxs such + /// that group_idxs[i] is the group index of keys[..][i]. + fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec); + + /// Adds the given Grouper into this one, mutating groups_idxs such that + /// the ith group of other now has group index group_idxs[i] in self. + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec); + + /// Partitions this Grouper into the given partitions. + /// + /// Updates partition_idxs and group_idxs such that the ith group of self + /// has group index group_idxs[i] in partition partition_idxs[i]. + /// + /// It is guaranteed that two equal keys in two independent partition_into + /// calls map to the same partition index if the seed and the number of + /// partitions is equal. + fn partition_into( + &self, + seed: u64, + partitions: &mut [Box], + partition_idxs: &mut Vec, + group_idxs: &mut Vec, + ); + + /// Returns the keys in this Grouper in group order, that is the key for + /// group i is returned in row i. + fn get_keys_in_group_order(&self) -> DataFrame; + + /// Returns the keys in this Grouper, mutating group_idxs such that the ith + /// key returned corresponds to group group_idxs[i]. + fn get_keys_groups(&self, group_idxs: &mut Vec) -> DataFrame; + + /// Stores this Grouper at the given path. + fn store_ooc(&self, _path: &Path) { + unimplemented!(); + } + + /// Loads this Grouper from the given path. + fn load_ooc(&mut self, _path: &Path) { + unimplemented!(); + } + + fn as_any(&self) -> &dyn Any; +} + +pub fn new_hash_grouper(key_schema: Arc, random_state: PlRandomState) -> Box { + Box::new(row_encoded::RowEncodedHashGrouper::new( + key_schema, + random_state, + )) +} diff --git a/crates/polars-expr/src/groups/row_encoded.rs b/crates/polars-expr/src/groups/row_encoded.rs new file mode 100644 index 000000000000..46ec956106a5 --- /dev/null +++ b/crates/polars-expr/src/groups/row_encoded.rs @@ -0,0 +1,183 @@ +use std::mem::MaybeUninit; + +use hashbrown::hash_table::{Entry, HashTable}; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_unordered; +use polars_row::EncodingField; +use polars_utils::aliases::PlRandomState; +use polars_utils::itertools::Itertools; +use polars_utils::vec::PushUnchecked; + +use super::*; + +struct Group { + key_hash: u64, + key_offset: usize, + key_length: u32, + group_idx: IdxSize, +} + +impl Group { + unsafe fn key<'k>(&self, key_data: &'k [u8]) -> &'k [u8] { + key_data.get_unchecked(self.key_offset..self.key_offset + self.key_length as usize) + } +} + +#[derive(Default)] +pub struct RowEncodedHashGrouper { + key_schema: Arc, + table: HashTable, + key_data: Vec, + random_state: PlRandomState, +} + +impl RowEncodedHashGrouper { + pub fn new(key_schema: Arc, random_state: PlRandomState) -> Self { + Self { + key_schema, + random_state, + ..Default::default() + } + } + + fn insert_key(&mut self, hash: u64, key: &[u8]) -> IdxSize { + let num_groups = self.table.len(); + let entry = self.table.entry( + hash, + |g| unsafe { hash == g.key_hash && key == g.key(&self.key_data) }, + |g| g.key_hash, + ); + + match entry { + Entry::Occupied(e) => e.get().group_idx, + Entry::Vacant(e) => { + let group_idx: IdxSize = num_groups.try_into().unwrap(); + let group = Group { + key_hash: hash, + key_offset: self.key_data.len(), + key_length: key.len().try_into().unwrap(), + group_idx, + }; + self.key_data.extend(key); + e.insert(group); + group_idx + }, + } + } + + fn finalize_keys(&self, mut key_rows: Vec<&[u8]>) -> DataFrame { + let key_dtypes = self + .key_schema + .iter() + .map(|(_name, dt)| dt.to_physical().to_arrow(CompatLevel::newest())) + .collect::>(); + let fields = vec![EncodingField::new_unsorted(); key_dtypes.len()]; + let key_columns = + unsafe { polars_row::decode::decode_rows(&mut key_rows, &fields, &key_dtypes) }; + + let cols = self + .key_schema + .iter() + .zip(key_columns) + .map(|((name, dt), col)| { + let s = Series::try_from((name.clone(), col)).unwrap(); + unsafe { s.to_logical_repr_unchecked(dt) } + .unwrap() + .into_column() + }) + .collect(); + unsafe { DataFrame::new_no_checks_height_from_first(cols) } + } +} + +impl Grouper for RowEncodedHashGrouper { + fn new_empty(&self) -> Box { + Box::new(Self::new( + self.key_schema.clone(), + self.random_state.clone(), + )) + } + + fn num_groups(&self) -> IdxSize { + self.table.len() as IdxSize + } + + fn insert_keys(&mut self, keys: &DataFrame, group_idxs: &mut Vec) { + let series = keys + .get_columns() + .iter() + .map(|c| c.as_materialized_series().clone()) + .collect_vec(); + let keys_encoded = _get_rows_encoded_unordered(&series[..]) + .unwrap() + .into_array(); + assert!(keys_encoded.len() == keys[0].len()); + + group_idxs.clear(); + group_idxs.reserve(keys_encoded.len()); + for key in keys_encoded.values_iter() { + let hash = self.random_state.hash_one(key); + unsafe { + group_idxs.push_unchecked(self.insert_key(hash, key)); + } + } + } + + fn combine(&mut self, other: &dyn Grouper, group_idxs: &mut Vec) { + let other = other.as_any().downcast_ref::().unwrap(); + + self.table.reserve(other.table.len(), |g| g.key_hash); // TODO: cardinality estimation. + + unsafe { + group_idxs.clear(); + group_idxs.reserve(other.table.len()); + let idx_out = group_idxs.spare_capacity_mut(); + for group in other.table.iter() { + let group_key = group.key(&other.key_data); + let new_idx = self.insert_key(group.key_hash, group_key); + *idx_out.get_unchecked_mut(group.group_idx as usize) = MaybeUninit::new(new_idx); + } + group_idxs.set_len(other.table.len()); + } + } + + fn get_keys_in_group_order(&self) -> DataFrame { + let mut key_rows: Vec<&[u8]> = Vec::with_capacity(self.table.len()); + unsafe { + let out = key_rows.spare_capacity_mut(); + for group in &self.table { + *out.get_unchecked_mut(group.group_idx as usize) = + MaybeUninit::new(group.key(&self.key_data)); + } + key_rows.set_len(self.table.len()); + } + self.finalize_keys(key_rows) + } + + fn get_keys_groups(&self, group_idxs: &mut Vec) -> DataFrame { + group_idxs.clear(); + group_idxs.reserve(self.table.len()); + self.finalize_keys( + self.table + .iter() + .map(|group| unsafe { + group_idxs.push(group.group_idx); + group.key(&self.key_data) + }) + .collect(), + ) + } + + fn partition_into( + &self, + _seed: u64, + _partitions: &mut [Box], + _partition_idxs: &mut Vec, + _group_idxs: &mut Vec, + ) { + unimplemented!() + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/lib.rs b/crates/polars-expr/src/lib.rs index 9981e47f1451..2778f4621dc2 100644 --- a/crates/polars-expr/src/lib.rs +++ b/crates/polars-expr/src/lib.rs @@ -1,4 +1,5 @@ mod expressions; +pub mod groups; pub mod planner; pub mod prelude; pub mod reduce; diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs index 3573192ae16f..53e4abfa7b58 100644 --- a/crates/polars-expr/src/reduce/convert.rs +++ b/crates/polars-expr/src/reduce/convert.rs @@ -17,7 +17,8 @@ pub fn into_reduction( let get_dt = |node| { expr_arena .get(node) - .to_dtype(schema, Context::Default, expr_arena) + .to_dtype(schema, Context::Default, expr_arena)? + .materialize_unknown() }; let out = match expr_arena.get(node) { AExpr::Agg(agg) => match agg { diff --git a/crates/polars-expr/src/reduce/len.rs b/crates/polars-expr/src/reduce/len.rs index db8aee647824..fa5aedb91f18 100644 --- a/crates/polars-expr/src/reduce/len.rs +++ b/crates/polars-expr/src/reduce/len.rs @@ -42,7 +42,7 @@ impl GroupedReduction for LenReduce { group_idxs: &[IdxSize], ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - assert!(self.groups.len() == other.groups.len()); + assert!(other.groups.len() == group_idxs.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, v) in group_idxs.iter().zip(other.groups.iter()) { diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs index 0caa2ccabcb8..cef81b098701 100644 --- a/crates/polars-expr/src/reduce/mean.rs +++ b/crates/polars-expr/src/reduce/mean.rs @@ -141,7 +141,7 @@ impl Reducer for BoolMeanReducer { assert!(dtype == &DataType::Boolean); let ca: Float64Chunked = v .into_iter() - .map(|(s, c)| s as f64 / c as f64) + .map(|(s, c)| (c != 0).then(|| s as f64 / c as f64)) .collect_ca(PlSmallStr::EMPTY); Ok(ca.into_series()) } diff --git a/crates/polars-expr/src/reduce/min_max.rs b/crates/polars-expr/src/reduce/min_max.rs index f1ec0cbcc5d2..81d6bb85f8a5 100644 --- a/crates/polars-expr/src/reduce/min_max.rs +++ b/crates/polars-expr/src/reduce/min_max.rs @@ -401,8 +401,7 @@ impl GroupedReduction for BoolMaxGroupedReduction { group_idxs: &[IdxSize], ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - assert!(self.values.len() == other.values.len()); - assert!(self.mask.len() == other.mask.len()); + assert!(other.values.len() == group_idxs.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, (v, o)) in group_idxs diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs index 8fc0620f27fe..0be52a1de225 100644 --- a/crates/polars-expr/src/reduce/mod.rs +++ b/crates/polars-expr/src/reduce/mod.rs @@ -196,7 +196,7 @@ where ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); assert!(self.in_dtype == other.in_dtype); - assert!(self.values.len() == other.values.len()); + assert!(group_idxs.len() == other.values.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, v) in group_idxs.iter().zip(other.values.iter()) { @@ -297,8 +297,7 @@ where ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); assert!(self.in_dtype == other.in_dtype); - assert!(self.values.len() == other.values.len()); - assert!(self.mask.len() == other.mask.len()); + assert!(group_idxs.len() == other.values.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, (v, o)) in group_idxs diff --git a/crates/polars-expr/src/reduce/sum.rs b/crates/polars-expr/src/reduce/sum.rs index 2b5d9d79c13f..111d69eec4f2 100644 --- a/crates/polars-expr/src/reduce/sum.rs +++ b/crates/polars-expr/src/reduce/sum.rs @@ -116,7 +116,7 @@ where ) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); assert!(self.in_dtype == other.in_dtype); - assert!(self.sums.len() == other.sums.len()); + assert!(other.sums.len() == group_idxs.len()); unsafe { // SAFETY: indices are in-bounds guaranteed by trait. for (g, v) in group_idxs.iter().zip(other.sums.iter()) { diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 2dfd642cde1f..25986e512381 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -203,6 +203,7 @@ dynamic_group_by = [ "temporal", "polars-expr/dynamic_group_by", "polars-mem-engine/dynamic_group_by", + "polars-stream/dynamic_group_by", ] ewma = ["polars-plan/ewma"] ewma_by = ["polars-plan/ewma_by"] diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 81f4fe54e7e4..c12db482b9f1 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -34,11 +34,11 @@ use hashbrown::hash_map::{Entry, RawEntryMut}; pub use iejoin::{IEJoinOptions, InequalityOperator}; #[cfg(feature = "merge_sorted")] pub use merge_sorted::_merge_sorted_dfs; -use polars_core::hashing::_HASHMAP_INIT_SIZE; #[allow(unused_imports)] -use polars_core::prelude::sort::arg_sort_multiple::{ +use polars_core::chunked_array::ops::row_encode::{ encode_rows_vertical_par_unordered, encode_rows_vertical_par_unordered_broadcast_nulls, }; +use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; pub(super) use polars_core::series::IsSorted; use polars_core::utils::slice_offsets; diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index 4604920351eb..d72ff4488251 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -11,9 +11,6 @@ use polars_core::utils::accumulate_dataframes_horizontal; #[cfg(feature = "to_dummies")] use polars_core::POOL; -#[allow(unused_imports)] -use crate::prelude::*; - pub trait IntoDf { fn to_df(&self) -> &DataFrame; } @@ -94,6 +91,8 @@ pub trait DataFrameOps: IntoDf { separator: Option<&str>, drop_first: bool, ) -> PolarsResult { + use crate::series::ToDummies; + let df = self.to_df(); let set: PlHashSet<&str> = if let Some(columns) = columns { diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index e49e24ddf7a8..47d467f7f7ba 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -1,7 +1,7 @@ use num_traits::Bounded; -use polars_core::prelude::arity::unary_elementwise_values; #[cfg(feature = "dtype-struct")] -use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_ca; +use polars_core::prelude::arity::unary_elementwise_values; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::with_match_physical_numeric_polars_type; diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 9cd85697ebaa..7c0a35db38a1 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,8 +1,8 @@ use std::any::Any; use arrow::array::BinaryArray; +use polars_core::chunked_array::ops::row_encode::_get_rows_encoded_compat_array; use polars_core::prelude::sort::_broadcast_bools; -use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_row::decode::decode_rows_from_binary; diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index a290321f4cf8..8cb4b8cc2387 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -19,17 +19,17 @@ impl AExpr { pub fn to_dtype( &self, schema: &Schema, - ctxt: Context, + ctx: Context, arena: &Arena, ) -> PolarsResult { - self.to_field(schema, ctxt, arena).map(|f| f.dtype) + self.to_field(schema, ctx, arena).map(|f| f.dtype) } /// Get Field result of the expression. The schema is the input data. pub fn to_field( &self, schema: &Schema, - ctxt: Context, + ctx: Context, arena: &Arena, ) -> PolarsResult { // During aggregation a column that isn't aggregated gets an extra nesting level @@ -37,7 +37,7 @@ impl AExpr { // But not if we do an aggregation: // col(foo: i64).sum() -> i64 // The `nested` keeps track of the nesting we need to add. - let mut nested = matches!(ctxt, Context::Aggregation) as u8; + let mut nested = matches!(ctx, Context::Aggregation) as u8; let mut field = self.to_field_impl(schema, arena, &mut nested)?; if nested >= 1 { diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 7163e18de165..1697a5571d4e 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -85,7 +85,7 @@ impl AExpr { } } - pub(crate) fn replace_inputs(mut self, inputs: &[Node]) -> Self { + pub fn replace_inputs(mut self, inputs: &[Node]) -> Self { use AExpr::*; let input = match &mut self { Column(_) | Literal(_) | Len => return self, diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index c60e614350cd..04c320e33ec3 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -66,6 +66,9 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataTyp .collect(); StructArray::new(dtype.clone(), rows.len(), values, None).to_boxed() }, + ArrowDataType::List { .. } | ArrowDataType::LargeList { .. } => { + todo!("list decoding is not yet supported in polars' row encoding") + }, dt => { with_match_arrow_primitive_type!(dt, |$T| { decode_primitive::<$T>(rows, field).to_boxed() diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index bd8a12dfbd01..fa9699c72f1b 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -260,6 +260,7 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE .map(|opt_s| opt_s.map(|s| s.as_bytes())); crate::variable::encode_iter(iter, out, field) }, + ArrowDataType::Null => {}, // No output needed. dt => { with_match_arrow_primitive_type!(dt, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); @@ -286,6 +287,7 @@ pub fn encoded_size(dtype: &ArrowDataType) -> usize { Float32 => f32::ENCODED_LEN, Float64 => f64::ENCODED_LEN, Boolean => bool::ENCODED_LEN, + Null => 0, dt => unimplemented!("{dt:?}"), } } diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index 4ed5ac3342c9..ddf4e7be4f18 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -39,3 +39,4 @@ version_check = { workspace = true } nightly = [] bitwise = ["polars-core/bitwise", "polars-plan/bitwise"] merge_sorted = ["polars-plan/merge_sorted"] +dynamic_group_by = [] diff --git a/crates/polars-stream/src/nodes/group_by.rs b/crates/polars-stream/src/nodes/group_by.rs new file mode 100644 index 000000000000..6954263bb99b --- /dev/null +++ b/crates/polars-stream/src/nodes/group_by.rs @@ -0,0 +1,221 @@ +use std::sync::Arc; + +use polars_core::prelude::IntoColumn; +use polars_core::schema::Schema; +use polars_expr::groups::Grouper; +use polars_expr::reduce::GroupedReduction; + +use super::compute_node_prelude::*; +use crate::async_primitives::connector::Receiver; +use crate::expression::StreamExpr; +use crate::nodes::in_memory_source::InMemorySourceNode; + +struct LocalGroupBySinkState { + grouper: Box, + grouped_reductions: Vec>, +} + +struct GroupBySinkState { + key_selectors: Vec, + grouped_reduction_selectors: Vec, + grouper: Box, + grouped_reductions: Vec>, + local: Vec, +} + +impl GroupBySinkState { + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + receivers: Vec>, + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(receivers.len() >= self.local.len()); + self.local + .resize_with(receivers.len(), || LocalGroupBySinkState { + grouper: self.grouper.new_empty(), + grouped_reductions: self + .grouped_reductions + .iter() + .map(|r| r.new_empty()) + .collect(), + }); + for (mut recv, local) in receivers.into_iter().zip(&mut self.local) { + let key_selectors = &self.key_selectors; + let grouped_reduction_selectors = &self.grouped_reduction_selectors; + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + let mut group_idxs = Vec::new(); + while let Ok(morsel) = recv.recv().await { + // Compute group indices from key. + let df = morsel.into_df(); + let mut key_columns = Vec::new(); + for selector in key_selectors { + let s = selector.evaluate(&df, state).await?; + key_columns.push(s.into_column()); + } + let keys = DataFrame::new_with_broadcast_len(key_columns, df.height())?; + local.grouper.insert_keys(&keys, &mut group_idxs); + + // Update reductions. + for (selector, reduction) in grouped_reduction_selectors + .iter() + .zip(&mut local.grouped_reductions) + { + unsafe { + // SAFETY: we resize the reduction to the number of groups beforehand. + reduction.resize(local.grouper.num_groups()); + reduction.update_groups( + &selector.evaluate(&df, state).await?, + &group_idxs, + )?; + } + } + } + Ok(()) + })); + } + } + + fn into_source(mut self, output_schema: &Schema) -> PolarsResult { + // TODO: parallelize this with partitions. + let mut group_idxs = Vec::new(); + let num_pipelines = self.local.len(); + let mut combined = self.local.pop().unwrap(); + for local in self.local { + combined.grouper.combine(&*local.grouper, &mut group_idxs); + for (l, r) in combined + .grouped_reductions + .iter_mut() + .zip(&local.grouped_reductions) + { + unsafe { + l.resize(combined.grouper.num_groups()); + l.combine(&**r, &group_idxs)?; + } + } + } + let mut out = combined.grouper.get_keys_in_group_order(); + let out_names = output_schema.iter_names().skip(out.width()); + for (mut r, name) in combined.grouped_reductions.into_iter().zip(out_names) { + unsafe { + out.with_column_unchecked(r.finalize()?.with_name(name.clone()).into_column()); + } + } + let mut source_node = InMemorySourceNode::new(Arc::new(out)); + source_node.initialize(num_pipelines); + Ok(source_node) + } +} + +enum GroupByState { + Sink(GroupBySinkState), + Source(InMemorySourceNode), + Done, +} + +pub struct GroupByNode { + state: GroupByState, + output_schema: Arc, +} + +impl GroupByNode { + pub fn new( + key_selectors: Vec, + grouped_reduction_selectors: Vec, + grouped_reductions: Vec>, + grouper: Box, + output_schema: Arc, + ) -> Self { + Self { + state: GroupByState::Sink(GroupBySinkState { + key_selectors, + grouped_reduction_selectors, + grouped_reductions, + grouper, + local: Vec::new(), + }), + output_schema, + } + } +} + +impl ComputeNode for GroupByNode { + fn name(&self) -> &str { + "group_by" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(recv.len() == 1 && send.len() == 1); + + // State transitions. + match &mut self.state { + // If the output doesn't want any more data, transition to being done. + _ if send[0] == PortState::Done => { + self.state = GroupByState::Done; + }, + // Input is done, transition to being a source. + GroupByState::Sink(_) if matches!(recv[0], PortState::Done) => { + let GroupByState::Sink(sink) = + core::mem::replace(&mut self.state, GroupByState::Done) + else { + unreachable!() + }; + self.state = GroupByState::Source(sink.into_source(&self.output_schema)?); + }, + // Defer to source node implementation. + GroupByState::Source(src) => { + src.update_state(&mut [], send)?; + if send[0] == PortState::Done { + self.state = GroupByState::Done; + } + }, + // Nothing to change. + GroupByState::Done | GroupByState::Sink(_) => {}, + } + + // Communicate our state. + match &self.state { + GroupByState::Sink { .. } => { + send[0] = PortState::Blocked; + recv[0] = PortState::Ready; + }, + GroupByState::Source(..) => { + recv[0] = PortState::Done; + send[0] = PortState::Ready; + }, + GroupByState::Done => { + recv[0] = PortState::Done; + send[0] = PortState::Done; + }, + } + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv: &mut [Option>], + send: &mut [Option>], + state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(send.len() == 1 && recv.len() == 1); + match &mut self.state { + GroupByState::Sink(sink) => { + assert!(send[0].is_none()); + sink.spawn( + scope, + recv[0].take().unwrap().parallel(), + state, + join_handles, + ) + }, + GroupByState::Source(source) => { + assert!(recv[0].is_none()); + source.spawn(scope, &mut [], send, state, join_handles); + }, + GroupByState::Done => unreachable!(), + } + } +} diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index 82ad0f8293e9..af13c9548f59 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -1,4 +1,5 @@ pub mod filter; +pub mod group_by; pub mod in_memory_map; pub mod in_memory_sink; pub mod in_memory_source; diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 21dd8e9dd634..32def086f3bc 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -183,6 +183,17 @@ fn visualize_plan_rec( (out, &[][..]) }, + PhysNodeKind::GroupBy { input, key, aggs } => { + let label = "group-by"; + ( + format!( + "{label}\\nkey:\\n{}\\naggs:\\n{}", + fmt_exprs(key, expr_arena), + fmt_exprs(aggs, expr_arena) + ), + from_ref(input), + ) + }, }; out.push(format!( diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 6b61a98979f2..1a8c58ae392b 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -664,7 +664,7 @@ fn lower_exprs_with_ctx( /// Computes the schema that selecting the given expressions on the input schema /// would result in. -fn compute_output_schema( +pub fn compute_output_schema( input_schema: &Schema, exprs: &[ExprIR], expr_arena: &Arena, diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 74321642f380..c25083c412b5 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -2,16 +2,37 @@ use std::sync::Arc; use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; use polars_core::schema::Schema; -use polars_error::PolarsResult; +use polars_error::{polars_ensure, PolarsResult}; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; -use polars_plan::plans::{AExpr, FunctionIR, IR}; +use polars_plan::plans::{AExpr, FunctionIR, IRAggExpr, IR}; use polars_plan::prelude::SinkType; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; use slotmap::SlotMap; use super::{PhysNode, PhysNodeKey, PhysNodeKind}; -use crate::physical_plan::lower_expr::{is_elementwise, ExprCache}; +use crate::physical_plan::lower_expr::{build_select_node, is_elementwise, lower_exprs, ExprCache}; + +fn build_slice_node( + input: PhysNodeKey, + offset: i64, + length: usize, + phys_sm: &mut SlotMap, +) -> PhysNodeKey { + if offset >= 0 { + let offset = offset as usize; + phys_sm.insert(PhysNode::new( + phys_sm[input].output_schema.clone(), + PhysNodeKind::StreamingSlice { + input, + offset, + length, + }, + )) + } else { + todo!() + } +} #[recursive::recursive] pub fn lower_ir( @@ -22,19 +43,26 @@ pub fn lower_ir( schema_cache: &mut PlHashMap>, expr_cache: &mut ExprCache, ) -> PolarsResult { - let ir_node = ir_arena.get(node); - let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); - let node_kind = match ir_node { - IR::SimpleProjection { input, columns } => { - let columns = columns.iter_names_cloned().collect::>(); - let phys_input = lower_ir( - *input, + // Helper macro to simplify recursive calls. + macro_rules! lower_ir { + ($input:expr) => { + lower_ir( + $input, ir_arena, expr_arena, phys_sm, schema_cache, expr_cache, - )?; + ) + }; + } + + let ir_node = ir_arena.get(node); + let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); + let node_kind = match ir_node { + IR::SimpleProjection { input, columns } => { + let columns = columns.iter_names_cloned().collect::>(); + let phys_input = lower_ir!(*input)?; PhysNodeKind::SimpleProjection { input: phys_input, columns, @@ -43,17 +71,8 @@ pub fn lower_ir( IR::Select { input, expr, .. } => { let selectors = expr.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; - return super::lower_expr::build_select_node( - phys_input, &selectors, expr_arena, phys_sm, expr_cache, - ); + let phys_input = lower_ir!(*input)?; + return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::HStack { input, exprs, .. } @@ -63,14 +82,7 @@ pub fn lower_ir( { // FIXME: constant literal columns should be broadcasted with hstack. let selectors = exprs.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; PhysNodeKind::Select { input: phys_input, selectors, @@ -84,14 +96,7 @@ pub fn lower_ir( // // FIXME: constant literal columns should be broadcasted with hstack. let exprs = exprs.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; let input_schema = &phys_sm[phys_input].output_schema; let mut selectors = PlIndexMap::with_capacity(input_schema.len() + exprs.len()); for name in input_schema.iter_names() { @@ -106,43 +111,19 @@ pub fn lower_ir( selectors.insert(expr.output_name().clone(), expr); } let selectors = selectors.into_values().collect_vec(); - return super::lower_expr::build_select_node( - phys_input, &selectors, expr_arena, phys_sm, expr_cache, - ); + return build_select_node(phys_input, &selectors, expr_arena, phys_sm, expr_cache); }, IR::Slice { input, offset, len } => { - if *offset >= 0 { - let offset = *offset as usize; - let length = *len as usize; - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; - PhysNodeKind::StreamingSlice { - input: phys_input, - offset, - length, - } - } else { - todo!() - } + let offset = *offset; + let len = *len as usize; + let phys_input = lower_ir!(*input)?; + return Ok(build_slice_node(phys_input, offset, len, phys_sm)); }, IR::Filter { input, predicate } => { let predicate = predicate.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; let cols_and_predicate = output_schema .iter_names() .cloned() @@ -154,7 +135,7 @@ pub fn lower_ir( }) .chain([predicate]) .collect_vec(); - let (trans_input, mut trans_cols_and_predicate) = super::lower_expr::lower_exprs( + let (trans_input, mut trans_cols_and_predicate) = lower_exprs( phys_input, &cols_and_predicate, expr_arena, @@ -170,7 +151,7 @@ pub fn lower_ir( let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); trans_cols_and_predicate.pop(); // Remove predicate. - return super::lower_expr::build_select_node( + return build_select_node( post_filter, &trans_cols_and_predicate, expr_arena, @@ -223,14 +204,7 @@ pub fn lower_ir( IR::Sink { input, payload } => { if *payload == SinkType::Memory { - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; PhysNodeKind::InMemorySink { input: phys_input } } else { todo!() @@ -246,14 +220,7 @@ pub fn lower_ir( } let function = function.clone(); - let phys_input = lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?; + let phys_input = lower_ir!(*input)?; match function { FunctionIR::RowIndex { @@ -293,14 +260,7 @@ pub fn lower_ir( by_column: by_column.clone(), slice: *slice, sort_options: sort_options.clone(), - input: lower_ir( - *input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - )?, + input: lower_ir!(*input)?, }, IR::Union { inputs, options } => { @@ -311,16 +271,7 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| { - lower_ir( - input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - ) - }) + .map(|input| lower_ir!(input)) .collect::>()?; PhysNodeKind::OrderedUnion { inputs } }, @@ -333,16 +284,7 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| { - lower_ir( - input, - ir_arena, - expr_arena, - phys_sm, - schema_cache, - expr_cache, - ) - }) + .map(|input| lower_ir!(input)) .collect::>()?; PhysNodeKind::Zip { inputs, @@ -378,7 +320,82 @@ pub fn lower_ir( IR::PythonScan { .. } => todo!(), IR::Reduce { .. } => todo!(), IR::Cache { .. } => todo!(), - IR::GroupBy { .. } => todo!(), + IR::GroupBy { + input, + keys, + aggs, + schema: _, + apply, + maintain_order, + options, + } => { + if apply.is_some() || *maintain_order { + todo!() + } + + #[cfg(feature = "dynamic_group_by")] + if options.dynamic.is_some() || options.rolling.is_some() { + todo!() + } + + let key = keys.clone(); + let mut aggs = aggs.clone(); + let options = options.clone(); + + polars_ensure!(!keys.is_empty(), ComputeError: "at least one key is required in a group_by operation"); + + // TODO: allow all aggregates. + let mut input_exprs = key.clone(); + for agg in &aggs { + match expr_arena.get(agg.node()) { + AExpr::Agg(expr) => match expr { + IRAggExpr::Min { input, .. } + | IRAggExpr::Max { input, .. } + | IRAggExpr::Mean(input) + | IRAggExpr::Sum(input) => { + if is_elementwise(*input, expr_arena, expr_cache) { + input_exprs.push(ExprIR::from_node(*input, expr_arena)); + } else { + todo!() + } + }, + _ => todo!(), + }, + AExpr::Len => input_exprs.push(key[0].clone()), // Hack, use the first key column for the length. + _ => todo!(), + } + } + + let phys_input = lower_ir!(*input)?; + let (trans_input, trans_exprs) = + lower_exprs(phys_input, &input_exprs, expr_arena, phys_sm, expr_cache)?; + let trans_key = trans_exprs[..key.len()].to_vec(); + let trans_aggs = aggs + .iter_mut() + .zip(trans_exprs.iter().skip(key.len())) + .map(|(agg, trans_expr)| { + let old_expr = expr_arena.get(agg.node()).clone(); + let new_expr = old_expr.replace_inputs(&[trans_expr.node()]); + ExprIR::new(expr_arena.add(new_expr), agg.output_name_inner().clone()) + }) + .collect(); + + let mut node = phys_sm.insert(PhysNode::new( + output_schema, + PhysNodeKind::GroupBy { + input: trans_input, + key: trans_key, + aggs: trans_aggs, + }, + )); + + // TODO: actually limit number of groups instead of computing full + // result and then slicing. + if let Some((offset, len)) = options.slice { + node = build_slice_node(node, offset, len, phys_sm); + } + return Ok(node); + }, IR::Join { .. } => todo!(), IR::Distinct { .. } => todo!(), IR::ExtContext { .. } => todo!(), diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index eddbc87bda99..691ef8e672d7 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -136,6 +136,12 @@ pub enum PhysNodeKind { scan_type: FileScan, file_options: FileScanOptions, }, + + GroupBy { + input: PhysNodeKey, + key: Vec, + aggs: Vec, + }, } #[recursive::recursive] @@ -179,7 +185,8 @@ fn insert_multiplexers( | PhysNodeKind::InMemoryMap { input, .. } | PhysNodeKind::Map { input, .. } | PhysNodeKind::Sort { input, .. } - | PhysNodeKind::Multiplexer { input } => { + | PhysNodeKind::Multiplexer { input } + | PhysNodeKind::GroupBy { input, .. } => { insert_multiplexers(*input, phys_sm, referenced); }, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 34692aa10b9a..b70ce9508f38 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use parking_lot::Mutex; -use polars_core::schema::Schema; +use polars_core::schema::{Schema, SchemaExt}; use polars_error::PolarsResult; +use polars_expr::groups::new_hash_grouper; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; @@ -20,6 +21,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; +use crate::physical_plan::lower_expr::compute_output_schema; use crate::utils::late_materialized_df::LateMaterializedDataFrame; fn has_potential_recurring_entrance(node: Node, arena: &Arena) -> bool { @@ -349,6 +351,46 @@ fn to_graph_rec<'a>( } } }, + + GroupBy { input, key, aggs } => { + let input_key = to_graph_rec(*input, ctx)?; + + let input_schema = &ctx.phys_sm[*input].output_schema; + let key_schema = compute_output_schema(input_schema, key, ctx.expr_arena)? + .materialize_unknown_dtypes()?; + let random_state = Default::default(); + let grouper = new_hash_grouper(Arc::new(key_schema), random_state); + + let key_selectors = key + .iter() + .map(|e| create_stream_expr(e, ctx, input_schema)) + .try_collect_vec()?; + + let mut grouped_reductions = Vec::new(); + let mut grouped_reduction_selectors = Vec::new(); + for agg in aggs { + let (reduction, input_node) = + into_reduction(agg.node(), ctx.expr_arena, input_schema)?; + let selector = create_stream_expr( + &ExprIR::from_node(input_node, ctx.expr_arena), + ctx, + input_schema, + )?; + grouped_reductions.push(reduction); + grouped_reduction_selectors.push(selector); + } + + ctx.graph.add_node( + nodes::group_by::GroupByNode::new( + key_selectors, + grouped_reduction_selectors, + grouped_reductions, + grouper, + node.output_schema.clone(), + ), + [input_key], + ) + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key); diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index 80730273672e..225c0b97553c 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -19,6 +19,7 @@ pytestmark = pytest.mark.xdist_group("streaming") +@pytest.mark.may_fail_auto_streaming def test_streaming_categoricals_5921() -> None: with pl.StringCache(): out_lazy = ( @@ -74,6 +75,7 @@ def test_streaming_streamable_functions(monkeypatch: Any, capfd: Any) -> None: @pytest.mark.slow +@pytest.mark.may_fail_auto_streaming def test_cross_join_stack() -> None: a = pl.Series(np.arange(100_000)).to_frame().lazy() t0 = time.time()