Skip to content

Commit 22a10ae

Browse files
committed
Handle subqueries
1 parent 86d60e0 commit 22a10ae

File tree

2 files changed

+68
-58
lines changed

2 files changed

+68
-58
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,8 +649,12 @@ impl<T> Transformed<T> {
649649
}
650650

651651
impl Transformed<()> {
652-
/// Invoke the given function `f` and combine the transformed state with the
653-
/// current state
652+
/// Invoke the given function `f` and combine the transformed state with
653+
/// the current state,
654+
///
655+
/// if f() returns an Err, returns that err
656+
/// If f() returns Ok, returns a true transformed flag if either self or
657+
/// the result of f() was transformed
654658
pub fn and_then<F>(self, f: F) -> Result<Transformed<()>>
655659
where
656660
F: FnOnce() -> Result<Transformed<()>>,

datafusion/expr/src/logical_plan/mutate.rs

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
// under the License.
1717

1818
use super::plan::*;
19-
use crate::{lit, Expr};
19+
use crate::expr::{Exists, InSubquery};
20+
use crate::Expr;
2021
use datafusion_common::tree_node::Transformed;
2122
use datafusion_common::Result;
2223
use datafusion_common::{DFSchema, DFSchemaRef};
@@ -27,25 +28,13 @@ use std::sync::Arc;
2728
fn rewrite_expr_iter_mut<'a, F>(
2829
i: impl IntoIterator<Item = &'a mut Expr>,
2930
mut f: F,
30-
) -> Result<()>
31+
) -> Result<Transformed<()>>
3132
where
32-
F: FnMut(Expr) -> Result<Expr>,
33+
F: FnMut(&mut Expr) -> Result<Transformed<()>>,
3334
{
34-
i.into_iter().try_for_each(|e| rewrite_expr(e, &mut f))
35-
}
36-
37-
/// Rewrites an expression in place using a function that takes the expression by ownership
38-
fn rewrite_expr<'a, F>(e: &'a mut Expr, mut f: F) -> Result<()>
39-
where
40-
F: FnMut(Expr) -> Result<Expr>,
41-
{
42-
let mut t = lit(0);
43-
std::mem::swap(e, &mut t);
44-
// transform
45-
let mut t = f(t)?;
46-
// put it back
47-
std::mem::swap(e, &mut t);
48-
Ok(())
35+
i.into_iter().fold(Ok(Transformed::no(())), |acc, expr| {
36+
acc.and_then(|acc| acc.and_then(|| f(expr)))
37+
})
4938
}
5039

5140
impl LogicalPlan {
@@ -54,61 +43,58 @@ impl LogicalPlan {
5443
///
5544
/// If the closure returns an error, the error is returned and the expressions
5645
/// are left in a partially modified state
57-
pub fn rewrite_exprs<F>(mut self, mut f: F) -> Result<Self>
46+
pub fn rewrite_exprs<F>(&mut self, mut f: F) -> Result<Transformed<()>>
5847
where
59-
F: FnMut(Expr) -> Result<Expr>,
48+
F: FnMut(&mut Expr) -> Result<Transformed<()>>,
6049
{
61-
match &mut self {
50+
match self {
6251
LogicalPlan::Projection(Projection { expr, .. }) => {
63-
rewrite_expr_iter_mut(expr.iter_mut(), &mut f)?;
52+
rewrite_expr_iter_mut(expr.iter_mut(), f)
6453
}
6554
LogicalPlan::Values(Values { values, .. }) => {
66-
rewrite_expr_iter_mut(values.iter_mut().flatten(), &mut f)?;
67-
}
68-
LogicalPlan::Filter(Filter { predicate, .. }) => {
69-
rewrite_expr(predicate, &mut f)?
55+
rewrite_expr_iter_mut(values.iter_mut().flatten(), f)
7056
}
57+
LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate),
7158
LogicalPlan::Repartition(Repartition {
7259
partitioning_scheme,
7360
..
7461
}) => match partitioning_scheme {
75-
Partitioning::Hash(expr, _) => {
76-
rewrite_expr_iter_mut(expr.iter_mut(), &mut f)?
77-
}
62+
Partitioning::Hash(expr, _) => rewrite_expr_iter_mut(expr.iter_mut(), f),
7863
Partitioning::DistributeBy(expr) => {
79-
rewrite_expr_iter_mut(expr.iter_mut(), &mut f)?
64+
rewrite_expr_iter_mut(expr.iter_mut(), f)
8065
}
81-
Partitioning::RoundRobinBatch(_) => {}
66+
Partitioning::RoundRobinBatch(_) => Ok(Transformed::no(())),
8267
},
8368
LogicalPlan::Window(Window { window_expr, .. }) => {
84-
rewrite_expr_iter_mut(window_expr.iter_mut(), &mut f)?;
69+
rewrite_expr_iter_mut(window_expr.iter_mut(), f)
8570
}
8671
LogicalPlan::Aggregate(Aggregate {
8772
group_expr,
8873
aggr_expr,
8974
..
90-
}) => rewrite_expr_iter_mut(
91-
group_expr.iter_mut().chain(aggr_expr.iter_mut()),
92-
&mut f,
93-
)?,
75+
}) => {
76+
let exprs = group_expr.iter_mut().chain(aggr_expr.iter_mut());
77+
rewrite_expr_iter_mut(exprs, f)
78+
}
9479
// There are two part of expression for join, equijoin(on) and non-equijoin(filter).
9580
// 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
9681
// 2. the second part is non-equijoin(filter).
9782
LogicalPlan::Join(Join { on, filter, .. }) => {
9883
// don't look at the equijoin expressions as a whole
99-
rewrite_expr_iter_mut(
100-
on.iter_mut().flat_map(|(e1, e2)| {
101-
std::iter::once(e1).chain(std::iter::once(e2))
102-
}),
103-
&mut f,
104-
)?;
84+
let exprs = on
85+
.iter_mut()
86+
.flat_map(|(e1, e2)| std::iter::once(e1).chain(std::iter::once(e2)));
87+
88+
let result = rewrite_expr_iter_mut(exprs, &mut f)?;
10589

10690
if let Some(filter) = filter.as_mut() {
107-
rewrite_expr(filter, &mut f)?;
91+
result.and_then(|| f(filter))
92+
} else {
93+
Ok(result)
10894
}
10995
}
11096
LogicalPlan::Sort(Sort { expr, .. }) => {
111-
rewrite_expr_iter_mut(expr.iter_mut(), &mut f)?
97+
rewrite_expr_iter_mut(expr.iter_mut(), f)
11298
}
11399
LogicalPlan::Extension(extension) => {
114100
// would be nice to avoid this copy -- maybe can
@@ -117,7 +103,7 @@ impl LogicalPlan {
117103
todo!();
118104
}
119105
LogicalPlan::TableScan(TableScan { filters, .. }) => {
120-
rewrite_expr_iter_mut(filters.iter_mut(), &mut f)?;
106+
rewrite_expr_iter_mut(filters.iter_mut(), f)
121107
}
122108
LogicalPlan::Unnest(Unnest { column, .. }) => {
123109
//f(&Expr::Column(column.clone()))
@@ -128,13 +114,14 @@ impl LogicalPlan {
128114
select_expr,
129115
sort_expr,
130116
..
131-
})) => rewrite_expr_iter_mut(
132-
on_expr
117+
})) => {
118+
let exprs = on_expr
133119
.iter_mut()
134120
.chain(select_expr.iter_mut())
135-
.chain(sort_expr.iter_mut().flat_map(|x| x.iter_mut())),
136-
&mut f,
137-
)?,
121+
.chain(sort_expr.iter_mut().flat_map(|x| x.iter_mut()));
122+
123+
rewrite_expr_iter_mut(exprs, f)
124+
}
138125
// plans without expressions
139126
LogicalPlan::EmptyRelation(_)
140127
| LogicalPlan::RecursiveQuery(_)
@@ -151,10 +138,8 @@ impl LogicalPlan {
151138
| LogicalPlan::Ddl(_)
152139
| LogicalPlan::Copy(_)
153140
| LogicalPlan::DescribeTable(_)
154-
| LogicalPlan::Prepare(_) => {}
141+
| LogicalPlan::Prepare(_) => Ok(Transformed::no(())),
155142
}
156-
157-
Ok(self)
158143
}
159144
}
160145

@@ -168,7 +153,7 @@ const PLACEHOLDER: OnceCell<Arc<LogicalPlan>> = OnceCell::new();
168153
/// of the Arc to avoid cloning the entire plan
169154
///
170155
/// On error, the node will be partially rewritten (left with a placeholder logical plan)
171-
fn rewrite_arc<F>(node: &mut Arc<LogicalPlan>, f: &mut F) -> Result<Transformed<()>>
156+
fn rewrite_arc<F>(node: &mut Arc<LogicalPlan>, mut f: F) -> Result<Transformed<()>>
172157
where
173158
F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>,
174159
{
@@ -215,7 +200,7 @@ impl LogicalPlan {
215200
where
216201
F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>,
217202
{
218-
match self {
203+
let children_result = match self {
219204
LogicalPlan::Projection(Projection { input, .. }) => {
220205
rewrite_arc(input, &mut f)
221206
}
@@ -269,6 +254,27 @@ impl LogicalPlan {
269254
| LogicalPlan::EmptyRelation { .. }
270255
| LogicalPlan::Values { .. }
271256
| LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())),
272-
}
257+
}?;
258+
259+
// after visiting the actual children we we need to visit any subqueries
260+
// that are inside the expressions
261+
children_result.and_then(|| self.rewrite_subqueries(&mut f))
262+
}
263+
264+
/// applies the closure `f` to LogicalPlans in any subquery expressions
265+
///
266+
/// If Err is returned, the plan may be left in a partially modified state
267+
fn rewrite_subqueries<F>(&mut self, mut f: F) -> Result<Transformed<()>>
268+
where
269+
F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>,
270+
{
271+
self.rewrite_exprs(|expr| match expr {
272+
Expr::Exists(Exists { subquery, .. })
273+
| Expr::InSubquery(InSubquery { subquery, .. })
274+
| Expr::ScalarSubquery(subquery) => {
275+
rewrite_arc(&mut subquery.subquery, &mut f)
276+
}
277+
_ => Ok(Transformed::no(())),
278+
})
273279
}
274280
}

0 commit comments

Comments
 (0)