diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e659e62d7ac7..62b6ad287aa6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1140,6 +1140,7 @@ dependencies = [ "datafusion-functions-array", "datafusion-optimizer", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-sql", "flate2", @@ -1324,6 +1325,7 @@ dependencies = [ "chrono", "datafusion-common", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-physical-expr", "hashbrown 0.14.5", "indexmap 2.2.6", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 54ca38af675f..9f1f7484357b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -105,6 +105,7 @@ datafusion-functions-aggregate = { workspace = true } datafusion-functions-array = { workspace = true, optional = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } diff --git a/datafusion/core/src/physical_optimizer/convert_first_last.rs b/datafusion/core/src/physical_optimizer/convert_first_last.rs deleted file mode 100644 index 62537169cfc6..000000000000 --- a/datafusion/core/src/physical_optimizer/convert_first_last.rs +++ /dev/null @@ -1,260 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_common::Result; -use datafusion_common::{ - config::ConfigOptions, - tree_node::{Transformed, TransformedResult, TreeNode}, -}; -use datafusion_physical_expr::expressions::{FirstValue, LastValue}; -use datafusion_physical_expr::{ - equivalence::ProjectionMapping, reverse_order_bys, AggregateExpr, - EquivalenceProperties, PhysicalSortRequirement, -}; -use datafusion_physical_plan::aggregates::concat_slices; -use datafusion_physical_plan::{ - aggregates::{AggregateExec, AggregateMode}, - ExecutionPlan, ExecutionPlanProperties, InputOrderMode, -}; -use std::sync::Arc; - -use datafusion_physical_plan::windows::get_ordered_partition_by_indices; - -use super::PhysicalOptimizerRule; - -/// The optimizer rule check the ordering requirements of the aggregate expressions. -/// And convert between FIRST_VALUE and LAST_VALUE if possible. -/// For example, If we have an ascending values and we want LastValue from the descending requirement, -/// it is equivalent to FirstValue with the current ascending ordering. -/// -/// The concrete example is that, says we have values c1 with [1, 2, 3], which is an ascending order. -/// If we want LastValue(c1 order by desc), which is the first value of reversed c1 [3, 2, 1], -/// so we can convert the aggregate expression to FirstValue(c1 order by asc), -/// since the current ordering is already satisfied, it saves our time! -#[derive(Default)] -pub struct OptimizeAggregateOrder {} - -impl OptimizeAggregateOrder { - pub fn new() -> Self { - Self::default() - } -} - -impl PhysicalOptimizerRule for OptimizeAggregateOrder { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - plan.transform_up(get_common_requirement_of_aggregate_input) - .data() - } - - fn name(&self) -> &str { - "OptimizeAggregateOrder" - } - - fn schema_check(&self) -> bool { - true - } -} - -fn get_common_requirement_of_aggregate_input( - plan: Arc, -) -> Result>> { - if let Some(aggr_exec) = plan.as_any().downcast_ref::() { - let input = aggr_exec.input(); - let mut aggr_expr = try_get_updated_aggr_expr_from_child(aggr_exec); - let group_by = aggr_exec.group_expr(); - let mode = aggr_exec.mode(); - - let input_eq_properties = input.equivalence_properties(); - let groupby_exprs = group_by.input_exprs(); - // If existing ordering satisfies a prefix of the GROUP BY expressions, - // prefix requirements with this section. In this case, aggregation will - // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, input); - let requirement = indices - .iter() - .map(|&idx| PhysicalSortRequirement { - expr: groupby_exprs[idx].clone(), - options: None, - }) - .collect::>(); - - try_convert_first_last_if_better( - &requirement, - &mut aggr_expr, - input_eq_properties, - )?; - - let required_input_ordering = (!requirement.is_empty()).then_some(requirement); - - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; - let projection_mapping = - ProjectionMapping::try_new(group_by.expr(), &input.schema())?; - - let cache = AggregateExec::compute_properties( - input, - plan.schema().clone(), - &projection_mapping, - mode, - &input_order_mode, - ); - - let aggr_exec = aggr_exec.new_with_aggr_expr_and_ordering_info( - required_input_ordering, - aggr_expr, - cache, - input_order_mode, - ); - - Ok(Transformed::yes( - Arc::new(aggr_exec) as Arc - )) - } else { - Ok(Transformed::no(plan)) - } -} - -/// In `create_initial_plan` for LogicalPlan::Aggregate, we have a nested AggregateExec where the first layer -/// is in Partial mode and the second layer is in Final or Finalpartitioned mode. -/// If the first layer of aggregate plan is transformed, we need to update the child of the layer with final mode. -/// Therefore, we check it and get the updated aggregate expressions. -/// -/// If AggregateExec is created from elsewhere, we skip the check and return the original aggregate expressions. -fn try_get_updated_aggr_expr_from_child( - aggr_exec: &AggregateExec, -) -> Vec> { - let input = aggr_exec.input(); - if aggr_exec.mode() == &AggregateMode::Final - || aggr_exec.mode() == &AggregateMode::FinalPartitioned - { - // Some aggregators may be modified during initialization for - // optimization purposes. For example, a FIRST_VALUE may turn - // into a LAST_VALUE with the reverse ordering requirement. - // To reflect such changes to subsequent stages, use the updated - // `AggregateExpr`/`PhysicalSortExpr` objects. - // - // The bottom up transformation is the mirror of LogicalPlan::Aggregate creation in [create_initial_plan] - if let Some(c_aggr_exec) = input.as_any().downcast_ref::() { - if c_aggr_exec.mode() == &AggregateMode::Partial { - // If the input is an AggregateExec in Partial mode, then the - // input is a CoalescePartitionsExec. In this case, the - // AggregateExec is the second stage of aggregation. The - // requirements of the second stage are the requirements of - // the first stage. - return c_aggr_exec.aggr_expr().to_vec(); - } - } - } - - aggr_exec.aggr_expr().to_vec() -} - -/// Get the common requirement that satisfies all the aggregate expressions. -/// -/// # Parameters -/// -/// - `aggr_exprs`: A slice of `Arc` containing all the -/// aggregate expressions. -/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the -/// physical GROUP BY expression. -/// - `eq_properties`: A reference to an `EquivalenceProperties` instance -/// representing equivalence properties for ordering. -/// - `agg_mode`: A reference to an `AggregateMode` instance representing the -/// mode of aggregation. -/// -/// # Returns -/// -/// A `LexRequirement` instance, which is the requirement that satisfies all the -/// aggregate requirements. Returns an error in case of conflicting requirements. -/// -/// Similar to the one in datafusion/physical-plan/src/aggregates/mod.rs, but this -/// function care only the possible conversion between FIRST_VALUE and LAST_VALUE -fn try_convert_first_last_if_better( - prefix_requirement: &[PhysicalSortRequirement], - aggr_exprs: &mut [Arc], - eq_properties: &EquivalenceProperties, -) -> Result<()> { - for aggr_expr in aggr_exprs.iter_mut() { - let aggr_req = aggr_expr.order_bys().unwrap_or(&[]); - let reverse_aggr_req = reverse_order_bys(aggr_req); - let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); - let reverse_aggr_req = - PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); - - if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { - let mut first_value = first_value.clone(); - - if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &aggr_req, - )) { - first_value = first_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(first_value) as _; - } else if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &reverse_aggr_req, - )) { - // Converting to LAST_VALUE enables more efficient execution - // given the existing ordering: - let mut last_value = first_value.convert_to_last(); - last_value = last_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(last_value) as _; - } else { - // Requirement is not satisfied with existing ordering. - first_value = first_value.with_requirement_satisfied(false); - *aggr_expr = Arc::new(first_value) as _; - } - continue; - } - if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { - let mut last_value = last_value.clone(); - if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &aggr_req, - )) { - last_value = last_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(last_value) as _; - } else if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &reverse_aggr_req, - )) { - // Converting to FIRST_VALUE enables more efficient execution - // given the existing ordering: - let mut first_value = last_value.convert_to_first(); - first_value = first_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(first_value) as _; - } else { - // Requirement is not satisfied with existing ordering. - last_value = last_value.with_requirement_satisfied(false); - *aggr_expr = Arc::new(last_value) as _; - } - continue; - } - } - - Ok(()) -} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index c80668c6da74..7cc9a0fb75d4 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -24,7 +24,6 @@ pub mod aggregate_statistics; pub mod coalesce_batches; pub mod combine_partial_final_agg; -mod convert_first_last; pub mod enforce_distribution; pub mod enforce_sorting; pub mod join_selection; @@ -37,6 +36,7 @@ pub mod pruning; pub mod replace_with_order_preserving_variants; mod sort_pushdown; pub mod topk_aggregation; +mod update_aggr_exprs; mod utils; #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 416985983dfe..e3b60a0cca80 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -19,8 +19,8 @@ use std::sync::Arc; -use super::convert_first_last::OptimizeAggregateOrder; use super::projection_pushdown::ProjectionPushdown; +use super::update_aggr_exprs::OptimizeAggregateOrder; use crate::config::ConfigOptions; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs new file mode 100644 index 000000000000..6a6ca815c510 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An optimizer rule that checks ordering requirements of aggregate expressions +//! and modifies the expressions to work more efficiently if possible. + +use std::sync::Arc; + +use super::PhysicalOptimizerRule; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_physical_expr::{ + reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalSortRequirement, +}; +use datafusion_physical_plan::aggregates::concat_slices; +use datafusion_physical_plan::windows::get_ordered_partition_by_indices; +use datafusion_physical_plan::{ + aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties, +}; + +/// This optimizer rule checks ordering requirements of aggregate expressions. +/// +/// There are 3 kinds of aggregators in terms of ordering requirements: +/// - `AggregateOrderSensitivity::Insensitive`, meaning that ordering is not +/// important. +/// - `AggregateOrderSensitivity::HardRequirement`, meaning that the aggregator +/// requires a specific ordering. +/// - `AggregateOrderSensitivity::Beneficial`, meaning that the aggregator can +/// handle unordered input, but can run more efficiently if its input conforms +/// to a specific ordering. +/// +/// This rule analyzes aggregate expressions of type `Beneficial` to see whether +/// their input ordering requirements are satisfied. If this is the case, the +/// aggregators are modified to run in a more efficient mode. +#[derive(Default)] +pub struct OptimizeAggregateOrder {} + +impl OptimizeAggregateOrder { + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for OptimizeAggregateOrder { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|plan| { + if let Some(aggr_exec) = plan.as_any().downcast_ref::() { + // Final stage implementations do not rely on ordering -- those + // ordering fields may be pruned out by first stage aggregates. + // Hence, necessary information for proper merge is added during + // the first stage to the state field, which the final stage uses. + if !aggr_exec.mode().is_first_stage() { + return Ok(Transformed::no(plan)); + } + let input = aggr_exec.input(); + let mut aggr_expr = aggr_exec.aggr_expr().to_vec(); + + let groupby_exprs = aggr_exec.group_expr().input_exprs(); + // If the existing ordering satisfies a prefix of the GROUP BY + // expressions, prefix requirements with this section. In this + // case, aggregation will work more efficiently. + let indices = get_ordered_partition_by_indices(&groupby_exprs, input); + let requirement = indices + .iter() + .map(|&idx| { + PhysicalSortRequirement::new(groupby_exprs[idx].clone(), None) + }) + .collect::>(); + + aggr_expr = try_convert_aggregate_if_better( + aggr_expr, + &requirement, + input.equivalence_properties(), + )?; + + let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr); + + Ok(Transformed::yes(Arc::new(aggr_exec) as _)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "OptimizeAggregateOrder" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Tries to convert each aggregate expression to a potentially more efficient +/// version. +/// +/// # Parameters +/// +/// * `aggr_exprs` - A vector of `Arc` representing the +/// aggregate expressions to be optimized. +/// * `prefix_requirement` - An array slice representing the ordering +/// requirements preceding the aggregate expressions. +/// * `eq_properties` - A reference to the `EquivalenceProperties` object +/// containing ordering information. +/// +/// # Returns +/// +/// Returns `Ok(converted_aggr_exprs)` if the conversion process completes +/// successfully. Any errors occuring during the conversion process are +/// passed through. +fn try_convert_aggregate_if_better( + aggr_exprs: Vec>, + prefix_requirement: &[PhysicalSortRequirement], + eq_properties: &EquivalenceProperties, +) -> Result>> { + aggr_exprs + .into_iter() + .map(|aggr_expr| { + let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(&[]); + let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); + let aggr_sort_reqs = + PhysicalSortRequirement::from_sort_exprs(aggr_sort_exprs); + let reverse_aggr_req = + PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_sort_exprs); + + // If the aggregate expression benefits from input ordering, and + // there is an actual ordering enabling this, try to update the + // aggregate expression to benefit from the existing ordering. + // Otherwise, leave it as is. + if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty() + { + let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs); + if eq_properties.ordering_satisfy_requirement(&reqs) { + // Existing ordering satisfies the aggregator requirements: + aggr_expr.with_beneficial_ordering(true)? + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to reverse enables more efficient execution + // given the existing ordering (if possible): + aggr_expr + .reverse_expr() + .unwrap_or(aggr_expr) + .with_beneficial_ordering(true)? + } else { + // There is no beneficial ordering present -- aggregation + // will still work albeit in a less efficient mode. + aggr_expr.with_beneficial_ordering(false)? + } + .ok_or_else(|| { + plan_datafusion_err!( + "Expects an aggregate expression that can benefit from input ordering" + ) + }) + } else { + Ok(aggr_expr) + } + }) + .collect() +} diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d82a5a2cc1a1..0d8d06f49bc3 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -39,6 +39,7 @@ pub use datafusion_expr::{ Expr, }; pub use datafusion_functions::expr_fn::*; +pub use datafusion_functions_aggregate::expr_fn::*; #[cfg(feature = "array_expressions")] pub use datafusion_functions_array::expr_fn::*; diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index f251969ca618..fb5a8db550e3 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -47,10 +47,6 @@ pub enum AggregateFunction { ApproxDistinct, /// Aggregation into an array ArrayAgg, - /// First value in a group according to some ordering - FirstValue, - /// Last value in a group according to some ordering - LastValue, /// N'th value in a group according to some ordering NthValue, /// Variance (Sample) @@ -114,8 +110,6 @@ impl AggregateFunction { Avg => "AVG", ApproxDistinct => "APPROX_DISTINCT", ArrayAgg => "ARRAY_AGG", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", NthValue => "NTH_VALUE", Variance => "VAR", VariancePop => "VAR_POP", @@ -168,8 +162,6 @@ impl FromStr for AggregateFunction { "min" => AggregateFunction::Min, "sum" => AggregateFunction::Sum, "array_agg" => AggregateFunction::ArrayAgg, - "first_value" => AggregateFunction::FirstValue, - "last_value" => AggregateFunction::LastValue, "nth_value" => AggregateFunction::NthValue, "string_agg" => AggregateFunction::StringAgg, // statistical @@ -273,9 +265,7 @@ impl AggregateFunction { } AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::FirstValue - | AggregateFunction::LastValue - | AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), + AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } @@ -329,9 +319,7 @@ impl AggregateFunction { | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop - | AggregateFunction::ApproxMedian - | AggregateFunction::FirstValue - | AggregateFunction::LastValue => { + | AggregateFunction::ApproxMedian => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 5e43c160ba0a..0c05355cde1b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -504,6 +504,15 @@ impl Sort { nulls_first, } } + + /// Create a new Sort expression with the opposite sort direction + pub fn reverse(&self) -> Self { + Self { + expr: self.expr.clone(), + asc: !self.asc, + nulls_first: !self.nulls_first, + } + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index c491a2656470..d0114a472541 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ce4a2a709842..6bd204c53c61 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -283,9 +283,6 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::FirstValue | AggregateFunction::LastValue => { - Ok(input_types.to_vec()) - } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b620a897bcc9..0274038a36bf 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -22,10 +22,11 @@ use crate::function::{ }; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; +use crate::utils::AggregateOrderSensitivity; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -193,6 +194,33 @@ impl AggregateUDF { self.inner.create_groups_accumulator() } + /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details. + pub fn with_beneficial_ordering( + self, + beneficial_ordering: bool, + ) -> Result> { + self.inner + .with_beneficial_ordering(beneficial_ordering) + .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf })) + } + + /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`] + /// for possible options. + pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { + self.inner.order_sensitivity() + } + + /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will + /// generate same result with this `AggregateUDF` when iterated in reverse + /// order, and `None` if there is no such `AggregateUDF`). + pub fn reverse_udf(&self) -> Option { + match self.inner.reverse_expr() { + ReversedUDAF::NotSupported => None, + ReversedUDAF::Identical => Some(self.clone()), + ReversedUDAF::Reversed(reverse) => Some(Self { inner: reverse }), + } + } + pub fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) } @@ -361,6 +389,39 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { &[] } + /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is + /// satisfied by its input. If this is not the case, UDFs with order + /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce + /// the correct result with possibly more work internally. + /// + /// # Returns + /// + /// Returns `Ok(Some(updated_udf))` if the process completes successfully. + /// If the expression can benefit from existing input ordering, but does + /// not implement the method, returns an error. Order insensitive and hard + /// requirement aggregators return `Ok(None)`. + fn with_beneficial_ordering( + self: Arc, + _beneficial_ordering: bool, + ) -> Result>> { + if self.order_sensitivity().is_beneficial() { + return exec_err!( + "Should implement with satisfied for aggregator :{:?}", + self.name() + ); + } + Ok(None) + } + + /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`] + /// for possible options. + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + // We have hard ordering requirements by default, meaning that order + // sensitive UDFs need their input orderings to satisfy their ordering + // requirements to generate correct results. + AggregateOrderSensitivity::HardRequirement + } + /// Optionally apply per-UDaF simplification / rewrite rules. /// /// This can be used to apply function specific simplification rules during diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 0d25a3443f47..e5b7bddab837 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1217,6 +1217,37 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } +/// Represents the sensitivity of an aggregate expression to ordering. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum AggregateOrderSensitivity { + /// Indicates that the aggregate expression is insensitive to ordering. + /// Ordering at the input is not important for the result of the aggregator. + Insensitive, + /// Indicates that the aggregate expression has a hard requirement on ordering. + /// The aggregator can not produce a correct result unless its ordering + /// requirement is satisfied. + HardRequirement, + /// Indicates that ordering is beneficial for the aggregate expression in terms + /// of evaluation efficiency. The aggregator can produce its result efficiently + /// when its required ordering is satisfied; however, it can still produce the + /// correct result (albeit less efficiently) when its required ordering is not met. + Beneficial, +} + +impl AggregateOrderSensitivity { + pub fn is_insensitive(&self) -> bool { + self.eq(&AggregateOrderSensitivity::Insensitive) + } + + pub fn is_beneficial(&self) -> bool { + self.eq(&AggregateOrderSensitivity::Beneficial) + } + + pub fn hard_requires(&self) -> bool { + self.eq(&AggregateOrderSensitivity::HardRequirement) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 5d3d48344014..fd4e21971028 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -17,8 +17,12 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn, SortOptions}; +use arrow::compute::{self, lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; use datafusion_common::{ @@ -26,23 +30,15 @@ use datafusion_common::{ }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_expr::utils::format_state_name; +use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Signature, TypeSignature, + Volatility, }; -use datafusion_physical_expr_common::aggregate::utils::{ - down_cast_any_ref, get_sort_options, ordering_fields, +use datafusion_physical_expr_common::aggregate::utils::get_sort_options; +use datafusion_physical_expr_common::sort_expr::{ + limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, }; -use datafusion_physical_expr_common::aggregate::AggregateExpr; -use datafusion_physical_expr_common::expressions; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr_common::utils::reverse_order_bys; - -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; make_udaf_expr_and_func!( FirstValue, @@ -54,6 +50,7 @@ make_udaf_expr_and_func!( pub struct FirstValue { signature: Signature, aliases: Vec, + requirement_satisfied: bool, } impl Debug for FirstValue { @@ -75,7 +72,7 @@ impl Default for FirstValue { impl FirstValue { pub fn new() -> Self { Self { - aliases: vec![String::from("FIRST_VALUE")], + aliases: vec![String::from("FIRST_VALUE"), String::from("first_value")], signature: Signature::one_of( vec![ // TODO: we can introduce more strict signature that only numeric of array types are allowed @@ -84,8 +81,14 @@ impl FirstValue { ], Volatility::Immutable, ), + requirement_satisfied: false, } } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl AggregateUDFImpl for FirstValue { @@ -106,37 +109,19 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let mut all_sort_orders = vec![]; - - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in acc_args.sort_exprs { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::column::col(name, acc_args.schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - } - } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); - } - - let ordering_req = all_sort_orders; + let ordering_req = limited_convert_logical_sort_exprs_to_physical( + acc_args.sort_exprs, + acc_args.schema, + )?; let ordering_dtypes = ordering_req .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + // When requirement is empty, or it is signalled by outside caller that + // the ordering requirement is/will be satisfied. + let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied; FirstValueAccumulator::try_new( acc_args.data_type, @@ -161,6 +146,23 @@ impl AggregateUDFImpl for FirstValue { fn aliases(&self) -> &[String] { &self.aliases } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new( + FirstValue::new().with_requirement_satisfied(beneficial_ordering), + ))) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Beneficial + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(last_value_udaf().inner()) + } } #[derive(Debug)] @@ -338,355 +340,133 @@ impl Accumulator for FirstValueAccumulator { } } -/// TO BE DEPRECATED: Builtin FIRST_VALUE physical aggregate expression will be replaced by udf in the future -#[derive(Debug, Clone)] -pub struct FirstValuePhysicalExpr { - name: String, - input_data_type: DataType, - order_by_data_types: Vec, - expr: Arc, - ordering_req: LexOrdering, +make_udaf_expr_and_func!( + LastValue, + last_value, + "Returns the last value in a group of values.", + last_value_udaf +); + +pub struct LastValue { + signature: Signature, + aliases: Vec, requirement_satisfied: bool, - ignore_nulls: bool, - state_fields: Vec, } -impl FirstValuePhysicalExpr { - /// Creates a new FIRST_VALUE aggregation function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ordering_req: LexOrdering, - order_by_data_types: Vec, - state_fields: Vec, - ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); - Self { - name: name.into(), - input_data_type, - order_by_data_types, - expr, - ordering_req, - requirement_satisfied, - ignore_nulls: false, - state_fields, - } - } - - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; - self - } - - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name - } - - /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type - } - - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types +impl Debug for LastValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("LastValue") + .field("name", &self.name()) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() } +} - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr +impl Default for LastValue { + fn default() -> Self { + Self::new() } +} - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req +impl LastValue { + pub fn new() -> Self { + Self { + aliases: vec![String::from("LAST_VALUE"), String::from("last_value")], + signature: Signature::one_of( + vec![ + // TODO: we can introduce more strict signature that only numeric of array types are allowed + TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::Uniform(1, NUMERICS.to_vec()), + ], + Volatility::Immutable, + ), + requirement_satisfied: false, + } } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { self.requirement_satisfied = requirement_satisfied; self } - - pub fn convert_to_last(self) -> LastValuePhysicalExpr { - let mut name = format!("LAST{}", &self.name[5..]); - replace_order_by_clause(&mut name); - - let FirstValuePhysicalExpr { - expr, - input_data_type, - ordering_req, - order_by_data_types, - .. - } = self; - LastValuePhysicalExpr::new( - expr, - name, - input_data_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - ) - } } -impl AggregateExpr for FirstValuePhysicalExpr { - /// Return a reference to Any that can be used for downcasting +impl AggregateUDFImpl for LastValue { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } - - fn state_fields(&self) -> Result> { - if !self.state_fields.is_empty() { - return Ok(self.state_fields.clone()); - } - - let mut fields = vec![Field::new( - format_state_name(&self.name, "first_value"), - self.input_data_type.clone(), - true, - )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); - Ok(fields) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_last())) - } - - fn create_sliding_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } -} - -impl PartialEq for FirstValuePhysicalExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// TO BE DEPRECATED: Builtin LAST_VALUE physical aggregate expression will be replaced by udf in the future -#[derive(Debug, Clone)] -pub struct LastValuePhysicalExpr { - name: String, - input_data_type: DataType, - order_by_data_types: Vec, - expr: Arc, - ordering_req: LexOrdering, - requirement_satisfied: bool, - ignore_nulls: bool, -} - -impl LastValuePhysicalExpr { - /// Creates a new LAST_VALUE aggregation function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ordering_req: LexOrdering, - order_by_data_types: Vec, - ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); - Self { - name: name.into(), - input_data_type, - order_by_data_types, - expr, - ordering_req, - requirement_satisfied, - ignore_nulls: false, - } - } - - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; - self - } - - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name + "LAST_VALUE" } - /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type - } - - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types - } - - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr - } - - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req + fn signature(&self) -> &Signature { + &self.signature } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } - pub fn convert_to_first(self) -> FirstValuePhysicalExpr { - let mut name = format!("FIRST{}", &self.name[4..]); - replace_order_by_clause(&mut name); - - let LastValuePhysicalExpr { - expr, - input_data_type, - ordering_req, - order_by_data_types, - .. - } = self; - FirstValuePhysicalExpr::new( - expr, - name, - input_data_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - vec![], - ) - } -} + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let ordering_req = limited_convert_logical_sort_exprs_to_physical( + acc_args.sort_exprs, + acc_args.schema, + )?; -impl AggregateExpr for LastValuePhysicalExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) - } + let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied; - fn create_accumulator(&self) -> Result> { LastValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, + acc_args.data_type, + &ordering_dtypes, + ordering_req, + acc_args.ignore_nulls, ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let StateFieldsArgs { + name, + input_type, + return_type: _, + ordering_fields, + is_distinct: _, + } = args; let mut fields = vec![Field::new( - format_state_name(&self.name, "last_value"), - self.input_data_type.clone(), + format_state_name(name, "last_value"), + input_type.clone(), true, )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); + fields.extend(ordering_fields.to_vec()); + fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn name(&self) -> &str { - &self.name + fn aliases(&self) -> &[String] { + &self.aliases } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_first())) + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new( + LastValue::new().with_requirement_satisfied(beneficial_ordering), + ))) } - fn create_sliding_accumulator(&self) -> Result> { - LastValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Beneficial } -} -impl PartialEq for LastValuePhysicalExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(first_value_udaf().inner()) } } @@ -896,31 +676,6 @@ fn convert_to_sort_cols( .collect::>() } -fn replace_order_by_clause(order_by: &mut String) { - let suffixes = [ - (" DESC NULLS FIRST]", " ASC NULLS LAST]"), - (" ASC NULLS FIRST]", " DESC NULLS LAST]"), - (" DESC NULLS LAST]", " ASC NULLS FIRST]"), - (" ASC NULLS LAST]", " DESC NULLS FIRST]"), - ]; - - if let Some(start) = order_by.find("ORDER BY [") { - if let Some(end) = order_by[start..].find(']') { - let order_by_start = start + 9; - let order_by_end = start + end; - - let column_order = &order_by[order_by_start..=order_by_end]; - for &(suffix, replacement) in &suffixes { - if column_order.ends_with(suffix) { - let new_order = column_order.replace(suffix, replacement); - order_by.replace_range(order_by_start..=order_by_end, &new_order); - break; - } - } - } - } -} - #[cfg(test)] mod tests { use arrow::array::Int64Array; diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 3e80174eec33..ac40a90aaec6 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -76,6 +76,7 @@ pub mod expr_fn { pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ first_last::first_value_udaf(), + first_last::last_value_udaf(), covariance::covar_samp_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 67d5c9b23b74..59c0b476c7cf 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 404f054cb9fa..c232935f9e23 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,11 +23,9 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{internal_err, Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{ - aggregate_function::AggregateFunction as AggregateFunctionFunc, col, - expr::AggregateFunction, LogicalPlanBuilder, -}; +use datafusion_expr::{col, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; +use datafusion_functions_aggregate::first_last::first_value; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -99,17 +97,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { // Construct the aggregation expression to be used to fetch the selected expressions. let aggr_expr = select_expr .into_iter() - .map(|e| { - Expr::AggregateFunction(AggregateFunction::new( - AggregateFunctionFunc::FirstValue, - vec![e], - false, - None, - sort_expr.clone(), - None, - )) - }) - .collect::>(); + .map(|e| first_value(vec![e], false, None, sort_expr.clone(), None)); let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; let group_expr = normalize_cols(on_expr, input.as_ref())?; diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 4e9414bc5a11..503e2d8f9758 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -19,20 +19,22 @@ pub mod groups_accumulator; pub mod stats; pub mod utils; -use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::type_coercion::aggregates::check_arg_count; -use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, -}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; +use self::utils::{down_cast_any_ref, ordering_fields}; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; +use crate::utils::reverse_order_bys; -use self::utils::{down_cast_any_ref, ordering_fields}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::type_coercion::aggregates::check_arg_count; +use datafusion_expr::utils::AggregateOrderSensitivity; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, +}; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. @@ -47,6 +49,7 @@ pub fn create_aggregate_expr( ignore_nulls: bool, is_distinct: bool, ) -> Result> { + debug_assert_eq!(sort_exprs.len(), ordering_req.len()); let input_exprs_types = input_phy_exprs .iter() .map(|arg| arg.data_type(schema)) @@ -117,6 +120,37 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { None } + /// Indicates whether aggregator can produce the correct result with any + /// arbitrary input ordering. By default, we assume that aggregate expressions + /// are order insensitive. + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + /// Sets the indicator whether ordering requirements of the aggregator is + /// satisfied by its input. If this is not the case, aggregators with order + /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce + /// the correct result with possibly more work internally. + /// + /// # Returns + /// + /// Returns `Ok(Some(updated_expr))` if the process completes successfully. + /// If the expression can benefit from existing input ordering, but does + /// not implement the method, returns an error. Order insensitive and hard + /// requirement aggregators return `Ok(None)`. + fn with_beneficial_ordering( + self: Arc, + _requirement_satisfied: bool, + ) -> Result>> { + if self.order_bys().is_some() && self.order_sensitivity().is_beneficial() { + return exec_err!( + "Should implement with satisfied for aggregator :{:?}", + self.name() + ); + } + Ok(None) + } + /// Human readable name such as `"MIN(c2)"`. The default /// implementation returns placeholder text. fn name(&self) -> &str { @@ -305,6 +339,74 @@ impl AggregateExpr for AggregateFunctionExpr { fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + if !self.ordering_req.is_empty() { + // If there is requirement, use the sensitivity of the implementation + self.fun.order_sensitivity() + } else { + // If no requirement, aggregator is order insensitive + AggregateOrderSensitivity::Insensitive + } + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + let Some(updated_fn) = self + .fun + .clone() + .with_beneficial_ordering(beneficial_ordering)? + else { + return Ok(None); + }; + create_aggregate_expr( + &updated_fn, + &self.args, + &self.sort_exprs, + &self.ordering_req, + &self.schema, + self.name(), + self.ignore_nulls, + self.is_distinct, + ) + .map(Some) + } + + fn reverse_expr(&self) -> Option> { + if let Some(reverse_udf) = self.fun.reverse_udf() { + let reverse_ordering_req = reverse_order_bys(&self.ordering_req); + let reverse_sort_exprs = self + .sort_exprs + .iter() + .map(|e| { + if let Expr::Sort(s) = e { + Expr::Sort(s.reverse()) + } else { + // Expects to receive `Expr::Sort`. + unreachable!() + } + }) + .collect::>(); + let mut name = self.name().to_string(); + replace_order_by_clause(&mut name); + replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); + let reverse_aggr = create_aggregate_expr( + &reverse_udf, + &self.args, + &reverse_sort_exprs, + &reverse_ordering_req, + &self.schema, + name, + self.ignore_nulls, + self.is_distinct, + ) + .unwrap(); + return Some(reverse_aggr); + } + None + } } impl PartialEq for AggregateFunctionExpr { @@ -325,3 +427,32 @@ impl PartialEq for AggregateFunctionExpr { .unwrap_or(false) } } + +fn replace_order_by_clause(order_by: &mut String) { + let suffixes = [ + (" DESC NULLS FIRST]", " ASC NULLS LAST]"), + (" ASC NULLS FIRST]", " DESC NULLS LAST]"), + (" DESC NULLS LAST]", " ASC NULLS FIRST]"), + (" ASC NULLS LAST]", " DESC NULLS FIRST]"), + ]; + + if let Some(start) = order_by.find("ORDER BY [") { + if let Some(end) = order_by[start..].find(']') { + let order_by_start = start + 9; + let order_by_end = start + end; + + let column_order = &order_by[order_by_start..=order_by_end]; + for (suffix, replacement) in suffixes { + if column_order.ends_with(suffix) { + let new_order = column_order.replace(suffix, replacement); + order_by.replace_range(order_by_start..=order_by_end, &new_order); + break; + } + } + } + } +} + +fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) { + *aggr_name = aggr_name.replace(fn_name_old, fn_name_new); +} diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr-common/src/expressions/cast.rs similarity index 99% rename from datafusion/physical-expr/src/expressions/cast.rs rename to datafusion/physical-expr-common/src/expressions/cast.rs index 4f940a792bb9..8ef3d16f6334 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr-common/src/expressions/cast.rs @@ -20,8 +20,7 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; +use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; @@ -229,7 +228,8 @@ pub fn cast( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; + + use crate::expressions::column::col; use arrow::{ array::{ diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index d102422081dc..4b5965e164b5 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -15,4 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod cast; pub mod column; + +pub use cast::{cast, cast_with_options, CastExpr}; diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 1e1187212d96..f637355519af 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -21,13 +21,14 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::physical_expr::PhysicalExpr; +use crate::utils::limited_convert_logical_expr_to_physical_expr; + use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; - -use crate::physical_expr::PhysicalExpr; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Expr}; /// Represents Sort operation for a column in a RecordBatch #[derive(Clone, Debug)] @@ -267,3 +268,29 @@ pub type LexRequirement = Vec; ///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which /// represents a reference to a lexicographical ordering requirement. pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; + +/// Converts each [`Expr::Sort`] into a corresponding [`PhysicalSortExpr`]. +/// Returns an error if the given logical expression is not a [`Expr::Sort`]. +pub fn limited_convert_logical_sort_exprs_to_physical( + exprs: &[Expr], + schema: &Schema, +) -> Result> { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + let Expr::Sort(sort) = expr else { + return exec_err!("Expects to receive sort expression"); + }; + sort_exprs.push(PhysicalSortExpr { + expr: limited_convert_logical_expr_to_physical_expr( + sort.expr.as_ref(), + schema, + )?, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + Ok(sort_exprs) +} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 487aba945aa5..f661400fcb10 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,14 +17,17 @@ use std::sync::Arc; -use crate::{ - physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr, tree_node::ExprContext, -}; +use crate::expressions::{self, CastExpr}; +use crate::physical_expr::PhysicalExpr; +use crate::sort_expr::PhysicalSortExpr; +use crate::tree_node::ExprContext; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; -use datafusion_common::Result; +use arrow::datatypes::Schema; +use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::Expr; /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. @@ -105,15 +108,41 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec`. +/// If conversion is not supported yet, returns Error. +pub fn limited_convert_logical_expr_to_physical_expr( + expr: &Expr, + schema: &Schema, +) -> Result> { + match expr { + Expr::Column(col) => expressions::column::col(&col.name, schema), + Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( + limited_convert_logical_expr_to_physical_expr( + cast_expr.expr.as_ref(), + schema, + )?, + cast_expr.data_type.clone(), + None, + ))), + Expr::Alias(alias_expr) => limited_convert_logical_expr_to_physical_expr( + alias_expr.expr.as_ref(), + schema, + ), + _ => exec_err!( + "Unsupported expression: {expr} for conversion to Arc" + ), + } +} + #[cfg(test)] mod tests { use std::sync::Arc; + use super::*; + use arrow::array::Int32Array; use datafusion_common::cast::{as_boolean_array, as_int32_array}; - use super::*; - #[test] fn scatter_int() -> Result<()> { let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 7e2c7bb27144..837a9d551153 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -30,15 +30,13 @@ use crate::{ reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, }; -use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; -use arrow_array::{new_empty_array, StructArray}; +use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; use arrow_schema::{Fields, SortOptions}; - -use datafusion_common::utils::array_into_list_array; -use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::utils::{array_into_list_array, compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; /// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi @@ -131,6 +129,10 @@ impl AggregateExpr for OrderSensitiveArrayAgg { (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::HardRequirement + } + fn name(&self) -> &str { &self.name } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 18252ea370eb..e10008995463 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -28,15 +28,14 @@ use std::sync::Arc; -use arrow::datatypes::Schema; - -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; - use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; +use arrow::datatypes::Schema; +use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_expr::AggregateFunction; + /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -46,7 +45,7 @@ pub fn create_aggregate_expr( ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, - ignore_nulls: bool, + _ignore_nulls: bool, ) -> Result> { let name = name.into(); // get the result data type for this aggregate function @@ -332,27 +331,6 @@ pub fn create_aggregate_expr( "APPROX_MEDIAN(DISTINCT) aggregations are not available" ); } - (AggregateFunction::FirstValue, _) => Arc::new( - expressions::FirstValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - vec![], - ) - .with_ignore_nulls(ignore_nulls), - ), - (AggregateFunction::LastValue, _) => Arc::new( - expressions::LastValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - ) - .with_ignore_nulls(ignore_nulls), - ), (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] @@ -396,17 +374,16 @@ pub fn create_aggregate_expr( mod tests { use arrow::datatypes::{DataType, Field}; - use datafusion_common::{plan_err, DataFusionError, ScalarValue}; - use datafusion_expr::type_coercion::aggregates::NUMERICS; - use datafusion_expr::{type_coercion, Signature}; - + use super::*; use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; - use super::*; + use datafusion_common::{plan_err, DataFusionError, ScalarValue}; + use datafusion_expr::type_coercion::aggregates::NUMERICS; + use datafusion_expr::{type_coercion, Signature}; #[test] fn test_count_arragg_approx_expr() -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 039c8814e987..d8220db4d90d 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::expressions::{NthValueAgg, OrderSensitiveArrayAgg}; - pub use datafusion_physical_expr_common::aggregate::AggregateExpr; mod hyperloglog; @@ -59,11 +55,3 @@ pub mod utils { get_sort_options, ordering_fields, DecimalAverager, Hashable, }; } - -/// Checks whether the given aggregate expression is order-sensitive. -/// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. -/// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. -pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { - aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() -} diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs index dba259a507fd..ee7426a897b3 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -34,6 +34,7 @@ use arrow_array::{new_empty_array, ArrayRef, StructArray}; use arrow_schema::{DataType, Field, Fields}; use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi @@ -125,6 +126,10 @@ impl AggregateExpr for NthValueAgg { (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::HardRequirement + } + fn name(&self) -> &str { &self.name } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 016c4c4ae107..7bf389ecfdf3 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -22,7 +22,7 @@ use super::ordering::collapse_lex_ordering; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::{CastExpr, Literal}; +use crate::expressions::Literal; use crate::{ physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, @@ -35,6 +35,7 @@ use datafusion_common::{JoinSide, JoinType, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::CastExpr; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::{IndexMap, IndexSet}; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index a7921800fccd..1e9644f75afe 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -20,7 +20,6 @@ #[macro_use] mod binary; mod case; -mod cast; mod column; mod datum; mod in_list; @@ -53,8 +52,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; -pub use crate::aggregate::min_max::{Max, Min}; -pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; +pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; @@ -63,26 +61,17 @@ pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; -pub use crate::window::cume_dist::cume_dist; -pub use crate::window::cume_dist::CumeDist; -pub use crate::window::lead_lag::WindowShift; -pub use crate::window::lead_lag::{lag, lead}; +pub use crate::window::cume_dist::{cume_dist, CumeDist}; +pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; pub use crate::window::ntile::Ntile; -pub use crate::window::rank::{dense_rank, percent_rank, rank}; -pub use crate::window::rank::{Rank, RankType}; +pub use crate::window::rank::{dense_rank, percent_rank, rank, Rank, RankType}; pub use crate::window::row_number::RowNumber; pub use crate::PhysicalSortExpr; -pub use datafusion_functions_aggregate::first_last::{ - FirstValuePhysicalExpr as FirstValue, LastValuePhysicalExpr as LastValue, -}; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; -pub use cast::{cast, cast_with_options, CastExpr}; pub use column::UnKnownColumn; -pub use datafusion_expr::utils::format_state_name; -pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; @@ -93,11 +82,17 @@ pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; +pub use datafusion_expr::utils::format_state_name; +pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; +pub use datafusion_physical_expr_common::expressions::column::{col, Column}; +pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; + #[cfg(test)] pub(crate) mod tests { use std::sync::Arc; use crate::AggregateExpr; + use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b0e2af82e6e2..2bb95852ff43 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -39,15 +39,11 @@ use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; -use datafusion_physical_expr::aggregate::is_order_sensitive; -use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ - equivalence::ProjectionMapping, + equivalence::{collapse_lex_req, ProjectionMapping}, expressions::{Column, Max, Min, UnKnownColumn}, - AggregateExpr, LexRequirement, PhysicalExpr, -}; -use datafusion_physical_expr::{ - physical_exprs_contains, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, + physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, + LexRequirement, PhysicalExpr, PhysicalSortRequirement, }; use itertools::Itertools; @@ -274,20 +270,15 @@ pub struct AggregateExec { impl AggregateExec { /// Function used in `ConvertFirstLast` optimizer rule, /// where we need parts of the new value, others cloned from the old one - pub fn new_with_aggr_expr_and_ordering_info( - &self, - required_input_ordering: Option, - aggr_expr: Vec>, - cache: PlanProperties, - input_order_mode: InputOrderMode, - ) -> Self { + /// Rewrites aggregate exec with new aggregate expressions. + pub fn with_new_aggr_exprs(&self, aggr_expr: Vec>) -> Self { Self { aggr_expr, - required_input_ordering, - metrics: ExecutionPlanMetricsSet::new(), - input_order_mode, - cache, // clone the rest of the fields + required_input_ordering: self.required_input_ordering.clone(), + metrics: ExecutionPlanMetricsSet::new(), + input_order_mode: self.input_order_mode.clone(), + cache: self.cache.clone(), mode: self.mode, group_by: self.group_by.clone(), filter_expr: self.filter_expr.clone(), @@ -844,11 +835,10 @@ fn get_aggregate_expr_req( group_by: &PhysicalGroupBy, agg_mode: &AggregateMode, ) -> LexOrdering { - // If the aggregation function is not order sensitive, or the aggregation - // is performing a "second stage" calculation, or all aggregate function - // requirements are inside the GROUP BY expression, then ignore the ordering - // requirement. - if !is_order_sensitive(aggr_expr) || !agg_mode.is_first_stage() { + // If the aggregation function is ordering requirement is not absolutely + // necessary, or the aggregation is performing a "second stage" calculation, + // then ignore the ordering requirement. + if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() { return vec![]; } @@ -1203,11 +1193,12 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_expr::expr::Sort; use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ lit, ApproxDistinct, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, }; - use datafusion_physical_expr::{reverse_order_bys, PhysicalSortExpr}; + use datafusion_physical_expr::PhysicalSortExpr; use futures::{FutureExt, Stream}; @@ -1958,6 +1949,66 @@ mod tests { Ok(()) } + // FIRST_VALUE(b ORDER BY b ) + fn test_first_value_agg_expr( + schema: &Schema, + sort_options: SortOptions, + ) -> Result> { + let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { + expr: Box::new(datafusion_expr::Expr::Column( + datafusion_common::Column::new(Some("table1"), "b"), + )), + asc: !sort_options.descending, + nulls_first: sort_options.nulls_first, + })]; + let ordering_req = vec![PhysicalSortExpr { + expr: col("b", schema)?, + options: sort_options, + }]; + let args = vec![col("b", schema)?]; + let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); + datafusion_physical_expr_common::aggregate::create_aggregate_expr( + &func, + &args, + &sort_exprs, + &ordering_req, + schema, + "FIRST_VALUE(b)", + false, + false, + ) + } + + // LAST_VALUE(b ORDER BY b ) + fn test_last_value_agg_expr( + schema: &Schema, + sort_options: SortOptions, + ) -> Result> { + let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { + expr: Box::new(datafusion_expr::Expr::Column( + datafusion_common::Column::new(Some("table1"), "b"), + )), + asc: !sort_options.descending, + nulls_first: sort_options.nulls_first, + })]; + let ordering_req = vec![PhysicalSortExpr { + expr: col("b", schema)?, + options: sort_options, + }]; + let args = vec![col("b", schema)?]; + let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); + datafusion_physical_expr_common::aggregate::create_aggregate_expr( + &func, + &args, + &sort_exprs, + &ordering_req, + schema, + "LAST_VALUE(b)", + false, + false, + ) + } + // This function either constructs the physical plan below, // // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", @@ -1995,27 +2046,14 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let ordering_req = vec![PhysicalSortExpr { - expr: col("b", &schema)?, - options: SortOptions::default(), - }]; + let sort_options = SortOptions { + descending: false, + nulls_first: false, + }; let aggregates: Vec> = if is_first_acc { - vec![Arc::new(FirstValue::new( - col("b", &schema)?, - "FIRST_VALUE(b)".to_string(), - DataType::Float64, - ordering_req.clone(), - vec![DataType::Float64], - vec![], - ))] + vec![test_first_value_agg_expr(&schema, sort_options)?] } else { - vec![Arc::new(LastValue::new( - col("b", &schema)?, - "LAST_VALUE(b)".to_string(), - DataType::Float64, - ordering_req.clone(), - vec![DataType::Float64], - ))] + vec![test_last_value_agg_expr(&schema, sort_options)?] }; let memory_exec = Arc::new(MemoryExec::try_new( @@ -2170,34 +2208,15 @@ mod tests { ])); let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; let option_desc = SortOptions { descending: true, nulls_first: true, }; - let sort_expr = vec![PhysicalSortExpr { - expr: col_b.clone(), - options: option_desc, - }]; - let sort_expr_reverse = reverse_order_bys(&sort_expr); let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); let aggregates: Vec> = vec![ - Arc::new(FirstValue::new( - col_b.clone(), - "FIRST_VALUE(b)".to_string(), - DataType::Float64, - sort_expr_reverse.clone(), - vec![DataType::Float64], - vec![], - )), - Arc::new(LastValue::new( - col_b.clone(), - "LAST_VALUE(b)".to_string(), - DataType::Float64, - sort_expr.clone(), - vec![DataType::Float64], - )), + test_first_value_agg_expr(&schema, option_desc)?, + test_last_value_agg_expr(&schema, option_desc)?, ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let aggregate_exec = Arc::new(AggregateExec::try_new( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 434ec9f81f15..fecfa2bc33ae 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -557,10 +557,6 @@ enum AggregateFunction { BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; - // When a function with the same name exists among built-in window functions, - // we append "_AGG" to obey name scoping rules. - FIRST_VALUE_AGG = 24; - LAST_VALUE_AGG = 25; REGR_SLOPE = 26; REGR_INTERCEPT = 27; REGR_COUNT = 28; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 86a5975c8bb8..91bf3170e51f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -442,8 +442,6 @@ impl serde::Serialize for AggregateFunction { Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::FirstValueAgg => "FIRST_VALUE_AGG", - Self::LastValueAgg => "LAST_VALUE_AGG", Self::RegrSlope => "REGR_SLOPE", Self::RegrIntercept => "REGR_INTERCEPT", Self::RegrCount => "REGR_COUNT", @@ -487,8 +485,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR", "BOOL_AND", "BOOL_OR", - "FIRST_VALUE_AGG", - "LAST_VALUE_AGG", "REGR_SLOPE", "REGR_INTERCEPT", "REGR_COUNT", @@ -561,8 +557,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg), - "LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg), "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), "REGR_COUNT" => Ok(AggregateFunction::RegrCount), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cb2de710075a..979ce692450e 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2854,10 +2854,6 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, - /// When a function with the same name exists among built-in window functions, - /// we append "_AGG" to obey name scoping rules. - FirstValueAgg = 24, - LastValueAgg = 25, RegrSlope = 26, RegrIntercept = 27, RegrCount = 28, @@ -2900,8 +2896,6 @@ impl AggregateFunction { AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG", - AggregateFunction::LastValueAgg => "LAST_VALUE_AGG", AggregateFunction::RegrSlope => "REGR_SLOPE", AggregateFunction::RegrIntercept => "REGR_INTERCEPT", AggregateFunction::RegrCount => "REGR_COUNT", @@ -2941,8 +2935,6 @@ impl AggregateFunction { "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "FIRST_VALUE_AGG" => Some(Self::FirstValueAgg), - "LAST_VALUE_AGG" => Some(Self::LastValueAgg), "REGR_SLOPE" => Some(Self::RegrSlope), "REGR_INTERCEPT" => Some(Self::RegrIntercept), "REGR_COUNT" => Some(Self::RegrCount), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 00c62fc32b98..eaba9c0c12ce 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -450,8 +450,6 @@ impl From for AggregateFunction { } protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, - protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, - protobuf::AggregateFunction::LastValueAgg => Self::LastValue, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index f2ee679ac129..16ba166d9f47 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -386,8 +386,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { } AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, - AggregateFunction::FirstValue => Self::FirstValueAgg, - AggregateFunction::LastValue => Self::LastValueAgg, AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, } @@ -696,12 +694,6 @@ pub fn serialize_expr( protobuf::AggregateFunction::ApproxMedian } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d3badee3efff..c0da4cc0cdd4 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -26,11 +26,10 @@ use datafusion::physical_plan::expressions::{ ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, - DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, IsNotNullExpr, - IsNullExpr, LastValue, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, - NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, - RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, Variance, VariancePop, - WindowShift, + DistinctCount, DistinctSum, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, + Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, + StringAgg, Sum, TryCastExpr, Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -318,10 +317,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::ApproxPercentileContWithWeight } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ApproxMedian - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::FirstValueAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::LastValueAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d83d6cd1c297..4e2534227ef9 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,8 +32,6 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; -use datafusion::functions_aggregate::expr_fn::first_value; -use datafusion::functions_aggregate::median::median; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions};