@@ -626,6 +626,59 @@ impl EquivalenceGroup {
626
626
JoinType :: RightSemi | JoinType :: RightAnti => right_equivalences. clone ( ) ,
627
627
}
628
628
}
629
+
630
+ /// Checks if two expressions are equal either directly or through equivalence classes.
631
+ /// For complex expressions (e.g. a + b), checks that the expression trees are structurally
632
+ /// identical and their leaf nodes are equivalent either directly or through equivalence classes.
633
+ pub fn exprs_equal (
634
+ & self ,
635
+ left : & Arc < dyn PhysicalExpr > ,
636
+ right : & Arc < dyn PhysicalExpr > ,
637
+ ) -> bool {
638
+ // Direct equality check
639
+ if left. eq ( right) {
640
+ return true ;
641
+ }
642
+
643
+ // Check if expressions are equivalent through equivalence classes
644
+ // We need to check both directions since expressions might be in different classes
645
+ if let Some ( left_class) = self . get_equivalence_class ( left) {
646
+ if left_class. contains ( right) {
647
+ return true ;
648
+ }
649
+ }
650
+ if let Some ( right_class) = self . get_equivalence_class ( right) {
651
+ if right_class. contains ( left) {
652
+ return true ;
653
+ }
654
+ }
655
+
656
+ // For non-leaf nodes, check structural equality
657
+ let left_children = left. children ( ) ;
658
+ let right_children = right. children ( ) ;
659
+
660
+ // If either expression is a leaf node and we haven't found equality yet,
661
+ // they must be different
662
+ if left_children. is_empty ( ) || right_children. is_empty ( ) {
663
+ return false ;
664
+ }
665
+
666
+ // Type equality check through reflection
667
+ if left. as_any ( ) . type_id ( ) != right. as_any ( ) . type_id ( ) {
668
+ return false ;
669
+ }
670
+
671
+ // Check if the number of children is the same
672
+ if left_children. len ( ) != right_children. len ( ) {
673
+ return false ;
674
+ }
675
+
676
+ // Check if all children are equal
677
+ left_children
678
+ . into_iter ( )
679
+ . zip ( right_children)
680
+ . all ( |( left_child, right_child) | self . exprs_equal ( left_child, right_child) )
681
+ }
629
682
}
630
683
631
684
impl Display for EquivalenceGroup {
@@ -647,9 +700,10 @@ mod tests {
647
700
648
701
use super :: * ;
649
702
use crate :: equivalence:: tests:: create_test_params;
650
- use crate :: expressions:: { lit, Literal } ;
703
+ use crate :: expressions:: { lit, BinaryExpr , Literal } ;
651
704
652
705
use datafusion_common:: { Result , ScalarValue } ;
706
+ use datafusion_expr:: Operator ;
653
707
654
708
#[ test]
655
709
fn test_bridge_groups ( ) -> Result < ( ) > {
@@ -777,4 +831,159 @@ mod tests {
777
831
assert ! ( !cls1. contains_any( & cls3) ) ;
778
832
assert ! ( !cls2. contains_any( & cls3) ) ;
779
833
}
834
+
835
+ #[ test]
836
+ fn test_exprs_equal ( ) -> Result < ( ) > {
837
+ struct TestCase {
838
+ left : Arc < dyn PhysicalExpr > ,
839
+ right : Arc < dyn PhysicalExpr > ,
840
+ expected : bool ,
841
+ description : & ' static str ,
842
+ }
843
+
844
+ // Create test columns
845
+ let col_a = Arc :: new ( Column :: new ( "a" , 0 ) ) as Arc < dyn PhysicalExpr > ;
846
+ let col_b = Arc :: new ( Column :: new ( "b" , 1 ) ) as Arc < dyn PhysicalExpr > ;
847
+ let col_x = Arc :: new ( Column :: new ( "x" , 2 ) ) as Arc < dyn PhysicalExpr > ;
848
+ let col_y = Arc :: new ( Column :: new ( "y" , 3 ) ) as Arc < dyn PhysicalExpr > ;
849
+
850
+ // Create test literals
851
+ let lit_1 =
852
+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 1 ) ) ) ) as Arc < dyn PhysicalExpr > ;
853
+ let lit_2 =
854
+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 2 ) ) ) ) as Arc < dyn PhysicalExpr > ;
855
+
856
+ // Create equivalence group with classes (a = x) and (b = y)
857
+ let eq_group = EquivalenceGroup :: new ( vec ! [
858
+ EquivalenceClass :: new( vec![ Arc :: clone( & col_a) , Arc :: clone( & col_x) ] ) ,
859
+ EquivalenceClass :: new( vec![ Arc :: clone( & col_b) , Arc :: clone( & col_y) ] ) ,
860
+ ] ) ;
861
+
862
+ let test_cases = vec ! [
863
+ // Basic equality tests
864
+ TestCase {
865
+ left: Arc :: clone( & col_a) ,
866
+ right: Arc :: clone( & col_a) ,
867
+ expected: true ,
868
+ description: "Same column should be equal" ,
869
+ } ,
870
+ // Equivalence class tests
871
+ TestCase {
872
+ left: Arc :: clone( & col_a) ,
873
+ right: Arc :: clone( & col_x) ,
874
+ expected: true ,
875
+ description: "Columns in same equivalence class should be equal" ,
876
+ } ,
877
+ TestCase {
878
+ left: Arc :: clone( & col_b) ,
879
+ right: Arc :: clone( & col_y) ,
880
+ expected: true ,
881
+ description: "Columns in same equivalence class should be equal" ,
882
+ } ,
883
+ TestCase {
884
+ left: Arc :: clone( & col_a) ,
885
+ right: Arc :: clone( & col_b) ,
886
+ expected: false ,
887
+ description:
888
+ "Columns in different equivalence classes should not be equal" ,
889
+ } ,
890
+ // Literal tests
891
+ TestCase {
892
+ left: Arc :: clone( & lit_1) ,
893
+ right: Arc :: clone( & lit_1) ,
894
+ expected: true ,
895
+ description: "Same literal should be equal" ,
896
+ } ,
897
+ TestCase {
898
+ left: Arc :: clone( & lit_1) ,
899
+ right: Arc :: clone( & lit_2) ,
900
+ expected: false ,
901
+ description: "Different literals should not be equal" ,
902
+ } ,
903
+ // Complex expression tests
904
+ TestCase {
905
+ left: Arc :: new( BinaryExpr :: new(
906
+ Arc :: clone( & col_a) ,
907
+ Operator :: Plus ,
908
+ Arc :: clone( & col_b) ,
909
+ ) ) as Arc <dyn PhysicalExpr >,
910
+ right: Arc :: new( BinaryExpr :: new(
911
+ Arc :: clone( & col_x) ,
912
+ Operator :: Plus ,
913
+ Arc :: clone( & col_y) ,
914
+ ) ) as Arc <dyn PhysicalExpr >,
915
+ expected: true ,
916
+ description:
917
+ "Binary expressions with equivalent operands should be equal" ,
918
+ } ,
919
+ TestCase {
920
+ left: Arc :: new( BinaryExpr :: new(
921
+ Arc :: clone( & col_a) ,
922
+ Operator :: Plus ,
923
+ Arc :: clone( & col_b) ,
924
+ ) ) as Arc <dyn PhysicalExpr >,
925
+ right: Arc :: new( BinaryExpr :: new(
926
+ Arc :: clone( & col_x) ,
927
+ Operator :: Plus ,
928
+ Arc :: clone( & col_a) ,
929
+ ) ) as Arc <dyn PhysicalExpr >,
930
+ expected: false ,
931
+ description:
932
+ "Binary expressions with non-equivalent operands should not be equal" ,
933
+ } ,
934
+ TestCase {
935
+ left: Arc :: new( BinaryExpr :: new(
936
+ Arc :: clone( & col_a) ,
937
+ Operator :: Plus ,
938
+ Arc :: clone( & lit_1) ,
939
+ ) ) as Arc <dyn PhysicalExpr >,
940
+ right: Arc :: new( BinaryExpr :: new(
941
+ Arc :: clone( & col_x) ,
942
+ Operator :: Plus ,
943
+ Arc :: clone( & lit_1) ,
944
+ ) ) as Arc <dyn PhysicalExpr >,
945
+ expected: true ,
946
+ description: "Binary expressions with equivalent column and same literal should be equal" ,
947
+ } ,
948
+ TestCase {
949
+ left: Arc :: new( BinaryExpr :: new(
950
+ Arc :: new( BinaryExpr :: new(
951
+ Arc :: clone( & col_a) ,
952
+ Operator :: Plus ,
953
+ Arc :: clone( & col_b) ,
954
+ ) ) ,
955
+ Operator :: Multiply ,
956
+ Arc :: clone( & lit_1) ,
957
+ ) ) as Arc <dyn PhysicalExpr >,
958
+ right: Arc :: new( BinaryExpr :: new(
959
+ Arc :: new( BinaryExpr :: new(
960
+ Arc :: clone( & col_x) ,
961
+ Operator :: Plus ,
962
+ Arc :: clone( & col_y) ,
963
+ ) ) ,
964
+ Operator :: Multiply ,
965
+ Arc :: clone( & lit_1) ,
966
+ ) ) as Arc <dyn PhysicalExpr >,
967
+ expected: true ,
968
+ description: "Nested binary expressions with equivalent operands should be equal" ,
969
+ } ,
970
+ ] ;
971
+
972
+ for TestCase {
973
+ left,
974
+ right,
975
+ expected,
976
+ description,
977
+ } in test_cases
978
+ {
979
+ let actual = eq_group. exprs_equal ( & left, & right) ;
980
+ assert_eq ! (
981
+ actual, expected,
982
+ "{}: Failed comparing {:?} and {:?}, expected {}, got {}" ,
983
+ description, left, right, expected, actual
984
+ ) ;
985
+ }
986
+
987
+ Ok ( ( ) )
988
+ }
780
989
}
0 commit comments