Skip to content

Commit

Permalink
fix: Correct wildcard and input expansion for some more functions (#1…
Browse files Browse the repository at this point in the history
…9588)

Co-authored-by: siddharthv <[email protected]>
  • Loading branch information
siddharth-vi and siddharthv authored Nov 4, 2024
1 parent 7be7f06 commit e52a598
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 17 deletions.
17 changes: 6 additions & 11 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,12 @@ impl ArrayNameSpace {
pub fn count_matches<E: Into<Expr>>(self, element: E) -> Expr {
let other = element.into();

self.0
.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::CountMatches),
&[other],
false,
None,
)
.with_function_options(|mut options| {
options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION;
options
})
self.0.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::CountMatches),
&[other],
false,
None,
)
}

#[cfg(feature = "array_to_struct")]
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-plan/src/dsl/functions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
}),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION
| FunctionFlags::ALLOW_RENAME,
flags: FunctionFlags::default() | FunctionFlags::ALLOW_RENAME,
fmt_str: "datetime",
..Default::default()
},
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,7 @@ impl ListNameSpace {
cast_to_supertypes: Some(SuperTypeOptions {
flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST,
}),
flags: FunctionFlags::default()
| FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR,
flags: FunctionFlags::default() & !FunctionFlags::RETURNS_SCALAR,
..Default::default()
},
}
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/functions/as_datatype/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,13 @@ def test_datetime_ambiguous_time_zone_earliest() -> None:
expected = datetime(2018, 10, 28, 2, 30, tzinfo=ZoneInfo("Europe/Brussels"))
assert result == expected
assert result.fold == 0


def test_datetime_wildcard_expansion() -> None:
df = pl.DataFrame({"a": [1], "b": [2]})
assert df.select(
pl.datetime(year=pl.all(), month=pl.all(), day=pl.all()).name.keep()
).to_dict(as_series=False) == {
"a": [datetime(1, 1, 1, 0, 0)],
"b": [datetime(2, 2, 2, 0, 0)],
}
11 changes: 11 additions & 0 deletions py-polars/tests/unit/operations/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,17 @@ def test_array_count_matches(
assert out.to_dict(as_series=False) == {"count_matches": expected}


def test_array_count_matches_wildcard_expansion() -> None:
df = pl.DataFrame(
{"a": [[1, 2]], "b": [[3, 4]]},
schema={"a": pl.Array(pl.Int64, 2), "b": pl.Array(pl.Int64, 2)},
)
assert df.select(pl.all().arr.count_matches(3)).to_dict(as_series=False) == {
"a": [0],
"b": [1],
}


def test_array_to_struct() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [4, 5, None]]}, schema={"a": pl.Array(pl.Int8, 3)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def test_list_set_operations_float() -> None:

def test_list_set_operations() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}
{
"a": [[1, 2, 3], [1, 1, 1], [4]],
"b": [[4, 2, 1], [2, 1, 12], [4]],
"c": [[1, 2], [2, 1, 76], [8, 9]],
}
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
Expand All @@ -64,6 +68,24 @@ def test_list_set_operations() -> None:
[],
]

# check expansion of columns
assert df.select(pl.col("a", "b").list.set_intersection("c")).to_dict(
as_series=False
) == {"a": [[1, 2], [1], []], "b": [[2, 1], [2, 1], []]}

assert df.select(pl.col("a", "b").list.set_union("c")).to_dict(as_series=False) == {
"a": [[1, 2, 3], [1, 2, 76], [4, 8, 9]],
"b": [[4, 2, 1], [2, 1, 12, 76], [4, 8, 9]],
}

assert df.select(pl.col("a", "b").list.set_difference("c")).to_dict(
as_series=False
) == {"a": [[3], [], [4]], "b": [[4], [12], [4]]}

assert df.select(pl.col("a", "b").list.set_symmetric_difference("c")).to_dict(
as_series=False
) == {"a": [[3], [2, 76], [4, 8, 9]], "b": [[4], [12, 76], [4, 8, 9]]}

# check logical types
dtype = pl.List(pl.Date)
assert (
Expand Down

0 comments on commit e52a598

Please sign in to comment.