Skip to content

Commit 8190cb9

Browse files
authored
Optimized push down filter #10291 (#10366)
1 parent 2c56a3c commit 8190cb9

File tree

1 file changed

+81
-58
lines changed

1 file changed

+81
-58
lines changed

datafusion/optimizer/src/push_down_filter.rs

+81-58
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
use std::collections::{HashMap, HashSet};
1818
use std::sync::Arc;
1919

20-
use crate::optimizer::ApplyOrder;
21-
use crate::{OptimizerConfig, OptimizerRule};
20+
use itertools::Itertools;
2221

2322
use datafusion_common::tree_node::{
2423
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
@@ -29,6 +28,7 @@ use datafusion_common::{
2928
};
3029
use datafusion_expr::expr::Alias;
3130
use datafusion_expr::expr_rewriter::replace_col;
31+
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
3232
use datafusion_expr::logical_plan::{
3333
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
3434
};
@@ -38,7 +38,8 @@ use datafusion_expr::{
3838
ScalarFunctionDefinition, TableProviderFilterPushDown,
3939
};
4040

41-
use itertools::Itertools;
41+
use crate::optimizer::ApplyOrder;
42+
use crate::{OptimizerConfig, OptimizerRule};
4243

4344
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
4445
/// they are applied as early as possible.
@@ -407,7 +408,7 @@ fn push_down_all_join(
407408
right: &LogicalPlan,
408409
on_filter: Vec<Expr>,
409410
is_inner_join: bool,
410-
) -> Result<LogicalPlan> {
411+
) -> Result<Transformed<LogicalPlan>> {
411412
let on_filter_empty = on_filter.is_empty();
412413
// Get pushable predicates from current optimizer state
413414
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
@@ -505,41 +506,43 @@ fn push_down_all_join(
505506
// wrap the join on the filter whose predicates must be kept
506507
match conjunction(keep_predicates) {
507508
Some(predicate) => {
508-
Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter)
509+
let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
510+
Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
509511
}
510-
None => Ok(plan),
512+
None => Ok(Transformed::no(plan)),
511513
}
512514
}
513515

514516
fn push_down_join(
515517
plan: &LogicalPlan,
516518
join: &Join,
517519
parent_predicate: Option<&Expr>,
518-
) -> Result<Option<LogicalPlan>> {
519-
let predicates = match parent_predicate {
520-
Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()),
521-
None => vec![],
522-
};
520+
) -> Result<Transformed<LogicalPlan>> {
521+
// Split the parent predicate into individual conjunctive parts.
522+
let predicates = parent_predicate
523+
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
523524

524-
// Convert JOIN ON predicate to Predicates
525+
// Extract conjunctions from the JOIN's ON filter, if present.
525526
let on_filters = join
526527
.filter
527528
.as_ref()
528-
.map(|e| split_conjunction_owned(e.clone()))
529-
.unwrap_or_default();
529+
.map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
530530

531531
let mut is_inner_join = false;
532532
let infer_predicates = if join.join_type == JoinType::Inner {
533533
is_inner_join = true;
534+
534535
// Only allow both side key is column.
535536
let join_col_keys = join
536537
.on
537538
.iter()
538-
.flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
539-
(Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
540-
_ => None,
539+
.filter_map(|(l, r)| {
540+
let left_col = l.try_into_col().ok()?;
541+
let right_col = r.try_into_col().ok()?;
542+
Some((left_col, right_col))
541543
})
542544
.collect::<Vec<_>>();
545+
543546
// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
544547
// For inner joins, duplicate filters for joined columns so filters can be pushed down
545548
// to both sides. Take the following query as an example:
@@ -559,6 +562,7 @@ fn push_down_join(
559562
.chain(on_filters.iter())
560563
.filter_map(|predicate| {
561564
let mut join_cols_to_replace = HashMap::new();
565+
562566
let columns = match predicate.to_columns() {
563567
Ok(columns) => columns,
564568
Err(e) => return Some(Err(e)),
@@ -596,20 +600,32 @@ fn push_down_join(
596600
};
597601

598602
if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() {
599-
return Ok(None);
603+
return Ok(Transformed::no(plan.clone()));
600604
}
601-
Ok(Some(push_down_all_join(
605+
606+
match push_down_all_join(
602607
predicates,
603608
infer_predicates,
604609
plan,
605610
&join.left,
606611
&join.right,
607612
on_filters,
608613
is_inner_join,
609-
)?))
614+
) {
615+
Ok(plan) => Ok(Transformed::yes(plan.data)),
616+
Err(e) => Err(e),
617+
}
610618
}
611619

612620
impl OptimizerRule for PushDownFilter {
621+
fn try_optimize(
622+
&self,
623+
_plan: &LogicalPlan,
624+
_config: &dyn OptimizerConfig,
625+
) -> Result<Option<LogicalPlan>> {
626+
internal_err!("Should have called PushDownFilter::rewrite")
627+
}
628+
613629
fn name(&self) -> &str {
614630
"push_down_filter"
615631
}
@@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
618634
Some(ApplyOrder::TopDown)
619635
}
620636

621-
fn try_optimize(
637+
fn supports_rewrite(&self) -> bool {
638+
true
639+
}
640+
641+
fn rewrite(
622642
&self,
623-
plan: &LogicalPlan,
643+
plan: LogicalPlan,
624644
_config: &dyn OptimizerConfig,
625-
) -> Result<Option<LogicalPlan>> {
645+
) -> Result<Transformed<LogicalPlan>> {
626646
let filter = match plan {
627-
LogicalPlan::Filter(filter) => filter,
628-
// we also need to pushdown filter in Join.
629-
LogicalPlan::Join(join) => return push_down_join(plan, join, None),
630-
_ => return Ok(None),
647+
LogicalPlan::Filter(ref filter) => filter,
648+
LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None),
649+
_ => return Ok(Transformed::no(plan)),
631650
};
632651

633652
let child_plan = filter.input.as_ref();
634653
let new_plan = match child_plan {
635-
LogicalPlan::Filter(child_filter) => {
654+
LogicalPlan::Filter(ref child_filter) => {
636655
let parents_predicates = split_conjunction(&filter.predicate);
637656
let set: HashSet<&&Expr> = parents_predicates.iter().collect();
638657

@@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
652671
new_predicate,
653672
child_filter.input.clone(),
654673
)?);
655-
self.try_optimize(&new_filter, _config)?
656-
.unwrap_or(new_filter)
674+
self.rewrite(new_filter, _config)?.data
657675
}
658676
LogicalPlan::Repartition(_)
659677
| LogicalPlan::Distinct(_)
660678
| LogicalPlan::Sort(_) => {
661-
// commutable
662679
let new_filter = plan.with_new_exprs(
663680
plan.expressions(),
664681
vec![child_plan.inputs()[0].clone()],
665682
)?;
666683
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
667684
}
668-
LogicalPlan::SubqueryAlias(subquery_alias) => {
685+
LogicalPlan::SubqueryAlias(ref subquery_alias) => {
669686
let mut replace_map = HashMap::new();
670687
for (i, (qualifier, field)) in
671688
subquery_alias.input.schema().iter().enumerate()
@@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
685702
)?);
686703
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
687704
}
688-
LogicalPlan::Projection(projection) => {
705+
LogicalPlan::Projection(ref projection) => {
689706
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
690707
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
691708
// collect projection.
@@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
742759
}
743760
}
744761
}
745-
None => return Ok(None),
762+
None => return Ok(Transformed::no(plan)),
746763
}
747764
}
748-
LogicalPlan::Union(union) => {
765+
LogicalPlan::Union(ref union) => {
749766
let mut inputs = Vec::with_capacity(union.inputs.len());
750767
for input in &union.inputs {
751768
let mut replace_map = HashMap::new();
@@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
770787
schema: plan.schema().clone(),
771788
})
772789
}
773-
LogicalPlan::Aggregate(agg) => {
790+
LogicalPlan::Aggregate(ref agg) => {
774791
// We can push down Predicate which in groupby_expr.
775792
let group_expr_columns = agg
776793
.group_expr
@@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
821838
None => new_agg,
822839
}
823840
}
824-
LogicalPlan::Join(join) => {
825-
match push_down_join(&filter.input, join, Some(&filter.predicate))? {
826-
Some(optimized_plan) => optimized_plan,
827-
None => return Ok(None),
828-
}
841+
LogicalPlan::Join(ref join) => {
842+
push_down_join(
843+
&unwrap_arc(filter.clone().input),
844+
join,
845+
Some(&filter.predicate),
846+
)?
847+
.data
829848
}
830-
LogicalPlan::CrossJoin(cross_join) => {
849+
LogicalPlan::CrossJoin(ref cross_join) => {
831850
let predicates = split_conjunction_owned(filter.predicate.clone());
832851
let join = convert_cross_join_to_inner_join(cross_join.clone())?;
833852
let join_plan = LogicalPlan::Join(join);
@@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
843862
vec![],
844863
true,
845864
)?;
846-
convert_to_cross_join_if_beneficial(plan)?
865+
convert_to_cross_join_if_beneficial(plan.data)?
847866
}
848-
LogicalPlan::TableScan(scan) => {
867+
LogicalPlan::TableScan(ref scan) => {
849868
let filter_predicates = split_conjunction(&filter.predicate);
850869
let results = scan
851870
.source
@@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
892911
None => new_scan,
893912
}
894913
}
895-
LogicalPlan::Extension(extension_plan) => {
914+
LogicalPlan::Extension(ref extension_plan) => {
896915
let prevent_cols =
897916
extension_plan.node.prevent_predicate_push_down_columns();
898917

@@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
935954
None => new_extension,
936955
}
937956
}
938-
_ => return Ok(None),
957+
_ => return Ok(Transformed::no(plan)),
939958
};
940-
Ok(Some(new_plan))
959+
960+
Ok(Transformed::yes(new_plan))
941961
}
942962
}
943963

@@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
10241044

10251045
#[cfg(test)]
10261046
mod tests {
1027-
use super::*;
10281047
use std::any::Any;
10291048
use std::fmt::{Debug, Formatter};
10301049

1031-
use crate::optimizer::Optimizer;
1032-
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
1033-
use crate::test::*;
1034-
use crate::OptimizerContext;
1035-
10361050
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1051+
use async_trait::async_trait;
1052+
10371053
use datafusion_common::ScalarValue;
10381054
use datafusion_expr::expr::ScalarFunction;
10391055
use datafusion_expr::logical_plan::table_scan;
@@ -1043,7 +1059,13 @@ mod tests {
10431059
Volatility,
10441060
};
10451061

1046-
use async_trait::async_trait;
1062+
use crate::optimizer::Optimizer;
1063+
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
1064+
use crate::test::*;
1065+
use crate::OptimizerContext;
1066+
1067+
use super::*;
1068+
10471069
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
10481070

10491071
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
@@ -2298,9 +2320,9 @@ mod tests {
22982320
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
22992321

23002322
let optimized_plan = PushDownFilter::new()
2301-
.try_optimize(&plan, &OptimizerContext::new())
2323+
.rewrite(plan, &OptimizerContext::new())
23022324
.expect("failed to optimize plan")
2303-
.unwrap();
2325+
.data;
23042326

23052327
let expected = "\
23062328
Filter: a = Int64(1)\
@@ -2667,8 +2689,9 @@ Projection: a, b
26672689
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
26682690
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
26692691
let optimized_plan = PushDownFilter::new()
2670-
.try_optimize(&plan, &OptimizerContext::new())?
2671-
.expect("failed to optimize plan");
2692+
.rewrite(plan, &OptimizerContext::new())
2693+
.expect("failed to optimize plan")
2694+
.data;
26722695
assert_optimized_plan_eq(optimized_plan, expected)
26732696
}
26742697

0 commit comments

Comments
 (0)