Skip to content

Commit b0d7cd0

Browse files
Preserve ordering equivalencies on with_reorder (#13770)
* Preserve ordering equivalencies on `with_reorder` * Add assertions * Return early if filtered_exprs is empty * Add clarify comment * Refactor * Add comprehensive test case * Add comment for exprs_equal * Cargo fmt * Clippy fix * Update properties.rs * Update exprs_equal and add tests * Update properties.rs --------- Co-authored-by: berkaysynnada <[email protected]>
1 parent f3b1141 commit b0d7cd0

File tree

2 files changed

+494
-4
lines changed

2 files changed

+494
-4
lines changed

datafusion/physical-expr/src/equivalence/class.rs

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,59 @@ impl EquivalenceGroup {
626626
JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(),
627627
}
628628
}
629+
630+
/// Checks if two expressions are equal either directly or through equivalence classes.
631+
/// For complex expressions (e.g. a + b), checks that the expression trees are structurally
632+
/// identical and their leaf nodes are equivalent either directly or through equivalence classes.
633+
pub fn exprs_equal(
634+
&self,
635+
left: &Arc<dyn PhysicalExpr>,
636+
right: &Arc<dyn PhysicalExpr>,
637+
) -> bool {
638+
// Direct equality check
639+
if left.eq(right) {
640+
return true;
641+
}
642+
643+
// Check if expressions are equivalent through equivalence classes
644+
// We need to check both directions since expressions might be in different classes
645+
if let Some(left_class) = self.get_equivalence_class(left) {
646+
if left_class.contains(right) {
647+
return true;
648+
}
649+
}
650+
if let Some(right_class) = self.get_equivalence_class(right) {
651+
if right_class.contains(left) {
652+
return true;
653+
}
654+
}
655+
656+
// For non-leaf nodes, check structural equality
657+
let left_children = left.children();
658+
let right_children = right.children();
659+
660+
// If either expression is a leaf node and we haven't found equality yet,
661+
// they must be different
662+
if left_children.is_empty() || right_children.is_empty() {
663+
return false;
664+
}
665+
666+
// Type equality check through reflection
667+
if left.as_any().type_id() != right.as_any().type_id() {
668+
return false;
669+
}
670+
671+
// Check if the number of children is the same
672+
if left_children.len() != right_children.len() {
673+
return false;
674+
}
675+
676+
// Check if all children are equal
677+
left_children
678+
.into_iter()
679+
.zip(right_children)
680+
.all(|(left_child, right_child)| self.exprs_equal(left_child, right_child))
681+
}
629682
}
630683

631684
impl Display for EquivalenceGroup {
@@ -647,9 +700,10 @@ mod tests {
647700

648701
use super::*;
649702
use crate::equivalence::tests::create_test_params;
650-
use crate::expressions::{lit, Literal};
703+
use crate::expressions::{lit, BinaryExpr, Literal};
651704

652705
use datafusion_common::{Result, ScalarValue};
706+
use datafusion_expr::Operator;
653707

654708
#[test]
655709
fn test_bridge_groups() -> Result<()> {
@@ -777,4 +831,159 @@ mod tests {
777831
assert!(!cls1.contains_any(&cls3));
778832
assert!(!cls2.contains_any(&cls3));
779833
}
834+
835+
#[test]
836+
fn test_exprs_equal() -> Result<()> {
837+
struct TestCase {
838+
left: Arc<dyn PhysicalExpr>,
839+
right: Arc<dyn PhysicalExpr>,
840+
expected: bool,
841+
description: &'static str,
842+
}
843+
844+
// Create test columns
845+
let col_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
846+
let col_b = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
847+
let col_x = Arc::new(Column::new("x", 2)) as Arc<dyn PhysicalExpr>;
848+
let col_y = Arc::new(Column::new("y", 3)) as Arc<dyn PhysicalExpr>;
849+
850+
// Create test literals
851+
let lit_1 =
852+
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc<dyn PhysicalExpr>;
853+
let lit_2 =
854+
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc<dyn PhysicalExpr>;
855+
856+
// Create equivalence group with classes (a = x) and (b = y)
857+
let eq_group = EquivalenceGroup::new(vec![
858+
EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]),
859+
EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]),
860+
]);
861+
862+
let test_cases = vec![
863+
// Basic equality tests
864+
TestCase {
865+
left: Arc::clone(&col_a),
866+
right: Arc::clone(&col_a),
867+
expected: true,
868+
description: "Same column should be equal",
869+
},
870+
// Equivalence class tests
871+
TestCase {
872+
left: Arc::clone(&col_a),
873+
right: Arc::clone(&col_x),
874+
expected: true,
875+
description: "Columns in same equivalence class should be equal",
876+
},
877+
TestCase {
878+
left: Arc::clone(&col_b),
879+
right: Arc::clone(&col_y),
880+
expected: true,
881+
description: "Columns in same equivalence class should be equal",
882+
},
883+
TestCase {
884+
left: Arc::clone(&col_a),
885+
right: Arc::clone(&col_b),
886+
expected: false,
887+
description:
888+
"Columns in different equivalence classes should not be equal",
889+
},
890+
// Literal tests
891+
TestCase {
892+
left: Arc::clone(&lit_1),
893+
right: Arc::clone(&lit_1),
894+
expected: true,
895+
description: "Same literal should be equal",
896+
},
897+
TestCase {
898+
left: Arc::clone(&lit_1),
899+
right: Arc::clone(&lit_2),
900+
expected: false,
901+
description: "Different literals should not be equal",
902+
},
903+
// Complex expression tests
904+
TestCase {
905+
left: Arc::new(BinaryExpr::new(
906+
Arc::clone(&col_a),
907+
Operator::Plus,
908+
Arc::clone(&col_b),
909+
)) as Arc<dyn PhysicalExpr>,
910+
right: Arc::new(BinaryExpr::new(
911+
Arc::clone(&col_x),
912+
Operator::Plus,
913+
Arc::clone(&col_y),
914+
)) as Arc<dyn PhysicalExpr>,
915+
expected: true,
916+
description:
917+
"Binary expressions with equivalent operands should be equal",
918+
},
919+
TestCase {
920+
left: Arc::new(BinaryExpr::new(
921+
Arc::clone(&col_a),
922+
Operator::Plus,
923+
Arc::clone(&col_b),
924+
)) as Arc<dyn PhysicalExpr>,
925+
right: Arc::new(BinaryExpr::new(
926+
Arc::clone(&col_x),
927+
Operator::Plus,
928+
Arc::clone(&col_a),
929+
)) as Arc<dyn PhysicalExpr>,
930+
expected: false,
931+
description:
932+
"Binary expressions with non-equivalent operands should not be equal",
933+
},
934+
TestCase {
935+
left: Arc::new(BinaryExpr::new(
936+
Arc::clone(&col_a),
937+
Operator::Plus,
938+
Arc::clone(&lit_1),
939+
)) as Arc<dyn PhysicalExpr>,
940+
right: Arc::new(BinaryExpr::new(
941+
Arc::clone(&col_x),
942+
Operator::Plus,
943+
Arc::clone(&lit_1),
944+
)) as Arc<dyn PhysicalExpr>,
945+
expected: true,
946+
description: "Binary expressions with equivalent column and same literal should be equal",
947+
},
948+
TestCase {
949+
left: Arc::new(BinaryExpr::new(
950+
Arc::new(BinaryExpr::new(
951+
Arc::clone(&col_a),
952+
Operator::Plus,
953+
Arc::clone(&col_b),
954+
)),
955+
Operator::Multiply,
956+
Arc::clone(&lit_1),
957+
)) as Arc<dyn PhysicalExpr>,
958+
right: Arc::new(BinaryExpr::new(
959+
Arc::new(BinaryExpr::new(
960+
Arc::clone(&col_x),
961+
Operator::Plus,
962+
Arc::clone(&col_y),
963+
)),
964+
Operator::Multiply,
965+
Arc::clone(&lit_1),
966+
)) as Arc<dyn PhysicalExpr>,
967+
expected: true,
968+
description: "Nested binary expressions with equivalent operands should be equal",
969+
},
970+
];
971+
972+
for TestCase {
973+
left,
974+
right,
975+
expected,
976+
description,
977+
} in test_cases
978+
{
979+
let actual = eq_group.exprs_equal(&left, &right);
980+
assert_eq!(
981+
actual, expected,
982+
"{}: Failed comparing {:?} and {:?}, expected {}, got {}",
983+
description, left, right, expected, actual
984+
);
985+
}
986+
987+
Ok(())
988+
}
780989
}

0 commit comments

Comments
 (0)