Skip to content

Commit

Permalink
Start refactoring to be in polars-ops and use expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Oct 23, 2024
1 parent 47ac42f commit 7f02952
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 68 deletions.
7 changes: 0 additions & 7 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub(crate) mod nulls;
mod reverse;
#[cfg(feature = "rolling_window")]
pub(crate) mod rolling_window;
mod search;
pub mod search_sorted;
mod set;
mod shift;
Expand Down Expand Up @@ -238,12 +237,6 @@ pub trait ChunkApply<'a, T> {
F: Fn(Option<T>, &S) -> S;
}

/// Search for an item.
pub trait ChunkSearch<'a, T> {
/// Return the index of the given value within self, or `None` if not found.
fn index_of(&'a self, value: Option<T>) -> Option<usize>;
}

/// Aggregation operations.
pub trait ChunkAgg<T> {
/// Aggregate the sum of the ChunkedArray.
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ mod rle;
mod rolling;
#[cfg(feature = "round_series")]
mod round;
// TODO add a feature?
mod search;
#[cfg(feature = "search_sorted")]
mod search_sorted;
#[cfg(feature = "to_dummies")]
Expand Down
77 changes: 77 additions & 0 deletions crates/polars-ops/src/series/ops/search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use polars_core::downcast_as_macro_arg_physical;
use polars_core::prelude::*;
use polars_utils::float::IsFloat;

use crate::series::ops::SeriesSealed;

/// Search for an item, typically in a ChunkedArray.
pub trait ChunkSearch<'a, T> {
/// Return the index of the given value within self, or `None` if not found.
fn index_of(&'a self, value: Option<T>) -> Option<usize>;
}


impl<'a, T> ChunkSearch<'a, T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn index_of(&'a self, value: Option<T::Native>) -> Option<usize> {
// A NaN is never equal to anything, including itself. But we still want
// to be able to search for NaNs, so we handle them specially.
if value.map(|v| v.is_nan()) == Some(true) {
return self
.iter()
.position(|opt_val| opt_val.map(|v| v.is_nan()) == Some(true));
}

self.iter().position(|opt_val| opt_val == value)
}
}

/// Try casting the value to the correct type, then call index_of().
macro_rules! try_index_of_numeric_ca {
($ca:expr, $value:expr) => {{
let ca = $ca;
let cast_value = $value.map(|v| AnyValue::from(v).strict_cast(ca.dtype()));
if cast_value == Some(None) {
// We can can't cast the searched-for value to a valid data point
// within the dtype of the Series we're searching, which means we
// will never find that value.
None
} else {
let cast_value = cast_value.flatten();
ca.index_of(cast_value.map(|v| v.extract().unwrap()))
}
}};
}

pub trait SearchSeries: SeriesSealed {
fn index_of(&self, value_series: &Series) -> PolarsResult<Option<usize>> {
let series = self.as_series();
let value_series = if value_series.dtype().is_null() {
// Should be able to cast null dtype to anything, so cast it to dtype of
// Series we're searching.
&value_series.cast(series.dtype())?
} else {
value_series
};
let value_dtype = value_series.dtype();

if value_dtype.is_signed_integer() {
let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0);
let result = downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value);
return Ok(result);
}
if value_dtype.is_unsigned_integer() {
let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value));
}
if value_dtype.is_float() {
let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(series, try_index_of_numeric_ca, value));
}
// At this point we're done handling integers and floats.
unimplemented!("TODO")
}
}

61 changes: 0 additions & 61 deletions crates/polars-python/src/series/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,64 +150,3 @@ impl PySeries {
index_of(&self.series, &value.series).map_err(|e| PyErr::from(PyPolarsErr::from(e)))
}
}

/// Try casting the value to the correct type, then call index_of().
macro_rules! try_index_of {
($self:expr, $value:expr) => {{
let cast_value = $value.map(|v| AnyValue::from(v).strict_cast($self.dtype()));
if cast_value == Some(None) {
// We can can't cast the searched-for value to a valid data point
// within the dtype of the Series we're searching, which means we
// will never find that value.
None
} else {
let cast_value = cast_value.flatten();
$self.index_of(cast_value.map(|v| v.extract().unwrap()))
}
}};
}

fn index_of(series: &Series, value_series: &Series) -> PolarsResult<Option<usize>> {
let value_series = if value_series.dtype().is_null() {
// Should be able to cast null dtype to anything, so cast it to dtype of
// Series we're searching.
&value_series.cast(series.dtype())?
} else {
value_series
};
let value_dtype = value_series.dtype();

if value_dtype.is_signed_integer() {
let value = value_series.cast(&DataType::Int64)?.i64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(series, try_index_of, value));
}
if value_dtype.is_unsigned_integer() {
let value = value_series.cast(&DataType::UInt64)?.u64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(series, try_index_of, value));
}
if value_dtype.is_float() {
let value = value_series.cast(&DataType::Float64)?.f64().unwrap().get(0);
return Ok(downcast_as_macro_arg_physical!(series, try_index_of, value));
}
// At this point we're done handling integers and floats.
match value_series.dtype() {
DataType::List(_) => {
let value = value_series
.list()
.unwrap()
.get(0)
.map(|arr| Series::from_arrow("".into(), arr).unwrap());
Ok(series.list()?.index_of(value.as_ref()))
},
#[cfg(feature="dtype-array")]
DataType::Array(_, _) => {
let value = value_series
.array()
.unwrap()
.get(0)
.map(|arr| Series::from_arrow("".into(), arr).unwrap());
Ok(series.array()?.index_of(value.as_ref()))
},
_ => unimplemented!("TODO"),
}
}

0 comments on commit 7f02952

Please sign in to comment.