Skip to content

Commit a38089a

Browse files
committed
first draft
Signed-off-by: jayzhan211 <[email protected]>
1 parent 82ea059 commit a38089a

File tree

10 files changed

+36
-81
lines changed

10 files changed

+36
-81
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ use datafusion_expr::{
8585
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan,
8686
WindowFrame, WindowFrameBound, WriteOp,
8787
};
88-
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
8988
use datafusion_physical_expr::expressions::Literal;
9089
use datafusion_physical_expr::LexOrdering;
9190
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
@@ -1840,11 +1839,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18401839
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
18411840
== NullTreatment::IgnoreNulls;
18421841

1842+
// TODO: Remove this after array_agg are all udafs
18431843
let (agg_expr, filter, order_by) = match func_def {
1844-
AggregateFunctionDefinition::BuiltIn(
1845-
datafusion_expr::AggregateFunction::ArrayAgg,
1846-
) if !distinct && order_by.is_none() => {
1847-
let sort_exprs = order_by.clone().unwrap_or(vec![]);
1844+
AggregateFunctionDefinition::UDF(udf)
1845+
if udf.name() == "ARRAY_AGG" && (*distinct || order_by.is_some()) =>
1846+
{
18481847
let physical_sort_exprs = match order_by {
18491848
Some(exprs) => Some(create_physical_sort_exprs(
18501849
exprs,
@@ -1855,16 +1854,15 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18551854
};
18561855
let ordering_reqs: Vec<PhysicalSortExpr> =
18571856
physical_sort_exprs.clone().unwrap_or(vec![]);
1858-
let agg_expr = udaf::create_aggregate_expr(
1859-
&array_agg_udaf(),
1857+
let fun = aggregates::AggregateFunction::ArrayAgg;
1858+
let agg_expr = aggregates::create_aggregate_expr(
1859+
&fun,
1860+
*distinct,
18601861
&physical_args,
1861-
args,
1862-
&sort_exprs,
18631862
&ordering_reqs,
18641863
physical_input_schema,
18651864
name,
18661865
ignore_nulls,
1867-
*distinct,
18681866
)?;
18691867
(agg_expr, filter, physical_sort_exprs)
18701868
}
@@ -1916,6 +1914,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
19161914
(agg_expr, filter, physical_sort_exprs)
19171915
}
19181916
};
1917+
19191918
Ok((agg_expr, filter, order_by))
19201919
}
19211920
other => internal_err!("Invalid aggregate expression '{other:?}'"),

datafusion/expr/src/function.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,6 @@ pub struct StateFieldsArgs<'a> {
106106
/// The input type of the aggregate function.
107107
pub input_type: &'a DataType,
108108

109-
/// If the input type is nullable.
110-
pub input_nullable: bool,
111-
112109
/// The return type of the aggregate function.
113110
pub return_type: &'a DataType,
114111

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@ use arrow::datatypes::DataType;
2222
use arrow_schema::Field;
2323

2424
use datafusion_common::cast::as_list_array;
25-
use datafusion_common::utils::array_into_list_array;
25+
use datafusion_common::utils::array_into_list_array_nullable;
2626
use datafusion_common::Result;
2727
use datafusion_common::ScalarValue;
28-
use datafusion_expr::expr::AggregateFunction;
29-
use datafusion_expr::expr::AggregateFunctionDefinition;
30-
use datafusion_expr::function::AccumulatorArgs;
31-
use datafusion_expr::simplify::SimplifyInfo;
28+
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3229
use datafusion_expr::utils::format_state_name;
3330
use datafusion_expr::AggregateUDFImpl;
34-
use datafusion_expr::Expr;
3531
use datafusion_expr::{Accumulator, Signature, Volatility};
3632
use std::sync::Arc;
3733

@@ -84,47 +80,17 @@ impl AggregateUDFImpl for ArrayAgg {
8480
))))
8581
}
8682

87-
fn state_fields(
88-
&self,
89-
args: datafusion_expr::function::StateFieldsArgs,
90-
) -> Result<Vec<Field>> {
83+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
9184
Ok(vec![Field::new_list(
9285
format_state_name(args.name, "array_agg"),
9386
Field::new("item", args.input_type.clone(), true),
94-
args.input_nullable,
87+
true,
9588
)])
9689
}
9790

9891
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
9992
Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?))
10093
}
101-
102-
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
103-
datafusion_expr::ReversedUDAF::Identical
104-
}
105-
106-
fn simplify(
107-
&self,
108-
) -> Option<datafusion_expr::function::AggregateFunctionSimplification> {
109-
let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| {
110-
if aggregate_function.order_by.is_some() || aggregate_function.distinct {
111-
Ok(Expr::AggregateFunction(AggregateFunction {
112-
func_def: AggregateFunctionDefinition::BuiltIn(
113-
datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg,
114-
),
115-
args: aggregate_function.args,
116-
distinct: aggregate_function.distinct,
117-
filter: aggregate_function.filter,
118-
order_by: aggregate_function.order_by,
119-
null_treatment: aggregate_function.null_treatment,
120-
}))
121-
} else {
122-
Ok(Expr::AggregateFunction(aggregate_function))
123-
}
124-
};
125-
126-
Some(Box::new(simplify))
127-
}
12894
}
12995

13096
#[derive(Debug)]
@@ -150,8 +116,11 @@ impl Accumulator for ArrayAggAccumulator {
150116
return Ok(());
151117
}
152118
assert!(values.len() == 1, "array_agg can only take 1 param!");
153-
let val = values[0].clone();
154-
self.values.push(val);
119+
120+
let val = Arc::clone(&values[0]);
121+
if val.len() > 0 {
122+
self.values.push(val);
123+
}
155124
Ok(())
156125
}
157126

@@ -175,17 +144,15 @@ impl Accumulator for ArrayAggAccumulator {
175144

176145
fn evaluate(&mut self) -> Result<ScalarValue> {
177146
// Transform Vec<ListArr> to ListArr
178-
179147
let element_arrays: Vec<&dyn Array> =
180148
self.values.iter().map(|a| a.as_ref()).collect();
181149

182150
if element_arrays.is_empty() {
183-
let arr = ScalarValue::new_list(&[], &self.datatype);
184-
return Ok(ScalarValue::List(arr));
151+
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
185152
}
186153

187154
let concated_array = arrow::compute::concat(&element_arrays)?;
188-
let list_array = array_into_list_array(concated_array);
155+
let list_array = array_into_list_array_nullable(concated_array);
189156

190157
Ok(ScalarValue::List(Arc::new(list_array)))
191158
}

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,6 @@ impl AggregateUDFImpl for LastValue {
440440
let StateFieldsArgs {
441441
name,
442442
input_type,
443-
input_nullable: _,
444443
return_type: _,
445444
ordering_fields,
446445
is_distinct: _,

datafusion/functions-aggregate/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@
5656
pub mod macros;
5757

5858
pub mod approx_distinct;
59-
pub mod correlation;
6059
pub mod array_agg;
60+
pub mod correlation;
6161
pub mod count;
6262
pub mod covariance;
6363
pub mod first_last;
@@ -92,8 +92,8 @@ pub mod expr_fn {
9292
pub use super::approx_median::approx_median;
9393
pub use super::approx_percentile_cont::approx_percentile_cont;
9494
pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight;
95-
pub use super::average::avg;
9695
pub use super::array_agg::array_agg;
96+
pub use super::average::avg;
9797
pub use super::bit_and_or_xor::bit_and;
9898
pub use super::bit_and_or_xor::bit_or;
9999
pub use super::bit_and_or_xor::bit_xor;

datafusion/functions-array/src/planner.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
2020
use datafusion_common::{utils::list_ndims, DFSchema, Result};
2121
use datafusion_expr::{
22+
expr::AggregateFunctionDefinition,
2223
planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
23-
sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess,
24+
sqlparser, Expr, ExprSchemable, GetFieldAccess,
2425
};
2526
use datafusion_functions::expr_fn::get_field;
2627
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
@@ -153,8 +154,9 @@ impl ExprPlanner for FieldAccessPlanner {
153154
}
154155

155156
fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
156-
agg_func.func_def
157-
== datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(
158-
AggregateFunction::ArrayAgg,
159-
)
157+
if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def {
158+
return udf.name() == "ARRAY_AGG";
159+
}
160+
161+
false
160162
}

datafusion/physical-expr-common/src/aggregate/mod.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ pub fn create_aggregate_expr(
8888
ordering_fields,
8989
is_distinct,
9090
input_type: input_exprs_types[0].clone(),
91-
input_nullable: input_phy_exprs[0].nullable(schema)?,
9291
}))
9392
}
9493

@@ -264,7 +263,6 @@ pub struct AggregateFunctionExpr {
264263
ordering_fields: Vec<Field>,
265264
is_distinct: bool,
266265
input_type: DataType,
267-
input_nullable: bool,
268266
}
269267

270268
impl AggregateFunctionExpr {
@@ -293,7 +291,6 @@ impl AggregateExpr for AggregateFunctionExpr {
293291
let args = StateFieldsArgs {
294292
name: &self.name,
295293
input_type: &self.input_type,
296-
input_nullable: self.input_nullable,
297294
return_type: &self.data_type,
298295
ordering_fields: &self.ordering_fields,
299296
is_distinct: self.is_distinct,
@@ -303,11 +300,7 @@ impl AggregateExpr for AggregateFunctionExpr {
303300
}
304301

305302
fn field(&self) -> Result<Field> {
306-
Ok(Field::new(
307-
&self.name,
308-
self.data_type.clone(),
309-
self.input_nullable,
310-
))
303+
Ok(Field::new(&self.name, self.data_type.clone(), true))
311304
}
312305

313306
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use std::sync::Arc;
3030

3131
use arrow::datatypes::Schema;
3232

33-
use datafusion_common::{not_impl_err, Result};
33+
use datafusion_common::{internal_err, not_impl_err, Result};
3434
use datafusion_expr::AggregateFunction;
3535

3636
use crate::expressions::{self};

datafusion/physical-plan/src/aggregates/no_grouping.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ fn aggregate_batch(
218218
Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
219219
None => Cow::Borrowed(&batch),
220220
};
221+
221222
// 1.3
222223
let values = &expr
223224
.iter()

datafusion/proto/src/physical_plan/to_proto.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ use datafusion::datasource::file_format::parquet::ParquetSink;
2323
use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr};
2424
use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
2525
use datafusion::physical_plan::expressions::{
26-
BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg,
27-
InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr,
28-
NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr,
29-
WindowShift,
26+
BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, InListExpr,
27+
IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile,
28+
OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift,
3029
};
3130
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
3231
use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr};
@@ -245,9 +244,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result<AggrFn> {
245244
let mut distinct = false;
246245

247246
// TODO: remove
248-
let inner = if aggr_expr.downcast_ref::<ArrayAgg>().is_some() {
249-
protobuf::AggregateFunction::ArrayAgg
250-
} else if aggr_expr.downcast_ref::<DistinctArrayAgg>().is_some() {
247+
let inner = if aggr_expr.downcast_ref::<DistinctArrayAgg>().is_some() {
251248
distinct = true;
252249
protobuf::AggregateFunction::ArrayAgg
253250
} else if aggr_expr.downcast_ref::<OrderSensitiveArrayAgg>().is_some() {

0 commit comments

Comments
 (0)