From 901b2437cabdc66998642d4d761bc5d36053b720 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 19:36:29 +0200 Subject: [PATCH] perf: Use List's TotalEqKernel (#18984) --- crates/polars-compute/src/comparisons/list.rs | 285 ++++++++++++++---- .../src/chunked_array/comparison/mod.rs | 183 +++++++---- .../tests/unit/operations/test_explode.py | 11 +- 3 files changed, 364 insertions(+), 115 deletions(-) diff --git a/crates/polars-compute/src/comparisons/list.rs b/crates/polars-compute/src/comparisons/list.rs index fa35cbaac9b6..cd0414b7cca8 100644 --- a/crates/polars-compute/src/comparisons/list.rs +++ b/crates/polars-compute/src/comparisons/list.rs @@ -1,86 +1,257 @@ -use arrow::array::ListArray; -use arrow::bitmap::{Bitmap, MutableBitmap}; -use arrow::types::Offset; +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::{days_ms, f16, i256, months_days_ns, Offset}; use super::TotalEqKernel; -use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; -impl TotalEqKernel for ListArray { - type Scalar = (); +macro_rules! compare { + ( + $lhs:expr, $rhs:expr, + $op:path, $true_op:expr, + $ineq_len_rv:literal, $invalid_rv:literal + ) => {{ + let lhs = $lhs; + let rhs = $rhs; - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - assert_eq!(self.len(), other.len()); + assert_eq!(lhs.len(), rhs.len()); + assert_eq!(lhs.dtype(), rhs.dtype()); - let mut bitmap = MutableBitmap::with_capacity(self.len()); + macro_rules! call_binary { + ($T:ty) => {{ + let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap(); + let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap(); - for i in 0..self.len() { - let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); - let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + (0..$lhs.len()) + .map(|i| { + let lval = $lhs.validity().map_or(true, |v| v.get(i).unwrap()); + let rval = $rhs.validity().map_or(true, |v| v.get(i).unwrap()); - if !lval || !rval { - bitmap.push(true); - continue; - } + if !lval || !rval { + return $invalid_rv; + } - let (lstart, lend) = self.offsets().start_end(i); - let (rstart, rend) = other.offsets().start_end(i); + // SAFETY: ListArray's invariant offsets.len_proxy() == len + let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) }; + let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) }; - if lend - lstart != rend - rstart { - bitmap.push(false); - continue; - } + if lend - lstart != rend - rstart { + return $ineq_len_rv; + } - let mut lhs_values = self.values().clone(); - lhs_values.slice(lstart, lend - lstart); - let mut rhs_values = other.values().clone(); - rhs_values.slice(rstart, rend - rstart); + let mut lhs_values = lhs_values.clone(); + lhs_values.slice(lstart, lend - lstart); + let mut rhs_values = rhs_values.clone(); + rhs_values.slice(rstart, rend - rstart); - let result = array_tot_eq_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); - bitmap.push(result.unset_bits() == 0); + $true_op($op(&lhs_values, &rhs_values)) + }) + .collect_trusted() + }}; } - bitmap.freeze() - } + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.values().dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - assert_eq!(self.len(), other.len()); + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), - let mut bitmap = MutableBitmap::with_capacity(self.len()); + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} - for i in 0..self.len() { - let (lstart, lend) = self.offsets().start_end(i); - let (rstart, rend) = other.offsets().start_end(i); +macro_rules! compare_broadcast { + ( + $lhs:expr, $rhs:expr, + $offsets:expr, $validity:expr, + $op:path, $true_op:expr, + $ineq_len_rv:literal, $invalid_rv:literal + ) => {{ + let lhs = $lhs; + let rhs = $rhs; - let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); - let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + macro_rules! call_binary { + ($T:ty) => {{ + let values: &$T = $lhs.as_any().downcast_ref().unwrap(); + let scalar: &$T = $rhs.as_any().downcast_ref().unwrap(); - if !lval || !rval { - bitmap.push(false); - continue; - } + let length = $offsets.len_proxy(); - if lend - lstart != rend - rstart { - bitmap.push(true); - continue; - } + (0..length) + .map(move |i| { + let v = $validity.map_or(true, |v| v.get(i).unwrap()); - let mut lhs_values = self.values().clone(); - lhs_values.slice(lstart, lend - lstart); - let mut rhs_values = self.values().clone(); - rhs_values.slice(rstart, rend - rstart); + if !v { + return $invalid_rv; + } - let result = array_tot_ne_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); - bitmap.push(result.set_bits() > 0); + let (start, end) = unsafe { $offsets.start_end_unchecked(i) }; + + if end - start != scalar.len() { + return $ineq_len_rv; + } + + // @TODO: I feel like there is a better way to do this. + let mut values: $T = values.clone(); + <$T>::slice(&mut values, start, end - start); + + $true_op($op(&values, scalar)) + }) + .collect_trusted() + }}; } - bitmap.freeze() + assert_eq!(lhs.dtype(), rhs.dtype()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} + +impl TotalEqKernel for ListArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + compare!( + self, + other, + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0, + false, + true + ) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + compare!( + self, + other, + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0, + true, + false + ) } - fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + compare_broadcast!( + self.values().as_ref(), + other.as_ref(), + self.offsets(), + self.validity(), + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0, + false, + true + ) } - fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + compare_broadcast!( + self.values().as_ref(), + other.as_ref(), + self.offsets(), + self.validity(), + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0, + true, + false + ) } } diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 4ee9cad0a482..344f7a796735 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -638,76 +638,117 @@ impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked { } } -#[doc(hidden)] -fn _list_comparison_helper(lhs: &ListChunked, rhs: &ListChunked, op: F) -> BooleanChunked +fn _list_comparison_helper( + lhs: &ListChunked, + rhs: &ListChunked, + op: F, + broadcast_op: B, + missing: bool, + is_ne: bool, +) -> BooleanChunked where - F: Fn(Option<&Series>, Option<&Series>) -> Option, + F: Fn(&ListArray, &ListArray) -> Bitmap, + B: Fn(&ListArray, &Box) -> Bitmap, { match (lhs.len(), rhs.len()) { (_, 1) => { - let right = rhs.get_as_series(0).map(|s| s.with_name(PlSmallStr::EMPTY)); - lhs.amortized_iter() - .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) - .collect_trusted() + let right = rhs.chunks()[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + if !right.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return lhs.is_not_null(); + } else { + return lhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, lhs.len()); + } + } + + let values = right.values().sliced( + (*right.offsets().first()).try_into().unwrap(), + right.offsets().range().try_into().unwrap(), + ); + + arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into()) }, (1, _) => { - let left = lhs.get_as_series(0).map(|s| s.with_name(PlSmallStr::EMPTY)); - rhs.amortized_iter() - .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) - .collect_trusted() + let left = lhs.chunks()[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + if !left.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return rhs.is_not_null(); + } else { + return rhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()); + } + } + + let values = left.values().sliced( + (*left.offsets().first()).try_into().unwrap(), + left.offsets().range().try_into().unwrap(), + ); + + arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into()) }, - _ => lhs - .amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| { - op( - left.as_ref().map(|us| us.as_ref()), - right.as_ref().map(|us| us.as_ref()), - ) - }) - .collect_trusted(), + _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), } } impl ChunkCompareEq<&ListChunked> for ListChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_equals = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(l.equals(r)), - _ => None, - }; - - _list_comparison_helper(self, rhs, _series_equals) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_kernel, + TotalEqKernel::tot_eq_kernel_broadcast, + false, + false, + ) } fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_equals_missing = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(l.equals_missing(r)), - (None, None) => Some(true), - _ => Some(false), - }; - - _list_comparison_helper(self, rhs, _series_equals_missing) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_missing_kernel, + TotalEqKernel::tot_eq_missing_kernel_broadcast, + true, + false, + ) } fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_not_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(!l.equals(r)), - _ => None, - }; - - _list_comparison_helper(self, rhs, _series_not_equal) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_kernel, + TotalEqKernel::tot_ne_kernel_broadcast, + false, + true, + ) } fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_not_equal_missing = - |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(!l.equals_missing(r)), - (None, None) => Some(false), - _ => Some(true), - }; - - _list_comparison_helper(self, rhs, _series_not_equal_missing) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_missing_kernel, + TotalEqKernel::tot_ne_missing_kernel_broadcast, + true, + true, + ) } } @@ -798,6 +839,8 @@ fn _array_comparison_helper( rhs: &ArrayChunked, op: F, broadcast_op: B, + missing: bool, + is_ne: bool, ) -> BooleanChunked where F: Fn(&FixedSizeListArray, &FixedSizeListArray) -> Bitmap, @@ -808,17 +851,41 @@ where let right = rhs.chunks()[0] .as_any() .downcast_ref::() - .unwrap() - .values(); - arity::unary_mut_values(lhs, |a| broadcast_op(a, right).into()) + .unwrap(); + + if !right.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return lhs.is_not_null(); + } else { + return lhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, lhs.len()); + } + } + + arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into()) }, (1, _) => { let left = lhs.chunks()[0] .as_any() .downcast_ref::() - .unwrap() - .values(); - arity::unary_mut_values(rhs, |a| broadcast_op(a, left).into()) + .unwrap(); + + if !left.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return rhs.is_not_null(); + } else { + return rhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()); + } + } + + arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into()) }, _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), } @@ -833,6 +900,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_eq_kernel, TotalEqKernel::tot_eq_kernel_broadcast, + false, + false, ) } @@ -842,6 +911,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_eq_missing_kernel, TotalEqKernel::tot_eq_missing_kernel_broadcast, + true, + false, ) } @@ -851,6 +922,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_ne_kernel, TotalEqKernel::tot_ne_kernel_broadcast, + false, + true, ) } @@ -860,6 +933,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_ne_missing_kernel, TotalEqKernel::tot_ne_missing_kernel_broadcast, + true, + true, ) } } diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 14aefa93c3c1..3807a6b29ef5 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -405,14 +405,14 @@ def test_fast_explode_merge_left_16923() -> None: @pytest.mark.parametrize( ("values", "exploded"), [ - (["foobar", None], ["f", "o", "o", "b", "a", "r", None]), - ([None, "foo", "bar"], [None, "f", "o", "o", "b", "a", "r"]), + (["foobar", None], ["f", "o", "o", "b", "a", "r", ""]), + ([None, "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), ( [None, "foo", "bar", None, "ham"], - [None, "f", "o", "o", "b", "a", "r", None, "h", "a", "m"], + ["", "f", "o", "o", "b", "a", "r", "", "h", "a", "m"], ), (["foo", "bar", "ham"], ["f", "o", "o", "b", "a", "r", "h", "a", "m"]), - (["", None, "foo", "bar"], ["", None, "f", "o", "o", "b", "a", "r"]), + (["", None, "foo", "bar"], ["", "", "f", "o", "o", "b", "a", "r"]), (["", "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), ], ) @@ -421,6 +421,9 @@ def test_series_str_explode_deprecated( ) -> None: with pytest.deprecated_call(): result = pl.Series(values).str.explode() + if result.to_list() != exploded: + print(result.to_list()) + print(exploded) assert result.to_list() == exploded