17
17
use std:: collections:: { HashMap , HashSet } ;
18
18
use std:: sync:: Arc ;
19
19
20
- use crate :: optimizer:: ApplyOrder ;
21
- use crate :: { OptimizerConfig , OptimizerRule } ;
20
+ use itertools:: Itertools ;
22
21
23
22
use datafusion_common:: tree_node:: {
24
23
Transformed , TransformedResult , TreeNode , TreeNodeRecursion ,
@@ -29,6 +28,7 @@ use datafusion_common::{
29
28
} ;
30
29
use datafusion_expr:: expr:: Alias ;
31
30
use datafusion_expr:: expr_rewriter:: replace_col;
31
+ use datafusion_expr:: logical_plan:: tree_node:: unwrap_arc;
32
32
use datafusion_expr:: logical_plan:: {
33
33
CrossJoin , Join , JoinType , LogicalPlan , TableScan , Union ,
34
34
} ;
@@ -38,7 +38,8 @@ use datafusion_expr::{
38
38
ScalarFunctionDefinition , TableProviderFilterPushDown ,
39
39
} ;
40
40
41
- use itertools:: Itertools ;
41
+ use crate :: optimizer:: ApplyOrder ;
42
+ use crate :: { OptimizerConfig , OptimizerRule } ;
42
43
43
44
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
44
45
/// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
407
408
right : & LogicalPlan ,
408
409
on_filter : Vec < Expr > ,
409
410
is_inner_join : bool ,
410
- ) -> Result < LogicalPlan > {
411
+ ) -> Result < Transformed < LogicalPlan > > {
411
412
let on_filter_empty = on_filter. is_empty ( ) ;
412
413
// Get pushable predicates from current optimizer state
413
414
let ( left_preserved, right_preserved) = lr_is_preserved ( join_plan) ?;
@@ -505,41 +506,43 @@ fn push_down_all_join(
505
506
// wrap the join on the filter whose predicates must be kept
506
507
match conjunction ( keep_predicates) {
507
508
Some ( predicate) => {
508
- Filter :: try_new ( predicate, Arc :: new ( plan) ) . map ( LogicalPlan :: Filter )
509
+ let new_filter_plan = Filter :: try_new ( predicate, Arc :: new ( plan) ) ?;
510
+ Ok ( Transformed :: yes ( LogicalPlan :: Filter ( new_filter_plan) ) )
509
511
}
510
- None => Ok ( plan) ,
512
+ None => Ok ( Transformed :: no ( plan) ) ,
511
513
}
512
514
}
513
515
514
516
fn push_down_join (
515
517
plan : & LogicalPlan ,
516
518
join : & Join ,
517
519
parent_predicate : Option < & Expr > ,
518
- ) -> Result < Option < LogicalPlan > > {
519
- let predicates = match parent_predicate {
520
- Some ( parent_predicate) => split_conjunction_owned ( parent_predicate. clone ( ) ) ,
521
- None => vec ! [ ] ,
522
- } ;
520
+ ) -> Result < Transformed < LogicalPlan > > {
521
+ // Split the parent predicate into individual conjunctive parts.
522
+ let predicates = parent_predicate
523
+ . map_or_else ( Vec :: new, |pred| split_conjunction_owned ( pred. clone ( ) ) ) ;
523
524
524
- // Convert JOIN ON predicate to Predicates
525
+ // Extract conjunctions from the JOIN's ON filter, if present.
525
526
let on_filters = join
526
527
. filter
527
528
. as_ref ( )
528
- . map ( |e| split_conjunction_owned ( e. clone ( ) ) )
529
- . unwrap_or_default ( ) ;
529
+ . map_or_else ( Vec :: new, |filter| split_conjunction_owned ( filter. clone ( ) ) ) ;
530
530
531
531
let mut is_inner_join = false ;
532
532
let infer_predicates = if join. join_type == JoinType :: Inner {
533
533
is_inner_join = true ;
534
+
534
535
// Only allow both side key is column.
535
536
let join_col_keys = join
536
537
. on
537
538
. iter ( )
538
- . flat_map ( |( l, r) | match ( l. try_into_col ( ) , r. try_into_col ( ) ) {
539
- ( Ok ( l_col) , Ok ( r_col) ) => Some ( ( l_col, r_col) ) ,
540
- _ => None ,
539
+ . filter_map ( |( l, r) | {
540
+ let left_col = l. try_into_col ( ) . ok ( ) ?;
541
+ let right_col = r. try_into_col ( ) . ok ( ) ?;
542
+ Some ( ( left_col, right_col) )
541
543
} )
542
544
. collect :: < Vec < _ > > ( ) ;
545
+
543
546
// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
544
547
// For inner joins, duplicate filters for joined columns so filters can be pushed down
545
548
// to both sides. Take the following query as an example:
@@ -559,6 +562,7 @@ fn push_down_join(
559
562
. chain ( on_filters. iter ( ) )
560
563
. filter_map ( |predicate| {
561
564
let mut join_cols_to_replace = HashMap :: new ( ) ;
565
+
562
566
let columns = match predicate. to_columns ( ) {
563
567
Ok ( columns) => columns,
564
568
Err ( e) => return Some ( Err ( e) ) ,
@@ -596,20 +600,32 @@ fn push_down_join(
596
600
} ;
597
601
598
602
if on_filters. is_empty ( ) && predicates. is_empty ( ) && infer_predicates. is_empty ( ) {
599
- return Ok ( None ) ;
603
+ return Ok ( Transformed :: no ( plan . clone ( ) ) ) ;
600
604
}
601
- Ok ( Some ( push_down_all_join (
605
+
606
+ match push_down_all_join (
602
607
predicates,
603
608
infer_predicates,
604
609
plan,
605
610
& join. left ,
606
611
& join. right ,
607
612
on_filters,
608
613
is_inner_join,
609
- ) ?) )
614
+ ) {
615
+ Ok ( plan) => Ok ( Transformed :: yes ( plan. data ) ) ,
616
+ Err ( e) => Err ( e) ,
617
+ }
610
618
}
611
619
612
620
impl OptimizerRule for PushDownFilter {
621
+ fn try_optimize (
622
+ & self ,
623
+ _plan : & LogicalPlan ,
624
+ _config : & dyn OptimizerConfig ,
625
+ ) -> Result < Option < LogicalPlan > > {
626
+ internal_err ! ( "Should have called PushDownFilter::rewrite" )
627
+ }
628
+
613
629
fn name ( & self ) -> & str {
614
630
"push_down_filter"
615
631
}
@@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
618
634
Some ( ApplyOrder :: TopDown )
619
635
}
620
636
621
- fn try_optimize (
637
+ fn supports_rewrite ( & self ) -> bool {
638
+ true
639
+ }
640
+
641
+ fn rewrite (
622
642
& self ,
623
- plan : & LogicalPlan ,
643
+ plan : LogicalPlan ,
624
644
_config : & dyn OptimizerConfig ,
625
- ) -> Result < Option < LogicalPlan > > {
645
+ ) -> Result < Transformed < LogicalPlan > > {
626
646
let filter = match plan {
627
- LogicalPlan :: Filter ( filter) => filter,
628
- // we also need to pushdown filter in Join.
629
- LogicalPlan :: Join ( join) => return push_down_join ( plan, join, None ) ,
630
- _ => return Ok ( None ) ,
647
+ LogicalPlan :: Filter ( ref filter) => filter,
648
+ LogicalPlan :: Join ( ref join) => return push_down_join ( & plan, join, None ) ,
649
+ _ => return Ok ( Transformed :: no ( plan) ) ,
631
650
} ;
632
651
633
652
let child_plan = filter. input . as_ref ( ) ;
634
653
let new_plan = match child_plan {
635
- LogicalPlan :: Filter ( child_filter) => {
654
+ LogicalPlan :: Filter ( ref child_filter) => {
636
655
let parents_predicates = split_conjunction ( & filter. predicate ) ;
637
656
let set: HashSet < & & Expr > = parents_predicates. iter ( ) . collect ( ) ;
638
657
@@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
652
671
new_predicate,
653
672
child_filter. input . clone ( ) ,
654
673
) ?) ;
655
- self . try_optimize ( & new_filter, _config) ?
656
- . unwrap_or ( new_filter)
674
+ self . rewrite ( new_filter, _config) ?. data
657
675
}
658
676
LogicalPlan :: Repartition ( _)
659
677
| LogicalPlan :: Distinct ( _)
660
678
| LogicalPlan :: Sort ( _) => {
661
- // commutable
662
679
let new_filter = plan. with_new_exprs (
663
680
plan. expressions ( ) ,
664
681
vec ! [ child_plan. inputs( ) [ 0 ] . clone( ) ] ,
665
682
) ?;
666
683
child_plan. with_new_exprs ( child_plan. expressions ( ) , vec ! [ new_filter] ) ?
667
684
}
668
- LogicalPlan :: SubqueryAlias ( subquery_alias) => {
685
+ LogicalPlan :: SubqueryAlias ( ref subquery_alias) => {
669
686
let mut replace_map = HashMap :: new ( ) ;
670
687
for ( i, ( qualifier, field) ) in
671
688
subquery_alias. input . schema ( ) . iter ( ) . enumerate ( )
@@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
685
702
) ?) ;
686
703
child_plan. with_new_exprs ( child_plan. expressions ( ) , vec ! [ new_filter] ) ?
687
704
}
688
- LogicalPlan :: Projection ( projection) => {
705
+ LogicalPlan :: Projection ( ref projection) => {
689
706
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
690
707
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
691
708
// collect projection.
@@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
742
759
}
743
760
}
744
761
}
745
- None => return Ok ( None ) ,
762
+ None => return Ok ( Transformed :: no ( plan ) ) ,
746
763
}
747
764
}
748
- LogicalPlan :: Union ( union) => {
765
+ LogicalPlan :: Union ( ref union) => {
749
766
let mut inputs = Vec :: with_capacity ( union. inputs . len ( ) ) ;
750
767
for input in & union. inputs {
751
768
let mut replace_map = HashMap :: new ( ) ;
@@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
770
787
schema : plan. schema ( ) . clone ( ) ,
771
788
} )
772
789
}
773
- LogicalPlan :: Aggregate ( agg) => {
790
+ LogicalPlan :: Aggregate ( ref agg) => {
774
791
// We can push down Predicate which in groupby_expr.
775
792
let group_expr_columns = agg
776
793
. group_expr
@@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
821
838
None => new_agg,
822
839
}
823
840
}
824
- LogicalPlan :: Join ( join) => {
825
- match push_down_join ( & filter. input , join, Some ( & filter. predicate ) ) ? {
826
- Some ( optimized_plan) => optimized_plan,
827
- None => return Ok ( None ) ,
828
- }
841
+ LogicalPlan :: Join ( ref join) => {
842
+ push_down_join (
843
+ & unwrap_arc ( filter. clone ( ) . input ) ,
844
+ join,
845
+ Some ( & filter. predicate ) ,
846
+ ) ?
847
+ . data
829
848
}
830
- LogicalPlan :: CrossJoin ( cross_join) => {
849
+ LogicalPlan :: CrossJoin ( ref cross_join) => {
831
850
let predicates = split_conjunction_owned ( filter. predicate . clone ( ) ) ;
832
851
let join = convert_cross_join_to_inner_join ( cross_join. clone ( ) ) ?;
833
852
let join_plan = LogicalPlan :: Join ( join) ;
@@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
843
862
vec ! [ ] ,
844
863
true ,
845
864
) ?;
846
- convert_to_cross_join_if_beneficial ( plan) ?
865
+ convert_to_cross_join_if_beneficial ( plan. data ) ?
847
866
}
848
- LogicalPlan :: TableScan ( scan) => {
867
+ LogicalPlan :: TableScan ( ref scan) => {
849
868
let filter_predicates = split_conjunction ( & filter. predicate ) ;
850
869
let results = scan
851
870
. source
@@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
892
911
None => new_scan,
893
912
}
894
913
}
895
- LogicalPlan :: Extension ( extension_plan) => {
914
+ LogicalPlan :: Extension ( ref extension_plan) => {
896
915
let prevent_cols =
897
916
extension_plan. node . prevent_predicate_push_down_columns ( ) ;
898
917
@@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
935
954
None => new_extension,
936
955
}
937
956
}
938
- _ => return Ok ( None ) ,
957
+ _ => return Ok ( Transformed :: no ( plan ) ) ,
939
958
} ;
940
- Ok ( Some ( new_plan) )
959
+
960
+ Ok ( Transformed :: yes ( new_plan) )
941
961
}
942
962
}
943
963
@@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1024
1044
1025
1045
#[ cfg( test) ]
1026
1046
mod tests {
1027
- use super :: * ;
1028
1047
use std:: any:: Any ;
1029
1048
use std:: fmt:: { Debug , Formatter } ;
1030
1049
1031
- use crate :: optimizer:: Optimizer ;
1032
- use crate :: rewrite_disjunctive_predicate:: RewriteDisjunctivePredicate ;
1033
- use crate :: test:: * ;
1034
- use crate :: OptimizerContext ;
1035
-
1036
1050
use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
1051
+ use async_trait:: async_trait;
1052
+
1037
1053
use datafusion_common:: ScalarValue ;
1038
1054
use datafusion_expr:: expr:: ScalarFunction ;
1039
1055
use datafusion_expr:: logical_plan:: table_scan;
@@ -1043,7 +1059,13 @@ mod tests {
1043
1059
Volatility ,
1044
1060
} ;
1045
1061
1046
- use async_trait:: async_trait;
1062
+ use crate :: optimizer:: Optimizer ;
1063
+ use crate :: rewrite_disjunctive_predicate:: RewriteDisjunctivePredicate ;
1064
+ use crate :: test:: * ;
1065
+ use crate :: OptimizerContext ;
1066
+
1067
+ use super :: * ;
1068
+
1047
1069
fn observe ( _plan : & LogicalPlan , _rule : & dyn OptimizerRule ) { }
1048
1070
1049
1071
fn assert_optimized_plan_eq ( plan : LogicalPlan , expected : & str ) -> Result < ( ) > {
@@ -2298,9 +2320,9 @@ mod tests {
2298
2320
table_scan_with_pushdown_provider ( TableProviderFilterPushDown :: Inexact ) ?;
2299
2321
2300
2322
let optimized_plan = PushDownFilter :: new ( )
2301
- . try_optimize ( & plan, & OptimizerContext :: new ( ) )
2323
+ . rewrite ( plan, & OptimizerContext :: new ( ) )
2302
2324
. expect ( "failed to optimize plan" )
2303
- . unwrap ( ) ;
2325
+ . data ;
2304
2326
2305
2327
let expected = "\
2306
2328
Filter: a = Int64(1)\
@@ -2667,8 +2689,9 @@ Projection: a, b
2667
2689
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
2668
2690
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
2669
2691
let optimized_plan = PushDownFilter :: new ( )
2670
- . try_optimize ( & plan, & OptimizerContext :: new ( ) ) ?
2671
- . expect ( "failed to optimize plan" ) ;
2692
+ . rewrite ( plan, & OptimizerContext :: new ( ) )
2693
+ . expect ( "failed to optimize plan" )
2694
+ . data ;
2672
2695
assert_optimized_plan_eq ( optimized_plan, expected)
2673
2696
}
2674
2697
0 commit comments