Skip to content

Commit

Permalink
feat(rust, python): add time_unit argument to duration
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Oct 7, 2023
1 parent c01d599 commit 8f3f29e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 23 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use correlation::*;
pub use horizontal::*;
pub use index::*;
#[cfg(feature = "temporal")]
use polars_core::export::arrow::temporal_conversions::NANOSECONDS;
use polars_core::export::arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
#[cfg(feature = "temporal")]
use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
#[cfg(feature = "dtype-struct")]
Expand Down
69 changes: 48 additions & 21 deletions crates/polars-plan/src/dsl/functions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ pub struct DurationArgs {
pub milliseconds: Expr,
pub microseconds: Expr,
pub nanoseconds: Expr,
pub time_unit: TimeUnit,
}

impl Default for DurationArgs {
Expand All @@ -190,6 +191,7 @@ impl Default for DurationArgs {
milliseconds: lit(0),
microseconds: lit(0),
nanoseconds: lit(0),
time_unit: TimeUnit::Nanoseconds,
}
}
}
Expand Down Expand Up @@ -258,15 +260,15 @@ pub fn duration(args: DurationArgs) -> Expr {
if s.iter().any(|s| s.is_empty()) {
return Ok(Some(Series::new_empty(
s[0].name(),
&DataType::Duration(TimeUnit::Nanoseconds),
&DataType::Duration(args.time_unit),
)));
}

let days = s[0].cast(&DataType::Int64).unwrap();
let seconds = s[1].cast(&DataType::Int64).unwrap();
let mut nanoseconds = s[2].cast(&DataType::Int64).unwrap();
let microseconds = s[3].cast(&DataType::Int64).unwrap();
let milliseconds = s[4].cast(&DataType::Int64).unwrap();
let mut microseconds = s[3].cast(&DataType::Int64).unwrap();
let mut milliseconds = s[4].cast(&DataType::Int64).unwrap();
let minutes = s[5].cast(&DataType::Int64).unwrap();
let hours = s[6].cast(&DataType::Int64).unwrap();
let weeks = s[7].cast(&DataType::Int64).unwrap();
Expand All @@ -278,34 +280,59 @@ pub fn duration(args: DurationArgs) -> Expr {
(s.len() != max_len && s.get(0).unwrap() != AnyValue::Int64(0)) || s.len() == max_len
};

if nanoseconds.len() != max_len {
nanoseconds = nanoseconds.new_from_index(0, max_len);
}
if condition(&microseconds) {
nanoseconds = nanoseconds + (microseconds * 1_000);
}
if condition(&milliseconds) {
nanoseconds = nanoseconds + (milliseconds * 1_000_000);
}
let multiplier = match args.time_unit {
TimeUnit::Nanoseconds => NANOSECONDS,
TimeUnit::Microseconds => MICROSECONDS,
TimeUnit::Milliseconds => MILLISECONDS,
};

let mut duration = match args.time_unit {
TimeUnit::Nanoseconds => {
if nanoseconds.len() != max_len {
nanoseconds = nanoseconds.new_from_index(0, max_len);
}
if condition(&microseconds) {
nanoseconds = nanoseconds + (microseconds * 1_000);
}
if condition(&milliseconds) {
nanoseconds = nanoseconds + (milliseconds * 1_000_000);
}
nanoseconds
},
TimeUnit::Microseconds => {
if microseconds.len() != max_len {
microseconds = microseconds.new_from_index(0, max_len);
}
if condition(&milliseconds) {
microseconds = microseconds + (milliseconds * 1_000);
}
microseconds
},
TimeUnit::Milliseconds => {
if milliseconds.len() != max_len {
milliseconds = milliseconds.new_from_index(0, max_len);
}
milliseconds
},
};

if condition(&seconds) {
nanoseconds = nanoseconds + (seconds * NANOSECONDS);
duration = duration + (seconds * multiplier);
}
if condition(&days) {
nanoseconds = nanoseconds + (days * NANOSECONDS * SECONDS_IN_DAY);
duration = duration + (days * multiplier * SECONDS_IN_DAY);
}
if condition(&minutes) {
nanoseconds = nanoseconds + minutes * NANOSECONDS * 60;
duration = duration + minutes * multiplier * 60;
}
if condition(&hours) {
nanoseconds = nanoseconds + hours * NANOSECONDS * 60 * 60;
duration = duration + hours * multiplier * 60 * 60;
}
if condition(&weeks) {
nanoseconds = nanoseconds + weeks * NANOSECONDS * SECONDS_IN_DAY * 7;
duration = duration + weeks * multiplier * SECONDS_IN_DAY * 7;
}

nanoseconds
.cast(&DataType::Duration(TimeUnit::Nanoseconds))
.map(Some)
duration.cast(&DataType::Duration(args.time_unit)).map(Some)
}) as Arc<dyn SeriesUdf>);

Expr::AnonymousFunction {
Expand All @@ -320,7 +347,7 @@ pub fn duration(args: DurationArgs) -> Expr {
args.weeks,
],
function,
output_type: GetOutput::from_type(DataType::Duration(TimeUnit::Nanoseconds)),
output_type: GetOutput::from_type(DataType::Duration(args.time_unit)),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/functions/as_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def duration(
minutes: Expr | str | int | None = None,
hours: Expr | str | int | None = None,
weeks: Expr | str | int | None = None,
time_unit: TimeUnit = "ns",
) -> Expr:
"""
Create polars `Duration` from distinct time components.
Expand Down Expand Up @@ -294,6 +295,7 @@ def duration(
minutes,
hours,
weeks,
time_unit,
)
)

Expand Down
3 changes: 3 additions & 0 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ pub fn dtype_cols(dtypes: Vec<Wrap<DataType>>) -> PyResult<PyExpr> {

#[allow(clippy::too_many_arguments)]
#[pyfunction]
#[pyo3(signature = (days, seconds, nanoseconds, microseconds, milliseconds, minutes, hours, weeks, time_unit))]
pub fn duration(
days: Option<PyExpr>,
seconds: Option<PyExpr>,
Expand All @@ -297,6 +298,7 @@ pub fn duration(
minutes: Option<PyExpr>,
hours: Option<PyExpr>,
weeks: Option<PyExpr>,
time_unit: Wrap<TimeUnit>,
) -> PyExpr {
set_unwrapped_or_0!(
days,
Expand All @@ -317,6 +319,7 @@ pub fn duration(
minutes,
hours,
weeks,
time_unit: time_unit.0,
};
dsl::duration(args).into()
}
Expand Down
28 changes: 27 additions & 1 deletion py-polars/tests/unit/functions/test_as_datatype.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from datetime import date, datetime
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -100,6 +100,32 @@ def test_empty_duration() -> None:
assert s.shape == (0, 1)


@pytest.mark.parametrize(
("time_unit", "expected"),
[
("ms", timedelta(days=1, minutes=2, seconds=3, milliseconds=4)),
("us", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)),
("ns", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)),
],
)
def test_duration_time_units(time_unit: TimeUnit, expected: timedelta) -> None:
result = pl.LazyFrame().select(
pl.duration(
days=1,
minutes=2,
seconds=3,
milliseconds=4,
microseconds=5,
nanoseconds=6,
time_unit=time_unit,
)
)
assert result.schema["duration"] == pl.Duration(time_unit)
assert result.collect()["duration"].item() == expected
if time_unit == "ns":
assert result.collect()["duration"].dt.nanoseconds().item() == 86523004005006


def test_list_concat() -> None:
s0 = pl.Series("a", [[1, 2]])
s1 = pl.Series("b", [[3, 4, 5]])
Expand Down

0 comments on commit 8f3f29e

Please sign in to comment.