Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Series/Expr.list.product #15148

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod dispersion;
pub(crate) mod hash;
mod min_max;
mod namespace;
mod product;
#[cfg(feature = "list_sets")]
mod sets;
mod sum_mean;
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::Write;

use arrow::array::ValueSize;
use arrow::legacy::kernels::list::{index_is_oob, sublist_get};
use namespace::product::product_with_nulls;
use polars_core::chunked_array::builder::get_list_builder;
#[cfg(feature = "list_gather")]
use polars_core::export::num::ToPrimitive;
Expand Down Expand Up @@ -204,6 +205,11 @@ pub trait ListNameSpaceImpl: AsList {
}
}

fn lst_product(&self) -> PolarsResult<Series> {
let ca = self.as_list();
product_with_nulls(ca, &ca.inner_dtype())
}

fn lst_mean(&self) -> Series {
let ca = self.as_list();

Expand Down
76 changes: 76 additions & 0 deletions crates/polars-ops/src/chunked_array/list/product.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use polars_core::export::num::NumCast;

use super::*;

fn product<T>(s: &Series) -> PolarsResult<T>
where
T: NumCast,
{
let prod = s.product()?.cast(&DataType::Float64)?;
Ok(T::from(prod.f64().unwrap().get(0).unwrap()).unwrap())
}

pub(super) fn product_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> PolarsResult<Series> {
petrosbar marked this conversation as resolved.
Show resolved Hide resolved
use DataType::*;
let out = match inner_dtype {
Boolean => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
Int8 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
UInt8 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
Int16 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
UInt16 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
Int32 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
UInt32 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
Int64 => {
let out: Int64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<i64>(s.as_ref()).unwrap()));
out.into_series()
},
UInt64 => {
let out: UInt64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<u64>(s.as_ref()).unwrap()));
out.into_series()
},
Float32 => {
let out: Float32Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<f32>(s.as_ref()).unwrap()));
out.into_series()
},
Float64 => {
let out: Float64Chunked =
ca.apply_amortized_generic(|s| s.map(|s| product::<f64>(s.as_ref()).unwrap()));
out.into_series()
},
_ => {
polars_bail!(InvalidOperation: "`list.product` operation not supported for dtype `{inner_dtype}`")
},
};
Ok(out.with_name(ca.name()))
}
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum ListFunction {
#[cfg(feature = "list_count")]
CountMatches,
Sum,
Product,
Length,
Max,
Min,
Expand Down Expand Up @@ -79,6 +80,7 @@ impl ListFunction {
#[cfg(feature = "list_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
Sum => mapper.nested_sum_type(),
Product => mapper.nested_product_type(),
Min => mapper.map_to_list_and_array_inner_dtype(),
Max => mapper.map_to_list_and_array_inner_dtype(),
Mean => mapper.with_dtype(DataType::Float64),
Expand Down Expand Up @@ -144,6 +146,7 @@ impl Display for ListFunction {
#[cfg(feature = "list_count")]
CountMatches => "count_matches",
Sum => "sum",
Product => "product",
Min => "min",
Max => "max",
Mean => "mean",
Expand Down Expand Up @@ -211,6 +214,7 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "list_count")]
CountMatches => map_as_slice!(count_matches),
Sum => map!(sum),
Product => map!(product),
Length => map!(length),
Max => map!(max),
Min => map!(min),
Expand Down Expand Up @@ -514,6 +518,10 @@ pub(super) fn sum(s: &Series) -> PolarsResult<Series> {
s.list()?.lst_sum()
}

pub(super) fn product(s: &Series) -> PolarsResult<Series> {
s.list()?.lst_product()
}

pub(super) fn length(s: &Series) -> PolarsResult<Series> {
Ok(s.list()?.lst_lengths().into_series())
}
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,19 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

pub fn nested_product_type(&self) -> PolarsResult<Field> {
use DataType::*;
let mut first = self.fields[0].clone();
let dt = first.data_type().inner_dtype().cloned().unwrap_or(Unknown);

if matches!(dt, UInt8 | Int8 | UInt16 | Int16 | UInt32 | Int32) {
first.coerce(Int64);
} else {
first.coerce(dt);
}
Ok(first)
}

pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
let base_dtype = self.fields[0].data_type();
let exponent_dtype = self.fields[1].data_type();
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ impl ListNameSpace {
.map_private(FunctionExpr::ListExpr(ListFunction::Sum))
}

/// Compute the product of the items in every sublist.
pub fn product(self) -> Expr {
self.0
.map_private(FunctionExpr::ListExpr(ListFunction::Product))
}

/// Compute the mean of every sublist and return a `Series` of dtype `Float64`
pub fn mean(self) -> Expr {
self.0
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The following methods are available under the `expr.list` attribute.
Expr.list.mean
Expr.list.median
Expr.list.min
Expr.list.product
Expr.list.reverse
Expr.list.sample
Expr.list.set_difference
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The following methods are available under the `Series.list` attribute.
Series.list.mean
Series.list.median
Series.list.min
Series.list.product
Series.list.reverse
Series.list.sample
Series.list.set_difference
Expand Down
20 changes: 20 additions & 0 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,26 @@ def sum(self) -> Expr:
"""
return wrap_expr(self._pyexpr.list_sum())

def product(self) -> Expr:
"""
Compute the product of the lists in the array.

Examples
--------
>>> df = pl.DataFrame({"values": [[2, 2], [2, 3, 4]]})
>>> df.with_columns(product=pl.col("values").list.product())
shape: (2, 2)
┌───────────┬─────────┐
│ values ┆ product │
│ --- ┆ --- │
│ list[i64] ┆ i64 │
╞═══════════╪═════════╡
│ [2, 2] ┆ 4 │
│ [2, 3, 4] ┆ 24 │
└───────────┴─────────┘
"""
return wrap_expr(self._pyexpr.list_product())

def max(self) -> Expr:
"""
Compute the max value of the lists in the array.
Expand Down
16 changes: 16 additions & 0 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def sum(self) -> Series:
]
"""

def product(self) -> Series:
"""
Compute the product of the arrays in the list.

Examples
--------
>>> s = pl.Series("values", [[2, 2], [2, 3, 4]])
>>> s.list.product()
shape: (2,)
Series: 'values' [i64]
[
4
24
]
"""

def max(self) -> Series:
"""
Compute the max value of the arrays in the list.
Expand Down
9 changes: 9 additions & 0 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ impl PyExpr {
self.inner.clone().list().sum().with_fmt("list.sum").into()
}

fn list_product(&self) -> Self {
self.inner
.clone()
.list()
.product()
.with_fmt("list.product")
.into()
}

#[cfg(feature = "list_drop_nulls")]
fn list_drop_nulls(&self) -> Self {
self.inner.clone().list().drop_nulls().into()
Expand Down
62 changes: 62 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,65 @@ def test_take_list_15719() -> None:
)

assert_frame_equal(df, expected)


def test_list_product_and_dtypes() -> None:
for dt_in, dt_out in [
(pl.Int8, pl.Int64),
(pl.Int16, pl.Int64),
(pl.Int32, pl.Int64),
(pl.Int64, pl.Int64),
(pl.UInt8, pl.Int64),
(pl.UInt16, pl.Int64),
(pl.UInt32, pl.Int64),
(pl.UInt64, pl.UInt64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
]:
df = pl.DataFrame(
{"a": [[1], [None, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]},
schema={"a": pl.List(dt_in)},
)
assert df.select(pl.col("a").list.product()).dtypes == [dt_out]

# Lists of numerics
assert pl.DataFrame(
{"a": [[1], [2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]},
).select(pl.col("a").list.product()).to_dict(as_series=False) == {
"a": [1, 6, 24, 120]
}

# Lists of numerics with nulls
assert pl.DataFrame(
{"a": [[1], [None, 2, 3], [1, 2, 3, 4, None], [1, 2, 3, 4, 5]]},
).select(pl.col("a").list.product()).to_dict(as_series=False) == {
"a": [1, 6, 24, 120]
}

# List of booleans
assert pl.DataFrame(
{"a": [[True], [True, True], [True, False], [False, False]]},
).select(pl.col("a").list.product()).to_dict(as_series=False) == {"a": [1, 1, 0, 0]}

# List of booleans with nulls
assert pl.DataFrame(
{
"a": [
[True],
[True, True],
[True, False],
[True, True, None],
[False, False],
]
},
).select(pl.col("a").list.product()).to_dict(as_series=False) == {
"a": [1, 1, 0, 1, 0]
}


def test_list_product_invalid_type_raises() -> None:
with pytest.raises(
pl.InvalidOperationError,
match="`list.product` operation not supported for dtype",
):
pl.Series("a", [["a", "b"]]).list.product()
1 change: 1 addition & 0 deletions py-polars/tests/unit/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def test_list_arithmetic() -> None:
assert_series_equal(s.list.mean(), pl.Series("a", [1.5, 2.0]))
assert_series_equal(s.list.max(), pl.Series("a", [2, 3]))
assert_series_equal(s.list.min(), pl.Series("a", [1, 1]))
assert_series_equal(s.list.product(), pl.Series("a", [2, 6]))


def test_list_ordering() -> None:
Expand Down
Loading