Skip to content

Convert nth_value to UDAF #11287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
bef3ede
Copies `NthValueAccumulator` to `functions-aggregate`
jcsherin Jun 27, 2024
c684040
Partial implementation of `AggregateUDFImpl`
jcsherin Jun 27, 2024
165c4f5
Implements `accumulator` method
jcsherin Jun 27, 2024
e82f055
Retains existing comments verbatim
jcsherin Jun 27, 2024
a45349f
Removes unnecessary path prefix
jcsherin Jun 27, 2024
9c9a6c4
Implements `reverse_expr` method
jcsherin Jun 27, 2024
60370ca
Adds `nullable` field to `NthValue`
jcsherin Jun 27, 2024
3631812
Revert to existing name
jcsherin Jun 27, 2024
98039e5
Implements `state_fields` method
jcsherin Jun 27, 2024
415d9db
Removes `nth_value` from `physical-expr`
jcsherin Jun 27, 2024
d9ebdbe
Adds default
jcsherin Jun 28, 2024
c4e5417
Exports `nth_value`
jcsherin Jun 28, 2024
f74459b
Fixes build error in physical plan roundtrip test
jcsherin Jun 28, 2024
0fc98f1
Minor: formatting
jcsherin Jun 28, 2024
02c01fc
Parses `N` from input expression
jcsherin Jun 28, 2024
729d9c5
Fixes build error by using `nth_value_udaf`
jcsherin Jul 2, 2024
40b4607
Fixes `reverse_expr` by passing correct `N`
jcsherin Jul 2, 2024
a86ca1f
Update plan with lowercase UDF name
jcsherin Jul 2, 2024
02f4497
Updates error message for incorrect no. of arguments
jcsherin Jul 2, 2024
2e90028
Fixes nullable "item" in `state_fields`
jcsherin Jul 2, 2024
972f118
Minor: fix formatting after resolving conflicts
jcsherin Jul 3, 2024
cf974c0
Updates multiple existing plans with lowercase name
jcsherin Jul 3, 2024
427a8bb
Implements `retract_batch` for window aggregations
jcsherin Jul 3, 2024
488881f
Fixes: regex mismatch for error message in CI
jcsherin Jul 3, 2024
91bacb7
Revert "Updates multiple existing plans with lowercase name"
jcsherin Jul 5, 2024
97e8955
Revert "Implements `retract_batch` for window aggregations"
jcsherin Jul 5, 2024
4634641
Fixes: use builtin window function instead of udaf
jcsherin Jul 5, 2024
7b57cce
Revert "Updates error message for incorrect no. of arguments"
jcsherin Jul 5, 2024
20a804b
Refactor: renames field and method
jcsherin Jul 5, 2024
cc08872
Removes hack for nullability
jcsherin Jul 5, 2024
35b0c0d
Minor: refactors `reverse_expr`
jcsherin Jul 5, 2024
f6215d9
Minor: removes unncessary path prefix
jcsherin Jul 5, 2024
5874f0f
Minor: cleanup arguments for creating aggregate expr
jcsherin Jul 6, 2024
6ce6679
Refactor: extracts `merge_ordered_arrays` to `physical-expr-common`
jcsherin Jul 6, 2024
5ea5421
Minor: adds todo for configuring nullability
jcsherin Jul 6, 2024
d0a3c3d
Retrigger CI
jcsherin Jul 6, 2024
e3a8026
Merge branch 'main' into convert-udaf-nth-value
jcsherin Jul 6, 2024
1990bad
Merge remote-tracking branch 'apache/main' into convert-udaf-nth-value
alamb Jul 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 0 additions & 7 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ pub enum AggregateFunction {
Max,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this list is quite close to empty 🤞

/// Aggregation into an array
ArrayAgg,
/// N'th value in a group according to some ordering
NthValue,
}

impl AggregateFunction {
Expand All @@ -50,7 +48,6 @@ impl AggregateFunction {
Min => "MIN",
Max => "MAX",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
}
}
}
Expand All @@ -69,7 +66,6 @@ impl FromStr for AggregateFunction {
"max" => AggregateFunction::Max,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
_ => {
return plan_err!("There is no built-in function named {name}");
}
Expand Down Expand Up @@ -114,7 +110,6 @@ impl AggregateFunction {
coerced_data_types[0].clone(),
input_expr_nullable[0],
)))),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
}
}

Expand All @@ -124,7 +119,6 @@ impl AggregateFunction {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(false),
AggregateFunction::NthValue => Ok(true),
}
}
}
Expand All @@ -147,7 +141,6 @@ impl AggregateFunction {
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
}
}
}
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ pub fn coerce_types(
// unpack the dictionary to get the value
get_min_max_result_type(input_types)
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
}
}

Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub mod average;
pub mod bit_and_or_xor;
pub mod bool_and_or;
pub mod grouping;
pub mod nth_value;
pub mod string_agg;

use crate::approx_percentile_cont::approx_percentile_cont_udaf;
Expand Down Expand Up @@ -105,6 +106,7 @@ pub mod expr_fn {
pub use super::first_last::last_value;
pub use super::grouping::grouping;
pub use super::median::median;
pub use super::nth_value::nth_value;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
Expand Down Expand Up @@ -157,6 +159,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
bool_and_or::bool_or_udaf(),
average::avg_udaf(),
grouping::grouping_udaf(),
nth_value::nth_value_udaf(),
]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,149 +22,149 @@ use std::any::Any;
use std::collections::VecDeque;
use std::sync::Arc;

use crate::aggregate::array_agg_ordered::merge_ordered_arrays;
use crate::aggregate::utils::{down_cast_any_ref, ordering_fields};
use crate::expressions::{format_state_name, Literal};
use crate::{
reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr,
};

use arrow_array::cast::AsArray;
use arrow_array::{new_empty_array, ArrayRef, StructArray};
use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray};
use arrow_schema::{DataType, Field, Fields};

use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx};
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_expr::Accumulator;
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature,
Volatility,
};
use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays;
use datafusion_physical_expr_common::aggregate::utils::ordering_fields;
use datafusion_physical_expr_common::sort_expr::{
limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr,
};

make_udaf_expr_and_func!(
NthValueAgg,
nth_value,
"Returns the nth value in a group of values.",
nth_value_udaf
);

/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi
/// partition setting, partial aggregations are computed for every partition,
/// and then their results are merged.
#[derive(Debug)]
pub struct NthValueAgg {
/// Column name
name: String,
/// The `DataType` for the input expression
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// The `N` value.
n: i64,
/// If the input expression can have `NULL`s
nullable: bool,
/// Ordering data types
order_by_data_types: Vec<DataType>,
/// Ordering requirement
ordering_req: LexOrdering,
signature: Signature,
/// Determines whether `N` is relative to the beginning or the end
/// of the aggregation. When set to `true`, then `N` is from the end.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

reversed: bool,
}

impl NthValueAgg {
/// Create a new `NthValueAgg` aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
n: i64,
name: impl Into<String>,
input_data_type: DataType,
nullable: bool,
order_by_data_types: Vec<DataType>,
ordering_req: LexOrdering,
) -> Self {
pub fn new() -> Self {
Self {
name: name.into(),
input_data_type,
expr,
n,
nullable,
order_by_data_types,
ordering_req,
signature: Signature::any(2, Volatility::Immutable),
reversed: false,
}
}

pub fn with_reversed(mut self, reversed: bool) -> Self {
self.reversed = reversed;
self
}
}

impl AggregateExpr for NthValueAgg {
impl Default for NthValueAgg {
fn default() -> Self {
Self::new()
}
}

impl AggregateUDFImpl for NthValueAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.input_data_type.clone(), true))
fn name(&self) -> &str {
"nth_value"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
self.n,
&self.input_data_type,
&self.order_by_data_types,
self.ordering_req.clone(),
)?))
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
}

fn state_fields(&self) -> Result<Vec<Field>> {
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let n = match acc_args.input_exprs[1] {
Expr::Literal(ScalarValue::Int64(Some(value))) => {
if self.reversed {
Ok(-value)
} else {
Ok(value)
}
}
_ => not_impl_err!(
"{} not supported for n: {}",
self.name(),
&acc_args.input_exprs[1]
),
}?;

let ordering_req = limited_convert_logical_sort_exprs_to_physical(
acc_args.sort_exprs,
acc_args.schema,
)?;

let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;

NthValueAccumulator::try_new(
n,
acc_args.input_type,
&ordering_dtypes,
ordering_req,
)
.map(|acc| Box::new(acc) as _)
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "nth_value"),
Field::new("item", self.input_data_type.clone(), true),
self.nullable, // This should be the same as field()
format_state_name(self.name(), "nth_value"),
// TODO: The nullability of the list element should be configurable.
// The hard-coded `true` should be changed once the field for
// nullability is added to `StateFieldArgs` struct.
// See: https://github.com/apache/datafusion/pull/11063
Field::new("item", args.input_type.clone(), true),
false,
Copy link
Contributor Author

@jcsherin jcsherin Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should nullable be configurable? But it is unavailable in StateFieldArgs. I think it is related to #11274 and #11094.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think given the existing nth function, we should let nullable configurable. And, the nullability is actually for the list element. We should add nullable in StateFieldArgs.

let mut fields = vec![Field::new_list(
            format_state_name(self.name(), "nth_value"),
            Field::new("item", args.input_type.clone(), self.nullable),
            false)]

@eejbyfeldt is working on it in #11063

Copy link
Contributor Author

@jcsherin jcsherin Jul 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marked it as a TODO in comments so that it can be completed once it is added to StateFieldArgs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on #11299, I think we will have non-null for nth_value too (similar to first value).

)];
if !self.ordering_req.is_empty() {
let orderings =
ordering_fields(&self.ordering_req, &self.order_by_data_types);
let orderings = args.ordering_fields.to_vec();
if !orderings.is_empty() {
fields.push(Field::new_list(
format_state_name(&self.name, "nth_value_orderings"),
format_state_name(self.name(), "nth_value_orderings"),
Field::new("item", DataType::Struct(Fields::from(orderings)), true),
self.nullable,
false,
));
}
Ok(fields)
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _;
vec![Arc::clone(&self.expr), n]
fn aliases(&self) -> &[String] {
&[]
}

fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
(!self.ordering_req.is_empty()).then_some(&self.ordering_req)
}

fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::HardRequirement
}

fn name(&self) -> &str {
&self.name
}

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
Some(Arc::new(Self {
name: self.name.to_string(),
input_data_type: self.input_data_type.clone(),
expr: Arc::clone(&self.expr),
// index should be from the opposite side
n: -self.n,
nullable: self.nullable,
order_by_data_types: self.order_by_data_types.clone(),
// reverse requirement
ordering_req: reverse_order_bys(&self.ordering_req),
}) as _)
}
}

impl PartialEq<dyn Any> for NthValueAgg {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.input_data_type == x.input_data_type
&& self.order_by_data_types == x.order_by_data_types
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Reversed(Arc::from(AggregateUDF::from(
Self::new().with_reversed(!self.reversed),
)))
}
}

#[derive(Debug)]
pub(crate) struct NthValueAccumulator {
pub struct NthValueAccumulator {
/// The `N` value.
n: i64,
/// Stores entries in the `NTH_VALUE` result.
values: VecDeque<ScalarValue>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
itertools = { version = "0.12", features = ["use_std"] }
log = { workspace = true }
paste = "1.0.14"
Expand Down
5 changes: 3 additions & 2 deletions datafusion/functions-array/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion_expr::{
sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess,
};
use datafusion_functions::expr_fn::get_field;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;

use crate::{
array_has::array_has_all,
Expand Down Expand Up @@ -119,8 +120,8 @@ impl UserDefinedSQLPlanner for FieldAccessPlanner {
// Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => {
Ok(PlannerResult::Planned(Expr::AggregateFunction(
datafusion_expr::expr::AggregateFunction::new(
AggregateFunction::NthValue,
datafusion_expr::expr::AggregateFunction::new_udf(
nth_value_udaf(),
agg_func
.args
.into_iter()
Expand Down
Loading