From 011e366b7dbbf1977ea663566c5789fbaf743105 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 25 Sep 2024 13:36:09 +0200 Subject: [PATCH 01/33] fix: Incorrect broadcasting on list-of-string set ops (#18918) --- .../polars-ops/src/chunked_array/list/sets.rs | 10 +++++++-- .../namespaces/list/test_set_operations.py | 22 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index e105d96b737a..4a3187631575 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -303,7 +303,10 @@ fn binary( let offset = if broadcast_rhs { // going via skip iterator instead of slice doesn't heap alloc nor trigger a bitcount let a_iter = a.into_iter().skip(start_a).take(end_a - start_a); - let b_iter = b.into_iter(); + let b_iter = b + .into_iter() + .skip(first_b as usize) + .take(second_b as usize - first_b as usize); set_operation( &mut set, &mut set2, @@ -314,7 +317,10 @@ fn binary( true, ) } else if broadcast_lhs { - let a_iter = a.into_iter(); + let a_iter = a + .into_iter() + .skip(first_a as usize) + .take(second_a as usize - first_a as usize); let b_iter = b.into_iter().skip(start_b).take(end_b - start_b); set_operation( &mut set, diff --git a/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py b/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py index 8082b33391cf..ba451dcb1e41 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_set_operations.py @@ -163,6 +163,28 @@ def test_list_set_operations_binary() -> None: ] +def test_list_set_operations_broadcast_binary() -> None: + df = pl.DataFrame( + { + "a": [["2", "3", "3"], ["3", "1"], ["1", "2", "3"]], + "b": [["1", "2"], ["4"], ["5"]], + } + ) + + assert df.select(pl.col("a").list.set_intersection(pl.col.b.first())).to_dict( + as_series=False + ) == {"a": [["2"], ["1"], ["1", "2"]]} + assert df.select(pl.col("a").list.set_union(pl.col.b.first())).to_dict( + as_series=False + ) == {"a": [["2", "3", "1"], ["3", "1", "2"], ["1", "2", "3"]]} + assert df.select(pl.col("a").list.set_difference(pl.col.b.first())).to_dict( + as_series=False + ) == {"a": [["3"], ["3"], ["3"]]} + assert df.select(pl.col.b.first().list.set_difference("a")).to_dict( + as_series=False + ) == {"b": [["1"], ["2"], []]} + + def test_set_operations_14290() -> None: df = pl.DataFrame( { From 2467b2716ffa04b93aedda55a82d83bb6d998ed0 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 25 Sep 2024 16:05:38 +0200 Subject: [PATCH 02/33] refactor: Fix/skip variety of new-streaming tests (#18924) --- crates/polars-plan/src/plans/python/pyarrow.rs | 2 +- .../src/physical_plan/lower_expr.rs | 18 ++++++++++++++---- py-polars/tests/unit/io/test_parquet.py | 1 + .../tests/unit/lazyframe/test_lazyframe.py | 1 + .../operations/namespaces/list/test_list.py | 1 + .../namespaces/string/test_string.py | 2 +- .../operations/namespaces/test_strptime.py | 2 +- 7 files changed, 20 insertions(+), 7 deletions(-) diff --git a/crates/polars-plan/src/plans/python/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs index 20b800fa81b1..78fcc20cc453 100644 --- a/crates/polars-plan/src/plans/python/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -44,7 +44,7 @@ pub fn predicate_to_pa( } else { let mut list_repr = String::with_capacity(s.len() * 5); list_repr.push('['); - for av in s.iter() { + for av in s.rechunk().iter() { if let AnyValue::Boolean(v) = av { let s = if v { "True" } else { "False" }; write!(list_repr, "{},", s).unwrap(); diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 8e891bd408af..919694e8c538 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -87,13 +87,23 @@ pub(crate) fn is_elementwise( function: _, output_type: _, options, - } - | AExpr::Function { + } => { + options.is_elementwise() && input.iter().all(|e| is_elementwise(e.node(), arena, cache)) + }, + AExpr::Function { input, - function: _, + function, options, } => { - options.is_elementwise() && input.iter().all(|e| is_elementwise(e.node(), arena, cache)) + match function { + // Non-strict strptime must be done in-memory to ensure the format + // is consistent across the entire dataframe. + FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)) => opts.strict, + _ => { + options.is_elementwise() + && input.iter().all(|e| is_elementwise(e.node(), arena, cache)) + }, + } }, AExpr::Window { .. } => false, diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 10aa23dfa7b6..b53766ae2c2c 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -114,6 +114,7 @@ def test_to_from_buffer( @pytest.mark.parametrize("use_pyarrow", [True, False]) @pytest.mark.parametrize("rechunk_and_expected_chunks", [(True, 1), (False, 3)]) +@pytest.mark.may_fail_auto_streaming def test_read_parquet_respects_rechunk_16416( use_pyarrow: bool, rechunk_and_expected_chunks: tuple[bool, int] ) -> None: diff --git a/py-polars/tests/unit/lazyframe/test_lazyframe.py b/py-polars/tests/unit/lazyframe/test_lazyframe.py index 23394110b40e..026d4f157e2f 100644 --- a/py-polars/tests/unit/lazyframe/test_lazyframe.py +++ b/py-polars/tests/unit/lazyframe/test_lazyframe.py @@ -354,6 +354,7 @@ def test_inspect(capsys: CaptureFixture[str]) -> None: assert len(res.out) > 0 +@pytest.mark.may_fail_auto_streaming def test_fetch(fruits_cars: pl.DataFrame) -> None: res = fruits_cars.lazy().select("*")._fetch(2) assert_frame_equal(res, res[:2]) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index f306bbff5d7b..966fee3ea5ac 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -620,6 +620,7 @@ def test_list_unique2() -> None: assert sorted(result[1]) == [1, 2] +@pytest.mark.may_fail_auto_streaming def test_list_to_struct() -> None: df = pl.DataFrame({"n": [[0, 1, 2], [0, 1]]}) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py index fe47b8d07d2e..842b0fd141a5 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_string.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -429,7 +429,7 @@ def test_str_to_integer_base_expr() -> None: # test strict raise df = pl.DataFrame({"str": ["110", "ff00", "cafe", None], "base": [2, 10, 10, 8]}) - with pytest.raises(ComputeError, match="failed for 2 value"): + with pytest.raises(ComputeError): df.select(pl.col("str").str.to_integer(base="base")) diff --git a/py-polars/tests/unit/operations/namespaces/test_strptime.py b/py-polars/tests/unit/operations/namespaces/test_strptime.py index 3aa5890198df..41fdb028e31d 100644 --- a/py-polars/tests/unit/operations/namespaces/test_strptime.py +++ b/py-polars/tests/unit/operations/namespaces/test_strptime.py @@ -161,7 +161,7 @@ def test_to_date_all_inferred_date_patterns(time_string: str, expected: date) -> ], ) def test_non_exact_short_elements_10223(value: str, attr: str) -> None: - with pytest.raises(InvalidOperationError, match="conversion .* failed"): + with pytest.raises((InvalidOperationError, ComputeError)): getattr(pl.Series(["2019-01-01", value]).str, attr)(exact=False) From bef75b9d3f8f8b4ee81fd7f6c48226d70bf4959b Mon Sep 17 00:00:00 2001 From: Alex Harris <122488678+aleexharris@users.noreply.github.com> Date: Thu, 26 Sep 2024 07:49:05 +0100 Subject: [PATCH 03/33] docs(rust): Typo for `IntoDf` trait (#18933) --- crates/polars-ops/src/frame/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index 539d4e0cebc1..4604920351eb 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -27,7 +27,7 @@ impl IntoDf for DataFrame { impl DataFrameOps for T {} pub trait DataFrameOps: IntoDf { - /// Crea dummy variables. + /// Create dummy variables. /// /// # Example /// From 68b6f0e4f73fbaaacb212abad2210eb6a75df79e Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 26 Sep 2024 08:51:26 +0200 Subject: [PATCH 04/33] refactor: Fix/skip variety of new-streaming tests, cont (#18928) --- crates/polars-expr/src/reduce/min_max.rs | 11 +++++++++-- crates/polars-python/src/series/buffers.rs | 1 + crates/polars-stream/src/physical_plan/lower_ir.rs | 8 +++++++- py-polars/tests/unit/operations/test_ewm.py | 1 + py-polars/tests/unit/operations/test_join_asof.py | 1 + py-polars/tests/unit/operations/test_sort.py | 1 + py-polars/tests/unit/operations/test_statistics.py | 6 +++--- py-polars/tests/unit/test_projections.py | 1 + 8 files changed, 24 insertions(+), 6 deletions(-) diff --git a/crates/polars-expr/src/reduce/min_max.rs b/crates/polars-expr/src/reduce/min_max.rs index ba011d7d95f0..27cf3d5b5727 100644 --- a/crates/polars-expr/src/reduce/min_max.rs +++ b/crates/polars-expr/src/reduce/min_max.rs @@ -25,6 +25,9 @@ struct MinReduceState { impl MinReduceState { fn update_with_value(&mut self, other: &AnyValue<'static>) { + // AnyValue uses total ordering, so NaN is greater than any value. + // This means other < self.value.value() already ignores incoming NaNs. + // We still must check if self is NaN and if so replace. if self.value.is_null() || !other.is_null() && (other < self.value.value() || self.value.is_nan()) { @@ -80,8 +83,12 @@ struct MaxReduceState { impl MaxReduceState { fn update_with_value(&mut self, other: &AnyValue<'static>) { + // AnyValue uses total ordering, so NaN is greater than any value. + // This means other > self.value.value() might have false positives. + // We also must check if self is NaN and if so replace. if self.value.is_null() - || !other.is_null() && (other > self.value.value() || self.value.is_nan()) + || !other.is_null() + && (other > self.value.value() && !other.is_nan() || self.value.is_nan()) { self.value.update(other.clone()); } @@ -90,7 +97,7 @@ impl MaxReduceState { impl ReductionState for MaxReduceState { fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.min_reduce()?; + let sc = batch.max_reduce()?; self.update_with_value(sc.value()); Ok(()) } diff --git a/crates/polars-python/src/series/buffers.rs b/crates/polars-python/src/series/buffers.rs index 55013ea1dd2c..30eddce08e39 100644 --- a/crates/polars-python/src/series/buffers.rs +++ b/crates/polars-python/src/series/buffers.rs @@ -337,6 +337,7 @@ where T: PolarsNumericType, { let ca: &ChunkedArray = s.as_ref().as_ref(); + let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); arr.values().clone() } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index beec7a57e358..c00f4adc6003 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -4,7 +4,7 @@ use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; use polars_core::schema::Schema; use polars_error::PolarsResult; use polars_plan::plans::expr_ir::{ExprIR, OutputName}; -use polars_plan::plans::{AExpr, IR}; +use polars_plan::plans::{AExpr, FunctionIR, IR}; use polars_plan::prelude::SinkType; use polars_utils::arena::{Arena, Node}; use polars_utils::itertools::Itertools; @@ -238,6 +238,12 @@ pub fn lower_ir( }, IR::MapFunction { input, function } => { + // MergeSorted uses a rechunk hack incompatible with the + // streaming engine. + if let FunctionIR::MergeSorted { .. } = function { + todo!() + } + let function = function.clone(); let phys_input = lower_ir( *input, diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py index f4b31e880ec3..b776bce1c243 100644 --- a/py-polars/tests/unit/operations/test_ewm.py +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -201,6 +201,7 @@ def test_ewm_param_validation() -> None: # https://github.com/pola-rs/polars/issues/4951 +@pytest.mark.may_fail_auto_streaming def test_ewm_with_multiple_chunks() -> None: df0 = pl.DataFrame( data=[ diff --git a/py-polars/tests/unit/operations/test_join_asof.py b/py-polars/tests/unit/operations/test_join_asof.py index 58b29207dac7..21b7db4ebb90 100644 --- a/py-polars/tests/unit/operations/test_join_asof.py +++ b/py-polars/tests/unit/operations/test_join_asof.py @@ -1091,6 +1091,7 @@ def test_asof_join_nearest_by_date() -> None: assert_frame_equal(out, expected) +@pytest.mark.may_fail_auto_streaming # See #18927. def test_asof_join_string() -> None: left = pl.DataFrame({"x": [None, "a", "b", "c", None, "d", None]}).set_sorted("x") right = pl.DataFrame({"x": ["apple", None, "chutney"], "y": [0, 1, 2]}).set_sorted( diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 6b2060e8305f..57dbec1a13ee 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -806,6 +806,7 @@ def test_sort_string_nulls() -> None: ] +@pytest.mark.may_fail_auto_streaming def test_sort_by_unequal_lengths_7207() -> None: df = pl.DataFrame({"a": [0, 1, 1, 0], "b": [3, 2, 3, 2]}) with pytest.raises(pl.exceptions.ShapeError): diff --git a/py-polars/tests/unit/operations/test_statistics.py b/py-polars/tests/unit/operations/test_statistics.py index ed8b964582cb..30a0bb7cd7a4 100644 --- a/py-polars/tests/unit/operations/test_statistics.py +++ b/py-polars/tests/unit/operations/test_statistics.py @@ -96,9 +96,9 @@ def test_median_quantile_duration() -> None: def test_correlation_cast_supertype() -> None: df = pl.DataFrame({"a": [1, 8, 3], "b": [4.0, 5.0, 2.0]}) df = df.with_columns(pl.col("b")) - assert df.select(pl.corr("a", "b")).to_dict(as_series=False) == { - "a": [0.5447047794019219] - } + assert_frame_equal( + df.select(pl.corr("a", "b")), pl.DataFrame({"a": [0.5447047794019219]}) + ) def test_cov_corr_f32_type() -> None: diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 700100ced4c4..7c279648fa1c 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -122,6 +122,7 @@ def test_hconcat_projection_pushdown_length_maintained() -> None: assert_frame_equal(out, expected) +@pytest.mark.may_fail_auto_streaming def test_unnest_columns_available() -> None: df = pl.DataFrame( { From aec911f9cce0754c0aa658713d79189c09fe045b Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Thu, 26 Sep 2024 08:58:26 +0200 Subject: [PATCH 05/33] fix: Infer reshape dims when determining schema (#18923) --- crates/polars-core/src/datatypes/mod.rs | 2 + crates/polars-core/src/datatypes/reshape.rs | 118 ++++++++++++++++++ crates/polars-core/src/frame/column/mod.rs | 5 +- .../src/series/arithmetic/borrowed.rs | 10 +- crates/polars-core/src/series/ops/reshape.rs | 118 +++++++++--------- .../src/expressions/aggregation.rs | 7 +- crates/polars-expr/src/expressions/mod.rs | 7 +- .../src/chunked_array/list/namespace.rs | 4 +- .../src/dsl/function_expr/dispatch.rs | 9 +- .../polars-plan/src/dsl/function_expr/list.rs | 6 +- .../polars-plan/src/dsl/function_expr/mod.rs | 16 +-- .../src/dsl/function_expr/schema.rs | 48 ++++--- crates/polars-plan/src/dsl/mod.rs | 10 +- crates/polars-plan/src/dsl/options.rs | 2 +- crates/polars-python/src/expr/general.rs | 3 +- .../src/lazyframe/visitor/expr_nodes.rs | 5 +- crates/polars-python/src/series/general.rs | 6 + .../tests/unit/operations/test_reshape.py | 19 ++- 18 files changed, 285 insertions(+), 110 deletions(-) create mode 100644 crates/polars-core/src/datatypes/reshape.rs diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 04f79f1eb086..64266a0066db 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -13,6 +13,7 @@ mod any_value; mod dtype; mod field; mod into_scalar; +mod reshape; #[cfg(feature = "object")] mod static_array_collect; mod time_unit; @@ -41,6 +42,7 @@ use polars_utils::abs_diff::AbsDiff; use polars_utils::float::IsFloat; use polars_utils::min_max::MinMax; use polars_utils::nulls::IsNull; +pub use reshape::*; #[cfg(feature = "serde")] use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] diff --git a/crates/polars-core/src/datatypes/reshape.rs b/crates/polars-core/src/datatypes/reshape.rs new file mode 100644 index 000000000000..7fa5e3dcfa7f --- /dev/null +++ b/crates/polars-core/src/datatypes/reshape.rs @@ -0,0 +1,118 @@ +use std::fmt; +use std::hash::Hash; +use std::num::NonZeroU64; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[repr(transparent)] +pub struct Dimension(NonZeroU64); + +/// A dimension in a reshape. +/// +/// Any dimension smaller than 0 is seen as an `infer`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum ReshapeDimension { + Infer, + Specified(Dimension), +} + +impl fmt::Debug for Dimension { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl fmt::Display for ReshapeDimension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Infer => f.write_str("inferred"), + Self::Specified(v) => v.get().fmt(f), + } + } +} + +impl Hash for ReshapeDimension { + fn hash(&self, state: &mut H) { + self.to_repr().hash(state) + } +} + +impl Dimension { + #[inline] + pub const fn new(v: u64) -> Self { + assert!(v <= i64::MAX as u64); + + // SAFETY: Bounds check done before + let dim = unsafe { NonZeroU64::new_unchecked(v.wrapping_add(1)) }; + Self(dim) + } + + #[inline] + pub const fn get(self) -> u64 { + self.0.get() - 1 + } +} + +impl ReshapeDimension { + #[inline] + pub const fn new(v: i64) -> Self { + if v < 0 { + Self::Infer + } else { + // SAFETY: We have bounds checked for -1 + let dim = unsafe { NonZeroU64::new_unchecked((v as u64).wrapping_add(1)) }; + Self::Specified(Dimension(dim)) + } + } + + #[inline] + fn to_repr(self) -> u64 { + match self { + Self::Infer => 0, + Self::Specified(dim) => dim.0.get(), + } + } + + #[inline] + pub const fn get(self) -> Option { + match self { + ReshapeDimension::Infer => None, + ReshapeDimension::Specified(dim) => Some(dim.get()), + } + } + + #[inline] + pub const fn get_or_infer(self, inferred: u64) -> u64 { + match self { + ReshapeDimension::Infer => inferred, + ReshapeDimension::Specified(dim) => dim.get(), + } + } + + #[inline] + pub fn get_or_infer_with(self, f: impl Fn() -> u64) -> u64 { + match self { + ReshapeDimension::Infer => f(), + ReshapeDimension::Specified(dim) => dim.get(), + } + } + + pub const fn new_dimension(dimension: u64) -> ReshapeDimension { + Self::Specified(Dimension::new(dimension)) + } +} + +impl TryFrom for Dimension { + type Error = (); + + #[inline] + fn try_from(value: i64) -> Result { + let ReshapeDimension::Specified(v) = ReshapeDimension::new(value) else { + return Err(()); + }; + + Ok(v) + } +} diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index b39f66b543ea..727faf0768c8 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -9,6 +9,7 @@ use polars_utils::pl_str::PlSmallStr; use self::gather::check_bounds_ca; use crate::chunked_array::cast::CastOptions; use crate::chunked_array::metadata::{MetadataFlags, MetadataTrait}; +use crate::datatypes::ReshapeDimension; use crate::prelude::*; use crate::series::{BitRepr, IsSorted, SeriesPhysIter}; use crate::utils::{slice_offsets, Container}; @@ -730,7 +731,7 @@ impl Column { self.as_materialized_series().unique().map(Column::from) } - pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult { + pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { // @scalar-opt self.as_materialized_series() .reshape_list(dimensions) @@ -738,7 +739,7 @@ impl Column { } #[cfg(feature = "dtype-array")] - pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult { + pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { // @scalar-opt self.as_materialized_series() .reshape_array(dimensions) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 115626d63805..2e613ea7e1a0 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -115,16 +115,18 @@ impl NumOpsDispatchInner for BooleanType { } #[cfg(feature = "dtype-array")] -fn array_shape(dt: &DataType, infer: bool) -> Vec { - fn inner(dt: &DataType, buf: &mut Vec) { +fn array_shape(dt: &DataType, infer: bool) -> Vec { + fn inner(dt: &DataType, buf: &mut Vec) { if let DataType::Array(_, size) = dt { - buf.push(*size as i64) + buf.push(ReshapeDimension::Specified( + Dimension::try_from(*size as i64).unwrap(), + )) } } let mut buf = vec![]; if infer { - buf.push(-1) + buf.push(ReshapeDimension::Infer) } inner(dt, &mut buf); buf diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index 85c8e283e166..fdc6b6091058 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -1,14 +1,9 @@ use std::borrow::Cow; -#[cfg(feature = "dtype-array")] -use std::cmp::Ordering; -#[cfg(feature = "dtype-array")] -use std::collections::VecDeque; use arrow::array::*; use arrow::legacy::kernels::list::array_to_unit_list; use arrow::offset::Offsets; use polars_error::{polars_bail, polars_ensure, PolarsResult}; -#[cfg(feature = "dtype-array")] use polars_utils::format_tuple; use crate::chunked_array::builder::get_list_builder; @@ -90,70 +85,70 @@ impl Series { } #[cfg(feature = "dtype-array")] - pub fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult { + pub fn reshape_array(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { polars_ensure!( !dimensions.is_empty(), InvalidOperation: "at least one dimension must be specified" ); - let mut dims = dimensions.iter().copied().collect::>(); - let leaf_array = self.get_leaf_array(); let size = leaf_array.len(); let mut total_dim_size = 1; - let mut infer_dim_index: Option = None; - for (index, &dim) in dims.iter().enumerate() { - match dim.cmp(&0) { - Ordering::Greater => total_dim_size *= dim as usize, - Ordering::Equal => { + let mut num_infers = 0; + for (index, &dim) in dimensions.iter().enumerate() { + match dim { + ReshapeDimension::Infer => { polars_ensure!( - index == 0, - InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}", - format_tuple!(dims) + num_infers == 0, + InvalidOperation: "can only specify one inferred dimension" ); - total_dim_size = 0; - // We can early exit here, as empty arrays will error with multiple dimensions, - // and non-empty arrays will error when the first dimension is zero. - break; + num_infers += 1; }, - Ordering::Less => { - polars_ensure!( - infer_dim_index.is_none(), - InvalidOperation: "can only specify one unknown dimension" - ); - infer_dim_index = Some(index); + ReshapeDimension::Specified(dim) => { + let dim = dim.get(); + + if dim > 0 { + total_dim_size *= dim as usize + } else { + polars_ensure!( + index == 0, + InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}", + format_tuple!(dimensions) + ); + total_dim_size = 0; + // We can early exit here, as empty arrays will error with multiple dimensions, + // and non-empty arrays will error when the first dimension is zero. + break; + } }, } } if size == 0 { - if dims.len() > 1 || (infer_dim_index.is_none() && total_dim_size != 0) { - polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dims)) + if dimensions.len() > 1 || (num_infers == 0 && total_dim_size != 0) { + polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dimensions)) } } else if total_dim_size == 0 { - polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dims)) + polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions)) } else { polars_ensure!( size % total_dim_size == 0, - InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dims) + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) ); } - // Infer dimension - if let Some(index) = infer_dim_index { - let inferred_dim = size / total_dim_size; - let item = dims.get_mut(index).unwrap(); - *item = i64::try_from(inferred_dim).unwrap(); - } - let leaf_array = leaf_array.rechunk(); let mut prev_dtype = leaf_array.dtype().clone(); let mut prev_array = leaf_array.chunks()[0].clone(); // We pop the outer dimension as that is the height of the series. - let _ = dims.pop_front(); - while let Some(dim) = dims.pop_back() { + for idx in (1..dimensions.len()).rev() { + // Infer dimension if needed + let dim = dimensions[idx].get_or_infer_with(|| { + debug_assert!(num_infers > 0); + (size / total_dim_size) as u64 + }); prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); prev_array = FixedSizeListArray::new( @@ -172,7 +167,7 @@ impl Series { }) } - pub fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult { + pub fn reshape_list(&self, dimensions: &[ReshapeDimension]) -> PolarsResult { polars_ensure!( !dimensions.is_empty(), InvalidOperation: "at least one dimension must be specified" @@ -187,38 +182,43 @@ impl Series { let s_ref = s.as_ref(); - let dimensions = dimensions.to_vec(); + // let dimensions = dimensions.to_vec(); match dimensions.len() { 1 => { polars_ensure!( - dimensions[0] as usize == s_ref.len() || dimensions[0] == -1_i64, + dimensions[0].get().map_or(true, |dim| dim as usize == s_ref.len()), InvalidOperation: "cannot reshape len {} into shape {:?}", s_ref.len(), dimensions, ); Ok(s_ref.clone()) }, 2 => { - let mut rows = dimensions[0]; - let mut cols = dimensions[1]; + let rows = dimensions[0]; + let cols = dimensions[1]; if s_ref.len() == 0_usize { - if (rows == -1 || rows == 0) && (cols == -1 || cols == 0 || cols == 1) { + if rows.get_or_infer(0) == 0 && cols.get_or_infer(0) <= 1 { let s = reshape_fast_path(s.name().clone(), s_ref); return Ok(s); } else { - polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {:?}", dimensions,) + polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {}", format_tuple!(dimensions)) } } + use ReshapeDimension as RD; // Infer dimension. - if rows == -1 && cols >= 1 { - rows = s_ref.len() as i64 / cols - } else if cols == -1 && rows >= 1 { - cols = s_ref.len() as i64 / rows - } else if rows == -1 && cols == -1 { - rows = s_ref.len() as i64; - cols = 1_i64; - } + + let (rows, cols) = match (rows, cols) { + (RD::Infer, RD::Specified(cols)) if cols.get() >= 1 => { + (s_ref.len() as u64 / cols.get(), cols.get()) + }, + (RD::Specified(rows), RD::Infer) if rows.get() >= 1 => { + (rows.get(), s_ref.len() as u64 / rows.get()) + }, + (RD::Infer, RD::Infer) => (s_ref.len() as u64, 1u64), + (RD::Specified(rows), RD::Specified(cols)) => (rows.get(), cols.get()), + _ => polars_bail!(InvalidOperation: "reshape of non-zero list into zero list"), + }; // Fast path, we can create a unit list so we only allocate offsets. if rows as usize == s_ref.len() && cols == 1 { @@ -234,9 +234,9 @@ impl Series { let mut builder = get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone())?; - let mut offset = 0i64; + let mut offset = 0u64; for _ in 0..rows { - let row = s_ref.slice(offset, cols as usize); + let row = s_ref.slice(offset as i64, cols as usize); builder.append_series(&row).unwrap(); offset += cols; } @@ -279,7 +279,11 @@ mod test { (&[-1, 2], 2), (&[2, -1], 2), ] { - let out = s.reshape_list(dims)?; + let dims = dims + .iter() + .map(|&v| ReshapeDimension::new(v)) + .collect::>(); + let out = s.reshape_list(&dims)?; assert_eq!(out.len(), list_len); assert!(matches!(out.dtype(), DataType::List(_))); assert_eq!(out.explode()?.len(), 4); diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index cdac9a46610a..e41886a29590 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -375,7 +375,12 @@ impl PhysicalExpr for AggregationExpr { let s = match ac.agg_state() { // mean agg: // -> f64 -> list - AggState::AggregatedScalar(s) => s.reshape_list(&[-1, 1]).unwrap(), + AggState::AggregatedScalar(s) => s + .reshape_list(&[ + ReshapeDimension::Infer, + ReshapeDimension::new_dimension(1), + ]) + .unwrap(), _ => { let agg = ac.aggregated(); agg.as_list().into_series() diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index ec17842d719c..8a74033953dc 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -421,7 +421,12 @@ impl<'a> AggregationContext<'a> { self.groups(); let rows = self.groups.len(); let s = s.new_from_index(0, rows); - let out = s.reshape_list(&[rows as i64, -1]).unwrap(); + let out = s + .reshape_list(&[ + ReshapeDimension::new_dimension(rows as u64), + ReshapeDimension::Infer, + ]) + .unwrap(); self.state = AggState::AggregatedList(out.clone()); out }, diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index e16ac5da4453..3584fa792d07 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -44,7 +44,9 @@ fn cast_rhs( } if !matches!(s.dtype(), DataType::List(_)) && s.dtype() == inner_type { // coerce to list JIT - *s = s.reshape_list(&[-1, 1]).unwrap(); + *s = s + .reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)]) + .unwrap(); } if s.dtype() != dtype { *s = s.cast(dtype).map_err(|e| { diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 6ee70819b4f7..7d0669626253 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -72,12 +72,9 @@ pub(super) fn unique_counts(s: &Column) -> PolarsResult { polars_ops::prelude::unique_counts(s.as_materialized_series()).map(Column::from) } -pub(super) fn reshape(s: &Column, dimensions: &[i64], nested: &NestedType) -> PolarsResult { - match nested { - NestedType::List => s.reshape_list(dimensions), - #[cfg(feature = "dtype-array")] - NestedType::Array => s.reshape_array(dimensions), - } +#[cfg(feature = "dtype-array")] +pub(super) fn reshape(c: &Column, dimensions: &[ReshapeDimension]) -> PolarsResult { + c.reshape_array(dimensions) } #[cfg(feature = "repeat_by")] diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 35467eff92bc..9159c2b6b3ee 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -395,7 +395,9 @@ pub(super) fn concat(s: &mut [Column]) -> PolarsResult> { let mut first_ca = match first.list().ok() { Some(ca) => ca, None => { - first = first.reshape_list(&[-1, 1]).unwrap(); + first = first + .reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)]) + .unwrap(); first.list().unwrap() }, } @@ -506,7 +508,7 @@ pub(super) fn gather(args: &[Column], null_on_oob: bool) -> PolarsResult let idx = idx.get(0)?.try_extract::()?; let out = ca.lst_get(idx, null_on_oob).map(Column::from)?; // make sure we return a list - out.reshape_list(&[-1, 1]) + out.reshape_list(&[ReshapeDimension::Infer, ReshapeDimension::new_dimension(1)]) } else { ca.lst_gather(idx.as_materialized_series(), null_on_oob) .map(Column::from) diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index a5049b84eb1b..6347f6cee7b4 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -80,6 +80,7 @@ pub(crate) use correlation::CorrelationMethod; #[cfg(feature = "fused")] pub(crate) use fused::FusedOperator; pub(crate) use list::ListFunction; +use polars_core::datatypes::ReshapeDimension; use polars_core::prelude::*; #[cfg(feature = "random")] pub(crate) use random::RandomMethod; @@ -172,7 +173,8 @@ pub enum FunctionExpr { Skew(bool), #[cfg(feature = "moment")] Kurtosis(bool, bool), - Reshape(Vec, NestedType), + #[cfg(feature = "dtype-array")] + Reshape(Vec), #[cfg(feature = "repeat_by")] RepeatBy, ArgUnique, @@ -524,10 +526,8 @@ impl Hash for FunctionExpr { left_closed.hash(state); include_breaks.hash(state); }, - Reshape(dims, nested) => { - dims.hash(state); - nested.hash(state); - }, + #[cfg(feature = "dtype-array")] + Reshape(dims) => dims.hash(state), #[cfg(feature = "repeat_by")] RepeatBy => {}, #[cfg(feature = "cutqcut")] @@ -728,7 +728,8 @@ impl Display for FunctionExpr { Cut { .. } => "cut", #[cfg(feature = "cutqcut")] QCut { .. } => "qcut", - Reshape(_, _) => "reshape", + #[cfg(feature = "dtype-array")] + Reshape(_) => "reshape", #[cfg(feature = "repeat_by")] RepeatBy => "repeat_by", #[cfg(feature = "rle")] @@ -1068,7 +1069,8 @@ impl From for SpecialEq> { PeakMax => map!(peaks::peak_max), #[cfg(feature = "repeat_by")] RepeatBy => map_as_slice!(dispatch::repeat_by), - Reshape(dims, nested) => map!(dispatch::reshape, &dims, &nested), + #[cfg(feature = "dtype-array")] + Reshape(dims) => map!(dispatch::reshape, &dims), #[cfg(feature = "cutqcut")] Cut { breaks, diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 15f03e6bb848..11b190b41d50 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -244,25 +244,43 @@ impl FunctionExpr { }, #[cfg(feature = "repeat_by")] RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())), - Reshape(dims, nested_type) => mapper.map_dtype(|dt| { + #[cfg(feature = "dtype-array")] + Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| { let dtype = dt.inner_dtype().unwrap_or(dt).clone(); + if dims.len() == 1 { - dtype - } else { - match nested_type { - NestedType::List => DataType::List(Box::new(dtype)), - #[cfg(feature = "dtype-array")] - NestedType::Array => { - let mut prev_dtype = dtype.leaf_dtype().clone(); - - // We pop the outer dimension as that is the height of the series. - for dim in &dims[1..] { - prev_dtype = DataType::Array(Box::new(prev_dtype), *dim as usize); - } - prev_dtype - }, + return Ok(dtype); + } + + let num_infers = dims.iter().filter(|d| matches!(d, ReshapeDimension::Infer)).count(); + + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + + let mut inferred_size = 0; + if num_infers == 1 { + let mut total_size = 1u64; + let mut current = dt; + while let DataType::Array(dt, width) = current { + if *width == 0 { + total_size = 0; + break; + } + + current = dt.as_ref(); + total_size *= *width as u64; } + + let current_size = dims.iter().map(|d| d.get_or_infer(1)).product::(); + inferred_size = total_size / current_size; + } + + let mut prev_dtype = dtype.leaf_dtype().clone(); + + // We pop the outer dimension as that is the height of the series. + for dim in &dims[1..] { + prev_dtype = DataType::Array(Box::new(prev_dtype), dim.get_or_infer(inferred_size) as usize); } + Ok(prev_dtype) }), #[cfg(feature = "cutqcut")] QCut { diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index e346b247e3ee..0d591ce81313 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1738,9 +1738,13 @@ impl Expr { }) } - pub fn reshape(self, dimensions: &[i64], nested_type: NestedType) -> Self { - let dimensions = dimensions.to_vec(); - self.apply_private(FunctionExpr::Reshape(dimensions, nested_type)) + #[cfg(feature = "dtype-array")] + pub fn reshape(self, dimensions: &[i64]) -> Self { + let dimensions = dimensions + .iter() + .map(|&v| ReshapeDimension::new(v)) + .collect(); + self.apply_private(FunctionExpr::Reshape(dimensions)) } #[cfg(feature = "ewma")] diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index a4d9ae84cd73..73481796d3e0 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -106,7 +106,7 @@ pub enum WindowMapping { pub enum NestedType { #[cfg(feature = "dtype-array")] Array, - List, + // List, } #[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index 715fec9ccbb0..e215a6ba2fe4 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -762,8 +762,9 @@ impl PyExpr { self.inner.clone().kurtosis(fisher, bias).into() } + #[cfg(feature = "dtype-array")] fn reshape(&self, dims: Vec) -> Self { - self.inner.clone().reshape(&dims, NestedType::Array).into() + self.inner.clone().reshape(&dims).into() } fn to_physical(&self) -> Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index f5ea2455a8b4..a1e5b26f1e27 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1169,9 +1169,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::Mode => ("mode",).to_object(py), FunctionExpr::Skew(bias) => ("skew", bias).to_object(py), FunctionExpr::Kurtosis(fisher, bias) => ("kurtosis", fisher, bias).to_object(py), - FunctionExpr::Reshape(_, _) => { - return Err(PyNotImplementedError::new_err("reshape")) - }, + #[cfg(feature = "dtype-array")] + FunctionExpr::Reshape(_) => return Err(PyNotImplementedError::new_err("reshape")), #[cfg(feature = "repeat_by")] FunctionExpr::RepeatBy => ("repeat_by",).to_object(py), FunctionExpr::ArgUnique => ("arg_unique",).to_object(py), diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index e0563a9c327d..1398aa0c3dd1 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -77,7 +77,13 @@ impl PySeries { }) } + #[cfg(feature = "dtype-array")] fn reshape(&self, dims: Vec) -> PyResult { + let dims = dims + .into_iter() + .map(ReshapeDimension::new) + .collect::>(); + let out = self .series .reshape_array(&dims) diff --git a/py-polars/tests/unit/operations/test_reshape.py b/py-polars/tests/unit/operations/test_reshape.py index 677e4ba6d107..12ddfed628c5 100644 --- a/py-polars/tests/unit/operations/test_reshape.py +++ b/py-polars/tests/unit/operations/test_reshape.py @@ -9,6 +9,10 @@ from polars.testing import assert_series_equal +def display_shape(shape: tuple[int, ...]) -> str: + return "(" + ", ".join(tuple(str(d) if d >= 0 else "inferred" for d in shape)) + ")" + + def test_reshape() -> None: s = pl.Series("a", [1, 2, 3, 4]) out = s.reshape((-1, 2)) @@ -47,10 +51,11 @@ def test_reshape() -> None: @pytest.mark.parametrize("shape", [(1, 3), (5, 1), (-1, 5), (3, -1)]) def test_reshape_invalid_dimension_size(shape: tuple[int, ...]) -> None: s = pl.Series("a", [1, 2, 3, 4]) - print(shape) with pytest.raises( InvalidOperationError, - match=re.escape(f"cannot reshape array of size 4 into shape {shape}"), + match=re.escape( + f"cannot reshape array of size 4 into shape {display_shape(shape)}" + ), ): s.reshape(shape) @@ -61,7 +66,7 @@ def test_reshape_invalid_zero_dimension() -> None: with pytest.raises( InvalidOperationError, match=re.escape( - f"cannot reshape array into shape containing a zero dimension after the first: {shape}" + f"cannot reshape array into shape containing a zero dimension after the first: {display_shape(shape)}" ), ): s.reshape(shape) @@ -73,7 +78,7 @@ def test_reshape_invalid_zero_dimension2(shape: tuple[int, ...]) -> None: with pytest.raises( InvalidOperationError, match=re.escape( - f"cannot reshape non-empty array into shape containing a zero dimension: {shape}" + f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}" ), ): s.reshape(shape) @@ -83,7 +88,7 @@ def test_reshape_invalid_zero_dimension2(shape: tuple[int, ...]) -> None: def test_reshape_invalid_multiple_unknown_dims(shape: tuple[int, ...]) -> None: s = pl.Series("a", [1, 2, 3, 4]) with pytest.raises( - InvalidOperationError, match="can only specify one unknown dimension" + InvalidOperationError, match="can only specify one inferred dimension" ): s.reshape(shape) @@ -100,7 +105,9 @@ def test_reshape_empty_invalid_2d(shape: tuple[int, ...]) -> None: s = pl.Series("a", [], dtype=pl.Int64) with pytest.raises( InvalidOperationError, - match=re.escape(f"cannot reshape empty array into shape {shape}"), + match=re.escape( + f"cannot reshape empty array into shape {display_shape(shape)}" + ), ): s.reshape(shape) From d85240dca1900a5c88b49a5da397f0df4d1e347d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 26 Sep 2024 13:02:13 +0200 Subject: [PATCH 06/33] fix: Disable very old date in timezone test for CI (#18935) --- .../operations/namespaces/temporal/test_datetime.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index 7dc5dfd8a9cc..6bbd10828496 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -1367,7 +1367,7 @@ def test_dt_mean_deprecated() -> None: @pytest.mark.parametrize( "value", [ - date(1677, 9, 22), + # date(1677, 9, 22), # See test_literal_from_datetime. date(1970, 1, 1), date(2024, 2, 29), date(2262, 4, 11), @@ -1400,8 +1400,13 @@ def test_literal_from_date( @pytest.mark.parametrize( "value", [ - datetime(1677, 9, 22), - datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")), + # Very old dates with a timezone like EST caused problems for the CI due + # to the IANA timezone database updating their historical offset, so + # these have been disabled for now. A mismatch between the timezone + # database that chrono_tz crate uses vs. the one that Python uses (which + # differs from platform to platform) will cause this to fail. + # datetime(1677, 9, 22), + # datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")), datetime(1970, 1, 1), datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")), datetime(2024, 2, 29), From 503582e1a39d0281f342e078a2b1c04702e4368b Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Thu, 26 Sep 2024 13:29:32 +0200 Subject: [PATCH 07/33] fix: Properly choose inner physical type for Array (#18942) --- .../polars-core/src/chunked_array/ops/full.rs | 23 +++++++++++++++---- py-polars/tests/unit/datatypes/test_array.py | 15 ++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index ee307cc3ca8e..e33d38118891 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -128,14 +128,21 @@ impl ArrayChunked { ArrowDataType::FixedSizeList( Box::new(ArrowField::new( PlSmallStr::from_static("item"), - inner_dtype.to_arrow(CompatLevel::newest()), + inner_dtype.to_physical().to_arrow(CompatLevel::newest()), true, )), width, ), length, ); - ChunkedArray::with_chunk(name, arr) + // SAFETY: physical type matches the logical. + unsafe { + ChunkedArray::from_chunks_and_dtype( + name, + vec![Box::new(arr)], + DataType::Array(Box::new(inner_dtype.clone()), width), + ) + } } } @@ -147,14 +154,22 @@ impl ChunkFull<&Series> for ArrayChunked { let arrow_dtype = ArrowDataType::FixedSizeList( Box::new(ArrowField::new( PlSmallStr::from_static("item"), - dtype.to_arrow(CompatLevel::newest()), + dtype.to_physical().to_arrow(CompatLevel::newest()), true, )), width, ); let value = value.rechunk().chunks()[0].clone(); let arr = FixedSizeListArray::full(length, value, arrow_dtype); - ChunkedArray::with_chunk(name, arr) + + // SAFETY: physical type matches the logical. + unsafe { + ChunkedArray::from_chunks_and_dtype( + name, + vec![Box::new(arr)], + DataType::Array(Box::new(dtype.clone()), width), + ) + } } } diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index 03f92bd68d11..6c4f240803bf 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -327,3 +327,18 @@ def test_array_inner_recursive_python_dtype() -> None: def test_array_missing_shape() -> None: with pytest.raises(TypeError): pl.Array(pl.Int8) + + +def test_array_invalid_physical_type_18920() -> None: + s1 = pl.Series("x", [[1000, 2000]], pl.List(pl.Datetime)) + s2 = pl.Series("x", [None], pl.List(pl.Datetime)) + + df1 = s1.to_frame().with_columns(pl.col.x.list.to_array(2)) + df2 = s2.to_frame().with_columns(pl.col.x.list.to_array(2)) + + df = pl.concat([df1, df2]) + + expected_s = pl.Series("x", [[1000, 2000], None], pl.List(pl.Datetime)) + + expected = expected_s.to_frame().with_columns(pl.col.x.list.to_array(2)) + assert_frame_equal(df, expected) From 71a8b0543362aadb771c2a446ac93756ea36a124 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 26 Sep 2024 14:00:34 +0200 Subject: [PATCH 08/33] fix: Incorrect mode for sorted input (#18945) --- crates/polars-ops/src/chunked_array/mode.rs | 31 ++++++++------------- py-polars/tests/unit/series/test_series.py | 3 +- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs index a36b161775ca..1c981c9eb3e2 100644 --- a/crates/polars-ops/src/chunked_array/mode.rs +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -1,4 +1,3 @@ -use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::{with_match_physical_integer_polars_type, POOL}; @@ -33,29 +32,23 @@ fn mode_64(ca: &Float64Chunked) -> PolarsResult { fn mode_indices(groups: GroupsProxy) -> Vec { match groups { GroupsProxy::Idx(groups) => { - let mut groups = groups.into_iter().collect_trusted::>(); - groups.sort_unstable_by_key(|k| k.1.len()); - let last = &groups.last().unwrap(); - let max_occur = last.1.len(); + let Some(max_len) = groups.iter().map(|g| g.1.len()).max() else { + return Vec::new(); + }; groups - .iter() - .rev() - .take_while(|v| v.1.len() == max_occur) - .map(|v| v.0) + .into_iter() + .filter(|g| g.1.len() == max_len) + .map(|g| g.0) .collect() }, GroupsProxy::Slice { groups, .. } => { - let last = groups.last().unwrap(); - let max_occur = last[1]; - + let Some(max_len) = groups.iter().map(|g| g[1]).max() else { + return Vec::new(); + }; groups - .iter() - .rev() - .take_while(|v| { - let len = v[1]; - len == max_occur - }) - .map(|v| v[0]) + .into_iter() + .filter(|g| g[1] == max_len) + .map(|g| g[0]) .collect() }, } diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 2a4d6f1d5285..0f45cdfb6e21 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -980,6 +980,7 @@ def test_reinterpret() -> None: def test_mode() -> None: s = pl.Series("a", [1, 1, 2]) assert s.mode().to_list() == [1] + assert s.set_sorted().mode().to_list() == [1] df = pl.DataFrame([s]) assert df.select([pl.col("a").mode()])["a"].to_list() == [1] @@ -990,7 +991,7 @@ def test_mode() -> None: assert pl.Series([1.0, 2.0, 3.0, 2.0]).mode().item() == 2.0 # sorted data - assert pl.int_range(0, 3, eager=True).mode().to_list() == [2, 1, 0] + assert set(pl.int_range(0, 3, eager=True).mode().to_list()) == {0, 1, 2} def test_diff() -> None: From bb214d6ca72348c3263d2629d2f4067fc3ade2fc Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 26 Sep 2024 20:46:41 +0200 Subject: [PATCH 09/33] refactor: Another set of new-stream test skip/fixes (#18952) --- .../src/plans/conversion/expr_expansion.rs | 2 +- .../src/physical_plan/lower_expr.rs | 2 +- .../operations/namespaces/test_categorical.py | 20 +++++-- .../tests/unit/operations/test_transpose.py | 9 +++ .../tests/unit/operations/test_unpivot.py | 56 +++++++++---------- .../tests/unit/sql/test_wildcard_opts.py | 2 +- py-polars/tests/unit/test_datatypes.py | 1 + 7 files changed, 53 insertions(+), 39 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index b17db4c728d0..4d1aa76caff5 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -643,7 +643,7 @@ fn find_flags(expr: &Expr) -> PolarsResult { #[cfg(feature = "dtype-struct")] fn toggle_cse(opt_flags: &mut OptFlags) { - if opt_flags.contains(OptFlags::EAGER) { + if opt_flags.contains(OptFlags::EAGER) && !opt_flags.contains(OptFlags::NEW_STREAMING) { #[cfg(debug_assertions)] { use polars_core::config::verbose; diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index 919694e8c538..39493af054c2 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -348,7 +348,7 @@ fn build_fallback_node_with_ctx( expr, Context::Default, ctx.expr_arena, - None, + Some(&ctx.phys_sm[input_node].output_schema), &mut conv_state, ) }) diff --git a/py-polars/tests/unit/operations/namespaces/test_categorical.py b/py-polars/tests/unit/operations/namespaces/test_categorical.py index 708abf7eed4d..3e491894c18e 100644 --- a/py-polars/tests/unit/operations/namespaces/test_categorical.py +++ b/py-polars/tests/unit/operations/namespaces/test_categorical.py @@ -1,3 +1,5 @@ +import pytest + import polars as pl from polars.testing import assert_frame_equal @@ -58,20 +60,26 @@ def test_categorical_lexical_ordering_after_concat() -> None: } -def test_sort_categoricals_6014() -> None: +@pytest.mark.may_fail_auto_streaming +def test_sort_categoricals_6014_internal() -> None: with pl.StringCache(): # create basic categorical - df1 = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( + df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( pl.col("key").cast(pl.Categorical) ) + + out = df.sort("key") + assert out.to_dict(as_series=False) == {"key": ["bbb", "aaa", "ccc"]} + + +def test_sort_categoricals_6014_lexical() -> None: + with pl.StringCache(): # create lexically-ordered categorical - df2 = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( + df = pl.DataFrame({"key": ["bbb", "aaa", "ccc"]}).with_columns( pl.col("key").cast(pl.Categorical("lexical")) ) - out = df1.sort("key") - assert out.to_dict(as_series=False) == {"key": ["bbb", "aaa", "ccc"]} - out = df2.sort("key") + out = df.sort("key") assert out.to_dict(as_series=False) == {"key": ["aaa", "bbb", "ccc"]} diff --git a/py-polars/tests/unit/operations/test_transpose.py b/py-polars/tests/unit/operations/test_transpose.py index ebb58ac4f2bf..a43a6e7f629e 100644 --- a/py-polars/tests/unit/operations/test_transpose.py +++ b/py-polars/tests/unit/operations/test_transpose.py @@ -13,6 +13,7 @@ from polars.testing import assert_frame_equal, assert_series_equal +@pytest.mark.may_fail_auto_streaming def test_transpose_supertype() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": ["foo", "bar", "ham"]}) result = df.transpose() @@ -26,6 +27,7 @@ def test_transpose_supertype() -> None: assert_frame_equal(result, expected) +@pytest.mark.may_fail_auto_streaming def test_transpose_tz_naive_and_tz_aware() -> None: df = pl.DataFrame( { @@ -41,6 +43,7 @@ def test_transpose_tz_naive_and_tz_aware() -> None: df.transpose() +@pytest.mark.may_fail_auto_streaming def test_transpose_struct() -> None: df = pl.DataFrame( { @@ -82,6 +85,7 @@ def test_transpose_struct() -> None: assert_frame_equal(result, expected) +@pytest.mark.may_fail_auto_streaming def test_transpose_arguments() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) expected = pl.DataFrame( @@ -136,6 +140,7 @@ def name_generator() -> Iterator[str]: assert_frame_equal(expected, out) +@pytest.mark.may_fail_auto_streaming def test_transpose_categorical_data() -> None: with pl.StringCache(): df = pl.DataFrame( @@ -174,6 +179,7 @@ def test_transpose_categorical_data() -> None: ).transpose() +@pytest.mark.may_fail_auto_streaming def test_transpose_logical_data() -> None: df = pl.DataFrame( { @@ -192,6 +198,7 @@ def test_transpose_logical_data() -> None: assert_frame_equal(result, expected) +@pytest.mark.may_fail_auto_streaming def test_err_transpose_object() -> None: class CustomObject: pass @@ -200,12 +207,14 @@ class CustomObject: pl.DataFrame([CustomObject()]).transpose() +@pytest.mark.may_fail_auto_streaming def test_transpose_name_from_column_13777() -> None: csv_file = io.BytesIO(b"id,kc\nhi,3") df = pl.read_csv(csv_file).transpose(column_names="id") assert_series_equal(df.to_series(0), pl.Series("hi", [3])) +@pytest.mark.may_fail_auto_streaming def test_transpose_multiple_chunks() -> None: df = pl.DataFrame({"a": ["1"]}) expected = pl.DataFrame({"column_0": ["1"], "column_1": ["1"]}) diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py index 7b51d91122dc..ada642c294ae 100644 --- a/py-polars/tests/unit/operations/test_unpivot.py +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -7,48 +7,44 @@ def test_unpivot() -> None: df = pl.DataFrame({"A": ["a", "b", "c"], "B": [1, 3, 5], "C": [2, 4, 6]}) + expected = { + ("a", "B", 1), + ("b", "B", 3), + ("c", "B", 5), + ("a", "C", 2), + ("b", "C", 4), + ("c", "C", 6), + } for _idv, _vv in (("A", ("B", "C")), (cs.string(), cs.integer())): unpivoted_eager = df.unpivot(index="A", on=["B", "C"]) - assert all(unpivoted_eager["value"] == [1, 3, 5, 2, 4, 6]) + assert set(unpivoted_eager.iter_rows()) == expected - unpivoted_lazy = df.lazy().unpivot(index="A", on=["B", "C"]) - assert all(unpivoted_lazy.collect()["value"] == [1, 3, 5, 2, 4, 6]) + unpivoted_lazy = df.lazy().unpivot(index="A", on=["B", "C"]).collect() + assert set(unpivoted_lazy.iter_rows()) == expected unpivoted = df.unpivot(index="A", on="B") - assert all(unpivoted["value"] == [1, 3, 5]) - n = 3 - + assert set(unpivoted["value"]) == {1, 3, 5} + + expected_full = { + ("A", "a"), + ("A", "b"), + ("A", "c"), + ("B", "1"), + ("B", "3"), + ("B", "5"), + ("C", "2"), + ("C", "4"), + ("C", "6"), + } for unpivoted in [df.unpivot(), df.lazy().unpivot().collect()]: - assert unpivoted["variable"].to_list() == ["A"] * n + ["B"] * n + ["C"] * n - assert unpivoted["value"].to_list() == [ - "a", - "b", - "c", - "1", - "3", - "5", - "2", - "4", - "6", - ] + assert set(unpivoted.iter_rows()) == expected_full with pytest.deprecated_call(match="unpivot"): for unpivoted in [ df.melt(value_name="foo", variable_name="bar"), df.lazy().melt(value_name="foo", variable_name="bar").collect(), ]: - assert unpivoted["bar"].to_list() == ["A"] * n + ["B"] * n + ["C"] * n - assert unpivoted["foo"].to_list() == [ - "a", - "b", - "c", - "1", - "3", - "5", - "2", - "4", - "6", - ] + assert set(unpivoted.iter_rows()) == expected_full def test_unpivot_projection_pd_7747() -> None: diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py index e27ce9ac14b3..c31a55d61829 100644 --- a/py-polars/tests/unit/sql/test_wildcard_opts.py +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -180,6 +180,6 @@ def test_select_wildcard_errors(df: pl.DataFrame) -> None: # note: missing "()" around the exclude option results in dupe col with pytest.raises( DuplicateError, - match="the name 'City' is duplicate", + match="City", ): assert df.sql("SELECT * EXCLUDE Address, City FROM self") diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 9bd545125f64..4d604f2964e9 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -138,6 +138,7 @@ def test_repr(dtype: PolarsDataType, representation: str) -> None: assert repr(dtype) == representation +@pytest.mark.may_fail_auto_streaming def test_conversion_dtype() -> None: df = ( pl.DataFrame( From a7845c46d614fad0ccc8264c228fb2660ca3a943 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 08:43:27 +0200 Subject: [PATCH 10/33] feat: Allow for zero-width fixed size lists (#18940) --- .../src/array/fixed_size_list/data.rs | 2 + .../src/array/fixed_size_list/ffi.rs | 14 +- .../src/array/fixed_size_list/mod.rs | 70 +++-- .../src/array/fixed_size_list/mutable.rs | 50 +++- .../src/array/growable/fixed_size_list.rs | 30 ++- crates/polars-arrow/src/array/static_array.rs | 4 +- crates/polars-arrow/src/compute/cast/mod.rs | 33 ++- crates/polars-arrow/src/datatypes/mod.rs | 11 + .../src/io/ipc/read/array/fixed_size_list.rs | 5 +- .../src/legacy/array/fixed_size_list.rs | 7 + crates/polars-arrow/src/legacy/array/mod.rs | 16 +- .../polars-compute/src/comparisons/array.rs | 8 + .../src/chunked_array/array/mod.rs | 3 +- crates/polars-core/src/chunked_array/cast.rs | 2 +- crates/polars-core/src/chunked_array/from.rs | 1 + crates/polars-core/src/datatypes/dtype.rs | 7 +- crates/polars-core/src/frame/column/mod.rs | 78 ++++++ crates/polars-core/src/series/from.rs | 52 ++-- crates/polars-core/src/series/mod.rs | 11 + crates/polars-core/src/series/ops/downcast.rs | 239 +++++++++++++++--- crates/polars-core/src/series/ops/reshape.rs | 105 +++++--- .../src/arrow/read/deserialize/mod.rs | 6 +- .../src/arrow/read/deserialize/nested.rs | 4 +- .../arrow/read/deserialize/nested_utils.rs | 17 +- .../src/arrow/read/statistics/list.rs | 1 + .../polars-plan/src/dsl/function_expr/list.rs | 2 +- .../it/arrow/array/fixed_size_list/mod.rs | 10 +- .../it/arrow/compute/aggregate/memory.rs | 2 +- .../tests/unit/constructors/test_series.py | 11 +- py-polars/tests/unit/datatypes/test_array.py | 41 +++ py-polars/tests/unit/io/test_parquet.py | 13 +- .../tests/unit/operations/test_reshape.py | 38 ++- 32 files changed, 677 insertions(+), 216 deletions(-) diff --git a/crates/polars-arrow/src/array/fixed_size_list/data.rs b/crates/polars-arrow/src/array/fixed_size_list/data.rs index de9bc1b882c2..c1f353db691a 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/data.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/data.rs @@ -18,6 +18,7 @@ impl Arrow2Arrow for FixedSizeListArray { fn from_data(data: &ArrayData) -> Self { let dtype: ArrowDataType = data.data_type().clone().into(); + let length = data.len() - data.offset(); let size = match dtype { ArrowDataType::FixedSizeList(_, size) => size, _ => unreachable!("must be FixedSizeList type"), @@ -28,6 +29,7 @@ impl Arrow2Arrow for FixedSizeListArray { Self { size, + length, dtype, values, validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), diff --git a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs index 29cf7957cf6c..297d7ae8e5f2 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs @@ -1,4 +1,4 @@ -use polars_error::PolarsResult; +use polars_error::{polars_ensure, PolarsResult}; use super::FixedSizeListArray; use crate::array::ffi::{FromFfi, ToFfi}; @@ -31,11 +31,19 @@ unsafe impl ToFfi for FixedSizeListArray { impl FromFfi for FixedSizeListArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { let dtype = array.dtype().clone(); + let (_, width) = FixedSizeListArray::try_child_and_size(&dtype)?; let validity = unsafe { array.validity() }?; - let child = unsafe { array.child(0)? }; + let child = unsafe { array.child(0) }?; let values = ffi::try_from(child)?; - let mut fsl = Self::try_new(dtype, values, validity)?; + let length = if values.len() == 0 { + 0 + } else { + polars_ensure!(width > 0, InvalidOperation: "Zero-width array with values"); + values.len() / width + }; + + let mut fsl = Self::try_new(dtype, length, values, validity)?; fsl.slice(array.offset(), array.length()); Ok(fsl) } diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 37cc7e2ad781..4f1622819813 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -10,7 +10,7 @@ mod iterator; mod mutable; pub use mutable::*; -use polars_error::{polars_bail, PolarsResult}; +use polars_error::{polars_bail, polars_ensure, PolarsResult}; use polars_utils::pl_str::PlSmallStr; /// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. @@ -18,6 +18,7 @@ use polars_utils::pl_str::PlSmallStr; #[derive(Clone)] pub struct FixedSizeListArray { size: usize, // this is redundant with `dtype`, but useful to not have to deconstruct the dtype. + length: usize, // invariant: this is values.len() / size if size > 0 dtype: ArrowDataType, values: Box, validity: Option, @@ -34,6 +35,7 @@ impl FixedSizeListArray { /// * the validity's length is not equal to `values.len() / size`. pub fn try_new( dtype: ArrowDataType, + length: usize, values: Box, validity: Option, ) -> PolarsResult { @@ -45,34 +47,61 @@ impl FixedSizeListArray { polars_bail!(ComputeError: "FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_dtype:?} while it got {values_dtype:?}.") } - if values.len() % size != 0 { - polars_bail!(ComputeError: - "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", - values.len(), - size - ) - } - let len = values.len() / size; + polars_ensure!(size == 0 || values.len() % size == 0, ComputeError: + "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", + values.len(), + size + ); + + polars_ensure!(size == 0 || values.len() / size == length, ComputeError: + "length of values ({}) is not equal to given length ({}) in FixedSizeListArray({size}).", + values.len() / size, + length, + ); + polars_ensure!(size != 0 || values.len() == 0, ComputeError: + "zero width FixedSizeListArray has values (length = {}).", + values.len(), + ); if validity .as_ref() - .map_or(false, |validity| validity.len() != len) + .map_or(false, |validity| validity.len() != length) { polars_bail!(ComputeError: "validity mask length must be equal to the number of values divided by size") } Ok(Self { size, + length, dtype, values, validity, }) } + #[inline] + fn has_invariants(&self) -> bool { + let has_valid_length = (self.size == 0 && self.values().len() == 0) + || (self.size > 0 + && self.values().len() % self.size() == 0 + && self.values().len() / self.size() == self.length); + let has_valid_validity = self + .validity + .as_ref() + .map_or(true, |v| v.len() == self.length); + + has_valid_length && has_valid_validity + } + /// Alias to `Self::try_new(...).unwrap()` #[track_caller] - pub fn new(dtype: ArrowDataType, values: Box, validity: Option) -> Self { - Self::try_new(dtype, values, validity).unwrap() + pub fn new( + dtype: ArrowDataType, + length: usize, + values: Box, + validity: Option, + ) -> Self { + Self::try_new(dtype, length, values, validity).unwrap() } /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. @@ -83,7 +112,7 @@ impl FixedSizeListArray { /// Returns a new empty [`FixedSizeListArray`]. pub fn new_empty(dtype: ArrowDataType) -> Self { let values = new_empty_array(Self::get_child_and_size(&dtype).0.dtype().clone()); - Self::new(dtype, values, None) + Self::new(dtype, 0, values, None) } /// Returns a new null [`FixedSizeListArray`]. @@ -91,7 +120,7 @@ impl FixedSizeListArray { let (field, size) = Self::get_child_and_size(&dtype); let values = new_null_array(field.dtype().clone(), length * size); - Self::new(dtype, values, Some(Bitmap::new_zeroed(length))) + Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length))) } } @@ -124,6 +153,7 @@ impl FixedSizeListArray { .filter(|bitmap| bitmap.unset_bits() > 0); self.values .slice_unchecked(offset * self.size, length * self.size); + self.length = length; } impl_sliced!(); @@ -136,7 +166,8 @@ impl FixedSizeListArray { /// Returns the length of this array #[inline] pub fn len(&self) -> usize { - self.values.len() / self.size + debug_assert!(self.has_invariants()); + self.length } /// The optional validity. @@ -184,12 +215,7 @@ impl FixedSizeListArray { impl FixedSizeListArray { pub(crate) fn try_child_and_size(dtype: &ArrowDataType) -> PolarsResult<(&Field, usize)> { match dtype.to_logical_type() { - ArrowDataType::FixedSizeList(child, size) => { - if *size == 0 { - polars_bail!(ComputeError: "FixedSizeBinaryArray expects a positive size") - } - Ok((child.as_ref(), *size)) - }, + ArrowDataType::FixedSizeList(child, size) => Ok((child.as_ref(), *size)), _ => polars_bail!(ComputeError: "FixedSizeListArray expects DataType::FixedSizeList"), } } @@ -233,12 +259,14 @@ impl Splitable for FixedSizeListArray { ( Self { dtype: self.dtype.clone(), + length: offset, values: lhs_values, validity: lhs_validity, size, }, Self { dtype: self.dtype.clone(), + length: self.length - offset, values: rhs_values, validity: rhs_validity, size, diff --git a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs index 04802e59bd67..b3a32be0802c 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs @@ -14,6 +14,7 @@ use crate::datatypes::{ArrowDataType, Field}; pub struct MutableFixedSizeListArray { dtype: ArrowDataType, size: usize, + length: usize, values: M, validity: Option, } @@ -22,6 +23,7 @@ impl From> for FixedSizeListArray fn from(mut other: MutableFixedSizeListArray) -> Self { FixedSizeListArray::new( other.dtype, + other.length, other.values.as_box(), other.validity.map(|x| x.into()), ) @@ -53,12 +55,19 @@ impl MutableFixedSizeListArray { }; Self { size, + length: 0, dtype, values, validity: None, } } + #[inline] + fn has_valid_invariants(&self) -> bool { + (self.size == 0 && self.values().len() == 0) + || (self.size > 0 && self.values.len() / self.size == self.length) + } + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. pub const fn size(&self) -> usize { self.size @@ -66,7 +75,8 @@ impl MutableFixedSizeListArray { /// The length of this array pub fn len(&self) -> usize { - self.values.len() / self.size + debug_assert!(self.has_valid_invariants()); + self.length } /// The inner values @@ -74,11 +84,6 @@ impl MutableFixedSizeListArray { &self.values } - /// The values as a mutable reference - pub fn mut_values(&mut self) -> &mut M { - &mut self.values - } - fn init_validity(&mut self) { let len = self.values.len() / self.size; @@ -98,6 +103,10 @@ impl MutableFixedSizeListArray { if let Some(validity) = &mut self.validity { validity.push(true) } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); + Ok(()) } @@ -108,6 +117,9 @@ impl MutableFixedSizeListArray { if let Some(validity) = &mut self.validity { validity.push(true) } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); } #[inline] @@ -117,6 +129,9 @@ impl MutableFixedSizeListArray { Some(validity) => validity.push(false), None => self.init_validity(), } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); } /// Reserves `additional` slots. @@ -138,7 +153,8 @@ impl MutableFixedSizeListArray { impl MutableArray for MutableFixedSizeListArray { fn len(&self) -> usize { - self.values.len() / self.size + debug_assert!(self.has_valid_invariants()); + self.length } fn validity(&self) -> Option<&MutableBitmap> { @@ -148,6 +164,7 @@ impl MutableArray for MutableFixedSizeListArray { fn as_box(&mut self) -> Box { FixedSizeListArray::new( self.dtype.clone(), + self.length, self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), ) @@ -157,6 +174,7 @@ impl MutableArray for MutableFixedSizeListArray { fn as_arc(&mut self) -> Arc { FixedSizeListArray::new( self.dtype.clone(), + self.length, self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), ) @@ -185,6 +203,9 @@ impl MutableArray for MutableFixedSizeListArray { } else { self.init_validity() } + self.length += 1; + + debug_assert!(self.has_valid_invariants()); } fn reserve(&mut self, additional: usize) { @@ -206,6 +227,9 @@ where for items in iter { self.try_push(items)?; } + + debug_assert!(self.has_valid_invariants()); + Ok(()) } } @@ -223,6 +247,9 @@ where } else { self.push_null(); } + + debug_assert!(self.has_valid_invariants()); + Ok(()) } } @@ -243,6 +270,8 @@ where } else { self.push_null(); } + + debug_assert!(self.has_valid_invariants()); } } @@ -253,6 +282,11 @@ where fn try_extend_from_self(&mut self, other: &Self) -> PolarsResult<()> { extend_validity(self.len(), &mut self.validity, &other.validity); - self.values.try_extend_from_self(&other.values) + self.values.try_extend_from_self(&other.values)?; + self.length += other.len(); + + debug_assert!(self.has_valid_invariants()); + + Ok(()) } } diff --git a/crates/polars-arrow/src/array/growable/fixed_size_list.rs b/crates/polars-arrow/src/array/growable/fixed_size_list.rs index c15202084006..5fedb9a4d254 100644 --- a/crates/polars-arrow/src/array/growable/fixed_size_list.rs +++ b/crates/polars-arrow/src/array/growable/fixed_size_list.rs @@ -6,7 +6,6 @@ use super::{make_growable, Growable}; use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity}; use crate::array::{Array, FixedSizeListArray}; use crate::bitmap::MutableBitmap; -use crate::datatypes::ArrowDataType; /// Concrete [`Growable`] for the [`FixedSizeListArray`]. pub struct GrowableFixedSizeList<'a> { @@ -14,6 +13,7 @@ pub struct GrowableFixedSizeList<'a> { validity: Option, values: Box + 'a>, size: usize, + length: usize, } impl<'a> GrowableFixedSizeList<'a> { @@ -33,24 +33,25 @@ impl<'a> GrowableFixedSizeList<'a> { use_validity = true; }; - let size = - if let ArrowDataType::FixedSizeList(_, size) = &arrays[0].dtype().to_logical_type() { - *size - } else { - unreachable!("`GrowableFixedSizeList` expects `DataType::FixedSizeList`") - }; + let size = arrays[0].size(); let inner = arrays .iter() - .map(|array| array.values().as_ref()) + .map(|array| { + debug_assert_eq!(array.size(), size); + array.values().as_ref() + }) .collect::>(); let values = make_growable(&inner, use_validity, 0); + assert_eq!(values.len(), 0); + Self { arrays, values, validity: prepare_validity(use_validity, capacity), size, + length: 0, } } @@ -60,6 +61,7 @@ impl<'a> GrowableFixedSizeList<'a> { FixedSizeListArray::new( self.arrays[0].dtype().clone(), + self.length, values, validity.map(|v| v.into()), ) @@ -71,16 +73,24 @@ impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { let array = *self.arrays.get_unchecked_release(index); extend_validity(&mut self.validity, array, start, len); + self.length += len; + let start_length = self.values.len(); self.values .extend(index, start * self.size, len * self.size); + debug_assert!(self.size == 0 || (self.values.len() - start_length) / self.size == len); } unsafe fn extend_copies(&mut self, index: usize, start: usize, len: usize, copies: usize) { let array = *self.arrays.get_unchecked_release(index); extend_validity_copies(&mut self.validity, array, start, len, copies); + self.length += len * copies; + let start_length = self.values.len(); self.values .extend_copies(index, start * self.size, len * self.size, copies); + debug_assert!( + self.size == 0 || (self.values.len() - start_length) / self.size == len * copies + ); } fn extend_validity(&mut self, additional: usize) { @@ -88,11 +98,12 @@ impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { if let Some(validity) = &mut self.validity { validity.extend_constant(additional, false); } + self.length += additional; } #[inline] fn len(&self) -> usize { - self.values.len() / self.size + self.length } fn as_arc(&mut self) -> Arc { @@ -111,6 +122,7 @@ impl<'a> From> for FixedSizeListArray { Self::new( val.arrays[0].dtype().clone(), + val.length, values, val.validity.map(|v| v.into()), ) diff --git a/crates/polars-arrow/src/array/static_array.rs b/crates/polars-arrow/src/array/static_array.rs index 3cfbc870e141..a79b6a909fe6 100644 --- a/crates/polars-arrow/src/array/static_array.rs +++ b/crates/polars-arrow/src/array/static_array.rs @@ -1,8 +1,8 @@ use bytemuck::Zeroable; use polars_utils::no_call_const; +use super::growable::{Growable, GrowableFixedSizeList}; use crate::array::binview::BinaryViewValueIter; -use crate::array::growable::{Growable, GrowableFixedSizeList}; use crate::array::static_array_collect::ArrayFromIterDtype; use crate::array::{ Array, ArrayValuesIter, BinaryArray, BinaryValueIter, BinaryViewArray, BooleanArray, @@ -394,7 +394,7 @@ impl StaticArray for FixedSizeListArray { } fn full(length: usize, value: Self::ValueT<'_>, dtype: ArrowDataType) -> Self { - let singular_arr = FixedSizeListArray::new(dtype, value, None); + let singular_arr = FixedSizeListArray::new(dtype, 1, value, None); let mut arr = GrowableFixedSizeList::new(vec![&singular_arr], false, length); unsafe { arr.extend_copies(0, 0, 1, length) } arr.into() diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 4dd8857b95c3..27f93eb07356 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -174,25 +174,20 @@ fn cast_list_to_fixed_size_list( ) -> PolarsResult { let null_cnt = list.null_count(); let new_values = if null_cnt == 0 { - let offsets = list.offsets().buffer().iter(); - let expected = - (list.offsets().first().to_usize()..list.len()).map(|ix| O::from_as_usize(ix * size)); - - match offsets - .zip(expected) - .find(|(actual, expected)| *actual != expected) - { - Some(_) => polars_bail!(ComputeError: - "not all elements have the specified width {size}" - ), - None => { - let sliced_values = list.values().sliced( - list.offsets().first().to_usize(), - list.offsets().range().to_usize(), - ); - cast(sliced_values.as_ref(), inner.dtype(), options)? - }, + let start_offset = list.offsets().first().to_usize(); + let offsets = list.offsets().buffer(); + + let mut is_valid = true; + for (i, offset) in offsets.iter().enumerate() { + is_valid &= offset.to_usize() == start_offset + i * size; } + + polars_ensure!(is_valid, ComputeError: "not all elements have the specified width {size}"); + + let sliced_values = list + .values() + .sliced(start_offset, list.offsets().range().to_usize()); + cast(sliced_values.as_ref(), inner.dtype(), options)? } else { let offsets = list.offsets().as_slice(); // Check the lengths of each list are equal to the fixed size. @@ -232,8 +227,10 @@ fn cast_list_to_fixed_size_list( cast(take_values.as_ref(), inner.dtype(), options)? }; + FixedSizeListArray::try_new( ArrowDataType::FixedSizeList(Box::new(inner.clone()), size), + list.len(), new_values, list.validity().cloned(), ) diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index 6ef9687f146e..8f2226c709e6 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -567,6 +567,17 @@ impl ArrowDataType { pub fn is_view(&self) -> bool { matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) } + + pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType { + ArrowDataType::FixedSizeList( + Box::new(Field::new( + PlSmallStr::from_static("item"), + self, + is_nullable, + )), + size, + ) + } } impl From for ArrowDataType { diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs index eac68f9fda54..fdfa13574e3d 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -1,7 +1,7 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; -use polars_error::{polars_err, PolarsResult}; +use polars_error::{polars_ensure, polars_err, PolarsResult}; use super::super::super::IpcField; use super::super::deserialize::{read, skip}; @@ -41,6 +41,7 @@ pub fn read_fixed_size_list( )?; let (field, size) = FixedSizeListArray::get_child_and_size(&dtype); + polars_ensure!(size > 0, nyi = "Cannot read zero sized arrays from IPC"); let limit = limit.map(|x| x.saturating_mul(size)); @@ -59,7 +60,7 @@ pub fn read_fixed_size_list( version, scratch, )?; - FixedSizeListArray::try_new(dtype, values, validity) + FixedSizeListArray::try_new(dtype, values.len() / size, values, validity) } pub fn skip_fixed_size_list( diff --git a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs index 99382b0b6407..813f357e2137 100644 --- a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs @@ -10,6 +10,7 @@ use crate::legacy::kernels::concatenate::concatenate_owned_unchecked; pub struct AnonymousBuilder { arrays: Vec, validity: Option, + length: usize, pub width: usize, } @@ -19,6 +20,7 @@ impl AnonymousBuilder { arrays: Vec::with_capacity(capacity), validity: None, width, + length: 0, } } pub fn is_empty(&self) -> bool { @@ -32,6 +34,8 @@ impl AnonymousBuilder { if let Some(validity) = &mut self.validity { validity.push(true) } + + self.length += 1; } pub fn push_null(&mut self) { @@ -41,6 +45,8 @@ impl AnonymousBuilder { Some(validity) => validity.push(false), None => self.init_validity(), } + + self.length += 1; } fn init_validity(&mut self) { @@ -82,6 +88,7 @@ impl AnonymousBuilder { let dtype = FixedSizeListArray::default_datatype(inner_dtype.clone(), self.width); Ok(FixedSizeListArray::new( dtype, + self.length, values, self.validity.map(|validity| validity.into()), )) diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs index bbb876283470..f15ac1811f96 100644 --- a/crates/polars-arrow/src/legacy/array/mod.rs +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -212,11 +212,23 @@ pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { + let width = *width; + let array = array.as_any().downcast_ref::().unwrap(); let inner = array.values(); + let length = if width == array.size() { + array.len() + } else { + assert!(array.values().len() > 0 || width != 0); + if width == 0 { + 0 + } else { + array.values().len() / width + } + }; let new_values = convert_inner_type(inner.as_ref(), field.dtype()); - let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), *width); - FixedSizeListArray::new(dtype, new_values, array.validity().cloned()).boxed() + let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), width); + FixedSizeListArray::new(dtype, length, new_values, array.validity().cloned()).boxed() }, ArrowDataType::Struct(fields) => { let array = array.as_any().downcast_ref::().unwrap(); diff --git a/crates/polars-compute/src/comparisons/array.rs b/crates/polars-compute/src/comparisons/array.rs index da120f27553b..23d43887a280 100644 --- a/crates/polars-compute/src/comparisons/array.rs +++ b/crates/polars-compute/src/comparisons/array.rs @@ -48,6 +48,10 @@ impl TotalEqKernel for FixedSizeListArray { return Bitmap::new_with_value(false, self.len()); } + if *self_width == 0 { + return Bitmap::new_with_value(true, self.len()); + } + let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref()); agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0) @@ -69,6 +73,10 @@ impl TotalEqKernel for FixedSizeListArray { return Bitmap::new_with_value(true, self.len()); } + if *self_width == 0 { + return Bitmap::new_with_value(false, self.len()); + } + let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref()); agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size()) diff --git a/crates/polars-core/src/chunked_array/array/mod.rs b/crates/polars-core/src/chunked_array/array/mod.rs index 59bdd92b67cc..3e0e47a7e86a 100644 --- a/crates/polars-core/src/chunked_array/array/mod.rs +++ b/crates/polars-core/src/chunked_array/array/mod.rs @@ -74,7 +74,8 @@ impl ArrayChunked { out.dtype().to_arrow(CompatLevel::newest()), ca.width(), ); - let arr = FixedSizeListArray::new(inner_dtype, values, arr.validity().cloned()); + let arr = + FixedSizeListArray::new(inner_dtype, arr.len(), values, arr.validity().cloned()); Ok(arr) }); diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index ea758742169e..1b8228d4ea69 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -664,7 +664,7 @@ fn cast_fixed_size_list( let new_values = new_inner.array_ref(0).clone(); let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), ca.width()); - let new_arr = FixedSizeListArray::new(dtype, new_values, arr.validity().cloned()); + let new_arr = FixedSizeListArray::new(dtype, ca.len(), new_values, arr.validity().cloned()); Ok((Box::new(new_arr), inner_dtype)) } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index bf5c748eeed1..33e984b94e0f 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -71,6 +71,7 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy let arrow_dtype = FixedSizeListArray::default_datatype(ArrowDataType::UInt32, width); let new_array = FixedSizeListArray::new( arrow_dtype, + values_arr.len(), cat.array_ref(0).clone(), list_arr.validity().cloned(), ); diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 956d055a52c2..cd79349bfcd8 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -591,10 +591,9 @@ impl DataType { Duration(unit) => Ok(ArrowDataType::Duration(unit.to_arrow())), Time => Ok(ArrowDataType::Time64(ArrowTimeUnit::Nanosecond)), #[cfg(feature = "dtype-array")] - Array(dt, size) => Ok(ArrowDataType::FixedSizeList( - Box::new(dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level)), - *size, - )), + Array(dt, size) => Ok(dt + .try_to_arrow(compat_level)? + .to_fixed_size_list(*size, true)), List(dt) => Ok(ArrowDataType::LargeList(Box::new( dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level), ))), diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 727faf0768c8..1dea44ee393a 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -189,6 +189,84 @@ impl Column { } } + // # Try to Chunked Arrays + pub fn try_bool(&self) -> Option<&BooleanChunked> { + self.as_materialized_series().try_bool() + } + pub fn try_i8(&self) -> Option<&Int8Chunked> { + self.as_materialized_series().try_i8() + } + pub fn try_i16(&self) -> Option<&Int16Chunked> { + self.as_materialized_series().try_i16() + } + pub fn try_i32(&self) -> Option<&Int32Chunked> { + self.as_materialized_series().try_i32() + } + pub fn try_i64(&self) -> Option<&Int64Chunked> { + self.as_materialized_series().try_i64() + } + pub fn try_u8(&self) -> Option<&UInt8Chunked> { + self.as_materialized_series().try_u8() + } + pub fn try_u16(&self) -> Option<&UInt16Chunked> { + self.as_materialized_series().try_u16() + } + pub fn try_u32(&self) -> Option<&UInt32Chunked> { + self.as_materialized_series().try_u32() + } + pub fn try_u64(&self) -> Option<&UInt64Chunked> { + self.as_materialized_series().try_u64() + } + pub fn try_f32(&self) -> Option<&Float32Chunked> { + self.as_materialized_series().try_f32() + } + pub fn try_f64(&self) -> Option<&Float64Chunked> { + self.as_materialized_series().try_f64() + } + pub fn try_str(&self) -> Option<&StringChunked> { + self.as_materialized_series().try_str() + } + pub fn try_list(&self) -> Option<&ListChunked> { + self.as_materialized_series().try_list() + } + pub fn try_binary(&self) -> Option<&BinaryChunked> { + self.as_materialized_series().try_binary() + } + pub fn try_idx(&self) -> Option<&IdxCa> { + self.as_materialized_series().try_idx() + } + pub fn try_binary_offset(&self) -> Option<&BinaryOffsetChunked> { + self.as_materialized_series().try_binary_offset() + } + #[cfg(feature = "dtype-datetime")] + pub fn try_datetime(&self) -> Option<&DatetimeChunked> { + self.as_materialized_series().try_datetime() + } + #[cfg(feature = "dtype-struct")] + pub fn try_struct(&self) -> Option<&StructChunked> { + self.as_materialized_series().try_struct() + } + #[cfg(feature = "dtype-decimal")] + pub fn try_decimal(&self) -> Option<&DecimalChunked> { + self.as_materialized_series().try_decimal() + } + #[cfg(feature = "dtype-array")] + pub fn try_array(&self) -> Option<&ArrayChunked> { + self.as_materialized_series().try_array() + } + #[cfg(feature = "dtype-categorical")] + pub fn try_categorical(&self) -> Option<&CategoricalChunked> { + self.as_materialized_series().try_categorical() + } + #[cfg(feature = "dtype-date")] + pub fn try_date(&self) -> Option<&DateChunked> { + self.as_materialized_series().try_date() + } + #[cfg(feature = "dtype-duration")] + pub fn try_duration(&self) -> Option<&DurationChunked> { + self.as_materialized_series().try_duration() + } + // # To Chunked Arrays pub fn bool(&self) -> PolarsResult<&BooleanChunked> { self.as_materialized_series().bool() diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index ce473a4d60fb..7f61f99895f4 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -517,37 +517,33 @@ unsafe fn to_physical_and_dtype( to_physical_and_dtype(out, md) }, #[cfg(feature = "dtype-array")] - #[allow(unused_variables)] ArrowDataType::FixedSizeList(field, size) => { - feature_gated!("dtype-array", { - let values = arrays - .iter() - .map(|arr| { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.values().clone() - }) - .collect::>(); + let values = arrays + .iter() + .map(|arr| { + let arr = arr.as_any().downcast_ref::().unwrap(); + arr.values().clone() + }) + .collect::>(); - let (converted_values, dtype) = - to_physical_and_dtype(values, Some(&field.metadata)); + let (converted_values, dtype) = to_physical_and_dtype(values, Some(&field.metadata)); - let arrays = arrays - .iter() - .zip(converted_values) - .map(|(arr, values)| { - let arr = arr.as_any().downcast_ref::().unwrap(); - - let dtype = - FixedSizeListArray::default_datatype(values.dtype().clone(), *size); - Box::from(FixedSizeListArray::new( - dtype, - values, - arr.validity().cloned(), - )) as ArrayRef - }) - .collect(); - (arrays, DataType::Array(Box::new(dtype), *size)) - }) + let arrays = arrays + .iter() + .zip(converted_values) + .map(|(arr, values)| { + let arr = arr.as_any().downcast_ref::().unwrap(); + + let dtype = FixedSizeListArray::default_datatype(values.dtype().clone(), *size); + Box::from(FixedSizeListArray::new( + dtype, + arr.len(), + values, + arr.validity().cloned(), + )) as ArrayRef + }) + .collect(); + (arrays, DataType::Array(Box::new(dtype), *size)) }, ArrowDataType::LargeList(field) => { let values = arrays diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index a41c5822283c..ce9bcffba2f0 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -825,6 +825,17 @@ impl Series { unsafe { Ok(self.take_unchecked(&idx)) } } + pub fn try_idx(&self) -> Option<&IdxCa> { + #[cfg(feature = "bigidx")] + { + self.try_u64() + } + #[cfg(not(feature = "bigidx"))] + { + self.try_u32() + } + } + pub fn idx(&self) -> PolarsResult<&IdxCa> { #[cfg(feature = "bigidx")] { diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs index ce57e42c610c..2189fc319b5e 100644 --- a/crates/polars-core/src/series/ops/downcast.rs +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -1,36 +1,190 @@ use crate::prelude::*; use crate::series::implementations::null::NullChunked; -macro_rules! unpack_chunked { - ($series:expr, $expected:pat => $ca:ty, $name:expr) => { +macro_rules! unpack_chunked_err { + ($series:expr => $name:expr) => { + polars_err!(SchemaMismatch: "invalid series dtype: expected `{}`, got `{}`", $name, $series.dtype()) + }; +} + +macro_rules! try_unpack_chunked { + ($series:expr, $expected:pat => $ca:ty) => { match $series.dtype() { $expected => { // Check downcast in debug compiles #[cfg(debug_assertions)] { - Ok($series.as_ref().as_any().downcast_ref::<$ca>().unwrap()) + Some($series.as_ref().as_any().downcast_ref::<$ca>().unwrap()) } #[cfg(not(debug_assertions))] unsafe { - Ok(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) + Some(&*($series.as_ref() as *const dyn SeriesTrait as *const $ca)) } }, - dt => polars_bail!( - SchemaMismatch: "invalid series dtype: expected `{}`, got `{}`", $name, dt, - ), + _ => None, } }; } impl Series { + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int8]` + pub fn try_i8(&self) -> Option<&Int8Chunked> { + try_unpack_chunked!(self, DataType::Int8 => Int8Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int16]` + pub fn try_i16(&self) -> Option<&Int16Chunked> { + try_unpack_chunked!(self, DataType::Int16 => Int16Chunked) + } + + /// Unpack to [`ChunkedArray`] + /// ``` + /// # use polars_core::prelude::*; + /// let s = Series::new("foo".into(), [1i32 ,2, 3]); + /// let s_squared: Series = s.i32() + /// .unwrap() + /// .into_iter() + /// .map(|opt_v| { + /// match opt_v { + /// Some(v) => Some(v * v), + /// None => None, // null value + /// } + /// }).collect(); + /// ``` + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int32]` + pub fn try_i32(&self) -> Option<&Int32Chunked> { + try_unpack_chunked!(self, DataType::Int32 => Int32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int64]` + pub fn try_i64(&self) -> Option<&Int64Chunked> { + try_unpack_chunked!(self, DataType::Int64 => Int64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float32]` + pub fn try_f32(&self) -> Option<&Float32Chunked> { + try_unpack_chunked!(self, DataType::Float32 => Float32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float64]` + pub fn try_f64(&self) -> Option<&Float64Chunked> { + try_unpack_chunked!(self, DataType::Float64 => Float64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt8]` + pub fn try_u8(&self) -> Option<&UInt8Chunked> { + try_unpack_chunked!(self, DataType::UInt8 => UInt8Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt16]` + pub fn try_u16(&self) -> Option<&UInt16Chunked> { + try_unpack_chunked!(self, DataType::UInt16 => UInt16Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt32]` + pub fn try_u32(&self) -> Option<&UInt32Chunked> { + try_unpack_chunked!(self, DataType::UInt32 => UInt32Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt64]` + pub fn try_u64(&self) -> Option<&UInt64Chunked> { + try_unpack_chunked!(self, DataType::UInt64 => UInt64Chunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Boolean]` + pub fn try_bool(&self) -> Option<&BooleanChunked> { + try_unpack_chunked!(self, DataType::Boolean => BooleanChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::String]` + pub fn try_str(&self) -> Option<&StringChunked> { + try_unpack_chunked!(self, DataType::String => StringChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` + pub fn try_binary(&self) -> Option<&BinaryChunked> { + try_unpack_chunked!(self, DataType::Binary => BinaryChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` + pub fn try_binary_offset(&self) -> Option<&BinaryOffsetChunked> { + try_unpack_chunked!(self, DataType::BinaryOffset => BinaryOffsetChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Time]` + #[cfg(feature = "dtype-time")] + pub fn try_time(&self) -> Option<&TimeChunked> { + try_unpack_chunked!(self, DataType::Time => TimeChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Date]` + #[cfg(feature = "dtype-date")] + pub fn try_date(&self) -> Option<&DateChunked> { + try_unpack_chunked!(self, DataType::Date => DateChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Datetime]` + #[cfg(feature = "dtype-datetime")] + pub fn try_datetime(&self) -> Option<&DatetimeChunked> { + try_unpack_chunked!(self, DataType::Datetime(_, _) => DatetimeChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Duration]` + #[cfg(feature = "dtype-duration")] + pub fn try_duration(&self) -> Option<&DurationChunked> { + try_unpack_chunked!(self, DataType::Duration(_) => DurationChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Decimal]` + #[cfg(feature = "dtype-decimal")] + pub fn try_decimal(&self) -> Option<&DecimalChunked> { + try_unpack_chunked!(self, DataType::Decimal(_, _) => DecimalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype list + pub fn try_list(&self) -> Option<&ListChunked> { + try_unpack_chunked!(self, DataType::List(_) => ListChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Array]` + #[cfg(feature = "dtype-array")] + pub fn try_array(&self) -> Option<&ArrayChunked> { + try_unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Categorical]` + #[cfg(feature = "dtype-categorical")] + pub fn try_categorical(&self) -> Option<&CategoricalChunked> { + try_unpack_chunked!(self, DataType::Categorical(_, _) | DataType::Enum(_, _) => CategoricalChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Struct]` + #[cfg(feature = "dtype-struct")] + pub fn try_struct(&self) -> Option<&StructChunked> { + #[cfg(debug_assertions)] + { + if let DataType::Struct(_) = self.dtype() { + let any = self.as_any(); + assert!(any.is::()); + } + } + try_unpack_chunked!(self, DataType::Struct(_) => StructChunked) + } + + /// Unpack to [`ChunkedArray`] of dtype `[DataType::Null]` + pub fn try_null(&self) -> Option<&NullChunked> { + try_unpack_chunked!(self, DataType::Null => NullChunked) + } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int8]` pub fn i8(&self) -> PolarsResult<&Int8Chunked> { - unpack_chunked!(self, DataType::Int8 => Int8Chunked, "Int8") + self.try_i8() + .ok_or_else(|| unpack_chunked_err!(self => "Int8")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int16]` pub fn i16(&self) -> PolarsResult<&Int16Chunked> { - unpack_chunked!(self, DataType::Int16 => Int16Chunked, "Int16") + self.try_i16() + .ok_or_else(|| unpack_chunked_err!(self => "Int16")) } /// Unpack to [`ChunkedArray`] @@ -49,109 +203,129 @@ impl Series { /// ``` /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int32]` pub fn i32(&self) -> PolarsResult<&Int32Chunked> { - unpack_chunked!(self, DataType::Int32 => Int32Chunked, "Int32") + self.try_i32() + .ok_or_else(|| unpack_chunked_err!(self => "Int32")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Int64]` pub fn i64(&self) -> PolarsResult<&Int64Chunked> { - unpack_chunked!(self, DataType::Int64 => Int64Chunked, "Int64") + self.try_i64() + .ok_or_else(|| unpack_chunked_err!(self => "Int64")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float32]` pub fn f32(&self) -> PolarsResult<&Float32Chunked> { - unpack_chunked!(self, DataType::Float32 => Float32Chunked, "Float32") + self.try_f32() + .ok_or_else(|| unpack_chunked_err!(self => "Float32")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Float64]` pub fn f64(&self) -> PolarsResult<&Float64Chunked> { - unpack_chunked!(self, DataType::Float64 => Float64Chunked, "Float64") + self.try_f64() + .ok_or_else(|| unpack_chunked_err!(self => "Float64")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt8]` pub fn u8(&self) -> PolarsResult<&UInt8Chunked> { - unpack_chunked!(self, DataType::UInt8 => UInt8Chunked, "UInt8") + self.try_u8() + .ok_or_else(|| unpack_chunked_err!(self => "UInt8")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt16]` pub fn u16(&self) -> PolarsResult<&UInt16Chunked> { - unpack_chunked!(self, DataType::UInt16 => UInt16Chunked, "UInt16") + self.try_u16() + .ok_or_else(|| unpack_chunked_err!(self => "UInt16")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt32]` pub fn u32(&self) -> PolarsResult<&UInt32Chunked> { - unpack_chunked!(self, DataType::UInt32 => UInt32Chunked, "UInt32") + self.try_u32() + .ok_or_else(|| unpack_chunked_err!(self => "UInt32")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::UInt64]` pub fn u64(&self) -> PolarsResult<&UInt64Chunked> { - unpack_chunked!(self, DataType::UInt64 => UInt64Chunked, "UInt64") + self.try_u64() + .ok_or_else(|| unpack_chunked_err!(self => "UInt64")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Boolean]` pub fn bool(&self) -> PolarsResult<&BooleanChunked> { - unpack_chunked!(self, DataType::Boolean => BooleanChunked, "Boolean") + self.try_bool() + .ok_or_else(|| unpack_chunked_err!(self => "Boolean")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::String]` pub fn str(&self) -> PolarsResult<&StringChunked> { - unpack_chunked!(self, DataType::String => StringChunked, "String") + self.try_str() + .ok_or_else(|| unpack_chunked_err!(self => "String")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` pub fn binary(&self) -> PolarsResult<&BinaryChunked> { - unpack_chunked!(self, DataType::Binary => BinaryChunked, "Binary") + self.try_binary() + .ok_or_else(|| unpack_chunked_err!(self => "Binary")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Binary]` pub fn binary_offset(&self) -> PolarsResult<&BinaryOffsetChunked> { - unpack_chunked!(self, DataType::BinaryOffset => BinaryOffsetChunked, "BinaryOffset") + self.try_binary_offset() + .ok_or_else(|| unpack_chunked_err!(self => "BinaryOffset")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Time]` #[cfg(feature = "dtype-time")] pub fn time(&self) -> PolarsResult<&TimeChunked> { - unpack_chunked!(self, DataType::Time => TimeChunked, "Time") + self.try_time() + .ok_or_else(|| unpack_chunked_err!(self => "Time")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Date]` #[cfg(feature = "dtype-date")] pub fn date(&self) -> PolarsResult<&DateChunked> { - unpack_chunked!(self, DataType::Date => DateChunked, "Date") + self.try_date() + .ok_or_else(|| unpack_chunked_err!(self => "Date")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Datetime]` #[cfg(feature = "dtype-datetime")] pub fn datetime(&self) -> PolarsResult<&DatetimeChunked> { - unpack_chunked!(self, DataType::Datetime(_, _) => DatetimeChunked, "Datetime") + self.try_datetime() + .ok_or_else(|| unpack_chunked_err!(self => "Datetime")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Duration]` #[cfg(feature = "dtype-duration")] pub fn duration(&self) -> PolarsResult<&DurationChunked> { - unpack_chunked!(self, DataType::Duration(_) => DurationChunked, "Duration") + self.try_duration() + .ok_or_else(|| unpack_chunked_err!(self => "Duration")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Decimal]` #[cfg(feature = "dtype-decimal")] pub fn decimal(&self) -> PolarsResult<&DecimalChunked> { - unpack_chunked!(self, DataType::Decimal(_, _) => DecimalChunked, "Decimal") + self.try_decimal() + .ok_or_else(|| unpack_chunked_err!(self => "Decimal")) } /// Unpack to [`ChunkedArray`] of dtype list pub fn list(&self) -> PolarsResult<&ListChunked> { - unpack_chunked!(self, DataType::List(_) => ListChunked, "List") + self.try_list() + .ok_or_else(|| unpack_chunked_err!(self => "List")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Array]` #[cfg(feature = "dtype-array")] pub fn array(&self) -> PolarsResult<&ArrayChunked> { - unpack_chunked!(self, DataType::Array(_, _) => ArrayChunked, "FixedSizeList") + self.try_array() + .ok_or_else(|| unpack_chunked_err!(self => "FixedSizeList")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Categorical]` #[cfg(feature = "dtype-categorical")] pub fn categorical(&self) -> PolarsResult<&CategoricalChunked> { - unpack_chunked!(self, DataType::Categorical(_, _) | DataType::Enum(_, _) => CategoricalChunked, "Enum | Categorical") + self.try_categorical() + .ok_or_else(|| unpack_chunked_err!(self => "Enum | Categorical")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Struct]` @@ -164,11 +338,14 @@ impl Series { assert!(any.is::()); } } - unpack_chunked!(self, DataType::Struct(_) => StructChunked, "Struct") + + self.try_struct() + .ok_or_else(|| unpack_chunked_err!(self => "Struct")) } /// Unpack to [`ChunkedArray`] of dtype `[DataType::Null]` pub fn null(&self) -> PolarsResult<&NullChunked> { - unpack_chunked!(self, DataType::Null => NullChunked, "Null") + self.try_null() + .ok_or_else(|| unpack_chunked_err!(self => "Null")) } } diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index fdc6b6091058..544754755e6e 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -96,63 +96,90 @@ impl Series { let mut total_dim_size = 1; let mut num_infers = 0; - for (index, &dim) in dimensions.iter().enumerate() { + for &dim in dimensions { match dim { - ReshapeDimension::Infer => { - polars_ensure!( - num_infers == 0, - InvalidOperation: "can only specify one inferred dimension" - ); - num_infers += 1; - }, - ReshapeDimension::Specified(dim) => { - let dim = dim.get(); - - if dim > 0 { - total_dim_size *= dim as usize - } else { - polars_ensure!( - index == 0, - InvalidOperation: "cannot reshape array into shape containing a zero dimension after the first: {}", - format_tuple!(dimensions) - ); - total_dim_size = 0; - // We can early exit here, as empty arrays will error with multiple dimensions, - // and non-empty arrays will error when the first dimension is zero. - break; - } - }, + ReshapeDimension::Infer => num_infers += 1, + ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize, } } + polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension"); + if size == 0 { - if dimensions.len() > 1 || (num_infers == 0 && total_dim_size != 0) { - polars_bail!(InvalidOperation: "cannot reshape empty array into shape {}", format_tuple!(dimensions)) - } - } else if total_dim_size == 0 { - polars_bail!(InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", format_tuple!(dimensions)) - } else { polars_ensure!( - size % total_dim_size == 0, - InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + num_infers > 0 || total_dim_size == 0, + InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}", + format_tuple!(dimensions), ); + + let mut prev_arrow_dtype = leaf_array + .dtype() + .to_physical() + .to_arrow(CompatLevel::newest()); + let mut prev_dtype = leaf_array.dtype().clone(); + let mut prev_array = leaf_array.chunks()[0].clone(); + + // @NOTE: We need to collect the iterator here because it is lazily processed. + let mut current_length = dimensions[0].get_or_infer(0); + let len_iter = dimensions[1..] + .iter() + .map(|d| { + let length = current_length as usize; + current_length *= d.get_or_infer(0); + length + }) + .collect::>(); + + // We pop the outer dimension as that is the height of the series. + for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() { + // Infer dimension if needed + let dim = dim.get_or_infer(0); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); + prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); + + prev_array = + FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None) + .boxed(); + } + + return Ok(unsafe { + Series::from_chunks_and_dtype_unchecked( + leaf_array.name().clone(), + vec![prev_array], + &prev_dtype, + ) + }); } + polars_ensure!( + total_dim_size > 0, + InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}", + format_tuple!(dimensions) + ); + + polars_ensure!( + size % total_dim_size == 0, + InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions) + ); + let leaf_array = leaf_array.rechunk(); + let mut prev_arrow_dtype = leaf_array + .dtype() + .to_physical() + .to_arrow(CompatLevel::newest()); let mut prev_dtype = leaf_array.dtype().clone(); let mut prev_array = leaf_array.chunks()[0].clone(); // We pop the outer dimension as that is the height of the series. - for idx in (1..dimensions.len()).rev() { + for dim in dimensions[1..].iter().rev() { // Infer dimension if needed - let dim = dimensions[idx].get_or_infer_with(|| { - debug_assert!(num_infers > 0); - (size / total_dim_size) as u64 - }); + let dim = dim.get_or_infer((size / total_dim_size) as u64); + prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true); prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize); prev_array = FixedSizeListArray::new( - prev_dtype.to_arrow(CompatLevel::newest()), + prev_arrow_dtype.clone(), + prev_array.len() / dim as usize, prev_array, None, ) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs index 520f7f8596e1..3bc1beb30973 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs @@ -45,7 +45,7 @@ pub fn create_list( nested: &mut NestedState, values: Box, ) -> Box { - let (mut offsets, validity) = nested.pop().unwrap(); + let (length, mut offsets, validity) = nested.pop().unwrap(); let validity = validity.and_then(freeze_validity); match dtype.to_logical_type() { ArrowDataType::List(_) => { @@ -75,7 +75,7 @@ pub fn create_list( )) }, ArrowDataType::FixedSizeList(_, _) => { - Box::new(FixedSizeListArray::new(dtype, values, validity)) + Box::new(FixedSizeListArray::new(dtype, length, values, validity)) }, _ => unreachable!(), } @@ -87,7 +87,7 @@ pub fn create_map( nested: &mut NestedState, values: Box, ) -> Box { - let (mut offsets, validity) = nested.pop().unwrap(); + let (_, mut offsets, validity) = nested.pop().unwrap(); match dtype.to_logical_type() { ArrowDataType::Map(_, _) => { offsets.push(values.len() as i64); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index 114eeef67341..b5b083f8b882 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -403,7 +403,7 @@ pub fn columns_to_iter_recursive( let (mut nested, last_array) = field_to_nested_array(init.clone(), &mut columns, &mut types, last_field)?; debug_assert!(matches!(nested.last().unwrap(), NestedContent::Struct)); - let (_, struct_validity) = nested.pop().unwrap(); + let (_, _, struct_validity) = nested.pop().unwrap(); let mut field_arrays = Vec::>::with_capacity(fields.len()); field_arrays.push(last_array); @@ -416,7 +416,7 @@ pub fn columns_to_iter_recursive( { debug_assert!(matches!(_nested.last().unwrap(), NestedContent::Struct)); debug_assert_eq!( - _nested.pop().unwrap().1.and_then(freeze_validity), + _nested.pop().unwrap().2.and_then(freeze_validity), struct_validity.clone().and_then(freeze_validity), ); } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index ad542cf05753..ab769848ca92 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -87,12 +87,17 @@ impl Nested { } } - fn take(mut self) -> (Vec, Option) { + fn take(mut self) -> (usize, Vec, Option) { if !matches!(self.content, NestedContent::Primitive) { if let Some(validity) = self.validity.as_mut() { validity.extend_constant(self.num_valids, true); validity.extend_constant(self.num_invalids, false); } + + debug_assert!(self + .validity + .as_ref() + .map_or(true, |v| v.len() == self.length)); } self.num_valids = 0; @@ -101,11 +106,11 @@ impl Nested { match self.content { NestedContent::Primitive => { debug_assert!(self.validity.map_or(true, |validity| validity.is_empty())); - (Vec::new(), None) + (self.length, Vec::new(), None) }, - NestedContent::List { offsets } => (offsets, self.validity), - NestedContent::FixedSizeList { .. } => (Vec::new(), self.validity), - NestedContent::Struct => (Vec::new(), self.validity), + NestedContent::List { offsets } => (self.length, offsets, self.validity), + NestedContent::FixedSizeList { .. } => (self.length, Vec::new(), self.validity), + NestedContent::Struct => (self.length, Vec::new(), self.validity), } } @@ -254,7 +259,7 @@ impl NestedState { Self { nested } } - pub fn pop(&mut self) -> Option<(Vec, Option)> { + pub fn pop(&mut self) -> Option<(usize, Vec, Option)> { Some(self.nested.pop()?.take()) } diff --git a/crates/polars-parquet/src/arrow/read/statistics/list.rs b/crates/polars-parquet/src/arrow/read/statistics/list.rs index 54f308c94f4d..baea27289124 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/list.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/list.rs @@ -64,6 +64,7 @@ impl MutableArray for DynMutableListArray { }, ArrowDataType::FixedSizeList(field, _) => Box::new(FixedSizeListArray::new( ArrowDataType::FixedSizeList(field.clone(), inner.len()), + 1, inner, None, )), diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 9159c2b6b3ee..9225dfa4cae0 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -392,7 +392,7 @@ pub(super) fn concat(s: &mut [Column]) -> PolarsResult> { let mut first = std::mem::take(&mut s[0]); let other = &s[1..]; - let mut first_ca = match first.list().ok() { + let mut first_ca = match first.try_list() { Some(ca) => ca, None => { first = first diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs index 5e3e4174f667..db0c4ccfd802 100644 --- a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs @@ -12,6 +12,7 @@ fn data() -> FixedSizeListArray { Box::new(Field::new("a".into(), values.dtype().clone(), true)), 2, ), + 2, values.boxed(), Some([true, false].into()), ) @@ -87,6 +88,7 @@ fn wrong_size() { Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), 2 ), + 2, values.boxed(), None ) @@ -95,12 +97,13 @@ fn wrong_size() { #[test] fn wrong_len() { - let values = Int32Array::from_slice([10, 20, 0]); + let values = Int32Array::from_slice([10, 20, 0, 0]); assert!(FixedSizeListArray::try_new( ArrowDataType::FixedSizeList( Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), 2 ), + 2, values.boxed(), Some([true, false, false].into()), // it should be 2 ) @@ -109,11 +112,12 @@ fn wrong_len() { #[test] fn wrong_dtype() { - let values = Int32Array::from_slice([10, 20, 0]); + let values = Int32Array::from_slice([10, 20, 0, 0]); assert!(FixedSizeListArray::try_new( ArrowDataType::Binary, + 2, values.boxed(), - Some([true, false, false].into()), // it should be 2 + Some([true, false, false, false].into()), ) .is_err()); } diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs index 45e19d194a46..075e5179e1ca 100644 --- a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -27,6 +27,6 @@ fn fixed_size_list() { 3, ); let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); - let a = FixedSizeListArray::new(dtype, values, None); + let a = FixedSizeListArray::new(dtype, 2, values, None); assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); } diff --git a/py-polars/tests/unit/constructors/test_series.py b/py-polars/tests/unit/constructors/test_series.py index c31a5b48ce68..9c6346bf5395 100644 --- a/py-polars/tests/unit/constructors/test_series.py +++ b/py-polars/tests/unit/constructors/test_series.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re from datetime import date, datetime, timedelta from typing import TYPE_CHECKING, Any @@ -9,7 +8,6 @@ import pytest import polars as pl -from polars.exceptions import InvalidOperationError from polars.testing.asserts.series import assert_series_equal if TYPE_CHECKING: @@ -157,11 +155,10 @@ def test_series_init_pandas_timestamp_18127() -> None: def test_series_init_np_2d_zero_zero_shape() -> None: arr = np.array([]).reshape(0, 0) - with pytest.raises( - InvalidOperationError, - match=re.escape("cannot reshape empty array into shape (0, 0)"), - ): - pl.Series(arr) + assert_series_equal( + pl.Series("a", arr), + pl.Series("a", [], pl.Array(pl.Float64, 0)), + ) def test_list_null_constructor_schema() -> None: diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index 6c4f240803bf..b578266b0c6f 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -342,3 +342,44 @@ def test_array_invalid_physical_type_18920() -> None: expected = expected_s.to_frame().with_columns(pl.col.x.list.to_array(2)) assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + "fn", + [ + "__add__", + "__sub__", + "__mul__", + "__truediv__", + "__mod__", + "__eq__", + "__ne__", + ], +) +def test_zero_width_array(fn: str) -> None: + series_f = getattr(pl.Series, fn) + expr_f = getattr(pl.Expr, fn) + + values = [ + [ + [[]], + [None], + ], + [ + [[], []], + [None, []], + [[], None], + [None, None], + ], + ] + + for vs in values: + for lhs in vs: + for rhs in vs: + a = pl.Series("a", lhs, pl.Array(pl.Int8, 0)) + b = pl.Series("b", rhs, pl.Array(pl.Int8, 0)) + + series_f(a, b) + + df = pl.concat([a.to_frame(), b.to_frame()], how="horizontal") + df.select(c=expr_f(pl.col.a, pl.col.b)) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index b53766ae2c2c..39e515bf31d8 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -876,9 +876,8 @@ def test_parquet_array_dtype_nulls() -> None: ([[1, 2, 3]], pl.Array(pl.Int64, 3)), ([[1, None, 3], None, [1, 2, None]], pl.Array(pl.Int64, 3)), ([[1, 2], None, [None, 3]], pl.Array(pl.Int64, 2)), - # @TODO: Enable when zero-width arrays are enabled - # ([[], [], []], pl.Array(pl.Int64, 0)), - # ([[], None, []], pl.Array(pl.Int64, 0)), + ([[], [], []], pl.Array(pl.Int64, 0)), + ([[], None, []], pl.Array(pl.Int64, 0)), ( [[[1, 5, 2], [42, 13, 37]], [[1, 2, 3], [5, 2, 3]], [[1, 2, 1], [3, 1, 3]]], pl.Array(pl.Array(pl.Int8, 3), 2), @@ -924,7 +923,7 @@ def test_parquet_array_dtype_nulls() -> None: [[]], [[None]], [[[None], None]], - [[[None], []]], + [[[None], [None]]], [[[[None]], [[[1]]]]], [[[[[None]]]]], [[[[[1]]]]], @@ -940,12 +939,6 @@ def test_complex_types(series: list[Any], dtype: pl.DataType) -> None: test_round_trip(df) -@pytest.mark.xfail -def test_placeholder_zero_array() -> None: - # @TODO: if this does not fail anymore please enable the upper test-cases - pl.Series([[]], dtype=pl.Array(pl.Int8, 0)) - - @pytest.mark.write_disk def test_parquet_array_statistics(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) diff --git a/py-polars/tests/unit/operations/test_reshape.py b/py-polars/tests/unit/operations/test_reshape.py index 12ddfed628c5..d6d000769637 100644 --- a/py-polars/tests/unit/operations/test_reshape.py +++ b/py-polars/tests/unit/operations/test_reshape.py @@ -66,7 +66,7 @@ def test_reshape_invalid_zero_dimension() -> None: with pytest.raises( InvalidOperationError, match=re.escape( - f"cannot reshape array into shape containing a zero dimension after the first: {display_shape(shape)}" + f"cannot reshape non-empty array into shape containing a zero dimension: {display_shape(shape)}" ), ): s.reshape(shape) @@ -100,24 +100,14 @@ def test_reshape_empty_valid_1d(shape: tuple[int, ...]) -> None: assert_series_equal(out, s) -@pytest.mark.parametrize("shape", [(0, 1), (1, -1), (-1, 1)]) -def test_reshape_empty_invalid_2d(shape: tuple[int, ...]) -> None: - s = pl.Series("a", [], dtype=pl.Int64) - with pytest.raises( - InvalidOperationError, - match=re.escape( - f"cannot reshape empty array into shape {display_shape(shape)}" - ), - ): - s.reshape(shape) - - @pytest.mark.parametrize("shape", [(1,), (2,)]) def test_reshape_empty_invalid_1d(shape: tuple[int, ...]) -> None: s = pl.Series("a", [], dtype=pl.Int64) with pytest.raises( InvalidOperationError, - match=re.escape(f"cannot reshape empty array into shape ({shape[0]})"), + match=re.escape( + f"cannot reshape empty array into shape without zero dimension: ({shape[0]})" + ), ): s.reshape(shape) @@ -131,3 +121,23 @@ def test_array_ndarray_reshape() -> None: n = n[0] s = s[0] assert (n[0] == s[0].to_numpy()).all() + + +@pytest.mark.parametrize( + "shape", + [ + (0, 1), + (1, 0), + (-1, 10, 20, 10), + (-1, 1, 0), + (10, 1, 0), + (10, 0, 1, 0), + (10, 0, 1), + (42, 2, 3, 4, 0, 2, 3, 4), + (42, 1, 1, 1, 0), + ], +) +def test_reshape_empty(shape: tuple[int, ...]) -> None: + s = pl.Series("a", [], dtype=pl.Int64) + expected_len = max(shape[0], 0) + assert s.reshape(shape).len() == expected_len From d097d3cb12acd80d08a4887267dc348662968579 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 09:13:33 +0200 Subject: [PATCH 11/33] fix: Properly fetch type of full None List Series (#18916) --- .../src/frame/column/arithmetic.rs | 14 +- crates/polars-core/src/frame/column/mod.rs | 307 ++---------------- crates/polars-core/src/frame/column/scalar.rs | 287 ++++++++++++++++ crates/polars-core/src/frame/mod.rs | 2 +- crates/polars-core/src/series/any_value.rs | 150 +++++---- crates/polars-io/src/utils/other.rs | 6 +- crates/polars-lazy/src/tests/aggregations.rs | 2 +- crates/polars-lazy/src/tests/queries.rs | 4 +- crates/polars-ops/src/series/ops/duration.rs | 21 +- .../src/dsl/function_expr/round.rs | 26 +- py-polars/tests/unit/test_scalar.py | 23 ++ 11 files changed, 465 insertions(+), 377 deletions(-) create mode 100644 crates/polars-core/src/frame/column/scalar.rs diff --git a/crates/polars-core/src/frame/column/arithmetic.rs b/crates/polars-core/src/frame/column/arithmetic.rs index 79fd0053b320..05cc091c0a15 100644 --- a/crates/polars-core/src/frame/column/arithmetic.rs +++ b/crates/polars-core/src/frame/column/arithmetic.rs @@ -27,7 +27,7 @@ fn unit_series_op PolarsResult>( debug_assert!(r.len() <= 1); op(l, r) - .and_then(|s| ScalarColumn::from_single_value_series(s, length)) + .map(|s| ScalarColumn::from_single_value_series(s, length)) .map(Column::from) } @@ -70,12 +70,12 @@ fn num_op_with_broadcast Series>( c: &'_ Column, n: T, op: F, -) -> PolarsResult { +) -> Column { match c { - Column::Series(s) => Ok(op(s, n).into()), + Column::Series(s) => op(s, n).into(), Column::Scalar(s) => { - ScalarColumn::from_single_value_series(op(&s.as_single_value_series(), n), s.length) - .map(Column::from) + ScalarColumn::from_single_value_series(op(&s.as_single_value_series(), n), s.len()) + .into() }, } } @@ -111,7 +111,7 @@ macro_rules! broadcastable_num_ops { where T: Num + NumCast, { - type Output = PolarsResult; + type Output = Self; #[inline] fn $op(self, rhs: T) -> Self::Output { @@ -123,7 +123,7 @@ macro_rules! broadcastable_num_ops { where T: Num + NumCast, { - type Output = PolarsResult; + type Output = Column; #[inline] fn $op(self, rhs: T) -> Self::Output { diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 1dea44ee393a..78c36db57f78 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -1,10 +1,10 @@ use std::borrow::Cow; -use std::sync::OnceLock; use num_traits::{Num, NumCast}; use polars_error::PolarsResult; use polars_utils::index::check_bounds; use polars_utils::pl_str::PlSmallStr; +pub use scalar::ScalarColumn; use self::gather::check_bounds_ca; use crate::chunked_array::cast::CastOptions; @@ -16,6 +16,7 @@ use crate::utils::{slice_offsets, Container}; use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; mod arithmetic; +mod scalar; /// A column within a [`DataFrame`]. /// @@ -35,25 +36,6 @@ pub enum Column { Scalar(ScalarColumn), } -/// A [`Column`] that consists of a repeated [`Scalar`] -/// -/// This is lazily materialized into a [`Series`]. -#[derive(Debug, Clone)] -pub struct ScalarColumn { - name: PlSmallStr, - // The value of this scalar may be incoherent when `length == 0`. - scalar: Scalar, - length: usize, - - // invariants: - // materialized.name() == name - // materialized.len() == length - // materialized.dtype() == value.dtype - // materialized[i] == value, for all 0 <= i < length - /// A lazily materialized [`Series`] variant of this [`ScalarColumn`] - materialized: OnceLock, -} - /// Convert `Self` into a [`Column`] pub trait IntoColumn: Sized { fn into_column(self) -> Column; @@ -98,7 +80,11 @@ impl Column { match self { Column::Series(s) => s, Column::Scalar(s) => { - let series = s.materialized.take().unwrap_or_else(|| s.to_series()); + let series = std::mem::replace( + s, + ScalarColumn::new_empty(PlSmallStr::EMPTY, DataType::Null), + ) + .take_materialized_series(); *self = Column::Series(series); let Column::Series(s) = self else { unreachable!(); @@ -122,7 +108,7 @@ impl Column { pub fn dtype(&self) -> &DataType { match self { Column::Series(s) => s.dtype(), - Column::Scalar(s) => s.scalar.dtype(), + Column::Scalar(s) => s.dtype(), } } @@ -130,8 +116,8 @@ impl Column { pub fn field(&self) -> Cow { match self { Column::Series(s) => s.field(), - Column::Scalar(s) => match s.materialized.get() { - None => Cow::Owned(Field::new(s.name.clone(), s.scalar.dtype().clone())), + Column::Scalar(s) => match s.lazy_as_materialized_series() { + None => Cow::Owned(Field::new(s.name().clone(), s.dtype().clone())), Some(s) => s.field(), }, } @@ -141,7 +127,7 @@ impl Column { pub fn name(&self) -> &PlSmallStr { match self { Column::Series(s) => s.name(), - Column::Scalar(s) => &s.name, + Column::Scalar(s) => s.name(), } } @@ -149,7 +135,7 @@ impl Column { pub fn len(&self) -> usize { match self { Column::Series(s) => s.len(), - Column::Scalar(s) => s.length, + Column::Scalar(s) => s.len(), } } @@ -163,13 +149,7 @@ impl Column { pub fn rename(&mut self, name: PlSmallStr) { match self { Column::Series(s) => _ = s.rename(name), - Column::Scalar(s) => { - if let Some(series) = s.materialized.get_mut() { - series.rename(name.clone()); - } - - s.name = name; - }, + Column::Scalar(s) => _ = s.rename(name), } } @@ -377,7 +357,7 @@ impl Column { pub fn clear(&self) -> Self { match self { Column::Series(s) => s.clear().into(), - Column::Scalar(s) => Self::new_scalar(s.name.clone(), s.scalar.clone(), 0), + Column::Scalar(s) => s.resize(0).into(), } } @@ -394,8 +374,8 @@ impl Column { match self { Column::Series(s) => s.new_from_index(index, length).into(), Column::Scalar(s) => { - if index >= s.length { - Self::full_null(s.name.clone(), length, s.scalar.dtype()) + if index >= s.len() { + Self::full_null(s.name().clone(), length, s.dtype()) } else { s.resize(length).into() } @@ -415,14 +395,18 @@ impl Column { pub fn is_null(&self) -> BooleanChunked { match self { Self::Series(s) => s.is_null(), - Self::Scalar(s) => BooleanChunked::full(s.name.clone(), s.scalar.is_null(), s.length), + Self::Scalar(s) => { + BooleanChunked::full(s.name().clone(), s.scalar().is_null(), s.len()) + }, } } #[inline] pub fn is_not_null(&self) -> BooleanChunked { match self { Self::Series(s) => s.is_not_null(), - Self::Scalar(s) => BooleanChunked::full(s.name.clone(), !s.scalar.is_null(), s.length), + Self::Scalar(s) => { + BooleanChunked::full(s.name().clone(), !s.scalar().is_null(), s.len()) + }, } } @@ -449,7 +433,7 @@ impl Column { match self { Column::Series(s) => s.slice(offset, length).into(), Column::Scalar(s) => { - let (_, length) = slice_offsets(offset, length, s.length); + let (_, length) = slice_offsets(offset, length, s.len()); s.resize(length).into() }, } @@ -465,7 +449,7 @@ impl Column { pub fn null_count(&self) -> usize { match self { Self::Series(s) => s.null_count(), - Self::Scalar(s) if s.scalar.is_null() => s.length, + Self::Scalar(s) if s.scalar().is_null() => s.len(), Self::Scalar(_) => 0, } } @@ -875,7 +859,7 @@ impl Column { match self { Column::Series(s) => s.gather_every(n, offset).into(), - Column::Scalar(s) => s.resize(s.length - offset / n).into(), + Column::Scalar(s) => s.resize(s.len() - offset / n).into(), } } @@ -891,7 +875,7 @@ impl Column { match self { Column::Series(s) => s.extend_constant(value, n).map(Column::from), Column::Scalar(s) => { - if s.scalar.as_any_value() == value { + if s.scalar().as_any_value() == value { Ok(s.resize(s.len() + n).into()) } else { s.as_materialized_series() @@ -958,7 +942,7 @@ impl Column { match self { Column::Series(s) => s.get_unchecked(index), - Column::Scalar(s) => s.scalar.as_any_value(), + Column::Scalar(s) => s.scalar().as_any_value(), } } @@ -1089,240 +1073,16 @@ impl PartialEq for Column { impl From for Column { #[inline] fn from(series: Series) -> Self { + // We instantiate a Scalar Column if the Series is length is 1. This makes it possible for + // future operations to be faster. if series.len() == 1 { - // SAFETY: We just did the bounds check - let value = unsafe { series.get_unchecked(0) }.into_static(); - let value = Scalar::new(series.dtype().clone(), value); - let mut col = ScalarColumn::new(series.name().clone(), value, 1); - col.materialized = OnceLock::from(series); - return Self::Scalar(col); + return Self::Scalar(ScalarColumn::unit_scalar_from_series(series)); } Self::Series(series) } } -impl From for Column { - #[inline] - fn from(value: ScalarColumn) -> Self { - Self::Scalar(value) - } -} - -impl ScalarColumn { - #[inline] - pub fn new(name: PlSmallStr, scalar: Scalar, length: usize) -> Self { - Self { - name, - scalar, - length, - - materialized: OnceLock::new(), - } - } - - #[inline] - pub fn new_empty(name: PlSmallStr, dtype: DataType) -> Self { - Self { - name, - scalar: Scalar::new(dtype, AnyValue::Null), - length: 0, - - materialized: OnceLock::new(), - } - } - - pub fn name(&self) -> &PlSmallStr { - &self.name - } - - pub fn dtype(&self) -> &DataType { - self.scalar.dtype() - } - - pub fn len(&self) -> usize { - self.length - } - - pub fn is_empty(&self) -> bool { - self.length == 0 - } - - fn _to_series(name: PlSmallStr, value: Scalar, length: usize) -> Series { - let series = if length == 0 { - Series::new_empty(name, value.dtype()) - } else { - value.into_series(name).new_from_index(0, length) - }; - - debug_assert_eq!(series.len(), length); - - series - } - - /// Materialize the [`ScalarColumn`] into a [`Series`]. - pub fn to_series(&self) -> Series { - Self::_to_series(self.name.clone(), self.scalar.clone(), self.length) - } - - /// Get the [`ScalarColumn`] as [`Series`] - /// - /// This needs to materialize upon the first call. Afterwards, this is cached. - pub fn as_materialized_series(&self) -> &Series { - self.materialized.get_or_init(|| self.to_series()) - } - - /// Take the [`ScalarColumn`] and materialize as a [`Series`] if not already done. - pub fn take_materialized_series(self) -> Series { - self.materialized - .into_inner() - .unwrap_or_else(|| Self::_to_series(self.name, self.scalar, self.length)) - } - - /// Take the [`ScalarColumn`] as a series with a single value. - /// - /// If the [`ScalarColumn`] has `length=0` the resulting `Series` will also have `length=0`. - pub fn as_single_value_series(&self) -> Series { - match self.materialized.get() { - Some(s) => s.head(Some(1)), - None => Self::_to_series( - self.name.clone(), - self.scalar.clone(), - usize::min(1, self.length), - ), - } - } - - /// Create a new [`ScalarColumn`] from a `length=1` Series and expand it `length`. - /// - /// This will panic if the value cannot be made static or if the series has length `0`. - pub fn from_single_value_series(series: Series, length: usize) -> PolarsResult { - debug_assert_eq!(series.len(), 1); - let value = series.get(0)?; - let value = value.into_static(); - let value = Scalar::new(series.dtype().clone(), value); - Ok(ScalarColumn::new(series.name().clone(), value, length)) - } - - /// Resize the [`ScalarColumn`] to new `length`. - /// - /// This reuses the materialized [`Series`], if `length <= self.length`. - pub fn resize(&self, length: usize) -> ScalarColumn { - if self.length == length { - return self.clone(); - } - - // This is violates an invariant if this triggers, the scalar value is undefined if the - // self.length == 0 so therefore we should never resize using that value. - debug_assert_ne!(self.length, 0); - - let mut resized = Self { - name: self.name.clone(), - scalar: self.scalar.clone(), - length, - materialized: OnceLock::new(), - }; - - if self.length >= length { - if let Some(materialized) = self.materialized.get() { - resized.materialized = OnceLock::from(materialized.head(Some(length))); - debug_assert_eq!(resized.materialized.get().unwrap().len(), length); - } - } - - resized - } - - pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { - // @NOTE: We expect that when casting the materialized series mostly does not need change - // the physical array. Therefore, we try to cast the entire materialized array if it is - // available. - - match self.materialized.get() { - Some(s) => { - let materialized = s.cast_with_options(dtype, options)?; - assert_eq!(self.length, materialized.len()); - - let mut casted = if materialized.len() == 0 { - Self::new_empty(materialized.name().clone(), materialized.dtype().clone()) - } else { - // SAFETY: Just did bounds check - let scalar = unsafe { materialized.get_unchecked(0) }.into_static(); - Self::new( - materialized.name().clone(), - Scalar::new(materialized.dtype().clone(), scalar), - self.length, - ) - }; - casted.materialized = OnceLock::from(materialized); - Ok(casted) - }, - None => { - let s = self - .as_single_value_series() - .cast_with_options(dtype, options)?; - assert_eq!(1, s.len()); - - if self.length == 0 { - Ok(Self::new_empty(s.name().clone(), s.dtype().clone())) - } else { - Self::from_single_value_series(s, self.length) - } - }, - } - } - - pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { - self.cast_with_options(dtype, CastOptions::Strict) - } - pub fn cast(&self, dtype: &DataType) -> PolarsResult { - self.cast_with_options(dtype, CastOptions::NonStrict) - } - /// # Safety - /// - /// This can lead to invalid memory access in downstream code. - pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { - // @NOTE: We expect that when casting the materialized series mostly does not need change - // the physical array. Therefore, we try to cast the entire materialized array if it is - // available. - - match self.materialized.get() { - Some(s) => { - let materialized = s.cast_unchecked(dtype)?; - assert_eq!(self.length, materialized.len()); - - let mut casted = if materialized.len() == 0 { - Self::new_empty(materialized.name().clone(), materialized.dtype().clone()) - } else { - // SAFETY: Just did bounds check - let scalar = unsafe { materialized.get_unchecked(0) }.into_static(); - Self::new( - materialized.name().clone(), - Scalar::new(materialized.dtype().clone(), scalar), - self.length, - ) - }; - casted.materialized = OnceLock::from(materialized); - Ok(casted) - }, - None => { - let s = self.as_single_value_series().cast_unchecked(dtype)?; - assert_eq!(1, s.len()); - - if self.length == 0 { - Ok(Self::new_empty(s.name().clone(), s.dtype().clone())) - } else { - Self::from_single_value_series(s, self.length) - } - }, - } - } - - pub fn has_nulls(&self) -> bool { - self.length != 0 && self.scalar.is_null() - } -} - impl IntoColumn for T { #[inline] fn into_column(self) -> Column { @@ -1337,13 +1097,6 @@ impl IntoColumn for Column { } } -impl IntoColumn for ScalarColumn { - #[inline(always)] - fn into_column(self) -> Column { - self.into() - } -} - /// We don't want to serialize the scalar columns. So this helps pretend that columns are always /// initialized without implementing From for Series. /// diff --git a/crates/polars-core/src/frame/column/scalar.rs b/crates/polars-core/src/frame/column/scalar.rs new file mode 100644 index 000000000000..18e53c469960 --- /dev/null +++ b/crates/polars-core/src/frame/column/scalar.rs @@ -0,0 +1,287 @@ +use std::sync::OnceLock; + +use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; + +use super::{AnyValue, Column, DataType, IntoColumn, Scalar, Series}; +use crate::chunked_array::cast::CastOptions; + +/// A [`Column`] that consists of a repeated [`Scalar`] +/// +/// This is lazily materialized into a [`Series`]. +#[derive(Debug, Clone)] +pub struct ScalarColumn { + name: PlSmallStr, + // The value of this scalar may be incoherent when `length == 0`. + scalar: Scalar, + length: usize, + + // invariants: + // materialized.name() == name + // materialized.len() == length + // materialized.dtype() == value.dtype + // materialized[i] == value, for all 0 <= i < length + /// A lazily materialized [`Series`] variant of this [`ScalarColumn`] + materialized: OnceLock, +} + +impl ScalarColumn { + #[inline] + pub fn new(name: PlSmallStr, scalar: Scalar, length: usize) -> Self { + Self { + name, + scalar, + length, + + materialized: OnceLock::new(), + } + } + + #[inline] + pub fn new_empty(name: PlSmallStr, dtype: DataType) -> Self { + Self { + name, + scalar: Scalar::new(dtype, AnyValue::Null), + length: 0, + + materialized: OnceLock::new(), + } + } + + pub fn name(&self) -> &PlSmallStr { + &self.name + } + + pub fn scalar(&self) -> &Scalar { + &self.scalar + } + + pub fn dtype(&self) -> &DataType { + self.scalar.dtype() + } + + pub fn len(&self) -> usize { + self.length + } + + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + fn _to_series(name: PlSmallStr, value: Scalar, length: usize) -> Series { + let series = if length == 0 { + Series::new_empty(name, value.dtype()) + } else { + value.into_series(name).new_from_index(0, length) + }; + + debug_assert_eq!(series.len(), length); + + series + } + + /// Materialize the [`ScalarColumn`] into a [`Series`]. + pub fn to_series(&self) -> Series { + Self::_to_series(self.name.clone(), self.scalar.clone(), self.length) + } + + /// Get the [`ScalarColumn`] as [`Series`] if it was already materialized. + pub fn lazy_as_materialized_series(&self) -> Option<&Series> { + self.materialized.get() + } + + /// Get the [`ScalarColumn`] as [`Series`] + /// + /// This needs to materialize upon the first call. Afterwards, this is cached. + pub fn as_materialized_series(&self) -> &Series { + self.materialized.get_or_init(|| self.to_series()) + } + + /// Take the [`ScalarColumn`] and materialize as a [`Series`] if not already done. + pub fn take_materialized_series(self) -> Series { + self.materialized + .into_inner() + .unwrap_or_else(|| Self::_to_series(self.name, self.scalar, self.length)) + } + + /// Take the [`ScalarColumn`] as a series with a single value. + /// + /// If the [`ScalarColumn`] has `length=0` the resulting `Series` will also have `length=0`. + pub fn as_single_value_series(&self) -> Series { + match self.materialized.get() { + Some(s) => s.head(Some(1)), + None => Self::_to_series( + self.name.clone(), + self.scalar.clone(), + usize::min(1, self.length), + ), + } + } + + /// Create a new [`ScalarColumn`] from a `length=1` Series and expand it `length`. + /// + /// This will panic if the value cannot be made static or if the series has length `0`. + #[inline] + pub fn unit_scalar_from_series(series: Series) -> Self { + assert_eq!(series.len(), 1); + // SAFETY: We just did the bounds check + let value = unsafe { series.get_unchecked(0) }; + let value = value.into_static(); + let value = Scalar::new(series.dtype().clone(), value); + let mut sc = ScalarColumn::new(series.name().clone(), value, 1); + sc.materialized = OnceLock::from(series); + sc + } + + /// Create a new [`ScalarColumn`] from a `length=1` Series and expand it `length`. + /// + /// This will panic if the value cannot be made static or if the series has length `0`. + pub fn from_single_value_series(series: Series, length: usize) -> Self { + debug_assert_eq!(series.len(), 1); + let value = series.get(0).unwrap(); + let value = value.into_static(); + let value = Scalar::new(series.dtype().clone(), value); + ScalarColumn::new(series.name().clone(), value, length) + } + + /// Resize the [`ScalarColumn`] to new `length`. + /// + /// This reuses the materialized [`Series`], if `length <= self.length`. + pub fn resize(&self, length: usize) -> ScalarColumn { + if self.length == length { + return self.clone(); + } + + // This is violates an invariant if this triggers, the scalar value is undefined if the + // self.length == 0 so therefore we should never resize using that value. + debug_assert!(length == 0 || self.length > 0); + + let mut resized = Self { + name: self.name.clone(), + scalar: self.scalar.clone(), + length, + materialized: OnceLock::new(), + }; + + if self.length >= length { + if let Some(materialized) = self.materialized.get() { + resized.materialized = OnceLock::from(materialized.head(Some(length))); + debug_assert_eq!(resized.materialized.get().unwrap().len(), length); + } + } + + resized + } + + pub fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + // @NOTE: We expect that when casting the materialized series mostly does not need change + // the physical array. Therefore, we try to cast the entire materialized array if it is + // available. + + match self.materialized.get() { + Some(s) => { + let materialized = s.cast_with_options(dtype, options)?; + assert_eq!(self.length, materialized.len()); + + let mut casted = if materialized.len() == 0 { + Self::new_empty(materialized.name().clone(), materialized.dtype().clone()) + } else { + // SAFETY: Just did bounds check + let scalar = unsafe { materialized.get_unchecked(0) }.into_static(); + Self::new( + materialized.name().clone(), + Scalar::new(materialized.dtype().clone(), scalar), + self.length, + ) + }; + casted.materialized = OnceLock::from(materialized); + Ok(casted) + }, + None => { + let s = self + .as_single_value_series() + .cast_with_options(dtype, options)?; + + if self.length == 0 { + Ok(Self::new_empty(s.name().clone(), s.dtype().clone())) + } else { + assert_eq!(1, s.len()); + Ok(Self::from_single_value_series(s, self.length)) + } + }, + } + } + + pub fn strict_cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Strict) + } + pub fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) + } + /// # Safety + /// + /// This can lead to invalid memory access in downstream code. + pub unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + // @NOTE: We expect that when casting the materialized series mostly does not need change + // the physical array. Therefore, we try to cast the entire materialized array if it is + // available. + + match self.materialized.get() { + Some(s) => { + let materialized = s.cast_unchecked(dtype)?; + assert_eq!(self.length, materialized.len()); + + let mut casted = if materialized.len() == 0 { + Self::new_empty(materialized.name().clone(), materialized.dtype().clone()) + } else { + // SAFETY: Just did bounds check + let scalar = unsafe { materialized.get_unchecked(0) }.into_static(); + Self::new( + materialized.name().clone(), + Scalar::new(materialized.dtype().clone(), scalar), + self.length, + ) + }; + casted.materialized = OnceLock::from(materialized); + Ok(casted) + }, + None => { + let s = self.as_single_value_series().cast_unchecked(dtype)?; + assert_eq!(1, s.len()); + + if self.length == 0 { + Ok(Self::new_empty(s.name().clone(), s.dtype().clone())) + } else { + Ok(Self::from_single_value_series(s, self.length)) + } + }, + } + } + + pub fn rename(&mut self, name: PlSmallStr) -> &mut Self { + if let Some(series) = self.materialized.get_mut() { + series.rename(name.clone()); + } + + self.name = name; + self + } + + pub fn has_nulls(&self) -> bool { + self.length != 0 && self.scalar.is_null() + } +} + +impl IntoColumn for ScalarColumn { + #[inline(always)] + fn into_column(self) -> Column { + self.into() + } +} + +impl From for Column { + #[inline] + fn from(value: ScalarColumn) -> Self { + Self::Scalar(value) + } +} diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 0867f7a3686d..a741ad846351 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2166,7 +2166,7 @@ impl DataFrame { /// let mut df = DataFrame::new(vec![s0, s1])?; /// /// // Add 32 to get lowercase ascii values - /// df.apply_at_idx(1, |s| (s + 32).unwrap()); + /// df.apply_at_idx(1, |s| s + 32); /// # Ok::<(), PolarsError>(()) /// ``` /// Results in: diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 1f6562b5397a..5cd438db3441 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -2,6 +2,7 @@ use std::fmt::Write; use arrow::bitmap::MutableBitmap; +use crate::chunked_array::builder::{get_list_builder, AnonymousOwnedListBuilder}; #[cfg(feature = "object")] use crate::chunked_array::object::registry::ObjectRegistry; use crate::prelude::*; @@ -603,76 +604,103 @@ fn any_values_to_list( inner_type: &DataType, strict: bool, ) -> PolarsResult { - let it = match inner_type { - // Structs don't support empty fields yet. - // We must ensure the data-types match what we do physical - #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) if fields.is_empty() => { - DataType::Struct(vec![Field::new(PlSmallStr::EMPTY, DataType::Null)]) - }, - _ => inner_type.clone(), - }; - let target_dtype = DataType::List(Box::new(it)); + // GB: + // Lord forgive for the sins I have committed in this function. The amount of strange + // exceptions that need to happen for this to work are insane and I feel like I am going crazy. + // + // This function is essentially a copy of the `` where it does not + // sample the datatype from the first element and instead we give it explicitly. This allows + // this function to properly assign a datatype if `avs` starts with a `null` value. Previously, + // this was solved by assigning the `dtype` again afterwards, but why? We should not link the + // implementation of these functions. We still need to assign the dtype of the ListArray and + // such, anyways. + // + // Then, `collect_ca_with_dtype` does not possess the necessary exceptions shown in this + // function to use that. I have tried adding the exceptions there and it broke other things. I + // really do feel like this is the simplest solution. - // This is handled downstream. The builder will choose the first non-null type. let mut valid = true; - #[allow(unused_mut)] - let mut out: ListChunked = if inner_type == &DataType::Null { - avs.iter() - .map(|av| match av { - AnyValue::List(b) => Some(b.clone()), - AnyValue::Null => None, - _ => { - valid = false; - None - }, - }) - .collect_trusted() - } - // Make sure that wrongly inferred AnyValues don't deviate from the datatype. - else { - avs.iter() - .map(|av| match av { - AnyValue::List(b) => { - if b.dtype() == inner_type { - Some(b.clone()) - } else { - match b.cast(inner_type) { - Ok(out) => { - if out.null_count() != b.null_count() { - valid = !strict; - } - Some(out) - }, - Err(_) => { - valid = !strict; - Some(Series::full_null(b.name().clone(), b.len(), inner_type)) - }, - } - } + let capacity = avs.len(); + + let ca = match inner_type { + // AnyValues with empty lists in python can create + // Series of an unknown dtype. + // We use the anonymousbuilder without a dtype + // the empty arrays is then not added (we add an extra offset instead) + // the next non-empty series then must have the correct dtype. + DataType::Null => { + let mut builder = AnonymousOwnedListBuilder::new(PlSmallStr::EMPTY, capacity, None); + for av in avs { + match av { + AnyValue::List(b) => builder.append_series(b)?, + AnyValue::Null => builder.append_null(), + _ => { + valid = false; + builder.append_null(); + }, + } + } + builder.finish() + }, + + #[cfg(feature = "object")] + DataType::Object(_, _) => polars_bail!(nyi = "Nested object types"), + + _ => { + let list_inner_type = match inner_type { + // Categoricals may not have a revmap yet. We just give them an empty one here and + // the list builder takes care of the rest. + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(None, ordering) => { + DataType::Categorical(Some(Arc::new(RevMapping::default())), *ordering) }, - AnyValue::Null => None, - _ => { - valid = false; - None + + // Structs don't support empty fields yet. + // We must ensure the data-types match what we do physical + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) if fields.is_empty() => { + DataType::Struct(vec![Field::new(PlSmallStr::EMPTY, DataType::Null)]) }, - }) - .collect_trusted() + + _ => inner_type.clone(), + }; + + let mut builder = + get_list_builder(&list_inner_type, capacity * 5, capacity, PlSmallStr::EMPTY)?; + + for av in avs { + match av { + AnyValue::List(b) => match b.cast(inner_type) { + Ok(casted) => { + if casted.null_count() != b.null_count() { + valid = !strict; + } + builder.append_series(&casted)?; + }, + Err(_) => { + valid = false; + for _ in 0..b.len() { + builder.append_null(); + } + }, + }, + AnyValue::Null => builder.append_null(), + _ => { + valid = false; + builder.append_null() + }, + } + } + + builder.finish() + }, }; if strict && !valid { - polars_bail!(SchemaMismatch: "unexpected value while building Series of type {:?}", target_dtype); + polars_bail!(SchemaMismatch: "unexpected value while building Series of type {:?}", DataType::List(Box::new(inner_type.clone()))); } - // Ensure the logical type is correct for nested types. - #[cfg(feature = "dtype-struct")] - if !matches!(inner_type, DataType::Null) && out.inner_dtype().is_nested() { - unsafe { - out.set_dtype(target_dtype.clone()); - }; - } - - Ok(out) + Ok(ca) } #[cfg(feature = "dtype-array")] diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index 45300d80d319..12e3ee2f9d01 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -87,7 +87,7 @@ pub(crate) fn update_row_counts(dfs: &mut [(DataFrame, IdxSize)], offset: IdxSiz let mut previous = dfs[0].1 + offset; for (df, n_read) in &mut dfs[1..] { if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = (&*s + previous).unwrap(); + *s = &*s + previous; } previous += *n_read; } @@ -103,7 +103,7 @@ pub(crate) fn update_row_counts2(dfs: &mut [DataFrame], offset: IdxSize) { for df in &mut dfs[1..] { let n_read = df.height() as IdxSize; if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = (&*s + previous).unwrap(); + *s = &*s + previous; } previous += n_read; } @@ -122,7 +122,7 @@ pub(crate) fn update_row_counts3(dfs: &mut [DataFrame], heights: &[IdxSize], off let n_read = heights[i]; if let Some(s) = unsafe { df.get_columns_mut() }.get_mut(0) { - *s = (&*s + previous).unwrap(); + *s = &*s + previous; } previous += n_read; diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 7bb21eb5bcd6..6b2d8cb05da0 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -31,7 +31,7 @@ fn test_agg_exprs() -> PolarsResult<()> { .lazy() .group_by_stable([col("cars")]) .agg([(lit(1) - col("A")) - .map(|s| Ok(Some((&s * 2)?)), GetOutput::same_type()) + .map(|s| Ok(Some(&s * 2)), GetOutput::same_type()) .alias("foo")]) .collect()?; let ca = out.column("foo")?.list()?; diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index ff4894b99857..4d482202cd67 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -88,7 +88,7 @@ fn test_lazy_udf() { let df = get_df(); let new = df .lazy() - .select([col("sepal_width").map(|s| Ok(Some((s * 200.0)?)), GetOutput::same_type())]) + .select([col("sepal_width").map(|s| Ok(Some(s * 200.0)), GetOutput::same_type())]) .collect() .unwrap(); assert_eq!( @@ -247,7 +247,7 @@ fn test_lazy_query_2() { let df = load_df(); let ldf = df .lazy() - .with_column(col("a").map(|s| Ok(Some((s * 2)?)), GetOutput::same_type())) + .with_column(col("a").map(|s| Ok(Some(s * 2)), GetOutput::same_type())) .filter(col("a").lt(lit(2))) .select([col("b"), col("a")]); diff --git a/crates/polars-ops/src/series/ops/duration.rs b/crates/polars-ops/src/series/ops/duration.rs index b839fda8d375..2c8f0ae022e5 100644 --- a/crates/polars-ops/src/series/ops/duration.rs +++ b/crates/polars-ops/src/series/ops/duration.rs @@ -35,7 +35,7 @@ pub fn impl_duration(s: &[Column], time_unit: TimeUnit) -> PolarsResult microseconds = (microseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000)))?; } if !is_zero_scalar(&milliseconds) { - microseconds = (microseconds + (milliseconds * 1_000)?)?; + microseconds = (microseconds + milliseconds * 1_000)?; } microseconds }, @@ -44,10 +44,10 @@ pub fn impl_duration(s: &[Column], time_unit: TimeUnit) -> PolarsResult nanoseconds = nanoseconds.new_from_index(0, max_len); } if !is_zero_scalar(µseconds) { - nanoseconds = (nanoseconds + (microseconds * 1_000)?)?; + nanoseconds = (nanoseconds + microseconds * 1_000)?; } if !is_zero_scalar(&milliseconds) { - nanoseconds = (nanoseconds + (milliseconds * 1_000_000)?)?; + nanoseconds = (nanoseconds + milliseconds * 1_000_000)?; } nanoseconds }, @@ -72,24 +72,19 @@ pub fn impl_duration(s: &[Column], time_unit: TimeUnit) -> PolarsResult TimeUnit::Milliseconds => MILLISECONDS, }; if !is_zero_scalar(&seconds) { - let units = seconds * multiplier; - duration = (duration + units?)?; + duration = (duration + seconds * multiplier)?; } if !is_zero_scalar(&minutes) { - let units = minutes * (multiplier * 60); - duration = (duration + units?)?; + duration = (duration + minutes * multiplier * 60)?; } if !is_zero_scalar(&hours) { - let units = hours * (multiplier * 60 * 60); - duration = (duration + units?)?; + duration = (duration + hours * multiplier * 60 * 60)?; } if !is_zero_scalar(&days) { - let units = days * (multiplier * SECONDS_IN_DAY); - duration = (duration + units?)?; + duration = (duration + days * multiplier * SECONDS_IN_DAY)?; } if !is_zero_scalar(&weeks) { - let units = weeks * (multiplier * SECONDS_IN_DAY * 7); - duration = (duration + units?)?; + duration = (duration + weeks * multiplier * SECONDS_IN_DAY * 7)?; } duration diff --git a/crates/polars-plan/src/dsl/function_expr/round.rs b/crates/polars-plan/src/dsl/function_expr/round.rs index 41b2f04324d0..110d639d3c92 100644 --- a/crates/polars-plan/src/dsl/function_expr/round.rs +++ b/crates/polars-plan/src/dsl/function_expr/round.rs @@ -8,11 +8,11 @@ pub(super) fn round(c: &Column, decimals: u32) -> PolarsResult { Column::Scalar(s) if s.is_empty() => { s.as_materialized_series().round(decimals).map(Column::from) }, - Column::Scalar(s) => ScalarColumn::from_single_value_series( + Column::Scalar(s) => Ok(ScalarColumn::from_single_value_series( s.as_single_value_series().round(decimals)?, s.len(), ) - .map(Column::from), + .into()), } } @@ -23,11 +23,11 @@ pub(super) fn round_sig_figs(c: &Column, digits: i32) -> PolarsResult { .as_materialized_series() .round_sig_figs(digits) .map(Column::from), - Column::Scalar(s) => ScalarColumn::from_single_value_series( + Column::Scalar(s) => Ok(ScalarColumn::from_single_value_series( s.as_single_value_series().round_sig_figs(digits)?, s.len(), ) - .map(Column::from), + .into()), } } @@ -35,10 +35,11 @@ pub(super) fn floor(c: &Column) -> PolarsResult { match c { Column::Series(s) => s.floor().map(Column::from), Column::Scalar(s) if s.is_empty() => s.as_materialized_series().floor().map(Column::from), - Column::Scalar(s) => { - ScalarColumn::from_single_value_series(s.as_single_value_series().floor()?, s.len()) - .map(Column::from) - }, + Column::Scalar(s) => Ok(ScalarColumn::from_single_value_series( + s.as_single_value_series().floor()?, + s.len(), + ) + .into()), } } @@ -46,9 +47,10 @@ pub(super) fn ceil(c: &Column) -> PolarsResult { match c { Column::Series(s) => s.ceil().map(Column::from), Column::Scalar(s) if s.is_empty() => s.as_materialized_series().ceil().map(Column::from), - Column::Scalar(s) => { - ScalarColumn::from_single_value_series(s.as_single_value_series().ceil()?, s.len()) - .map(Column::from) - }, + Column::Scalar(s) => Ok(ScalarColumn::from_single_value_series( + s.as_single_value_series().ceil()?, + s.len(), + ) + .into()), } } diff --git a/py-polars/tests/unit/test_scalar.py b/py-polars/tests/unit/test_scalar.py index f0a845dd43ab..d1f354d8e48e 100644 --- a/py-polars/tests/unit/test_scalar.py +++ b/py-polars/tests/unit/test_scalar.py @@ -13,3 +13,26 @@ def test_invalid_broadcast() -> None: ) with pytest.raises(pl.exceptions.InvalidOperationError): df.select(pl.col("group").filter(pl.col("group") == 0), "a") + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Null, + pl.Int32, + pl.String, + pl.Enum(["foo"]), + pl.Binary, + pl.List(pl.Int32), + pl.Struct({"a": pl.Int32}), + pl.Array(pl.Int32, 1), + pl.List(pl.List(pl.Int32)), + ], +) +def test_null_literals(dtype: pl.DataType) -> None: + assert ( + pl.DataFrame([pl.Series("a", [1, 2], pl.Int64)]) + .with_columns(pl.lit(None).cast(dtype).alias("b")) + .collect_schema() + .dtypes() + ) == [pl.Int64, dtype] From e2c71501c8186b23803571882e01983501b2a89d Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Fri, 27 Sep 2024 17:58:05 +1000 Subject: [PATCH 12/33] fix: Fix `Expr.over` with `order_by` did not take effect if group keys were sorted (#18947) --- crates/polars-expr/src/expressions/window.rs | 12 +++--------- .../polars/tests/it/lazy/expressions/window.rs | 2 -- py-polars/tests/unit/operations/test_window.py | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index 5a455cf5932b..b47d1744f662 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -315,7 +315,6 @@ impl WindowExpr { fn determine_map_strategy( &self, agg_state: &AggState, - sorted_keys: bool, gb: &GroupBy, ) -> PolarsResult { match (self.mapping, agg_state) { @@ -334,13 +333,8 @@ impl WindowExpr { // no explicit aggregations, map over the groups //`(col("x").sum() * col("y")).over("groups")` (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => { - if sorted_keys { - if let GroupsProxy::Idx(g) = gb.get_groups() { - debug_assert!(g.is_sorted_flag()) - } - // GroupsProxy::Slice is always sorted - - // Note that group columns must be sorted for this to make sense!!! + if let GroupsProxy::Slice { .. } = gb.get_groups() { + // Result can be directly exploded if the input was sorted. Ok(MapStrategy::Explode) } else { Ok(MapStrategy::Map) @@ -516,7 +510,7 @@ impl PhysicalExpr for WindowExpr { let mut ac = self.run_aggregation(df, state, &gb)?; use MapStrategy::*; - match self.determine_map_strategy(ac.agg_state(), sorted_keys, &gb)? { + match self.determine_map_strategy(ac.agg_state(), &gb)? { Nothing => { let mut out = ac.flat_naive().into_owned(); diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs index 21d8a3d26bf7..fb52ac3810bb 100644 --- a/crates/polars/tests/it/lazy/expressions/window.rs +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -150,9 +150,7 @@ fn test_sort_by_in_groups() -> PolarsResult<()> { col("cars"), col("A") .sort_by([col("B")], SortMultipleOptions::default()) - .implode() .over([col("cars")]) - .explode() .alias("sorted_A_by_B"), ]) .collect()?; diff --git a/py-polars/tests/unit/operations/test_window.py b/py-polars/tests/unit/operations/test_window.py index 8171fd5f9b03..69b9cd8f0e55 100644 --- a/py-polars/tests/unit/operations/test_window.py +++ b/py-polars/tests/unit/operations/test_window.py @@ -518,3 +518,21 @@ def test_lit_window_broadcast() -> None: assert pl.DataFrame({"a": [1, 1, 2]}).select(pl.lit(0).over("a").alias("a"))[ "a" ].to_list() == [0, 0, 0] + + +def test_order_by_sorted_keys_18943() -> None: + df = pl.DataFrame( + { + "g": [1, 1, 1, 1], + "t": [4, 3, 2, 1], + "x": [10, 20, 30, 40], + } + ) + + expect = pl.DataFrame({"x": [100, 90, 70, 40]}) + + out = df.select(pl.col("x").cum_sum().over("g", order_by="t")) + assert_frame_equal(out, expect) + + out = df.set_sorted("g").select(pl.col("x").cum_sum().over("g", order_by="t")) + assert_frame_equal(out, expect) From cafc163f1c72d5926a3500d072115a137f7f6468 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 10:36:14 +0200 Subject: [PATCH 13/33] fix: Properly implement AnyValue::Binary `into_py` (#18960) --- crates/polars-python/src/conversion/any_value.rs | 4 ++-- crates/polars-python/src/dataframe/export.rs | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/crates/polars-python/src/conversion/any_value.rs b/crates/polars-python/src/conversion/any_value.rs index 70cfaaf6d3ab..9eb6401a7b57 100644 --- a/crates/polars-python/src/conversion/any_value.rs +++ b/crates/polars-python/src/conversion/any_value.rs @@ -105,8 +105,8 @@ pub(crate) fn any_value_into_py_object(av: AnyValue, py: Python) -> PyObject { let object = v.0.as_any().downcast_ref::().unwrap(); object.inner.clone() }, - AnyValue::Binary(v) => v.into_py(py), - AnyValue::BinaryOwned(v) => v.into_py(py), + AnyValue::Binary(v) => PyBytes::new_bound(py, v).into_py(py), + AnyValue::BinaryOwned(v) => PyBytes::new_bound(py, &v).into_py(py), AnyValue::Decimal(v, scale) => { let convert = utils.getattr(intern!(py, "to_py_decimal")).unwrap(); const N: usize = 3; diff --git a/crates/polars-python/src/dataframe/export.rs b/crates/polars-python/src/dataframe/export.rs index c9fe0d4a48d0..013cb32dfdd9 100644 --- a/crates/polars-python/src/dataframe/export.rs +++ b/crates/polars-python/src/dataframe/export.rs @@ -56,8 +56,11 @@ impl PyDataFrame { c.get_object(idx).map(|any| any.into()); obj.to_object(py) }, - // SAFETY: we are in bounds. - _ => unsafe { Wrap(c.get_unchecked(idx)).into_py(py) }, + _ => { + // SAFETY: we are in bounds. + let av = unsafe { c.get_unchecked(idx) }; + Wrap(av).into_py(py) + }, }), ) }), From 74b53070992bfd5b4e78b621411d0aa732d159cb Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 27 Sep 2024 11:15:01 +0200 Subject: [PATCH 14/33] feat: Use FFI to extract Series from different Polars binaries (#18964) --- crates/polars-python/src/map/lazy.rs | 45 +++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/crates/polars-python/src/map/lazy.rs b/crates/polars-python/src/map/lazy.rs index c1a680056774..77389b974085 100644 --- a/crates/polars-python/src/map/lazy.rs +++ b/crates/polars-python/src/map/lazy.rs @@ -1,6 +1,7 @@ use polars::prelude::*; +use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; -use pyo3::types::PyList; +use pyo3::types::{PyDict, PyList}; use crate::py_modules::POLARS; use crate::series::PySeries; @@ -42,9 +43,45 @@ impl ToSeries for PyObject { } }, }; - let pyseries = py_pyseries.extract::(py).unwrap(); - // Finally get the actual Series - Ok(pyseries.series) + let s = match py_pyseries.extract::(py) { + Ok(pyseries) => pyseries.series, + // This happens if the executed Polars is not from this source. + // Currently only happens in PC-workers + // For now use arrow to convert + // Eventually we must use Polars' Series Export as that can deal with + // multiple chunks + Err(_) => { + use polars::export::arrow::ffi; + let kwargs = PyDict::new_bound(py); + kwargs.set_item("in_place", true).unwrap(); + py_pyseries + .call_method_bound(py, "rechunk", (), Some(&kwargs)) + .map_err(|e| polars_err!(ComputeError: "could not rechunk: {e}"))?; + + // Prepare a pointer to receive the Array struct. + let array = Box::new(ffi::ArrowArray::empty()); + let schema = Box::new(ffi::ArrowSchema::empty()); + + let array_ptr = &*array as *const ffi::ArrowArray; + let schema_ptr = &*schema as *const ffi::ArrowSchema; + // SAFETY: + // this is unsafe as it write to the pointers we just prepared + py_pyseries + .call_method1( + py, + "_export_arrow_to_c", + (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), + ) + .map_err(|e| polars_err!(ComputeError: "{e}"))?; + + unsafe { + let field = ffi::import_field_from_c(schema.as_ref())?; + let array = ffi::import_array_from_c(*array, field.dtype)?; + Series::from_arrow(field.name, array)? + } + }, + }; + Ok(s) } } From a030634f689d146b3e1990e6bdb3d4462b864869 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 11:48:37 +0200 Subject: [PATCH 15/33] fix: Parallel evaluation of `cumulative_eval` (#18959) --- crates/polars-lazy/src/dsl/eval.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs index a3e3a97a589a..92783aa874ee 100644 --- a/crates/polars-lazy/src/dsl/eval.rs +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -81,8 +81,8 @@ pub trait ExprEvalExtension: IntoExpr + Sized { (1..c.len() + 1) .into_par_iter() .map(|len| { - let s = c.slice(0, len); - if (len - s.null_count()) >= min_periods { + let c = c.slice(0, len); + if (len - c.null_count()) >= min_periods { let df = c.clone().into_frame(); let out = phys_expr.evaluate(&df, &state)?.into_column(); finish(out) From 3342cc27d0169e7d12e2bace0ebd4d73915a40bf Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Fri, 27 Sep 2024 20:48:18 +1000 Subject: [PATCH 16/33] fix: Fix `lit().shrink_dtype()` broadcasting (#18958) --- crates/polars-plan/src/plans/aexpr/scalar.rs | 2 +- py-polars/tests/unit/expr/test_exprs.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/plans/aexpr/scalar.rs b/crates/polars-plan/src/plans/aexpr/scalar.rs index f7d681b407d4..553c8800b538 100644 --- a/crates/polars-plan/src/plans/aexpr/scalar.rs +++ b/crates/polars-plan/src/plans/aexpr/scalar.rs @@ -8,7 +8,7 @@ pub fn is_scalar_ae(node: Node, expr_arena: &Arena) -> bool { AExpr::Literal(lv) => lv.is_scalar(), AExpr::Function { options, input, .. } | AExpr::AnonymousFunction { options, input, .. } => { - if options.is_elementwise() { + if options.is_elementwise() || !options.flags.contains(FunctionFlags::CHANGES_LENGTH) { input.iter().all(|e| e.is_scalar(expr_arena)) } else { options.flags.contains(FunctionFlags::RETURNS_SCALAR) diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index b97b3c8f1288..31bf08534df7 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -645,3 +645,11 @@ def test_slice() -> None: result = df.select(pl.all().slice(1, 1)) expected = pl.DataFrame({"a": data["a"][1:2], "b": data["b"][1:2]}) assert_frame_equal(result, expected) + + +def test_function_expr_scalar_identification_18755() -> None: + # The function uses `ApplyOptions::GroupWise`, however the input is scalar. + assert_frame_equal( + pl.DataFrame({"a": [1, 2]}).with_columns(pl.lit(5).shrink_dtype().alias("b")), + pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int8)}), + ) From a0ec630b25aa847699f9c2d7389fee84749a6491 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 27 Sep 2024 12:56:26 +0200 Subject: [PATCH 17/33] feat(python): Drop python 3.8 support (#18965) --- .github/workflows/lint-python.yml | 4 +- .github/workflows/release-python.yml | 2 +- .github/workflows/test-python.yml | 2 +- py-polars/Cargo.toml | 2 +- py-polars/polars/_typing.py | 52 ++++++++----------- py-polars/polars/_utils/async_.py | 4 +- .../polars/_utils/construction/dataframe.py | 8 ++- .../polars/_utils/construction/series.py | 7 ++- py-polars/polars/_utils/construction/utils.py | 4 +- py-polars/polars/_utils/convert.py | 2 +- py-polars/polars/_utils/deprecation.py | 5 +- py-polars/polars/_utils/getitem.py | 5 +- py-polars/polars/_utils/parse/expr.py | 3 +- py-polars/polars/_utils/udfs.py | 4 +- py-polars/polars/_utils/various.py | 13 +++-- py-polars/polars/convert/general.py | 5 +- py-polars/polars/convert/normalize.py | 17 ++---- py-polars/polars/dataframe/_html.py | 3 +- py-polars/polars/dataframe/frame.py | 23 ++++---- py-polars/polars/dataframe/group_by.py | 3 +- py-polars/polars/dataframe/plotting.py | 4 +- py-polars/polars/datatypes/classes.py | 5 +- py-polars/polars/datatypes/constructor.py | 4 +- py-polars/polars/datatypes/convert.py | 3 +- py-polars/polars/datatypes/group.py | 3 +- py-polars/polars/dependencies.py | 13 ++--- py-polars/polars/expr/array.py | 3 +- py-polars/polars/expr/datetime.py | 4 +- py-polars/polars/expr/expr.py | 10 ++-- py-polars/polars/expr/list.py | 3 +- py-polars/polars/expr/string.py | 3 +- py-polars/polars/expr/struct.py | 4 +- py-polars/polars/expr/whenthen.py | 4 +- .../functions/aggregation/horizontal.py | 4 +- py-polars/polars/functions/as_datatype.py | 3 +- py-polars/polars/functions/business.py | 4 +- py-polars/polars/functions/col.py | 3 +- py-polars/polars/functions/eager.py | 5 +- py-polars/polars/functions/lazy.py | 6 ++- py-polars/polars/functions/whenthen.py | 4 +- py-polars/polars/interchange/protocol.py | 6 +-- py-polars/polars/io/_utils.py | 6 ++- py-polars/polars/io/csv/_utils.py | 4 +- py-polars/polars/io/csv/batched_reader.py | 3 +- py-polars/polars/io/csv/functions.py | 5 +- .../polars/io/database/_cursor_proxies.py | 4 +- py-polars/polars/io/database/_executor.py | 6 +-- py-polars/polars/io/database/functions.py | 3 +- py-polars/polars/io/ipc/functions.py | 4 +- py-polars/polars/io/ndjson.py | 3 +- py-polars/polars/io/plugins.py | 6 ++- py-polars/polars/io/spreadsheet/_utils.py | 5 +- .../polars/io/spreadsheet/_write_utils.py | 4 +- py-polars/polars/io/spreadsheet/functions.py | 3 +- py-polars/polars/lazyframe/frame.py | 8 ++- py-polars/polars/lazyframe/group_by.py | 4 +- py-polars/polars/ml/torch.py | 3 +- py-polars/polars/plugins.py | 4 +- py-polars/polars/schema.py | 12 ++--- py-polars/polars/selectors.py | 4 +- py-polars/polars/series/array.py | 3 +- py-polars/polars/series/datetime.py | 3 +- py-polars/polars/series/list.py | 3 +- py-polars/polars/series/series.py | 7 +-- py-polars/polars/series/string.py | 4 +- py-polars/polars/series/struct.py | 4 +- py-polars/polars/sql/context.py | 3 +- .../testing/parametric/strategies/core.py | 4 +- .../testing/parametric/strategies/data.py | 4 +- .../testing/parametric/strategies/dtype.py | 4 +- .../testing/parametric/strategies/legacy.py | 3 +- py-polars/pyproject.toml | 3 +- py-polars/tests/docs/run_doctest.py | 3 +- py-polars/tests/docs/test_user_guide.py | 2 +- py-polars/tests/unit/conftest.py | 11 ++-- .../unit/constructors/test_constructors.py | 5 +- .../tests/unit/constructors/test_dataframe.py | 6 ++- py-polars/tests/unit/dataframe/test_df.py | 3 +- .../tests/unit/dataframe/test_upsample.py | 1 - py-polars/tests/unit/datatypes/test_parse.py | 15 +++--- py-polars/tests/unit/io/cloud/test_aws.py | 3 +- .../tests/unit/io/database/test_async.py | 3 +- py-polars/tests/unit/io/database/test_read.py | 2 +- .../tests/unit/io/database/test_write.py | 8 +-- py-polars/tests/unit/io/test_plugins.py | 2 +- py-polars/tests/unit/io/test_spreadsheet.py | 4 +- py-polars/tests/unit/io/test_utils.py | 5 +- .../unit/operations/map/test_map_elements.py | 7 +-- .../unit/operations/map/test_map_groups.py | 5 +- .../temporal/test_add_business_days.py | 9 +--- .../namespaces/temporal/test_to_datetime.py | 10 +--- .../tests/unit/operations/test_comparison.py | 4 +- .../tests/unit/operations/test_cross_join.py | 11 +--- .../tests/unit/operations/test_ewm_by.py | 9 +--- .../tests/unit/operations/test_interpolate.py | 9 +--- .../tests/unit/operations/test_transpose.py | 2 +- py-polars/tests/unit/series/test_series.py | 3 +- py-polars/tests/unit/sql/test_joins.py | 30 +++++++---- py-polars/tests/unit/test_config.py | 5 +- py-polars/tests/unit/test_format.py | 9 ++-- py-polars/tests/unit/test_selectors.py | 10 +--- py-polars/tests/unit/test_string_cache.py | 2 +- py-polars/tests/unit/utils/test_utils.py | 3 +- 103 files changed, 325 insertions(+), 278 deletions(-) diff --git a/.github/workflows/lint-python.yml b/.github/workflows/lint-python.yml index 5197645a8f23..f75e5e576e02 100644 --- a/.github/workflows/lint-python.yml +++ b/.github/workflows/lint-python.yml @@ -39,7 +39,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8', '3.12'] + python-version: ['3.9', '3.12'] steps: - uses: actions/checkout@v4 @@ -63,4 +63,4 @@ jobs: # Allow untyped calls for older Python versions - name: Run mypy working-directory: py-polars - run: mypy ${{ (matrix.python-version == '3.8') && '--allow-untyped-calls' || '' }} + run: mypy ${{ (matrix.python-version == '3.9') && '--allow-untyped-calls' || '' }} diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index e7b482b42749..c44242a2a374 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -18,7 +18,7 @@ concurrency: cancel-in-progress: true env: - PYTHON_VERSION: '3.8' + PYTHON_VERSION: '3.9' CARGO_INCREMENTAL: 0 CARGO_NET_RETRY: 10 RUSTUP_MAX_RETRIES: 10 diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 23fe797f2795..4f7c7e6a7027 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -39,7 +39,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ['3.8', '3.11', '3.12'] + python-version: ['3.9', '3.12'] include: - os: windows-latest python-version: '3.12' diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 6f26a5475da4..7a783d9e7481 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -12,7 +12,7 @@ libc = { workspace = true } # Explicit dependency is needed to add bigidx in CI during release polars = { workspace = true } polars-python = { workspace = true, features = ["pymethods", "iejoin"] } -pyo3 = { workspace = true, features = ["abi3-py38", "chrono", "extension-module", "multiple-pymethods"] } +pyo3 = { workspace = true, features = ["abi3-py39", "chrono", "extension-module", "multiple-pymethods"] } [build-dependencies] built = { version = "0.7", features = ["chrono", "git2", "cargo-lock"], optional = true } diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 9b0cc722de57..1670b08aeb2f 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -1,17 +1,11 @@ from __future__ import annotations +from collections.abc import Collection, Iterable, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, - Collection, - Iterable, - List, Literal, - Mapping, Protocol, - Sequence, - Tuple, - Type, TypedDict, TypeVar, Union, @@ -55,29 +49,29 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: # Data types PolarsDataType: TypeAlias = Union["DataTypeClass", "DataType"] -PolarsTemporalType: TypeAlias = Union[Type["TemporalType"], "TemporalType"] -PolarsIntegerType: TypeAlias = Union[Type["IntegerType"], "IntegerType"] +PolarsTemporalType: TypeAlias = Union[type["TemporalType"], "TemporalType"] +PolarsIntegerType: TypeAlias = Union[type["IntegerType"], "IntegerType"] OneOrMoreDataTypes: TypeAlias = Union[PolarsDataType, Iterable[PolarsDataType]] PythonDataType: TypeAlias = Union[ - Type[int], - Type[float], - Type[bool], - Type[str], - Type["date"], - Type["time"], - Type["datetime"], - Type["timedelta"], - Type[List[Any]], - Type[Tuple[Any, ...]], - Type[bytes], - Type[object], - Type["Decimal"], - Type[None], + type[int], + type[float], + type[bool], + type[str], + type["date"], + type["time"], + type["datetime"], + type["timedelta"], + type[list[Any]], + type[tuple[Any, ...]], + type[bytes], + type[object], + type["Decimal"], + type[None], ] SchemaDefinition: TypeAlias = Union[ Mapping[str, Union[PolarsDataType, PythonDataType]], - Sequence[Union[str, Tuple[str, Union[PolarsDataType, PythonDataType, None]]]], + Sequence[Union[str, tuple[str, Union[PolarsDataType, PythonDataType, None]]]], ] SchemaDict: TypeAlias = Mapping[str, PolarsDataType] @@ -85,7 +79,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: TemporalLiteral: TypeAlias = Union["date", "time", "datetime", "timedelta"] NonNestedLiteral: TypeAlias = Union[NumericLiteral, TemporalLiteral, str, bool, bytes] # Python literal types (can convert into a `lit` expression) -PythonLiteral: TypeAlias = Union[NonNestedLiteral, List[Any]] +PythonLiteral: TypeAlias = Union[NonNestedLiteral, list[Any]] # Inputs that can convert into a `col` expression IntoExprColumn: TypeAlias = Union["Expr", "Series", str] # Inputs that can convert into an expression @@ -204,7 +198,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: # Excel IO ColumnFormatDict: TypeAlias = Mapping[ # dict of colname(s) or selector(s) to format string or dict - Union[ColumnNameOrSelector, Tuple[ColumnNameOrSelector, ...]], + Union[ColumnNameOrSelector, tuple[ColumnNameOrSelector, ...]], Union[str, Mapping[str, str]], ] ConditionalFormatDict: TypeAlias = Mapping[ @@ -214,12 +208,12 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: ] ColumnTotalsDefinition: TypeAlias = Union[ # dict of colname(s) to str, a collection of str, or a boolean - Mapping[Union[ColumnNameOrSelector, Tuple[ColumnNameOrSelector]], str], + Mapping[Union[ColumnNameOrSelector, tuple[ColumnNameOrSelector]], str], Sequence[str], bool, ] ColumnWidthsDefinition: TypeAlias = Union[ - Mapping[ColumnNameOrSelector, Union[Tuple[str, ...], int]], int + Mapping[ColumnNameOrSelector, Union[tuple[str, ...], int]], int ] RowTotalsDefinition: TypeAlias = Union[ # dict of colname to str(s), a collection of str, or a boolean @@ -234,7 +228,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: # typevars for core polars types PolarsType = TypeVar("PolarsType", "DataFrame", "LazyFrame", "Series", "Expr") FrameType = TypeVar("FrameType", "DataFrame", "LazyFrame") -BufferInfo: TypeAlias = Tuple[int, int, int] +BufferInfo: TypeAlias = tuple[int, int, int] # type alias for supported spreadsheet engines ExcelSpreadsheetEngine: TypeAlias = Literal["xlsx2csv", "openpyxl", "calamine"] diff --git a/py-polars/polars/_utils/async_.py b/py-polars/polars/_utils/async_.py index a8fed8facda7..966f28cecdb6 100644 --- a/py-polars/polars/_utils/async_.py +++ b/py-polars/polars/_utils/async_.py @@ -1,12 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, TypeVar +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Any, Generic, TypeVar from polars._utils.wrap import wrap_df from polars.dependencies import _GEVENT_AVAILABLE if TYPE_CHECKING: from asyncio.futures import Future + from collections.abc import Generator from polars.polars import PyDataFrame diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index 1ab05930f5da..f174fbc736cf 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +from collections.abc import Generator, Mapping from datetime import date, datetime, time, timedelta from functools import singledispatch from itertools import islice, zip_longest @@ -9,11 +10,6 @@ TYPE_CHECKING, Any, Callable, - Generator, - Iterable, - Mapping, - MutableMapping, - Sequence, ) import polars._reexport as pl @@ -63,6 +59,8 @@ from polars.polars import PyDataFrame if TYPE_CHECKING: + from collections.abc import Iterable, MutableMapping, Sequence + from polars import DataFrame, Expr, Series from polars._typing import ( Orientation, diff --git a/py-polars/polars/_utils/construction/series.py b/py-polars/polars/_utils/construction/series.py index 379bdbeb0a30..121378b6d873 100644 --- a/py-polars/polars/_utils/construction/series.py +++ b/py-polars/polars/_utils/construction/series.py @@ -1,16 +1,13 @@ from __future__ import annotations import contextlib +from collections.abc import Generator, Iterator from datetime import date, datetime, time, timedelta from itertools import islice from typing import ( TYPE_CHECKING, Any, Callable, - Generator, - Iterable, - Iterator, - Sequence, ) import polars._reexport as pl @@ -65,6 +62,8 @@ from polars.polars import PySeries if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from polars import DataFrame, Series from polars._typing import PolarsDataType from polars.dependencies import pandas as pd diff --git a/py-polars/polars/_utils/construction/utils.py b/py-polars/polars/_utils/construction/utils.py index de214a2dfb15..8b73728c92fc 100644 --- a/py-polars/polars/_utils/construction/utils.py +++ b/py-polars/polars/_utils/construction/utils.py @@ -2,11 +2,13 @@ import sys from functools import lru_cache -from typing import TYPE_CHECKING, Any, Callable, Sequence, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, get_type_hints from polars.dependencies import _check_for_pydantic, pydantic if TYPE_CHECKING: + from collections.abc import Sequence + import pandas as pd PANDAS_SIMPLE_NUMPY_DTYPES = { diff --git a/py-polars/polars/_utils/convert.py b/py-polars/polars/_utils/convert.py index e2f4066c62de..cc23c77f73ac 100644 --- a/py-polars/polars/_utils/convert.py +++ b/py-polars/polars/_utils/convert.py @@ -8,7 +8,6 @@ Any, Callable, NoReturn, - Sequence, no_type_check, overload, ) @@ -26,6 +25,7 @@ from polars.dependencies import _ZONEINFO_AVAILABLE, zoneinfo if TYPE_CHECKING: + from collections.abc import Sequence from datetime import date, tzinfo from decimal import Decimal diff --git a/py-polars/polars/_utils/deprecation.py b/py-polars/polars/_utils/deprecation.py index bf93cff116f7..4a84826bca86 100644 --- a/py-polars/polars/_utils/deprecation.py +++ b/py-polars/polars/_utils/deprecation.py @@ -1,14 +1,15 @@ from __future__ import annotations import inspect +from collections.abc import Sequence from functools import wraps -from typing import TYPE_CHECKING, Callable, Sequence, TypeVar +from typing import TYPE_CHECKING, Callable, TypeVar from polars._utils.various import issue_warning if TYPE_CHECKING: import sys - from typing import Mapping + from collections.abc import Mapping from polars._typing import Ambiguous diff --git a/py-polars/polars/_utils/getitem.py b/py-polars/polars/_utils/getitem.py index 991dba52bb2d..2ab09ba9af8c 100644 --- a/py-polars/polars/_utils/getitem.py +++ b/py-polars/polars/_utils/getitem.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, NoReturn, Sequence, overload +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, NoReturn, overload import polars._reexport as pl import polars.functions as F @@ -22,6 +23,8 @@ from polars.meta.index_type import get_index_type if TYPE_CHECKING: + from collections.abc import Iterable + from polars import DataFrame, Series from polars._typing import ( MultiColSelector, diff --git a/py-polars/polars/_utils/parse/expr.py b/py-polars/polars/_utils/parse/expr.py index 69e372ae6219..2213f20a1b6f 100644 --- a/py-polars/polars/_utils/parse/expr.py +++ b/py-polars/polars/_utils/parse/expr.py @@ -1,7 +1,8 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Any, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import polars._reexport as pl from polars import functions as F diff --git a/py-polars/polars/_utils/udfs.py b/py-polars/polars/_utils/udfs.py index 661c4e9cf303..0ff968ed59ec 100644 --- a/py-polars/polars/_utils/udfs.py +++ b/py-polars/polars/_utils/udfs.py @@ -16,11 +16,9 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - AbstractSet, Any, Callable, ClassVar, - Iterator, Literal, NamedTuple, Union, @@ -29,6 +27,8 @@ from polars._utils.various import re_escape if TYPE_CHECKING: + from collections.abc import Iterator + from collections.abc import Set as AbstractSet from dis import Instruction if sys.version_info >= (3, 10): diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index f25ab95e23c2..4acad1df5237 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -5,18 +5,21 @@ import re import sys import warnings -from collections.abc import MappingView, Sized +from collections.abc import ( + Collection, + Generator, + Iterable, + MappingView, + Sequence, + Sized, +) from enum import Enum from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, - Collection, - Generator, - Iterable, Literal, - Sequence, TypeVar, overload, ) diff --git a/py-polars/polars/convert/general.py b/py-polars/polars/convert/general.py index 011054169d04..cee4b925e9e9 100644 --- a/py-polars/polars/convert/general.py +++ b/py-polars/polars/convert/general.py @@ -3,7 +3,8 @@ import io import itertools import re -from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, overload +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING, Any, overload import polars._reexport as pl from polars import functions as F @@ -24,6 +25,8 @@ from polars.exceptions import NoDataError if TYPE_CHECKING: + from collections.abc import Mapping + from polars import DataFrame, Series from polars._typing import Orientation, SchemaDefinition, SchemaDict from polars.dependencies import numpy as np diff --git a/py-polars/polars/convert/normalize.py b/py-polars/polars/convert/normalize.py index 7f90260aaac4..d6e24ddc8af5 100644 --- a/py-polars/polars/convert/normalize.py +++ b/py-polars/polars/convert/normalize.py @@ -4,27 +4,20 @@ import json from collections import abc -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any from polars._utils.unstable import unstable from polars.dataframe import DataFrame from polars.datatypes.constants import N_INFER_DEFAULT if TYPE_CHECKING: - from polars.schema import Schema - -import sys + from collections.abc import Sequence -if sys.version_info >= (3, 9): + from polars.schema import Schema - def _remove_prefix(text: str, prefix: str) -> str: - return text.removeprefix(prefix) -else: - def _remove_prefix(text: str, prefix: str) -> str: - if text.startswith(prefix): - return text[len(prefix) :] - return text +def _remove_prefix(text: str, prefix: str) -> str: + return text.removeprefix(prefix) def _simple_json_normalize( diff --git a/py-polars/polars/dataframe/_html.py b/py-polars/polars/dataframe/_html.py index 62464873e42a..6f034eab0f41 100644 --- a/py-polars/polars/dataframe/_html.py +++ b/py-polars/polars/dataframe/_html.py @@ -4,11 +4,12 @@ import os from textwrap import dedent -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING from polars.dependencies import html if TYPE_CHECKING: + from collections.abc import Iterable from types import TracebackType from polars import DataFrame diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 7fafdd058735..911d057a1f46 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -6,7 +6,12 @@ import os import random from collections import defaultdict -from collections.abc import Sized +from collections.abc import ( + Generator, + Iterable, + Sequence, + Sized, +) from io import BytesIO, StringIO from operator import itemgetter from pathlib import Path @@ -16,13 +21,7 @@ Any, Callable, ClassVar, - Collection, - Generator, - Iterable, - Iterator, - Mapping, NoReturn, - Sequence, TypeVar, get_args, overload, @@ -115,6 +114,11 @@ if TYPE_CHECKING: import sys + from collections.abc import ( + Collection, + Iterator, + Mapping, + ) from datetime import timedelta from io import IOBase from typing import Literal @@ -3973,8 +3977,9 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: else (connection, False) ) with ( - conn if can_close_conn else contextlib.nullcontext() - ), conn.cursor() as cursor: + conn if can_close_conn else contextlib.nullcontext(), + conn.cursor() as cursor, + ): catalog, db_schema, unpacked_table_name = unpack_table_name(table_name) n_rows: int if adbc_version >= (0, 7): diff --git a/py-polars/polars/dataframe/group_by.py b/py-polars/polars/dataframe/group_by.py index 204afc3d25c1..ce3cf054a7a0 100644 --- a/py-polars/polars/dataframe/group_by.py +++ b/py-polars/polars/dataframe/group_by.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterable +from typing import TYPE_CHECKING, Callable from polars import functions as F from polars._utils.convert import parse_as_duration_string @@ -8,6 +8,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable from datetime import timedelta from polars import DataFrame diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py index 75a2b92aa09e..11828cd54ebc 100644 --- a/py-polars/polars/dataframe/plotting.py +++ b/py-polars/polars/dataframe/plotting.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Union +from typing import TYPE_CHECKING, Callable, Union from polars.dependencies import altair as alt @@ -27,7 +27,7 @@ from typing_extensions import Unpack Encoding: TypeAlias = Union[X, Y, Color, Order, Size, Tooltip] - Encodings: TypeAlias = Dict[str, Encoding] + Encodings: TypeAlias = dict[str, Encoding] def _maybe_extract_shorthand(encoding: Encoding) -> Encoding: diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 8906f2459c28..bb538b4f01e8 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -2,9 +2,10 @@ import contextlib from collections import OrderedDict +from collections.abc import Mapping from datetime import timezone from inspect import isclass -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any import polars._reexport as pl import polars.datatypes @@ -14,6 +15,8 @@ from polars.polars import dtype_str_repr as _dtype_str_repr if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + from polars import Series from polars._typing import ( CategoricalOrdering, diff --git a/py-polars/polars/datatypes/constructor.py b/py-polars/polars/datatypes/constructor.py index 63e0d912cbf8..bcd77f9b54a0 100644 --- a/py-polars/polars/datatypes/constructor.py +++ b/py-polars/polars/datatypes/constructor.py @@ -2,7 +2,7 @@ import functools from decimal import Decimal as PyDecimal -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable from polars import datatypes as dt from polars.dependencies import numpy as np @@ -16,6 +16,8 @@ _DOCUMENTING = True if TYPE_CHECKING: + from collections.abc import Sequence + from polars._typing import PolarsDataType if not _DOCUMENTING: diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 91fe61da4b6f..d46d8c111581 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -4,9 +4,10 @@ import functools import re import sys +from collections.abc import Collection from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal -from typing import TYPE_CHECKING, Any, Collection, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from polars.datatypes.classes import ( Array, diff --git a/py-polars/polars/datatypes/group.py b/py-polars/polars/datatypes/group.py index f30153ed6e7a..3332dd4a7e7f 100644 --- a/py-polars/polars/datatypes/group.py +++ b/py-polars/polars/datatypes/group.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from polars.datatypes.classes import ( Array, @@ -27,6 +27,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable from polars._typing import ( PolarsDataType, diff --git a/py-polars/polars/dependencies.py b/py-polars/polars/dependencies.py index 10548da8c904..9770c2035ce8 100644 --- a/py-polars/polars/dependencies.py +++ b/py-polars/polars/dependencies.py @@ -2,11 +2,12 @@ import re import sys -from functools import lru_cache +from collections.abc import Hashable +from functools import cache from importlib import import_module from importlib.util import find_spec from types import ModuleType -from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast _ALTAIR_AVAILABLE = True _DELTALAKE_AVAILABLE = True @@ -149,6 +150,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: import json import pickle import subprocess + import zoneinfo import altair import deltalake @@ -161,11 +163,6 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: import pyarrow import pydantic import pyiceberg - - if sys.version_info >= (3, 9): - import zoneinfo - else: - from backports import zoneinfo else: # infrequently-used builtins dataclasses, _ = _lazy_import("dataclasses") @@ -193,7 +190,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: gevent, _GEVENT_AVAILABLE = _lazy_import("gevent") -@lru_cache(maxsize=None) +@cache def _might_be(cls: type, type_: str) -> bool: # infer whether the given class "might" be associated with the given # module (in which case it's reasonable to do a real isinstance check; diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 302b0b024adf..928e3149c35d 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Callable from polars._utils.parse import parse_into_expression from polars._utils.wrap import wrap_expr diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index a35674412ccf..9aaed1352d09 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime as dt -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING import polars._reexport as pl from polars import functions as F @@ -13,6 +13,8 @@ from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Int32 if TYPE_CHECKING: + from collections.abc import Iterable + from polars import Expr from polars._typing import ( Ambiguous, diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 1d3b2a07fdcb..341f99876b21 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4,6 +4,7 @@ import math import operator import warnings +from collections.abc import Collection, Mapping, Sequence from datetime import timedelta from functools import reduce from io import BytesIO, StringIO @@ -13,13 +14,7 @@ Any, Callable, ClassVar, - Collection, - FrozenSet, - Iterable, - Mapping, NoReturn, - Sequence, - Set, TypeVar, ) @@ -69,6 +64,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable from io import IOBase from polars import DataFrame, LazyFrame, Series @@ -5771,7 +5767,7 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Expr: └───────────┴──────────────────┴──────────┘ """ if isinstance(other, Collection) and not isinstance(other, str): - if isinstance(other, (Set, FrozenSet)): + if isinstance(other, (set, frozenset)): other = list(other) other = F.lit(pl.Series(other))._pyexpr else: diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 5560f368767b..48b4d1da9c49 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1,7 +1,8 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Any, Callable, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable import polars._reexport as pl from polars import functions as F diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 1b04ab7febad..7582758d5921 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Mapping +from collections.abc import Mapping +from typing import TYPE_CHECKING import polars._reexport as pl from polars import functions as F diff --git a/py-polars/polars/expr/struct.py b/py-polars/polars/expr/struct.py index 09b9c1688fbb..57b8b6eddfb3 100644 --- a/py-polars/polars/expr/struct.py +++ b/py-polars/polars/expr/struct.py @@ -1,12 +1,14 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Iterable, Sequence +from typing import TYPE_CHECKING from polars._utils.parse import parse_into_list_of_expressions from polars._utils.wrap import wrap_expr if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from polars import Expr from polars._typing import IntoExpr diff --git a/py-polars/polars/expr/whenthen.py b/py-polars/polars/expr/whenthen.py index 752ed1b44c06..65ad722f2cce 100644 --- a/py-polars/polars/expr/whenthen.py +++ b/py-polars/polars/expr/whenthen.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any import polars.functions as F from polars._utils.parse import ( @@ -11,6 +11,8 @@ from polars.expr.expr import Expr if TYPE_CHECKING: + from collections.abc import Iterable + from polars._typing import IntoExpr from polars.polars import PyExpr diff --git a/py-polars/polars/functions/aggregation/horizontal.py b/py-polars/polars/functions/aggregation/horizontal.py index 121f1bb41497..5406a77d287d 100644 --- a/py-polars/polars/functions/aggregation/horizontal.py +++ b/py-polars/polars/functions/aggregation/horizontal.py @@ -1,7 +1,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING import polars.functions as F from polars._utils.parse import parse_into_list_of_expressions @@ -12,6 +12,8 @@ import polars.polars as plr if TYPE_CHECKING: + from collections.abc import Iterable + from polars import Expr from polars._typing import IntoExpr diff --git a/py-polars/polars/functions/as_datatype.py b/py-polars/polars/functions/as_datatype.py index 5937e5c5b091..30398daa01d9 100644 --- a/py-polars/polars/functions/as_datatype.py +++ b/py-polars/polars/functions/as_datatype.py @@ -1,7 +1,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Iterable, overload +from typing import TYPE_CHECKING, overload from polars import functions as F from polars._utils.parse import ( @@ -16,6 +16,7 @@ if TYPE_CHECKING: + from collections.abc import Iterable from typing import Literal from polars import Expr, Series diff --git a/py-polars/polars/functions/business.py b/py-polars/polars/functions/business.py index dc70a4ba3679..2f92c40e987e 100644 --- a/py-polars/polars/functions/business.py +++ b/py-polars/polars/functions/business.py @@ -2,7 +2,7 @@ import contextlib from datetime import date -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING from polars._utils.parse import parse_into_expression from polars._utils.wrap import wrap_expr @@ -11,6 +11,8 @@ import polars.polars as plr if TYPE_CHECKING: + from collections.abc import Iterable + from polars import Expr from polars._typing import IntoExprColumn diff --git a/py-polars/polars/functions/col.py b/py-polars/polars/functions/col.py index 01cca3c4fe21..b354cd878176 100644 --- a/py-polars/polars/functions/col.py +++ b/py-polars/polars/functions/col.py @@ -1,7 +1,8 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Iterable +from collections.abc import Iterable +from typing import TYPE_CHECKING from polars._utils.wrap import wrap_expr from polars.datatypes import is_polars_dtype diff --git a/py-polars/polars/functions/eager.py b/py-polars/polars/functions/eager.py index e8cbb00e3dca..e65f4aa949c0 100644 --- a/py-polars/polars/functions/eager.py +++ b/py-polars/polars/functions/eager.py @@ -1,9 +1,10 @@ from __future__ import annotations import contextlib +from collections.abc import Sequence from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Iterable, Sequence, get_args +from typing import TYPE_CHECKING, get_args import polars._reexport as pl from polars import functions as F @@ -16,6 +17,8 @@ import polars.polars as plr if TYPE_CHECKING: + from collections.abc import Iterable + from polars import DataFrame, Expr, LazyFrame, Series from polars._typing import FrameType, JoinStrategy, PolarsType diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index c78996e07753..61cead9871d8 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1,7 +1,8 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, overload +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable, overload import polars._reexport as pl import polars.functions as F @@ -20,7 +21,8 @@ import polars.polars as plr if TYPE_CHECKING: - from typing import Awaitable, Collection, Literal + from collections.abc import Awaitable, Collection, Iterable + from typing import Literal from polars import DataFrame, Expr, LazyFrame, Series from polars._typing import ( diff --git a/py-polars/polars/functions/whenthen.py b/py-polars/polars/functions/whenthen.py index e72708d311e7..b9d9040cceb3 100644 --- a/py-polars/polars/functions/whenthen.py +++ b/py-polars/polars/functions/whenthen.py @@ -1,7 +1,7 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any import polars._reexport as pl from polars._utils.parse import parse_predicates_constraints_into_expression @@ -10,6 +10,8 @@ import polars.polars as plr if TYPE_CHECKING: + from collections.abc import Iterable + from polars._typing import IntoExprColumn diff --git a/py-polars/polars/interchange/protocol.py b/py-polars/polars/interchange/protocol.py index 4eda7fa95f2d..4865baa35ae6 100644 --- a/py-polars/polars/interchange/protocol.py +++ b/py-polars/polars/interchange/protocol.py @@ -5,11 +5,8 @@ TYPE_CHECKING, Any, ClassVar, - Iterable, Literal, Protocol, - Sequence, - Tuple, TypedDict, ) @@ -17,6 +14,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable, Sequence from polars.interchange.buffer import PolarsBuffer from polars.interchange.column import PolarsColumn @@ -71,7 +69,7 @@ class DtypeKind(IntEnum): CATEGORICAL = 23 -Dtype: TypeAlias = Tuple[DtypeKind, int, str, str] # see Column.dtype +Dtype: TypeAlias = tuple[DtypeKind, int, str, str] # see Column.dtype class ColumnNullType(IntEnum): diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index e8971b08660d..68d4b604d6a6 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -5,12 +5,16 @@ from contextlib import contextmanager from io import BytesIO, StringIO from pathlib import Path -from typing import IO, Any, ContextManager, Iterator, Sequence, overload +from typing import IO, TYPE_CHECKING, Any, overload from polars._utils.various import is_int_sequence, is_str_sequence, normalize_filepath from polars.dependencies import _FSSPEC_AVAILABLE, fsspec from polars.exceptions import NoDataError +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + from contextlib import AbstractContextManager as ContextManager + def parse_columns_arg( columns: Sequence[str] | Sequence[int] | str | int | None, diff --git a/py-polars/polars/io/csv/_utils.py b/py-polars/polars/io/csv/_utils.py index b4bbb055c3a9..bdc910de5aa4 100644 --- a/py-polars/polars/io/csv/_utils.py +++ b/py-polars/polars/io/csv/_utils.py @@ -1,8 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Sequence + from polars import DataFrame diff --git a/py-polars/polars/io/csv/batched_reader.py b/py-polars/polars/io/csv/batched_reader.py index 2bbb1583317c..6866a0b79723 100644 --- a/py-polars/polars/io/csv/batched_reader.py +++ b/py-polars/polars/io/csv/batched_reader.py @@ -1,7 +1,8 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING from polars._utils.various import ( _process_null_values, diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 71f71ef6ce68..2e0562ceaee7 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -2,9 +2,10 @@ import contextlib import os +from collections.abc import Sequence from io import BytesIO, StringIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Mapping, Sequence +from typing import IO, TYPE_CHECKING, Any, Callable import polars._reexport as pl import polars.functions as F @@ -30,6 +31,8 @@ from polars.polars import PyDataFrame, PyLazyFrame if TYPE_CHECKING: + from collections.abc import Mapping + from polars import DataFrame, LazyFrame from polars._typing import CsvEncoding, PolarsDataType, SchemaDict diff --git a/py-polars/polars/io/database/_cursor_proxies.py b/py-polars/polars/io/database/_cursor_proxies.py index 129f7609759a..d25abc3acce5 100644 --- a/py-polars/polars/io/database/_cursor_proxies.py +++ b/py-polars/polars/io/database/_cursor_proxies.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from polars.io.database._utils import _run_async if TYPE_CHECKING: import sys - from collections.abc import Coroutine + from collections.abc import Coroutine, Iterable import pyarrow as pa diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index b45301b87da2..278e3e8e0738 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -1,10 +1,10 @@ from __future__ import annotations import re -from collections.abc import Coroutine +from collections.abc import Coroutine, Sequence from contextlib import suppress from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, Iterable, Sequence +from typing import TYPE_CHECKING, Any from polars import functions as F from polars._utils.various import parse_version @@ -22,7 +22,7 @@ if TYPE_CHECKING: import sys - from collections.abc import Iterator + from collections.abc import Iterable, Iterator from types import TracebackType import pyarrow as pa diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index 098dabfdd90f..ac5dcaaac3e7 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Any, Iterable, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, overload from polars.datatypes import N_INFER_DEFAULT from polars.dependencies import import_optional @@ -10,6 +10,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Iterable if sys.version_info >= (3, 10): from typing import TypeAlias diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 43fbc8136de2..6d64a560d094 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -3,7 +3,7 @@ import contextlib import os from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Sequence +from typing import IO, TYPE_CHECKING, Any import polars._reexport as pl import polars.functions as F @@ -28,6 +28,8 @@ from polars.polars import read_ipc_schema as _read_ipc_schema if TYPE_CHECKING: + from collections.abc import Sequence + from polars import DataFrame, DataType, LazyFrame from polars._typing import SchemaDict diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 33e85914ce34..7a5fb2c0d1e6 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -1,9 +1,10 @@ from __future__ import annotations import contextlib +from collections.abc import Sequence from io import BytesIO, StringIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Sequence +from typing import IO, TYPE_CHECKING, Any from polars._utils.deprecation import deprecate_renamed_parameter from polars._utils.various import is_path_or_str_sequence, normalize_filepath diff --git a/py-polars/polars/io/plugins.py b/py-polars/polars/io/plugins.py index 02f598515c1e..8c9e6581bf7e 100644 --- a/py-polars/polars/io/plugins.py +++ b/py-polars/polars/io/plugins.py @@ -2,13 +2,15 @@ import os import sys -from typing import TYPE_CHECKING, Callable, Iterator +from collections.abc import Iterator +from typing import TYPE_CHECKING, Callable import polars._reexport as pl from polars._utils.unstable import unstable if TYPE_CHECKING: - from typing import Callable, Iterator + from collections.abc import Iterator + from typing import Callable from polars import DataFrame, Expr, LazyFrame from polars._typing import SchemaDict diff --git a/py-polars/polars/io/spreadsheet/_utils.py b/py-polars/polars/io/spreadsheet/_utils.py index c7f647c9b01d..c535662fa33b 100644 --- a/py-polars/polars/io/spreadsheet/_utils.py +++ b/py-polars/polars/io/spreadsheet/_utils.py @@ -2,7 +2,10 @@ from contextlib import contextmanager from pathlib import Path -from typing import Any, Iterator, cast +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from collections.abc import Iterator @contextmanager diff --git a/py-polars/polars/io/spreadsheet/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py index ba4032f5293f..8489b359a416 100644 --- a/py-polars/polars/io/spreadsheet/_write_utils.py +++ b/py-polars/polars/io/spreadsheet/_write_utils.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Sequence from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Sequence, overload +from typing import TYPE_CHECKING, Any, overload from polars import functions as F from polars.datatypes import ( @@ -21,6 +22,7 @@ from polars.selectors import _expand_selector_dicts, _expand_selectors, numeric if TYPE_CHECKING: + from collections.abc import Iterable from typing import Literal from xlsxwriter import Workbook diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index dd445f9e9c97..2b04fc1ac4dc 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -1,10 +1,11 @@ from __future__ import annotations import re +from collections.abc import Sequence from datetime import time from io import BufferedReader, BytesIO, StringIO, TextIOWrapper from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, NoReturn, Sequence, overload +from typing import IO, TYPE_CHECKING, Any, Callable, NoReturn, overload import polars._reexport as pl from polars import from_arrow diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 60e8b803421c..9e65f918d385 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -3,6 +3,7 @@ import contextlib import os import warnings +from collections.abc import Collection, Mapping from datetime import date, datetime, time, timedelta from functools import lru_cache, partial, reduce from io import BytesIO, StringIO @@ -13,11 +14,7 @@ Any, Callable, ClassVar, - Collection, - Iterable, - Mapping, NoReturn, - Sequence, TypeVar, overload, ) @@ -91,8 +88,9 @@ if TYPE_CHECKING: import sys + from collections.abc import Awaitable, Iterable, Sequence from io import IOBase - from typing import Awaitable, Literal + from typing import Literal import pyarrow as pa diff --git a/py-polars/polars/lazyframe/group_by.py b/py-polars/polars/lazyframe/group_by.py index 8a469e8b1e5d..1143ea9c4831 100644 --- a/py-polars/polars/lazyframe/group_by.py +++ b/py-polars/polars/lazyframe/group_by.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Iterable +from typing import TYPE_CHECKING, Callable from polars import functions as F from polars._utils.deprecation import deprecate_renamed_function @@ -8,6 +8,8 @@ from polars._utils.wrap import wrap_ldf if TYPE_CHECKING: + from collections.abc import Iterable + from polars import DataFrame, LazyFrame from polars._typing import IntoExpr, RollingInterpolationMethod, SchemaDict from polars.polars import PyLazyGroupBy diff --git a/py-polars/polars/ml/torch.py b/py-polars/polars/ml/torch.py index 134e1c260e63..f92f2ab30fe9 100644 --- a/py-polars/polars/ml/torch.py +++ b/py-polars/polars/ml/torch.py @@ -1,7 +1,7 @@ # mypy: disable-error-code="unused-ignore" from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING from polars._utils.unstable import issue_unstable_warning from polars.dataframe import DataFrame @@ -10,6 +10,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Sequence from torch import Tensor, memory_format diff --git a/py-polars/polars/plugins.py b/py-polars/polars/plugins.py index e295b87a63a6..a501f8b5ed09 100644 --- a/py-polars/polars/plugins.py +++ b/py-polars/polars/plugins.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from polars._utils.parse import parse_into_list_of_expressions from polars._utils.wrap import wrap_expr @@ -12,6 +12,8 @@ import polars.polars as plr if TYPE_CHECKING: + from collections.abc import Iterable + from polars import Expr from polars._typing import IntoExpr diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index a099b67aac2c..72eb8b86d25e 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -2,18 +2,18 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING +from polars.datatypes import DataType from polars.datatypes._parse import parse_into_dtype +BaseSchema = OrderedDict[str, DataType] + if TYPE_CHECKING: + from collections.abc import Iterable + from polars._typing import PythonDataType - from polars.datatypes import DataType - BaseSchema = OrderedDict[str, DataType] -else: - # Python 3.8 does not support generic OrderedDict at runtime - BaseSchema = OrderedDict __all__ = ["Schema"] diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 36b26068a77f..2631f222612a 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -1,16 +1,14 @@ from __future__ import annotations +from collections.abc import Collection, Mapping, Sequence from datetime import timezone from functools import reduce from operator import or_ from typing import ( TYPE_CHECKING, Any, - Collection, Literal, - Mapping, NoReturn, - Sequence, overload, ) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 3051228ca941..877ec303fcfb 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Callable from polars import functions as F from polars._utils.wrap import wrap_s from polars.series.utils import expr_dispatch if TYPE_CHECKING: + from collections.abc import Sequence from datetime import date, datetime, time from polars import Series diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index e4da43de78c7..dcf65ff15312 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING from polars._utils.deprecation import deprecate_function from polars._utils.unstable import unstable @@ -9,6 +9,7 @@ if TYPE_CHECKING: import datetime as dt + from collections.abc import Iterable from polars import Expr, Series from polars._typing import ( diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index c6480e3ddbff..cf70f5225f56 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -1,12 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable from polars import functions as F from polars._utils.wrap import wrap_s from polars.series.utils import expr_dispatch if TYPE_CHECKING: + from collections.abc import Sequence from datetime import date, datetime, time from polars import Expr, Series diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 506c52ed9f7a..c66d1f0a4abc 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3,6 +3,7 @@ import contextlib import math import os +from collections.abc import Iterable, Sequence from contextlib import nullcontext from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal @@ -11,13 +12,8 @@ Any, Callable, ClassVar, - Collection, - Generator, - Iterable, Literal, - Mapping, NoReturn, - Sequence, Union, overload, ) @@ -114,6 +110,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Collection, Generator, Mapping import jax import numpy.typing as npt diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index e2104f67af44..af9ce66850c4 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING from polars._utils.deprecation import deprecate_function from polars._utils.unstable import unstable @@ -9,6 +9,8 @@ from polars.series.utils import expr_dispatch if TYPE_CHECKING: + from collections.abc import Mapping + from polars import Expr, Series from polars._typing import ( Ambiguous, diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index 750ccfd51275..e8137a23be32 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING from polars._utils.various import BUILDING_SPHINX_DOCS, sphinx_accessor from polars._utils.wrap import wrap_df @@ -8,6 +8,8 @@ from polars.series.utils import expr_dispatch if TYPE_CHECKING: + from collections.abc import Sequence + from polars import DataFrame, Series from polars.polars import PySeries elif BUILDING_SPHINX_DOCS: diff --git a/py-polars/polars/sql/context.py b/py-polars/polars/sql/context.py index 85101e65ed4b..c48290e547c4 100644 --- a/py-polars/polars/sql/context.py +++ b/py-polars/polars/sql/context.py @@ -5,9 +5,7 @@ from typing import ( TYPE_CHECKING, Callable, - Collection, Generic, - Mapping, Union, overload, ) @@ -30,6 +28,7 @@ if TYPE_CHECKING: import sys + from collections.abc import Collection, Mapping from types import TracebackType from typing import Any, Final, Literal diff --git a/py-polars/polars/testing/parametric/strategies/core.py b/py-polars/polars/testing/parametric/strategies/core.py index b3aef2e51301..984853afe057 100644 --- a/py-polars/polars/testing/parametric/strategies/core.py +++ b/py-polars/polars/testing/parametric/strategies/core.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Collection, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Any, overload import hypothesis.strategies as st from hypothesis.errors import InvalidArgument @@ -16,6 +17,7 @@ from polars.testing.parametric.strategies.dtype import _instantiate_dtype, dtypes if TYPE_CHECKING: + from collections.abc import Collection, Sequence from typing import Literal from hypothesis.strategies import DrawFn, SearchStrategy diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index adf94379326b..f63b6d8a8945 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -3,8 +3,9 @@ from __future__ import annotations import decimal +from collections.abc import Mapping from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal import hypothesis.strategies as st from hypothesis.errors import InvalidArgument @@ -60,6 +61,7 @@ ) if TYPE_CHECKING: + from collections.abc import Sequence from datetime import date, time from hypothesis.strategies import SearchStrategy diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index 96dd928d2b84..fff1ad58c726 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Collection, Sequence +from typing import TYPE_CHECKING import hypothesis.strategies as st from hypothesis.errors import InvalidArgument @@ -35,6 +35,8 @@ ) if TYPE_CHECKING: + from collections.abc import Collection, Sequence + from hypothesis.strategies import DrawFn, SearchStrategy from polars._typing import CategoricalOrdering, PolarsDataType, TimeUnit diff --git a/py-polars/polars/testing/parametric/strategies/legacy.py b/py-polars/polars/testing/parametric/strategies/legacy.py index 7e238594a5c3..e5bc328f1ed5 100644 --- a/py-polars/polars/testing/parametric/strategies/legacy.py +++ b/py-polars/polars/testing/parametric/strategies/legacy.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any import hypothesis.strategies as st from hypothesis.errors import InvalidArgument diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index fed084292021..0cfc6b9e43c7 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -10,7 +10,7 @@ authors = [ { name = "Ritchie Vink", email = "ritchie46@gmail.com" }, ] license = { file = "LICENSE" } -requires-python = ">=3.8" +requires-python = ">=3.9" keywords = ["dataframe", "arrow", "out-of-core"] classifiers = [ @@ -22,7 +22,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 353b9cfb0dd0..404070deae85 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -37,11 +37,12 @@ import warnings from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Iterator +from typing import TYPE_CHECKING, Any import polars as pl if TYPE_CHECKING: + from collections.abc import Iterator from types import ModuleType diff --git a/py-polars/tests/docs/test_user_guide.py b/py-polars/tests/docs/test_user_guide.py index 7cae31e8c0c3..69586f002e8d 100644 --- a/py-polars/tests/docs/test_user_guide.py +++ b/py-polars/tests/docs/test_user_guide.py @@ -2,8 +2,8 @@ import os import runpy +from collections.abc import Iterator from pathlib import Path -from typing import Iterator import matplotlib as mpl import pytest diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py index 7335f8f46835..ead5bee2c265 100644 --- a/py-polars/tests/unit/conftest.py +++ b/py-polars/tests/unit/conftest.py @@ -6,7 +6,7 @@ import string import sys import tracemalloc -from typing import Any, Generator, List, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import pytest @@ -14,6 +14,9 @@ import polars as pl from polars.testing.parametric import load_profile +if TYPE_CHECKING: + from collections.abc import Generator + load_profile( profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] ) @@ -128,7 +131,7 @@ def str_ints_df() -> pl.DataFrame: @pytest.fixture(params=ISO8601_FORMATS_DATETIME) def iso8601_format_datetime(request: pytest.FixtureRequest) -> list[str]: - return cast(List[str], request.param) + return cast(list[str], request.param) ISO8601_TZ_AWARE_FORMATS_DATETIME = [] @@ -151,7 +154,7 @@ def iso8601_format_datetime(request: pytest.FixtureRequest) -> list[str]: @pytest.fixture(params=ISO8601_TZ_AWARE_FORMATS_DATETIME) def iso8601_tz_aware_format_datetime(request: pytest.FixtureRequest) -> list[str]: - return cast(List[str], request.param) + return cast(list[str], request.param) ISO8601_FORMATS_DATE = [] @@ -163,7 +166,7 @@ def iso8601_tz_aware_format_datetime(request: pytest.FixtureRequest) -> list[str @pytest.fixture(params=ISO8601_FORMATS_DATE) def iso8601_format_date(request: pytest.FixtureRequest) -> list[str]: - return cast(List[str], request.param) + return cast(list[str], request.param) class MemoryUsage: diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index dbfe77b46c97..d340433ddf10 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -4,7 +4,7 @@ from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from random import shuffle -from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple +from typing import TYPE_CHECKING, Any, Literal, NamedTuple import numpy as np import pandas as pd @@ -22,7 +22,6 @@ if TYPE_CHECKING: from collections.abc import Callable - from zoneinfo import ZoneInfo from polars._typing import PolarsDataType @@ -281,7 +280,7 @@ class PageView(BaseModel): "top": 123 }] """ - adapter: TypeAdapter[Any] = TypeAdapter(List[PageView]) + adapter: TypeAdapter[Any] = TypeAdapter(list[PageView]) models = adapter.validate_json(data_json) result = pl.DataFrame(models) diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py index 3703475aa438..5e56630e7552 100644 --- a/py-polars/tests/unit/constructors/test_dataframe.py +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -2,13 +2,17 @@ import sys from collections import OrderedDict -from typing import Any, Iterator, Mapping +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any import pytest import polars as pl from polars.exceptions import DataOrientationWarning, InvalidOperationError +if TYPE_CHECKING: + from collections.abc import Iterator + def test_df_mixed_dtypes_string() -> None: data = {"x": [["abc", 12, 34.5]], "y": [1]} diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 8798dc89ed22..2389bef36c88 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -7,7 +7,7 @@ from decimal import Decimal from io import BytesIO from operator import floordiv, truediv -from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, cast import numpy as np import pyarrow as pa @@ -32,6 +32,7 @@ from tests.unit.conftest import INTEGER_DTYPES if TYPE_CHECKING: + from collections.abc import Iterator, Sequence from zoneinfo import ZoneInfo from polars import Expr diff --git a/py-polars/tests/unit/dataframe/test_upsample.py b/py-polars/tests/unit/dataframe/test_upsample.py index 21160ad54df8..28bcbf13f401 100644 --- a/py-polars/tests/unit/dataframe/test_upsample.py +++ b/py-polars/tests/unit/dataframe/test_upsample.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from datetime import timezone - from zoneinfo import ZoneInfo from polars._typing import FillNullStrategy, PolarsIntegerType diff --git a/py-polars/tests/unit/datatypes/test_parse.py b/py-polars/tests/unit/datatypes/test_parse.py index c95763033b32..0979292e8e50 100644 --- a/py-polars/tests/unit/datatypes/test_parse.py +++ b/py-polars/tests/unit/datatypes/test_parse.py @@ -4,12 +4,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, ForwardRef, - List, NamedTuple, Optional, - Tuple, Union, ) @@ -63,9 +60,9 @@ def test_parse_py_type_into_dtype(input: Any, expected: PolarsDataType) -> None: @pytest.mark.parametrize( ("input", "expected"), [ - (List[int], pl.List(pl.Int64())), - (Tuple[str, ...], pl.List(pl.String())), - (Tuple[datetime, datetime], pl.List(pl.Datetime("us"))), + (list[int], pl.List(pl.Int64())), + (tuple[str, ...], pl.List(pl.String())), + (tuple[datetime, datetime], pl.List(pl.Datetime("us"))), ], ) def test_parse_generic_into_dtype(input: Any, expected: PolarsDataType) -> None: @@ -76,9 +73,9 @@ def test_parse_generic_into_dtype(input: Any, expected: PolarsDataType) -> None: @pytest.mark.parametrize( "input", [ - Dict[str, float], - Tuple[int, str], - Tuple[int, float, float], + dict[str, float], + tuple[int, str], + tuple[int, float, float], ], ) def test_parse_generic_into_dtype_invalid(input: Any) -> None: diff --git a/py-polars/tests/unit/io/cloud/test_aws.py b/py-polars/tests/unit/io/cloud/test_aws.py index 9a004a7d825f..6f2116421822 100644 --- a/py-polars/tests/unit/io/cloud/test_aws.py +++ b/py-polars/tests/unit/io/cloud/test_aws.py @@ -1,7 +1,7 @@ from __future__ import annotations import multiprocessing -from typing import TYPE_CHECKING, Any, Callable, Iterator +from typing import TYPE_CHECKING, Any, Callable import boto3 import pytest @@ -11,6 +11,7 @@ from polars.testing import assert_frame_equal if TYPE_CHECKING: + from collections.abc import Iterator from pathlib import Path pytestmark = [ diff --git a/py-polars/tests/unit/io/database/test_async.py b/py-polars/tests/unit/io/database/test_async.py index a8492e7a5276..3bdea8207d2c 100644 --- a/py-polars/tests/unit/io/database/test_async.py +++ b/py-polars/tests/unit/io/database/test_async.py @@ -2,7 +2,7 @@ import asyncio from math import ceil -from typing import TYPE_CHECKING, Any, Iterable, overload +from typing import TYPE_CHECKING, Any, overload import pytest import sqlalchemy @@ -13,6 +13,7 @@ from polars.testing import assert_frame_equal if TYPE_CHECKING: + from collections.abc import Iterable from pathlib import Path SURREAL_MOCK_DATA: list[dict[str, Any]] = [ diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index 967379f356f7..deb44a5a79f4 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -32,7 +32,7 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: - with suppress(ModuleNotFoundError): # not available on 3.8/windows + with suppress(ModuleNotFoundError): # not available on windows from adbc_driver_sqlite.dbapi import connect args = tuple(str(a) if isinstance(a, Path) else a for a in args) diff --git a/py-polars/tests/unit/io/database/test_write.py b/py-polars/tests/unit/io/database/test_write.py index da9550d7126a..0e50044030f4 100644 --- a/py-polars/tests/unit/io/database/test_write.py +++ b/py-polars/tests/unit/io/database/test_write.py @@ -28,16 +28,16 @@ "adbc", True, marks=pytest.mark.skipif( - sys.version_info < (3, 9) or sys.platform == "win32", - reason="adbc not available on Windows or <= Python 3.8", + sys.platform == "win32", + reason="adbc not available on Windows", ), ), pytest.param( "adbc", False, marks=pytest.mark.skipif( - sys.version_info < (3, 9) or sys.platform == "win32", - reason="adbc not available on Windows or <= Python 3.8", + sys.platform == "win32", + reason="adbc not available on Windows", ), ), ], diff --git a/py-polars/tests/unit/io/test_plugins.py b/py-polars/tests/unit/io/test_plugins.py index 98c25edc3f4a..6303df166962 100644 --- a/py-polars/tests/unit/io/test_plugins.py +++ b/py-polars/tests/unit/io/test_plugins.py @@ -7,7 +7,7 @@ from polars.testing import assert_frame_equal if TYPE_CHECKING: - from typing import Iterator + from collections.abc import Iterator # A simple python source. But this can dispatch into a rust IO source as well. diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 06af29659ba0..7483f371b51c 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -5,7 +5,7 @@ from datetime import date, datetime from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable import pytest @@ -17,6 +17,8 @@ from tests.unit.conftest import FLOAT_DTYPES, NUMERIC_DTYPES if TYPE_CHECKING: + from collections.abc import Sequence + from polars._typing import ExcelSpreadsheetEngine, SelectorType pytestmark = pytest.mark.slow() diff --git a/py-polars/tests/unit/io/test_utils.py b/py-polars/tests/unit/io/test_utils.py index 7c2173469ebb..e115aec4f71f 100644 --- a/py-polars/tests/unit/io/test_utils.py +++ b/py-polars/tests/unit/io/test_utils.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING import pytest from polars.io._utils import looks_like_url, parse_columns_arg, parse_row_index_args +if TYPE_CHECKING: + from collections.abc import Sequence + @pytest.mark.parametrize( ("columns", "expected"), diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index ce147be9ef27..0ef231bc8943 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -340,9 +340,10 @@ def test_map_elements_chunked_14390() -> None: def test_cabbage_strategy_14396() -> None: df = pl.DataFrame({"x": [1, 2, 3]}) - with pytest.raises( - ValueError, match="strategy 'cabbage' is not supported" - ), pytest.warns(PolarsInefficientMapWarning): + with ( + pytest.raises(ValueError, match="strategy 'cabbage' is not supported"), + pytest.warns(PolarsInefficientMapWarning), + ): df.select(pl.col("x").map_elements(lambda x: 2 * x, strategy="cabbage")) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/operations/map/test_map_groups.py b/py-polars/tests/unit/operations/map/test_map_groups.py index d675d43e28f6..772f4d088249 100644 --- a/py-polars/tests/unit/operations/map/test_map_groups.py +++ b/py-polars/tests/unit/operations/map/test_map_groups.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any import numpy as np import pytest @@ -9,6 +9,9 @@ from polars.exceptions import ComputeError from polars.testing import assert_frame_equal +if TYPE_CHECKING: + from collections.abc import Sequence + def test_map_groups() -> None: df = pl.DataFrame( diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py b/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py index 5fe78225670e..4d646e2bba49 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_add_business_days.py @@ -1,7 +1,6 @@ from __future__ import annotations import datetime as dt -import sys from datetime import date, datetime, timedelta from typing import TYPE_CHECKING @@ -11,19 +10,13 @@ from hypothesis import assume, given import polars as pl -from polars.dependencies import _ZONEINFO_AVAILABLE from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_series_equal if TYPE_CHECKING: from polars._typing import Roll, TimeUnit -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo +from zoneinfo import ZoneInfo def test_add_business_days() -> None: diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py index 5bb09a0f8d1c..f07106ee76ac 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_to_datetime.py @@ -1,25 +1,17 @@ from __future__ import annotations -import sys from datetime import date, datetime from typing import TYPE_CHECKING +from zoneinfo import ZoneInfo import hypothesis.strategies as st import pytest from hypothesis import given import polars as pl -from polars.dependencies import _ZONEINFO_AVAILABLE from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_series_equal -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo - if TYPE_CHECKING: from hypothesis.strategies import DrawFn diff --git a/py-polars/tests/unit/operations/test_comparison.py b/py-polars/tests/unit/operations/test_comparison.py index 72771ff64eab..26f95e269339 100644 --- a/py-polars/tests/unit/operations/test_comparison.py +++ b/py-polars/tests/unit/operations/test_comparison.py @@ -2,7 +2,7 @@ import math from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, ContextManager +from typing import TYPE_CHECKING, Any import pytest @@ -11,6 +11,8 @@ from polars.testing import assert_frame_equal if TYPE_CHECKING: + from contextlib import AbstractContextManager as ContextManager + from polars._typing import PolarsDataType diff --git a/py-polars/tests/unit/operations/test_cross_join.py b/py-polars/tests/unit/operations/test_cross_join.py index 94830371ed35..f424da5ab170 100644 --- a/py-polars/tests/unit/operations/test_cross_join.py +++ b/py-polars/tests/unit/operations/test_cross_join.py @@ -1,17 +1,8 @@ -import sys from datetime import datetime +from zoneinfo import ZoneInfo import pytest -from polars.dependencies import _ZONEINFO_AVAILABLE - -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo - import polars as pl diff --git a/py-polars/tests/unit/operations/test_ewm_by.py b/py-polars/tests/unit/operations/test_ewm_by.py index e804f6b6d5d8..8004303238d4 100644 --- a/py-polars/tests/unit/operations/test_ewm_by.py +++ b/py-polars/tests/unit/operations/test_ewm_by.py @@ -1,25 +1,18 @@ from __future__ import annotations -import sys from datetime import date, datetime, timedelta from typing import TYPE_CHECKING import pytest import polars as pl -from polars.dependencies import _ZONEINFO_AVAILABLE from polars.exceptions import InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from polars._typing import PolarsIntegerType, TimeUnit -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo +from zoneinfo import ZoneInfo @pytest.mark.parametrize("sort", [True, False]) diff --git a/py-polars/tests/unit/operations/test_interpolate.py b/py-polars/tests/unit/operations/test_interpolate.py index e80d4102466c..bac0693205da 100644 --- a/py-polars/tests/unit/operations/test_interpolate.py +++ b/py-polars/tests/unit/operations/test_interpolate.py @@ -1,24 +1,17 @@ from __future__ import annotations -import sys from datetime import date, datetime, time, timedelta from typing import TYPE_CHECKING, Any import pytest import polars as pl -from polars.dependencies import _ZONEINFO_AVAILABLE from polars.testing import assert_frame_equal if TYPE_CHECKING: from polars._typing import PolarsDataType, PolarsTemporalType -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo +from zoneinfo import ZoneInfo @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/operations/test_transpose.py b/py-polars/tests/unit/operations/test_transpose.py index a43a6e7f629e..d81ab5b1404e 100644 --- a/py-polars/tests/unit/operations/test_transpose.py +++ b/py-polars/tests/unit/operations/test_transpose.py @@ -1,6 +1,6 @@ import io +from collections.abc import Iterator from datetime import date, datetime -from typing import Iterator import pytest diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 0f45cdfb6e21..491890b5dc13 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -2,7 +2,7 @@ import math from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING, Any, Iterator, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np import pandas as pd @@ -31,6 +31,7 @@ from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder if TYPE_CHECKING: + from collections.abc import Iterator from zoneinfo import ZoneInfo from polars._typing import EpochTimeUnit, PolarsDataType, TimeUnit diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index 3a1e90bed1b4..43c00ed8b3d5 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -295,10 +295,13 @@ def test_join_misc_16255() -> None: ) def test_non_equi_joins(constraint: str) -> None: # no support (yet) for non equi-joins in polars joins - with pytest.raises( - SQLInterfaceError, - match=r"only equi-join constraints are supported", - ), pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx: + with ( + pytest.raises( + SQLInterfaceError, + match=r"only equi-join constraints are supported", + ), + pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx, + ): ctx.execute( f""" SELECT * @@ -310,12 +313,19 @@ def test_non_equi_joins(constraint: str) -> None: def test_implicit_joins() -> None: # no support for this yet; ensure we catch it - with pytest.raises( - SQLInterfaceError, - match=r"not currently supported .* use explicit JOIN syntax instead", - ), pl.SQLContext( - {"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2], "c": ["x", "y", "z"]})} - ) as ctx: + with ( + pytest.raises( + SQLInterfaceError, + match=r"not currently supported .* use explicit JOIN syntax instead", + ), + pl.SQLContext( + { + "tbl": pl.DataFrame( + {"a": [1, 2, 3], "b": [4, 3, 2], "c": ["x", "y", "z"]} + ) + } + ) as ctx, + ): ctx.execute( """ SELECT t1.* diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py index f6cd9411bebf..cd0491b5435f 100644 --- a/py-polars/tests/unit/test_config.py +++ b/py-polars/tests/unit/test_config.py @@ -2,7 +2,7 @@ import os from pathlib import Path -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any import pytest @@ -11,6 +11,9 @@ from polars._utils.unstable import issue_unstable_warning from polars.config import _POLARS_CFG_ENV_VARS +if TYPE_CHECKING: + from collections.abc import Iterator + @pytest.fixture(autouse=True) def _environ() -> Iterator[None]: diff --git a/py-polars/tests/unit/test_format.py b/py-polars/tests/unit/test_format.py index 2461f100209b..44c754e4884b 100644 --- a/py-polars/tests/unit/test_format.py +++ b/py-polars/tests/unit/test_format.py @@ -2,7 +2,7 @@ import string from decimal import Decimal as D -from typing import TYPE_CHECKING, Any, Iterator +from typing import TYPE_CHECKING, Any import pytest @@ -10,6 +10,8 @@ from polars.exceptions import InvalidOperationError if TYPE_CHECKING: + from collections.abc import Iterator + from polars._typing import PolarsDataType @@ -290,8 +292,9 @@ def test_fmt_float_full() -> None: def test_fmt_list_12188() -> None: # set max_items to 1 < 4(size of failed list) to touch the testing branch. - with pl.Config(fmt_table_cell_list_len=1), pytest.raises( - InvalidOperationError, match="from `i64` to `u8` failed" + with ( + pl.Config(fmt_table_cell_list_len=1), + pytest.raises(InvalidOperationError, match="from `i64` to `u8` failed"), ): pl.DataFrame( { diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index a5d573d10379..bf44ff87bac5 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -1,26 +1,18 @@ -import sys from collections import OrderedDict from datetime import datetime from typing import Any +from zoneinfo import ZoneInfo import pytest import polars as pl import polars.selectors as cs from polars._typing import SelectorType -from polars.dependencies import _ZONEINFO_AVAILABLE from polars.exceptions import ColumnNotFoundError, InvalidOperationError from polars.selectors import expand_selector, is_selector from polars.testing import assert_frame_equal from tests.unit.conftest import INTEGER_DTYPES, TEMPORAL_DTYPES -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo - def assert_repr_equals(item: Any, expected: str) -> None: """Assert that the repr of an item matches the expected string.""" diff --git a/py-polars/tests/unit/test_string_cache.py b/py-polars/tests/unit/test_string_cache.py index b54b08d48a86..def1f15db07a 100644 --- a/py-polars/tests/unit/test_string_cache.py +++ b/py-polars/tests/unit/test_string_cache.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import pytest diff --git a/py-polars/tests/unit/utils/test_utils.py b/py-polars/tests/unit/utils/test_utils.py index f6eca92215da..96730c91434b 100644 --- a/py-polars/tests/unit/utils/test_utils.py +++ b/py-polars/tests/unit/utils/test_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any import numpy as np import pytest @@ -25,6 +25,7 @@ ) if TYPE_CHECKING: + from collections.abc import Sequence from zoneinfo import ZoneInfo from polars._typing import TimeUnit From 79fcd5332de196bd8e0a8024a7535822fac57d29 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 12:56:44 +0200 Subject: [PATCH 18/33] refactor: Divide `ChunkCompare` into `Eq` and `Ineq` variants (#18963) --- .../chunked_array/comparison/categorical.rs | 18 +- .../src/chunked_array/comparison/mod.rs | 75 +++---- .../src/chunked_array/comparison/scalar.rs | 24 ++- .../polars-core/src/chunked_array/ops/mod.rs | 29 +-- .../src/chunked_array/ops/unique/mod.rs | 3 +- crates/polars-core/src/frame/column/mod.rs | 22 +- crates/polars-core/src/series/comparison.rs | 203 +++++++++++++----- crates/polars-core/src/series/mod.rs | 5 +- crates/polars-expr/src/expressions/apply.rs | 16 +- crates/polars-expr/src/expressions/binary.rs | 20 +- .../src/chunked_array/array/count.rs | 2 +- .../src/chunked_array/list/count.rs | 2 +- crates/polars-ops/src/chunked_array/peaks.rs | 4 +- 13 files changed, 264 insertions(+), 159 deletions(-) diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index faa7f619cdb2..bbcd6b6047c9 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -96,7 +96,7 @@ where } } -impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { +impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked { type Item = PolarsResult; fn equal(&self, rhs: &CategoricalChunked) -> Self::Item { @@ -134,6 +134,10 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { UInt32Chunked::not_equal_missing, ) } +} + +impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked { + type Item = PolarsResult; fn gt(&self, rhs: &CategoricalChunked) -> Self::Item { cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r) @@ -217,7 +221,7 @@ where } } -impl ChunkCompare<&StringChunked> for CategoricalChunked { +impl ChunkCompareEq<&StringChunked> for CategoricalChunked { type Item = PolarsResult; fn equal(&self, rhs: &StringChunked) -> Self::Item { @@ -265,6 +269,10 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { StringChunked::not_equal_missing, ) } +} + +impl ChunkCompareIneq<&StringChunked> for CategoricalChunked { + type Item = PolarsResult; fn gt(&self, rhs: &StringChunked) -> Self::Item { cat_str_compare_helper( @@ -376,7 +384,7 @@ where } } -impl ChunkCompare<&str> for CategoricalChunked { +impl ChunkCompareEq<&str> for CategoricalChunked { type Item = PolarsResult; fn equal(&self, rhs: &str) -> Self::Item { @@ -414,6 +422,10 @@ impl ChunkCompare<&str> for CategoricalChunked { UInt32Chunked::equal_missing, ) } +} + +impl ChunkCompareIneq<&str> for CategoricalChunked { + type Item = PolarsResult; fn gt(&self, rhs: &str) -> Self::Item { cat_single_str_compare_helper( diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 300f5f338cff..ecf8f78fcd9a 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -16,7 +16,7 @@ use crate::series::implementations::null::NullChunked; use crate::series::IsSorted; use crate::utils::align_chunks_binary; -impl ChunkCompare<&ChunkedArray> for ChunkedArray +impl ChunkCompareEq<&ChunkedArray> for ChunkedArray where T: PolarsNumericType, T::Array: TotalOrdKernel + TotalEqKernel, @@ -126,6 +126,14 @@ where ), } } +} + +impl ChunkCompareIneq<&ChunkedArray> for ChunkedArray +where + T: PolarsNumericType, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; fn lt(&self, rhs: &ChunkedArray) -> BooleanChunked { // Broadcast. @@ -188,7 +196,7 @@ where } } -impl ChunkCompare<&NullChunked> for NullChunked { +impl ChunkCompareEq<&NullChunked> for NullChunked { type Item = BooleanChunked; fn equal(&self, rhs: &NullChunked) -> Self::Item { @@ -206,6 +214,10 @@ impl ChunkCompare<&NullChunked> for NullChunked { fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item { BooleanChunked::full(self.name().clone(), false, get_broadcast_length(self, rhs)) } +} + +impl ChunkCompareIneq<&NullChunked> for NullChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &NullChunked) -> Self::Item { BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) @@ -234,7 +246,7 @@ fn get_broadcast_length(lhs: &NullChunked, rhs: &NullChunked) -> usize { } } -impl ChunkCompare<&BooleanChunked> for BooleanChunked { +impl ChunkCompareEq<&BooleanChunked> for BooleanChunked { type Item = BooleanChunked; fn equal(&self, rhs: &BooleanChunked) -> BooleanChunked { @@ -348,6 +360,10 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { ), } } +} + +impl ChunkCompareIneq<&BooleanChunked> for BooleanChunked { + type Item = BooleanChunked; fn lt(&self, rhs: &BooleanChunked) -> BooleanChunked { // Broadcast. @@ -410,7 +426,7 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { } } -impl ChunkCompare<&StringChunked> for StringChunked { +impl ChunkCompareEq<&StringChunked> for StringChunked { type Item = BooleanChunked; fn equal(&self, rhs: &StringChunked) -> BooleanChunked { @@ -424,9 +440,14 @@ impl ChunkCompare<&StringChunked> for StringChunked { fn not_equal(&self, rhs: &StringChunked) -> BooleanChunked { self.as_binary().not_equal(&rhs.as_binary()) } + fn not_equal_missing(&self, rhs: &StringChunked) -> BooleanChunked { self.as_binary().not_equal_missing(&rhs.as_binary()) } +} + +impl ChunkCompareIneq<&StringChunked> for StringChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &StringChunked) -> BooleanChunked { self.as_binary().gt(&rhs.as_binary()) @@ -445,7 +466,7 @@ impl ChunkCompare<&StringChunked> for StringChunked { } } -impl ChunkCompare<&BinaryChunked> for BinaryChunked { +impl ChunkCompareEq<&BinaryChunked> for BinaryChunked { type Item = BooleanChunked; fn equal(&self, rhs: &BinaryChunked) -> BooleanChunked { @@ -551,6 +572,10 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { ), } } +} + +impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked { + type Item = BooleanChunked; fn lt(&self, rhs: &BinaryChunked) -> BooleanChunked { // Broadcast. @@ -644,7 +669,7 @@ where } } -impl ChunkCompare<&ListChunked> for ListChunked { +impl ChunkCompareEq<&ListChunked> for ListChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ListChunked) -> BooleanChunked { let _series_equals = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { @@ -684,23 +709,6 @@ impl ChunkCompare<&ListChunked> for ListChunked { _list_comparison_helper(self, rhs, _series_not_equal_missing) } - - // The following are not implemented because gt, lt comparison of series don't make sense. - fn gt(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } - - fn gt_eq(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt_eq(&self, _rhs: &ListChunked) -> BooleanChunked { - unimplemented!() - } } #[cfg(feature = "dtype-struct")] @@ -741,7 +749,7 @@ where } #[cfg(feature = "dtype-struct")] -impl ChunkCompare<&StructChunked> for StructChunked { +impl ChunkCompareEq<&StructChunked> for StructChunked { type Item = BooleanChunked; fn equal(&self, rhs: &StructChunked) -> BooleanChunked { struct_helper( @@ -785,7 +793,7 @@ impl ChunkCompare<&StructChunked> for StructChunked { } #[cfg(feature = "dtype-array")] -impl ChunkCompare<&ArrayChunked> for ArrayChunked { +impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked { if self.width() != rhs.width() { @@ -834,23 +842,6 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked { PlSmallStr::EMPTY, ) } - - // following are not implemented because gt, lt comparison of series don't make sense - fn gt(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } - - fn gt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } - - fn lt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked { - unimplemented!() - } } impl Not for &BooleanChunked { diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs index 1c632299c1e4..4649696b41dd 100644 --- a/crates/polars-core/src/chunked_array/comparison/scalar.rs +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -61,13 +61,14 @@ where ca } -impl ChunkCompare for ChunkedArray +impl ChunkCompareEq for ChunkedArray where T: PolarsNumericType, Rhs: ToPrimitive, T::Array: TotalOrdKernel + TotalEqKernel, { type Item = BooleanChunked; + fn equal(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); let fa = Some(|x: T::Native| x.tot_ge(&rhs)); @@ -111,6 +112,15 @@ where }) } } +} + +impl ChunkCompareIneq for ChunkedArray +where + T: PolarsNumericType, + Rhs: ToPrimitive, + T::Array: TotalOrdKernel + TotalEqKernel, +{ + type Item = BooleanChunked; fn gt(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); @@ -157,7 +167,7 @@ where } } -impl ChunkCompare<&[u8]> for BinaryChunked { +impl ChunkCompareEq<&[u8]> for BinaryChunked { type Item = BooleanChunked; fn equal(&self, rhs: &[u8]) -> BooleanChunked { @@ -175,6 +185,10 @@ impl ChunkCompare<&[u8]> for BinaryChunked { fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked { arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into()) } +} + +impl ChunkCompareIneq<&[u8]> for BinaryChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &[u8]) -> BooleanChunked { arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into()) @@ -193,7 +207,7 @@ impl ChunkCompare<&[u8]> for BinaryChunked { } } -impl ChunkCompare<&str> for StringChunked { +impl ChunkCompareEq<&str> for StringChunked { type Item = BooleanChunked; fn equal(&self, rhs: &str) -> BooleanChunked { @@ -211,6 +225,10 @@ impl ChunkCompare<&str> for StringChunked { fn not_equal_missing(&self, rhs: &str) -> BooleanChunked { arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into()) } +} + +impl ChunkCompareIneq<&str> for StringChunked { + type Item = BooleanChunked; fn gt(&self, rhs: &str) -> BooleanChunked { arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into()) diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 8da567d06491..2bc1337e598f 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -38,7 +38,6 @@ pub(crate) mod unique; #[cfg(feature = "zip_with")] pub mod zip; -use polars_utils::no_call_const; #[cfg(feature = "serde-lazy")] use serde::{Deserialize, Serialize}; pub use sort::options::*; @@ -312,7 +311,7 @@ pub trait ChunkVar { /// df.filter(&mask) /// } /// ``` -pub trait ChunkCompare { +pub trait ChunkCompareEq { type Item; /// Check for equality. @@ -326,30 +325,24 @@ pub trait ChunkCompare { /// Check for inequality where `None == None`. fn not_equal_missing(&self, rhs: Rhs) -> Self::Item; +} + +/// Compare [`Series`] and [`ChunkedArray`]'s using inequality operators (`<`, `>=`, etc.) and get +/// a `boolean` mask that can be used to filter rows. +pub trait ChunkCompareIneq { + type Item; /// Greater than comparison. - #[allow(unused_variables)] - fn gt(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn gt(&self, rhs: Rhs) -> Self::Item; /// Greater than or equal comparison. - #[allow(unused_variables)] - fn gt_eq(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn gt_eq(&self, rhs: Rhs) -> Self::Item; /// Less than comparison. - #[allow(unused_variables)] - fn lt(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn lt(&self, rhs: Rhs) -> Self::Item; /// Less than or equal comparison - #[allow(unused_variables)] - fn lt_eq(&self, rhs: Rhs) -> Self::Item { - no_call_const!() - } + fn lt_eq(&self, rhs: Rhs) -> Self::Item; } /// Get unique values in a `ChunkedArray` diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index b645088b4d68..b073700867af 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -87,7 +87,8 @@ where T: PolarsNumericType, T::Native: TotalHash + TotalEq + ToTotalOrd, ::TotalOrdItem: Hash + Eq + Ord, - ChunkedArray: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, + ChunkedArray: + IntoSeries + for<'a> ChunkCompareEq<&'a ChunkedArray, Item = BooleanChunked>, { fn unique(&self) -> PolarsResult { // prevent stackoverflow repeated sorted.unique call diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 78c36db57f78..cb88a9946ca4 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -992,61 +992,65 @@ impl Column { } } -impl ChunkCompare<&Column> for Column { +impl ChunkCompareEq<&Column> for Column { type Item = PolarsResult; /// Create a boolean mask by checking for equality. #[inline] - fn equal(&self, rhs: &Column) -> PolarsResult { + fn equal(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .equal(rhs.as_materialized_series()) } /// Create a boolean mask by checking for equality. #[inline] - fn equal_missing(&self, rhs: &Column) -> PolarsResult { + fn equal_missing(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .equal_missing(rhs.as_materialized_series()) } /// Create a boolean mask by checking for inequality. #[inline] - fn not_equal(&self, rhs: &Column) -> PolarsResult { + fn not_equal(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .not_equal(rhs.as_materialized_series()) } /// Create a boolean mask by checking for inequality. #[inline] - fn not_equal_missing(&self, rhs: &Column) -> PolarsResult { + fn not_equal_missing(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .not_equal_missing(rhs.as_materialized_series()) } +} + +impl ChunkCompareIneq<&Column> for Column { + type Item = PolarsResult; /// Create a boolean mask by checking if self > rhs. #[inline] - fn gt(&self, rhs: &Column) -> PolarsResult { + fn gt(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .gt(rhs.as_materialized_series()) } /// Create a boolean mask by checking if self >= rhs. #[inline] - fn gt_eq(&self, rhs: &Column) -> PolarsResult { + fn gt_eq(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .gt_eq(rhs.as_materialized_series()) } /// Create a boolean mask by checking if self < rhs. #[inline] - fn lt(&self, rhs: &Column) -> PolarsResult { + fn lt(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .lt(rhs.as_materialized_series()) } /// Create a boolean mask by checking if self <= rhs. #[inline] - fn lt_eq(&self, rhs: &Column) -> PolarsResult { + fn lt_eq(&self, rhs: &Column) -> Self::Item { self.as_materialized_series() .lt_eq(rhs.as_materialized_series()) } diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index bea981db89f1..228221c076aa 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -4,8 +4,8 @@ use crate::prelude::*; use crate::series::arithmetic::coerce_lhs_rhs; use crate::series::nulls::replace_non_null; -macro_rules! impl_compare { - ($self:expr, $rhs:expr, $method:ident, $struct_function:expr) => {{ +macro_rules! impl_eq_compare { + ($self:expr, $rhs:expr, $method:ident) => {{ use DataType::*; let (lhs, rhs) = ($self, $rhs); validate_types(lhs.dtype(), rhs.dtype())?; @@ -70,14 +70,7 @@ macro_rules! impl_compare { #[cfg(feature = "dtype-array")] Array(_, _) => lhs.array().unwrap().$method(rhs.array().unwrap()), #[cfg(feature = "dtype-struct")] - Struct(_) => { - let lhs = lhs - .struct_() - .unwrap(); - let rhs = rhs.struct_().unwrap(); - - $struct_function(lhs, rhs)? - }, + Struct(_) => lhs.struct_().unwrap().$method(rhs.struct_().unwrap()), #[cfg(feature = "dtype-decimal")] Decimal(_, s1) => { let DataType::Decimal(_, s2) = rhs.dtype() else { @@ -96,14 +89,108 @@ macro_rules! impl_compare { }}; } -#[cfg(feature = "dtype-struct")] -fn raise_struct(_a: &StructChunked, _b: &StructChunked) -> PolarsResult { - polars_bail!(InvalidOperation: "order comparison not support for struct dtype") +macro_rules! bail_invalid_ineq { + ($lhs:expr, $rhs:expr, $op:literal) => { + polars_bail!( + InvalidOperation: "cannot perform '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + $op, + $lhs.name(), $lhs.dtype(), + $rhs.name(), $rhs.dtype(), + ) + }; } -#[cfg(not(feature = "dtype-struct"))] -fn raise_struct(_a: &(), _b: &()) -> PolarsResult { - unimplemented!() +macro_rules! impl_ineq_compare { + ($self:expr, $rhs:expr, $method:ident, $op:literal) => {{ + use DataType::*; + let (lhs, rhs) = ($self, $rhs); + validate_types(lhs.dtype(), rhs.dtype())?; + + polars_ensure!( + lhs.len() == rhs.len() || + + // Broadcast + lhs.len() == 1 || + rhs.len() == 1, + ShapeMismatch: + "could not perform '{}' comparison between series '{}' of length: {} and series '{}' of length: {}, because they have different lengths", + $op, + lhs.name(), lhs.len(), + rhs.name(), rhs.len() + ); + + #[cfg(feature = "dtype-categorical")] + match (lhs.dtype(), rhs.dtype()) { + (Categorical(_, _) | Enum(_, _), Categorical(_, _) | Enum(_, _)) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.categorical().unwrap())? + .with_name(lhs.name().clone())); + }, + (Categorical(_, _) | Enum(_, _), String) => { + return Ok(lhs + .categorical() + .unwrap() + .$method(rhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + (String, Categorical(_, _) | Enum(_, _)) => { + return Ok(rhs + .categorical() + .unwrap() + .$method(lhs.str().unwrap())? + .with_name(lhs.name().clone())); + }, + _ => (), + }; + + let (lhs, rhs) = coerce_lhs_rhs(lhs, rhs).map_err(|_| + polars_err!( + SchemaMismatch: "could not evaluate '{}' comparison between series '{}' of dtype: {} and series '{}' of dtype: {}", + $op, + lhs.name(), lhs.dtype(), + rhs.name(), rhs.dtype() + ) + )?; + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); + let mut out = match lhs.dtype() { + Null => lhs.null().unwrap().$method(rhs.null().unwrap()), + Boolean => lhs.bool().unwrap().$method(rhs.bool().unwrap()), + String => lhs.str().unwrap().$method(rhs.str().unwrap()), + Binary => lhs.binary().unwrap().$method(rhs.binary().unwrap()), + UInt8 => lhs.u8().unwrap().$method(rhs.u8().unwrap()), + UInt16 => lhs.u16().unwrap().$method(rhs.u16().unwrap()), + UInt32 => lhs.u32().unwrap().$method(rhs.u32().unwrap()), + UInt64 => lhs.u64().unwrap().$method(rhs.u64().unwrap()), + Int8 => lhs.i8().unwrap().$method(rhs.i8().unwrap()), + Int16 => lhs.i16().unwrap().$method(rhs.i16().unwrap()), + Int32 => lhs.i32().unwrap().$method(rhs.i32().unwrap()), + Int64 => lhs.i64().unwrap().$method(rhs.i64().unwrap()), + Float32 => lhs.f32().unwrap().$method(rhs.f32().unwrap()), + Float64 => lhs.f64().unwrap().$method(rhs.f64().unwrap()), + List(_) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-array")] + Array(_, _) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-struct")] + Struct(_) => bail_invalid_ineq!(lhs, rhs, $op), + #[cfg(feature = "dtype-decimal")] + Decimal(_, s1) => { + let DataType::Decimal(_, s2) = rhs.dtype() else { + unreachable!() + }; + let scale = s1.max(s2).unwrap(); + let lhs = lhs.decimal().unwrap().to_scale(scale).unwrap(); + let rhs = rhs.decimal().unwrap().to_scale(scale).unwrap(); + lhs.0.$method(&rhs.0) + }, + + dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()), + }; + out.rename(lhs.name().clone()); + PolarsResult::Ok(out) + }}; } fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> { @@ -124,74 +211,61 @@ fn validate_types(left: &DataType, right: &DataType) -> PolarsResult<()> { Ok(()) } -impl ChunkCompare<&Series> for Series { +impl ChunkCompareEq<&Series> for Series { type Item = PolarsResult; /// Create a boolean mask by checking for equality. - fn equal(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, equal, |a: &StructChunked, b: &StructChunked| { - PolarsResult::Ok(a.equal(b)) - }) + fn equal(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, equal) } /// Create a boolean mask by checking for equality. - fn equal_missing(&self, rhs: &Series) -> PolarsResult { - impl_compare!( - self, - rhs, - equal_missing, - |a: &StructChunked, b: &StructChunked| PolarsResult::Ok(a.equal_missing(b)) - ) + fn equal_missing(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, equal_missing) } /// Create a boolean mask by checking for inequality. - fn not_equal(&self, rhs: &Series) -> PolarsResult { - impl_compare!( - self, - rhs, - not_equal, - |a: &StructChunked, b: &StructChunked| PolarsResult::Ok(a.not_equal(b)) - ) + fn not_equal(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, not_equal) } /// Create a boolean mask by checking for inequality. - fn not_equal_missing(&self, rhs: &Series) -> PolarsResult { - impl_compare!( - self, - rhs, - not_equal_missing, - |a: &StructChunked, b: &StructChunked| PolarsResult::Ok(a.not_equal_missing(b)) - ) + fn not_equal_missing(&self, rhs: &Series) -> Self::Item { + impl_eq_compare!(self, rhs, not_equal_missing) } +} + +impl ChunkCompareIneq<&Series> for Series { + type Item = PolarsResult; /// Create a boolean mask by checking if self > rhs. - fn gt(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, gt, raise_struct) + fn gt(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, gt, ">") } /// Create a boolean mask by checking if self >= rhs. - fn gt_eq(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, gt_eq, raise_struct) + fn gt_eq(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, gt_eq, ">=") } /// Create a boolean mask by checking if self < rhs. - fn lt(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, lt, raise_struct) + fn lt(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, lt, "<") } /// Create a boolean mask by checking if self <= rhs. - fn lt_eq(&self, rhs: &Series) -> PolarsResult { - impl_compare!(self, rhs, lt_eq, raise_struct) + fn lt_eq(&self, rhs: &Series) -> Self::Item { + impl_ineq_compare!(self, rhs, lt_eq, "<=") } } -impl ChunkCompare for Series +impl ChunkCompareEq for Series where Rhs: NumericNative, { type Item = PolarsResult; - fn equal(&self, rhs: Rhs) -> PolarsResult { + fn equal(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, equal, rhs)) @@ -203,7 +277,7 @@ where Ok(apply_method_physical_numeric!(&s, equal_missing, rhs)) } - fn not_equal(&self, rhs: Rhs) -> PolarsResult { + fn not_equal(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, not_equal, rhs)) @@ -214,33 +288,40 @@ where let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, not_equal_missing, rhs)) } +} - fn gt(&self, rhs: Rhs) -> PolarsResult { +impl ChunkCompareIneq for Series +where + Rhs: NumericNative, +{ + type Item = PolarsResult; + + fn gt(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, gt, rhs)) } - fn gt_eq(&self, rhs: Rhs) -> PolarsResult { + fn gt_eq(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, gt_eq, rhs)) } - fn lt(&self, rhs: Rhs) -> PolarsResult { + fn lt(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, lt, rhs)) } - fn lt_eq(&self, rhs: Rhs) -> PolarsResult { + fn lt_eq(&self, rhs: Rhs) -> Self::Item { validate_types(self.dtype(), &DataType::Int8)?; let s = self.to_physical_repr(); Ok(apply_method_physical_numeric!(&s, lt_eq, rhs)) } } -impl ChunkCompare<&str> for Series { +impl ChunkCompareEq<&str> for Series { type Item = PolarsResult; fn equal(&self, rhs: &str) -> PolarsResult { @@ -294,8 +375,12 @@ impl ChunkCompare<&str> for Series { _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)), } } +} + +impl ChunkCompareIneq<&str> for Series { + type Item = PolarsResult; - fn gt(&self, rhs: &str) -> PolarsResult { + fn gt(&self, rhs: &str) -> Self::Item { validate_types(self.dtype(), &DataType::String)?; match self.dtype() { DataType::String => Ok(self.str().unwrap().gt(rhs)), diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index ce9bcffba2f0..72cb3b67dc41 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -1,5 +1,5 @@ //! Type agnostic columnar data structure. -pub use crate::prelude::ChunkCompare; +pub use crate::prelude::ChunkCompareEq; use crate::prelude::*; use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; @@ -90,7 +90,8 @@ use crate::POOL; /// .all(|(a, b)| a == *b)) /// ``` /// -/// See all the comparison operators in the [CmpOps trait](crate::chunked_array::ops::ChunkCompare) +/// See all the comparison operators in the [ChunkCompareEq trait](crate::chunked_array::ops::ChunkCompareEq) and +/// [ChunkCompareIneq trait](crate::chunked_array::ops::ChunkCompareIneq). /// /// ## Iterators /// The Series variants contain differently typed [ChunkedArray](crate::chunked_array::ChunkedArray)s. diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 0eeb8555071b..803e0801e636 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -614,12 +614,12 @@ impl ApplyExpr { if max.get(0).unwrap() == min.get(0).unwrap() { let one_equals = - |value: &Series| Some(ChunkCompare::equal(input, value).ok()?.any()); + |value: &Series| Some(ChunkCompareEq::equal(input, value).ok()?.any()); return one_equals(min); } - let smaller = ChunkCompare::lt(input, min).ok()?; - let bigger = ChunkCompare::gt(input, max).ok()?; + let smaller = ChunkCompareIneq::lt(input, min).ok()?; + let bigger = ChunkCompareIneq::gt(input, max).ok()?; Some(!(smaller | bigger).all()) }; @@ -662,7 +662,7 @@ impl ApplyExpr { // don't read the row_group anyways as // the condition will evaluate to false. // e.g. in_between(10, 5) - if ChunkCompare::gt(&left, &right).ok()?.all() { + if ChunkCompareIneq::gt(&left, &right).ok()?.all() { return Some(false); } @@ -674,15 +674,15 @@ impl ApplyExpr { }; // check the right limit of the interval. // if the end is open, we should be stricter (lt_eq instead of lt). - if right_open && ChunkCompare::lt_eq(&right, min).ok()?.all() - || !right_open && ChunkCompare::lt(&right, min).ok()?.all() + if right_open && ChunkCompareIneq::lt_eq(&right, min).ok()?.all() + || !right_open && ChunkCompareIneq::lt(&right, min).ok()?.all() { return Some(false); } // we couldn't conclude anything using the right limit, // check the left limit of the interval - if left_open && ChunkCompare::gt_eq(&left, max).ok()?.all() - || !left_open && ChunkCompare::gt(&left, max).ok()?.all() + if left_open && ChunkCompareIneq::gt_eq(&left, max).ok()?.all() + || !left_open && ChunkCompareIneq::gt(&left, max).ok()?.all() { return Some(false); } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 179f4524aa7e..d0b00bf2ddac 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -55,12 +55,12 @@ fn apply_operator_owned(left: Series, right: Series, op: Operator) -> PolarsResu pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResult { use DataType::*; match op { - Operator::Gt => ChunkCompare::gt(left, right).map(|ca| ca.into_series()), - Operator::GtEq => ChunkCompare::gt_eq(left, right).map(|ca| ca.into_series()), - Operator::Lt => ChunkCompare::lt(left, right).map(|ca| ca.into_series()), - Operator::LtEq => ChunkCompare::lt_eq(left, right).map(|ca| ca.into_series()), - Operator::Eq => ChunkCompare::equal(left, right).map(|ca| ca.into_series()), - Operator::NotEq => ChunkCompare::not_equal(left, right).map(|ca| ca.into_series()), + Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_series()), + Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_series()), + Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_series()), + Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_series()), + Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_series()), + Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_series()), Operator::Plus => left + right, Operator::Minus => left - right, Operator::Multiply => left * right, @@ -283,7 +283,7 @@ mod stats { use super::*; fn apply_operator_stats_eq(min_max: &Series, literal: &Series) -> bool { - use ChunkCompare as C; + use ChunkCompareIneq as C; // Literal is greater than max, don't need to read. if C::gt(literal, min_max).map(|s| s.all()).unwrap_or(false) { return false; @@ -301,7 +301,7 @@ mod stats { if min_max.len() < 2 || min_max.null_count() > 0 { return true; } - use ChunkCompare as C; + use ChunkCompareEq as C; // First check proofs all values are the same (e.g. min/max is the same) // Second check proofs all values are equal, so we can skip as we search @@ -315,7 +315,7 @@ mod stats { } fn apply_operator_stats_rhs_lit(min_max: &Series, literal: &Series, op: Operator) -> bool { - use ChunkCompare as C; + use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), Operator::NotEq => apply_operator_stats_neq(min_max, literal), @@ -351,7 +351,7 @@ mod stats { } fn apply_operator_stats_lhs_lit(literal: &Series, min_max: &Series, op: Operator) -> bool { - use ChunkCompare as C; + use ChunkCompareIneq as C; match op { Operator::Eq => apply_operator_stats_eq(min_max, literal), Operator::NotEq => apply_operator_stats_eq(min_max, literal), diff --git a/crates/polars-ops/src/chunked_array/array/count.rs b/crates/polars-ops/src/chunked_array/array/count.rs index ef54e7b70591..466f148463bf 100644 --- a/crates/polars-ops/src/chunked_array/array/count.rs +++ b/crates/polars-ops/src/chunked_array/array/count.rs @@ -11,7 +11,7 @@ pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult::equal_missing(&s, &value).map(|ca| ca.into_series()) + ChunkCompareEq::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) })?; let out = count_boolean_bits(&ca); Ok(out.into_series()) diff --git a/crates/polars-ops/src/chunked_array/list/count.rs b/crates/polars-ops/src/chunked_array/list/count.rs index e54c603f3a25..89fdd71ed5d2 100644 --- a/crates/polars-ops/src/chunked_array/list/count.rs +++ b/crates/polars-ops/src/chunked_array/list/count.rs @@ -45,7 +45,7 @@ pub fn list_count_matches(ca: &ListChunked, value: AnyValue) -> PolarsResult::equal_missing(&s, &value).map(|ca| ca.into_series()) + ChunkCompareEq::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) })?; let out = count_boolean_bits(&ca); Ok(out.into_series()) diff --git a/crates/polars-ops/src/chunked_array/peaks.rs b/crates/polars-ops/src/chunked_array/peaks.rs index 437756a44327..7631a07ac141 100644 --- a/crates/polars-ops/src/chunked_array/peaks.rs +++ b/crates/polars-ops/src/chunked_array/peaks.rs @@ -4,7 +4,7 @@ use polars_core::prelude::*; /// Get a boolean mask of the local maximum peaks. pub fn peak_max(ca: &ChunkedArray) -> BooleanChunked where - ChunkedArray: for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, + ChunkedArray: for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, { let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); @@ -14,7 +14,7 @@ where /// Get a boolean mask of the local minimum peaks. pub fn peak_min(ca: &ChunkedArray) -> BooleanChunked where - ChunkedArray: for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, + ChunkedArray: for<'a> ChunkCompareIneq<&'a ChunkedArray, Item = BooleanChunked>, { let shift_left = ca.shift_and_fill(1, Some(Zero::zero())); let shift_right = ca.shift_and_fill(-1, Some(Zero::zero())); From 13e97174287d9102496834ac82ca9720be80f0d4 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Fri, 27 Sep 2024 21:34:20 +1000 Subject: [PATCH 19/33] feat: Add `allow_missing_columns` option to `read/scan_parquet` (#18922) --- .../polars-io/src/parquet/read/async_impl.rs | 15 ++-- .../polars-io/src/parquet/read/predicates.rs | 2 +- .../polars-io/src/parquet/read/read_impl.rs | 86 ++++++++++++++----- crates/polars-io/src/parquet/read/reader.rs | 14 ++- crates/polars-lazy/src/scan/ndjson.rs | 1 + crates/polars-lazy/src/scan/parquet.rs | 3 + .../src/executors/scan/parquet.rs | 9 +- .../arrow/read/deserialize/utils/filter.rs | 2 +- .../src/parquet/metadata/row_metadata.rs | 8 +- .../src/parquet/read/column/mod.rs | 1 + .../src/executors/sources/parquet.rs | 2 + crates/polars-plan/src/plans/builder_dsl.rs | 5 ++ crates/polars-plan/src/plans/options.rs | 1 + crates/polars-python/src/lazyframe/general.rs | 4 +- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../nodes/parquet_source/metadata_fetch.rs | 14 ++- .../parquet_source/row_group_data_fetch.rs | 4 + .../nodes/parquet_source/row_group_decode.rs | 77 +++++++++++++---- crates/polars/tests/it/io/parquet/read/mod.rs | 1 + .../tests/it/io/parquet/read/row_group.rs | 1 + .../polars/tests/it/io/parquet/write/mod.rs | 1 + py-polars/polars/io/parquet/functions.py | 18 ++++ py-polars/tests/unit/io/test_parquet.py | 44 ++++++++++ 23 files changed, 256 insertions(+), 59 deletions(-) diff --git a/crates/polars-io/src/parquet/read/async_impl.rs b/crates/polars-io/src/parquet/read/async_impl.rs index 0c1ead03b85b..da50364855da 100644 --- a/crates/polars-io/src/parquet/read/async_impl.rs +++ b/crates/polars-io/src/parquet/read/async_impl.rs @@ -178,12 +178,15 @@ async fn download_projection( let mut offsets = Vec::with_capacity(fields.len()); fields.iter().for_each(|name| { // A single column can have multiple matches (structs). - let iter = row_group.columns_under_root_iter(name).map(|meta| { - let byte_range = meta.byte_range(); - let offset = byte_range.start; - let byte_range = byte_range.start as usize..byte_range.end as usize; - (offset, byte_range) - }); + let iter = row_group + .columns_under_root_iter(name) + .unwrap() + .map(|meta| { + let byte_range = meta.byte_range(); + let offset = byte_range.start; + let byte_range = byte_range.start as usize..byte_range.end as usize; + (offset, byte_range) + }); for (offset, range) in iter { offsets.push(offset); diff --git a/crates/polars-io/src/parquet/read/predicates.rs b/crates/polars-io/src/parquet/read/predicates.rs index 87615de1b8c2..eb8f7747f078 100644 --- a/crates/polars-io/src/parquet/read/predicates.rs +++ b/crates/polars-io/src/parquet/read/predicates.rs @@ -24,7 +24,7 @@ pub(crate) fn collect_statistics( let stats = schema .iter_values() .map(|field| { - let iter = md.columns_under_root_iter(&field.name); + let iter = md.columns_under_root_iter(&field.name).unwrap(); Ok(if iter.len() == 0 { ColumnStats::new(field.into(), None, None, None) diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 830208776d30..45aa2260de30 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -326,12 +326,19 @@ fn rg_to_dfs_prefiltered( .map(|i| { let col_idx = live_idx_to_col_idx[i]; - let name = schema.get_at_index(col_idx).unwrap().0; - let field_md = file_metadata.row_groups[rg_idx] - .columns_under_root_iter(name) - .collect::>(); + let (name, field) = schema.get_at_index(col_idx).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + md.num_rows(), + &DataType::from_arrow(&field.dtype, true), + )); + }; + + let part = iter.collect::>(); - column_idx_to_series(col_idx, field_md.as_slice(), None, schema, store) + column_idx_to_series(col_idx, part.as_slice(), None, schema, store) .map(Column::from) }) .collect::>>()?; @@ -384,20 +391,30 @@ fn rg_to_dfs_prefiltered( .then(|| calc_prefilter_cost(&filter_mask)) .unwrap_or_default(); + #[cfg(debug_assertions)] + { + let md = &file_metadata.row_groups[rg_idx]; + debug_assert_eq!(md.num_rows(), mask.len()); + } + + let n_rows_in_result = filter_mask.set_bits(); + let mut dead_columns = (0..num_dead_columns) .into_par_iter() .map(|i| { let col_idx = dead_idx_to_col_idx[i]; - let name = schema.get_at_index(col_idx).unwrap().0; - #[cfg(debug_assertions)] - { - let md = &file_metadata.row_groups[rg_idx]; - debug_assert_eq!(md.num_rows(), mask.len()); - } - let field_md = file_metadata.row_groups[rg_idx] - .columns_under_root_iter(name) - .collect::>(); + let (name, field) = schema.get_at_index(col_idx).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + n_rows_in_result, + &DataType::from_arrow(&field.dtype, true), + )); + }; + + let field_md = iter.collect::>(); let pre = || { column_idx_to_series( @@ -556,8 +573,17 @@ fn rg_to_dfs_optionally_par_over_columns( projection .par_iter() .map(|column_i| { - let name = schema.get_at_index(*column_i).unwrap().0; - let part = md.columns_under_root_iter(name).collect::>(); + let (name, field) = schema.get_at_index(*column_i).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + rg_slice.1, + &DataType::from_arrow(&field.dtype, true), + )); + }; + + let part = iter.collect::>(); column_idx_to_series( *column_i, @@ -574,8 +600,17 @@ fn rg_to_dfs_optionally_par_over_columns( projection .iter() .map(|column_i| { - let name = schema.get_at_index(*column_i).unwrap().0; - let part = md.columns_under_root_iter(name).collect::>(); + let (name, field) = schema.get_at_index(*column_i).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + rg_slice.1, + &DataType::from_arrow(&field.dtype, true), + )); + }; + + let part = iter.collect::>(); column_idx_to_series( *column_i, @@ -672,12 +707,21 @@ fn rg_to_dfs_par_over_rg( let columns = projection .iter() .map(|column_i| { - let name = schema.get_at_index(*column_i).unwrap().0; - let field_md = md.columns_under_root_iter(name).collect::>(); + let (name, field) = schema.get_at_index(*column_i).unwrap(); + + let Some(iter) = md.columns_under_root_iter(name) else { + return Ok(Column::full_null( + name.clone(), + md.num_rows(), + &DataType::from_arrow(&field.dtype, true), + )); + }; + + let part = iter.collect::>(); column_idx_to_series( *column_i, - field_md.as_slice(), + part.as_slice(), Some(Filter::new_ranged(slice.0, slice.0 + slice.1)), schema, store, diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 3f4328943973..3cef93a89d73 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -85,9 +85,14 @@ impl ParquetReader { /// dtype, and sets the projection indices. pub fn with_arrow_schema_projection( mut self, - first_schema: &ArrowSchema, + first_schema: &Arc, projected_arrow_schema: Option<&ArrowSchema>, + allow_missing_columns: bool, ) -> PolarsResult { + if allow_missing_columns { + self.schema.replace(first_schema.clone()); + } + let schema = self.schema()?; if let Some(projected_arrow_schema) = projected_arrow_schema { @@ -301,9 +306,14 @@ impl ParquetAsyncReader { pub async fn with_arrow_schema_projection( mut self, - first_schema: &ArrowSchema, + first_schema: &Arc, projected_arrow_schema: Option<&ArrowSchema>, + allow_missing_columns: bool, ) -> PolarsResult { + if allow_missing_columns { + self.schema.replace(first_schema.clone()); + } + let schema = self.schema().await?; if let Some(projected_arrow_schema) = projected_arrow_schema { diff --git a/crates/polars-lazy/src/scan/ndjson.rs b/crates/polars-lazy/src/scan/ndjson.rs index a44ce9053ef5..635d23c2ee2d 100644 --- a/crates/polars-lazy/src/scan/ndjson.rs +++ b/crates/polars-lazy/src/scan/ndjson.rs @@ -137,6 +137,7 @@ impl LazyFileListReader for LazyJsonLineReader { }, glob: true, include_file_paths: self.include_file_paths, + allow_missing_columns: false, }; let options = NDJsonReadOptions { diff --git a/crates/polars-lazy/src/scan/parquet.rs b/crates/polars-lazy/src/scan/parquet.rs index eb26eafb6144..382addea7ed1 100644 --- a/crates/polars-lazy/src/scan/parquet.rs +++ b/crates/polars-lazy/src/scan/parquet.rs @@ -21,6 +21,7 @@ pub struct ScanArgsParquet { /// Expand path given via globbing rules. pub glob: bool, pub include_file_paths: Option, + pub allow_missing_columns: bool, } impl Default for ScanArgsParquet { @@ -37,6 +38,7 @@ impl Default for ScanArgsParquet { cache: true, glob: true, include_file_paths: None, + allow_missing_columns: false, } } } @@ -74,6 +76,7 @@ impl LazyFileListReader for LazyParquetReader { self.args.hive_options, self.args.glob, self.args.include_file_paths, + self.args.allow_missing_columns, )? .build() .into(); diff --git a/crates/polars-mem-engine/src/executors/scan/parquet.rs b/crates/polars-mem-engine/src/executors/scan/parquet.rs index e15f8ee8be00..49b01c471610 100644 --- a/crates/polars-mem-engine/src/executors/scan/parquet.rs +++ b/crates/polars-mem-engine/src/executors/scan/parquet.rs @@ -202,6 +202,8 @@ impl ParquetExec { }) .collect::>(); + let allow_missing_columns = self.file_options.allow_missing_columns; + let out = POOL.install(|| { readers_and_metadata .into_par_iter() @@ -217,8 +219,9 @@ impl ParquetExec { .with_row_index(row_index) .with_predicate(predicate.clone()) .with_arrow_schema_projection( - first_schema.as_ref(), + &first_schema, projected_arrow_schema.as_deref(), + allow_missing_columns, )? .finish()?; @@ -395,6 +398,7 @@ impl ParquetExec { let first_schema = first_schema.clone(); let projected_arrow_schema = projected_arrow_schema.clone(); let predicate = predicate.clone(); + let allow_missing_columns = self.file_options.allow_missing_columns; if verbose { eprintln!("reading of {}/{} file...", processed, paths.len()); @@ -422,8 +426,9 @@ impl ParquetExec { .with_slice(Some(slice)) .with_row_index(row_index) .with_arrow_schema_projection( - first_schema.as_ref(), + &first_schema, projected_arrow_schema.as_deref(), + allow_missing_columns, ) .await? .use_statistics(use_statistics) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs index 03e641634467..a9f0f7b3ef87 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs @@ -29,7 +29,7 @@ impl Filter { } } - pub(crate) fn num_rows(&self) -> usize { + pub fn num_rows(&self) -> usize { match self { Filter::Range(range) => range.len(), Filter::Mask(bitmap) => bitmap.set_bits(), diff --git a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs index 013308ad7f12..9cca27553415 100644 --- a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs @@ -49,16 +49,14 @@ impl RowGroupMetadata { self.columns.len() } - /// Fetch all columns under this root name. + /// Fetch all columns under this root name if it exists. pub fn columns_under_root_iter( &self, root_name: &str, - ) -> impl ExactSizeIterator + DoubleEndedIterator { + ) -> Option + DoubleEndedIterator> { self.column_lookup .get(root_name) - .unwrap() - .iter() - .map(|&x| &self.columns[x]) + .map(|x| x.iter().map(|&x| &self.columns[x])) } /// Number of rows in this row group. diff --git a/crates/polars-parquet/src/parquet/read/column/mod.rs b/crates/polars-parquet/src/parquet/read/column/mod.rs index 56f914ba568e..8e5a1f533375 100644 --- a/crates/polars-parquet/src/parquet/read/column/mod.rs +++ b/crates/polars-parquet/src/parquet/read/column/mod.rs @@ -23,6 +23,7 @@ pub fn get_column_iterator<'a>( ) -> ColumnIterator<'a> { let columns = row_group .columns_under_root_iter(field_name) + .unwrap() .rev() .collect::>(); ColumnIterator::new(reader, columns, max_page_size) diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index faed9d4b667e..efe3edac1b87 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -134,6 +134,7 @@ impl ParquetSource { .with_arrow_schema_projection( &self.first_schema, self.projected_arrow_schema.as_deref(), + self.file_options.allow_missing_columns, )? .with_row_index(file_options.row_index) .with_predicate(predicate.clone()) @@ -199,6 +200,7 @@ impl ParquetSource { .with_arrow_schema_projection( &self.first_schema, self.projected_arrow_schema.as_deref(), + self.file_options.allow_missing_columns, ) .await? .with_predicate(predicate.clone()) diff --git a/crates/polars-plan/src/plans/builder_dsl.rs b/crates/polars-plan/src/plans/builder_dsl.rs index 5458c0442abe..bc695b47a035 100644 --- a/crates/polars-plan/src/plans/builder_dsl.rs +++ b/crates/polars-plan/src/plans/builder_dsl.rs @@ -54,6 +54,7 @@ impl DslBuilder { }, glob: false, include_file_paths: None, + allow_missing_columns: false, }; Ok(DslPlan::Scan { @@ -87,6 +88,7 @@ impl DslBuilder { hive_options: HiveOptions, glob: bool, include_file_paths: Option, + allow_missing_columns: bool, ) -> PolarsResult { let options = FileScanOptions { with_columns: None, @@ -98,6 +100,7 @@ impl DslBuilder { hive_options, glob, include_file_paths, + allow_missing_columns, }; Ok(DslPlan::Scan { sources, @@ -143,6 +146,7 @@ impl DslBuilder { hive_options, glob: true, include_file_paths, + allow_missing_columns: false, }, scan_type: FileScan::Ipc { options, @@ -181,6 +185,7 @@ impl DslBuilder { }, glob, include_file_paths, + allow_missing_columns: false, }; Ok(DslPlan::Scan { sources, diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index 078acbae7177..f0df191d395f 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -39,6 +39,7 @@ pub struct FileScanOptions { pub hive_options: HiveOptions, pub glob: bool, pub include_file_paths: Option, + pub allow_missing_columns: bool, } #[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)] diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 3d0327c8f2cd..18b9323388e7 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -240,7 +240,7 @@ impl PyLazyFrame { #[cfg(feature = "parquet")] #[staticmethod] #[pyo3(signature = (source, sources, n_rows, cache, parallel, rechunk, row_index, - low_memory, cloud_options, use_statistics, hive_partitioning, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths) + low_memory, cloud_options, use_statistics, hive_partitioning, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns) )] fn new_from_parquet( source: Option, @@ -259,6 +259,7 @@ impl PyLazyFrame { retries: usize, glob: bool, include_file_paths: Option, + allow_missing_columns: bool, ) -> PyResult { let parallel = parallel.0; let hive_schema = hive_schema.map(|s| Arc::new(s.0)); @@ -287,6 +288,7 @@ impl PyLazyFrame { hive_options, glob, include_file_paths: include_file_paths.map(|x| x.into()), + allow_missing_columns, }; let sources = sources.0; diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 8507d590d84c..726b5e7debd4 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (2, 0); + const VERSION: Version = (2, 1); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index f65a4436a75f..0bee88861769 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -6,7 +6,6 @@ use polars_io::prelude::FileMetadata; use polars_io::utils::byte_source::{DynByteSource, MemSliceByteSource}; use polars_io::utils::slice::SplitSlicePosition; use polars_utils::mmap::MemSlice; -use polars_utils::pl_str::PlSmallStr; use super::metadata_utils::{ensure_schema_has_projected_fields, read_parquet_metadata_bytes}; use super::ParquetSourceNode; @@ -116,6 +115,7 @@ impl ParquetSourceNode { .unwrap_left() .len(); let has_projection = self.file_options.with_columns.is_some(); + let allow_missing_columns = self.file_options.allow_missing_columns; let process_metadata_bytes = { move |handle: task_handles_ext::AbortOnDropHandle< @@ -145,7 +145,12 @@ impl ParquetSourceNode { ) } - ensure_schema_has_projected_fields(&schema, projected_arrow_schema.as_ref())?; + if !allow_missing_columns { + ensure_schema_has_projected_fields( + &schema, + projected_arrow_schema.as_ref(), + )?; + } PolarsResult::Ok((path_index, byte_source, metadata)) }); @@ -213,11 +218,12 @@ impl ParquetSourceNode { let (path_index, byte_source, metadata) = v.map_err(|err| { err.wrap_msg(|msg| { format!( - "error at path (index: {}, path: {:?}): {}", + "error at path (index: {}, path: {}): {}", current_path_index, scan_sources .get(current_path_index) - .map(|x| PlSmallStr::from_str(x.to_include_path_name())), + .unwrap() + .to_include_path_name(), msg ) }) diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs index e3d2ba329d68..dfa4b11e3b02 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs @@ -353,6 +353,10 @@ fn get_row_group_byte_ranges_for_projection<'a>( columns.iter().flat_map(|col_name| { row_group_metadata .columns_under_root_iter(col_name) + // `Option::into_iter` so that we return an empty iterator for the + // `allow_missing_columns` case + .into_iter() + .flatten() .map(|col| { let byte_range = col.byte_range(); byte_range.start as usize..byte_range.end as usize diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs index dc8283b7f735..eda18101d1a0 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs @@ -215,6 +215,11 @@ impl RowGroupDecoder { filter: Option, ) -> PolarsResult<()> { let projected_arrow_schema = &self.projected_arrow_schema; + let expected_num_rows = filter + .as_ref() + .map_or(row_group_data.row_group_metadata.num_rows(), |x| { + x.num_rows() + }); let Some((cols_per_thread, remainder)) = calc_cols_per_thread( row_group_data.row_group_metadata.num_rows(), @@ -222,10 +227,14 @@ impl RowGroupDecoder { self.min_values_per_thread, ) else { // Single-threaded - for s in projected_arrow_schema - .iter_values() - .map(|arrow_field| decode_column(arrow_field, row_group_data, filter.clone())) - { + for s in projected_arrow_schema.iter_values().map(|arrow_field| { + decode_column( + arrow_field, + row_group_data, + filter.clone(), + expected_num_rows, + ) + }) { out_vec.push(s?) } @@ -253,7 +262,12 @@ impl RowGroupDecoder { let (_, arrow_field) = projected_arrow_schema.get_at_index(i).unwrap(); - decode_column(arrow_field, &row_group_data, filter.clone()) + decode_column( + arrow_field, + &row_group_data, + filter.clone(), + expected_num_rows, + ) }) .collect::>>() } @@ -270,7 +284,14 @@ impl RowGroupDecoder { for out in projected_arrow_schema .iter_values() .take(remainder) - .map(|arrow_field| decode_column(arrow_field, row_group_data, filter.clone())) + .map(|arrow_field| { + decode_column( + arrow_field, + row_group_data, + filter.clone(), + expected_num_rows, + ) + }) { out_vec.push(out?); } @@ -307,10 +328,20 @@ fn decode_column( arrow_field: &ArrowField, row_group_data: &RowGroupData, filter: Option, + expected_num_rows: usize, ) -> PolarsResult { - let columns_to_deserialize = row_group_data + let Some(iter) = row_group_data .row_group_metadata .columns_under_root_iter(&arrow_field.name) + else { + return Ok(Column::full_null( + arrow_field.name.clone(), + expected_num_rows, + &DataType::from_arrow(&arrow_field.dtype, true), + )); + }; + + let columns_to_deserialize = iter .map(|col_md| { let byte_range = col_md.byte_range(); @@ -329,6 +360,8 @@ fn decode_column( filter, )?; + assert_eq!(array.len(), expected_num_rows); + let series = Series::try_from((arrow_field, array))?; // TODO: Also load in the metadata. @@ -440,14 +473,12 @@ impl RowGroupDecoder { let row_group_data = Arc::new(row_group_data); - let mut live_columns = { - let capacity = self.row_index.is_some() as usize + let mut live_columns = Vec::with_capacity( + self.row_index.is_some() as usize + self.predicate_arrow_field_indices.len() + self.hive_partitions_width - + self.include_file_paths.is_some() as usize; - - Vec::with_capacity(capacity) - }; + + self.include_file_paths.is_some() as usize, + ); if let Some(s) = self.materialize_row_index( row_group_data.as_ref(), @@ -479,7 +510,9 @@ impl RowGroupDecoder { .predicate_arrow_field_indices .iter() .map(|&i| self.projected_arrow_schema.get_at_index(i).unwrap()) - .map(|(_, arrow_field)| decode_column(arrow_field, &row_group_data, None)) + .map(|(_, arrow_field)| { + decode_column(arrow_field, &row_group_data, None, projection_height) + }) { live_columns.push(s?); } @@ -514,6 +547,7 @@ impl RowGroupDecoder { assert_eq!(mask_bitmap.len(), projection_height); let prefilter_cost = calc_prefilter_cost(&mask_bitmap); + let expected_num_rows = mask_bitmap.set_bits(); let dead_cols_filtered = self .non_predicate_arrow_field_indices @@ -527,6 +561,7 @@ impl RowGroupDecoder { prefilter_setting, mask, &mask_bitmap, + expected_num_rows, ) }) .collect::>>()?; @@ -569,10 +604,20 @@ fn decode_column_prefiltered( prefilter_setting: &PrefilterMaskSetting, mask: &BooleanChunked, mask_bitmap: &Bitmap, + expected_num_rows: usize, ) -> PolarsResult { - let columns_to_deserialize = row_group_data + let Some(iter) = row_group_data .row_group_metadata .columns_under_root_iter(&arrow_field.name) + else { + return Ok(Column::full_null( + arrow_field.name.clone(), + expected_num_rows, + &DataType::from_arrow(&arrow_field.dtype, true), + )); + }; + + let columns_to_deserialize = iter .map(|col_md| { let byte_range = col_md.byte_range(); @@ -596,6 +641,8 @@ fn decode_column_prefiltered( deserialize_filter, )?; + assert_eq!(array.len(), expected_num_rows); + let column = Series::try_from((arrow_field, array))?.into_column(); if !prefilter { diff --git a/crates/polars/tests/it/io/parquet/read/mod.rs b/crates/polars/tests/it/io/parquet/read/mod.rs index 60ed6108edb7..d671e085c86a 100644 --- a/crates/polars/tests/it/io/parquet/read/mod.rs +++ b/crates/polars/tests/it/io/parquet/read/mod.rs @@ -205,6 +205,7 @@ pub fn read_column( let mut statistics = metadata.row_groups[row_group] .columns_under_root_iter(field.name()) + .unwrap() .map(|column_meta| column_meta.statistics().transpose()) .collect::>>()?; diff --git a/crates/polars/tests/it/io/parquet/read/row_group.rs b/crates/polars/tests/it/io/parquet/read/row_group.rs index 6d567a120c92..80478a0da958 100644 --- a/crates/polars/tests/it/io/parquet/read/row_group.rs +++ b/crates/polars/tests/it/io/parquet/read/row_group.rs @@ -75,6 +75,7 @@ pub fn read_columns<'a, R: Read + Seek>( ) -> PolarsResult)>> { row_group_metadata .columns_under_root_iter(field_name) + .unwrap() .map(|meta| _read_single_column(reader, meta)) .collect() } diff --git a/crates/polars/tests/it/io/parquet/write/mod.rs b/crates/polars/tests/it/io/parquet/write/mod.rs index 4403277a0552..e98a0223937f 100644 --- a/crates/polars/tests/it/io/parquet/write/mod.rs +++ b/crates/polars/tests/it/io/parquet/write/mod.rs @@ -215,6 +215,7 @@ fn basic() -> ParquetResult<()> { assert_eq!( metadata.row_groups[0] .columns_under_root_iter("col") + .unwrap() .next() .unwrap() .uncompressed_size(), diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index bc434b05cc2d..687c827a9c57 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -60,6 +60,7 @@ def read_parquet( use_pyarrow: bool = False, pyarrow_options: dict[str, Any] | None = None, memory_map: bool = True, + allow_missing_columns: bool = False, ) -> DataFrame: """ Read into a DataFrame from a parquet file. @@ -139,6 +140,12 @@ def read_parquet( memory_map Memory map underlying file. This will likely increase performance. Only used when `use_pyarrow=True`. + allow_missing_columns + When reading a list of parquet files, if a column existing in the first + file cannot be found in subsequent files, the default behavior is to + raise an error. However, if `allow_missing_columns` is set to + `True`, a full-NULL column is returned instead of erroring for the files + that do not contain the column. Returns ------- @@ -198,6 +205,7 @@ def read_parquet( retries=retries, glob=glob, include_file_paths=None, + allow_missing_columns=allow_missing_columns, ) if columns is not None: @@ -307,6 +315,7 @@ def scan_parquet( storage_options: dict[str, Any] | None = None, retries: int = 2, include_file_paths: str | None = None, + allow_missing_columns: bool = False, ) -> LazyFrame: """ Lazily read from a local or cloud-hosted parquet file (or files). @@ -388,6 +397,12 @@ def scan_parquet( Number of retries if accessing a cloud instance fails. include_file_paths Include the path of the source file(s) as a column with this name. + allow_missing_columns + When reading a list of parquet files, if a column existing in the first + file cannot be found in subsequent files, the default behavior is to + raise an error. However, if `allow_missing_columns` is set to + `True`, a full-NULL column is returned instead of erroring for the files + that do not contain the column. See Also -------- @@ -439,6 +454,7 @@ def scan_parquet( retries=retries, glob=glob, include_file_paths=include_file_paths, + allow_missing_columns=allow_missing_columns, ) @@ -460,6 +476,7 @@ def _scan_parquet_impl( try_parse_hive_dates: bool = True, retries: int = 2, include_file_paths: str | None = None, + allow_missing_columns: bool = False, ) -> LazyFrame: if isinstance(source, list): sources = source @@ -490,5 +507,6 @@ def _scan_parquet_impl( retries=retries, glob=glob, include_file_paths=include_file_paths, + allow_missing_columns=allow_missing_columns, ) return wrap_ldf(pylf) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 39e515bf31d8..3f7e8b977505 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1911,3 +1911,47 @@ def test_prefilter_with_projection() -> None: .select(pl.col.a) .collect() ) + + +@pytest.mark.parametrize("parallel", ["columns", "row_groups", "prefiltered", "none"]) +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize("projection", [pl.all(), pl.col("b")]) +@pytest.mark.write_disk +def test_allow_missing_columns( + tmp_path: Path, + parallel: str, + streaming: bool, + projection: pl.Expr, +) -> None: + tmp_path.mkdir(exist_ok=True) + dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2})] + paths = [tmp_path / "1", tmp_path / "2"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + expected = pl.DataFrame({"a": [1, 2], "b": [1, None]}).select(projection) + + with pytest.raises(pl.exceptions.SchemaError, match="did not find column"): + pl.read_parquet(paths, parallel=parallel) # type: ignore[arg-type] + + with pytest.raises(pl.exceptions.SchemaError, match="did not find column"): + pl.scan_parquet(paths, parallel=parallel).select(projection).collect( # type: ignore[arg-type] + streaming=streaming + ) + + assert_frame_equal( + pl.read_parquet( + paths, + parallel=parallel, # type: ignore[arg-type] + allow_missing_columns=True, + ).select(projection), + expected, + ) + + assert_frame_equal( + pl.scan_parquet(paths, parallel=parallel, allow_missing_columns=True) # type: ignore[arg-type] + .select(projection) + .collect(streaming=streaming), + expected, + ) From 6bf93c47dbe7995d0208f704bd9c90c76d428ce1 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 13:53:38 +0200 Subject: [PATCH 20/33] refactor: Add FixedSizeList equality broadcasting (#18967) --- .../polars-compute/src/comparisons/array.rs | 172 +++++++++++++++++- crates/polars-compute/src/comparisons/list.rs | 2 +- .../src/chunked_array/comparison/mod.rs | 70 ++++--- 3 files changed, 213 insertions(+), 31 deletions(-) diff --git a/crates/polars-compute/src/comparisons/array.rs b/crates/polars-compute/src/comparisons/array.rs index 23d43887a280..facde12a5c37 100644 --- a/crates/polars-compute/src/comparisons/array.rs +++ b/crates/polars-compute/src/comparisons/array.rs @@ -1,7 +1,13 @@ -use arrow::array::{Array, FixedSizeListArray}; +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + FixedSizeListArray, ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, + Utf8ViewArray, +}; use arrow::bitmap::utils::count_zeros; use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::{days_ms, f16, i256, months_days_ns}; use super::TotalEqKernel; use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; @@ -52,6 +58,8 @@ impl TotalEqKernel for FixedSizeListArray { return Bitmap::new_with_value(true, self.len()); } + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref()); agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0) @@ -77,16 +85,170 @@ impl TotalEqKernel for FixedSizeListArray { return Bitmap::new_with_value(false, self.len()); } + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref()); agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size()) } - fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_logical_type() else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.dtype(), other.dtype().to_logical_type()); + + let width = *width; + + if width != other.len() { + return Bitmap::new_with_value(false, self.len()); + } + + if width == 0 { + return Bitmap::new_with_value(true, self.len()); + } + + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. + array_fsl_tot_eq_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width) } - fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + let ArrowDataType::FixedSizeList(self_type, width) = self.dtype().to_logical_type() else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.dtype(), other.dtype().to_logical_type()); + + let width = *width; + + if width != other.len() { + return Bitmap::new_with_value(true, self.len()); + } + + if width == 0 { + return Bitmap::new_with_value(false, self.len()); + } + + // @TODO: It is probably worth it to dispatch to a special kernel for when there are + // several nested arrays because that can be rather slow with this code. + array_fsl_tot_ne_missing_kernel(self.values().as_ref(), other.as_ref(), self.len(), width) } } + +macro_rules! compare { + ($lhs:expr, $rhs:expr, $length:expr, $width:expr, $op:path, $true_op:expr) => {{ + let lhs = $lhs; + let rhs = $rhs; + + macro_rules! call_binary { + ($T:ty) => {{ + let values: &$T = $lhs.as_any().downcast_ref().unwrap(); + let scalar: &$T = $rhs.as_any().downcast_ref().unwrap(); + + (0..$length) + .map(move |i| { + // @TODO: I feel like there is a better way to do this. + let mut values: $T = values.clone(); + <$T>::slice(&mut values, i * $width, $width); + + $true_op($op(&values, scalar)) + }) + .collect_trusted() + }}; + } + + assert_eq!(lhs.dtype(), rhs.dtype()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} + +fn array_fsl_tot_eq_missing_kernel( + values: &dyn Array, + scalar: &dyn Array, + length: usize, + width: usize, +) -> Bitmap { + // @NOTE: Zero-Width Array are handled before + debug_assert_eq!(values.len(), length * width); + debug_assert_eq!(scalar.len(), width); + + compare!( + values, + scalar, + length, + width, + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0 + ) +} + +fn array_fsl_tot_ne_missing_kernel( + values: &dyn Array, + scalar: &dyn Array, + length: usize, + width: usize, +) -> Bitmap { + // @NOTE: Zero-Width Array are handled before + debug_assert_eq!(values.len(), length * width); + debug_assert_eq!(scalar.len(), width); + + compare!( + values, + scalar, + length, + width, + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0 + ) +} diff --git a/crates/polars-compute/src/comparisons/list.rs b/crates/polars-compute/src/comparisons/list.rs index a66ad4f4312a..fa35cbaac9b6 100644 --- a/crates/polars-compute/src/comparisons/list.rs +++ b/crates/polars-compute/src/comparisons/list.rs @@ -32,7 +32,7 @@ impl TotalEqKernel for ListArray { let mut lhs_values = self.values().clone(); lhs_values.slice(lstart, lend - lstart); - let mut rhs_values = self.values().clone(); + let mut rhs_values = other.values().clone(); rhs_values.slice(rstart, rend - rstart); let result = array_tot_eq_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index ecf8f78fcd9a..4ee9cad0a482 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -6,7 +6,7 @@ mod categorical; use std::ops::{BitAnd, Not}; use arrow::array::BooleanArray; -use arrow::bitmap::MutableBitmap; +use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::compute; use num_traits::{NumCast, ToPrimitive}; use polars_compute::comparisons::{TotalEqKernel, TotalOrdKernel}; @@ -792,54 +792,74 @@ impl ChunkCompareEq<&StructChunked> for StructChunked { } } +#[cfg(feature = "dtype-array")] +fn _array_comparison_helper( + lhs: &ArrayChunked, + rhs: &ArrayChunked, + op: F, + broadcast_op: B, +) -> BooleanChunked +where + F: Fn(&FixedSizeListArray, &FixedSizeListArray) -> Bitmap, + B: Fn(&FixedSizeListArray, &Box) -> Bitmap, +{ + match (lhs.len(), rhs.len()) { + (_, 1) => { + let right = rhs.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + arity::unary_mut_values(lhs, |a| broadcast_op(a, right).into()) + }, + (1, _) => { + let left = lhs.chunks()[0] + .as_any() + .downcast_ref::() + .unwrap() + .values(); + arity::unary_mut_values(rhs, |a| broadcast_op(a, left).into()) + }, + _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), + } +} + #[cfg(feature = "dtype-array")] impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked { - if self.width() != rhs.width() { - return BooleanChunked::full(PlSmallStr::EMPTY, false, self.len()); - } - arity::binary_mut_values( + _array_comparison_helper( self, rhs, - |a, b| a.tot_eq_kernel(b).into(), - PlSmallStr::EMPTY, + TotalEqKernel::tot_eq_kernel, + TotalEqKernel::tot_eq_kernel_broadcast, ) } fn equal_missing(&self, rhs: &ArrayChunked) -> BooleanChunked { - if self.width() != rhs.width() { - return BooleanChunked::full(PlSmallStr::EMPTY, false, self.len()); - } - arity::binary_mut_with_options( + _array_comparison_helper( self, rhs, - |a, b| a.tot_eq_missing_kernel(b).into(), - PlSmallStr::EMPTY, + TotalEqKernel::tot_eq_missing_kernel, + TotalEqKernel::tot_eq_missing_kernel_broadcast, ) } fn not_equal(&self, rhs: &ArrayChunked) -> BooleanChunked { - if self.width() != rhs.width() { - return BooleanChunked::full(PlSmallStr::EMPTY, true, self.len()); - } - arity::binary_mut_values( + _array_comparison_helper( self, rhs, - |a, b| a.tot_ne_kernel(b).into(), - PlSmallStr::EMPTY, + TotalEqKernel::tot_ne_kernel, + TotalEqKernel::tot_ne_kernel_broadcast, ) } fn not_equal_missing(&self, rhs: &ArrayChunked) -> Self::Item { - if self.width() != rhs.width() { - return BooleanChunked::full(PlSmallStr::EMPTY, true, self.len()); - } - arity::binary_mut_with_options( + _array_comparison_helper( self, rhs, - |a, b| a.tot_ne_missing_kernel(b).into(), - PlSmallStr::EMPTY, + TotalEqKernel::tot_ne_missing_kernel, + TotalEqKernel::tot_ne_missing_kernel_broadcast, ) } } From 8bb5a02242cb4d54067018a2079603d4c2ea17ca Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 27 Sep 2024 14:44:37 +0200 Subject: [PATCH 21/33] refactor: Disable CSE-specific test on new streaming engine (#18971) --- py-polars/tests/unit/test_cse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index 6874512e6133..d6fdb8976c5b 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -777,6 +777,7 @@ def test_cse_chunks_18124() -> None: ).collect().shape == (4, 4) +@pytest.mark.may_fail_auto_streaming def test_eager_cse_during_struct_expansion_18411() -> None: df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]}) vc = pl.col("foo").value_counts() From 538fd6c455f6536c86f4aea15fefd33c0446a9de Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 27 Sep 2024 14:44:54 +0200 Subject: [PATCH 22/33] fix(rust): Window function had incorrect output name on ExprIR (#18970) --- .../src/plans/conversion/expr_to_ir.rs | 4 +++- crates/polars-plan/src/plans/expr_ir.rs | 16 ++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index 1e6457eed810..fe13dd1d3592 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -304,6 +304,8 @@ pub(super) fn to_aexpr_impl( order_by, options, } => { + // Process function first so name is correct. + let function = to_aexpr_impl(owned(function), arena, state)?; let order_by = if let Some((e, options)) = order_by { Some((to_aexpr_impl(owned(e.clone()), arena, state)?, options)) } else { @@ -311,7 +313,7 @@ pub(super) fn to_aexpr_impl( }; AExpr::Window { - function: to_aexpr_impl(owned(function), arena, state)?, + function, partition_by: to_aexprs(partition_by, arena, state)?, order_by, options, diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index 8d66ceb9072f..748b20a8740c 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -28,17 +28,21 @@ pub enum OutputName { } impl OutputName { - pub fn unwrap(&self) -> &PlSmallStr { + pub fn get(&self) -> Option<&PlSmallStr> { match self { - OutputName::Alias(name) => name, - OutputName::ColumnLhs(name) => name, - OutputName::LiteralLhs(name) => name, + OutputName::Alias(name) => Some(name), + OutputName::ColumnLhs(name) => Some(name), + OutputName::LiteralLhs(name) => Some(name), #[cfg(feature = "dtype-struct")] - OutputName::Field(name) => name, - OutputName::None => panic!("no output name set"), + OutputName::Field(name) => Some(name), + OutputName::None => None, } } + pub fn unwrap(&self) -> &PlSmallStr { + self.get().expect("no output name set") + } + pub(crate) fn is_none(&self) -> bool { matches!(self, OutputName::None) } From 73a334482dc59e3d3700b01f9bb08ae1760e9118 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 27 Sep 2024 14:58:50 +0200 Subject: [PATCH 23/33] fix: Make join test order-agnostic (#18975) --- py-polars/tests/unit/streaming/test_streaming_join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index cc09dde80cd0..f936d9bff6b9 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -122,7 +122,7 @@ def test_streaming_join_rechunk_12498() -> None: b = pl.select(B=rows).lazy() q = a.join(b, how="cross") - assert q.collect(streaming=True).to_dict(as_series=False) == { + assert q.collect(streaming=True).sort(["B", "A"]).to_dict(as_series=False) == { "A": [0, 1, 0, 1], "B": [0, 0, 1, 1], } From a7432b91cfe40c763db4e732b9583d0c5b60399c Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 27 Sep 2024 14:59:57 +0200 Subject: [PATCH 24/33] docs: Recommend targetDir for rust-analyzer (#18973) --- docs/source/development/contributing/ide.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/development/contributing/ide.md b/docs/source/development/contributing/ide.md index 562216e9ff5e..31811e3f12ae 100644 --- a/docs/source/development/contributing/ide.md +++ b/docs/source/development/contributing/ide.md @@ -19,7 +19,8 @@ For it to work well for the Polars code base, add the following settings to your ```json { - "rust-analyzer.cargo.features": "all" + "rust-analyzer.cargo.features": "all", + "rust-analyzer.cargo.targetDir": true } ``` From 89fd28567f9b668394556e58f34e387243a36618 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Sat, 28 Sep 2024 00:28:18 +1000 Subject: [PATCH 25/33] refactor(rust): Fix new-streaming `test_lazy_parquet::test_row_index` (#18978) --- .../src/nodes/parquet_source/row_group_decode.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs index eda18101d1a0..057d9e15d1f6 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs @@ -645,11 +645,15 @@ fn decode_column_prefiltered( let column = Series::try_from((arrow_field, array))?.into_column(); - if !prefilter { - column.filter(mask) + let column = if !prefilter { + column.filter(mask)? } else { - Ok(column) - } + column + }; + + assert_eq!(column.len(), expected_num_rows); + + Ok(column) } mod tests { From 3904774715d9a8858618a6813c1a81e8bb775bb9 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 27 Sep 2024 16:28:46 +0200 Subject: [PATCH 26/33] fix: Respect allow_threading in TernaryExpr (#18977) --- crates/polars-expr/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 48c91fab0c9e..c7208b10d63f 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -462,7 +462,7 @@ fn create_physical_expr_inner( truthy, falsy, node_to_expr(expression, expr_arena), - lit_count < 2, + state.allow_threading && lit_count < 2, is_scalar, ))) }, From 653a4cdb96f416226be4b2c824519b517a422134 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 27 Sep 2024 16:55:55 +0200 Subject: [PATCH 27/33] fix: IPC don't write variadic_buffer_counts in blocks, but only dictionaries (#18980) --- crates/polars-arrow/src/io/ipc/write/common.rs | 10 +++------- crates/polars-plan/src/plans/ir/scan_sources.rs | 15 +++++++++++++-- py-polars/tests/unit/io/test_ipc.py | 14 ++++++++++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs index 2aebf1ec5d50..a49c7fdcd790 100644 --- a/crates/polars-arrow/src/io/ipc/write/common.rs +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -254,13 +254,9 @@ fn set_variadic_buffer_counts(counts: &mut Vec, array: &dyn Array) { let array = array.as_any().downcast_ref::().unwrap(); set_variadic_buffer_counts(counts, array.values().as_ref()) }, - ArrowDataType::Dictionary(_, _, _) => { - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - set_variadic_buffer_counts(counts, array.values().as_ref()) - }, + // Don't traverse dictionary values as those are set when the `Dictionary` IPC struct + // is read. + ArrowDataType::Dictionary(_, _, _) => (), _ => (), } } diff --git a/crates/polars-plan/src/plans/ir/scan_sources.rs b/crates/polars-plan/src/plans/ir/scan_sources.rs index f6674c70fbce..789a5c4f4811 100644 --- a/crates/polars-plan/src/plans/ir/scan_sources.rs +++ b/crates/polars-plan/src/plans/ir/scan_sources.rs @@ -1,3 +1,4 @@ +use std::fmt::{Debug, Formatter}; use std::fs::File; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -14,10 +15,10 @@ use super::FileScanOptions; /// Set of sources to scan from /// -/// This is can either be a list of paths to files, opened files or in-memory buffers. Mixing of +/// This can either be a list of paths to files, opened files or in-memory buffers. Mixing of /// buffers is not currently possible. #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Debug, Clone)] +#[derive(Clone)] pub enum ScanSources { Paths(Arc<[PathBuf]>), @@ -27,6 +28,16 @@ pub enum ScanSources { Buffers(Arc<[bytes::Bytes]>), } +impl Debug for ScanSources { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Paths(p) => write!(f, "paths: {:?}", p.as_ref()), + Self::Files(p) => write!(f, "files: {} files", p.len()), + Self::Buffers(b) => write!(f, "buffers: {} in-memory-buffers", b.len()), + } + } +} + /// A reference to a single item in [`ScanSources`] #[derive(Debug, Clone, Copy)] pub enum ScanSourceRef<'a> { diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index dd60d0ae209c..37a583d929df 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -340,3 +340,17 @@ def test_ipc_decimal_15920( path = f"{tmp_path}/data" df.write_ipc(path) assert_frame_equal(pl.read_ipc(path), df) + + +def test_ipc_variadic_buffers_categorical_binview_18636() -> None: + df = pl.DataFrame( + { + "Test": pl.Series(["Value012"], dtype=pl.Categorical), + "Test2": pl.Series(["Value Two 20032"], dtype=pl.String), + } + ) + + b = io.BytesIO() + df.write_ipc(b) + b.seek(0) + assert_frame_equal(pl.read_ipc(b), df) From 2dbb444f3c8713db175f5478c7274be1b6699796 Mon Sep 17 00:00:00 2001 From: barak1412 Date: Fri, 27 Sep 2024 18:08:22 +0300 Subject: [PATCH 28/33] fix: Return correct value for `when().then().else()` on structs when using `first()`\`last()` (#18969) --- .../polars-core/src/chunked_array/ops/zip.rs | 10 +++---- .../tests/unit/functions/test_when_then.py | 29 +++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 630cef7f4808..2e6ca0f7afb8 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -375,10 +375,10 @@ impl ChunkZip for StructChunked { .all(|(r, m)| r == m)); let combine = if l.null_count() == 0 { - |r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or_not(r, m)) + |r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m)) } else { |r: Option<&Bitmap>, m: &Bitmap| { - Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and(r, m))) + Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m))) } }; @@ -411,10 +411,10 @@ impl ChunkZip for StructChunked { .all(|(l, m)| l == m)); let combine = if r.null_count() == 0 { - |r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m)) + |l: Option<&Bitmap>, m: &Bitmap| l.map(|l| arrow::bitmap::or_not(l, m)) } else { - |r: Option<&Bitmap>, m: &Bitmap| { - Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m))) + |l: Option<&Bitmap>, m: &Bitmap| { + Some(l.map_or_else(|| m.clone(), |l| arrow::bitmap::and(l, m))) } }; diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index 9d0b5ecaced0..b84b42dd22f5 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -597,6 +597,35 @@ def test_when_then_parametric( assert ref["if_true"].to_list() == ans["if_true"].to_list() +def test_when_then_else_struct_18961() -> None: + v1 = [None, {"foo": 0, "bar": "1"}] + v2 = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}] + + df = pl.DataFrame({"left": v1, "right": v2, "mask": [False, True]}) + + expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}] + ans = ( + df.select( + pl.when(pl.col.mask).then(pl.col.left).otherwise(pl.col.right.first()) + ) + .get_column("left") + .to_list() + ) + assert expected == ans + + df = pl.DataFrame({"left": v2, "right": v1, "mask": [True, False]}) + + expected = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": "1"}] + ans = ( + df.select( + pl.when(pl.col.mask).then(pl.col.left.first()).otherwise(pl.col.right) + ) + .get_column("left") + .to_list() + ) + assert expected == ans + + def test_when_then_supertype_15975() -> None: df = pl.DataFrame({"a": [1, 2, 3]}) From 6abc2f19fb6b01234ac36e0c384a8abb75debcc2 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Sat, 28 Sep 2024 01:13:00 +1000 Subject: [PATCH 29/33] refactor: Mention `allow_missing_columns` in error message when column not found (parquet) (#18972) --- crates/polars-io/src/parquet/read/reader.rs | 92 +++++++++++++------ crates/polars-io/src/parquet/read/utils.rs | 2 +- .../nodes/parquet_source/metadata_utils.rs | 2 +- py-polars/tests/unit/io/test_lazy_parquet.py | 7 +- py-polars/tests/unit/io/test_parquet.py | 10 +- 5 files changed, 77 insertions(+), 36 deletions(-) diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 3cef93a89d73..eb3609c127ac 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -95,22 +95,38 @@ impl ParquetReader { let schema = self.schema()?; - if let Some(projected_arrow_schema) = projected_arrow_schema { - self.projection = projected_arrow_schema_to_projection_indices( - schema.as_ref(), - projected_arrow_schema, - )?; - } else { - if schema.len() > first_schema.len() { - polars_bail!( - SchemaMismatch: - "parquet file contained extra columns and no selection was given" - ) + (|| { + if let Some(projected_arrow_schema) = projected_arrow_schema { + self.projection = projected_arrow_schema_to_projection_indices( + schema.as_ref(), + projected_arrow_schema, + )?; + } else { + if schema.len() > first_schema.len() { + polars_bail!( + SchemaMismatch: + "parquet file contained extra columns and no selection was given" + ) + } + + self.projection = + projected_arrow_schema_to_projection_indices(schema.as_ref(), first_schema)?; + }; + Ok(()) + })() + .map_err(|e| { + if !allow_missing_columns && matches!(e, PolarsError::ColumnNotFound(_)) { + e.wrap_msg(|s| { + format!( + "error with column selection, \ + consider enabling `allow_missing_columns`: {}", + s + ) + }) + } else { + e } - - self.projection = - projected_arrow_schema_to_projection_indices(schema.as_ref(), first_schema)?; - } + })?; Ok(self) } @@ -316,22 +332,38 @@ impl ParquetAsyncReader { let schema = self.schema().await?; - if let Some(projected_arrow_schema) = projected_arrow_schema { - self.projection = projected_arrow_schema_to_projection_indices( - schema.as_ref(), - projected_arrow_schema, - )?; - } else { - if schema.len() > first_schema.len() { - polars_bail!( - SchemaMismatch: - "parquet file contained extra columns and no selection was given" - ) + (|| { + if let Some(projected_arrow_schema) = projected_arrow_schema { + self.projection = projected_arrow_schema_to_projection_indices( + schema.as_ref(), + projected_arrow_schema, + )?; + } else { + if schema.len() > first_schema.len() { + polars_bail!( + SchemaMismatch: + "parquet file contained extra columns and no selection was given" + ) + } + + self.projection = + projected_arrow_schema_to_projection_indices(schema.as_ref(), first_schema)?; + }; + Ok(()) + })() + .map_err(|e| { + if !allow_missing_columns && matches!(e, PolarsError::ColumnNotFound(_)) { + e.wrap_msg(|s| { + format!( + "error with column selection, \ + consider enabling `allow_missing_columns`: {}", + s + ) + }) + } else { + e } - - self.projection = - projected_arrow_schema_to_projection_indices(schema.as_ref(), first_schema)?; - } + })?; Ok(self) } diff --git a/crates/polars-io/src/parquet/read/utils.rs b/crates/polars-io/src/parquet/read/utils.rs index 62df261237eb..7ce183088aee 100644 --- a/crates/polars-io/src/parquet/read/utils.rs +++ b/crates/polars-io/src/parquet/read/utils.rs @@ -40,7 +40,7 @@ pub(super) fn projected_arrow_schema_to_projection_indices( for (i, field) in projected_arrow_schema.iter_values().enumerate() { let dtype = { let Some((idx, _, field)) = schema.get_full(&field.name) else { - polars_bail!(SchemaMismatch: "did not find column in file: {}", field.name) + polars_bail!(ColumnNotFound: "did not find column in file: {}", field.name) }; projection_indices.push(idx); diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs index 3e4d03a3a270..61db45d54a0a 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs @@ -130,7 +130,7 @@ pub(super) fn ensure_schema_has_projected_fields( let expected_dtype = DataType::from_arrow(&field.dtype, true); let dtype = { let Some(field) = schema.get(&field.name) else { - polars_bail!(SchemaMismatch: "did not find column: {}", field.name) + polars_bail!(ColumnNotFound: "error with column selection, consider enabling `allow_missing_columns`: did not find column in file: {}", field.name) }; DataType::from_arrow(&field.dtype, true) }; diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 111289408dea..9604793c7baf 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -449,7 +449,7 @@ def test_parquet_schema_mismatch_panic_17067(tmp_path: Path, streaming: bool) -> pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).write_parquet(tmp_path / "1.parquet") pl.DataFrame({"c": [1, 2, 3], "d": [4, 5, 6]}).write_parquet(tmp_path / "2.parquet") - with pytest.raises(pl.exceptions.SchemaError): + with pytest.raises(pl.exceptions.ColumnNotFoundError): pl.scan_parquet(tmp_path).collect(streaming=streaming) @@ -642,5 +642,8 @@ def test_parquet_unaligned_schema_read_missing_cols_from_first( lf = pl.scan_parquet(paths) - with pytest.raises(pl.exceptions.SchemaError, match="did not find column"): + with pytest.raises( + pl.exceptions.ColumnNotFoundError, + match="did not find column in file: a", + ): lf.collect(streaming=streaming) diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 3f7e8b977505..8431c659cce0 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -1932,10 +1932,16 @@ def test_allow_missing_columns( expected = pl.DataFrame({"a": [1, 2], "b": [1, None]}).select(projection) - with pytest.raises(pl.exceptions.SchemaError, match="did not find column"): + with pytest.raises( + pl.exceptions.ColumnNotFoundError, + match="error with column selection, consider enabling `allow_missing_columns`: did not find column in file: b", + ): pl.read_parquet(paths, parallel=parallel) # type: ignore[arg-type] - with pytest.raises(pl.exceptions.SchemaError, match="did not find column"): + with pytest.raises( + pl.exceptions.ColumnNotFoundError, + match="error with column selection, consider enabling `allow_missing_columns`: did not find column in file: b", + ): pl.scan_parquet(paths, parallel=parallel).select(projection).collect( # type: ignore[arg-type] streaming=streaming ) From fa7ec47be7c0381ab9eacd737d116113b0ecb4c3 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 27 Sep 2024 18:02:28 +0200 Subject: [PATCH 30/33] fix: Ensure same fmt in Series/AnyValue to string cast (#18982) --- .../src/compute/cast/primitive_to.rs | 2 +- crates/polars-core/src/datatypes/any_value.rs | 17 ++++++++++------- py-polars/tests/unit/operations/test_cast.py | 6 ++++++ 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index 03a9c427cf4b..d017b0a8e212 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -13,7 +13,7 @@ use crate::offset::{Offset, Offsets}; use crate::temporal_conversions::*; use crate::types::{days_ms, f16, months_days_ns, NativeType}; -pub(super) trait SerPrimitive { +pub trait SerPrimitive { fn write(f: &mut Vec, val: Self) -> usize where Self: Sized; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 83d9d4f8301d..4155a9bf14e9 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; +use arrow::compute::cast::SerPrimitive; use arrow::types::PrimitiveType; -use polars_utils::format_pl_smallstr; #[cfg(feature = "dtype-categorical")] use polars_utils::sync::SyncPtr; use polars_utils::total_ord::ToTotalOrd; @@ -563,19 +563,22 @@ impl<'a> AnyValue<'a> { (AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()), // to string - (AnyValue::String(v), DataType::String) => { - AnyValue::StringOwned(PlSmallStr::from_str(v)) - }, + (AnyValue::String(v), DataType::String) => AnyValue::String(v), (AnyValue::StringOwned(v), DataType::String) => AnyValue::StringOwned(v.clone()), (av, DataType::String) => { + let mut tmp = vec![]; if av.is_unsigned_integer() { - AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); } else if av.is_float() { - AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); } else { - AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + let val = av.extract::()?; + SerPrimitive::write(&mut tmp, val); } + AnyValue::StringOwned(PlSmallStr::from_str(std::str::from_utf8(&tmp).unwrap())) }, // to binary diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index cdf7688f7c32..e01912237a19 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -686,3 +686,9 @@ def test_bool_numeric_supertype(dtype: PolarsDataType) -> None: df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]}) result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len()) assert result.item() - 0.3333333 <= 0.00001 + + +def test_cast_consistency() -> None: + assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns( + b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String) + ).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]} From 27e26071dab154745827b715a66137db93927da1 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 18:02:42 +0200 Subject: [PATCH 31/33] refactor: Preserve scalar in more places (#18898) --- .../src/chunked_array/ops/min_max_binary.rs | 48 +++++++---- .../polars-core/src/frame/column/compare.rs | 86 +++++++++++++++++++ crates/polars-core/src/frame/column/mod.rs | 71 +++------------ crates/polars-core/src/frame/mod.rs | 36 +++----- crates/polars-python/src/dataframe/general.rs | 4 +- 5 files changed, 143 insertions(+), 102 deletions(-) create mode 100644 crates/polars-core/src/frame/column/compare.rs diff --git a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs index bc33f088b1f9..28e7c491095b 100644 --- a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs +++ b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs @@ -31,31 +31,45 @@ where arity::binary_elementwise_values(left, right, op) } -pub(crate) fn min_max_binary_series( - left: &Series, - right: &Series, +pub(crate) fn min_max_binary_columns( + left: &Column, + right: &Column, min: bool, -) -> PolarsResult { +) -> PolarsResult { if left.dtype().to_physical().is_numeric() && left.null_count() == 0 && right.null_count() == 0 && left.len() == right.len() { - let (lhs, rhs) = coerce_lhs_rhs(left, right)?; - let logical = lhs.dtype(); - let lhs = lhs.to_physical_repr(); - let rhs = rhs.to_physical_repr(); + match (left, right) { + (Column::Series(left), Column::Series(right)) => { + let (lhs, rhs) = coerce_lhs_rhs(left, right)?; + let logical = lhs.dtype(); + let lhs = lhs.to_physical_repr(); + let rhs = rhs.to_physical_repr(); - with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| { - let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); - let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + with_match_physical_numeric_polars_type!(lhs.dtype(), |$T| { + let a: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let b: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); - if min { - min_binary(a, b).into_series().cast(logical) - } else { - max_binary(a, b).into_series().cast(logical) - } - }) + if min { + min_binary(a, b).into_series().cast(logical) + } else { + max_binary(a, b).into_series().cast(logical) + } + }) + .map(Column::from) + }, + _ => { + let mask = if min { + left.lt(right)? + } else { + left.gt(right)? + }; + + left.zip_with(&mask, right) + }, + } } else { let mask = if min { left.lt(right)? & left.is_not_null() | right.is_null() diff --git a/crates/polars-core/src/frame/column/compare.rs b/crates/polars-core/src/frame/column/compare.rs new file mode 100644 index 000000000000..fdb792d60074 --- /dev/null +++ b/crates/polars-core/src/frame/column/compare.rs @@ -0,0 +1,86 @@ +use polars_error::PolarsResult; + +use super::{BooleanChunked, ChunkCompareEq, ChunkCompareIneq, ChunkExpandAtIndex, Column, Series}; + +macro_rules! column_element_wise_broadcasting { + ($lhs:expr, $rhs:expr, $op:expr) => { + match ($lhs, $rhs) { + (Column::Series(lhs), Column::Series(rhs)) => $op(lhs, rhs), + (Column::Series(lhs), Column::Scalar(rhs)) => $op(lhs, &rhs.as_single_value_series()), + (Column::Scalar(lhs), Column::Series(rhs)) => $op(&lhs.as_single_value_series(), rhs), + (Column::Scalar(lhs), Column::Scalar(rhs)) => { + $op(&lhs.as_single_value_series(), &rhs.as_single_value_series()).map(|ca| { + if ca.len() == 0 { + ca + } else { + ca.new_from_index(0, lhs.len()) + } + }) + }, + } + }; +} + +impl ChunkCompareEq<&Column> for Column { + type Item = PolarsResult; + + /// Create a boolean mask by checking for equality. + #[inline] + fn equal(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::equal) + } + + /// Create a boolean mask by checking for equality. + #[inline] + fn equal_missing(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!( + self, + rhs, + >::equal_missing + ) + } + + /// Create a boolean mask by checking for inequality. + #[inline] + fn not_equal(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::not_equal) + } + + /// Create a boolean mask by checking for inequality. + #[inline] + fn not_equal_missing(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!( + self, + rhs, + >::not_equal_missing + ) + } +} + +impl ChunkCompareIneq<&Column> for Column { + type Item = PolarsResult; + + /// Create a boolean mask by checking if self > rhs. + #[inline] + fn gt(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::gt) + } + + /// Create a boolean mask by checking if self >= rhs. + #[inline] + fn gt_eq(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::gt_eq) + } + + /// Create a boolean mask by checking if self < rhs. + #[inline] + fn lt(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::lt) + } + + /// Create a boolean mask by checking if self <= rhs. + #[inline] + fn lt_eq(&self, rhs: &Column) -> PolarsResult { + column_element_wise_broadcasting!(self, rhs, >::lt_eq) + } +} diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index cb88a9946ca4..3a6343415a6a 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -16,6 +16,7 @@ use crate::utils::{slice_offsets, Container}; use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; mod arithmetic; +mod compare; mod scalar; /// A column within a [`DataFrame`]. @@ -990,69 +991,17 @@ impl Column { // @scalar-opt self.as_materialized_series().estimated_size() } -} - -impl ChunkCompareEq<&Column> for Column { - type Item = PolarsResult; - - /// Create a boolean mask by checking for equality. - #[inline] - fn equal(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .equal(rhs.as_materialized_series()) - } - - /// Create a boolean mask by checking for equality. - #[inline] - fn equal_missing(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .equal_missing(rhs.as_materialized_series()) - } - - /// Create a boolean mask by checking for inequality. - #[inline] - fn not_equal(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .not_equal(rhs.as_materialized_series()) - } - - /// Create a boolean mask by checking for inequality. - #[inline] - fn not_equal_missing(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .not_equal_missing(rhs.as_materialized_series()) - } -} -impl ChunkCompareIneq<&Column> for Column { - type Item = PolarsResult; - - /// Create a boolean mask by checking if self > rhs. - #[inline] - fn gt(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .gt(rhs.as_materialized_series()) - } - - /// Create a boolean mask by checking if self >= rhs. - #[inline] - fn gt_eq(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .gt_eq(rhs.as_materialized_series()) - } - - /// Create a boolean mask by checking if self < rhs. - #[inline] - fn lt(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .lt(rhs.as_materialized_series()) - } + pub(crate) fn sort_with(&self, options: SortOptions) -> PolarsResult { + match self { + Column::Series(s) => s.sort_with(options).map(Self::from), + Column::Scalar(s) => { + // This makes this function throw the same errors as Series::sort_with + _ = s.as_single_value_series().sort_with(options)?; - /// Create a boolean mask by checking if self <= rhs. - #[inline] - fn lt_eq(&self, rhs: &Column) -> Self::Item { - self.as_materialized_series() - .lt_eq(rhs.as_materialized_series()) + Ok(self.clone()) + }, + } } } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a741ad846351..998b82eddc78 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -37,7 +37,7 @@ use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] use crate::hashing::_df_rows_to_hashes_threaded_vertical; #[cfg(feature = "zip_with")] -use crate::prelude::min_max_binary::min_max_binary_series; +use crate::prelude::min_max_binary::min_max_binary_columns; use crate::prelude::sort::{argsort_multiple_row_fmt, prepare_arg_sort}; use crate::series::IsSorted; use crate::POOL; @@ -1870,7 +1870,7 @@ impl DataFrame { let df = df.as_single_chunk_par(); let mut take = match (by_column.len(), has_struct) { (1, false) => { - let s = &by_column[0].as_materialized_series(); + let s = &by_column[0]; let options = SortOptions { descending: sort_options.descending[0], nulls_last: sort_options.nulls_last[0], @@ -2584,24 +2584,19 @@ impl DataFrame { /// Aggregate the column horizontally to their min values. #[cfg(feature = "zip_with")] - pub fn min_horizontal(&self) -> PolarsResult> { - let min_fn = |acc: &Series, s: &Series| min_max_binary_series(acc, s, true); + pub fn min_horizontal(&self) -> PolarsResult> { + let min_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, true); match self.columns.len() { 0 => Ok(None), - 1 => Ok(Some( - self.columns[0].clone().as_materialized_series().clone(), - )), - 2 => min_fn( - self.columns[0].as_materialized_series(), - self.columns[1].as_materialized_series(), - ) - .map(Some), + 1 => Ok(Some(self.columns[0].clone())), + 2 => min_fn(&self.columns[0], &self.columns[1]).map(Some), _ => { // the try_reduce_with is a bit slower in parallelism, // but I don't think it matters here as we parallelize over columns, not over elements POOL.install(|| { - self.par_materialized_column_iter() + self.columns + .par_iter() .map(|s| Ok(Cow::Borrowed(s))) .try_reduce_with(|l, r| min_fn(&l, &r).map(Cow::Owned)) // we can unwrap the option, because we are certain there is a column @@ -2615,22 +2610,19 @@ impl DataFrame { /// Aggregate the column horizontally to their max values. #[cfg(feature = "zip_with")] - pub fn max_horizontal(&self) -> PolarsResult> { - let max_fn = |acc: &Series, s: &Series| min_max_binary_series(acc, s, false); + pub fn max_horizontal(&self) -> PolarsResult> { + let max_fn = |acc: &Column, s: &Column| min_max_binary_columns(acc, s, false); match self.columns.len() { 0 => Ok(None), - 1 => Ok(Some(self.columns[0].as_materialized_series().clone())), - 2 => max_fn( - self.columns[0].as_materialized_series(), - self.columns[1].as_materialized_series(), - ) - .map(Some), + 1 => Ok(Some(self.columns[0].clone())), + 2 => max_fn(&self.columns[0], &self.columns[1]).map(Some), _ => { // the try_reduce_with is a bit slower in parallelism, // but I don't think it matters here as we parallelize over columns, not over elements POOL.install(|| { - self.par_materialized_column_iter() + self.columns + .par_iter() .map(|s| Ok(Cow::Borrowed(s))) .try_reduce_with(|l, r| max_fn(&l, &r).map(Cow::Owned)) // we can unwrap the option, because we are certain there is a column diff --git a/crates/polars-python/src/dataframe/general.rs b/crates/polars-python/src/dataframe/general.rs index 78727ffefd33..5df89ed423c7 100644 --- a/crates/polars-python/src/dataframe/general.rs +++ b/crates/polars-python/src/dataframe/general.rs @@ -460,12 +460,12 @@ impl PyDataFrame { pub fn max_horizontal(&self) -> PyResult> { let s = self.df.max_horizontal().map_err(PyPolarsErr::from)?; - Ok(s.map(|s| s.into())) + Ok(s.map(|s| s.take_materialized_series().into())) } pub fn min_horizontal(&self) -> PyResult> { let s = self.df.min_horizontal().map_err(PyPolarsErr::from)?; - Ok(s.map(|s| s.into())) + Ok(s.map(|s| s.take_materialized_series().into())) } pub fn sum_horizontal(&self, ignore_nulls: bool) -> PyResult> { From 1be1792cbf7e8e372dcb27fc8f2f03414eb62735 Mon Sep 17 00:00:00 2001 From: Marshall Date: Fri, 27 Sep 2024 13:36:06 -0400 Subject: [PATCH 32/33] docs: Fix `is_not_nan` description (#18985) --- py-polars/polars/series/series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index c66d1f0a4abc..be3734f8e7ca 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3695,7 +3695,7 @@ def is_infinite(self) -> Series: def is_nan(self) -> Series: """ - Returns a boolean Series indicating which values are not NaN. + Returns a boolean Series indicating which values are NaN. Returns ------- From 901b2437cabdc66998642d4d761bc5d36053b720 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 27 Sep 2024 19:36:29 +0200 Subject: [PATCH 33/33] perf: Use List's TotalEqKernel (#18984) --- crates/polars-compute/src/comparisons/list.rs | 285 ++++++++++++++---- .../src/chunked_array/comparison/mod.rs | 183 +++++++---- .../tests/unit/operations/test_explode.py | 11 +- 3 files changed, 364 insertions(+), 115 deletions(-) diff --git a/crates/polars-compute/src/comparisons/list.rs b/crates/polars-compute/src/comparisons/list.rs index fa35cbaac9b6..cd0414b7cca8 100644 --- a/crates/polars-compute/src/comparisons/list.rs +++ b/crates/polars-compute/src/comparisons/list.rs @@ -1,86 +1,257 @@ -use arrow::array::ListArray; -use arrow::bitmap::{Bitmap, MutableBitmap}; -use arrow::types::Offset; +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::CustomIterTools; +use arrow::types::{days_ms, f16, i256, months_days_ns, Offset}; use super::TotalEqKernel; -use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; -impl TotalEqKernel for ListArray { - type Scalar = (); +macro_rules! compare { + ( + $lhs:expr, $rhs:expr, + $op:path, $true_op:expr, + $ineq_len_rv:literal, $invalid_rv:literal + ) => {{ + let lhs = $lhs; + let rhs = $rhs; - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - assert_eq!(self.len(), other.len()); + assert_eq!(lhs.len(), rhs.len()); + assert_eq!(lhs.dtype(), rhs.dtype()); - let mut bitmap = MutableBitmap::with_capacity(self.len()); + macro_rules! call_binary { + ($T:ty) => {{ + let lhs_values: &$T = $lhs.values().as_any().downcast_ref().unwrap(); + let rhs_values: &$T = $rhs.values().as_any().downcast_ref().unwrap(); - for i in 0..self.len() { - let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); - let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + (0..$lhs.len()) + .map(|i| { + let lval = $lhs.validity().map_or(true, |v| v.get(i).unwrap()); + let rval = $rhs.validity().map_or(true, |v| v.get(i).unwrap()); - if !lval || !rval { - bitmap.push(true); - continue; - } + if !lval || !rval { + return $invalid_rv; + } - let (lstart, lend) = self.offsets().start_end(i); - let (rstart, rend) = other.offsets().start_end(i); + // SAFETY: ListArray's invariant offsets.len_proxy() == len + let (lstart, lend) = unsafe { $lhs.offsets().start_end_unchecked(i) }; + let (rstart, rend) = unsafe { $rhs.offsets().start_end_unchecked(i) }; - if lend - lstart != rend - rstart { - bitmap.push(false); - continue; - } + if lend - lstart != rend - rstart { + return $ineq_len_rv; + } - let mut lhs_values = self.values().clone(); - lhs_values.slice(lstart, lend - lstart); - let mut rhs_values = other.values().clone(); - rhs_values.slice(rstart, rend - rstart); + let mut lhs_values = lhs_values.clone(); + lhs_values.slice(lstart, lend - lstart); + let mut rhs_values = rhs_values.clone(); + rhs_values.slice(rstart, rend - rstart); - let result = array_tot_eq_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); - bitmap.push(result.unset_bits() == 0); + $true_op($op(&lhs_values, &rhs_values)) + }) + .collect_trusted() + }}; } - bitmap.freeze() - } + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.values().dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - assert_eq!(self.len(), other.len()); + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), - let mut bitmap = MutableBitmap::with_capacity(self.len()); + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} - for i in 0..self.len() { - let (lstart, lend) = self.offsets().start_end(i); - let (rstart, rend) = other.offsets().start_end(i); +macro_rules! compare_broadcast { + ( + $lhs:expr, $rhs:expr, + $offsets:expr, $validity:expr, + $op:path, $true_op:expr, + $ineq_len_rv:literal, $invalid_rv:literal + ) => {{ + let lhs = $lhs; + let rhs = $rhs; - let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); - let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + macro_rules! call_binary { + ($T:ty) => {{ + let values: &$T = $lhs.as_any().downcast_ref().unwrap(); + let scalar: &$T = $rhs.as_any().downcast_ref().unwrap(); - if !lval || !rval { - bitmap.push(false); - continue; - } + let length = $offsets.len_proxy(); - if lend - lstart != rend - rstart { - bitmap.push(true); - continue; - } + (0..length) + .map(move |i| { + let v = $validity.map_or(true, |v| v.get(i).unwrap()); - let mut lhs_values = self.values().clone(); - lhs_values.slice(lstart, lend - lstart); - let mut rhs_values = self.values().clone(); - rhs_values.slice(rstart, rend - rstart); + if !v { + return $invalid_rv; + } - let result = array_tot_ne_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); - bitmap.push(result.set_bits() > 0); + let (start, end) = unsafe { $offsets.start_end_unchecked(i) }; + + if end - start != scalar.len() { + return $ineq_len_rv; + } + + // @TODO: I feel like there is a better way to do this. + let mut values: $T = values.clone(); + <$T>::slice(&mut values, start, end - start); + + $true_op($op(&values, scalar)) + }) + .collect_trusted() + }}; } - bitmap.freeze() + assert_eq!(lhs.dtype(), rhs.dtype()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.dtype().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray), + PH::BinaryView => call_binary!(BinaryViewArray), + PH::Utf8View => call_binary!(Utf8ViewArray), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray), + PH::Binary => call_binary!(BinaryArray), + PH::LargeBinary => call_binary!(BinaryArray), + PH::Utf8 => call_binary!(Utf8Array), + PH::LargeUtf8 => call_binary!(Utf8Array), + PH::List => call_binary!(ListArray), + PH::LargeList => call_binary!(ListArray), + PH::Struct => call_binary!(StructArray), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray), + } + }}; +} + +impl TotalEqKernel for ListArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + compare!( + self, + other, + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0, + false, + true + ) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + compare!( + self, + other, + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0, + true, + false + ) } - fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + compare_broadcast!( + self.values().as_ref(), + other.as_ref(), + self.offsets(), + self.validity(), + TotalEqKernel::tot_eq_missing_kernel, + |bm: Bitmap| bm.unset_bits() == 0, + false, + true + ) } - fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + compare_broadcast!( + self.values().as_ref(), + other.as_ref(), + self.offsets(), + self.validity(), + TotalEqKernel::tot_ne_missing_kernel, + |bm: Bitmap| bm.set_bits() > 0, + true, + false + ) } } diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 4ee9cad0a482..344f7a796735 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -638,76 +638,117 @@ impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked { } } -#[doc(hidden)] -fn _list_comparison_helper(lhs: &ListChunked, rhs: &ListChunked, op: F) -> BooleanChunked +fn _list_comparison_helper( + lhs: &ListChunked, + rhs: &ListChunked, + op: F, + broadcast_op: B, + missing: bool, + is_ne: bool, +) -> BooleanChunked where - F: Fn(Option<&Series>, Option<&Series>) -> Option, + F: Fn(&ListArray, &ListArray) -> Bitmap, + B: Fn(&ListArray, &Box) -> Bitmap, { match (lhs.len(), rhs.len()) { (_, 1) => { - let right = rhs.get_as_series(0).map(|s| s.with_name(PlSmallStr::EMPTY)); - lhs.amortized_iter() - .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) - .collect_trusted() + let right = rhs.chunks()[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + if !right.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return lhs.is_not_null(); + } else { + return lhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, lhs.len()); + } + } + + let values = right.values().sliced( + (*right.offsets().first()).try_into().unwrap(), + right.offsets().range().try_into().unwrap(), + ); + + arity::unary_mut_values(lhs, |a| broadcast_op(a, &values).into()) }, (1, _) => { - let left = lhs.get_as_series(0).map(|s| s.with_name(PlSmallStr::EMPTY)); - rhs.amortized_iter() - .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) - .collect_trusted() + let left = lhs.chunks()[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + if !left.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return rhs.is_not_null(); + } else { + return rhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()); + } + } + + let values = left.values().sliced( + (*left.offsets().first()).try_into().unwrap(), + left.offsets().range().try_into().unwrap(), + ); + + arity::unary_mut_values(rhs, |a| broadcast_op(a, &values).into()) }, - _ => lhs - .amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| { - op( - left.as_ref().map(|us| us.as_ref()), - right.as_ref().map(|us| us.as_ref()), - ) - }) - .collect_trusted(), + _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), } } impl ChunkCompareEq<&ListChunked> for ListChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_equals = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(l.equals(r)), - _ => None, - }; - - _list_comparison_helper(self, rhs, _series_equals) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_kernel, + TotalEqKernel::tot_eq_kernel_broadcast, + false, + false, + ) } fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_equals_missing = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(l.equals_missing(r)), - (None, None) => Some(true), - _ => Some(false), - }; - - _list_comparison_helper(self, rhs, _series_equals_missing) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_eq_missing_kernel, + TotalEqKernel::tot_eq_missing_kernel_broadcast, + true, + false, + ) } fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_not_equal = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(!l.equals(r)), - _ => None, - }; - - _list_comparison_helper(self, rhs, _series_not_equal) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_kernel, + TotalEqKernel::tot_ne_kernel_broadcast, + false, + true, + ) } fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked { - let _series_not_equal_missing = - |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) { - (Some(l), Some(r)) => Some(!l.equals_missing(r)), - (None, None) => Some(false), - _ => Some(true), - }; - - _list_comparison_helper(self, rhs, _series_not_equal_missing) + _list_comparison_helper( + self, + rhs, + TotalEqKernel::tot_ne_missing_kernel, + TotalEqKernel::tot_ne_missing_kernel_broadcast, + true, + true, + ) } } @@ -798,6 +839,8 @@ fn _array_comparison_helper( rhs: &ArrayChunked, op: F, broadcast_op: B, + missing: bool, + is_ne: bool, ) -> BooleanChunked where F: Fn(&FixedSizeListArray, &FixedSizeListArray) -> Bitmap, @@ -808,17 +851,41 @@ where let right = rhs.chunks()[0] .as_any() .downcast_ref::() - .unwrap() - .values(); - arity::unary_mut_values(lhs, |a| broadcast_op(a, right).into()) + .unwrap(); + + if !right.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return lhs.is_not_null(); + } else { + return lhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, lhs.len()); + } + } + + arity::unary_mut_values(lhs, |a| broadcast_op(a, right.values()).into()) }, (1, _) => { let left = lhs.chunks()[0] .as_any() .downcast_ref::() - .unwrap() - .values(); - arity::unary_mut_values(rhs, |a| broadcast_op(a, left).into()) + .unwrap(); + + if !left.validity().map_or(true, |v| v.get(0).unwrap()) { + if missing { + if is_ne { + return rhs.is_not_null(); + } else { + return rhs.is_null(); + } + } else { + return BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()); + } + } + + arity::unary_mut_values(rhs, |a| broadcast_op(a, left.values()).into()) }, _ => arity::binary_mut_values(lhs, rhs, |a, b| op(a, b).into(), PlSmallStr::EMPTY), } @@ -833,6 +900,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_eq_kernel, TotalEqKernel::tot_eq_kernel_broadcast, + false, + false, ) } @@ -842,6 +911,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_eq_missing_kernel, TotalEqKernel::tot_eq_missing_kernel_broadcast, + true, + false, ) } @@ -851,6 +922,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_ne_kernel, TotalEqKernel::tot_ne_kernel_broadcast, + false, + true, ) } @@ -860,6 +933,8 @@ impl ChunkCompareEq<&ArrayChunked> for ArrayChunked { rhs, TotalEqKernel::tot_ne_missing_kernel, TotalEqKernel::tot_ne_missing_kernel_broadcast, + true, + true, ) } } diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 14aefa93c3c1..3807a6b29ef5 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -405,14 +405,14 @@ def test_fast_explode_merge_left_16923() -> None: @pytest.mark.parametrize( ("values", "exploded"), [ - (["foobar", None], ["f", "o", "o", "b", "a", "r", None]), - ([None, "foo", "bar"], [None, "f", "o", "o", "b", "a", "r"]), + (["foobar", None], ["f", "o", "o", "b", "a", "r", ""]), + ([None, "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), ( [None, "foo", "bar", None, "ham"], - [None, "f", "o", "o", "b", "a", "r", None, "h", "a", "m"], + ["", "f", "o", "o", "b", "a", "r", "", "h", "a", "m"], ), (["foo", "bar", "ham"], ["f", "o", "o", "b", "a", "r", "h", "a", "m"]), - (["", None, "foo", "bar"], ["", None, "f", "o", "o", "b", "a", "r"]), + (["", None, "foo", "bar"], ["", "", "f", "o", "o", "b", "a", "r"]), (["", "foo", "bar"], ["", "f", "o", "o", "b", "a", "r"]), ], ) @@ -421,6 +421,9 @@ def test_series_str_explode_deprecated( ) -> None: with pytest.deprecated_call(): result = pl.Series(values).str.explode() + if result.to_list() != exploded: + print(result.to_list()) + print(exploded) assert result.to_list() == exploded