From 8f3f29ecc4b66902c3bf70bb9101ab41b20dc739 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 7 Oct 2023 16:42:53 +0300 Subject: [PATCH] feat(rust, python): add time_unit argument to duration --- crates/polars-plan/src/dsl/functions/mod.rs | 2 +- .../polars-plan/src/dsl/functions/temporal.rs | 69 +++++++++++++------ py-polars/polars/functions/as_datatype.py | 2 + py-polars/src/functions/lazy.rs | 3 + .../tests/unit/functions/test_as_datatype.py | 28 +++++++- 5 files changed, 81 insertions(+), 23 deletions(-) diff --git a/crates/polars-plan/src/dsl/functions/mod.rs b/crates/polars-plan/src/dsl/functions/mod.rs index fd8d8247e339..f95be89c0af6 100644 --- a/crates/polars-plan/src/dsl/functions/mod.rs +++ b/crates/polars-plan/src/dsl/functions/mod.rs @@ -21,7 +21,7 @@ pub use correlation::*; pub use horizontal::*; pub use index::*; #[cfg(feature = "temporal")] -use polars_core::export::arrow::temporal_conversions::NANOSECONDS; +use polars_core::export::arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; #[cfg(feature = "temporal")] use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index e6b866c2ac10..b3cc5f5fe8e1 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -177,6 +177,7 @@ pub struct DurationArgs { pub milliseconds: Expr, pub microseconds: Expr, pub nanoseconds: Expr, + pub time_unit: TimeUnit, } impl Default for DurationArgs { @@ -190,6 +191,7 @@ impl Default for DurationArgs { milliseconds: lit(0), microseconds: lit(0), nanoseconds: lit(0), + time_unit: TimeUnit::Nanoseconds, } } } @@ -258,15 +260,15 @@ pub fn duration(args: DurationArgs) -> Expr { if s.iter().any(|s| s.is_empty()) { return Ok(Some(Series::new_empty( s[0].name(), - &DataType::Duration(TimeUnit::Nanoseconds), + &DataType::Duration(args.time_unit), ))); } let days = s[0].cast(&DataType::Int64).unwrap(); let seconds = s[1].cast(&DataType::Int64).unwrap(); let mut nanoseconds = s[2].cast(&DataType::Int64).unwrap(); - let microseconds = s[3].cast(&DataType::Int64).unwrap(); - let milliseconds = s[4].cast(&DataType::Int64).unwrap(); + let mut microseconds = s[3].cast(&DataType::Int64).unwrap(); + let mut milliseconds = s[4].cast(&DataType::Int64).unwrap(); let minutes = s[5].cast(&DataType::Int64).unwrap(); let hours = s[6].cast(&DataType::Int64).unwrap(); let weeks = s[7].cast(&DataType::Int64).unwrap(); @@ -278,34 +280,59 @@ pub fn duration(args: DurationArgs) -> Expr { (s.len() != max_len && s.get(0).unwrap() != AnyValue::Int64(0)) || s.len() == max_len }; - if nanoseconds.len() != max_len { - nanoseconds = nanoseconds.new_from_index(0, max_len); - } - if condition(µseconds) { - nanoseconds = nanoseconds + (microseconds * 1_000); - } - if condition(&milliseconds) { - nanoseconds = nanoseconds + (milliseconds * 1_000_000); - } + let multiplier = match args.time_unit { + TimeUnit::Nanoseconds => NANOSECONDS, + TimeUnit::Microseconds => MICROSECONDS, + TimeUnit::Milliseconds => MILLISECONDS, + }; + + let mut duration = match args.time_unit { + TimeUnit::Nanoseconds => { + if nanoseconds.len() != max_len { + nanoseconds = nanoseconds.new_from_index(0, max_len); + } + if condition(µseconds) { + nanoseconds = nanoseconds + (microseconds * 1_000); + } + if condition(&milliseconds) { + nanoseconds = nanoseconds + (milliseconds * 1_000_000); + } + nanoseconds + }, + TimeUnit::Microseconds => { + if microseconds.len() != max_len { + microseconds = microseconds.new_from_index(0, max_len); + } + if condition(&milliseconds) { + microseconds = microseconds + (milliseconds * 1_000); + } + microseconds + }, + TimeUnit::Milliseconds => { + if milliseconds.len() != max_len { + milliseconds = milliseconds.new_from_index(0, max_len); + } + milliseconds + }, + }; + if condition(&seconds) { - nanoseconds = nanoseconds + (seconds * NANOSECONDS); + duration = duration + (seconds * multiplier); } if condition(&days) { - nanoseconds = nanoseconds + (days * NANOSECONDS * SECONDS_IN_DAY); + duration = duration + (days * multiplier * SECONDS_IN_DAY); } if condition(&minutes) { - nanoseconds = nanoseconds + minutes * NANOSECONDS * 60; + duration = duration + minutes * multiplier * 60; } if condition(&hours) { - nanoseconds = nanoseconds + hours * NANOSECONDS * 60 * 60; + duration = duration + hours * multiplier * 60 * 60; } if condition(&weeks) { - nanoseconds = nanoseconds + weeks * NANOSECONDS * SECONDS_IN_DAY * 7; + duration = duration + weeks * multiplier * SECONDS_IN_DAY * 7; } - nanoseconds - .cast(&DataType::Duration(TimeUnit::Nanoseconds)) - .map(Some) + duration.cast(&DataType::Duration(args.time_unit)).map(Some) }) as Arc); Expr::AnonymousFunction { @@ -320,7 +347,7 @@ pub fn duration(args: DurationArgs) -> Expr { args.weeks, ], function, - output_type: GetOutput::from_type(DataType::Duration(TimeUnit::Nanoseconds)), + output_type: GetOutput::from_type(DataType::Duration(args.time_unit)), options: FunctionOptions { collect_groups: ApplyOptions::ApplyFlat, input_wildcard_expansion: true, diff --git a/py-polars/polars/functions/as_datatype.py b/py-polars/polars/functions/as_datatype.py index d991131a33d7..1760aae27f6a 100644 --- a/py-polars/polars/functions/as_datatype.py +++ b/py-polars/polars/functions/as_datatype.py @@ -185,6 +185,7 @@ def duration( minutes: Expr | str | int | None = None, hours: Expr | str | int | None = None, weeks: Expr | str | int | None = None, + time_unit: TimeUnit = "ns", ) -> Expr: """ Create polars `Duration` from distinct time components. @@ -294,6 +295,7 @@ def duration( minutes, hours, weeks, + time_unit, ) ) diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 574f824569b9..93017aa529ed 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -288,6 +288,7 @@ pub fn dtype_cols(dtypes: Vec>) -> PyResult { #[allow(clippy::too_many_arguments)] #[pyfunction] +#[pyo3(signature = (days, seconds, nanoseconds, microseconds, milliseconds, minutes, hours, weeks, time_unit))] pub fn duration( days: Option, seconds: Option, @@ -297,6 +298,7 @@ pub fn duration( minutes: Option, hours: Option, weeks: Option, + time_unit: Wrap, ) -> PyExpr { set_unwrapped_or_0!( days, @@ -317,6 +319,7 @@ pub fn duration( minutes, hours, weeks, + time_unit: time_unit.0, }; dsl::duration(args).into() } diff --git a/py-polars/tests/unit/functions/test_as_datatype.py b/py-polars/tests/unit/functions/test_as_datatype.py index 6a92f0effd4f..f0d3b333a9b9 100644 --- a/py-polars/tests/unit/functions/test_as_datatype.py +++ b/py-polars/tests/unit/functions/test_as_datatype.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import date, datetime +from datetime import date, datetime, timedelta from typing import TYPE_CHECKING import pytest @@ -100,6 +100,32 @@ def test_empty_duration() -> None: assert s.shape == (0, 1) +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("ms", timedelta(days=1, minutes=2, seconds=3, milliseconds=4)), + ("us", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)), + ("ns", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)), + ], +) +def test_duration_time_units(time_unit: TimeUnit, expected: timedelta) -> None: + result = pl.LazyFrame().select( + pl.duration( + days=1, + minutes=2, + seconds=3, + milliseconds=4, + microseconds=5, + nanoseconds=6, + time_unit=time_unit, + ) + ) + assert result.schema["duration"] == pl.Duration(time_unit) + assert result.collect()["duration"].item() == expected + if time_unit == "ns": + assert result.collect()["duration"].dt.nanoseconds().item() == 86523004005006 + + def test_list_concat() -> None: s0 = pl.Series("a", [[1, 2]]) s1 = pl.Series("b", [[3, 4, 5]])