diff --git a/crates/polars-core/src/chunked_array/ops/search.rs b/crates/polars-core/src/chunked_array/ops/search.rs new file mode 100644 index 000000000000..a19cd229846d --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/search.rs @@ -0,0 +1,43 @@ +use polars_utils::float::IsFloat; + +use crate::prelude::*; + +impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray<T> +where + T: PolarsNumericType, +{ + fn index_of(&'a self, value: Option<T::Native>) -> Option<usize> { + // A NaN is never equal to anything, including itself. But we still want + // to be able to search for NaNs, so we handle them specially. + if value.map(|v| v.is_nan()) == Some(true) { + return self + .iter() + .position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true)); + } + + self.iter().position(|opt_val| opt_val == value) + } +} + +impl ChunkSearch<'_, &Series> for ListChunked { + fn index_of(&self, value: Option<&Series>) -> Option<usize> { + self.amortized_iter() + .position(|opt_val| match (opt_val, value) { + (Some(in_series), Some(value)) => in_series.as_ref() == value, + (None, None) => true, + _ => false, + }) + } +} + +#[cfg(feature="dtype-array")] +impl ChunkSearch<'_, &Series> for ArrayChunked { + fn index_of(&self, value: Option<&Series>) -> Option<usize> { + self.amortized_iter() + .position(|opt_val| match (opt_val, value) { + (Some(in_series), Some(value)) => in_series.as_ref() == value, + (None, None) => true, + _ => false, + }) + } +} diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs new file mode 100644 index 000000000000..4aa4ef51580e --- /dev/null +++ b/crates/polars-ops/src/series/ops/index_of.rs @@ -0,0 +1,82 @@ +use polars_core::downcast_as_macro_arg_physical; +use polars_core::prelude::*; +use polars_utils::float::IsFloat; + +/// Search for an item, typically in a ChunkedArray. +trait ChunkSearch<'a, T> { + /// Return the index of the given value within self, or `None` if not found. + fn index_of(&'a self, value: Option<T>) -> Option<usize>; +} + +impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray<T> +where + T: PolarsNumericType, +{ + fn index_of(&'a self, value: Option<T::Native>) -> Option<usize> { + // A NaN is never equal to anything, including itself. But we still want + // to be able to search for NaNs, so we handle them specially. + if value.map(|v| v.is_nan()) == Some(true) { + return self + .iter() + .position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true)); + } + + self.iter().position(|opt_val| opt_val == value) + } +} + +/// Try casting the value to the correct type, then call index_of(). +macro_rules! try_index_of_numeric_ca { + ($ca:expr, $value:expr) => {{ + let ca = $ca; + let cast_value = $value.map(|v| AnyValue::from(v).strict_cast(ca.dtype())); + if cast_value == Some(None) { + // We can can't cast the searched-for value to a valid data point + // within the dtype of the Series we're searching, which means we + // will never find that value. + None + } else { + let cast_value = cast_value.flatten(); + ca.index_of(cast_value.map(|v| v.extract().unwrap())) + } + }}; +} + +/// Find the index of a given value (the first and only entry in +/// `value_series`), find its index within `series`. +pub fn index_of(series: &Series, value_series: &Series) -> PolarsResult<Option<usize>> { + // TODO ensure value_series length is 1 + // TODO passing in a Series for the value is kinda meh, is there a way to pass in an AnyValue instead? + let value_series = if value_series.dtype().is_null() { + // Should be able to cast null dtype to anything, so cast it to dtype of + // Series we're searching. + &value_series.cast(series.dtype())? + } else { + value_series + }; + let value_dtype = value_series.dtype(); + + if value_dtype.is_signed_integer() { + let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0); + let result = downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value); + return Ok(result); + } + if value_dtype.is_unsigned_integer() { + let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0); + return Ok(downcast_as_macro_arg_physical!( + series, + try_index_of_numeric_ca, + value + )); + } + if value_dtype.is_float() { + let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0); + return Ok(downcast_as_macro_arg_physical!( + series, + try_index_of_numeric_ca, + value + )); + } + // At this point we're done handling integers and floats. + unimplemented!("TODO") +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index b684815238f7..b5c57374afb1 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -53,6 +53,8 @@ mod rle; mod rolling; #[cfg(feature = "round_series")] mod round; +// TODO add a feature? +mod index_of; #[cfg(feature = "search_sorted")] mod search_sorted; #[cfg(feature = "to_dummies")] @@ -122,6 +124,7 @@ pub use rle::*; pub use rolling::*; #[cfg(feature = "round_series")] pub use round::*; +pub use index_of::*; #[cfg(feature = "search_sorted")] pub use search_sorted::*; #[cfg(feature = "to_dummies")] diff --git a/crates/polars-plan/src/dsl/function_expr/index_of.rs b/crates/polars-plan/src/dsl/function_expr/index_of.rs new file mode 100644 index 000000000000..c583171fb405 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/index_of.rs @@ -0,0 +1,16 @@ +use polars_ops::series::index_of as index_of_op; +use super::*; + +pub(super) fn index_of(s: &mut [Column]) -> PolarsResult<Option<Column>> { + let series = s[0].as_materialized_series(); + let value = s[1].as_materialized_series(); + if value.len() != 1 { + polars_bail!( + ComputeError: + "there can only be a single value searched for in `index_of` expressions, but {} values were give", + value.len(), + ); + } + let result = index_of_op(series, value)?; + Ok(result.map(|r| Column::new(series.name().clone(), [r as IdxSize]))) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 6cebaa301b85..86aabd2314e0 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -57,6 +57,7 @@ mod round; #[cfg(feature = "row_hash")] mod row_hash; pub(super) mod schema; +mod index_of; #[cfg(feature = "search_sorted")] mod search_sorted; mod shift_and_fill; @@ -154,6 +155,7 @@ pub enum FunctionExpr { Hash(u64, u64, u64, u64), #[cfg(feature = "arg_where")] ArgWhere, + IndexOf, #[cfg(feature = "search_sorted")] SearchSorted(SearchSortedSide), #[cfg(feature = "range")] @@ -392,6 +394,7 @@ impl Hash for FunctionExpr { #[cfg(feature = "business")] Business(f) => f.hash(state), Pow(f) => f.hash(state), + IndexOf => {}, #[cfg(feature = "search_sorted")] SearchSorted(f) => f.hash(state), #[cfg(feature = "random")] @@ -629,6 +632,7 @@ impl Display for FunctionExpr { Hash(_, _, _, _) => "hash", #[cfg(feature = "arg_where")] ArgWhere => "arg_where", + IndexOf => "index_of", #[cfg(feature = "search_sorted")] SearchSorted(_) => "search_sorted", #[cfg(feature = "range")] @@ -918,6 +922,9 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> { ArgWhere => { wrap!(arg_where::arg_where) }, + IndexOf => { + wrap!(index_of::index_of) + } #[cfg(feature = "search_sorted")] SearchSorted(side) => { map_as_slice!(search_sorted::search_sorted_impl, side) diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index bc09cca94215..7b57e8241d94 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -49,6 +49,7 @@ impl FunctionExpr { Hash(..) => mapper.with_dtype(DataType::UInt64), #[cfg(feature = "arg_where")] ArgWhere => mapper.with_dtype(IDX_DTYPE), + IndexOf => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "search_sorted")] SearchSorted(_) => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "range")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index a88ff858e6ee..98812ef6f7e2 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -377,6 +377,24 @@ impl Expr { ) } + /// Find the index of a value. + pub fn index_of<E: Into<Expr>>(self, element: E) -> Expr { + let element = element.into(); + Expr::Function { + input: vec![self, element], + function: FunctionExpr::IndexOf, + options: FunctionOptions { + // TODO which ApplyOptions, if any? + //collect_groups: ApplyOptions::GroupWise, + flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, + fmt_str: "index_of", + // TODO can we rely on casting here instead of doing it in the + // function? + ..Default::default() + }, + } + } + #[cfg(feature = "search_sorted")] /// Find indices where elements should be inserted to maintain order. pub fn search_sorted<E: Into<Expr>>(self, element: E, side: SearchSortedSide) -> Expr { diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index 604049f62b66..7ec8331089a3 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -315,6 +315,10 @@ impl PyExpr { self.inner.clone().arg_min().into() } + fn index_of(&self, element: Self) -> Self { + self.inner.clone().index_of(element.inner).into() + } + #[cfg(feature = "search_sorted")] fn search_sorted(&self, element: Self, side: Wrap<SearchSortedSide>) -> Self { self.inner @@ -322,6 +326,7 @@ impl PyExpr { .search_sorted(element.inner, side.0) .into() } + fn gather(&self, idx: Self) -> Self { self.inner.clone().gather(idx.inner).into() } diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 07d2f872437c..3a660c4b5bf3 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1103,6 +1103,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> { ("hash", seed, seed_1, seed_2, seed_3).to_object(py) }, FunctionExpr::ArgWhere => ("argwhere",).to_object(py), + FunctionExpr::IndexOf => ("index_of",).to_object(py), #[cfg(feature = "search_sorted")] FunctionExpr::SearchSorted(side) => ( "search_sorted", diff --git a/crates/polars-python/src/series/scatter.rs b/crates/polars-python/src/series/scatter.rs index 97df60ef205b..1769bf673e7b 100644 --- a/crates/polars-python/src/series/scatter.rs +++ b/crates/polars-python/src/series/scatter.rs @@ -139,3 +139,13 @@ fn scatter_impl( s.and_then(|s| s.cast(&logical_dtype)) } + +#[pymethods] +impl PySeries { + /// Given a `PySeries` of length 0, find the index of the first value within + /// self. + fn index_of(&self, value: PySeries) -> PyResult<Option<usize>> { + // TODO assert length of value is 1? + index_of(&self.series, &value.series).map_err(|e| PyErr::from(PyPolarsErr::from(e))) + } +} diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 98f638b0846a..0c7ebbc3584e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2303,6 +2303,13 @@ def arg_min(self) -> Expr: """ return self._from_pyexpr(self._pyexpr.arg_min()) + def index_of(self, element: IntoExpr | np.ndarray[Any, Any]) -> Expr: + """ + TODO + """ + element = parse_into_expression(element, str_as_lit=True, list_as_series=True) # type: ignore[arg-type] + return self._from_pyexpr(self._pyexpr.index_of(element)) + def search_sorted( self, element: IntoExpr | np.ndarray[Any, Any], side: SearchSortedSide = "any" ) -> Expr: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index ea37a64aa778..2c474657266b 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4732,6 +4732,25 @@ def scatter( self._s.scatter(indices._s, values._s) return self + def index_of(self, element) -> int | None: + """ + Get the first index of a value, or ``None`` if it's not found. + + Parameters + ---------- + element + Value to find. + + Examples + -------- + TODO + """ + df = F.select(F.lit(self).index_of(element)) + if isinstance(element, (list, Series, pl.Expr, np.ndarray)): + return df.to_series() + else: + return df.item() + def clear(self, n: int = 0) -> Series: """ Create an empty copy of the current Series, with zero to 'n' elements.