Skip to content

Commit 43e8e92

Browse files
committed
Optimized push down filter apache#10291
1 parent 3b245ff commit 43e8e92

File tree

1 file changed

+82
-62
lines changed

1 file changed

+82
-62
lines changed

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 82 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
use std::collections::{HashMap, HashSet};
1818
use std::sync::Arc;
1919

20-
use crate::optimizer::ApplyOrder;
21-
use crate::{OptimizerConfig, OptimizerRule};
20+
use itertools::Itertools;
2221

2322
use datafusion_common::tree_node::{
2423
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
@@ -29,6 +28,7 @@ use datafusion_common::{
2928
};
3029
use datafusion_expr::expr::Alias;
3130
use datafusion_expr::expr_rewriter::replace_col;
31+
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
3232
use datafusion_expr::logical_plan::{
3333
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
3434
};
@@ -38,7 +38,8 @@ use datafusion_expr::{
3838
ScalarFunctionDefinition, TableProviderFilterPushDown,
3939
};
4040

41-
use itertools::Itertools;
41+
use crate::optimizer::ApplyOrder;
42+
use crate::{OptimizerConfig, OptimizerRule};
4243

4344
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
4445
/// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
407408
right: &LogicalPlan,
408409
on_filter: Vec<Expr>,
409410
is_inner_join: bool,
410-
) -> Result<LogicalPlan> {
411+
) -> Result<Transformed<LogicalPlan>> {
411412
let on_filter_empty = on_filter.is_empty();
412413
// Get pushable predicates from current optimizer state
413414
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
@@ -502,44 +503,45 @@ fn push_down_all_join(
502503
exprs.extend(join_conditions.into_iter().reduce(Expr::and));
503504
let plan = join_plan.with_new_exprs(exprs, vec![left, right])?;
504505

505-
// wrap the join on the filter whose predicates must be kept
506506
match conjunction(keep_predicates) {
507507
Some(predicate) => {
508-
Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter)
508+
let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
509+
Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
509510
}
510-
None => Ok(plan),
511+
None => Ok(Transformed::no(plan)),
511512
}
512513
}
513514

514515
fn push_down_join(
515-
plan: &LogicalPlan,
516+
plan: LogicalPlan,
516517
join: &Join,
517518
parent_predicate: Option<&Expr>,
518-
) -> Result<Option<LogicalPlan>> {
519-
let predicates = match parent_predicate {
520-
Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()),
521-
None => vec![],
522-
};
519+
) -> Result<Transformed<LogicalPlan>> {
520+
// Split the parent predicate into individual conjunctive parts.
521+
let predicates = parent_predicate
522+
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
523523

524-
// Convert JOIN ON predicate to Predicates
524+
// Extract conjunctions from the JOIN's ON filter, if present.
525525
let on_filters = join
526526
.filter
527527
.as_ref()
528-
.map(|e| split_conjunction_owned(e.clone()))
529-
.unwrap_or_default();
528+
.map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
530529

531530
let mut is_inner_join = false;
532531
let infer_predicates = if join.join_type == JoinType::Inner {
533532
is_inner_join = true;
533+
534534
// Only allow both side key is column.
535535
let join_col_keys = join
536536
.on
537537
.iter()
538-
.flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
539-
(Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
540-
_ => None,
538+
.filter_map(|(l, r)| {
539+
let left_col = l.try_into_col().ok()?;
540+
let right_col = r.try_into_col().ok()?;
541+
Some((left_col, right_col))
541542
})
542543
.collect::<Vec<_>>();
544+
543545
// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
544546
// For inner joins, duplicate filters for joined columns so filters can be pushed down
545547
// to both sides. Take the following query as an example:
@@ -559,6 +561,7 @@ fn push_down_join(
559561
.chain(on_filters.iter())
560562
.filter_map(|predicate| {
561563
let mut join_cols_to_replace = HashMap::new();
564+
562565
let columns = match predicate.to_columns() {
563566
Ok(columns) => columns,
564567
Err(e) => return Some(Err(e)),
@@ -596,20 +599,32 @@ fn push_down_join(
596599
};
597600

598601
if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() {
599-
return Ok(None);
602+
return Ok(Transformed::no(plan.clone()));
600603
}
601-
Ok(Some(push_down_all_join(
604+
605+
match push_down_all_join(
602606
predicates,
603607
infer_predicates,
604-
plan,
608+
&plan,
605609
&join.left,
606610
&join.right,
607611
on_filters,
608612
is_inner_join,
609-
)?))
613+
) {
614+
Ok(plan) => Ok(Transformed::yes(plan.data)),
615+
Err(e) => Err(e),
616+
}
610617
}
611618

612619
impl OptimizerRule for PushDownFilter {
620+
fn try_optimize(
621+
&self,
622+
_plan: &LogicalPlan,
623+
_config: &dyn OptimizerConfig,
624+
) -> Result<Option<LogicalPlan>> {
625+
internal_err!("Should have called PushDownFilter::rewrite")
626+
}
627+
613628
fn name(&self) -> &str {
614629
"push_down_filter"
615630
}
@@ -618,21 +633,26 @@ impl OptimizerRule for PushDownFilter {
618633
Some(ApplyOrder::TopDown)
619634
}
620635

621-
fn try_optimize(
636+
fn supports_rewrite(&self) -> bool {
637+
true
638+
}
639+
640+
fn rewrite(
622641
&self,
623-
plan: &LogicalPlan,
642+
plan: LogicalPlan,
624643
_config: &dyn OptimizerConfig,
625-
) -> Result<Option<LogicalPlan>> {
644+
) -> Result<Transformed<LogicalPlan>> {
626645
let filter = match plan {
627-
LogicalPlan::Filter(filter) => filter,
628-
// we also need to pushdown filter in Join.
629-
LogicalPlan::Join(join) => return push_down_join(plan, join, None),
630-
_ => return Ok(None),
646+
LogicalPlan::Filter(ref filter) => filter,
647+
LogicalPlan::Join(ref join) => {
648+
return push_down_join(plan.clone(), join, None)
649+
}
650+
_ => return Ok(Transformed::no(plan)),
631651
};
632652

633-
let child_plan = filter.input.as_ref();
653+
let child_plan = unwrap_arc(filter.clone().input);
634654
let new_plan = match child_plan {
635-
LogicalPlan::Filter(child_filter) => {
655+
LogicalPlan::Filter(ref child_filter) => {
636656
let parents_predicates = split_conjunction(&filter.predicate);
637657
let set: HashSet<&&Expr> = parents_predicates.iter().collect();
638658

@@ -652,20 +672,18 @@ impl OptimizerRule for PushDownFilter {
652672
new_predicate,
653673
child_filter.input.clone(),
654674
)?);
655-
self.try_optimize(&new_filter, _config)?
656-
.unwrap_or(new_filter)
675+
self.rewrite(new_filter, _config)?.data
657676
}
658677
LogicalPlan::Repartition(_)
659678
| LogicalPlan::Distinct(_)
660679
| LogicalPlan::Sort(_) => {
661-
// commutable
662680
let new_filter = plan.with_new_exprs(
663681
plan.expressions(),
664682
vec![child_plan.inputs()[0].clone()],
665683
)?;
666684
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
667685
}
668-
LogicalPlan::SubqueryAlias(subquery_alias) => {
686+
LogicalPlan::SubqueryAlias(ref subquery_alias) => {
669687
let mut replace_map = HashMap::new();
670688
for (i, (qualifier, field)) in
671689
subquery_alias.input.schema().iter().enumerate()
@@ -685,7 +703,7 @@ impl OptimizerRule for PushDownFilter {
685703
)?);
686704
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
687705
}
688-
LogicalPlan::Projection(projection) => {
706+
LogicalPlan::Projection(ref projection) => {
689707
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
690708
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
691709
// collect projection.
@@ -742,10 +760,10 @@ impl OptimizerRule for PushDownFilter {
742760
}
743761
}
744762
}
745-
None => return Ok(None),
763+
None => return Ok(Transformed::no(plan)),
746764
}
747765
}
748-
LogicalPlan::Union(union) => {
766+
LogicalPlan::Union(ref union) => {
749767
let mut inputs = Vec::with_capacity(union.inputs.len());
750768
for input in &union.inputs {
751769
let mut replace_map = HashMap::new();
@@ -770,7 +788,7 @@ impl OptimizerRule for PushDownFilter {
770788
schema: plan.schema().clone(),
771789
})
772790
}
773-
LogicalPlan::Aggregate(agg) => {
791+
LogicalPlan::Aggregate(ref agg) => {
774792
// We can push down Predicate which in groupby_expr.
775793
let group_expr_columns = agg
776794
.group_expr
@@ -821,13 +839,11 @@ impl OptimizerRule for PushDownFilter {
821839
None => new_agg,
822840
}
823841
}
824-
LogicalPlan::Join(join) => {
825-
match push_down_join(&filter.input, join, Some(&filter.predicate))? {
826-
Some(optimized_plan) => optimized_plan,
827-
None => return Ok(None),
828-
}
842+
LogicalPlan::Join(ref join) => {
843+
let unwrapped_plan = unwrap_arc(filter.clone().input);
844+
push_down_join(unwrapped_plan, join, Some(&filter.predicate))?.data
829845
}
830-
LogicalPlan::CrossJoin(cross_join) => {
846+
LogicalPlan::CrossJoin(ref cross_join) => {
831847
let predicates = split_conjunction_owned(filter.predicate.clone());
832848
let join = convert_cross_join_to_inner_join(cross_join.clone())?;
833849
let join_plan = LogicalPlan::Join(join);
@@ -843,9 +859,9 @@ impl OptimizerRule for PushDownFilter {
843859
vec![],
844860
true,
845861
)?;
846-
convert_to_cross_join_if_beneficial(plan)?
862+
convert_to_cross_join_if_beneficial(plan.data)?
847863
}
848-
LogicalPlan::TableScan(scan) => {
864+
LogicalPlan::TableScan(ref scan) => {
849865
let filter_predicates = split_conjunction(&filter.predicate);
850866
let results = scan
851867
.source
@@ -892,7 +908,7 @@ impl OptimizerRule for PushDownFilter {
892908
None => new_scan,
893909
}
894910
}
895-
LogicalPlan::Extension(extension_plan) => {
911+
LogicalPlan::Extension(ref extension_plan) => {
896912
let prevent_cols =
897913
extension_plan.node.prevent_predicate_push_down_columns();
898914

@@ -935,9 +951,10 @@ impl OptimizerRule for PushDownFilter {
935951
None => new_extension,
936952
}
937953
}
938-
_ => return Ok(None),
954+
_ => return Ok(Transformed::no(plan)),
939955
};
940-
Ok(Some(new_plan))
956+
957+
Ok(Transformed::yes(new_plan))
941958
}
942959
}
943960

@@ -1024,16 +1041,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
10241041

10251042
#[cfg(test)]
10261043
mod tests {
1027-
use super::*;
10281044
use std::any::Any;
10291045
use std::fmt::{Debug, Formatter};
10301046

1031-
use crate::optimizer::Optimizer;
1032-
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
1033-
use crate::test::*;
1034-
use crate::OptimizerContext;
1035-
10361047
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1048+
use async_trait::async_trait;
1049+
10371050
use datafusion_common::ScalarValue;
10381051
use datafusion_expr::expr::ScalarFunction;
10391052
use datafusion_expr::logical_plan::table_scan;
@@ -1043,7 +1056,13 @@ mod tests {
10431056
Volatility,
10441057
};
10451058

1046-
use async_trait::async_trait;
1059+
use crate::optimizer::Optimizer;
1060+
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
1061+
use crate::test::*;
1062+
use crate::OptimizerContext;
1063+
1064+
use super::*;
1065+
10471066
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
10481067

10491068
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
@@ -2298,9 +2317,9 @@ mod tests {
22982317
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
22992318

23002319
let optimized_plan = PushDownFilter::new()
2301-
.try_optimize(&plan, &OptimizerContext::new())
2320+
.rewrite(plan, &OptimizerContext::new())
23022321
.expect("failed to optimize plan")
2303-
.unwrap();
2322+
.data;
23042323

23052324
let expected = "\
23062325
Filter: a = Int64(1)\
@@ -2667,8 +2686,9 @@ Projection: a, b
26672686
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
26682687
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
26692688
let optimized_plan = PushDownFilter::new()
2670-
.try_optimize(&plan, &OptimizerContext::new())?
2671-
.expect("failed to optimize plan");
2689+
.rewrite(plan, &OptimizerContext::new())
2690+
.expect("failed to optimize plan")
2691+
.data;
26722692
assert_optimized_plan_eq(optimized_plan, expected)
26732693
}
26742694

0 commit comments

Comments
 (0)