Skip to content

Commit 5d98c32

Browse files
committed
aggregate expr builder
Signed-off-by: jayzhan211 <[email protected]>
1 parent deef834 commit 5d98c32

File tree

2 files changed

+196
-90
lines changed

2 files changed

+196
-90
lines changed

datafusion/core/src/physical_optimizer/aggregate_statistics.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ pub(crate) mod tests {
326326
use datafusion_functions_aggregate::count::count_udaf;
327327
use datafusion_physical_expr::expressions::cast;
328328
use datafusion_physical_expr::PhysicalExpr;
329-
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
329+
use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
330330
use datafusion_physical_plan::aggregates::AggregateMode;
331331

332332
/// Mock data using a MemoryExec which has an exact count statistic
@@ -419,19 +419,11 @@ pub(crate) mod tests {
419419

420420
// Return appropriate expr depending if COUNT is for col or table (*)
421421
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn AggregateExpr> {
422-
create_aggregate_expr(
423-
&count_udaf(),
424-
&[self.column()],
425-
&[],
426-
&[],
427-
&[],
428-
schema,
429-
self.column_name(),
430-
false,
431-
false,
432-
false,
433-
)
434-
.unwrap()
422+
AggregateExprBuilder::new(count_udaf(), vec![self.column()])
423+
.schema(schema.clone())
424+
.name(self.column_name())
425+
.build()
426+
.unwrap()
435427
}
436428

437429
/// what argument would this aggregate need in the plan?

datafusion/physical-expr-common/src/aggregate/mod.rs

Lines changed: 190 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub mod tdigest;
2323
pub mod utils;
2424

2525
use arrow::datatypes::{DataType, Field, Schema};
26-
use datafusion_common::{not_impl_err, DFSchema, Result};
26+
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
2727
use datafusion_expr::function::StateFieldsArgs;
2828
use datafusion_expr::type_coercion::aggregates::check_arg_count;
2929
use datafusion_expr::ReversedUDAF;
@@ -33,7 +33,7 @@ use datafusion_expr::{
3333
use std::fmt::Debug;
3434
use std::{any::Any, sync::Arc};
3535

36-
use self::utils::{down_cast_any_ref, ordering_fields};
36+
use self::utils::down_cast_any_ref;
3737
use crate::physical_expr::PhysicalExpr;
3838
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
3939
use crate::utils::reverse_order_bys;
@@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity;
5555
/// `is_reversed` is used to indicate whether the aggregation is running in reverse order,
5656
/// it could be used to hint Accumulator to accumulate in the reversed order,
5757
/// you can just set to false if you are not reversing expression
58+
///
59+
/// You can also create expression by [`AggregateExprBuilder`]
5860
#[allow(clippy::too_many_arguments)]
5961
pub fn create_aggregate_expr(
6062
fun: &AggregateUDF,
@@ -66,45 +68,24 @@ pub fn create_aggregate_expr(
6668
name: impl Into<String>,
6769
ignore_nulls: bool,
6870
is_distinct: bool,
69-
is_reversed: bool,
71+
_is_reversed: bool,
7072
) -> Result<Arc<dyn AggregateExpr>> {
71-
debug_assert_eq!(sort_exprs.len(), ordering_req.len());
72-
73-
let input_exprs_types = input_phy_exprs
74-
.iter()
75-
.map(|arg| arg.data_type(schema))
76-
.collect::<Result<Vec<_>>>()?;
77-
78-
check_arg_count(
79-
fun.name(),
80-
&input_exprs_types,
81-
&fun.signature().type_signature,
82-
)?;
83-
84-
let ordering_types = ordering_req
85-
.iter()
86-
.map(|e| e.expr.data_type(schema))
87-
.collect::<Result<Vec<_>>>()?;
88-
89-
let ordering_fields = ordering_fields(ordering_req, &ordering_types);
90-
let name = name.into();
91-
92-
Ok(Arc::new(AggregateFunctionExpr {
93-
fun: fun.clone(),
94-
args: input_phy_exprs.to_vec(),
95-
logical_args: input_exprs.to_vec(),
96-
data_type: fun.return_type(&input_exprs_types)?,
97-
name,
98-
schema: schema.clone(),
99-
dfschema: DFSchema::empty(),
100-
sort_exprs: sort_exprs.to_vec(),
101-
ordering_req: ordering_req.to_vec(),
102-
ignore_nulls,
103-
ordering_fields,
104-
is_distinct,
105-
input_type: input_exprs_types[0].clone(),
106-
is_reversed,
107-
}))
73+
let mut builder =
74+
AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec());
75+
builder = builder.sort_exprs(sort_exprs.to_vec());
76+
builder = builder.order_by(ordering_req.to_vec());
77+
builder = builder.logical_exprs(input_exprs.to_vec());
78+
builder = builder.schema(schema.clone());
79+
builder = builder.name(name);
80+
81+
if ignore_nulls {
82+
builder = builder.ignore_nulls();
83+
}
84+
if is_distinct {
85+
builder = builder.distinct();
86+
}
87+
88+
builder.build()
10889
}
10990

11091
#[allow(clippy::too_many_arguments)]
@@ -121,44 +102,177 @@ pub fn create_aggregate_expr_with_dfschema(
121102
is_distinct: bool,
122103
is_reversed: bool,
123104
) -> Result<Arc<dyn AggregateExpr>> {
124-
debug_assert_eq!(sort_exprs.len(), ordering_req.len());
125-
105+
let mut builder =
106+
AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec());
107+
builder = builder.sort_exprs(sort_exprs.to_vec());
108+
builder = builder.order_by(ordering_req.to_vec());
109+
builder = builder.logical_exprs(input_exprs.to_vec());
110+
builder = builder.dfschema(dfschema.clone());
126111
let schema: Schema = dfschema.into();
112+
builder = builder.schema(schema);
113+
builder = builder.name(name);
114+
115+
if ignore_nulls {
116+
builder = builder.ignore_nulls();
117+
}
118+
if is_distinct {
119+
builder = builder.distinct();
120+
}
121+
if is_reversed {
122+
builder = builder.reversed();
123+
}
124+
125+
builder.build()
126+
}
127127

128-
let input_exprs_types = input_phy_exprs
129-
.iter()
130-
.map(|arg| arg.data_type(&schema))
131-
.collect::<Result<Vec<_>>>()?;
132-
133-
check_arg_count(
134-
fun.name(),
135-
&input_exprs_types,
136-
&fun.signature().type_signature,
137-
)?;
138-
139-
let ordering_types = ordering_req
140-
.iter()
141-
.map(|e| e.expr.data_type(&schema))
142-
.collect::<Result<Vec<_>>>()?;
143-
144-
let ordering_fields = ordering_fields(ordering_req, &ordering_types);
145-
146-
Ok(Arc::new(AggregateFunctionExpr {
147-
fun: fun.clone(),
148-
args: input_phy_exprs.to_vec(),
149-
logical_args: input_exprs.to_vec(),
150-
data_type: fun.return_type(&input_exprs_types)?,
151-
name: name.into(),
152-
schema: schema.clone(),
153-
dfschema: dfschema.clone(),
154-
sort_exprs: sort_exprs.to_vec(),
155-
ordering_req: ordering_req.to_vec(),
156-
ignore_nulls,
157-
ordering_fields,
158-
is_distinct,
159-
input_type: input_exprs_types[0].clone(),
160-
is_reversed,
161-
}))
128+
#[derive(Debug, Clone)]
129+
pub struct AggregateExprBuilder {
130+
fun: Arc<AggregateUDF>,
131+
/// Physical expressions of the aggregate function
132+
args: Vec<Arc<dyn PhysicalExpr>>,
133+
/// Logical expressions of the aggregate function, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
134+
logical_args: Vec<Expr>,
135+
name: String,
136+
/// Arrow Schema for the aggregate function
137+
schema: Schema,
138+
/// Datafusion Schema for the aggregate function
139+
dfschema: DFSchema,
140+
/// The logical order by expressions, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
141+
sort_exprs: Vec<Expr>,
142+
/// The physical order by expressions
143+
ordering_req: LexOrdering,
144+
/// Whether to ignore null values
145+
ignore_nulls: bool,
146+
/// Whether is distinct aggregate function
147+
is_distinct: bool,
148+
/// Whether the expression is reversed
149+
is_reversed: bool,
150+
}
151+
152+
impl AggregateExprBuilder {
153+
pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
154+
Self {
155+
fun,
156+
args,
157+
logical_args: vec![],
158+
name: String::new(),
159+
schema: Schema::empty(),
160+
dfschema: DFSchema::empty(),
161+
sort_exprs: vec![],
162+
ordering_req: vec![],
163+
ignore_nulls: false,
164+
is_distinct: false,
165+
is_reversed: false,
166+
}
167+
}
168+
169+
pub fn build(self) -> Result<Arc<dyn AggregateExpr>> {
170+
let Self {
171+
fun,
172+
args,
173+
logical_args,
174+
name,
175+
schema,
176+
dfschema,
177+
sort_exprs,
178+
ordering_req,
179+
ignore_nulls,
180+
is_distinct,
181+
is_reversed,
182+
} = self;
183+
if args.is_empty() {
184+
return internal_err!("args should not be empty");
185+
}
186+
187+
let mut ordering_fields = vec![];
188+
189+
debug_assert_eq!(sort_exprs.len(), ordering_req.len());
190+
if !ordering_req.is_empty() {
191+
let ordering_types = ordering_req
192+
.iter()
193+
.map(|e| e.expr.data_type(&schema))
194+
.collect::<Result<Vec<_>>>()?;
195+
196+
ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types);
197+
}
198+
199+
let input_exprs_types = args
200+
.iter()
201+
.map(|arg| arg.data_type(&schema))
202+
.collect::<Result<Vec<_>>>()?;
203+
204+
check_arg_count(
205+
fun.name(),
206+
&input_exprs_types,
207+
&fun.signature().type_signature,
208+
)?;
209+
210+
let data_type = fun.return_type(&input_exprs_types)?;
211+
212+
Ok(Arc::new(AggregateFunctionExpr {
213+
fun: Arc::unwrap_or_clone(fun),
214+
args,
215+
logical_args,
216+
data_type,
217+
name,
218+
schema,
219+
dfschema,
220+
sort_exprs,
221+
ordering_req,
222+
ignore_nulls,
223+
ordering_fields,
224+
is_distinct,
225+
input_type: input_exprs_types[0].clone(),
226+
is_reversed,
227+
}))
228+
}
229+
230+
pub fn name(mut self, name: impl Into<String>) -> Self {
231+
self.name = name.into();
232+
self
233+
}
234+
235+
pub fn schema(mut self, schema: Schema) -> Self {
236+
self.schema = schema;
237+
self
238+
}
239+
240+
pub fn dfschema(mut self, dfschema: DFSchema) -> Self {
241+
self.dfschema = dfschema;
242+
self
243+
}
244+
245+
pub fn order_by(mut self, order_by: LexOrdering) -> Self {
246+
self.ordering_req = order_by;
247+
self
248+
}
249+
250+
pub fn reversed(mut self) -> Self {
251+
self.is_reversed = true;
252+
self
253+
}
254+
255+
pub fn distinct(mut self) -> Self {
256+
self.is_distinct = true;
257+
self
258+
}
259+
260+
pub fn ignore_nulls(mut self) -> Self {
261+
self.ignore_nulls = true;
262+
self
263+
}
264+
265+
/// This method will be deprecated in <https://github.com/apache/datafusion/issues/11359>
266+
pub fn sort_exprs(mut self, sort_exprs: Vec<Expr>) -> Self {
267+
self.sort_exprs = sort_exprs;
268+
self
269+
}
270+
271+
/// This method will be deprecated in <https://github.com/apache/datafusion/issues/11359>
272+
pub fn logical_exprs(mut self, logical_args: Vec<Expr>) -> Self {
273+
self.logical_args = logical_args;
274+
self
275+
}
162276
}
163277

164278
/// An aggregate expression that:

0 commit comments

Comments
 (0)