From fc35fbaf1790b111e59432f5374225f3d0f609c2 Mon Sep 17 00:00:00 2001 From: petrosbar Date: Sun, 17 Mar 2024 09:23:18 +0200 Subject: [PATCH 1/2] feat: Implement `Series/Expr.list.product` --- .../polars-ops/src/chunked_array/list/mod.rs | 1 + .../src/chunked_array/list/namespace.rs | 6 ++ .../src/chunked_array/list/product.rs | 76 +++++++++++++++++++ .../polars-plan/src/dsl/function_expr/list.rs | 8 ++ .../src/dsl/function_expr/schema.rs | 13 ++++ crates/polars-plan/src/dsl/list.rs | 6 ++ .../source/reference/expressions/list.rst | 1 + .../docs/source/reference/series/list.rst | 1 + py-polars/polars/expr/list.py | 20 +++++ py-polars/polars/series/list.py | 16 ++++ py-polars/src/expr/list.rs | 9 +++ py-polars/tests/unit/datatypes/test_list.py | 54 +++++++++++++ .../tests/unit/namespaces/list/test_list.py | 1 + 13 files changed, 212 insertions(+) create mode 100644 crates/polars-ops/src/chunked_array/list/product.rs diff --git a/crates/polars-ops/src/chunked_array/list/mod.rs b/crates/polars-ops/src/chunked_array/list/mod.rs index a93b1ed7e2b3..b1e2b7c6c1e3 100644 --- a/crates/polars-ops/src/chunked_array/list/mod.rs +++ b/crates/polars-ops/src/chunked_array/list/mod.rs @@ -8,6 +8,7 @@ mod dispersion; pub(crate) mod hash; mod min_max; mod namespace; +mod product; #[cfg(feature = "list_sets")] mod sets; mod sum_mean; diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 0d511c87967c..d493bad20238 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -2,6 +2,7 @@ use std::fmt::Write; use arrow::array::ValueSize; use arrow::legacy::kernels::list::{index_is_oob, sublist_get}; +use namespace::product::product_with_nulls; use polars_core::chunked_array::builder::get_list_builder; #[cfg(feature = "list_gather")] use polars_core::export::num::ToPrimitive; @@ -204,6 +205,11 @@ pub trait ListNameSpaceImpl: AsList { } } + fn lst_product(&self) -> PolarsResult { + let ca = self.as_list(); + product_with_nulls(ca, &ca.inner_dtype()) + } + fn lst_mean(&self) -> Series { let ca = self.as_list(); diff --git a/crates/polars-ops/src/chunked_array/list/product.rs b/crates/polars-ops/src/chunked_array/list/product.rs new file mode 100644 index 000000000000..4de7ae7f293e --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/product.rs @@ -0,0 +1,76 @@ +use polars_core::export::num::NumCast; + +use super::*; + +fn product(s: &Series) -> PolarsResult +where + T: NumCast, +{ + let prod = s.product()?.cast(&DataType::Float64)?; + Ok(T::from(prod.f64().unwrap().get(0).unwrap()).unwrap()) +} + +pub(super) fn product_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> PolarsResult { + use DataType::*; + let out = match inner_dtype { + Boolean => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + Int8 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + UInt8 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + Int16 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + UInt16 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + Int32 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + UInt32 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + Int64 => { + let out: Int64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + UInt64 => { + let out: UInt64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + Float32 => { + let out: Float32Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + Float64 => { + let out: Float64Chunked = + ca.apply_amortized_generic(|s| s.map(|s| product::(s.as_ref()).unwrap())); + out.into_series() + }, + _ => { + polars_bail!(InvalidOperation: "`list.product` operation not supported for dtype `{inner_dtype}`") + }, + }; + Ok(out.with_name(ca.name())) +} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 3b06841fbb55..5122eab08a5c 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -29,6 +29,7 @@ pub enum ListFunction { #[cfg(feature = "list_count")] CountMatches, Sum, + Product, Length, Max, Min, @@ -79,6 +80,7 @@ impl ListFunction { #[cfg(feature = "list_count")] CountMatches => mapper.with_dtype(IDX_DTYPE), Sum => mapper.nested_sum_type(), + Product => mapper.nested_product_type(), Min => mapper.map_to_list_and_array_inner_dtype(), Max => mapper.map_to_list_and_array_inner_dtype(), Mean => mapper.with_dtype(DataType::Float64), @@ -144,6 +146,7 @@ impl Display for ListFunction { #[cfg(feature = "list_count")] CountMatches => "count_matches", Sum => "sum", + Product => "product", Min => "min", Max => "max", Mean => "mean", @@ -211,6 +214,7 @@ impl From for SpecialEq> { #[cfg(feature = "list_count")] CountMatches => map_as_slice!(count_matches), Sum => map!(sum), + Product => map!(product), Length => map!(length), Max => map!(max), Min => map!(min), @@ -514,6 +518,10 @@ pub(super) fn sum(s: &Series) -> PolarsResult { s.list()?.lst_sum() } +pub(super) fn product(s: &Series) -> PolarsResult { + s.list()?.lst_product() +} + pub(super) fn length(s: &Series) -> PolarsResult { Ok(s.list()?.lst_lengths().into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 597b2003f955..c7fb4d879187 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -480,6 +480,19 @@ impl<'a> FieldsMapper<'a> { Ok(first) } + pub fn nested_product_type(&self) -> PolarsResult { + use DataType::*; + let mut first = self.fields[0].clone(); + let dt = first.data_type().inner_dtype().cloned().unwrap_or(Unknown); + + if matches!(dt, UInt8 | Int8 | UInt16 | Int16 | UInt32 | Int32) { + first.coerce(Int64); + } else { + first.coerce(dt); + } + Ok(first) + } + pub(super) fn pow_dtype(&self) -> PolarsResult { let base_dtype = self.fields[0].data_type(); let exponent_dtype = self.fields[1].data_type(); diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 0f6c15c755e7..024b1e3ad055 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -100,6 +100,12 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::Sum)) } + /// Compute the product of the items in every sublist. + pub fn product(self) -> Expr { + self.0 + .map_private(FunctionExpr::ListExpr(ListFunction::Product)) + } + /// Compute the mean of every sublist and return a `Series` of dtype `Float64` pub fn mean(self) -> Expr { self.0 diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index d168e3976f02..660fae008c44 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -33,6 +33,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.mean Expr.list.median Expr.list.min + Expr.list.product Expr.list.reverse Expr.list.sample Expr.list.set_difference diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index 2398fe0ea24d..8c7aa403e88b 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -33,6 +33,7 @@ The following methods are available under the `Series.list` attribute. Series.list.mean Series.list.median Series.list.min + Series.list.product Series.list.reverse Series.list.sample Series.list.set_difference diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 3c827794ffdb..047f091b6de4 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -215,6 +215,26 @@ def sum(self) -> Expr: """ return wrap_expr(self._pyexpr.list_sum()) + def product(self) -> Expr: + """ + Compute the product of the lists in the array. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[2, 2], [2, 3, 4]]}) + >>> df.with_columns(product=pl.col("values").list.product()) + shape: (2, 2) + ┌───────────┬─────────┐ + │ values ┆ product │ + │ --- ┆ --- │ + │ list[i64] ┆ i64 │ + ╞═══════════╪═════════╡ + │ [2, 2] ┆ 4 │ + │ [2, 3, 4] ┆ 24 │ + └───────────┴─────────┘ + """ + return wrap_expr(self._pyexpr.list_product()) + def max(self) -> Expr: """ Compute the max value of the lists in the array. diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 762d6d55eccb..9ccfc1f9fc60 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -185,6 +185,22 @@ def sum(self) -> Series: ] """ + def product(self) -> Series: + """ + Compute the product of the arrays in the list. + + Examples + -------- + >>> s = pl.Series("values", [[2, 2], [2, 3, 4]]) + >>> s.list.product() + shape: (2,) + Series: 'values' [i64] + [ + 4 + 24 + ] + """ + def max(self) -> Series: """ Compute the max value of the arrays in the list. diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 00b442cae1d3..d74fbc620734 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -144,6 +144,15 @@ impl PyExpr { self.inner.clone().list().sum().with_fmt("list.sum").into() } + fn list_product(&self) -> Self { + self.inner + .clone() + .list() + .product() + .with_fmt("list.product") + .into() + } + #[cfg(feature = "list_drop_nulls")] fn list_drop_nulls(&self) -> Self { self.inner.clone().list().drop_nulls().into() diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 018b8b4de159..fcc1c5502c7a 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -831,3 +831,57 @@ def test_take_list_15719() -> None: ) assert_frame_equal(df, expected) + + +def test_list_product_and_dtypes() -> None: + for dt_in, dt_out in [ + (pl.Int8, pl.Int64), + (pl.Int16, pl.Int64), + (pl.Int32, pl.Int64), + (pl.Int64, pl.Int64), + (pl.UInt8, pl.Int64), + (pl.UInt16, pl.Int64), + (pl.UInt32, pl.Int64), + (pl.UInt64, pl.UInt64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ]: + df = pl.DataFrame( + {"a": [[1], [None, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, + schema={"a": pl.List(dt_in)}, + ) + assert df.select(pl.col("a").list.product()).dtypes == [dt_out] + + # Lists of numerics + assert pl.DataFrame( + {"a": [[1], [2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, + ).select(pl.col("a").list.product()).to_dict(as_series=False) == { + "a": [1, 6, 24, 120] + } + + # Lists of numerics with nulls + assert pl.DataFrame( + {"a": [[1], [None, 2, 3], [1, 2, 3, 4, None], [1, 2, 3, 4, 5]]}, + ).select(pl.col("a").list.product()).to_dict(as_series=False) == { + "a": [1, 6, 24, 120] + } + + # List of booleans + assert pl.DataFrame( + {"a": [[True], [True, True], [True, False], [False, False]]}, + ).select(pl.col("a").list.product()).to_dict(as_series=False) == {"a": [1, 1, 0, 0]} + + # List of booleans with nulls + assert pl.DataFrame( + { + "a": [ + [True], + [True, True], + [True, False], + [True, True, None], + [False, False], + ] + }, + ).select(pl.col("a").list.product()).to_dict(as_series=False) == { + "a": [1, 1, 0, 1, 0] + } diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py index 95a0dc3e0472..993e476afbad 100644 --- a/py-polars/tests/unit/namespaces/list/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -791,6 +791,7 @@ def test_list_arithmetic() -> None: assert_series_equal(s.list.mean(), pl.Series("a", [1.5, 2.0])) assert_series_equal(s.list.max(), pl.Series("a", [2, 3])) assert_series_equal(s.list.min(), pl.Series("a", [1, 1])) + assert_series_equal(s.list.product(), pl.Series("a", [2, 6])) def test_list_ordering() -> None: From c65ffe477d649ba01d0fc0c5bbff0d889869993f Mon Sep 17 00:00:00 2001 From: petrosbar Date: Tue, 19 Mar 2024 08:45:47 +0200 Subject: [PATCH 2/2] test invalid dtype raises --- py-polars/tests/unit/datatypes/test_list.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index fcc1c5502c7a..991e7fa8bfea 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -885,3 +885,11 @@ def test_list_product_and_dtypes() -> None: ).select(pl.col("a").list.product()).to_dict(as_series=False) == { "a": [1, 1, 0, 1, 0] } + + +def test_list_product_invalid_type_raises() -> None: + with pytest.raises( + pl.InvalidOperationError, + match="`list.product` operation not supported for dtype", + ): + pl.Series("a", [["a", "b"]]).list.product()