Skip to content

Commit fc71133

Browse files
committed
move is_volatile() check out of visitor
1 parent be59824 commit fc71133

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,25 @@ impl CommonSubexprEliminate {
201201
id_array: &mut IdArray<'n>,
202202
expr_mask: ExprMask,
203203
) -> Result<bool> {
204-
let mut visitor = ExprIdentifierVisitor {
205-
expr_stats,
206-
id_array,
207-
visit_stack: vec![],
208-
down_index: 0,
209-
up_index: 0,
210-
expr_mask,
211-
random_state: &self.random_state,
212-
found_common: false,
213-
};
214-
expr.visit(&mut visitor)?;
204+
// related to https://github.com/apache/arrow-datafusion/issues/8814
205+
// If the expr contain volatile expression or is a short-circuit expression, skip it.
206+
Ok(if expr.is_volatile()? {
207+
false
208+
} else {
209+
let mut visitor = ExprIdentifierVisitor {
210+
expr_stats,
211+
id_array,
212+
visit_stack: vec![],
213+
down_index: 0,
214+
up_index: 0,
215+
expr_mask,
216+
random_state: &self.random_state,
217+
found_common: false,
218+
};
219+
expr.visit(&mut visitor)?;
215220

216-
Ok(visitor.found_common)
221+
visitor.found_common
222+
})
217223
}
218224

219225
/// Rewrites `exprs_list` with common sub-expressions replaced with a new
@@ -950,11 +956,9 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
950956
type Node = Expr;
951957

952958
fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
953-
// related to https://github.com/apache/arrow-datafusion/issues/8814
954-
// If the expr contain volatile expression or is a short-circuit expression, skip it.
955-
// TODO: propagate is_volatile state bottom-up + consider non-volatile sub-expressions for CSE
959+
// TODO: consider non-volatile sub-expressions for CSE
956960
// TODO: consider surely executed children of "short circuited"s for CSE
957-
if expr.short_circuits() || expr.is_volatile()? {
961+
if expr.short_circuits() {
958962
self.visit_stack.push(VisitRecord::JumpMark);
959963

960964
return Ok(TreeNodeRecursion::Jump);
@@ -1013,14 +1017,6 @@ struct CommonSubexprRewriter<'a, 'n> {
10131017
impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
10141018
type Node = Expr;
10151019

1016-
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
1017-
if matches!(expr, Expr::Alias(_)) {
1018-
self.alias_counter -= 1
1019-
}
1020-
1021-
Ok(Transformed::no(expr))
1022-
}
1023-
10241020
fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
10251021
if matches!(expr, Expr::Alias(_)) {
10261022
self.alias_counter += 1;
@@ -1029,7 +1025,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
10291025
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
10301026
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
10311027
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
1032-
if expr.short_circuits() || expr.is_volatile()? {
1028+
if expr.short_circuits() {
10331029
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
10341030
}
10351031

@@ -1069,6 +1065,14 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
10691065
Ok(Transformed::no(expr))
10701066
}
10711067
}
1068+
1069+
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
1070+
if matches!(expr, Expr::Alias(_)) {
1071+
self.alias_counter -= 1
1072+
}
1073+
1074+
Ok(Transformed::no(expr))
1075+
}
10721076
}
10731077

10741078
/// Replace common sub-expression in `expr` with the corresponding temporary
@@ -1080,14 +1084,18 @@ fn replace_common_expr<'n>(
10801084
common_exprs: &mut CommonExprs<'n>,
10811085
alias_generator: &AliasGenerator,
10821086
) -> Result<Transformed<Expr>> {
1083-
expr.rewrite(&mut CommonSubexprRewriter {
1084-
expr_stats,
1085-
id_array,
1086-
common_exprs,
1087-
down_index: 0,
1088-
alias_counter: 0,
1089-
alias_generator,
1090-
})
1087+
if id_array.is_empty() {
1088+
Ok(Transformed::no(expr))
1089+
} else {
1090+
expr.rewrite(&mut CommonSubexprRewriter {
1091+
expr_stats,
1092+
id_array,
1093+
common_exprs,
1094+
down_index: 0,
1095+
alias_counter: 0,
1096+
alias_generator,
1097+
})
1098+
}
10911099
}
10921100

10931101
#[cfg(test)]

0 commit comments

Comments
 (0)