diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index e67d1638e99d..46c339323b1b 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -75,6 +75,13 @@ pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { unsafe { take_unchecked(&**values, &take_by) } } +/// Check if an index is out of bounds for at least one sublist. +pub fn index_is_oob(arr: &ListArray, index: i64) -> bool { + arr.offsets() + .lengths() + .any(|len| index.negative_to_usize(len).is_none()) +} + /// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]` pub fn array_to_unit_list(array: ArrayRef) -> ListArray { let len = array.len(); diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index a4f7e78e2c6d..0d511c87967c 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -1,7 +1,7 @@ use std::fmt::Write; use arrow::array::ValueSize; -use arrow::legacy::kernels::list::sublist_get; +use arrow::legacy::kernels::list::{index_is_oob, sublist_get}; use polars_core::chunked_array::builder::get_list_builder; #[cfg(feature = "list_gather")] use polars_core::export::num::ToPrimitive; @@ -341,8 +341,12 @@ pub trait ListNameSpaceImpl: AsList { /// So index `0` would return the first item of every sublist /// and index `-1` would return the last item of every sublist /// if an index is out of bounds, it will return a `None`. - fn lst_get(&self, idx: i64) -> PolarsResult { + fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult { let ca = self.as_list(); + if !null_on_oob && ca.downcast_iter().any(|arr| index_is_oob(arr, idx)) { + polars_bail!(ComputeError: "get index is out of bounds"); + } + let chunks = ca .downcast_iter() .map(|arr| sublist_get(arr, idx)) diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index c43cfda13024..4b74a76692ed 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -72,7 +72,7 @@ pub trait ToStruct: AsList { (0..n_fields) .into_par_iter() .map(|i| { - ca.lst_get(i as i64).map(|mut s| { + ca.lst_get(i as i64, true).map(|mut s| { s.rename(&name_generator(i)); s }) diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 3fdbf6a18134..3b06841fbb55 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -21,7 +21,7 @@ pub enum ListFunction { }, Slice, Shift, - Get, + Get(bool), #[cfg(feature = "list_gather")] Gather(bool), #[cfg(feature = "list_gather")] @@ -71,7 +71,7 @@ impl ListFunction { Sample { .. } => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), Shift => mapper.with_same_dtype(), - Get => mapper.map_to_list_and_array_inner_dtype(), + Get(_) => mapper.map_to_list_and_array_inner_dtype(), #[cfg(feature = "list_gather")] Gather(_) => mapper.with_same_dtype(), #[cfg(feature = "list_gather")] @@ -136,7 +136,7 @@ impl Display for ListFunction { }, Slice => "slice", Shift => "shift", - Get => "get", + Get(_) => "get", #[cfg(feature = "list_gather")] Gather(_) => "gather", #[cfg(feature = "list_gather")] @@ -203,9 +203,9 @@ impl From for SpecialEq> { }, Slice => wrap!(slice), Shift => map_as_slice!(shift), - Get => wrap!(get), + Get(null_on_oob) => wrap!(get, null_on_oob), #[cfg(feature = "list_gather")] - Gather(null_ob_oob) => map_as_slice!(gather, null_ob_oob), + Gather(null_on_oob) => map_as_slice!(gather, null_on_oob), #[cfg(feature = "list_gather")] GatherEvery => map_as_slice!(gather_every), #[cfg(feature = "list_count")] @@ -414,7 +414,7 @@ pub(super) fn concat(s: &mut [Series]) -> PolarsResult> { first_ca.lst_concat(other).map(|ca| Some(ca.into_series())) } -pub(super) fn get(s: &mut [Series]) -> PolarsResult> { +pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult> { let ca = s[0].list()?; let index = s[1].cast(&DataType::Int64)?; let index = index.i64().unwrap(); @@ -423,7 +423,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { 1 => { let index = index.get(0); if let Some(index) = index { - ca.lst_get(index).map(Some) + ca.lst_get(index, null_on_oob).map(Some) } else { Ok(Some(Series::full_null( ca.name(), @@ -440,19 +440,24 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { let take_by = index .into_iter() .enumerate() - .map(|(i, opt_idx)| { - opt_idx.and_then(|idx| { + .map(|(i, opt_idx)| match opt_idx { + Some(idx) => { let (start, end) = unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) }; let offset = if idx >= 0 { start + idx } else { end + idx }; if offset >= end || offset < start || start == end { - None + if null_on_oob { + Ok(None) + } else { + polars_bail!(ComputeError: "get index is out of bounds"); + } } else { - Some(offset as IdxSize) + Ok(Some(offset as IdxSize)) } - }) + }, + None => Ok(None), }) - .collect::(); + .collect::>()?; let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); unsafe { s.take_unchecked(&take_by) } .cast(&ca.inner_dtype()) @@ -475,7 +480,7 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult if idx.len() == 1 && null_on_oob { // fast path let idx = idx.get(0)?.try_extract::()?; - let out = ca.lst_get(idx)?; + let out = ca.lst_get(idx, null_on_oob)?; // make sure we return a list out.reshape(&[-1, 1]) } else { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 397959e9980f..5ff04aee0098 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -718,6 +718,14 @@ macro_rules! wrap { ($e:expr) => { SpecialEq::new(Arc::new($e)) }; + + ($e:expr, $($args:expr),*) => {{ + let f = move |s: &mut [Series]| { + $e(s, $($args),*) + }; + + SpecialEq::new(Arc::new(f)) + }}; } // Fn(&[Series], args) diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 603ec2553590..0f6c15c755e7 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -151,9 +151,9 @@ impl ListNameSpace { } /// Get items in every sublist by index. - pub fn get(self, index: Expr) -> Expr { + pub fn get(self, index: Expr, null_on_oob: bool) -> Expr { self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::Get), + FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)), &[index], false, false, @@ -187,12 +187,12 @@ impl ListNameSpace { /// Get first item of every sublist. pub fn first(self) -> Expr { - self.get(lit(0i64)) + self.get(lit(0i64), true) } /// Get last item of every sublist. pub fn last(self) -> Expr { - self.get(lit(-1i64)) + self.get(lit(-1i64), true) } /// Join all string items in a sublist and place a separator between them. diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 2149912be665..6cfd5263c416 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> { // Array functions // ---- ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), - ArrayGet => self.visit_binary(|e, i| e.list().get(i)), + ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)), ArrayLength => self.visit_unary(|e| e.list().len()), ArrayMax => self.visit_unary(|e| e.list().max()), ArrayMean => self.visit_unary(|e| e.list().mean()), diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 474e586a1c62..3c827794ffdb 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -505,7 +505,12 @@ def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> E other_list.insert(0, wrap_expr(self._pyexpr)) return F.concat_list(other_list) - def get(self, index: int | Expr | str) -> Expr: + def get( + self, + index: int | Expr | str, + *, + null_on_oob: bool = True, + ) -> Expr: """ Get the value by index in the sublists. @@ -517,6 +522,10 @@ def get(self, index: int | Expr | str) -> Expr: ---------- index Index to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error Examples -------- @@ -534,7 +543,7 @@ def get(self, index: int | Expr | str) -> Expr: └───────────┴──────┘ """ index = parse_as_expression(index) - return wrap_expr(self._pyexpr.list_get(index)) + return wrap_expr(self._pyexpr.list_get(index, null_on_oob)) def gather( self, diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 9540cc6a2860..7879d96c9ea2 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -373,7 +373,12 @@ def concat(self, other: list[Series] | Series | list[Any]) -> Series: ] """ - def get(self, index: int | Series | list[int]) -> Series: + def get( + self, + index: int | Series | list[int], + *, + null_on_oob: bool = True, + ) -> Series: """ Get the value by index in the sublists. @@ -385,11 +390,15 @@ def get(self, index: int | Series | list[int]) -> Series: ---------- index Index to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error Examples -------- >>> s = pl.Series("a", [[3, 2, 1], [], [1, 2]]) - >>> s.list.get(0) + >>> s.list.get(0, null_on_oob=True) shape: (3,) Series: 'a' [i64] [ diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index fde544a6ce41..b00476c7bb3a 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -44,8 +44,12 @@ impl PyExpr { self.inner.clone().list().eval(expr.inner, parallel).into() } - fn list_get(&self, index: PyExpr) -> Self { - self.inner.clone().list().get(index.inner).into() + fn list_get(&self, index: PyExpr, null_on_oob: bool) -> Self { + self.inner + .clone() + .list() + .get(index.inner, null_on_oob) + .into() } fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index ba580ea8e6ba..fe0028d5067c 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -781,7 +781,7 @@ def test_list_gather_null_struct_14927() -> None: {"index": [1], "col_0": [None], "field_0": [None]}, schema={**df.schema, "field_0": pl.Float64}, ) - expr = pl.col("col_0").list.get(0).struct.field("field_0") + expr = pl.col("col_0").list.get(0, null_on_oob=True).struct.field("field_0") out = df.filter(pl.col("index") > 0).with_columns(expr) assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py index 570716e14fe5..40dc3561598c 100644 --- a/py-polars/tests/unit/namespaces/list/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -11,7 +11,7 @@ def test_list_arr_get() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(0) + out = a.list.get(0, null_on_oob=False) expected = pl.Series("a", [1, 4, 6]) assert_series_equal(out, expected) out = a.list[0] @@ -22,7 +22,74 @@ def test_list_arr_get() -> None: out = pl.select(pl.lit(a).list.first()).to_series() assert_series_equal(out, expected) - out = a.list.get(-1) + out = a.list.get(-1, null_on_oob=False) + expected = pl.Series("a", [3, 5, 9]) + assert_series_equal(out, expected) + out = a.list.last() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.last()).to_series() + assert_series_equal(out, expected) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + a.list.get(3, null_on_oob=False) + + # Null index. + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=False)) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + a.list.get(-3, null_on_oob=False) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + pl.DataFrame( + {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} + ).with_columns( + [ + pl.col("a").list.get(i, null_on_oob=False).alias(f"get_{i}") + for i in range(4) + ] + ) + + # get by indexes where some are out of bounds + df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select([pl.col("cars").list.get("indexes", null_on_oob=False)]).to_dict( + as_series=False + ) + + # exact on oob boundary + df = pl.DataFrame( + { + "index": [3, 3, 3], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(3, null_on_oob=False)) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=False)) + + +def test_list_arr_get_null_on_oob() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + out = a.list.get(0, null_on_oob=True) + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list[0] + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list.first() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.first()).to_series() + assert_series_equal(out, expected) + + out = a.list.get(-1, null_on_oob=True) expected = pl.Series("a", [3, 5, 9]) assert_series_equal(out, expected) out = a.list.last() @@ -31,24 +98,24 @@ def test_list_arr_get() -> None: assert_series_equal(out, expected) # Out of bounds index. - out = a.list.get(3) + out = a.list.get(3, null_on_oob=True) expected = pl.Series("a", [None, None, 9]) assert_series_equal(out, expected) # Null index. - out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None))) + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=True)) expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() assert_frame_equal(out_df, expected_df) a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(-3) + out = a.list.get(-3, null_on_oob=True) expected = pl.Series("a", [1, None, 7]) assert_series_equal(out, expected) assert pl.DataFrame( {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} ).with_columns( - [pl.col("a").list.get(i).alias(f"get_{i}") for i in range(4)] + [pl.col("a").list.get(i, null_on_oob=True).alias(f"get_{i}") for i in range(4)] ).to_dict(as_series=False) == { "a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]], "get_0": [1, 2, 3, 4, 7, None], @@ -60,9 +127,9 @@ def test_list_arr_get() -> None: # get by indexes where some are out of bounds df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) - assert df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False) == { - "cars": [2, 3, None, None] - } + assert df.select([pl.col("cars").list.get("indexes", null_on_oob=True)]).to_dict( + as_series=False + ) == {"cars": [2, 3, None, None]} # exact on oob boundary df = pl.DataFrame( { @@ -71,12 +138,12 @@ def test_list_arr_get() -> None: } ) - assert df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) == { - "lists": [None, None, 4] - } - assert df.select(pl.col("lists").list.get(pl.col("index"))).to_dict( + assert df.select(pl.col("lists").list.get(3, null_on_oob=True)).to_dict( as_series=False ) == {"lists": [None, None, 4]} + assert df.select( + pl.col("lists").list.get(pl.col("index"), null_on_oob=True) + ).to_dict(as_series=False) == {"lists": [None, None, 4]} def test_list_categorical_get() -> None: @@ -88,7 +155,9 @@ def test_list_categorical_get() -> None: } ) expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical) - assert_series_equal(df["actions"].list.get(0), expected, categorical_as_str=True) + assert_series_equal( + df["actions"].list.get(0, null_on_oob=True), expected, categorical_as_str=True + ) def test_contains() -> None: @@ -597,7 +666,7 @@ def test_select_from_list_to_struct_11143() -> None: def test_list_arr_get_8810() -> None: assert pl.DataFrame(pl.Series("a", [None], pl.List(pl.Int64))).select( - pl.col("a").list.get(0) + pl.col("a").list.get(0, null_on_oob=True) ).to_dict(as_series=False) == {"a": [None]}