diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 7450735fd312..a3e7f04cc9e1 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -34,7 +34,6 @@ pub(crate) mod nulls; mod reverse; #[cfg(feature = "rolling_window")] pub(crate) mod rolling_window; -mod search; pub mod search_sorted; mod set; mod shift; @@ -238,12 +237,6 @@ pub trait ChunkApply<'a, T> { F: Fn(Option, &S) -> S; } -/// Search for an item. -pub trait ChunkSearch<'a, T> { - /// Return the index of the given value within self, or `None` if not found. - fn index_of(&'a self, value: Option) -> Option; -} - /// Aggregation operations. pub trait ChunkAgg { /// Aggregate the sum of the ChunkedArray. diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index b684815238f7..1e1123d77d30 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -53,6 +53,8 @@ mod rle; mod rolling; #[cfg(feature = "round_series")] mod round; +// TODO add a feature? +mod search; #[cfg(feature = "search_sorted")] mod search_sorted; #[cfg(feature = "to_dummies")] diff --git a/crates/polars-ops/src/series/ops/search.rs b/crates/polars-ops/src/series/ops/search.rs new file mode 100644 index 000000000000..a5249e1107a3 --- /dev/null +++ b/crates/polars-ops/src/series/ops/search.rs @@ -0,0 +1,77 @@ +use polars_core::downcast_as_macro_arg_physical; +use polars_core::prelude::*; +use polars_utils::float::IsFloat; + +use crate::series::ops::SeriesSealed; + +/// Search for an item, typically in a ChunkedArray. +pub trait ChunkSearch<'a, T> { + /// Return the index of the given value within self, or `None` if not found. + fn index_of(&'a self, value: Option) -> Option; +} + + +impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray +where + T: PolarsNumericType, +{ + fn index_of(&'a self, value: Option) -> Option { + // A NaN is never equal to anything, including itself. But we still want + // to be able to search for NaNs, so we handle them specially. + if value.map(|v| v.is_nan()) == Some(true) { + return self + .iter() + .position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true)); + } + + self.iter().position(|opt_val| opt_val == value) + } +} + +/// Try casting the value to the correct type, then call index_of(). +macro_rules! try_index_of_numeric_ca { + ($ca:expr, $value:expr) => {{ + let ca = $ca; + let cast_value = $value.map(|v| AnyValue::from(v).strict_cast(ca.dtype())); + if cast_value == Some(None) { + // We can can't cast the searched-for value to a valid data point + // within the dtype of the Series we're searching, which means we + // will never find that value. + None + } else { + let cast_value = cast_value.flatten(); + ca.index_of(cast_value.map(|v| v.extract().unwrap())) + } + }}; +} + +pub trait SearchSeries: SeriesSealed { + fn index_of(&self, value_series: &Series) -> PolarsResult> { + let series = self.as_series(); + let value_series = if value_series.dtype().is_null() { + // Should be able to cast null dtype to anything, so cast it to dtype of + // Series we're searching. + &value_series.cast(series.dtype())? + } else { + value_series + }; + let value_dtype = value_series.dtype(); + + if value_dtype.is_signed_integer() { + let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0); + let result = downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value); + return Ok(result); + } + if value_dtype.is_unsigned_integer() { + let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0); + return Ok(downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value)); + } + if value_dtype.is_float() { + let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0); + return Ok(downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value)); + } + // At this point we're done handling integers and floats. + unimplemented!("TODO") + } +} + diff --git a/crates/polars-python/src/series/scatter.rs b/crates/polars-python/src/series/scatter.rs index 284b90368158..2fa6a0f729bc 100644 --- a/crates/polars-python/src/series/scatter.rs +++ b/crates/polars-python/src/series/scatter.rs @@ -150,64 +150,3 @@ impl PySeries { index_of(&self.series, &value.series).map_err(|e| PyErr::from(PyPolarsErr::from(e))) } } - -/// Try casting the value to the correct type, then call index_of(). -macro_rules! try_index_of { - ($self:expr, $value:expr) => {{ - let cast_value = $value.map(|v| AnyValue::from(v).strict_cast($self.dtype())); - if cast_value == Some(None) { - // We can can't cast the searched-for value to a valid data point - // within the dtype of the Series we're searching, which means we - // will never find that value. - None - } else { - let cast_value = cast_value.flatten(); - $self.index_of(cast_value.map(|v| v.extract().unwrap())) - } - }}; -} - -fn index_of(series: &Series, value_series: &Series) -> PolarsResult> { - let value_series = if value_series.dtype().is_null() { - // Should be able to cast null dtype to anything, so cast it to dtype of - // Series we're searching. - &value_series.cast(series.dtype())? - } else { - value_series - }; - let value_dtype = value_series.dtype(); - - if value_dtype.is_signed_integer() { - let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0); - return Ok(downcast_as_macro_arg_physical!(series, try_index_of, value)); - } - if value_dtype.is_unsigned_integer() { - let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0); - return Ok(downcast_as_macro_arg_physical!(series, try_index_of, value)); - } - if value_dtype.is_float() { - let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0); - return Ok(downcast_as_macro_arg_physical!(series, try_index_of, value)); - } - // At this point we're done handling integers and floats. - match value_series.dtype() { - DataType::List(_) => { - let value = value_series - .list() - .unwrap() - .get(0) - .map(|arr| Series::from_arrow("".into(), arr).unwrap()); - Ok(series.list()?.index_of(value.as_ref())) - }, - #[cfg(feature="dtype-array")] - DataType::Array(_, _) => { - let value = value_series - .array() - .unwrap() - .get(0) - .map(|arr| Series::from_arrow("".into(), arr).unwrap()); - Ok(series.array()?.index_of(value.as_ref())) - }, - _ => unimplemented!("TODO"), - } -}