Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Eq, PartialEq, Hash for dyn PhysicalExpr #13005

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions datafusion/core/tests/sql/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ use bytes::Bytes;
use chrono::{TimeZone, Utc};
use datafusion_expr::{col, lit, Expr, Operator};
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::PhysicalExpr;
use futures::stream::{self, BoxStream};
use object_store::{
path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta,
Expand Down Expand Up @@ -97,7 +96,7 @@ async fn parquet_partition_pruning_filter() -> Result<()> {
assert!(pred.as_any().is::<BinaryExpr>());
let pred = pred.as_any().downcast_ref::<BinaryExpr>().unwrap();

assert_eq!(pred, expected.as_any());
assert_eq!(pred, expected.as_ref());

Ok(())
}
Expand Down
84 changes: 37 additions & 47 deletions datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use datafusion_expr_common::sort_properties::ExprProperties;
/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash {
/// Returns the physical expression as [`Any`] so that it can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
Expand Down Expand Up @@ -141,38 +141,6 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
Ok(Some(vec![]))
}

/// Update the hash `state` with this expression requirements from
/// [`Hash`].
///
/// This method is required to support hashing [`PhysicalExpr`]s. To
/// implement it, typically the type implementing
/// [`PhysicalExpr`] implements [`Hash`] and
/// then the following boiler plate is used:
///
/// # Example:
/// ```
/// // User defined expression that derives Hash
/// #[derive(Hash, Debug, PartialEq, Eq)]
/// struct MyExpr {
/// val: u64
/// }
///
/// // impl PhysicalExpr {
/// // ...
/// # impl MyExpr {
/// // Boiler plate to call the derived Hash impl
/// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
/// use std::hash::Hash;
/// let mut s = state;
/// self.hash(&mut s);
/// }
/// // }
/// # }
/// ```
/// Note: [`PhysicalExpr`] is not constrained by [`Hash`]
/// directly because it must remain object safe.
fn dyn_hash(&self, _state: &mut dyn Hasher);

/// Calculates the properties of this [`PhysicalExpr`] based on its
/// children's properties (i.e. order and range), recursively aggregating
/// the information from its children. In cases where the [`PhysicalExpr`]
Expand All @@ -183,6 +151,42 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq<dyn Any> {
}
}

/// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object
/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
pub trait DynEq {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add some documentation here explaining why this is needed and what it is used for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in f1917fa.

fn dyn_eq(&self, other: &dyn Any) -> bool;
}

impl<T: Eq + Any> DynEq for T {
fn dyn_eq(&self, other: &dyn Any) -> bool {
other
.downcast_ref::<Self>()
.map_or(false, |other| other == self)
}
}

impl PartialEq for dyn PhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other.as_any())
}
}

impl Eq for dyn PhysicalExpr {}

/// [`PhysicalExpr`] can't be constrained by [`Hash`] directly because it must remain
/// object safe. To ease implementation blanket implementation is provided for [`Hash`]
/// types.
pub trait DynHash {
fn dyn_hash(&self, _state: &mut dyn Hasher);
}

impl<T: Hash + Any> DynHash for T {
fn dyn_hash(&self, mut state: &mut dyn Hasher) {
self.type_id().hash(&mut state);
self.hash(&mut state)
}
}

impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
Expand Down Expand Up @@ -210,20 +214,6 @@ pub fn with_new_children_if_necessary(
}
}

pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn PhysicalExpr>>() {
any.downcast_ref::<Arc<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else if any.is::<Box<dyn PhysicalExpr>>() {
any.downcast_ref::<Box<dyn PhysicalExpr>>()
.unwrap()
.as_any()
} else {
any
}
}

/// Returns [`Display`] able a list of [`PhysicalExpr`]
///
/// Example output: `[a + 1, b]`
Expand Down
4 changes: 0 additions & 4 deletions datafusion/physical-expr-common/src/sort_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,10 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {todo!() }
/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {todo!()}
/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> {todo!()}
/// # fn dyn_hash(&self, _state: &mut dyn Hasher) {todo!()}
/// # }
/// # impl Display for MyPhysicalExpr {
/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") }
/// # }
/// # impl PartialEq<dyn Any> for MyPhysicalExpr {
/// # fn eq(&self, _other: &dyn Any) -> bool { true }
/// # }
/// # fn col(name: &str) -> Arc<dyn PhysicalExpr> { Arc::new(MyPhysicalExpr) }
/// // Sort by a ASC
/// let options = SortOptions::default();
Expand Down
7 changes: 3 additions & 4 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ pub struct ConstExpr {

impl PartialEq for ConstExpr {
fn eq(&self, other: &Self) -> bool {
self.across_partitions == other.across_partitions
&& self.expr.eq(other.expr.as_any())
self.across_partitions == other.across_partitions && self.expr.eq(&other.expr)
}
}

Expand Down Expand Up @@ -121,7 +120,7 @@ impl ConstExpr {

/// Returns true if this constant expression is equal to the given expression
pub fn eq_expr(&self, other: impl AsRef<dyn PhysicalExpr>) -> bool {
self.expr.eq(other.as_ref().as_any())
self.expr.as_ref() == other.as_ref()
}

/// Returns a [`Display`]able list of `ConstExpr`.
Expand Down Expand Up @@ -557,7 +556,7 @@ impl EquivalenceGroup {
new_classes.push((source, vec![Arc::clone(target)]));
}
if let Some((_, values)) =
new_classes.iter_mut().find(|(key, _)| key.eq(source))
new_classes.iter_mut().find(|(key, _)| *key == source)
{
if !physical_exprs_contains(values, target) {
values.push(Arc::clone(target));
Expand Down
42 changes: 20 additions & 22 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

mod kernels;

use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
Expand All @@ -48,7 +47,7 @@ use kernels::{
};

/// Binary expression
#[derive(Debug, Hash, Clone)]
#[derive(Debug, Clone, Eq)]
pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
Expand All @@ -57,6 +56,24 @@ pub struct BinaryExpr {
fail_on_overflow: bool,
}

// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808
impl PartialEq for BinaryExpr {
fn eq(&self, other: &Self) -> bool {
self.left.eq(&other.left)
&& self.op.eq(&other.op)
&& self.right.eq(&other.right)
&& self.fail_on_overflow.eq(&other.fail_on_overflow)
}
}
impl Hash for BinaryExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.left.hash(state);
self.op.hash(state);
self.right.hash(state);
self.fail_on_overflow.hash(state);
}
}

impl BinaryExpr {
/// Create new binary expression
pub fn new(
Expand Down Expand Up @@ -477,11 +494,6 @@ impl PhysicalExpr for BinaryExpr {
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}

/// For each operator, [`BinaryExpr`] has distinct rules.
/// TODO: There may be rules specific to some data types and expression ranges.
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
Expand Down Expand Up @@ -525,20 +537,6 @@ impl PhysicalExpr for BinaryExpr {
}
}

impl PartialEq<dyn Any> for BinaryExpr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this is great

fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.left.eq(&x.left)
&& self.op == x.op
&& self.right.eq(&x.right)
&& self.fail_on_overflow.eq(&x.fail_on_overflow)
})
.unwrap_or(false)
}
}

/// Casts dictionary array to result type for binary numerical operators. Such operators
/// between array and scalar produce a dictionary array other than primitive array of the
/// same operators between array and array. This leads to inconsistent result types causing
Expand Down
40 changes: 3 additions & 37 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
// under the License.

use std::borrow::Cow;
use std::hash::{Hash, Hasher};
use std::hash::Hash;
use std::{any::Any, sync::Arc};

use crate::expressions::try_cast;
use crate::physical_expr::down_cast_any_ref;
use crate::PhysicalExpr;

use arrow::array::*;
Expand All @@ -37,7 +36,7 @@ use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);

#[derive(Debug, Hash)]
#[derive(Debug, Hash, PartialEq, Eq)]
enum EvalMethod {
/// CASE WHEN condition THEN result
/// [WHEN ...]
Expand Down Expand Up @@ -80,7 +79,7 @@ enum EvalMethod {
/// [WHEN ...]
/// [ELSE result]
/// END
#[derive(Debug, Hash)]
#[derive(Debug, Hash, PartialEq, Eq)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can it permit auto derivation of PartialEq for CaseExpr 🤔 -- we need manual impl's for others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is probably the strangest part of the bug. If you have double wrappers around dyn Trait then there is no issue: rust-lang/rust#78808 (comment)
CaseExpr doesn't have any dyn Trait fields with a single wrapper so the derive macro has no problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a ticket to remove those impl's once the bug is resolved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about #13196?

pub struct CaseExpr {
/// Optional base expression that can be compared to literal values in the "when" expressions
expr: Option<Arc<dyn PhysicalExpr>>,
Expand Down Expand Up @@ -506,39 +505,6 @@ impl PhysicalExpr for CaseExpr {
)?))
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}
}

impl PartialEq<dyn Any> for CaseExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
let expr_eq = match (&self.expr, &x.expr) {
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
(None, None) => true,
_ => false,
};
let else_expr_eq = match (&self.else_expr, &x.else_expr) {
(Some(expr1), Some(expr2)) => expr1.eq(expr2),
(None, None) => true,
_ => false,
};
expr_eq
&& else_expr_eq
&& self.when_then_expr.len() == x.when_then_expr.len()
&& self.when_then_expr.iter().zip(x.when_then_expr.iter()).all(
|((when1, then1), (when2, then2))| {
when1.eq(when2) && then1.eq(then2)
},
)
})
.unwrap_or(false)
}
}

/// Create a CASE expression
Expand Down
Loading