Skip to content

Commit

Permalink
fix: Added input validation for explode operation in the array name…
Browse files Browse the repository at this point in the history
…space (#19163)
  • Loading branch information
barak1412 authored Nov 1, 2024
1 parent 1dc5efc commit d2a0441
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 3 deletions.
5 changes: 5 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,9 @@ impl ArrayNameSpace {
None,
)
}
/// Returns a column with a separate row for every array element.
pub fn explode(self) -> Expr {
self.0
.map_private(FunctionExpr::ArrayExpr(ArrayFunction::Explode))
}
}
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub enum ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches,
Shift,
Explode,
}

impl ArrayFunction {
Expand All @@ -56,6 +57,7 @@ impl ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
Shift => mapper.with_same_dtype(),
Explode => mapper.try_map_to_array_inner_dtype(),
}
}
}
Expand Down Expand Up @@ -96,6 +98,7 @@ impl Display for ArrayFunction {
#[cfg(feature = "array_count")]
CountMatches => "count_matches",
Shift => "shift",
Explode => "explode",
};
write!(f, "arr.{name}")
}
Expand Down Expand Up @@ -129,6 +132,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
#[cfg(feature = "array_count")]
CountMatches => map_as_slice!(count_matches),
Shift => map_as_slice!(shift),
Explode => unreachable!(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ impl<'a> FieldsMapper<'a> {
Ok(first)
}

#[cfg(feature = "dtype-array")]
/// Map the dtype to the dtype of the array elements, with typo validation.
pub fn try_map_to_array_inner_dtype(&self) -> PolarsResult<Field> {
let dt = self.fields[0].dtype();
match dt {
DataType::Array(_, _) => self.map_to_list_and_array_inner_dtype(),
_ => polars_bail!(InvalidOperation: "expected Array type, got: {}", dt),
}
}

/// Map the dtypes to the "supertype" of a list of lists.
pub fn map_to_list_supertype(&self) -> PolarsResult<Field> {
self.try_map_dtypes(|dts| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ pub(super) fn optimize_functions(
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
let out = match function {
#[cfg(feature = "dtype-array")]
// arr.explode() -> explode()
FunctionExpr::ArrayExpr(ArrayFunction::Explode) => {
let input_node = input[0].node();
Some(AExpr::Explode(input_node))
},
// is_null().any() -> null_count() > 0
// is_not_null().any() -> null_count() < len()
// CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-python/src/expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,8 @@ impl PyExpr {
fn arr_shift(&self, n: PyExpr) -> Self {
self.inner.clone().arr().shift(n.inner).into()
}

fn arr_explode(&self) -> Self {
self.inner.clone().arr().explode().into()
}
}
2 changes: 1 addition & 1 deletion py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def explode(self) -> Expr:
│ 6 │
└─────┘
"""
return wrap_expr(self._pyexpr.explode())
return wrap_expr(self._pyexpr.arr_explode())

def contains(
self, item: float | str | bool | int | date | datetime | time | IntoExprColumn
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@expr_dispatch
class ArrayNameSpace:
"""Namespace for list related methods."""
"""Namespace for array related methods."""

_accessor = "arr"

Expand Down
13 changes: 12 additions & 1 deletion py-polars/tests/unit/operations/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.exceptions import ComputeError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -449,3 +449,14 @@ def test_array_n_unique() -> None:
{"n_unique": [2, 1, 1, None]}, schema={"n_unique": pl.UInt32}
)
assert_frame_equal(out, expected)


def test_explode_19049() -> None:
df = pl.DataFrame({"a": [[1, 2, 3]]}, schema={"a": pl.Array(pl.Int64, 3)})
result_df = df.select(pl.col.a.arr.explode())
expected_df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})
assert_frame_equal(result_df, expected_df)

df = pl.DataFrame({"a": [1, 2, 3]}, schema={"a": pl.Int64})
with pytest.raises(InvalidOperationError, match="expected Array type, got: i64"):
df.select(pl.col.a.arr.explode())

0 comments on commit d2a0441

Please sign in to comment.