@@ -56,9 +56,13 @@ struct Identifier<'n> {
56
56
}
57
57
58
58
impl < ' n > Identifier < ' n > {
59
- fn new ( expr : & ' n Expr , random_state : & RandomState ) -> Self {
59
+ fn new ( expr : & ' n Expr , is_tree : bool , random_state : & RandomState ) -> Self {
60
60
let mut hasher = random_state. build_hasher ( ) ;
61
- expr. hash_node ( & mut hasher) ;
61
+ if is_tree {
62
+ expr. hash ( & mut hasher) ;
63
+ } else {
64
+ expr. hash_node ( & mut hasher) ;
65
+ }
62
66
let hash = hasher. finish ( ) ;
63
67
Self { hash, expr }
64
68
}
@@ -911,31 +915,30 @@ struct ExprIdentifierVisitor<'a, 'n> {
911
915
found_common : bool ,
912
916
}
913
917
914
- /// Record item that used when traversing a expression tree.
918
+ /// Record item that used when traversing an expression tree.
915
919
enum VisitRecord < ' n > {
916
- /// `usize` postorder index assigned in `f-down`(). Starts from 0.
917
- EnterMark ( usize ) ,
918
- /// the node's children were skipped => jump to f_up on same node
919
- JumpMark ,
920
+ /// Contains the post-order index assigned in during the first, visiting traversal and
921
+ /// a boolean flag to indicate if the record marks an expression subtree (not just a
922
+ /// single node).
923
+ EnterMark ( usize , bool ) ,
920
924
/// Accumulated identifier of sub expression.
921
925
ExprItem ( Identifier < ' n > ) ,
922
926
}
923
927
924
928
impl < ' n > ExprIdentifierVisitor < ' _ , ' n > {
925
929
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
926
930
/// before it.
927
- fn pop_enter_mark ( & mut self ) -> Option < ( usize , Option < Identifier < ' n > > ) > {
931
+ fn pop_enter_mark ( & mut self ) -> ( usize , bool , Option < Identifier < ' n > > ) {
928
932
let mut expr_id = None ;
929
933
930
934
while let Some ( item) = self . visit_stack . pop ( ) {
931
935
match item {
932
- VisitRecord :: EnterMark ( idx ) => {
933
- return Some ( ( idx , expr_id) ) ;
936
+ VisitRecord :: EnterMark ( down_index , tree ) => {
937
+ return ( down_index , tree , expr_id) ;
934
938
}
935
939
VisitRecord :: ExprItem ( id) => {
936
940
expr_id = Some ( id. combine ( expr_id) ) ;
937
941
}
938
- VisitRecord :: JumpMark => return None ,
939
942
}
940
943
}
941
944
unreachable ! ( "Enter mark should paired with node number" ) ;
@@ -947,30 +950,32 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
947
950
948
951
fn f_down ( & mut self , expr : & ' n Expr ) -> Result < TreeNodeRecursion > {
949
952
// TODO: consider non-volatile sub-expressions for CSE
950
- // TODO: consider surely executed children of "short circuited"s for CSE
951
-
952
- // If an expression can short circuit its children then don't consider it for CSE
953
- // (https://github.com/apache/arrow-datafusion/issues/8814).
954
- if expr. short_circuits ( ) {
955
- self . visit_stack . push ( VisitRecord :: JumpMark ) ;
956
953
957
- return Ok ( TreeNodeRecursion :: Jump ) ;
958
- }
954
+ // If an expression can short circuit its children then don't consider its
955
+ // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814).
956
+ // This means that we don't recurse into its children, but handle the expression
957
+ // as a subtree when we calculate its identifier.
958
+ // TODO: consider surely executed children of "short circuited"s for CSE
959
+ let is_tree = expr. short_circuits ( ) ;
960
+ let tnr = if is_tree {
961
+ TreeNodeRecursion :: Jump
962
+ } else {
963
+ TreeNodeRecursion :: Continue
964
+ } ;
959
965
960
966
self . id_array . push ( ( 0 , None ) ) ;
961
967
self . visit_stack
962
- . push ( VisitRecord :: EnterMark ( self . down_index ) ) ;
968
+ . push ( VisitRecord :: EnterMark ( self . down_index , is_tree ) ) ;
963
969
self . down_index += 1 ;
964
970
965
- Ok ( TreeNodeRecursion :: Continue )
971
+ Ok ( tnr )
966
972
}
967
973
968
974
fn f_up ( & mut self , expr : & ' n Expr ) -> Result < TreeNodeRecursion > {
969
- let Some ( ( down_index, sub_expr_id) ) = self . pop_enter_mark ( ) else {
970
- return Ok ( TreeNodeRecursion :: Continue ) ;
971
- } ;
975
+ let ( down_index, is_tree, sub_expr_id) = self . pop_enter_mark ( ) ;
972
976
973
- let expr_id = Identifier :: new ( expr, self . random_state ) . combine ( sub_expr_id) ;
977
+ let expr_id =
978
+ Identifier :: new ( expr, is_tree, self . random_state ) . combine ( sub_expr_id) ;
974
979
975
980
self . id_array [ down_index] . 0 = self . up_index ;
976
981
if !self . expr_mask . ignores ( expr) {
@@ -1015,19 +1020,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
1015
1020
self . alias_counter += 1 ;
1016
1021
}
1017
1022
1018
- // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
1019
- // the `id_array`, which records the expr's identifier used to rewrite expr. So if we
1023
+ // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the
1024
+ // `id_array`, which records the expr's identifier used to rewrite expr. So if we
1020
1025
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
1021
- if expr. short_circuits ( ) {
1022
- return Ok ( Transformed :: new ( expr, false , TreeNodeRecursion :: Jump ) ) ;
1023
- }
1026
+ let is_tree = expr. short_circuits ( ) ;
1027
+ let tnr = if is_tree {
1028
+ TreeNodeRecursion :: Jump
1029
+ } else {
1030
+ TreeNodeRecursion :: Continue
1031
+ } ;
1024
1032
1025
1033
let ( up_index, expr_id) = self . id_array [ self . down_index ] ;
1026
1034
self . down_index += 1 ;
1027
1035
1028
1036
// skip `Expr`s without identifier (empty identifier).
1029
1037
let Some ( expr_id) = expr_id else {
1030
- return Ok ( Transformed :: no ( expr) ) ;
1038
+ return Ok ( Transformed :: new ( expr, false , tnr ) ) ;
1031
1039
} ;
1032
1040
1033
1041
let count = self . expr_stats . get ( & expr_id) . unwrap ( ) ;
@@ -1055,7 +1063,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
1055
1063
1056
1064
Ok ( Transformed :: new ( rewritten, true , TreeNodeRecursion :: Jump ) )
1057
1065
} else {
1058
- Ok ( Transformed :: no ( expr) )
1066
+ Ok ( Transformed :: new ( expr, false , tnr ) )
1059
1067
}
1060
1068
}
1061
1069
@@ -1802,4 +1810,32 @@ mod test {
1802
1810
assert ! ( result. len( ) == 1 ) ;
1803
1811
Ok ( ( ) )
1804
1812
}
1813
+
1814
+ #[ test]
1815
+ fn test_short_circuits ( ) -> Result < ( ) > {
1816
+ let table_scan = test_table_scan ( ) ?;
1817
+
1818
+ let extracted_short_circuit = col ( "a" ) . eq ( lit ( 0 ) ) . or ( col ( "b" ) . eq ( lit ( 0 ) ) ) ;
1819
+ let not_extracted_short_circuit_leg_1 = ( col ( "a" ) + col ( "b" ) ) . eq ( lit ( 0 ) ) ;
1820
+ let not_extracted_short_circuit_leg_2 = ( col ( "a" ) - col ( "b" ) ) . eq ( lit ( 0 ) ) ;
1821
+ let plan = LogicalPlanBuilder :: from ( table_scan. clone ( ) )
1822
+ . project ( vec ! [
1823
+ extracted_short_circuit. clone( ) . alias( "c1" ) ,
1824
+ extracted_short_circuit. alias( "c2" ) ,
1825
+ not_extracted_short_circuit_leg_1. clone( ) . alias( "c3" ) ,
1826
+ not_extracted_short_circuit_leg_2. clone( ) . alias( "c4" ) ,
1827
+ not_extracted_short_circuit_leg_1
1828
+ . or( not_extracted_short_circuit_leg_2)
1829
+ . alias( "c5" ) ,
1830
+ ] ) ?
1831
+ . build ( ) ?;
1832
+
1833
+ let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\
1834
+ \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
1835
+ \n TableScan: test";
1836
+
1837
+ assert_optimized_plan_eq ( expected, plan, None ) ;
1838
+
1839
+ Ok ( ( ) )
1840
+ }
1805
1841
}
0 commit comments