Skip to content

Commit 6b2e999

Browse files
peter-tothccciudatu
authored andcommitted
Improve TreeNode and LogicalPlan APIs to accept owned closures, deprecate transform_down_mut() and transform_up_mut() (apache#10126)
* Deprecate `TreeNode::transform_down_mut()` and `TreeNode::transform_up_mut()` methods * Refactor `TreeNode` and `LogicalPlan` apply, transform, transform_up, transform_down and transform_down_up APIs to accept owned closures
1 parent 5e77491 commit 6b2e999

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+238
-209
lines changed

datafusion-examples/examples/function_factory.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
164164
impl ScalarFunctionWrapper {
165165
// replaces placeholders such as $1 with actual arguments (args[0]
166166
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
167-
let result = expr.clone().transform(&|e| {
167+
let result = expr.clone().transform(|e| {
168168
let r = match e {
169169
Expr::Placeholder(placeholder) => {
170170
let placeholder_position =

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {
9191

9292
impl MyAnalyzerRule {
9393
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
94-
plan.transform(&|plan| {
94+
plan.transform(|plan| {
9595
Ok(match plan {
9696
LogicalPlan::Filter(filter) => {
9797
let predicate = Self::analyze_expr(filter.predicate.clone())?;
@@ -107,7 +107,7 @@ impl MyAnalyzerRule {
107107
}
108108

109109
fn analyze_expr(expr: Expr) -> Result<Expr> {
110-
expr.transform(&|expr| {
110+
expr.transform(|expr| {
111111
// closure is invoked for all sub expressions
112112
Ok(match expr {
113113
Expr::Literal(ScalarValue::Int64(i)) => {
@@ -163,7 +163,7 @@ impl OptimizerRule for MyOptimizerRule {
163163

164164
/// use rewrite_expr to modify the expression tree.
165165
fn my_rewrite(expr: Expr) -> Result<Expr> {
166-
expr.transform(&|expr| {
166+
expr.transform(|expr| {
167167
// closure is invoked for all sub expressions
168168
Ok(match expr {
169169
Expr::Between(Between {

datafusion/common/src/tree_node.rs

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,6 @@ macro_rules! handle_transform_recursion {
3131
}};
3232
}
3333

34-
macro_rules! handle_transform_recursion_down {
35-
($F_DOWN:expr, $F_CHILD:expr) => {{
36-
$F_DOWN?.transform_children(|n| n.map_children($F_CHILD))
37-
}};
38-
}
39-
40-
macro_rules! handle_transform_recursion_up {
41-
($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{
42-
$SELF.map_children($F_CHILD)?.transform_parent(|n| $F_UP(n))
43-
}};
44-
}
45-
4634
/// Defines a visitable and rewriteable tree node. This trait is implemented
4735
/// for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well as expression
4836
/// trees ([`PhysicalExpr`], [`Expr`]) in DataFusion.
@@ -137,61 +125,85 @@ pub trait TreeNode: Sized {
137125
/// or run a check on the tree.
138126
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
139127
&self,
140-
f: &mut F,
128+
mut f: F,
141129
) -> Result<TreeNodeRecursion> {
142-
f(self)?.visit_children(|| self.apply_children(|c| c.apply(f)))
130+
fn apply_impl<N: TreeNode, F: FnMut(&N) -> Result<TreeNodeRecursion>>(
131+
node: &N,
132+
f: &mut F,
133+
) -> Result<TreeNodeRecursion> {
134+
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
135+
}
136+
137+
apply_impl(self, &mut f)
143138
}
144139

145140
/// Convenience utility for writing optimizer rules: Recursively apply the
146141
/// given function `f` to the tree in a bottom-up (post-order) fashion. When
147142
/// `f` does not apply to a given node, it is left unchanged.
148-
fn transform<F: Fn(Self) -> Result<Transformed<Self>>>(
143+
fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
149144
self,
150-
f: &F,
145+
f: F,
151146
) -> Result<Transformed<Self>> {
152147
self.transform_up(f)
153148
}
154149

155150
/// Convenience utility for writing optimizer rules: Recursively apply the
156151
/// given function `f` to a node and then to its children (pre-order traversal).
157152
/// When `f` does not apply to a given node, it is left unchanged.
158-
fn transform_down<F: Fn(Self) -> Result<Transformed<Self>>>(
153+
fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
159154
self,
160-
f: &F,
155+
mut f: F,
161156
) -> Result<Transformed<Self>> {
162-
handle_transform_recursion_down!(f(self), |c| c.transform_down(f))
157+
fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
158+
node: N,
159+
f: &mut F,
160+
) -> Result<Transformed<N>> {
161+
f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
162+
}
163+
164+
transform_down_impl(self, &mut f)
163165
}
164166

165167
/// Convenience utility for writing optimizer rules: Recursively apply the
166168
/// given mutable function `f` to a node and then to its children (pre-order
167169
/// traversal). When `f` does not apply to a given node, it is left unchanged.
170+
#[deprecated(since = "38.0.0", note = "Use `transform_down` instead")]
168171
fn transform_down_mut<F: FnMut(Self) -> Result<Transformed<Self>>>(
169172
self,
170173
f: &mut F,
171174
) -> Result<Transformed<Self>> {
172-
handle_transform_recursion_down!(f(self), |c| c.transform_down_mut(f))
175+
self.transform_down(f)
173176
}
174177

175178
/// Convenience utility for writing optimizer rules: Recursively apply the
176179
/// given function `f` to all children of a node, and then to the node itself
177180
/// (post-order traversal). When `f` does not apply to a given node, it is
178181
/// left unchanged.
179-
fn transform_up<F: Fn(Self) -> Result<Transformed<Self>>>(
182+
fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
180183
self,
181-
f: &F,
184+
mut f: F,
182185
) -> Result<Transformed<Self>> {
183-
handle_transform_recursion_up!(self, |c| c.transform_up(f), f)
186+
fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
187+
node: N,
188+
f: &mut F,
189+
) -> Result<Transformed<N>> {
190+
node.map_children(|c| transform_up_impl(c, f))?
191+
.transform_parent(f)
192+
}
193+
194+
transform_up_impl(self, &mut f)
184195
}
185196

186197
/// Convenience utility for writing optimizer rules: Recursively apply the
187198
/// given mutable function `f` to all children of a node, and then to the
188199
/// node itself (post-order traversal). When `f` does not apply to a given
189200
/// node, it is left unchanged.
201+
#[deprecated(since = "38.0.0", note = "Use `transform_up` instead")]
190202
fn transform_up_mut<F: FnMut(Self) -> Result<Transformed<Self>>>(
191203
self,
192204
f: &mut F,
193205
) -> Result<Transformed<Self>> {
194-
handle_transform_recursion_up!(self, |c| c.transform_up_mut(f), f)
206+
self.transform_up(f)
195207
}
196208

197209
/// Transforms the tree using `f_down` while traversing the tree top-down
@@ -200,8 +212,8 @@ pub trait TreeNode: Sized {
200212
///
201213
/// Use this method if you want to start the `f_up` process right where `f_down` jumps.
202214
/// This can make the whole process faster by reducing the number of `f_up` steps.
203-
/// If you don't need this, it's just like using `transform_down_mut` followed by
204-
/// `transform_up_mut` on the same tree.
215+
/// If you don't need this, it's just like using `transform_down` followed by
216+
/// `transform_up` on the same tree.
205217
///
206218
/// Consider the following tree structure:
207219
/// ```text
@@ -288,22 +300,34 @@ pub trait TreeNode: Sized {
288300
FU: FnMut(Self) -> Result<Transformed<Self>>,
289301
>(
290302
self,
291-
f_down: &mut FD,
292-
f_up: &mut FU,
303+
mut f_down: FD,
304+
mut f_up: FU,
293305
) -> Result<Transformed<Self>> {
294-
handle_transform_recursion!(
295-
f_down(self),
296-
|c| c.transform_down_up(f_down, f_up),
297-
f_up
298-
)
306+
fn transform_down_up_impl<
307+
N: TreeNode,
308+
FD: FnMut(N) -> Result<Transformed<N>>,
309+
FU: FnMut(N) -> Result<Transformed<N>>,
310+
>(
311+
node: N,
312+
f_down: &mut FD,
313+
f_up: &mut FU,
314+
) -> Result<Transformed<N>> {
315+
handle_transform_recursion!(
316+
f_down(node),
317+
|c| transform_down_up_impl(c, f_down, f_up),
318+
f_up
319+
)
320+
}
321+
322+
transform_down_up_impl(self, &mut f_down, &mut f_up)
299323
}
300324

301325
/// Returns true if `f` returns true for node in the tree.
302326
///
303327
/// Stops recursion as soon as a matching node is found
304328
fn exists<F: FnMut(&Self) -> bool>(&self, mut f: F) -> bool {
305329
let mut found = false;
306-
self.apply(&mut |n| {
330+
self.apply(|n| {
307331
Ok(if f(n) {
308332
found = true;
309333
TreeNodeRecursion::Stop
@@ -439,9 +463,7 @@ impl TreeNodeRecursion {
439463
/// This struct is used by tree transformation APIs such as
440464
/// - [`TreeNode::rewrite`],
441465
/// - [`TreeNode::transform_down`],
442-
/// - [`TreeNode::transform_down_mut`],
443466
/// - [`TreeNode::transform_up`],
444-
/// - [`TreeNode::transform_up_mut`],
445467
/// - [`TreeNode::transform_down_up`]
446468
///
447469
/// to control the transformation and return the transformed result.
@@ -1362,7 +1384,7 @@ mod tests {
13621384
fn $NAME() -> Result<()> {
13631385
let tree = test_tree();
13641386
let mut visits = vec![];
1365-
tree.apply(&mut |node| {
1387+
tree.apply(|node| {
13661388
visits.push(format!("f_down({})", node.data));
13671389
$F(node)
13681390
})?;
@@ -1451,10 +1473,7 @@ mod tests {
14511473
#[test]
14521474
fn $NAME() -> Result<()> {
14531475
let tree = test_tree();
1454-
assert_eq!(
1455-
tree.transform_down_up(&mut $F_DOWN, &mut $F_UP,)?,
1456-
$EXPECTED_TREE
1457-
);
1476+
assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE);
14581477

14591478
Ok(())
14601479
}
@@ -1466,7 +1485,7 @@ mod tests {
14661485
#[test]
14671486
fn $NAME() -> Result<()> {
14681487
let tree = test_tree();
1469-
assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE);
1488+
assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE);
14701489

14711490
Ok(())
14721491
}
@@ -1478,7 +1497,7 @@ mod tests {
14781497
#[test]
14791498
fn $NAME() -> Result<()> {
14801499
let tree = test_tree();
1481-
assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE);
1500+
assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE);
14821501

14831502
Ok(())
14841503
}

datafusion/core/src/datasource/listing/helpers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ use object_store::{ObjectMeta, ObjectStore};
5050
/// was performed
5151
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
5252
let mut is_applicable = true;
53-
expr.apply(&mut |expr| {
53+
expr.apply(|expr| {
5454
match expr {
5555
Expr::Column(Column { ref name, .. }) => {
5656
is_applicable &= col_names.contains(name);

datafusion/core/src/physical_optimizer/coalesce_batches.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
5454
}
5555

5656
let target_batch_size = config.execution.batch_size;
57-
plan.transform_up(&|plan| {
57+
plan.transform_up(|plan| {
5858
let plan_any = plan.as_any();
5959
// The goal here is to detect operators that could produce small batches and only
6060
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here

datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
5151
plan: Arc<dyn ExecutionPlan>,
5252
_config: &ConfigOptions,
5353
) -> Result<Arc<dyn ExecutionPlan>> {
54-
plan.transform_down(&|plan| {
54+
plan.transform_down(|plan| {
5555
let transformed =
5656
plan.as_any()
5757
.downcast_ref::<AggregateExec>()
@@ -179,7 +179,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
179179
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
180180
group_expr
181181
.clone()
182-
.transform(&|expr| {
182+
.transform(|expr| {
183183
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
184184
match expr.as_any().downcast_ref::<Column>() {
185185
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),

datafusion/core/src/physical_optimizer/convert_first_last.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder {
6060
plan: Arc<dyn ExecutionPlan>,
6161
_config: &ConfigOptions,
6262
) -> Result<Arc<dyn ExecutionPlan>> {
63-
plan.transform_up(&get_common_requirement_of_aggregate_input)
63+
plan.transform_up(get_common_requirement_of_aggregate_input)
6464
.data()
6565
}
6666

datafusion/core/src/physical_optimizer/enforce_distribution.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ impl PhysicalOptimizerRule for EnforceDistribution {
197197
// Run a top-down process to adjust input key ordering recursively
198198
let plan_requirements = PlanWithKeyRequirements::new_default(plan);
199199
let adjusted = plan_requirements
200-
.transform_down(&adjust_input_keys_ordering)
200+
.transform_down(adjust_input_keys_ordering)
201201
.data()?;
202202
adjusted.plan
203203
} else {
204204
// Run a bottom-up process
205-
plan.transform_up(&|plan| {
205+
plan.transform_up(|plan| {
206206
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
207207
})
208208
.data()?
@@ -211,7 +211,7 @@ impl PhysicalOptimizerRule for EnforceDistribution {
211211
let distribution_context = DistributionContext::new_default(adjusted);
212212
// Distribution enforcement needs to be applied bottom-up.
213213
let distribution_context = distribution_context
214-
.transform_up(&|distribution_context| {
214+
.transform_up(|distribution_context| {
215215
ensure_distribution(distribution_context, config)
216216
})
217217
.data()?;
@@ -1772,22 +1772,22 @@ pub(crate) mod tests {
17721772
let plan_requirements =
17731773
PlanWithKeyRequirements::new_default($PLAN.clone());
17741774
let adjusted = plan_requirements
1775-
.transform_down(&adjust_input_keys_ordering)
1775+
.transform_down(adjust_input_keys_ordering)
17761776
.data()
17771777
.and_then(check_integrity)?;
17781778
// TODO: End state payloads will be checked here.
17791779
adjusted.plan
17801780
} else {
17811781
// Run reorder_join_keys_to_inputs rule
1782-
$PLAN.clone().transform_up(&|plan| {
1782+
$PLAN.clone().transform_up(|plan| {
17831783
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
17841784
})
17851785
.data()?
17861786
};
17871787

17881788
// Then run ensure_distribution rule
17891789
DistributionContext::new_default(adjusted)
1790-
.transform_up(&|distribution_context| {
1790+
.transform_up(|distribution_context| {
17911791
ensure_distribution(distribution_context, &config)
17921792
})
17931793
.data()

0 commit comments

Comments
 (0)