Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python) Series.index_of() and equivalent expression - Request for architectural review, take 2 #19408

Closed
wants to merge 11 commits into from
43 changes: 43 additions & 0 deletions crates/polars-core/src/chunked_array/ops/search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use polars_utils::float::IsFloat;

use crate::prelude::*;

impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn index_of(&'a self, value: Option<T::Native>) -> Option<usize> {
// A NaN is never equal to anything, including itself. But we still want
// to be able to search for NaNs, so we handle them specially.
if value.map(|v| v.is_nan()) == Some(true) {
return self
.iter()
.position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true));
}

self.iter().position(|opt_val| opt_val == value)
}
}

impl ChunkSearch<'_, &Series> for ListChunked {
fn index_of(&self, value: Option<&Series>) -> Option<usize> {
self.amortized_iter()
itamarst marked this conversation as resolved.
Show resolved Hide resolved
.position(|opt_val| match (opt_val, value) {
(Some(in_series), Some(value)) => in_series.as_ref() == value,
(None, None) => true,
_ => false,
})
}
}

#[cfg(feature="dtype-array")]
impl ChunkSearch<'_, &Series> for ArrayChunked {
fn index_of(&self, value: Option<&Series>) -> Option<usize> {
itamarst marked this conversation as resolved.
Show resolved Hide resolved
self.amortized_iter()
.position(|opt_val| match (opt_val, value) {
(Some(in_series), Some(value)) => in_series.as_ref() == value,
(None, None) => true,
_ => false,
})
}
}
82 changes: 82 additions & 0 deletions crates/polars-ops/src/series/ops/index_of.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use polars_core::downcast_as_macro_arg_physical;
use polars_core::prelude::*;
use polars_utils::float::IsFloat;

/// Search for an item, typically in a ChunkedArray.
trait ChunkSearch<'a, T> {
/// Return the index of the given value within self, or `None` if not found.
fn index_of(&'a self, value: Option<T>) -> Option<usize>;
}

impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn index_of(&'a self, value: Option<T::Native>) -> Option<usize> {
// A NaN is never equal to anything, including itself. But we still want
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be implemented over the chunks and not using an iterator over ChunkedArray's as they are really slow.

In the case we are looking for a None we can also fast path to first null.

// to be able to search for NaNs, so we handle them specially.
if value.map(|v| v.is_nan()) == Some(true) {
return self
.iter()
.position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true));
}

self.iter().position(|opt_val| opt_val == value)
}
}

/// Try casting the value to the correct type, then call index_of().
macro_rules! try_index_of_numeric_ca {
($ca:expr, $value:expr) => {{
let ca = $ca;
let cast_value = $value.map(|v| AnyValue::from(v).strict_cast(ca.dtype()));
if cast_value == Some(None) {
// We can can't cast the searched-for value to a valid data point
// within the dtype of the Series we're searching, which means we
// will never find that value.
None
} else {
let cast_value = cast_value.flatten();
ca.index_of(cast_value.map(|v| v.extract().unwrap()))
}
}};
}

/// Find the index of a given value (the first and only entry in
/// `value_series`), find its index within `series`.
pub fn index_of(series: &Series, value_series: &Series) -> PolarsResult<Option<usize>> {
// TODO ensure value_series length is 1
itamarst marked this conversation as resolved.
Show resolved Hide resolved
// TODO passing in a Series for the value is kinda meh, is there a way to pass in an AnyValue instead?
let value_series = if value_series.dtype().is_null() {
// Should be able to cast null dtype to anything, so cast it to dtype of
// Series we're searching.
&value_series.cast(series.dtype())?
} else {
value_series
};
let value_dtype = value_series.dtype();

if value_dtype.is_signed_integer() {
let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0);
let result = downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value);
itamarst marked this conversation as resolved.
Show resolved Hide resolved
return Ok(result);
}
if value_dtype.is_unsigned_integer() {
let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(
series,
try_index_of_numeric_ca,
value
));
}
if value_dtype.is_float() {
let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(
series,
try_index_of_numeric_ca,
value
));
}
// At this point we're done handling integers and floats.
unimplemented!("TODO")
}
3 changes: 3 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ mod rle;
mod rolling;
#[cfg(feature = "round_series")]
mod round;
// TODO add a feature?
mod index_of;
#[cfg(feature = "search_sorted")]
mod search_sorted;
#[cfg(feature = "to_dummies")]
Expand Down Expand Up @@ -122,6 +124,7 @@ pub use rle::*;
pub use rolling::*;
#[cfg(feature = "round_series")]
pub use round::*;
pub use index_of::*;
#[cfg(feature = "search_sorted")]
pub use search_sorted::*;
#[cfg(feature = "to_dummies")]
Expand Down
16 changes: 16 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/index_of.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use polars_ops::series::index_of as index_of_op;
use super::*;

pub(super) fn index_of(s: &mut [Column]) -> PolarsResult<Option<Column>> {
let series = s[0].as_materialized_series();
itamarst marked this conversation as resolved.
Show resolved Hide resolved
let value = s[1].as_materialized_series();
if value.len() != 1 {
polars_bail!(
ComputeError:
"there can only be a single value searched for in `index_of` expressions, but {} values were give",
value.len(),
);
}
let result = index_of_op(series, value)?;
Ok(result.map(|r| Column::new(series.name().clone(), [r as IdxSize])))
}
7 changes: 7 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ mod round;
#[cfg(feature = "row_hash")]
mod row_hash;
pub(super) mod schema;
mod index_of;
#[cfg(feature = "search_sorted")]
mod search_sorted;
mod shift_and_fill;
Expand Down Expand Up @@ -154,6 +155,7 @@ pub enum FunctionExpr {
Hash(u64, u64, u64, u64),
#[cfg(feature = "arg_where")]
ArgWhere,
IndexOf,
#[cfg(feature = "search_sorted")]
SearchSorted(SearchSortedSide),
#[cfg(feature = "range")]
Expand Down Expand Up @@ -392,6 +394,7 @@ impl Hash for FunctionExpr {
#[cfg(feature = "business")]
Business(f) => f.hash(state),
Pow(f) => f.hash(state),
IndexOf => {},
#[cfg(feature = "search_sorted")]
SearchSorted(f) => f.hash(state),
#[cfg(feature = "random")]
Expand Down Expand Up @@ -629,6 +632,7 @@ impl Display for FunctionExpr {
Hash(_, _, _, _) => "hash",
#[cfg(feature = "arg_where")]
ArgWhere => "arg_where",
IndexOf => "index_of",
#[cfg(feature = "search_sorted")]
SearchSorted(_) => "search_sorted",
#[cfg(feature = "range")]
Expand Down Expand Up @@ -918,6 +922,9 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> {
ArgWhere => {
wrap!(arg_where::arg_where)
},
IndexOf => {
wrap!(index_of::index_of)
}
#[cfg(feature = "search_sorted")]
SearchSorted(side) => {
map_as_slice!(search_sorted::search_sorted_impl, side)
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl FunctionExpr {
Hash(..) => mapper.with_dtype(DataType::UInt64),
#[cfg(feature = "arg_where")]
ArgWhere => mapper.with_dtype(IDX_DTYPE),
IndexOf => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "search_sorted")]
SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "range")]
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,24 @@ impl Expr {
)
}

/// Find the index of a value.
pub fn index_of<E: Into<Expr>>(self, element: E) -> Expr {
let element = element.into();
Expr::Function {
input: vec![self, element],
function: FunctionExpr::IndexOf,
options: FunctionOptions {
// TODO which ApplyOptions, if any?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific choice that makes sense here? I don't yet understand the enum options' meanings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think wildcard expansion should be set to 2..

//collect_groups: ApplyOptions::GroupWise,
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
fmt_str: "index_of",
// TODO can we rely on casting here instead of doing it in the
itamarst marked this conversation as resolved.
Show resolved Hide resolved
// function?
..Default::default()
},
}
}

#[cfg(feature = "search_sorted")]
/// Find indices where elements should be inserted to maintain order.
pub fn search_sorted<E: Into<Expr>>(self, element: E, side: SearchSortedSide) -> Expr {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-python/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,18 @@ impl PyExpr {
self.inner.clone().arg_min().into()
}

fn index_of(&self, element: Self) -> Self {
self.inner.clone().index_of(element.inner).into()
}

#[cfg(feature = "search_sorted")]
fn search_sorted(&self, element: Self, side: Wrap<SearchSortedSide>) -> Self {
self.inner
.clone()
.search_sorted(element.inner, side.0)
.into()
}

fn gather(&self, idx: Self) -> Self {
self.inner.clone().gather(idx.inner).into()
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
("hash", seed, seed_1, seed_2, seed_3).to_object(py)
},
FunctionExpr::ArgWhere => ("argwhere",).to_object(py),
FunctionExpr::IndexOf => ("index_of",).to_object(py),
#[cfg(feature = "search_sorted")]
FunctionExpr::SearchSorted(side) => (
"search_sorted",
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-python/src/series/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,13 @@ fn scatter_impl(

s.and_then(|s| s.cast(&logical_dtype))
}

#[pymethods]
impl PySeries {
/// Given a `PySeries` of length 0, find the index of the first value within
/// self.
fn index_of(&self, value: PySeries) -> PyResult<Option<usize>> {
// TODO assert length of value is 1?
index_of(&self.series, &value.series).map_err(|e| PyErr::from(PyPolarsErr::from(e)))
}
}
7 changes: 7 additions & 0 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,6 +2303,13 @@ def arg_min(self) -> Expr:
"""
return self._from_pyexpr(self._pyexpr.arg_min())

def index_of(self, element: IntoExpr | np.ndarray[Any, Any]) -> Expr:
"""
TODO
"""
element = parse_into_expression(element, str_as_lit=True, list_as_series=True) # type: ignore[arg-type]
return self._from_pyexpr(self._pyexpr.index_of(element))

def search_sorted(
self, element: IntoExpr | np.ndarray[Any, Any], side: SearchSortedSide = "any"
) -> Expr:
Expand Down
19 changes: 19 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4732,6 +4732,25 @@ def scatter(
self._s.scatter(indices._s, values._s)
return self

def index_of(self, element) -> int | None:
"""
Get the first index of a value, or ``None`` if it's not found.

Parameters
----------
element
Value to find.

Examples
--------
TODO
"""
df = F.select(F.lit(self).index_of(element))
if isinstance(element, (list, Series, pl.Expr, np.ndarray)):
return df.to_series()
else:
return df.item()

def clear(self, n: int = 0) -> Series:
"""
Create an empty copy of the current Series, with zero to 'n' elements.
Expand Down
Loading