Skip to content

Commit 5bdc745

Browse files
authored
Enhance short circuit handling in CommonSubexprEliminate (#11197)
* Enhance short circuit handling in `CommonSubexprEliminate` * explain is_tree * adjust test
1 parent a753c37 commit 5bdc745

File tree

1 file changed

+68
-32
lines changed

1 file changed

+68
-32
lines changed

datafusion/optimizer/src/common_subexpr_eliminate.rs

+68-32
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,13 @@ struct Identifier<'n> {
5656
}
5757

5858
impl<'n> Identifier<'n> {
59-
fn new(expr: &'n Expr, random_state: &RandomState) -> Self {
59+
fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self {
6060
let mut hasher = random_state.build_hasher();
61-
expr.hash_node(&mut hasher);
61+
if is_tree {
62+
expr.hash(&mut hasher);
63+
} else {
64+
expr.hash_node(&mut hasher);
65+
}
6266
let hash = hasher.finish();
6367
Self { hash, expr }
6468
}
@@ -911,31 +915,30 @@ struct ExprIdentifierVisitor<'a, 'n> {
911915
found_common: bool,
912916
}
913917

914-
/// Record item that used when traversing a expression tree.
918+
/// Record item that used when traversing an expression tree.
915919
enum VisitRecord<'n> {
916-
/// `usize` postorder index assigned in `f-down`(). Starts from 0.
917-
EnterMark(usize),
918-
/// the node's children were skipped => jump to f_up on same node
919-
JumpMark,
920+
/// Contains the post-order index assigned in during the first, visiting traversal and
921+
/// a boolean flag to indicate if the record marks an expression subtree (not just a
922+
/// single node).
923+
EnterMark(usize, bool),
920924
/// Accumulated identifier of sub expression.
921925
ExprItem(Identifier<'n>),
922926
}
923927

924928
impl<'n> ExprIdentifierVisitor<'_, 'n> {
925929
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
926930
/// before it.
927-
fn pop_enter_mark(&mut self) -> Option<(usize, Option<Identifier<'n>>)> {
931+
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>) {
928932
let mut expr_id = None;
929933

930934
while let Some(item) = self.visit_stack.pop() {
931935
match item {
932-
VisitRecord::EnterMark(idx) => {
933-
return Some((idx, expr_id));
936+
VisitRecord::EnterMark(down_index, tree) => {
937+
return (down_index, tree, expr_id);
934938
}
935939
VisitRecord::ExprItem(id) => {
936940
expr_id = Some(id.combine(expr_id));
937941
}
938-
VisitRecord::JumpMark => return None,
939942
}
940943
}
941944
unreachable!("Enter mark should paired with node number");
@@ -947,30 +950,32 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
947950

948951
fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
949952
// TODO: consider non-volatile sub-expressions for CSE
950-
// TODO: consider surely executed children of "short circuited"s for CSE
951-
952-
// If an expression can short circuit its children then don't consider it for CSE
953-
// (https://github.com/apache/arrow-datafusion/issues/8814).
954-
if expr.short_circuits() {
955-
self.visit_stack.push(VisitRecord::JumpMark);
956953

957-
return Ok(TreeNodeRecursion::Jump);
958-
}
954+
// If an expression can short circuit its children then don't consider its
955+
// children for CSE (https://github.com/apache/arrow-datafusion/issues/8814).
956+
// This means that we don't recurse into its children, but handle the expression
957+
// as a subtree when we calculate its identifier.
958+
// TODO: consider surely executed children of "short circuited"s for CSE
959+
let is_tree = expr.short_circuits();
960+
let tnr = if is_tree {
961+
TreeNodeRecursion::Jump
962+
} else {
963+
TreeNodeRecursion::Continue
964+
};
959965

960966
self.id_array.push((0, None));
961967
self.visit_stack
962-
.push(VisitRecord::EnterMark(self.down_index));
968+
.push(VisitRecord::EnterMark(self.down_index, is_tree));
963969
self.down_index += 1;
964970

965-
Ok(TreeNodeRecursion::Continue)
971+
Ok(tnr)
966972
}
967973

968974
fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
969-
let Some((down_index, sub_expr_id)) = self.pop_enter_mark() else {
970-
return Ok(TreeNodeRecursion::Continue);
971-
};
975+
let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark();
972976

973-
let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id);
977+
let expr_id =
978+
Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id);
974979

975980
self.id_array[down_index].0 = self.up_index;
976981
if !self.expr_mask.ignores(expr) {
@@ -1015,19 +1020,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
10151020
self.alias_counter += 1;
10161021
}
10171022

1018-
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
1019-
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
1023+
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the
1024+
// `id_array`, which records the expr's identifier used to rewrite expr. So if we
10201025
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
1021-
if expr.short_circuits() {
1022-
return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
1023-
}
1026+
let is_tree = expr.short_circuits();
1027+
let tnr = if is_tree {
1028+
TreeNodeRecursion::Jump
1029+
} else {
1030+
TreeNodeRecursion::Continue
1031+
};
10241032

10251033
let (up_index, expr_id) = self.id_array[self.down_index];
10261034
self.down_index += 1;
10271035

10281036
// skip `Expr`s without identifier (empty identifier).
10291037
let Some(expr_id) = expr_id else {
1030-
return Ok(Transformed::no(expr));
1038+
return Ok(Transformed::new(expr, false, tnr));
10311039
};
10321040

10331041
let count = self.expr_stats.get(&expr_id).unwrap();
@@ -1055,7 +1063,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> {
10551063

10561064
Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
10571065
} else {
1058-
Ok(Transformed::no(expr))
1066+
Ok(Transformed::new(expr, false, tnr))
10591067
}
10601068
}
10611069

@@ -1802,4 +1810,32 @@ mod test {
18021810
assert!(result.len() == 1);
18031811
Ok(())
18041812
}
1813+
1814+
#[test]
1815+
fn test_short_circuits() -> Result<()> {
1816+
let table_scan = test_table_scan()?;
1817+
1818+
let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1819+
let not_extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1820+
let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1821+
let plan = LogicalPlanBuilder::from(table_scan.clone())
1822+
.project(vec![
1823+
extracted_short_circuit.clone().alias("c1"),
1824+
extracted_short_circuit.alias("c2"),
1825+
not_extracted_short_circuit_leg_1.clone().alias("c3"),
1826+
not_extracted_short_circuit_leg_2.clone().alias("c4"),
1827+
not_extracted_short_circuit_leg_1
1828+
.or(not_extracted_short_circuit_leg_2)
1829+
.alias("c5"),
1830+
])?
1831+
.build()?;
1832+
1833+
let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\
1834+
\n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\
1835+
\n TableScan: test";
1836+
1837+
assert_optimized_plan_eq(expected, plan, None);
1838+
1839+
Ok(())
1840+
}
18051841
}

0 commit comments

Comments
 (0)