Skip to content

Commit

Permalink
feat: Always preserve sorted flag for .dt.date (#18692)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Sep 15, 2024
1 parent 766e8e5 commit 5a262db
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 37 deletions.
36 changes: 15 additions & 21 deletions crates/polars-plan/src/dsl/function_expr/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub(super) fn time(s: &Column) -> PolarsResult<Column> {
pub(super) fn date(s: &Column) -> PolarsResult<Column> {
match s.dtype() {
#[cfg(feature = "timezones")]
DataType::Datetime(_, Some(tz)) => {
DataType::Datetime(_, Some(_)) => {
let mut out = {
polars_ops::chunked_array::replace_time_zone(
s.datetime().unwrap(),
Expand All @@ -279,10 +279,12 @@ pub(super) fn date(s: &Column) -> PolarsResult<Column> {
)?
.cast(&DataType::Date)?
};
if tz != "UTC" {
// DST transitions may not preserve sortedness.
out.set_sorted_flag(IsSorted::Not);
}
// `replace_time_zone` may unset sorted flag. But, we're only taking the date
// part of the result, so we can safely preserve the sorted flag here. We may
// need to make an exception if a time zone introduces a change which involves
// "going back in time" and repeating a day, but we're not aware of that ever
// having happened.
out.set_sorted_flag(s.is_sorted_flag());
Ok(out.into())
},
DataType::Datetime(_, _) => s
Expand All @@ -297,22 +299,14 @@ pub(super) fn date(s: &Column) -> PolarsResult<Column> {
pub(super) fn datetime(s: &Column) -> PolarsResult<Column> {
match s.dtype() {
#[cfg(feature = "timezones")]
DataType::Datetime(tu, Some(tz)) => {
let mut out = {
polars_ops::chunked_array::replace_time_zone(
s.datetime().unwrap(),
None,
&StringChunked::from_iter(std::iter::once("raise")),
NonExistent::Raise,
)?
.cast(&DataType::Datetime(*tu, None))?
};
if tz != "UTC" {
// DST transitions may not preserve sortedness.
out.set_sorted_flag(IsSorted::Not);
}
Ok(out.into())
},
DataType::Datetime(tu, Some(_)) => polars_ops::chunked_array::replace_time_zone(
s.datetime().unwrap(),
None,
&StringChunked::from_iter(std::iter::once("raise")),
NonExistent::Raise,
)?
.cast(&DataType::Datetime(*tu, None))
.map(|x| x.into()),
DataType::Datetime(tu, _) => s
.datetime()
.unwrap()
Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,6 @@ def test_replace_time_zone_ambiguous_raises() -> None:
("from_tz", "expected_sortedness", "ambiguous"),
[
("Europe/London", False, "earliest"),
("Europe/London", False, "raise"),
("UTC", True, "earliest"),
("UTC", True, "raise"),
(None, True, "earliest"),
Expand All @@ -1375,20 +1374,20 @@ def test_replace_time_zone_sortedness_series(
from_tz: str | None, expected_sortedness: bool, ambiguous: Ambiguous
) -> None:
ser = (
pl.Series("ts", [1603584000000000, 1603587600000000])
pl.Series("ts", [1603584000000001, 1603587600000000])
.cast(pl.Datetime("us", from_tz))
.sort()
)
assert ser.flags["SORTED_ASC"]
result = ser.dt.replace_time_zone("UTC", ambiguous=ambiguous)
assert result.flags["SORTED_ASC"] == expected_sortedness
assert result.flags["SORTED_ASC"] == result.is_sorted()


@pytest.mark.parametrize(
("from_tz", "expected_sortedness", "ambiguous"),
[
("Europe/London", False, "earliest"),
("Europe/London", False, "raise"),
("UTC", True, "earliest"),
("UTC", True, "raise"),
(None, True, "earliest"),
Expand All @@ -1399,7 +1398,7 @@ def test_replace_time_zone_sortedness_expressions(
from_tz: str | None, expected_sortedness: bool, ambiguous: str
) -> None:
df = (
pl.Series("ts", [1603584000000000, 1603584060000000, 1603587600000000])
pl.Series("ts", [1603584000000001, 1603584060000000, 1603587600000000])
.cast(pl.Datetime("us", from_tz))
.sort()
.to_frame()
Expand All @@ -1410,6 +1409,7 @@ def test_replace_time_zone_sortedness_expressions(
pl.col("ts").dt.replace_time_zone("UTC", ambiguous=pl.col("ambiguous"))
)
assert result["ts"].flags["SORTED_ASC"] == expected_sortedness
assert result["ts"].is_sorted() == expected_sortedness


def test_invalid_ambiguous_value_in_expression() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,26 +125,19 @@ def test_dt_datetime_deprecated() -> None:
assert result.item() == expected


@pytest.mark.parametrize(
("time_zone", "expected"),
[
(None, True),
("Asia/Kathmandu", False),
("UTC", True),
],
)
def test_local_date_sortedness(time_zone: str | None, expected: bool) -> None:
# singleton - always sorted
@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu", "UTC"])
def test_local_date_sortedness(time_zone: str | None) -> None:
# singleton
ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort()
result = ser.dt.date()
assert result.flags["SORTED_ASC"]

# 2 elements - depends on time zone
# 2 elements
ser = (
pl.Series([datetime(2022, 1, 1, 23)] * 2).dt.replace_time_zone(time_zone)
).sort()
result = ser.dt.date()
assert result.flags["SORTED_ASC"] >= expected
assert result.flags["SORTED_ASC"]


@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu", "UTC"])
Expand Down

0 comments on commit 5a262db

Please sign in to comment.