Skip to content

Commit 4faae4e

Browse files
jcsherinalamb
authored andcommitted
Convert nth_value to UDAF (apache#11287)
* Copies `NthValueAccumulator` to `functions-aggregate` * Partial implementation of `AggregateUDFImpl` Pending methods are: - `accumulator` - `state_fields` - `reverse_expr` * Implements `accumulator` method * Retains existing comments verbatim * Removes unnecessary path prefix * Implements `reverse_expr` method * Adds `nullable` field to `NthValue` * Revert to existing name * Implements `state_fields` method * Removes `nth_value` from `physical-expr` * Adds default * Exports `nth_value` * Fixes build error in physical plan roundtrip test * Minor: formatting * Parses `N` from input expression * Fixes build error by using `nth_value_udaf` * Fixes `reverse_expr` by passing correct `N` * Update plan with lowercase UDF name * Updates error message for incorrect no. of arguments This error message is manually formatted to remain consistent with existing error statements. It is not formatted by running: ``` cargo test -p datafusion-sqllogictest --test sqllogictests errors -- --complete ``` * Fixes nullable "item" in `state_fields` * Minor: fix formatting after resolving conflicts * Updates multiple existing plans with lowercase name * Implements `retract_batch` for window aggregations * Fixes: regex mismatch for error message in CI * Revert "Updates multiple existing plans with lowercase name" This reverts commit 1913efda49e585816286b54b371d4166ac894d1f. * Revert "Implements `retract_batch` for window aggregations" This reverts commit 4bb204f6ec8028c4e3313db5af3fabfcdaf7fea8. * Fixes: use builtin window function instead of udaf * Revert "Updates error message for incorrect no. of arguments" This reverts commit fa61ce62dcae6eae6f8e9c9900ebf8cff5023bc0. * Refactor: renames field and method * Removes hack for nullability * Minor: refactors `reverse_expr` * Minor: removes unncessary path prefix * Minor: cleanup arguments for creating aggregate expr * Refactor: extracts `merge_ordered_arrays` to `physical-expr-common` * Minor: adds todo for configuring nullability * Retrigger CI --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 37da1e4 commit 4faae4e

File tree

22 files changed

+350
-357
lines changed

22 files changed

+350
-357
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/expr/src/aggregate_function.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ pub enum AggregateFunction {
3939
Max,
4040
/// Aggregation into an array
4141
ArrayAgg,
42-
/// N'th value in a group according to some ordering
43-
NthValue,
4442
}
4543

4644
impl AggregateFunction {
@@ -50,7 +48,6 @@ impl AggregateFunction {
5048
Min => "MIN",
5149
Max => "MAX",
5250
ArrayAgg => "ARRAY_AGG",
53-
NthValue => "NTH_VALUE",
5451
}
5552
}
5653
}
@@ -69,7 +66,6 @@ impl FromStr for AggregateFunction {
6966
"max" => AggregateFunction::Max,
7067
"min" => AggregateFunction::Min,
7168
"array_agg" => AggregateFunction::ArrayAgg,
72-
"nth_value" => AggregateFunction::NthValue,
7369
_ => {
7470
return plan_err!("There is no built-in function named {name}");
7571
}
@@ -114,7 +110,6 @@ impl AggregateFunction {
114110
coerced_data_types[0].clone(),
115111
input_expr_nullable[0],
116112
)))),
117-
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
118113
}
119114
}
120115

@@ -124,7 +119,6 @@ impl AggregateFunction {
124119
match self {
125120
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
126121
AggregateFunction::ArrayAgg => Ok(false),
127-
AggregateFunction::NthValue => Ok(true),
128122
}
129123
}
130124
}
@@ -147,7 +141,6 @@ impl AggregateFunction {
147141
.collect::<Vec<_>>();
148142
Signature::uniform(1, valid, Volatility::Immutable)
149143
}
150-
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
151144
}
152145
}
153146
}

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ pub fn coerce_types(
101101
// unpack the dictionary to get the value
102102
get_min_max_result_type(input_types)
103103
}
104-
AggregateFunction::NthValue => Ok(input_types.to_vec()),
105104
}
106105
}
107106

datafusion/functions-aggregate/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ pub mod average;
7474
pub mod bit_and_or_xor;
7575
pub mod bool_and_or;
7676
pub mod grouping;
77+
pub mod nth_value;
7778
pub mod string_agg;
7879

7980
use crate::approx_percentile_cont::approx_percentile_cont_udaf;
@@ -105,6 +106,7 @@ pub mod expr_fn {
105106
pub use super::first_last::last_value;
106107
pub use super::grouping::grouping;
107108
pub use super::median::median;
109+
pub use super::nth_value::nth_value;
108110
pub use super::regr::regr_avgx;
109111
pub use super::regr::regr_avgy;
110112
pub use super::regr::regr_count;
@@ -157,6 +159,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
157159
bool_and_or::bool_or_udaf(),
158160
average::avg_udaf(),
159161
grouping::grouping_udaf(),
162+
nth_value::nth_value_udaf(),
160163
]
161164
}
162165

datafusion/physical-expr/src/aggregate/nth_value.rs renamed to datafusion/functions-aggregate/src/nth_value.rs

Lines changed: 103 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -22,149 +22,149 @@ use std::any::Any;
2222
use std::collections::VecDeque;
2323
use std::sync::Arc;
2424

25-
use crate::aggregate::array_agg_ordered::merge_ordered_arrays;
26-
use crate::aggregate::utils::{down_cast_any_ref, ordering_fields};
27-
use crate::expressions::{format_state_name, Literal};
28-
use crate::{
29-
reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr,
30-
};
31-
32-
use arrow_array::cast::AsArray;
33-
use arrow_array::{new_empty_array, ArrayRef, StructArray};
25+
use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
3426
use arrow_schema::{DataType, Field, Fields};
27+
3528
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
36-
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
37-
use datafusion_expr::utils::AggregateOrderSensitivity;
38-
use datafusion_expr::Accumulator;
29+
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
30+
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
31+
use datafusion_expr::utils::format_state_name;
32+
use datafusion_expr::{
33+
Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature,
34+
Volatility,
35+
};
36+
use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays;
37+
use datafusion_physical_expr_common::aggregate::utils::ordering_fields;
38+
use datafusion_physical_expr_common::sort_expr::{
39+
limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr,
40+
};
41+
42+
make_udaf_expr_and_func!(
43+
NthValueAgg,
44+
nth_value,
45+
"Returns the nth value in a group of values.",
46+
nth_value_udaf
47+
);
3948

4049
/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
4150
/// partition setting, partial aggregations are computed for every partition,
4251
/// and then their results are merged.
4352
#[derive(Debug)]
4453
pub struct NthValueAgg {
45-
/// Column name
46-
name: String,
47-
/// The `DataType` for the input expression
48-
input_data_type: DataType,
49-
/// The input expression
50-
expr: Arc<dyn PhysicalExpr>,
51-
/// The `N` value.
52-
n: i64,
53-
/// If the input expression can have `NULL`s
54-
nullable: bool,
55-
/// Ordering data types
56-
order_by_data_types: Vec<DataType>,
57-
/// Ordering requirement
58-
ordering_req: LexOrdering,
54+
signature: Signature,
55+
/// Determines whether `N` is relative to the beginning or the end
56+
/// of the aggregation. When set to `true`, then `N` is from the end.
57+
reversed: bool,
5958
}
6059

6160
impl NthValueAgg {
6261
/// Create a new `NthValueAgg` aggregate function
63-
pub fn new(
64-
expr: Arc<dyn PhysicalExpr>,
65-
n: i64,
66-
name: impl Into<String>,
67-
input_data_type: DataType,
68-
nullable: bool,
69-
order_by_data_types: Vec<DataType>,
70-
ordering_req: LexOrdering,
71-
) -> Self {
62+
pub fn new() -> Self {
7263
Self {
73-
name: name.into(),
74-
input_data_type,
75-
expr,
76-
n,
77-
nullable,
78-
order_by_data_types,
79-
ordering_req,
64+
signature: Signature::any(2, Volatility::Immutable),
65+
reversed: false,
8066
}
8167
}
68+
69+
pub fn with_reversed(mut self, reversed: bool) -> Self {
70+
self.reversed = reversed;
71+
self
72+
}
8273
}
8374

84-
impl AggregateExpr for NthValueAgg {
75+
impl Default for NthValueAgg {
76+
fn default() -> Self {
77+
Self::new()
78+
}
79+
}
80+
81+
impl AggregateUDFImpl for NthValueAgg {
8582
fn as_any(&self) -> &dyn Any {
8683
self
8784
}
8885

89-
fn field(&self) -> Result<Field> {
90-
Ok(Field::new(&self.name, self.input_data_type.clone(), true))
86+
fn name(&self) -> &str {
87+
"nth_value"
88+
}
89+
90+
fn signature(&self) -> &Signature {
91+
&self.signature
9192
}
9293

93-
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
94-
Ok(Box::new(NthValueAccumulator::try_new(
95-
self.n,
96-
&self.input_data_type,
97-
&self.order_by_data_types,
98-
self.ordering_req.clone(),
99-
)?))
94+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95+
Ok(arg_types[0].clone())
10096
}
10197

102-
fn state_fields(&self) -> Result<Vec<Field>> {
98+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
99+
let n = match acc_args.input_exprs[1] {
100+
Expr::Literal(ScalarValue::Int64(Some(value))) => {
101+
if self.reversed {
102+
Ok(-value)
103+
} else {
104+
Ok(value)
105+
}
106+
}
107+
_ => not_impl_err!(
108+
"{} not supported for n: {}",
109+
self.name(),
110+
&acc_args.input_exprs[1]
111+
),
112+
}?;
113+
114+
let ordering_req = limited_convert_logical_sort_exprs_to_physical(
115+
acc_args.sort_exprs,
116+
acc_args.schema,
117+
)?;
118+
119+
let ordering_dtypes = ordering_req
120+
.iter()
121+
.map(|e| e.expr.data_type(acc_args.schema))
122+
.collect::<Result<Vec<_>>>()?;
123+
124+
NthValueAccumulator::try_new(
125+
n,
126+
acc_args.input_type,
127+
&ordering_dtypes,
128+
ordering_req,
129+
)
130+
.map(|acc| Box::new(acc) as _)
131+
}
132+
133+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
103134
let mut fields = vec![Field::new_list(
104-
format_state_name(&self.name, "nth_value"),
105-
Field::new("item", self.input_data_type.clone(), true),
106-
self.nullable, // This should be the same as field()
135+
format_state_name(self.name(), "nth_value"),
136+
// TODO: The nullability of the list element should be configurable.
137+
// The hard-coded `true` should be changed once the field for
138+
// nullability is added to `StateFieldArgs` struct.
139+
// See: https://github.com/apache/datafusion/pull/11063
140+
Field::new("item", args.input_type.clone(), true),
141+
false,
107142
)];
108-
if !self.ordering_req.is_empty() {
109-
let orderings =
110-
ordering_fields(&self.ordering_req, &self.order_by_data_types);
143+
let orderings = args.ordering_fields.to_vec();
144+
if !orderings.is_empty() {
111145
fields.push(Field::new_list(
112-
format_state_name(&self.name, "nth_value_orderings"),
146+
format_state_name(self.name(), "nth_value_orderings"),
113147
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
114-
self.nullable,
148+
false,
115149
));
116150
}
117151
Ok(fields)
118152
}
119153

120-
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
121-
let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _;
122-
vec![Arc::clone(&self.expr), n]
154+
fn aliases(&self) -> &[String] {
155+
&[]
123156
}
124157

125-
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
126-
(!self.ordering_req.is_empty()).then_some(&self.ordering_req)
127-
}
128-
129-
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
130-
AggregateOrderSensitivity::HardRequirement
131-
}
132-
133-
fn name(&self) -> &str {
134-
&self.name
135-
}
136-
137-
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
138-
Some(Arc::new(Self {
139-
name: self.name.to_string(),
140-
input_data_type: self.input_data_type.clone(),
141-
expr: Arc::clone(&self.expr),
142-
// index should be from the opposite side
143-
n: -self.n,
144-
nullable: self.nullable,
145-
order_by_data_types: self.order_by_data_types.clone(),
146-
// reverse requirement
147-
ordering_req: reverse_order_bys(&self.ordering_req),
148-
}) as _)
149-
}
150-
}
151-
152-
impl PartialEq<dyn Any> for NthValueAgg {
153-
fn eq(&self, other: &dyn Any) -> bool {
154-
down_cast_any_ref(other)
155-
.downcast_ref::<Self>()
156-
.map(|x| {
157-
self.name == x.name
158-
&& self.input_data_type == x.input_data_type
159-
&& self.order_by_data_types == x.order_by_data_types
160-
&& self.expr.eq(&x.expr)
161-
})
162-
.unwrap_or(false)
158+
fn reverse_expr(&self) -> ReversedUDAF {
159+
ReversedUDAF::Reversed(Arc::from(AggregateUDF::from(
160+
Self::new().with_reversed(!self.reversed),
161+
)))
163162
}
164163
}
165164

166165
#[derive(Debug)]
167-
pub(crate) struct NthValueAccumulator {
166+
pub struct NthValueAccumulator {
167+
/// The `N` value.
168168
n: i64,
169169
/// Stores entries in the `NTH_VALUE` result.
170170
values: VecDeque<ScalarValue>,

datafusion/functions-array/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ datafusion-common = { workspace = true }
4949
datafusion-execution = { workspace = true }
5050
datafusion-expr = { workspace = true }
5151
datafusion-functions = { workspace = true }
52+
datafusion-functions-aggregate = { workspace = true }
5253
itertools = { version = "0.12", features = ["use_std"] }
5354
log = { workspace = true }
5455
paste = "1.0.14"

datafusion/functions-array/src/planner.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use datafusion_expr::{
2323
sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess,
2424
};
2525
use datafusion_functions::expr_fn::get_field;
26+
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
2627

2728
use crate::{
2829
array_has::array_has_all,
@@ -119,8 +120,8 @@ impl UserDefinedSQLPlanner for FieldAccessPlanner {
119120
// Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
120121
Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => {
121122
Ok(PlannerResult::Planned(Expr::AggregateFunction(
122-
datafusion_expr::expr::AggregateFunction::new(
123-
AggregateFunction::NthValue,
123+
datafusion_expr::expr::AggregateFunction::new_udf(
124+
nth_value_udaf(),
124125
agg_func
125126
.args
126127
.into_iter()

0 commit comments

Comments
 (0)