17
17
use std:: collections:: { HashMap , HashSet } ;
18
18
use std:: sync:: Arc ;
19
19
20
- use crate :: optimizer:: ApplyOrder ;
21
- use crate :: { OptimizerConfig , OptimizerRule } ;
20
+ use itertools:: Itertools ;
22
21
23
22
use datafusion_common:: tree_node:: {
24
23
Transformed , TransformedResult , TreeNode , TreeNodeRecursion ,
@@ -29,6 +28,7 @@ use datafusion_common::{
29
28
} ;
30
29
use datafusion_expr:: expr:: Alias ;
31
30
use datafusion_expr:: expr_rewriter:: replace_col;
31
+ use datafusion_expr:: logical_plan:: tree_node:: unwrap_arc;
32
32
use datafusion_expr:: logical_plan:: {
33
33
CrossJoin , Join , JoinType , LogicalPlan , TableScan , Union ,
34
34
} ;
@@ -38,7 +38,8 @@ use datafusion_expr::{
38
38
ScalarFunctionDefinition , TableProviderFilterPushDown ,
39
39
} ;
40
40
41
- use itertools:: Itertools ;
41
+ use crate :: optimizer:: ApplyOrder ;
42
+ use crate :: { OptimizerConfig , OptimizerRule } ;
42
43
43
44
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
44
45
/// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
407
408
right : & LogicalPlan ,
408
409
on_filter : Vec < Expr > ,
409
410
is_inner_join : bool ,
410
- ) -> Result < LogicalPlan > {
411
+ ) -> Result < Transformed < LogicalPlan > > {
411
412
let on_filter_empty = on_filter. is_empty ( ) ;
412
413
// Get pushable predicates from current optimizer state
413
414
let ( left_preserved, right_preserved) = lr_is_preserved ( join_plan) ?;
@@ -502,44 +503,45 @@ fn push_down_all_join(
502
503
exprs. extend ( join_conditions. into_iter ( ) . reduce ( Expr :: and) ) ;
503
504
let plan = join_plan. with_new_exprs ( exprs, vec ! [ left, right] ) ?;
504
505
505
- // wrap the join on the filter whose predicates must be kept
506
506
match conjunction ( keep_predicates) {
507
507
Some ( predicate) => {
508
- Filter :: try_new ( predicate, Arc :: new ( plan) ) . map ( LogicalPlan :: Filter )
508
+ let new_filter_plan = Filter :: try_new ( predicate, Arc :: new ( plan) ) ?;
509
+ Ok ( Transformed :: yes ( LogicalPlan :: Filter ( new_filter_plan) ) )
509
510
}
510
- None => Ok ( plan) ,
511
+ None => Ok ( Transformed :: no ( plan) ) ,
511
512
}
512
513
}
513
514
514
515
fn push_down_join (
515
- plan : & LogicalPlan ,
516
+ plan : LogicalPlan ,
516
517
join : & Join ,
517
518
parent_predicate : Option < & Expr > ,
518
- ) -> Result < Option < LogicalPlan > > {
519
- let predicates = match parent_predicate {
520
- Some ( parent_predicate) => split_conjunction_owned ( parent_predicate. clone ( ) ) ,
521
- None => vec ! [ ] ,
522
- } ;
519
+ ) -> Result < Transformed < LogicalPlan > > {
520
+ // Split the parent predicate into individual conjunctive parts.
521
+ let predicates = parent_predicate
522
+ . map_or_else ( Vec :: new, |pred| split_conjunction_owned ( pred. clone ( ) ) ) ;
523
523
524
- // Convert JOIN ON predicate to Predicates
524
+ // Extract conjunctions from the JOIN's ON filter, if present.
525
525
let on_filters = join
526
526
. filter
527
527
. as_ref ( )
528
- . map ( |e| split_conjunction_owned ( e. clone ( ) ) )
529
- . unwrap_or_default ( ) ;
528
+ . map_or_else ( Vec :: new, |filter| split_conjunction_owned ( filter. clone ( ) ) ) ;
530
529
531
530
let mut is_inner_join = false ;
532
531
let infer_predicates = if join. join_type == JoinType :: Inner {
533
532
is_inner_join = true ;
533
+
534
534
// Only allow both side key is column.
535
535
let join_col_keys = join
536
536
. on
537
537
. iter ( )
538
- . flat_map ( |( l, r) | match ( l. try_into_col ( ) , r. try_into_col ( ) ) {
539
- ( Ok ( l_col) , Ok ( r_col) ) => Some ( ( l_col, r_col) ) ,
540
- _ => None ,
538
+ . filter_map ( |( l, r) | {
539
+ let left_col = l. try_into_col ( ) . ok ( ) ?;
540
+ let right_col = r. try_into_col ( ) . ok ( ) ?;
541
+ Some ( ( left_col, right_col) )
541
542
} )
542
543
. collect :: < Vec < _ > > ( ) ;
544
+
543
545
// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
544
546
// For inner joins, duplicate filters for joined columns so filters can be pushed down
545
547
// to both sides. Take the following query as an example:
@@ -559,6 +561,7 @@ fn push_down_join(
559
561
. chain ( on_filters. iter ( ) )
560
562
. filter_map ( |predicate| {
561
563
let mut join_cols_to_replace = HashMap :: new ( ) ;
564
+
562
565
let columns = match predicate. to_columns ( ) {
563
566
Ok ( columns) => columns,
564
567
Err ( e) => return Some ( Err ( e) ) ,
@@ -596,20 +599,32 @@ fn push_down_join(
596
599
} ;
597
600
598
601
if on_filters. is_empty ( ) && predicates. is_empty ( ) && infer_predicates. is_empty ( ) {
599
- return Ok ( None ) ;
602
+ return Ok ( Transformed :: no ( plan . clone ( ) ) ) ;
600
603
}
601
- Ok ( Some ( push_down_all_join (
604
+
605
+ match push_down_all_join (
602
606
predicates,
603
607
infer_predicates,
604
- plan,
608
+ & plan,
605
609
& join. left ,
606
610
& join. right ,
607
611
on_filters,
608
612
is_inner_join,
609
- ) ?) )
613
+ ) {
614
+ Ok ( plan) => Ok ( Transformed :: yes ( plan. data ) ) ,
615
+ Err ( e) => Err ( e) ,
616
+ }
610
617
}
611
618
612
619
impl OptimizerRule for PushDownFilter {
620
+ fn try_optimize (
621
+ & self ,
622
+ _plan : & LogicalPlan ,
623
+ _config : & dyn OptimizerConfig ,
624
+ ) -> Result < Option < LogicalPlan > > {
625
+ internal_err ! ( "Should have called PushDownFilter::rewrite" )
626
+ }
627
+
613
628
fn name ( & self ) -> & str {
614
629
"push_down_filter"
615
630
}
@@ -618,21 +633,26 @@ impl OptimizerRule for PushDownFilter {
618
633
Some ( ApplyOrder :: TopDown )
619
634
}
620
635
621
- fn try_optimize (
636
+ fn supports_rewrite ( & self ) -> bool {
637
+ true
638
+ }
639
+
640
+ fn rewrite (
622
641
& self ,
623
- plan : & LogicalPlan ,
642
+ plan : LogicalPlan ,
624
643
_config : & dyn OptimizerConfig ,
625
- ) -> Result < Option < LogicalPlan > > {
644
+ ) -> Result < Transformed < LogicalPlan > > {
626
645
let filter = match plan {
627
- LogicalPlan :: Filter ( filter) => filter,
628
- // we also need to pushdown filter in Join.
629
- LogicalPlan :: Join ( join) => return push_down_join ( plan, join, None ) ,
630
- _ => return Ok ( None ) ,
646
+ LogicalPlan :: Filter ( ref filter) => filter,
647
+ LogicalPlan :: Join ( ref join) => {
648
+ return push_down_join ( plan. clone ( ) , join, None )
649
+ }
650
+ _ => return Ok ( Transformed :: no ( plan) ) ,
631
651
} ;
632
652
633
- let child_plan = filter. input . as_ref ( ) ;
653
+ let child_plan = unwrap_arc ( filter. clone ( ) . input ) ;
634
654
let new_plan = match child_plan {
635
- LogicalPlan :: Filter ( child_filter) => {
655
+ LogicalPlan :: Filter ( ref child_filter) => {
636
656
let parents_predicates = split_conjunction ( & filter. predicate ) ;
637
657
let set: HashSet < & & Expr > = parents_predicates. iter ( ) . collect ( ) ;
638
658
@@ -652,20 +672,18 @@ impl OptimizerRule for PushDownFilter {
652
672
new_predicate,
653
673
child_filter. input . clone ( ) ,
654
674
) ?) ;
655
- self . try_optimize ( & new_filter, _config) ?
656
- . unwrap_or ( new_filter)
675
+ self . rewrite ( new_filter, _config) ?. data
657
676
}
658
677
LogicalPlan :: Repartition ( _)
659
678
| LogicalPlan :: Distinct ( _)
660
679
| LogicalPlan :: Sort ( _) => {
661
- // commutable
662
680
let new_filter = plan. with_new_exprs (
663
681
plan. expressions ( ) ,
664
682
vec ! [ child_plan. inputs( ) [ 0 ] . clone( ) ] ,
665
683
) ?;
666
684
child_plan. with_new_exprs ( child_plan. expressions ( ) , vec ! [ new_filter] ) ?
667
685
}
668
- LogicalPlan :: SubqueryAlias ( subquery_alias) => {
686
+ LogicalPlan :: SubqueryAlias ( ref subquery_alias) => {
669
687
let mut replace_map = HashMap :: new ( ) ;
670
688
for ( i, ( qualifier, field) ) in
671
689
subquery_alias. input . schema ( ) . iter ( ) . enumerate ( )
@@ -685,7 +703,7 @@ impl OptimizerRule for PushDownFilter {
685
703
) ?) ;
686
704
child_plan. with_new_exprs ( child_plan. expressions ( ) , vec ! [ new_filter] ) ?
687
705
}
688
- LogicalPlan :: Projection ( projection) => {
706
+ LogicalPlan :: Projection ( ref projection) => {
689
707
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
690
708
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
691
709
// collect projection.
@@ -742,10 +760,10 @@ impl OptimizerRule for PushDownFilter {
742
760
}
743
761
}
744
762
}
745
- None => return Ok ( None ) ,
763
+ None => return Ok ( Transformed :: no ( plan ) ) ,
746
764
}
747
765
}
748
- LogicalPlan :: Union ( union) => {
766
+ LogicalPlan :: Union ( ref union) => {
749
767
let mut inputs = Vec :: with_capacity ( union. inputs . len ( ) ) ;
750
768
for input in & union. inputs {
751
769
let mut replace_map = HashMap :: new ( ) ;
@@ -770,7 +788,7 @@ impl OptimizerRule for PushDownFilter {
770
788
schema : plan. schema ( ) . clone ( ) ,
771
789
} )
772
790
}
773
- LogicalPlan :: Aggregate ( agg) => {
791
+ LogicalPlan :: Aggregate ( ref agg) => {
774
792
// We can push down Predicate which in groupby_expr.
775
793
let group_expr_columns = agg
776
794
. group_expr
@@ -821,13 +839,11 @@ impl OptimizerRule for PushDownFilter {
821
839
None => new_agg,
822
840
}
823
841
}
824
- LogicalPlan :: Join ( join) => {
825
- match push_down_join ( & filter. input , join, Some ( & filter. predicate ) ) ? {
826
- Some ( optimized_plan) => optimized_plan,
827
- None => return Ok ( None ) ,
828
- }
842
+ LogicalPlan :: Join ( ref join) => {
843
+ let unwrapped_plan = unwrap_arc ( filter. clone ( ) . input ) ;
844
+ push_down_join ( unwrapped_plan, join, Some ( & filter. predicate ) ) ?. data
829
845
}
830
- LogicalPlan :: CrossJoin ( cross_join) => {
846
+ LogicalPlan :: CrossJoin ( ref cross_join) => {
831
847
let predicates = split_conjunction_owned ( filter. predicate . clone ( ) ) ;
832
848
let join = convert_cross_join_to_inner_join ( cross_join. clone ( ) ) ?;
833
849
let join_plan = LogicalPlan :: Join ( join) ;
@@ -843,9 +859,9 @@ impl OptimizerRule for PushDownFilter {
843
859
vec ! [ ] ,
844
860
true ,
845
861
) ?;
846
- convert_to_cross_join_if_beneficial ( plan) ?
862
+ convert_to_cross_join_if_beneficial ( plan. data ) ?
847
863
}
848
- LogicalPlan :: TableScan ( scan) => {
864
+ LogicalPlan :: TableScan ( ref scan) => {
849
865
let filter_predicates = split_conjunction ( & filter. predicate ) ;
850
866
let results = scan
851
867
. source
@@ -892,7 +908,7 @@ impl OptimizerRule for PushDownFilter {
892
908
None => new_scan,
893
909
}
894
910
}
895
- LogicalPlan :: Extension ( extension_plan) => {
911
+ LogicalPlan :: Extension ( ref extension_plan) => {
896
912
let prevent_cols =
897
913
extension_plan. node . prevent_predicate_push_down_columns ( ) ;
898
914
@@ -935,9 +951,10 @@ impl OptimizerRule for PushDownFilter {
935
951
None => new_extension,
936
952
}
937
953
}
938
- _ => return Ok ( None ) ,
954
+ _ => return Ok ( Transformed :: no ( plan ) ) ,
939
955
} ;
940
- Ok ( Some ( new_plan) )
956
+
957
+ Ok ( Transformed :: yes ( new_plan) )
941
958
}
942
959
}
943
960
@@ -1024,16 +1041,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1024
1041
1025
1042
#[ cfg( test) ]
1026
1043
mod tests {
1027
- use super :: * ;
1028
1044
use std:: any:: Any ;
1029
1045
use std:: fmt:: { Debug , Formatter } ;
1030
1046
1031
- use crate :: optimizer:: Optimizer ;
1032
- use crate :: rewrite_disjunctive_predicate:: RewriteDisjunctivePredicate ;
1033
- use crate :: test:: * ;
1034
- use crate :: OptimizerContext ;
1035
-
1036
1047
use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
1048
+ use async_trait:: async_trait;
1049
+
1037
1050
use datafusion_common:: ScalarValue ;
1038
1051
use datafusion_expr:: expr:: ScalarFunction ;
1039
1052
use datafusion_expr:: logical_plan:: table_scan;
@@ -1043,7 +1056,13 @@ mod tests {
1043
1056
Volatility ,
1044
1057
} ;
1045
1058
1046
- use async_trait:: async_trait;
1059
+ use crate :: optimizer:: Optimizer ;
1060
+ use crate :: rewrite_disjunctive_predicate:: RewriteDisjunctivePredicate ;
1061
+ use crate :: test:: * ;
1062
+ use crate :: OptimizerContext ;
1063
+
1064
+ use super :: * ;
1065
+
1047
1066
fn observe ( _plan : & LogicalPlan , _rule : & dyn OptimizerRule ) { }
1048
1067
1049
1068
fn assert_optimized_plan_eq ( plan : LogicalPlan , expected : & str ) -> Result < ( ) > {
@@ -2298,9 +2317,9 @@ mod tests {
2298
2317
table_scan_with_pushdown_provider ( TableProviderFilterPushDown :: Inexact ) ?;
2299
2318
2300
2319
let optimized_plan = PushDownFilter :: new ( )
2301
- . try_optimize ( & plan, & OptimizerContext :: new ( ) )
2320
+ . rewrite ( plan, & OptimizerContext :: new ( ) )
2302
2321
. expect ( "failed to optimize plan" )
2303
- . unwrap ( ) ;
2322
+ . data ;
2304
2323
2305
2324
let expected = "\
2306
2325
Filter: a = Int64(1)\
@@ -2667,8 +2686,9 @@ Projection: a, b
2667
2686
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
2668
2687
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
2669
2688
let optimized_plan = PushDownFilter :: new ( )
2670
- . try_optimize ( & plan, & OptimizerContext :: new ( ) ) ?
2671
- . expect ( "failed to optimize plan" ) ;
2689
+ . rewrite ( plan, & OptimizerContext :: new ( ) )
2690
+ . expect ( "failed to optimize plan" )
2691
+ . data ;
2672
2692
assert_optimized_plan_eq ( optimized_plan, expected)
2673
2693
}
2674
2694
0 commit comments