From 7dfc53e9a6e75516a69b81bd014737bc0569099e Mon Sep 17 00:00:00 2001 From: Lava <34743145+CanglongCl@users.noreply.github.com> Date: Sat, 6 Apr 2024 05:17:01 -0700 Subject: [PATCH] fix(rust, python): `pow` return type evaluation (#15506) --- crates/polars-core/src/utils/mod.rs | 17 +++ .../polars-plan/src/dsl/function_expr/pow.rs | 107 ++++++++---------- .../src/dsl/function_expr/schema.rs | 18 +-- py-polars/polars/dataframe/frame.py | 60 +++++----- py-polars/polars/expr/expr.py | 32 +++--- py-polars/polars/functions/lazy.py | 6 +- py-polars/polars/lazyframe/frame.py | 60 +++++----- py-polars/polars/series/series.py | 28 ++--- .../map/test_inefficient_map_warning.py | 4 +- py-polars/tests/unit/series/test_series.py | 40 +++++-- py-polars/tests/unit/sql/test_functions.py | 2 +- 11 files changed, 197 insertions(+), 177 deletions(-) diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index ef864b4a226b..90e23998f07c 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -319,11 +319,15 @@ macro_rules! with_match_physical_integer_type {( macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} use $crate::datatypes::DataType::*; match $dtype { + #[cfg(feature = "dtype-i8")] Int8 => __with_ty__! { i8 }, + #[cfg(feature = "dtype-i16")] Int16 => __with_ty__! { i16 }, Int32 => __with_ty__! { i32 }, Int64 => __with_ty__! { i64 }, + #[cfg(feature = "dtype-u8")] UInt8 => __with_ty__! { u8 }, + #[cfg(feature = "dtype-u16")] UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, UInt64 => __with_ty__! { u64 }, @@ -331,6 +335,19 @@ macro_rules! with_match_physical_integer_type {( } })} +#[macro_export] +macro_rules! with_match_physical_float_type {( + $dtype:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + match $dtype { + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + dt => panic!("not implemented for dtype {:?}", dt), + } +})} + #[macro_export] macro_rules! with_match_physical_float_polars_type {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* diff --git a/crates/polars-plan/src/dsl/function_expr/pow.rs b/crates/polars-plan/src/dsl/function_expr/pow.rs index d46bf20043bb..1016e63ec487 100644 --- a/crates/polars-plan/src/dsl/function_expr/pow.rs +++ b/crates/polars-plan/src/dsl/function_expr/pow.rs @@ -2,6 +2,7 @@ use arrow::legacy::kernels::pow::pow as pow_kernel; use num::pow::Pow; use polars_core::export::num; use polars_core::export::num::{Float, ToPrimitive}; +use polars_core::with_match_physical_integer_type; use super::*; @@ -128,65 +129,53 @@ where fn pow_on_series(base: &Series, exponent: &Series) -> PolarsResult> { use DataType::*; - match (base.dtype(), exponent.dtype()) { - #[cfg(feature = "dtype-u8")] - (UInt8, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.u8().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - #[cfg(feature = "dtype-i8")] - (Int8, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.i8().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - #[cfg(feature = "dtype-u16")] - (UInt16, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.u16().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - #[cfg(feature = "dtype-i16")] - (Int16, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.i16().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - (UInt32, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.u32().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - (Int32, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.i32().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - (UInt64, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.u64().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - (Int64, UInt8 | UInt16 | UInt32 | UInt64) => { - let ca = base.i64().unwrap(); - let exponent = exponent.strict_cast(&DataType::UInt32)?; - pow_to_uint_dtype(ca, exponent.u32().unwrap()) - }, - (Float32, _) => { - let ca = base.f32().unwrap(); - let exponent = exponent.strict_cast(&DataType::Float32)?; - pow_on_floats(ca, exponent.f32().unwrap()) - }, - (Float64, _) => { - let ca = base.f64().unwrap(); - let exponent = exponent.strict_cast(&DataType::Float64)?; - pow_on_floats(ca, exponent.f64().unwrap()) - }, - _ => { - let base = base.cast(&DataType::Float64)?; - pow_on_series(&base, exponent) - }, + + let base_dtype = base.dtype(); + polars_ensure!( + base_dtype.is_numeric(), + InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype + ); + let exponent_dtype = exponent.dtype(); + polars_ensure!( + exponent_dtype.is_numeric(), + InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype + ); + + // if false, dtype is float + if base_dtype.is_integer() { + with_match_physical_integer_type!(base_dtype, |$native_type| { + if exponent_dtype.is_float() { + match exponent_dtype { + Float32 => { + let ca = base.cast(&DataType::Float32)?; + pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap()) + }, + Float64 => { + let ca = base.cast(&DataType::Float64)?; + pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap()) + }, + _ => unreachable!(), + } + } else { + let ca = base.$native_type().unwrap(); + let exponent = exponent.strict_cast(&DataType::UInt32)?; + pow_to_uint_dtype(ca, exponent.u32().unwrap()) + } + }) + } else { + match base_dtype { + Float32 => { + let ca = base.f32().unwrap(); + let exponent = exponent.strict_cast(&DataType::Float32)?; + pow_on_floats(ca, exponent.f32().unwrap()) + }, + Float64 => { + let ca = base.f64().unwrap(); + let exponent = exponent.strict_cast(&DataType::Float64)?; + pow_on_floats(ca, exponent.f64().unwrap()) + }, + _ => unreachable!(), + } } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 4d4652f3ea88..820a8f9b4199 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -466,14 +466,16 @@ impl<'a> FieldsMapper<'a> { } pub(super) fn pow_dtype(&self) -> PolarsResult { - // base, exponent - match (self.fields[0].data_type(), self.fields[1].data_type()) { - ( - base_dtype, - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64, - ) => Ok(Field::new(self.fields[0].name(), base_dtype.clone())), - (DataType::Float32, _) => Ok(Field::new(self.fields[0].name(), DataType::Float32)), - (_, _) => Ok(Field::new(self.fields[0].name(), DataType::Float64)), + let base_dtype = self.fields[0].data_type(); + let exponent_dtype = self.fields[1].data_type(); + if base_dtype.is_integer() { + if exponent_dtype.is_float() { + Ok(Field::new(self.fields[0].name(), exponent_dtype.clone())) + } else { + Ok(Field::new(self.fields[0].name(), base_dtype.clone())) + } + } else { + Ok(Field::new(self.fields[0].name(), base_dtype.clone())) } } diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 8b9caf5a3a85..1d6ebae09259 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -8192,16 +8192,16 @@ def with_columns( ... ) >>> df.with_columns((pl.col("a") ** 2).alias("a^2")) shape: (4, 4) - ┌─────┬──────┬───────┬──────┐ - │ a ┆ b ┆ c ┆ a^2 │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ bool ┆ f64 │ - ╞═════╪══════╪═══════╪══════╡ - │ 1 ┆ 0.5 ┆ true ┆ 1.0 │ - │ 2 ┆ 4.0 ┆ true ┆ 4.0 │ - │ 3 ┆ 10.0 ┆ false ┆ 9.0 │ - │ 4 ┆ 13.0 ┆ true ┆ 16.0 │ - └─────┴──────┴───────┴──────┘ + ┌─────┬──────┬───────┬─────┐ + │ a ┆ b ┆ c ┆ a^2 │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ bool ┆ i64 │ + ╞═════╪══════╪═══════╪═════╡ + │ 1 ┆ 0.5 ┆ true ┆ 1 │ + │ 2 ┆ 4.0 ┆ true ┆ 4 │ + │ 3 ┆ 10.0 ┆ false ┆ 9 │ + │ 4 ┆ 13.0 ┆ true ┆ 16 │ + └─────┴──────┴───────┴─────┘ Added columns will replace existing columns with the same name. @@ -8228,16 +8228,16 @@ def with_columns( ... ] ... ) shape: (4, 6) - ┌─────┬──────┬───────┬──────┬──────┬───────┐ - │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ bool ┆ f64 ┆ f64 ┆ bool │ - ╞═════╪══════╪═══════╪══════╪══════╪═══════╡ - │ 1 ┆ 0.5 ┆ true ┆ 1.0 ┆ 0.25 ┆ false │ - │ 2 ┆ 4.0 ┆ true ┆ 4.0 ┆ 2.0 ┆ false │ - │ 3 ┆ 10.0 ┆ false ┆ 9.0 ┆ 5.0 ┆ true │ - │ 4 ┆ 13.0 ┆ true ┆ 16.0 ┆ 6.5 ┆ false │ - └─────┴──────┴───────┴──────┴──────┴───────┘ + ┌─────┬──────┬───────┬─────┬──────┬───────┐ + │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ bool ┆ i64 ┆ f64 ┆ bool │ + ╞═════╪══════╪═══════╪═════╪══════╪═══════╡ + │ 1 ┆ 0.5 ┆ true ┆ 1 ┆ 0.25 ┆ false │ + │ 2 ┆ 4.0 ┆ true ┆ 4 ┆ 2.0 ┆ false │ + │ 3 ┆ 10.0 ┆ false ┆ 9 ┆ 5.0 ┆ true │ + │ 4 ┆ 13.0 ┆ true ┆ 16 ┆ 6.5 ┆ false │ + └─────┴──────┴───────┴─────┴──────┴───────┘ Multiple columns also can be added using positional arguments instead of a list. @@ -8247,16 +8247,16 @@ def with_columns( ... (pl.col("c").not_()).alias("not c"), ... ) shape: (4, 6) - ┌─────┬──────┬───────┬──────┬──────┬───────┐ - │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ bool ┆ f64 ┆ f64 ┆ bool │ - ╞═════╪══════╪═══════╪══════╪══════╪═══════╡ - │ 1 ┆ 0.5 ┆ true ┆ 1.0 ┆ 0.25 ┆ false │ - │ 2 ┆ 4.0 ┆ true ┆ 4.0 ┆ 2.0 ┆ false │ - │ 3 ┆ 10.0 ┆ false ┆ 9.0 ┆ 5.0 ┆ true │ - │ 4 ┆ 13.0 ┆ true ┆ 16.0 ┆ 6.5 ┆ false │ - └─────┴──────┴───────┴──────┴──────┴───────┘ + ┌─────┬──────┬───────┬─────┬──────┬───────┐ + │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ bool ┆ i64 ┆ f64 ┆ bool │ + ╞═════╪══════╪═══════╪═════╪══════╪═══════╡ + │ 1 ┆ 0.5 ┆ true ┆ 1 ┆ 0.25 ┆ false │ + │ 2 ┆ 4.0 ┆ true ┆ 4 ┆ 2.0 ┆ false │ + │ 3 ┆ 10.0 ┆ false ┆ 9 ┆ 5.0 ┆ true │ + │ 4 ┆ 13.0 ┆ true ┆ 16 ┆ 6.5 ┆ false │ + └─────┴──────┴───────┴─────┴──────┴───────┘ Use keyword arguments to easily name your expression inputs. diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 617459587ab8..75237213df47 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -5330,16 +5330,16 @@ def pow(self, exponent: IntoExprColumn | int | float) -> Self: ... pl.col("x").pow(pl.col("x").log(2)).alias("x ** xlog2"), ... ) shape: (4, 3) - ┌─────┬───────┬────────────┐ - │ x ┆ cube ┆ x ** xlog2 │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ f64 │ - ╞═════╪═══════╪════════════╡ - │ 1 ┆ 1.0 ┆ 1.0 │ - │ 2 ┆ 8.0 ┆ 2.0 │ - │ 4 ┆ 64.0 ┆ 16.0 │ - │ 8 ┆ 512.0 ┆ 512.0 │ - └─────┴───────┴────────────┘ + ┌─────┬──────┬────────────┐ + │ x ┆ cube ┆ x ** xlog2 │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ f64 │ + ╞═════╪══════╪════════════╡ + │ 1 ┆ 1 ┆ 1.0 │ + │ 2 ┆ 8 ┆ 2.0 │ + │ 4 ┆ 64 ┆ 16.0 │ + │ 8 ┆ 512 ┆ 512.0 │ + └─────┴──────┴────────────┘ """ return self.__pow__(exponent) @@ -9185,13 +9185,13 @@ def cumulative_eval( ┌────────┐ │ values │ │ --- │ - │ f64 │ + │ i64 │ ╞════════╡ - │ 0.0 │ - │ -3.0 │ - │ -8.0 │ - │ -15.0 │ - │ -24.0 │ + │ 0 │ + │ -3 │ + │ -8 │ + │ -15 │ + │ -24 │ └────────┘ """ return self._from_pyexpr( diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index ecb884f942b1..9c135062dd6c 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -2210,10 +2210,10 @@ def sql_expr(sql: str | Sequence[str]) -> Expr | list[Expr]: ┌─────┬─────┬───────┐ │ a ┆ a_a ┆ a_txt │ │ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ str │ + │ i64 ┆ i64 ┆ str │ ╞═════╪═════╪═══════╡ - │ 2 ┆ 4.0 ┆ 2 │ - │ 1 ┆ 1.0 ┆ 1 │ + │ 2 ┆ 4 ┆ 2 │ + │ 1 ┆ 1 ┆ 1 │ └─────┴─────┴───────┘ """ if isinstance(sql, str): diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index f96adfd0bbc4..805b96c2e415 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4164,16 +4164,16 @@ def with_columns( ... ) >>> lf.with_columns((pl.col("a") ** 2).alias("a^2")).collect() shape: (4, 4) - ┌─────┬──────┬───────┬──────┐ - │ a ┆ b ┆ c ┆ a^2 │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ bool ┆ f64 │ - ╞═════╪══════╪═══════╪══════╡ - │ 1 ┆ 0.5 ┆ true ┆ 1.0 │ - │ 2 ┆ 4.0 ┆ true ┆ 4.0 │ - │ 3 ┆ 10.0 ┆ false ┆ 9.0 │ - │ 4 ┆ 13.0 ┆ true ┆ 16.0 │ - └─────┴──────┴───────┴──────┘ + ┌─────┬──────┬───────┬─────┐ + │ a ┆ b ┆ c ┆ a^2 │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ bool ┆ i64 │ + ╞═════╪══════╪═══════╪═════╡ + │ 1 ┆ 0.5 ┆ true ┆ 1 │ + │ 2 ┆ 4.0 ┆ true ┆ 4 │ + │ 3 ┆ 10.0 ┆ false ┆ 9 │ + │ 4 ┆ 13.0 ┆ true ┆ 16 │ + └─────┴──────┴───────┴─────┘ Added columns will replace existing columns with the same name. @@ -4200,16 +4200,16 @@ def with_columns( ... ] ... ).collect() shape: (4, 6) - ┌─────┬──────┬───────┬──────┬──────┬───────┐ - │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ bool ┆ f64 ┆ f64 ┆ bool │ - ╞═════╪══════╪═══════╪══════╪══════╪═══════╡ - │ 1 ┆ 0.5 ┆ true ┆ 1.0 ┆ 0.25 ┆ false │ - │ 2 ┆ 4.0 ┆ true ┆ 4.0 ┆ 2.0 ┆ false │ - │ 3 ┆ 10.0 ┆ false ┆ 9.0 ┆ 5.0 ┆ true │ - │ 4 ┆ 13.0 ┆ true ┆ 16.0 ┆ 6.5 ┆ false │ - └─────┴──────┴───────┴──────┴──────┴───────┘ + ┌─────┬──────┬───────┬─────┬──────┬───────┐ + │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ bool ┆ i64 ┆ f64 ┆ bool │ + ╞═════╪══════╪═══════╪═════╪══════╪═══════╡ + │ 1 ┆ 0.5 ┆ true ┆ 1 ┆ 0.25 ┆ false │ + │ 2 ┆ 4.0 ┆ true ┆ 4 ┆ 2.0 ┆ false │ + │ 3 ┆ 10.0 ┆ false ┆ 9 ┆ 5.0 ┆ true │ + │ 4 ┆ 13.0 ┆ true ┆ 16 ┆ 6.5 ┆ false │ + └─────┴──────┴───────┴─────┴──────┴───────┘ Multiple columns also can be added using positional arguments instead of a list. @@ -4219,16 +4219,16 @@ def with_columns( ... (pl.col("c").not_()).alias("not c"), ... ).collect() shape: (4, 6) - ┌─────┬──────┬───────┬──────┬──────┬───────┐ - │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ f64 ┆ bool ┆ f64 ┆ f64 ┆ bool │ - ╞═════╪══════╪═══════╪══════╪══════╪═══════╡ - │ 1 ┆ 0.5 ┆ true ┆ 1.0 ┆ 0.25 ┆ false │ - │ 2 ┆ 4.0 ┆ true ┆ 4.0 ┆ 2.0 ┆ false │ - │ 3 ┆ 10.0 ┆ false ┆ 9.0 ┆ 5.0 ┆ true │ - │ 4 ┆ 13.0 ┆ true ┆ 16.0 ┆ 6.5 ┆ false │ - └─────┴──────┴───────┴──────┴──────┴───────┘ + ┌─────┬──────┬───────┬─────┬──────┬───────┐ + │ a ┆ b ┆ c ┆ a^2 ┆ b/2 ┆ not c │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ bool ┆ i64 ┆ f64 ┆ bool │ + ╞═════╪══════╪═══════╪═════╪══════╪═══════╡ + │ 1 ┆ 0.5 ┆ true ┆ 1 ┆ 0.25 ┆ false │ + │ 2 ┆ 4.0 ┆ true ┆ 4 ┆ 2.0 ┆ false │ + │ 3 ┆ 10.0 ┆ false ┆ 9 ┆ 5.0 ┆ true │ + │ 4 ┆ 13.0 ┆ true ┆ 16 ┆ 6.5 ┆ false │ + └─────┴──────┴───────┴─────┴──────┴───────┘ Use keyword arguments to easily name your expression inputs. diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 74834ecde8c8..a88b571b9364 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1167,9 +1167,6 @@ def __pow__(self, exponent: int | float | Series) -> Series: return self.pow(exponent) def __rpow__(self, other: Any) -> Series: - if self.dtype.is_temporal(): - msg = "first cast to integer before raising datelike dtypes to a power" - raise TypeError(msg) return self.to_frame().select_seq(other ** F.col(self.name)).to_series() def __matmul__(self, other: Any) -> float | Series | None: @@ -1957,17 +1954,14 @@ def pow(self, exponent: int | float | Series) -> Series: >>> s = pl.Series("foo", [1, 2, 3, 4]) >>> s.pow(3) shape: (4,) - Series: 'foo' [f64] + Series: 'foo' [i64] [ - 1.0 - 8.0 - 27.0 - 64.0 + 1 + 8 + 27 + 64 ] """ - if self.dtype.is_temporal(): - msg = "first cast to integer before raising datelike dtypes to a power" - raise TypeError(msg) if _check_for_numpy(exponent) and isinstance(exponent, np.ndarray): exponent = Series(exponent) return self.to_frame().select_seq(F.col(self.name).pow(exponent)).to_series() @@ -2828,13 +2822,13 @@ def cumulative_eval( >>> s = pl.Series("values", [1, 2, 3, 4, 5]) >>> s.cumulative_eval(pl.element().first() - pl.element().last() ** 2) shape: (5,) - Series: 'values' [f64] + Series: 'values' [i64] [ - 0.0 - -3.0 - -8.0 - -15.0 - -24.0 + 0 + -3 + -8 + -15 + -24 ] """ diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index fa1663bc146c..7967e94027af 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -200,8 +200,8 @@ # --------------------------------------------- ( "a", - "lambda x: (3 << (32-x)) & 3", - '(3 * 2**(32 - pl.col("a"))).cast(pl.Int64) & 3', + "lambda x: (3 << (30-x)) & 3", + '(3 * 2**(30 - pl.col("a"))).cast(pl.Int64) & 3', ), ( "a", diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 0d3cf9775b91..1ed04c97ecb0 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -447,7 +447,9 @@ def test_arithmetic_datetime() -> None: a * 2 with pytest.raises(TypeError): a % 2 - with pytest.raises(TypeError): + with pytest.raises( + pl.InvalidOperationError, + ): a**2 with pytest.raises(TypeError): 2 / a @@ -457,7 +459,9 @@ def test_arithmetic_datetime() -> None: 2 * a with pytest.raises(TypeError): 2 % a - with pytest.raises(TypeError): + with pytest.raises( + pl.InvalidOperationError, + ): 2**a @@ -476,12 +480,11 @@ def test_power() -> None: m = pl.Series([2**33, 2**33], dtype=UInt64) # pow - assert_series_equal(a**2, pl.Series([1.0, 4.0], dtype=Float64)) + assert_series_equal(a**2, pl.Series([1, 4], dtype=Int64)) assert_series_equal(b**3, pl.Series([None, 8.0], dtype=Float64)) - assert_series_equal(a**a, pl.Series([1.0, 4.0], dtype=Float64)) + assert_series_equal(a**a, pl.Series([1, 4], dtype=Int64)) assert_series_equal(b**b, pl.Series([None, 4.0], dtype=Float64)) assert_series_equal(a**b, pl.Series([None, 4.0], dtype=Float64)) - assert_series_equal(a**None, pl.Series([None] * len(a), dtype=Float64)) # type: ignore[operator] assert_series_equal(d**d, pl.Series([1, 4], dtype=UInt8)) assert_series_equal(e**d, pl.Series([1, 4], dtype=Int8)) assert_series_equal(f**d, pl.Series([1, 4], dtype=UInt16)) @@ -490,8 +493,24 @@ def test_power() -> None: assert_series_equal(i**d, pl.Series([1, 4], dtype=Int32)) assert_series_equal(j**d, pl.Series([1, 4], dtype=UInt64)) assert_series_equal(k**d, pl.Series([1, 4], dtype=Int64)) - with pytest.raises(TypeError): + + with pytest.raises( + pl.InvalidOperationError, + match="`pow` operation not supported for dtype `null` as exponent", + ): + a ** pl.lit(None) + + with pytest.raises( + pl.InvalidOperationError, + match="`pow` operation not supported for dtype `date` as base", + ): c**2 + with pytest.raises( + pl.InvalidOperationError, + match="`pow` operation not supported for dtype `date` as exponent", + ): + 2**c + with pytest.raises(pl.ColumnNotFoundError): a ** "hi" # type: ignore[operator] @@ -504,13 +523,12 @@ def test_power() -> None: # rpow assert_series_equal(2.0**a, pl.Series("literal", [2.0, 4.0], dtype=Float64)) assert_series_equal(2**b, pl.Series("literal", [None, 4.0], dtype=Float64)) - with pytest.raises(TypeError): - 2**c + with pytest.raises(pl.ColumnNotFoundError): "hi" ** a # Series.pow() method - assert_series_equal(a.pow(2), pl.Series([1.0, 4.0], dtype=Float64)) + assert_series_equal(a.pow(2), pl.Series([1, 4], dtype=Int64)) def test_add_string() -> None: @@ -1986,13 +2004,13 @@ def test_cumulative_eval() -> None: expr2 = pl.element().last() ** 2 expected1 = pl.Series("values", [1, 1, 1, 1, 1]) - expected2 = pl.Series("values", [1.0, 4.0, 9.0, 16.0, 25.0]) + expected2 = pl.Series("values", [1, 4, 9, 16, 25]) assert_series_equal(s.cumulative_eval(expr1), expected1) assert_series_equal(s.cumulative_eval(expr2), expected2) # evaluate combined expressions and validate expr3 = expr1 - expr2 - expected3 = pl.Series("values", [0.0, -3.0, -8.0, -15.0, -24.0]) + expected3 = pl.Series("values", [0, -3, -8, -15, -24]) assert_series_equal(s.cumulative_eval(expr3), expected3) diff --git a/py-polars/tests/unit/sql/test_functions.py b/py-polars/tests/unit/sql/test_functions.py index 7ffb0be03d63..1449daa52482 100644 --- a/py-polars/tests/unit/sql/test_functions.py +++ b/py-polars/tests/unit/sql/test_functions.py @@ -25,7 +25,7 @@ def test_sql_expr() -> None: ) result = df.select(*sql_exprs) expected = pl.DataFrame( - {"a": [1, 1, 1], "aa": [1.0, 4.0, 27.0], "b2": ["yz", "bc", None]} + {"a": [1, 1, 1], "aa": [1, 4, 27], "b2": ["yz", "bc", None]} ) assert_frame_equal(result, expected)