From d2a04417ed1f0f14f030a7e1b1fd7957e1ad04e2 Mon Sep 17 00:00:00 2001 From: barak1412 Date: Fri, 1 Nov 2024 15:47:56 +0200 Subject: [PATCH] fix: Added input validation for `explode` operation in the array namespace (#19163) --- crates/polars-plan/src/dsl/array.rs | 5 +++++ crates/polars-plan/src/dsl/function_expr/array.rs | 4 ++++ crates/polars-plan/src/dsl/function_expr/schema.rs | 10 ++++++++++ .../optimizer/simplify_expr/simplify_functions.rs | 6 ++++++ crates/polars-python/src/expr/array.rs | 4 ++++ py-polars/polars/expr/array.py | 2 +- py-polars/polars/series/array.py | 2 +- .../unit/operations/namespaces/array/test_array.py | 13 ++++++++++++- 8 files changed, 43 insertions(+), 3 deletions(-) diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 23faf363421c..389ed11a16be 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -193,4 +193,9 @@ impl ArrayNameSpace { None, ) } + /// Returns a column with a separate row for every array element. + pub fn explode(self) -> Expr { + self.0 + .map_private(FunctionExpr::ArrayExpr(ArrayFunction::Explode)) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index dce6d44bce94..2ecd016981e3 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -30,6 +30,7 @@ pub enum ArrayFunction { #[cfg(feature = "array_count")] CountMatches, Shift, + Explode, } impl ArrayFunction { @@ -56,6 +57,7 @@ impl ArrayFunction { #[cfg(feature = "array_count")] CountMatches => mapper.with_dtype(IDX_DTYPE), Shift => mapper.with_same_dtype(), + Explode => mapper.try_map_to_array_inner_dtype(), } } } @@ -96,6 +98,7 @@ impl Display for ArrayFunction { #[cfg(feature = "array_count")] CountMatches => "count_matches", Shift => "shift", + Explode => "explode", }; write!(f, "arr.{name}") } @@ -129,6 +132,7 @@ impl From for SpecialEq> { #[cfg(feature = "array_count")] CountMatches => map_as_slice!(count_matches), Shift => map_as_slice!(shift), + Explode => unreachable!(), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 606ab81207c4..8ac54c172993 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -488,6 +488,16 @@ impl<'a> FieldsMapper<'a> { Ok(first) } + #[cfg(feature = "dtype-array")] + /// Map the dtype to the dtype of the array elements, with typo validation. + pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult { + let dt = self.fields[0].dtype(); + match dt { + DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(), + _ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt), + } + } + /// Map the dtypes to the "supertype" of a list of lists. pub fn map_to_list_supertype(&self) -> PolarsResult { self.try_map_dtypes(|dts| { diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs index 2b5493c62e6b..03f274e5211a 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs @@ -7,6 +7,12 @@ pub(super) fn optimize_functions( expr_arena: &mut Arena, ) -> PolarsResult> { let out = match function { + #[cfg(feature = "dtype-array")] + // arr.explode() -> explode() + FunctionExpr::ArrayExpr(ArrayFunction::Explode) => { + let input_node = input[0].node(); + Some(AExpr::Explode(input_node)) + }, // is_null().any() -> null_count() > 0 // is_not_null().any() -> null_count() < len() // CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS diff --git a/crates/polars-python/src/expr/array.rs b/crates/polars-python/src/expr/array.rs index a646ab4f375f..9d057dcc4c4f 100644 --- a/crates/polars-python/src/expr/array.rs +++ b/crates/polars-python/src/expr/array.rs @@ -132,4 +132,8 @@ impl PyExpr { fn arr_shift(&self, n: PyExpr) -> Self { self.inner.clone().arr().shift(n.inner).into() } + + fn arr_explode(&self) -> Self { + self.inner.clone().arr().explode().into() + } } diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 928e3149c35d..5a964b464934 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -605,7 +605,7 @@ def explode(self) -> Expr: │ 6 │ └─────┘ """ - return wrap_expr(self._pyexpr.explode()) + return wrap_expr(self._pyexpr.arr_explode()) def contains( self, item: float | str | bool | int | date | datetime | time | IntoExprColumn diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 877ec303fcfb..6bb7285b1176 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -17,7 +17,7 @@ @expr_dispatch class ArrayNameSpace: - """Namespace for list related methods.""" + """Namespace for array related methods.""" _accessor = "arr" diff --git a/py-polars/tests/unit/operations/namespaces/array/test_array.py b/py-polars/tests/unit/operations/namespaces/array/test_array.py index 78948406578a..f59366d2c4a5 100644 --- a/py-polars/tests/unit/operations/namespaces/array/test_array.py +++ b/py-polars/tests/unit/operations/namespaces/array/test_array.py @@ -6,7 +6,7 @@ import pytest import polars as pl -from polars.exceptions import ComputeError +from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal @@ -449,3 +449,14 @@ def test_array_n_unique() -> None: {"n_unique": [2, 1, 1, None]}, schema={"n_unique": pl.UInt32} ) assert_frame_equal(out, expected) + + +def test_explode_19049() -> None: + df = pl.DataFrame({"a": [[1, 2, 3]]}, schema={"a": pl.Array(pl.Int64, 3)}) + result_df = df.select(pl.col.a.arr.explode()) + expected_df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64}) + assert_frame_equal(result_df, expected_df) + + df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64}) + with pytest.raises(InvalidOperationError, match="expected Array type, got: i64"): + df.select(pl.col.a.arr.explode())