Skip to content

Commit

Permalink
Block rounding durations to negative durations
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-sil committed Mar 21, 2024
1 parent ab8fae6 commit d14a6a8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
3 changes: 3 additions & 0 deletions crates/polars-time/src/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ impl PolarsRound for DateChunked {
#[cfg(feature = "dtype-duration")]
impl PolarsRound for DurationChunked {
fn round(&self, every: Duration, offset: Duration, _tz: Option<&Tz>) -> PolarsResult<Self> {
if every.negative {
polars_bail!(ComputeError: "cannot round a Duration to a negative duration")
}
if !every.is_constant_duration() {
polars_bail!(InvalidOperation: "Cannot round a Duration series to a non-constant duration.");
}
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-time/src/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ impl PolarsTruncate for DurationChunked {
let out = if every.len() == 1 {
if let Some(every) = every.get(0) {
let every_duration = Duration::parse(every);
if every_duration.negative {
polars_bail!(ComputeError: "cannot truncate a Duration to a negative duration")
}
if every_duration.is_constant_duration() {
let every_units = to_time_unit(&every_duration);

Expand All @@ -152,6 +155,9 @@ impl PolarsTruncate for DurationChunked {
try_binary_elementwise(self, every, |opt_duration, opt_every| {
if let (Some(duration), Some(every)) = (opt_duration, opt_every) {
let every_duration = Duration::parse(every);
if every_duration.negative {
polars_bail!(ComputeError: "cannot truncate a Duration to a negative duration")
}
if every_duration.is_constant_duration() {
let every_units = to_time_unit(&every_duration);

Expand Down
24 changes: 19 additions & 5 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ def test_truncate_duration(time_unit: TimeUnit) -> None:
).dt.cast_time_unit(time_unit)

assert_series_equal(durations.dt.truncate("10s"), expected)
assert_series_equal(durations.dt.truncate("-10s"), expected)


def test_truncate_duration_zero() -> None:
Expand Down Expand Up @@ -591,7 +590,8 @@ def test_truncate_negative() -> None:
{
"date": [date(1895, 5, 7), date(1955, 11, 5)],
"datetime": [datetime(1895, 5, 7), datetime(1955, 11, 5)],
"duration": ["-1m", "1m"],
"duration": [timedelta(minutes=1), timedelta(minutes=-1)],
"every": ["-1m", "1m"],
}
)

Expand All @@ -605,15 +605,25 @@ def test_truncate_negative() -> None:
):
df.select(pl.col("datetime").dt.truncate("-1m"))

with pytest.raises(
ComputeError, match="cannot truncate a Duration to a negative duration"
):
df.select(pl.col("duration").dt.truncate("-1m"))

with pytest.raises(
ComputeError, match="cannot truncate a Date to a negative duration"
):
df.select(pl.col("date").dt.truncate(pl.col("duration")))
df.select(pl.col("date").dt.truncate(pl.col("every")))

with pytest.raises(
ComputeError, match="cannot truncate a Datetime to a negative duration"
):
df.select(pl.col("datetime").dt.truncate(pl.col("duration")))
df.select(pl.col("datetime").dt.truncate(pl.col("every")))

with pytest.raises(
ComputeError, match="cannot truncate a Duration to a negative duration"
):
df.select(pl.col("duration").dt.truncate(pl.col("every")))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -673,7 +683,6 @@ def test_round_duration(time_unit: TimeUnit) -> None:
).dt.cast_time_unit(time_unit)

assert_series_equal(durations.dt.round("10s"), expected)
assert_series_equal(durations.dt.round("-10s"), expected)


def test_round_duration_zero() -> None:
Expand Down Expand Up @@ -745,6 +754,11 @@ def test_round_negative() -> None:
):
pl.Series([datetime(1895, 5, 7)]).dt.round("-1m")

with pytest.raises(
ComputeError, match="cannot round a Duration to a negative duration"
):
pl.Series([timedelta(days=1)]).dt.round("-1m")


@pytest.mark.parametrize(
("time_unit", "date_in_that_unit"),
Expand Down

0 comments on commit d14a6a8

Please sign in to comment.