From b3172aa411ebcc2eae8adc49b977fe8f7bbdbd34 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 29 Aug 2024 11:53:33 +0200 Subject: [PATCH] fix: Expr.sign should preserve dtype (#18446) --- .../polars-plan/src/dsl/function_expr/sign.rs | 53 ++++++++----------- py-polars/polars/expr/expr.py | 29 +++++----- py-polars/polars/series/series.py | 27 +++++----- .../map/test_inefficient_map_warning.py | 4 -- py-polars/tests/unit/series/test_series.py | 4 +- py-polars/tests/unit/sql/test_operators.py | 2 +- 6 files changed, 55 insertions(+), 64 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/sign.rs b/crates/polars-plan/src/dsl/function_expr/sign.rs index 41707664e3ac..a7bf4d3277e6 100644 --- a/crates/polars-plan/src/dsl/function_expr/sign.rs +++ b/crates/polars-plan/src/dsl/function_expr/sign.rs @@ -1,41 +1,34 @@ +use num::{One, Zero}; use polars_core::export::num; -use DataType::*; +use polars_core::with_match_physical_numeric_polars_type; use super::*; pub(super) fn sign(s: &Series) -> PolarsResult { - match s.dtype() { - Float32 => { - let ca = s.f32().unwrap(); - sign_float(ca) - }, - Float64 => { - let ca = s.f64().unwrap(); - sign_float(ca) - }, - dt if dt.is_numeric() => { - let s = s.cast(&Float64)?; - sign(&s) - }, - dt => polars_bail!(opq = sign, dt), - } + let dt = s.dtype(); + polars_ensure!(dt.is_numeric(), opq = sign, dt); + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref(); + Ok(sign_impl(ca)) + }) } -fn sign_float(ca: &ChunkedArray) -> PolarsResult +fn sign_impl(ca: &ChunkedArray) -> Series where - T: PolarsFloatType, - T::Native: num::Float, + T: PolarsNumericType, ChunkedArray: IntoSeries, { - ca.apply_values(signum_improved).into_series().cast(&Int64) -} - -// Wrapper for the signum function that handles +/-0.0 inputs differently -// See discussion here: https://github.com/rust-lang/rust/issues/57543 -fn signum_improved(v: F) -> F { - if v.is_zero() { - v - } else { - v.signum() - } + ca.apply_values(|x| { + if x < T::Native::zero() { + T::Native::zero() - T::Native::one() + } else if x > T::Native::zero() { + T::Native::one() + } else { + // Returning x here ensures we return NaN for NaN input, and + // maintain the sign for signed zeroes (although we don't really + // care about the latter). + x + } + }) + .into_series() } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 54c9ba55e09d..962891d7aa58 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8746,30 +8746,31 @@ def upper_bound(self) -> Expr: def sign(self) -> Expr: """ - Compute the element-wise indication of the sign. + Compute the element-wise sign function on numeric types. - The returned values can be -1, 0, or 1: + The returned value is computed as follows: - * -1 if x < 0. - * 0 if x == 0. - * 1 if x > 0. + * -1 if x < 0. + * 1 if x > 0. + * x otherwise (typically 0, but could be NaN if the input is). - (null values are preserved as-is). + Null values are preserved as-is, and the dtype of the input is preserved. Examples -------- - >>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, None]}) - >>> df.select(pl.col("a").sign()) - shape: (5, 1) + >>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, float("nan"), None]}) + >>> df.select(pl.col.a.sign()) + shape: (6, 1) ┌──────┐ │ a │ │ --- │ - │ i64 │ + │ f64 │ ╞══════╡ - │ -1 │ - │ 0 │ - │ 0 │ - │ 1 │ + │ -1.0 │ + │ -0.0 │ + │ 0.0 │ + │ 1.0 │ + │ NaN │ │ null │ └──────┘ """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 901e424641fa..87e0d38f0b51 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4991,27 +4991,28 @@ def mode(self) -> Series: def sign(self) -> Series: """ - Compute the element-wise indication of the sign. + Compute the element-wise sign function on numeric types. - The returned values can be -1, 0, or 1: + The returned value is computed as follows: - * -1 if x < 0. - * 0 if x == 0. - * 1 if x > 0. + * -1 if x < 0. + * 1 if x > 0. + * x otherwise (typically 0, but could be NaN if the input is). - (null values are preserved as-is). + Null values are preserved as-is, and the dtype of the input is preserved. Examples -------- - >>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None]) + >>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None]) >>> s.sign() - shape: (5,) - Series: 'a' [i64] + shape: (6,) + Series: 'a' [f64] [ - -1 - 0 - 0 - 1 + -1.0 + -0.0 + 0.0 + 1.0 + NaN null ] """ diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 4c6877d08694..57a8ce795dc1 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -343,10 +343,6 @@ def test_parse_apply_raw_functions() -> None: ): df1 = lf.select(pl.col("a").map_elements(func)).collect() df2 = lf.select(getattr(pl.col("a"), func_name)()).collect() - if func_name == "sign": - # note: Polars' 'sign' function returns an Int64, while numpy's - # 'sign' function returns a Float64 - df1 = df1.with_columns(pl.col("a").cast(pl.Int64)) assert_frame_equal(df1, df2) # test bare 'json.loads' diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 08884308af48..3f7f159ccae0 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1747,8 +1747,8 @@ def test_sign() -> None: assert_series_equal(a.sign(), expected) # Floats - a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None]) - expected = pl.Series("a", [-1, 0, 0, 1, None]) + a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None]) + expected = pl.Series("a", [-1.0, 0.0, 0.0, 1.0, float("nan"), None]) assert_series_equal(a.sign(), expected) # Invalid input diff --git a/py-polars/tests/unit/sql/test_operators.py b/py-polars/tests/unit/sql/test_operators.py index 668ead0bc087..278e0776f2d7 100644 --- a/py-polars/tests/unit/sql/test_operators.py +++ b/py-polars/tests/unit/sql/test_operators.py @@ -37,7 +37,7 @@ def test_div() -> None: [ [-0.0995024875621891, 2.85714285714286, 12.0, None, -15.92356687898089], [-1, 2, 12, None, -16], - [-1, 1, 1, None, -1], + [-1.0, 1.0, 1.0, None, -1.0], ], schema=["a_div_b", "a_floordiv_b", "b_sign"], ),