Skip to content

Commit

Permalink
feat: Support arithmetic between Series with dtype list (#17823)
Browse files Browse the repository at this point in the history
Co-authored-by: Itamar Turner-Trauring <[email protected]>
  • Loading branch information
itamarst and pythonspeed authored Sep 23, 2024
1 parent fa84fc0 commit 341df85
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 24 deletions.
177 changes: 177 additions & 0 deletions crates/polars-core/src/series/arithmetic/list_borrowed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//! Allow arithmetic operations for ListChunked.
use super::*;
use crate::chunked_array::builder::AnonymousListBuilder;

/// Given an ArrayRef with some primitive values, wrap it in list(s) until it
/// matches the requested shape.
fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef {
if let Some(list_chunk) = shape.as_any().downcast_ref::<LargeListArray>() {
let result = LargeListArray::new(
list_chunk.dtype().clone(),
list_chunk.offsets().clone(),
reshape_list_based_on(data, list_chunk.values()),
list_chunk.validity().cloned(),
);
Box::new(result)
} else {
data.clone()
}
}

/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or
/// more nulls.
fn does_list_have_nulls(data: &ArrayRef) -> bool {
if let Some(list_chunk) = data.as_any().downcast_ref::<LargeListArray>() {
if list_chunk
.validity()
.map(|bitmap| bitmap.unset_bits() > 0)
.unwrap_or(false)
{
true
} else {
does_list_have_nulls(list_chunk.values())
}
} else {
false
}
}

/// Return whether the left and right have the same shape. We assume neither has
/// any nulls, recursively.
fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool {
debug_assert!(!does_list_have_nulls(left));
debug_assert!(!does_list_have_nulls(right));
let left_as_list = left.as_any().downcast_ref::<LargeListArray>();
let right_as_list = right.as_any().downcast_ref::<LargeListArray>();
match (left_as_list, right_as_list) {
(Some(left), Some(right)) => {
left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values())
},
(None, None) => left.len() == right.len(),
_ => false,
}
}

impl ListChunked {
/// Helper function for NumOpsDispatchInner implementation for ListChunked.
///
/// Run the given `op` on `self` and `rhs`.
fn arithm_helper(
&self,
rhs: &Series,
op: &dyn Fn(&Series, &Series) -> PolarsResult<Series>,
has_nulls: Option<bool>,
) -> PolarsResult<Series> {
polars_ensure!(
self.len() == rhs.len(),
InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}",
self.len(),
rhs.len()
);

let mut has_nulls = has_nulls.unwrap_or(false);
if !has_nulls {
for chunk in self.chunks().iter() {
if does_list_have_nulls(chunk) {
has_nulls = true;
break;
}
}
}
if !has_nulls {
for chunk in rhs.chunks().iter() {
if does_list_have_nulls(chunk) {
has_nulls = true;
break;
}
}
}
if has_nulls {
// A slower implementation since we can't just add the underlying
// values Arrow arrays. Given nulls, the two values arrays might not
// line up the way we expect.
let mut result = AnonymousListBuilder::new(
self.name().clone(),
self.len(),
Some(self.inner_dtype().clone()),
);
let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| {
let (Some(a_owner), Some(b_owner)) = (a, b) else {
// Operations with nulls always result in nulls:
return Ok(None);
};
let a = a_owner.as_ref();
let b = b_owner.as_ref();
polars_ensure!(
a.len() == b.len(),
InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}",
a.len(),
b.len()
);
let chunk_result = if let Ok(a_listchunked) = a.list() {
// If `a` contains more lists, we're going to reach this
// function recursively, and again have to decide whether to
// use the fast path (no nulls) or slow path (there were
// nulls). Since we know there were nulls, that means we
// have to stick to the slow path, so pass that information
// along.
a_listchunked.arithm_helper(b, op, Some(true))
} else {
op(a, b)
};
chunk_result.map(Some)
}).collect::<PolarsResult<Vec<Option<Series>>>>()?;
for s in combined.iter() {
if let Some(s) = s {
result.append_series(s)?;
} else {
result.append_null();
}
}
return Ok(result.finish().into());
}
let l_rechunked = self.clone().rechunk().into_series();
let l_leaf_array = l_rechunked.get_leaf_array();
let r_leaf_array = rhs.rechunk().get_leaf_array();
polars_ensure!(
lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]),
InvalidOperation: "can only do arithmetic operations on lists of the same size"
);

let result = op(&l_leaf_array, &r_leaf_array)?;

// We now need to wrap the Arrow arrays with the metadata that turns
// them into lists:
// TODO is there a way to do this without cloning the underlying data?
let result_chunks = result.chunks();
assert_eq!(result_chunks.len(), 1);
let left_chunk = &l_rechunked.chunks()[0];
let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk);

unsafe {
let mut result =
ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0);
result.compute_len();
Ok(result.into())
}
}
}

impl NumOpsDispatchInner for ListType {
fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None)
}
fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None)
}
fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None)
}
fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.divide(r), None)
}
fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None)
}
}
1 change: 1 addition & 0 deletions crates/polars-core/src/series/arithmetic/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod borrowed;
mod list_borrowed;
mod owned;

use std::borrow::Cow;
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ impl private::PrivateSeries for SeriesWrap<ListChunked> {
fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
(&self.0).into_total_eq_inner()
}

fn add_to(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.add_to(rhs)
}

fn subtract(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.subtract(rhs)
}

fn multiply(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.multiply(rhs)
}
fn divide(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.divide(rhs)
}
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.remainder(rhs)
}
}

impl SeriesTrait for SeriesWrap<ListChunked> {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)? / right.cast(&right_dt)?
},
dt @ List(_) => {
let left_dt = dt.cast_leaf(Float64);
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)? / right.cast(&right_dt)?
},
_ => {
if right.dtype().is_temporal() {
return left / right;
Expand Down
14 changes: 8 additions & 6 deletions crates/polars-plan/src/plans/conversion/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ fn process_list_arithmetic(
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
match (&type_left, &type_right) {
(DataType::List(inner), _) => {
if type_right != **inner {
(DataType::List(_), _) => {
let leaf = type_left.leaf_dtype();
if type_right != *leaf {
let new_node_right = expr_arena.add(AExpr::Cast {
expr: node_right,
dtype: *inner.clone(),
dtype: type_left.cast_leaf(leaf.clone()),
options: CastOptions::NonStrict,
});

Expand All @@ -73,11 +74,12 @@ fn process_list_arithmetic(
Ok(None)
}
},
(_, DataType::List(inner)) => {
if type_left != **inner {
(_, DataType::List(_)) => {
let leaf = type_right.leaf_dtype();
if type_left != *leaf {
let new_node_left = expr_arena.add(AExpr::Cast {
expr: node_left,
dtype: *inner.clone(),
dtype: type_right.cast_leaf(leaf.clone()),
options: CastOptions::NonStrict,
});

Expand Down
22 changes: 20 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,22 @@ def __sub__(self, other: Any) -> Self | Expr:
return F.lit(self) - other
return self._arithmetic(other, "sub", "sub_<>")

def _recursive_cast_to_dtype(self, leaf_dtype: PolarsDataType) -> Series:
"""
Convert leaf dtype the to given primitive datatype.
This is equivalent to logic in DataType::cast_leaf() in Rust.
"""

def convert_to_primitive(dtype: PolarsDataType) -> PolarsDataType:
if isinstance(dtype, Array):
return Array(convert_to_primitive(dtype.inner), shape=dtype.shape)
if isinstance(dtype, List):
return List(convert_to_primitive(dtype.inner))
return leaf_dtype

return self.cast(convert_to_primitive(self.dtype))

@overload
def __truediv__(self, other: Expr) -> Expr: ...

Expand All @@ -1073,9 +1089,11 @@ def __truediv__(self, other: Any) -> Series | Expr:

# this branch is exactly the floordiv function without rounding the floats
if self.dtype.is_float() or self.dtype == Decimal:
return self._arithmetic(other, "div", "div_<>")
as_float = self
else:
as_float = self._recursive_cast_to_dtype(Float64())

return self.cast(Float64) / other
return as_float._arithmetic(other, "div", "div_<>")

@overload
def __floordiv__(self, other: Expr) -> Expr: ...
Expand Down
Loading

0 comments on commit 341df85

Please sign in to comment.