Skip to content

Commit

Permalink
Changed default behavior to True to avoid breaking change.
Browse files Browse the repository at this point in the history
  • Loading branch information
edwjames committed Mar 31, 2024
1 parent 8c0c162 commit f1eb908
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
2 changes: 1 addition & 1 deletion crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> {
// Array functions
// ----
ArrayContains => self.visit_binary::<Expr>(|e, s| e.list().contains(s)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i, false)),
ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)),
ArrayLength => self.visit_unary(|e| e.list().len()),
ArrayMax => self.visit_unary(|e| e.list().max()),
ArrayMean => self.visit_unary(|e| e.list().mean()),
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def get(
self,
index: int | Expr | str,
*,
null_on_oob: bool = False,
null_on_oob: bool = True,
) -> Expr:
"""
Get the value by index in the sublists.
Expand All @@ -530,7 +530,7 @@ def get(
Examples
--------
>>> df = pl.DataFrame({"a": [[3, 2, 1], [], [1, 2]]})
>>> df.with_columns(get=pl.col("a").list.get(0, null_on_oob=True))
>>> df.with_columns(get=pl.col("a").list.get(0))
shape: (3, 2)
┌───────────┬──────┐
│ a ┆ get │
Expand Down Expand Up @@ -650,7 +650,7 @@ def first(self) -> Expr:
│ [1, 2] ┆ 1 │
└───────────┴───────┘
"""
return self.get(0, null_on_oob=True)
return self.get(0)

def last(self) -> Expr:
"""
Expand All @@ -671,7 +671,7 @@ def last(self) -> Expr:
│ [1, 2] ┆ 2 │
└───────────┴──────┘
"""
return self.get(-1, null_on_oob=True)
return self.get(-1)

def contains(
self, item: float | str | bool | int | date | datetime | time | IntoExprColumn
Expand Down
27 changes: 16 additions & 11 deletions py-polars/tests/unit/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def test_list_arr_get() -> None:
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])
out = a.list.get(0)
out = a.list.get(0, null_on_oob=False)
expected = pl.Series("a", [1, 4, 6])
assert_series_equal(out, expected)
out = a.list[0]
Expand All @@ -22,7 +22,7 @@ def test_list_arr_get() -> None:
out = pl.select(pl.lit(a).list.first()).to_series()
assert_series_equal(out, expected)

out = a.list.get(-1)
out = a.list.get(-1, null_on_oob=False)
expected = pl.Series("a", [3, 5, 9])
assert_series_equal(out, expected)
out = a.list.last()
Expand All @@ -31,30 +31,35 @@ def test_list_arr_get() -> None:
assert_series_equal(out, expected)

with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
a.list.get(3)
a.list.get(3, null_on_oob=False)

# Null index.
out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None)))
out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=False))
expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame()
assert_frame_equal(out_df, expected_df)

a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])

with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
a.list.get(-3)
a.list.get(-3, null_on_oob=False)

with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
pl.DataFrame(
{"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]}
).with_columns(
[pl.col("a").list.get(i).alias(f"get_{i}") for i in range(4)]
).to_dict(as_series=False)
[
pl.col("a").list.get(i, null_on_oob=False).alias(f"get_{i}")
for i in range(4)
]
)

# get by indexes where some are out of bounds
df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]})

with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False)
df.select([pl.col("cars").list.get("indexes", null_on_oob=False)]).to_dict(
as_series=False
)

# exact on oob boundary
df = pl.DataFrame(
Expand All @@ -65,15 +70,15 @@ def test_list_arr_get() -> None:
)

with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
df.select(pl.col("lists").list.get(3)).to_dict(as_series=False)
df.select(pl.col("lists").list.get(3, null_on_oob=False))

with pytest.raises(pl.ComputeError, match="get index is out of bounds"):
df.select(pl.col("lists").list.get(pl.col("index"))).to_dict(as_series=False)
df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=False))


def test_list_arr_get_null_on_oob() -> None:
a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]])
out = a.list.first()
out = a.list.get(0, null_on_oob=True)
expected = pl.Series("a", [1, 4, 6])
assert_series_equal(out, expected)
out = a.list[0]
Expand Down

0 comments on commit f1eb908

Please sign in to comment.