Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Native selector XOR set operation, guarantee consistent selector column-order #16833

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions crates/polars-plan/src/dsl/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ impl MetaNameSpace {
}
}

pub fn _selector_and(self, other: Expr) -> PolarsResult<Expr> {
if let Expr::Selector(mut s) = self.0 {
if let Expr::Selector(s_other) = other {
s = s.bitand(s_other);
} else {
s = s.bitand(Selector::Root(Box::new(other)))
}
Ok(Expr::Selector(s))
} else {
polars_bail!(ComputeError: "expected selector, got {:?}", self.0)
}
}

pub fn _selector_sub(self, other: Expr) -> PolarsResult<Expr> {
if let Expr::Selector(mut s) = self.0 {
if let Expr::Selector(s_other) = other {
Expand All @@ -122,12 +135,12 @@ impl MetaNameSpace {
}
}

pub fn _selector_and(self, other: Expr) -> PolarsResult<Expr> {
pub fn _selector_xor(self, other: Expr) -> PolarsResult<Expr> {
if let Expr::Selector(mut s) = self.0 {
if let Expr::Selector(s_other) = other {
s = s.bitand(s_other);
s = s ^ s_other;
} else {
s = s.bitand(Selector::Root(Box::new(other)))
s = s ^ Selector::Root(Box::new(other))
}
Ok(Expr::Selector(s))
} else {
Expand Down
24 changes: 17 additions & 7 deletions crates/polars-plan/src/dsl/selector.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, BitAnd, Sub};
use std::ops::{Add, BitAnd, BitXor, Sub};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -10,6 +10,7 @@ use super::*;
pub enum Selector {
Add(Box<Selector>, Box<Selector>),
Sub(Box<Selector>, Box<Selector>),
ExclusiveOr(Box<Selector>, Box<Selector>),
InterSect(Box<Selector>, Box<Selector>),
Root(Box<Expr>),
}
Expand All @@ -29,20 +30,29 @@ impl Add for Selector {
}
}

impl Sub for Selector {
impl BitAnd for Selector {
type Output = Selector;

#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self::Output {
Selector::Sub(Box::new(self), Box::new(rhs))
fn bitand(self, rhs: Self) -> Self::Output {
Selector::InterSect(Box::new(self), Box::new(rhs))
}
}

impl BitAnd for Selector {
impl BitXor for Selector {
type Output = Selector;

#[allow(clippy::suspicious_arithmetic_impl)]
fn bitand(self, rhs: Self) -> Self::Output {
Selector::InterSect(Box::new(self), Box::new(rhs))
fn bitxor(self, rhs: Self) -> Self::Output {
Selector::ExclusiveOr(Box::new(self), Box::new(rhs))
}
}

impl Sub for Selector {
type Output = Selector;

#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self::Output {
Selector::Sub(Box::new(self), Box::new(rhs))
}
}
66 changes: 42 additions & 24 deletions crates/polars-plan/src/logical_plan/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! this contains code used for rewriting projections, expanding wildcards, regex selection etc.
use std::ops::BitXor;

use super::*;

pub(crate) fn prepare_projection(
Expand Down Expand Up @@ -787,11 +789,27 @@ fn replace_selector_inner(
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;
members.extend(rhs_members)
},
Selector::ExclusiveOr(lhs, rhs) => {
let mut lhs_members = Default::default();
replace_selector_inner(*lhs, &mut lhs_members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

let xor_members = lhs_members.bitxor(&rhs_members);
*members = xor_members;
},
Selector::InterSect(lhs, rhs) => {
replace_selector_inner(*lhs, members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

*members = members.intersection(&rhs_members).cloned().collect()
},
Selector::Sub(lhs, rhs) => {
// fill lhs
replace_selector_inner(*lhs, members, scratch, schema, keys)?;

// subtract rhs
let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

Expand All @@ -801,19 +819,8 @@ fn replace_selector_inner(
new_members.insert(e);
}
}

*members = new_members;
},
Selector::InterSect(lhs, rhs) => {
// fill lhs
replace_selector_inner(*lhs, members, scratch, schema, keys)?;

// fill rhs
let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

*members = members.intersection(&rhs_members).cloned().collect()
},
}
Ok(())
}
Expand All @@ -829,17 +836,28 @@ fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult<
let mut members = PlIndexSet::new();
replace_selector_inner(swapped, &mut members, &mut vec![], schema, keys)?;

Ok(Expr::Columns(
members
.into_iter()
.map(|e| {
let Expr::Column(name) = e else {
unreachable!()
};
name
})
.collect(),
))
if members.len() <= 1 {
Ok(Expr::Columns(
members
.into_iter()
.map(|e| {
let Expr::Column(name) = e else {
unreachable!()
};
name
})
.collect(),
))
} else {
// Ensure that multiple columns returned from combined/nested selectors remain in schema order
let selected = schema
.iter_fields()
.map(|field| ColumnName::from(field.name().as_ref()))
.filter(|field_name| members.contains(&Expr::Column(field_name.clone())))
.collect();

Ok(Expr::Columns(selected))
}
},
e => Ok(e),
})
Expand Down
12 changes: 11 additions & 1 deletion py-polars/docs/source/reference/selectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ Importing
Set operations
--------------

Selectors support ``set`` operations such as:
Selectors support the following ``set`` operations:

- UNION: ``A | B``
- INTERSECTION: ``A & B``
- DIFFERENCE: ``A - B``
- EXCLUSIVE OR: ``A ^ B``
- COMPLEMENT: ``~A``

Note that both individual selector results and selector set operations will always return
matching columns in the same order as the underlying frame schema.

Examples
========
Expand Down Expand Up @@ -88,6 +91,13 @@ Examples
"Lmn": pl.Duration,
}

# Select the EXCLUSIVE OR of numeric columns and columns that contain an "e"
assert df.select(cs.contains("e") ^ cs.numeric()).schema == {
"abc": UInt16,
"bbb": UInt32,
"eee": Boolean,
}

# Select the COMPLEMENT of all columns of dtypes Duration and Time
assert df.select(~cs.by_dtype([pl.Duration, pl.Time])).schema == {
"abc": pl.UInt16,
Expand Down
14 changes: 9 additions & 5 deletions py-polars/polars/expr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,20 @@ def _as_selector(self) -> Expr:
return wrap_expr(self._pyexpr._meta_as_selector())

def _selector_add(self, other: Expr) -> Expr:
"""Add selectors."""
"""Add ('+') selectors."""
return wrap_expr(self._pyexpr._meta_selector_add(other._pyexpr))

def _selector_and(self, other: Expr) -> Expr:
"""And ('&') selectors."""
return wrap_expr(self._pyexpr._meta_selector_and(other._pyexpr))

def _selector_sub(self, other: Expr) -> Expr:
"""Subtract selectors."""
"""Subtract ('-') selectors."""
return wrap_expr(self._pyexpr._meta_selector_sub(other._pyexpr))

def _selector_and(self, other: Expr) -> Expr:
"""& selectors."""
return wrap_expr(self._pyexpr._meta_selector_and(other._pyexpr))
def _selector_xor(self, other: Expr) -> Expr:
"""Xor ('^') selectors."""
return wrap_expr(self._pyexpr._meta_selector_xor(other._pyexpr))

@overload
def serialize(self, file: None = ...) -> str: ...
Expand Down
35 changes: 29 additions & 6 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,15 @@ def __repr__(self) -> str:
elif hasattr(self, "_repr_override"):
return self._repr_override
else:
selector_name, params = self._attrs["name"], self._attrs["params"]
set_ops = {"and": "&", "or": "|", "sub": "-"}
selector_name, params = self._attrs["name"], self._attrs["params"] or {}
set_ops = {"and": "&", "or": "|", "sub": "-", "xor": "^"}
if selector_name in set_ops:
op = set_ops[selector_name]
return "({})".format(f" {op} ".join(repr(p) for p in params.values()))
else:
str_params = ", ".join(
(repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}")
for k, v in (params or {}).items()
for k, v in params.items()
).rstrip(",")
return f"cs.{selector_name}({str_params})"

Expand Down Expand Up @@ -381,6 +381,24 @@ def __or__(self, other: Any) -> SelectorType | Expr:
else:
return self.as_expr().__or__(other)

@overload # type: ignore[override]
def __xor__(self, other: SelectorType) -> SelectorType: ...

@overload
def __xor__(self, other: Any) -> Expr: ...

def __xor__(self, other: Any) -> SelectorType | Expr:
if is_column(other):
other = by_name(other.meta.output_name())
if is_selector(other):
return _selector_proxy_(
self.meta._as_selector().meta._selector_xor(other),
parameters={"self": self, "other": other},
name="xor",
)
else:
return self.as_expr().__or__(other)

def __rand__(self, other: Any) -> Expr: # type: ignore[override]
if is_column(other):
colname = other.meta.output_name()
Expand All @@ -396,6 +414,11 @@ def __ror__(self, other: Any) -> Expr: # type: ignore[override]
other = by_name(other.meta.output_name())
return self.as_expr().__ror__(other)

def __rxor__(self, other: Any) -> Expr: # type: ignore[override]
if is_column(other):
other = by_name(other.meta.output_name())
return self.as_expr().__rxor__(other)

def as_expr(self) -> Expr:
"""
Materialize the `selector` as a normal expression.
Expand Down Expand Up @@ -1149,7 +1172,7 @@ def categorical() -> SelectorType:
return _selector_proxy_(F.col(Categorical), name="categorical")


def contains(substring: str | Collection[str]) -> SelectorType:
def contains(*substring: str) -> SelectorType:
"""
Select columns whose names contain the given literal substring(s).

Expand Down Expand Up @@ -1191,7 +1214,7 @@ def contains(substring: str | Collection[str]) -> SelectorType:

Select columns that contain the substring 'ba' or the letter 'z':

>>> df.select(cs.contains(("ba", "z")))
>>> df.select(cs.contains("ba", "z"))
shape: (2, 3)
┌─────┬─────┬───────┐
│ bar ┆ baz ┆ zap │
Expand Down Expand Up @@ -1221,7 +1244,7 @@ def contains(substring: str | Collection[str]) -> SelectorType:
return _selector_proxy_(
F.col(raw_params),
name="contains",
parameters={"substring": escaped_substring},
parameters={"*substring": escaped_substring},
)


Expand Down
14 changes: 12 additions & 2 deletions py-polars/src/expr/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ impl PyExpr {
Ok(out.into())
}

fn _meta_selector_and(&self, other: PyExpr) -> PyResult<PyExpr> {
let out = self
.inner
.clone()
.meta()
._selector_and(other.inner)
.map_err(PyPolarsErr::from)?;
Ok(out.into())
}

fn _meta_selector_sub(&self, other: PyExpr) -> PyResult<PyExpr> {
let out = self
.inner
Expand All @@ -81,12 +91,12 @@ impl PyExpr {
Ok(out.into())
}

fn _meta_selector_and(&self, other: PyExpr) -> PyResult<PyExpr> {
fn _meta_selector_xor(&self, other: PyExpr) -> PyResult<PyExpr> {
let out = self
.inner
.clone()
.meta()
._selector_and(other.inner)
._selector_xor(other.inner)
.map_err(PyPolarsErr::from)?;
Ok(out.into())
}
Expand Down
Loading