Skip to content

Commit

Permalink
fix: Properly calculate duration units
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Sep 23, 2024
1 parent ea7953e commit 8b9d3a7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
15 changes: 10 additions & 5 deletions crates/polars-ops/src/series/ops/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,24 @@ pub fn impl_duration(s: &[Column], time_unit: TimeUnit) -> PolarsResult<Column>
TimeUnit::Milliseconds => MILLISECONDS,
};
if !is_zero_scalar(&seconds) {
duration = ((duration + seconds)? * multiplier)?;
let units = seconds * multiplier;
duration = (duration + units?)?;
}
if !is_zero_scalar(&minutes) {
duration = ((duration + minutes)? * (multiplier * 60))?;
let units = minutes * (multiplier * 60);
duration = (duration + units?)?;
}
if !is_zero_scalar(&hours) {
duration = ((duration + hours)? * (multiplier * 60 * 60))?;
let units = hours * (multiplier * 60 * 60);
duration = (duration + units?)?;
}
if !is_zero_scalar(&days) {
duration = ((duration + days)? * (multiplier * SECONDS_IN_DAY))?;
let units = days * (multiplier * SECONDS_IN_DAY);
duration = (duration + units?)?;
}
if !is_zero_scalar(&weeks) {
duration = ((duration + weeks)? * (multiplier * SECONDS_IN_DAY * 7))?;
let units = weeks * (multiplier * SECONDS_IN_DAY * 7);
duration = (duration + units?)?;
}

duration
Expand Down
33 changes: 33 additions & 0 deletions py-polars/tests/unit/datatypes/test_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,36 @@ def test_series_duration_var_overflow() -> None:
s = pl.Series([timedelta(days=10), timedelta(days=20), timedelta(days=40)])
with pytest.raises(PanicException, match="OverflowError"):
s.var()


def test_series_duration_units() -> None:
td = timedelta

assert_frame_equal(
pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(weeks=pl.col("x"))),
pl.DataFrame({"x": [td(weeks=i) for i in range(4)]}),
)
assert_frame_equal(
pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(days=pl.col("x"))),
pl.DataFrame({"x": [td(days=i) for i in range(4)]}),
)
assert_frame_equal(
pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(hours=pl.col("x"))),
pl.DataFrame({"x": [td(hours=i) for i in range(4)]}),
)
assert_frame_equal(
pl.DataFrame({"x": [0, 1, 2, 3]}).select(x=pl.duration(minutes=pl.col("x"))),
pl.DataFrame({"x": [td(minutes=i) for i in range(4)]}),
)
assert_frame_equal(
pl.DataFrame({"x": [0, 1, 2, 3]}).select(
x=pl.duration(milliseconds=pl.col("x"))
),
pl.DataFrame({"x": [td(milliseconds=i) for i in range(4)]}),
)
assert_frame_equal(
pl.DataFrame({"x": [0, 1, 2, 3]}).select(
x=pl.duration(microseconds=pl.col("x"))
),
pl.DataFrame({"x": [td(microseconds=i) for i in range(4)]}),
)

0 comments on commit 8b9d3a7

Please sign in to comment.