From d928d04a78b276cecda82e02d1ba4c73ba3736c4 Mon Sep 17 00:00:00 2001 From: wiedld Date: Tue, 10 Sep 2024 14:43:54 -0700 Subject: [PATCH] fix: add constants from either side to the UnionExec constants caveat: this has an unintended side effect, as the EnforceSorting removes the sort_expr from one input/side of the UnionExec (where it's not constant) --- .../physical-expr/src/equivalence/class.rs | 36 +++++++++++++++++++ .../src/equivalence/properties.rs | 16 ++------- datafusion/sqllogictest/test_files/order.slt | 4 +-- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 3fc7b0feb8a2c..3f6eb4972ef69 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::hash::{Hash, Hasher}; +use std::ops::RangeFull; use std::sync::Arc; use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; @@ -27,6 +29,7 @@ use crate::{ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; +use indexmap::IndexSet; #[derive(Debug, Clone)] /// A structure representing a expression known to be constant in a physical execution plan. @@ -124,6 +127,39 @@ pub fn const_exprs_contains( .any(|const_expr| const_expr.expr.eq(expr)) } +impl Eq for ConstExpr {} + +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(other.expr()) + } +} + +impl Hash for ConstExpr { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + } +} + +/// Concats two slices of `const_exprs, removing duplicates and +/// maintaining the order. +/// +/// Equality based upon the expression. `across_partitions` will +/// always be false as we do not validate the same constant value +/// on both sides. +pub fn concat_const_exprs(lhs: &[ConstExpr], rhs: &[ConstExpr]) -> Vec { + IndexSet::<&ConstExpr>::from_iter( + lhs.into_iter() + .chain(rhs.into_iter()) + .map(|const_expr| const_expr), + ) + .drain(RangeFull) + .map(|constant_expr| { + ConstExpr::new(constant_expr.expr.clone()).with_across_partitions(false) + }) + .collect() +} + /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by /// equality predicates (e.g. `a = b`), typically equi-join conditions and diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a5d54ee56cffe..d53342bcf40aa 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -18,6 +18,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; +use super::class::concat_const_exprs; use super::ordering::collapse_lex_ordering; use crate::equivalence::class::const_exprs_contains; use crate::equivalence::{ @@ -1539,19 +1540,8 @@ fn calculate_union_binary( } // First, calculate valid constants for the union. A quantity is constant - // after the union if it is constant in both sides. - let constants = lhs - .constants() - .iter() - .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) - .map(|const_expr| { - // TODO: When both sides' constants are valid across partitions, - // the union's constant should also be valid if values are - // the same. However, we do not have the capability to - // check this yet. - ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) - }) - .collect(); + // after the union if it is constant on one of the sides. + let constants = concat_const_exprs(lhs.constants(), rhs.constants()); // Next, calculate valid orderings for the union by searching for prefixes // in both sides. diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 671c18674accc..42b55427fe261 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1220,10 +1220,10 @@ physical_plan 01)ProjectionExec: expr=[b@0 as b, c@1 as c, a@2 as a, a0@3 as a0] 02)--SortPreservingMergeExec: [d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a@2 ASC NULLS LAST,a0@3 ASC NULLS LAST,b@0 ASC NULLS LAST], fetch=2 03)----UnionExec -04)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a@2 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] 05)--------ProjectionExec: expr=[b@1 as b, c@2 as c, a@0 as a, NULL as a0, d@3 as d] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true -07)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a0@3 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] +07)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] 08)--------ProjectionExec: expr=[b@1 as b, c@2 as c, NULL as a, a0@0 as a0, d@3 as d] 09)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true