From 7db314ebfdf41fbabb026d1ea190cec4c0f3b58a Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 12 Apr 2024 12:26:22 +0800 Subject: [PATCH] feat: Expressify `to_integer` --- .../src/chunked_array/strings/namespace.rs | 49 ++++++++++++++----- .../src/dsl/function_expr/strings.rs | 12 +++-- crates/polars-plan/src/dsl/string.rs | 12 +++-- py-polars/polars/expr/string.py | 8 ++- py-polars/polars/series/string.py | 3 +- py-polars/src/expr/string.rs | 4 +- .../unit/namespaces/string/test_string.py | 17 ++++++- 7 files changed, 76 insertions(+), 29 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 83713b788952..fa6718ffc596 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -1,4 +1,5 @@ use arrow::array::ValueSize; +use arrow::compute::boolean::all; use arrow::legacy::kernels::string::*; #[cfg(feature = "string_encoding")] use base64::engine::general_purpose; @@ -63,25 +64,47 @@ pub trait StringNameSpaceImpl: AsString { #[cfg(feature = "string_to_integer")] // Parse a string number with base _radix_ into a decimal (i64) - fn to_integer(&self, base: u32, strict: bool) -> PolarsResult { + fn to_integer(&self, base: &UInt32Chunked, strict: bool) -> PolarsResult { let ca = self.as_string(); - let f = |opt_s: Option<&str>| -> Option { - opt_s.and_then(|s| ::from_str_radix(s, base).ok()) + let f = |opt_s: Option<&str>, opt_base: Option| -> Option { + match (opt_s, opt_base) { + (Some(s), Some(base)) => ::from_str_radix(s, base).ok(), + _ => None, + } }; - let out: Int64Chunked = ca.apply_generic(f); - + let out = broadcast_binary_elementwise(ca, base, f); if strict && ca.null_count() != out.null_count() { - let failure_mask = !ca.is_null() & out.is_null(); + let failure_mask = ca.is_not_null() & out.is_null() & base.is_not_null(); let all_failures = ca.filter(&failure_mask)?; + if all_failures.is_empty() { + return Ok(out); + } let n_failures = all_failures.len(); let some_failures = all_failures.unique()?.slice(0, 10).sort(false); - let some_error_msg = some_failures - .get(0) - .and_then(|s| ::from_str_radix(s, base).err()) - .map_or_else( - || unreachable!("failed to extract ParseIntError"), - |e| format!("{}", e), - ); + let some_error_msg = match base.len() { + 1 => { + // we can ensure that base is not null. + let base = base.get(0).unwrap(); + some_failures + .get(0) + .and_then(|s| ::from_str_radix(s, base).err()) + .map_or_else( + || unreachable!("failed to extract ParseIntError"), + |e| format!("{}", e), + ) + }, + _ => { + let base_filures = base.filter(&failure_mask)?; + some_failures + .get(0) + .zip(base_filures.get(0)) + .and_then(|(s, base)| ::from_str_radix(s, base).err()) + .map_or_else( + || unreachable!("failed to extract ParseIntError"), + |e| format!("{}", e), + ) + }, + }; polars_bail!( ComputeError: "strict integer parsing failed for {} value(s): {}; error message for the \ diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 13d18d790c63..62f81865c22c 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -53,7 +53,7 @@ pub enum StringFunction { strict: bool, }, #[cfg(feature = "string_to_integer")] - ToInteger(u32, bool), + ToInteger(bool), LenBytes, LenChars, Lowercase, @@ -343,7 +343,7 @@ impl From for SpecialEq> { StripPrefix => map_as_slice!(strings::strip_prefix), StripSuffix => map_as_slice!(strings::strip_suffix), #[cfg(feature = "string_to_integer")] - ToInteger(base, strict) => map!(strings::to_integer, base, strict), + ToInteger(strict) => map_as_slice!(strings::to_integer, strict), Slice => map_as_slice!(strings::str_slice), #[cfg(feature = "string_encoding")] HexEncode => map!(strings::hex_encode), @@ -888,9 +888,11 @@ pub(super) fn reverse(s: &Series) -> PolarsResult { } #[cfg(feature = "string_to_integer")] -pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult { - let ca = s.str()?; - ca.to_integer(base, strict).map(|ok| ok.into_series()) +pub(super) fn to_integer(s: &[Series], strict: bool) -> PolarsResult { + let ca = s[0].str()?; + let base = s[1].strict_cast(&DataType::UInt32)?; + ca.to_integer(base.u32()?, strict) + .map(|ok| ok.into_series()) } pub(super) fn str_slice(s: &[Series]) -> PolarsResult { // Calculate the post-broadcast length and ensure everything is consistent. diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 88c43c4e5ff7..15c3db4cc463 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -483,11 +483,13 @@ impl StringNameSpace { #[cfg(feature = "string_to_integer")] /// Parse string in base radix into decimal. - pub fn to_integer(self, base: u32, strict: bool) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::ToInteger( - base, strict, - ))) + pub fn to_integer(self, base: Expr, strict: bool) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::ToInteger(strict)), + &[base], + false, + false, + ) } /// Return the length of each string as the number of bytes. diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 00306176c8a8..034148d1d4fd 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -2234,14 +2234,17 @@ def explode(self) -> Expr: """ return wrap_expr(self._pyexpr.str_explode()) - def to_integer(self, *, base: int = 10, strict: bool = True) -> Expr: + def to_integer( + self, *, base: int | IntoExprColumn = 10, strict: bool = True + ) -> Expr: """ Convert a String column into an Int64 column with base radix. Parameters ---------- base - Positive integer which is the base of the string we are parsing. + Positive integer or expression which is the base of the string + we are parsing. Default: 10. strict Bool, Default=True will raise any ParseError or overflow as ComputeError. @@ -2282,6 +2285,7 @@ def to_integer(self, *, base: int = 10, strict: bool = True) -> Expr: │ null ┆ null │ └──────┴────────┘ """ + base = parse_as_expression(base, str_as_lit=False) return wrap_expr(self._pyexpr.str_to_integer(base, strict)) @deprecate_renamed_function("to_integer", version="0.19.14") diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 8aad10f4dcf4..c14a32c82e31 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1690,7 +1690,8 @@ def to_integer(self, *, base: int = 10, strict: bool = True) -> Series: Parameters ---------- base - Positive integer which is the base of the string we are parsing. + Positive integer or expression which is the base of the string + we are parsing. Default: 10. strict Bool, Default=True will raise any ParseError or overflow as ComputeError. diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index e4e8b7bcceb7..5f870c204994 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -205,11 +205,11 @@ impl PyExpr { self.inner.clone().str().base64_decode(strict).into() } - fn str_to_integer(&self, base: u32, strict: bool) -> Self { + fn str_to_integer(&self, base: Self, strict: bool) -> Self { self.inner .clone() .str() - .to_integer(base, strict) + .to_integer(base.inner, strict) .with_fmt("str.to_integer") .into() } diff --git a/py-polars/tests/unit/namespaces/string/test_string.py b/py-polars/tests/unit/namespaces/string/test_string.py index ba7cae882c93..ce7e2e4a54b0 100644 --- a/py-polars/tests/unit/namespaces/string/test_string.py +++ b/py-polars/tests/unit/namespaces/string/test_string.py @@ -305,7 +305,22 @@ def test_str_to_integer() -> None: hex.str.to_integer(base=16) -def test_str_to_integer_df() -> None: +def test_str_to_integer_base_expr() -> None: + df = pl.DataFrame( + {"str": ["110", "ff00", "234", None, "130"], "base": [2, 16, 10, 8, None]} + ) + out = df.select(base_expr=pl.col("str").str.to_integer(base="base")) + expected = pl.DataFrame({"base_expr": [6, 65280, 234, None, None]}) + assert_frame_equal(out, expected) + + # test strict raise + df = pl.DataFrame({"str": ["110", "ff00", "cafe", None], "base": [2, 10, 10, 8]}) + + with pytest.raises(pl.ComputeError, match="failed for 2 value"): + df.select(pl.col("str").str.to_integer(base="base")) + + +def test_str_to_integer_base_literal() -> None: df = pl.DataFrame( { "bin": ["110", "101", "-010", "invalid", None],