Skip to content

Commit 4c75286

Browse files
update AggregateFunction
Upstream Changes: - The field name was switched from `func_name` to func. - AggregateFunctionDefinition was removed Ref: apache/datafusion#11803
1 parent 7207433 commit 4c75286

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

src/expr/aggregate.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ impl PyAggregate {
127127
// TODO: This Alias logic seems to be returning some strange results that we should investigate
128128
Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
129129
Expr::AggregateFunction(AggregateFunction {
130-
func_def: _, args, ..
130+
func: _, args, ..
131131
}) => Ok(args.iter().map(|e| PyExpr::from(e.clone())).collect()),
132132
_ => Err(py_type_err(
133133
"Encountered a non Aggregate type in aggregation_arguments",
@@ -138,8 +138,8 @@ impl PyAggregate {
138138
fn _agg_func_name(expr: &Expr) -> PyResult<String> {
139139
match expr {
140140
Expr::Alias(Alias { expr, .. }) => Self::_agg_func_name(expr.as_ref()),
141-
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => {
142-
Ok(func_def.name().to_owned())
141+
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
142+
Ok(func.name().to_owned())
143143
}
144144
_ => Err(py_type_err(
145145
"Encountered a non Aggregate type in agg_func_name",

src/expr/aggregate_expr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ impl From<AggregateFunction> for PyAggregateFunction {
4141
impl Display for PyAggregateFunction {
4242
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
4343
let args: Vec<String> = self.aggr.args.iter().map(|expr| expr.to_string()).collect();
44-
write!(f, "{}({})", self.aggr.func_def.name(), args.join(", "))
44+
write!(f, "{}({})", self.aggr.func.name(), args.join(", "))
4545
}
4646
}
4747

4848
#[pymethods]
4949
impl PyAggregateFunction {
5050
/// Get the aggregate type, such as "MIN", or "MAX"
5151
fn aggregate_type(&self) -> String {
52-
self.aggr.func_def.name().to_string()
52+
self.aggr.func.name().to_string()
5353
}
5454

5555
/// is this a distinct aggregate such as `COUNT(DISTINCT expr)`

src/functions.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use datafusion::functions_aggregate::all_default_aggregate_functions;
19-
use datafusion_expr::AggregateExt;
19+
use datafusion_expr::ExprFunctionExt as AggregateExt;
2020
use pyo3::{prelude::*, wrap_pyfunction};
2121

2222
use crate::common::data_type::NullTreatment;
@@ -31,9 +31,7 @@ use datafusion::functions_aggregate;
3131
use datafusion_common::{Column, ScalarValue, TableReference};
3232
use datafusion_expr::expr::Alias;
3333
use datafusion_expr::{
34-
expr::{
35-
find_df_window_func, AggregateFunction, AggregateFunctionDefinition, Sort, WindowFunction,
36-
},
34+
expr::{find_df_window_func, AggregateFunction, Sort, WindowFunction},
3735
lit, Expr, WindowFunctionDefinition,
3836
};
3937

@@ -638,18 +636,16 @@ fn window(
638636
}
639637

640638
macro_rules! aggregate_function {
641-
($NAME: ident, $FUNC: ident) => {
639+
($NAME: ident, $FUNC: path) => {
642640
aggregate_function!($NAME, $FUNC, stringify!($NAME));
643641
};
644-
($NAME: ident, $FUNC: ident, $DOC: expr) => {
642+
($NAME: ident, $FUNC: path, $DOC: expr) => {
645643
#[doc = $DOC]
646644
#[pyfunction]
647645
#[pyo3(signature = (*args, distinct=false))]
648646
fn $NAME(args: Vec<PyExpr>, distinct: bool) -> PyExpr {
649647
let expr = datafusion_expr::Expr::AggregateFunction(AggregateFunction {
650-
func_def: AggregateFunctionDefinition::BuiltIn(
651-
datafusion_expr::aggregate_function::AggregateFunction::$FUNC,
652-
),
648+
func: $FUNC(),
653649
args: args.into_iter().map(|e| e.into()).collect(),
654650
distinct,
655651
filter: None,
@@ -884,9 +880,9 @@ array_fn!(array_resize, array size value);
884880
array_fn!(flatten, array);
885881
array_fn!(range, start stop step);
886882

887-
aggregate_function!(array_agg, ArrayAgg);
888-
aggregate_function!(max, Max);
889-
aggregate_function!(min, Min);
883+
aggregate_function!(array_agg, functions_aggregate::array_agg::array_agg_udaf);
884+
aggregate_function!(max, functions_aggregate::min_max::max_udaf);
885+
aggregate_function!(min, functions_aggregate::min_max::min_udaf);
890886

891887
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
892888
m.add_wrapped(wrap_pyfunction!(abs))?;

0 commit comments

Comments
 (0)