Skip to content

Commit

Permalink
Logic for arithmetic between a ListChunked and a numeric Series.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Sep 24, 2024
1 parent e3141fd commit 835f9fa
Showing 1 changed file with 102 additions and 14 deletions.
116 changes: 102 additions & 14 deletions crates/polars-core/src/series/arithmetic/list_borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,112 @@ fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool {
}
}

/// Arithmetic operations that can be applied to a Series
#[derive(Clone, Copy)]
enum Op {
Add,
Subtract,
Multiply,
Divide,
Remainder,
}

impl Op {
/// Apply the operation to a pair of Series.
fn apply_with_series(&self, lhs: &Series, rhs: &Series) -> PolarsResult<Series> {
use Op::*;

match self {
Add => lhs + rhs,
Subtract => lhs - rhs,
Multiply => lhs * rhs,
Divide => lhs / rhs,
Remainder => lhs % rhs,
}
}

/// Apply the operation to a Series and scalar.
fn apply_with_scalar<T: Num + NumCast>(&self, lhs: &Series, rhs: T) -> Series {
use Op::*;

match self {
Add => lhs + rhs,
Subtract => lhs - rhs,
Multiply => lhs * rhs,
Divide => lhs / rhs,
Remainder => lhs % rhs,
}
}
}

impl ListChunked {
/// Helper function for NumOpsDispatchInner implementation for ListChunked.
///
/// Run the given `op` on `self` and `rhs`, for cases where `rhs` has a
/// primitive numeric dtype.
fn arithm_helper_numeric(&self, rhs: &Series, op: Op) -> PolarsResult<Series> {
let mut result = AnonymousListBuilder::new(
self.name().clone(),
self.len(),
Some(self.inner_dtype().clone()),
);
macro_rules! combine {
($ca:expr) => {{
self.amortized_iter()
.zip($ca.iter())
.map(|(a, b)| {
let (Some(a_owner), Some(b)) = (a, b) else {
// Operations with nulls always result in nulls:
return Ok(None);
};
let a = a_owner.as_ref().rechunk();
let leaf_result = op.apply_with_scalar(&a.get_leaf_array(), b);
let result =
reshape_list_based_on(&leaf_result.chunks()[0], &a.chunks()[0]);
Ok(Some(result))
})
.collect::<PolarsResult<Vec<Option<Box<dyn Array>>>>>()?
}};
}
let combined = downcast_as_macro_arg_physical!(rhs, combine);
for arr in combined.iter() {
if let Some(arr) = arr {
result.append_array(arr.as_ref());
} else {
result.append_null();
}
}
Ok(result.finish().into())
}

/// Helper function for NumOpsDispatchInner implementation for ListChunked.
///
/// Run the given `op` on `self` and `rhs`.
fn arithm_helper(
&self,
rhs: &Series,
op: &dyn Fn(&Series, &Series) -> PolarsResult<Series>,
has_nulls: Option<bool>,
) -> PolarsResult<Series> {
fn arithm_helper(&self, rhs: &Series, op: Op, has_nulls: Option<bool>) -> PolarsResult<Series> {
polars_ensure!(
self.dtype().leaf_dtype().is_numeric() && rhs.dtype().leaf_dtype().is_numeric(),
InvalidOperation: "List Series can only do arithmetic operations if they and other Series are numeric, left and right dtypes are {:?} and {:?}",
self.dtype(),
rhs.dtype()
);
polars_ensure!(
self.len() == rhs.len(),
InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}",
self.len(),
rhs.len()
);

if rhs.dtype().is_numeric() {
return self.arithm_helper_numeric(rhs, op);
}

polars_ensure!(
self.dtype() == rhs.dtype(),
InvalidOperation: "List Series doing arithmetic operations to each other should have same dtype; got {:?} and {:?}",
self.dtype(),
rhs.dtype()
);

let mut has_nulls = has_nulls.unwrap_or(false);
if !has_nulls {
for chunk in self.chunks().iter() {
Expand Down Expand Up @@ -118,7 +207,7 @@ impl ListChunked {
// along.
a_listchunked.arithm_helper(b, op, Some(true))
} else {
op(a, b)
op.apply_with_series(a, b)
};
chunk_result.map(Some)
}).collect::<PolarsResult<Vec<Option<Series>>>>()?;
Expand All @@ -139,8 +228,7 @@ impl ListChunked {
InvalidOperation: "can only do arithmetic operations on lists of the same size"
);

let result = op(&l_leaf_array, &r_leaf_array)?;

let result = op.apply_with_series(&l_leaf_array, &r_leaf_array)?;
// We now need to wrap the Arrow arrays with the metadata that turns
// them into lists:
// TODO is there a way to do this without cloning the underlying data?
Expand All @@ -160,18 +248,18 @@ impl ListChunked {

impl NumOpsDispatchInner for ListType {
fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None)
lhs.arithm_helper(rhs, Op::Add, None)
}
fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None)
lhs.arithm_helper(rhs, Op::Subtract, None)
}
fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None)
lhs.arithm_helper(rhs, Op::Multiply, None)
}
fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.divide(r), None)
lhs.arithm_helper(rhs, Op::Divide, None)
}
fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None)
lhs.arithm_helper(rhs, Op::Remainder, None)
}
}

0 comments on commit 835f9fa

Please sign in to comment.