Skip to content

Commit

Permalink
fix: Revert length check of patterns in str.extract_many() (#20953)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher authored Jan 29, 2025
1 parent 96a2d01 commit a7b933a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
1 change: 0 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ fn extract_many(
ascii_case_insensitive: bool,
overlapping: bool,
) -> PolarsResult<Column> {
_check_same_length(s, "extract_many")?;
let ca = s[0].str()?;
let patterns = &s[1];

Expand Down
28 changes: 9 additions & 19 deletions py-polars/tests/unit/operations/namespaces/string/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,21 +1824,17 @@ def test_replace_lit_n_char_13385(


def test_extract_many() -> None:
df = pl.DataFrame({"values": ["discontent"]})
df = pl.DataFrame({"values": ["discontent", "foobar"]})
patterns = ["winter", "disco", "onte", "discontent"]
assert (
df.with_columns(
pl.col("values")
.str.extract_many(patterns, overlapping=False)
.alias("matches"),
pl.col("values")
.str.extract_many(patterns, overlapping=True)
.alias("matches_overlapping"),
)
assert df.with_columns(
pl.col("values").str.extract_many(patterns, overlapping=False).alias("matches"),
pl.col("values")
.str.extract_many(patterns, overlapping=True)
.alias("matches_overlapping"),
).to_dict(as_series=False) == {
"values": ["discontent"],
"matches": [["disco"]],
"matches_overlapping": [["disco", "onte", "discontent"]],
"values": ["discontent", "foobar"],
"matches": [["disco"], []],
"matches_overlapping": [["disco", "onte", "discontent"], []],
}

# many patterns
Expand All @@ -1865,12 +1861,6 @@ def test_extract_many() -> None:
assert f2.to_list() == [[0], [0, 5]]


def test_str_extract_many_wrong_length() -> None:
df = pl.DataFrame({"num": ["-10", "-1", "0"]})
with pytest.raises(ComputeError, match="should have equal or unit length"):
df.select(pl.col("num").str.extract_many(pl.Series(["a", "b"])))


def test_json_decode_raise_on_data_type_mismatch_13061() -> None:
assert_series_equal(
pl.Series(["null", "null"]).str.json_decode(infer_schema_length=1),
Expand Down

0 comments on commit a7b933a

Please sign in to comment.