-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Convert nth_value
to UDAF
#11287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert nth_value
to UDAF
#11287
Changes from all commits
bef3ede
c684040
165c4f5
e82f055
a45349f
9c9a6c4
60370ca
3631812
98039e5
415d9db
d9ebdbe
c4e5417
f74459b
0fc98f1
02c01fc
729d9c5
40b4607
a86ca1f
02f4497
2e90028
972f118
cf974c0
427a8bb
488881f
91bacb7
97e8955
4634641
7b57cce
20a804b
cc08872
35b0c0d
f6215d9
5874f0f
6ce6679
5ea5421
d0a3c3d
e3a8026
1990bad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,149 +22,149 @@ use std::any::Any; | |
use std::collections::VecDeque; | ||
use std::sync::Arc; | ||
|
||
use crate::aggregate::array_agg_ordered::merge_ordered_arrays; | ||
use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; | ||
use crate::expressions::{format_state_name, Literal}; | ||
use crate::{ | ||
reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, | ||
}; | ||
|
||
use arrow_array::cast::AsArray; | ||
use arrow_array::{new_empty_array, ArrayRef, StructArray}; | ||
use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; | ||
use arrow_schema::{DataType, Field, Fields}; | ||
|
||
use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; | ||
use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; | ||
use datafusion_expr::utils::AggregateOrderSensitivity; | ||
use datafusion_expr::Accumulator; | ||
use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; | ||
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; | ||
use datafusion_expr::utils::format_state_name; | ||
use datafusion_expr::{ | ||
Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature, | ||
Volatility, | ||
}; | ||
use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; | ||
use datafusion_physical_expr_common::aggregate::utils::ordering_fields; | ||
use datafusion_physical_expr_common::sort_expr::{ | ||
limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, | ||
}; | ||
|
||
make_udaf_expr_and_func!( | ||
NthValueAgg, | ||
nth_value, | ||
"Returns the nth value in a group of values.", | ||
nth_value_udaf | ||
); | ||
|
||
/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi | ||
/// partition setting, partial aggregations are computed for every partition, | ||
/// and then their results are merged. | ||
#[derive(Debug)] | ||
pub struct NthValueAgg { | ||
/// Column name | ||
name: String, | ||
/// The `DataType` for the input expression | ||
input_data_type: DataType, | ||
/// The input expression | ||
expr: Arc<dyn PhysicalExpr>, | ||
/// The `N` value. | ||
n: i64, | ||
/// If the input expression can have `NULL`s | ||
nullable: bool, | ||
/// Ordering data types | ||
order_by_data_types: Vec<DataType>, | ||
/// Ordering requirement | ||
ordering_req: LexOrdering, | ||
signature: Signature, | ||
/// Determines whether `N` is relative to the beginning or the end | ||
/// of the aggregation. When set to `true`, then `N` is from the end. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❤️ |
||
reversed: bool, | ||
} | ||
|
||
impl NthValueAgg { | ||
/// Create a new `NthValueAgg` aggregate function | ||
pub fn new( | ||
expr: Arc<dyn PhysicalExpr>, | ||
n: i64, | ||
name: impl Into<String>, | ||
input_data_type: DataType, | ||
nullable: bool, | ||
order_by_data_types: Vec<DataType>, | ||
ordering_req: LexOrdering, | ||
) -> Self { | ||
pub fn new() -> Self { | ||
Self { | ||
name: name.into(), | ||
input_data_type, | ||
expr, | ||
n, | ||
nullable, | ||
order_by_data_types, | ||
ordering_req, | ||
signature: Signature::any(2, Volatility::Immutable), | ||
reversed: false, | ||
} | ||
} | ||
|
||
pub fn with_reversed(mut self, reversed: bool) -> Self { | ||
self.reversed = reversed; | ||
self | ||
} | ||
} | ||
|
||
impl AggregateExpr for NthValueAgg { | ||
impl Default for NthValueAgg { | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} | ||
|
||
impl AggregateUDFImpl for NthValueAgg { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn field(&self) -> Result<Field> { | ||
Ok(Field::new(&self.name, self.input_data_type.clone(), true)) | ||
fn name(&self) -> &str { | ||
"nth_value" | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
Ok(Box::new(NthValueAccumulator::try_new( | ||
self.n, | ||
&self.input_data_type, | ||
&self.order_by_data_types, | ||
self.ordering_req.clone(), | ||
)?)) | ||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(arg_types[0].clone()) | ||
} | ||
|
||
fn state_fields(&self) -> Result<Vec<Field>> { | ||
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { | ||
let n = match acc_args.input_exprs[1] { | ||
Expr::Literal(ScalarValue::Int64(Some(value))) => { | ||
if self.reversed { | ||
Ok(-value) | ||
} else { | ||
jayzhan211 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Ok(value) | ||
} | ||
} | ||
_ => not_impl_err!( | ||
"{} not supported for n: {}", | ||
self.name(), | ||
&acc_args.input_exprs[1] | ||
), | ||
}?; | ||
|
||
let ordering_req = limited_convert_logical_sort_exprs_to_physical( | ||
acc_args.sort_exprs, | ||
acc_args.schema, | ||
)?; | ||
|
||
let ordering_dtypes = ordering_req | ||
.iter() | ||
.map(|e| e.expr.data_type(acc_args.schema)) | ||
.collect::<Result<Vec<_>>>()?; | ||
|
||
NthValueAccumulator::try_new( | ||
n, | ||
acc_args.input_type, | ||
&ordering_dtypes, | ||
ordering_req, | ||
) | ||
.map(|acc| Box::new(acc) as _) | ||
} | ||
|
||
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { | ||
let mut fields = vec![Field::new_list( | ||
format_state_name(&self.name, "nth_value"), | ||
Field::new("item", self.input_data_type.clone(), true), | ||
self.nullable, // This should be the same as field() | ||
format_state_name(self.name(), "nth_value"), | ||
// TODO: The nullability of the list element should be configurable. | ||
// The hard-coded `true` should be changed once the field for | ||
// nullability is added to `StateFieldArgs` struct. | ||
// See: https://github.com/apache/datafusion/pull/11063 | ||
Field::new("item", args.input_type.clone(), true), | ||
false, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think given the existing nth function, we should let nullable configurable. And, the nullability is actually for the list element. We should add nullable in let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
Field::new("item", args.input_type.clone(), self.nullable),
false)] @eejbyfeldt is working on it in #11063 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Marked it as a TODO in comments so that it can be completed once it is added to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. based on #11299, I think we will have non-null for nth_value too (similar to first value). |
||
)]; | ||
if !self.ordering_req.is_empty() { | ||
let orderings = | ||
ordering_fields(&self.ordering_req, &self.order_by_data_types); | ||
let orderings = args.ordering_fields.to_vec(); | ||
if !orderings.is_empty() { | ||
fields.push(Field::new_list( | ||
format_state_name(&self.name, "nth_value_orderings"), | ||
format_state_name(self.name(), "nth_value_orderings"), | ||
Field::new("item", DataType::Struct(Fields::from(orderings)), true), | ||
self.nullable, | ||
false, | ||
)); | ||
} | ||
Ok(fields) | ||
} | ||
|
||
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { | ||
let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _; | ||
vec![Arc::clone(&self.expr), n] | ||
fn aliases(&self) -> &[String] { | ||
&[] | ||
} | ||
|
||
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { | ||
(!self.ordering_req.is_empty()).then_some(&self.ordering_req) | ||
} | ||
|
||
fn order_sensitivity(&self) -> AggregateOrderSensitivity { | ||
AggregateOrderSensitivity::HardRequirement | ||
} | ||
|
||
fn name(&self) -> &str { | ||
&self.name | ||
} | ||
|
||
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> { | ||
Some(Arc::new(Self { | ||
name: self.name.to_string(), | ||
input_data_type: self.input_data_type.clone(), | ||
expr: Arc::clone(&self.expr), | ||
// index should be from the opposite side | ||
n: -self.n, | ||
nullable: self.nullable, | ||
order_by_data_types: self.order_by_data_types.clone(), | ||
// reverse requirement | ||
ordering_req: reverse_order_bys(&self.ordering_req), | ||
}) as _) | ||
} | ||
} | ||
|
||
impl PartialEq<dyn Any> for NthValueAgg { | ||
fn eq(&self, other: &dyn Any) -> bool { | ||
down_cast_any_ref(other) | ||
.downcast_ref::<Self>() | ||
.map(|x| { | ||
self.name == x.name | ||
&& self.input_data_type == x.input_data_type | ||
&& self.order_by_data_types == x.order_by_data_types | ||
&& self.expr.eq(&x.expr) | ||
}) | ||
.unwrap_or(false) | ||
fn reverse_expr(&self) -> ReversedUDAF { | ||
ReversedUDAF::Reversed(Arc::from(AggregateUDF::from( | ||
Self::new().with_reversed(!self.reversed), | ||
))) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
pub(crate) struct NthValueAccumulator { | ||
pub struct NthValueAccumulator { | ||
/// The `N` value. | ||
n: i64, | ||
/// Stores entries in the `NTH_VALUE` result. | ||
values: VecDeque<ScalarValue>, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this list is quite close to empty 🤞