From f3bfe8ea923270e5d7f927ab7d803d55bec6df56 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 23 Oct 2024 15:14:00 -0400 Subject: [PATCH] Continued work on sketch of expr-based index_of(). --- crates/polars-ops/src/series/ops/mod.rs | 3 +- crates/polars-ops/src/series/ops/search.rs | 77 ------------------- .../src/dsl/function_expr/index_of.rs | 16 ++++ .../polars-plan/src/dsl/function_expr/mod.rs | 7 ++ .../src/dsl/function_expr/schema.rs | 1 + crates/polars-plan/src/dsl/mod.rs | 18 +++++ crates/polars-python/src/expr/general.rs | 5 ++ .../src/lazyframe/visitor/expr_nodes.rs | 1 + crates/polars-python/src/series/scatter.rs | 1 - py-polars/polars/expr/expr.py | 7 ++ py-polars/polars/series/series.py | 20 ++--- 11 files changed, 63 insertions(+), 93 deletions(-) delete mode 100644 crates/polars-ops/src/series/ops/search.rs create mode 100644 crates/polars-plan/src/dsl/function_expr/index_of.rs diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 1e1123d77d30..b5c57374afb1 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -54,7 +54,7 @@ mod rolling; #[cfg(feature = "round_series")] mod round; // TODO add a feature? -mod search; +mod index_of; #[cfg(feature = "search_sorted")] mod search_sorted; #[cfg(feature = "to_dummies")] @@ -124,6 +124,7 @@ pub use rle::*; pub use rolling::*; #[cfg(feature = "round_series")] pub use round::*; +pub use index_of::*; #[cfg(feature = "search_sorted")] pub use search_sorted::*; #[cfg(feature = "to_dummies")] diff --git a/crates/polars-ops/src/series/ops/search.rs b/crates/polars-ops/src/series/ops/search.rs deleted file mode 100644 index a5249e1107a3..000000000000 --- a/crates/polars-ops/src/series/ops/search.rs +++ /dev/null @@ -1,77 +0,0 @@ -use polars_core::downcast_as_macro_arg_physical; -use polars_core::prelude::*; -use polars_utils::float::IsFloat; - -use crate::series::ops::SeriesSealed; - -/// Search for an item, typically in a ChunkedArray. -pub trait ChunkSearch<'a, T> { - /// Return the index of the given value within self, or `None` if not found. - fn index_of(&'a self, value: Option) -> Option; -} - - -impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray -where - T: PolarsNumericType, -{ - fn index_of(&'a self, value: Option) -> Option { - // A NaN is never equal to anything, including itself. But we still want - // to be able to search for NaNs, so we handle them specially. - if value.map(|v| v.is_nan()) == Some(true) { - return self - .iter() - .position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true)); - } - - self.iter().position(|opt_val| opt_val == value) - } -} - -/// Try casting the value to the correct type, then call index_of(). -macro_rules! try_index_of_numeric_ca { - ($ca:expr, $value:expr) => {{ - let ca = $ca; - let cast_value = $value.map(|v| AnyValue::from(v).strict_cast(ca.dtype())); - if cast_value == Some(None) { - // We can can't cast the searched-for value to a valid data point - // within the dtype of the Series we're searching, which means we - // will never find that value. - None - } else { - let cast_value = cast_value.flatten(); - ca.index_of(cast_value.map(|v| v.extract().unwrap())) - } - }}; -} - -pub trait SearchSeries: SeriesSealed { - fn index_of(&self, value_series: &Series) -> PolarsResult> { - let series = self.as_series(); - let value_series = if value_series.dtype().is_null() { - // Should be able to cast null dtype to anything, so cast it to dtype of - // Series we're searching. - &value_series.cast(series.dtype())? - } else { - value_series - }; - let value_dtype = value_series.dtype(); - - if value_dtype.is_signed_integer() { - let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0); - let result = downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value); - return Ok(result); - } - if value_dtype.is_unsigned_integer() { - let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0); - return Ok(downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value)); - } - if value_dtype.is_float() { - let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0); - return Ok(downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value)); - } - // At this point we're done handling integers and floats. - unimplemented!("TODO") - } -} - diff --git a/crates/polars-plan/src/dsl/function_expr/index_of.rs b/crates/polars-plan/src/dsl/function_expr/index_of.rs new file mode 100644 index 000000000000..c583171fb405 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/index_of.rs @@ -0,0 +1,16 @@ +use polars_ops::series::index_of as index_of_op; +use super::*; + +pub(super) fn index_of(s: &mut [Column]) -> PolarsResult> { + let series = s[0].as_materialized_series(); + let value = s[1].as_materialized_series(); + if value.len() != 1 { + polars_bail!( + ComputeError: + "there can only be a single value searched for in `index_of` expressions, but {} values were give", + value.len(), + ); + } + let result = index_of_op(series, value)?; + Ok(result.map(|r| Column::new(series.name().clone(), [r as IdxSize]))) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 6cebaa301b85..86aabd2314e0 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -57,6 +57,7 @@ mod round; #[cfg(feature = "row_hash")] mod row_hash; pub(super) mod schema; +mod index_of; #[cfg(feature = "search_sorted")] mod search_sorted; mod shift_and_fill; @@ -154,6 +155,7 @@ pub enum FunctionExpr { Hash(u64, u64, u64, u64), #[cfg(feature = "arg_where")] ArgWhere, + IndexOf, #[cfg(feature = "search_sorted")] SearchSorted(SearchSortedSide), #[cfg(feature = "range")] @@ -392,6 +394,7 @@ impl Hash for FunctionExpr { #[cfg(feature = "business")] Business(f) => f.hash(state), Pow(f) => f.hash(state), + IndexOf => {}, #[cfg(feature = "search_sorted")] SearchSorted(f) => f.hash(state), #[cfg(feature = "random")] @@ -629,6 +632,7 @@ impl Display for FunctionExpr { Hash(_, _, _, _) => "hash", #[cfg(feature = "arg_where")] ArgWhere => "arg_where", + IndexOf => "index_of", #[cfg(feature = "search_sorted")] SearchSorted(_) => "search_sorted", #[cfg(feature = "range")] @@ -918,6 +922,9 @@ impl From for SpecialEq> { ArgWhere => { wrap!(arg_where::arg_where) }, + IndexOf => { + wrap!(index_of::index_of) + } #[cfg(feature = "search_sorted")] SearchSorted(side) => { map_as_slice!(search_sorted::search_sorted_impl, side) diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index bc09cca94215..7b57e8241d94 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -49,6 +49,7 @@ impl FunctionExpr { Hash(..) => mapper.with_dtype(DataType::UInt64), #[cfg(feature = "arg_where")] ArgWhere => mapper.with_dtype(IDX_DTYPE), + IndexOf => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "search_sorted")] SearchSorted(_) => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "range")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index a88ff858e6ee..98812ef6f7e2 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -377,6 +377,24 @@ impl Expr { ) } + /// Find the index of a value. + pub fn index_of>(self, element: E) -> Expr { + let element = element.into(); + Expr::Function { + input: vec![self, element], + function: FunctionExpr::IndexOf, + options: FunctionOptions { + // TODO which ApplyOptions, if any? + //collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + fmt_str: "index_of", + // TODO can we rely on casting here instead of doing it in the + // function? + ..Default::default() + }, + } + } + #[cfg(feature = "search_sorted")] /// Find indices where elements should be inserted to maintain order. pub fn search_sorted>(self, element: E, side: SearchSortedSide) -> Expr { diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index 604049f62b66..7ec8331089a3 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -315,6 +315,10 @@ impl PyExpr { self.inner.clone().arg_min().into() } + fn index_of(&self, element: Self) -> Self { + self.inner.clone().index_of(element.inner).into() + } + #[cfg(feature = "search_sorted")] fn search_sorted(&self, element: Self, side: Wrap) -> Self { self.inner @@ -322,6 +326,7 @@ impl PyExpr { .search_sorted(element.inner, side.0) .into() } + fn gather(&self, idx: Self) -> Self { self.inner.clone().gather(idx.inner).into() } diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 07d2f872437c..3a660c4b5bf3 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1103,6 +1103,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { ("hash", seed, seed_1, seed_2, seed_3).to_object(py) }, FunctionExpr::ArgWhere => ("argwhere",).to_object(py), + FunctionExpr::IndexOf => ("index_of",).to_object(py), #[cfg(feature = "search_sorted")] FunctionExpr::SearchSorted(side) => ( "search_sorted", diff --git a/crates/polars-python/src/series/scatter.rs b/crates/polars-python/src/series/scatter.rs index 2fa6a0f729bc..1769bf673e7b 100644 --- a/crates/polars-python/src/series/scatter.rs +++ b/crates/polars-python/src/series/scatter.rs @@ -1,6 +1,5 @@ use polars::export::arrow::array::Array; use polars::prelude::*; -use polars_core::downcast_as_macro_arg_physical; use pyo3::prelude::*; use super::PySeries; diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 98f638b0846a..0c7ebbc3584e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2303,6 +2303,13 @@ def arg_min(self) -> Expr: """ return self._from_pyexpr(self._pyexpr.arg_min()) + def index_of(self, element: IntoExpr | np.ndarray[Any, Any]) -> Expr: + """ + TODO + """ + element = parse_into_expression(element, str_as_lit=True, list_as_series=True) # type: ignore[arg-type] + return self._from_pyexpr(self._pyexpr.index_of(element)) + def search_sorted( self, element: IntoExpr | np.ndarray[Any, Any], side: SearchSortedSide = "any" ) -> Expr: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 756433114b18..2c474657266b 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4732,32 +4732,24 @@ def scatter( self._s.scatter(indices._s, values._s) return self - def index_of( - self, - value: Series | Iterable[PythonLiteral] | PythonLiteral | None, - ) -> int | None: + def index_of(self, element) -> int | None: """ Get the first index of a value, or ``None`` if it's not found. Parameters ---------- - value + element Value to find. Examples -------- TODO """ - if isinstance(value, Series): - # Searching for lists or arrays: - value = value.implode() + df = F.select(F.lit(self).index_of(element)) + if isinstance(element, (list, Series, pl.Expr, np.ndarray)): + return df.to_series() else: - value = Series(values=[value]) - - if isinstance(self.dtype, Array): - value = value.cast(Array(self.dtype.inner, len(value[0]))) - - return self._s.index_of(value._s) + return df.item() def clear(self, n: int = 0) -> Series: """