Skip to content

Commit fdd1e3d

Browse files
authored
Convert Average to UDAF #10942 (#10964)
* add avg udaf * remove avg from expr * add test stub * migrate avg udaf * change avg udaf signature remove avg phy expr * fix tests * fix state_fields fn * fix ut in phy-plan aggr * refactor Average to Avg * refactor Average to Avg * fix type coercion tests * fix example and logic tests * fix py expr failing ut * update docs * fix failing tests * formatting examples * remove duplicate code and fix uts * addressing PR comments * add ut for logical avg window * fix physical plan roundtrip_window test case
1 parent c2ea6b3 commit fdd1e3d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+564
-596
lines changed

datafusion-examples/examples/dataframe_subquery.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use arrow_schema::DataType;
1919
use std::sync::Arc;
2020

2121
use datafusion::error::Result;
22+
use datafusion::functions_aggregate::average::avg;
2223
use datafusion::prelude::*;
2324
use datafusion::test_util::arrow_test_data;
2425
use datafusion_common::ScalarValue;

datafusion-examples/examples/simplify_udaf_expression.rs

+18-20
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,21 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow_schema::{Field, Schema};
19-
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20-
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
21-
use datafusion_expr::simplify::SimplifyInfo;
22-
2318
use std::{any::Any, sync::Arc};
2419

20+
use arrow_schema::{Field, Schema};
21+
2522
use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
2623
use datafusion::error::Result;
24+
use datafusion::functions_aggregate::average::avg_udaf;
25+
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
2726
use datafusion::{assert_batches_eq, prelude::*};
2827
use datafusion_common::cast::as_float64_array;
28+
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
29+
use datafusion_expr::simplify::SimplifyInfo;
2930
use datafusion_expr::{
30-
expr::{AggregateFunction, AggregateFunctionDefinition},
31-
function::AccumulatorArgs,
32-
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
31+
expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF,
32+
AggregateUDFImpl, GroupsAccumulator, Signature,
3333
};
3434

3535
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
@@ -92,18 +92,16 @@ impl AggregateUDFImpl for BetterAvgUdaf {
9292
// with build-in aggregate function to illustrate the use
9393
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
9494
_: &dyn SimplifyInfo| {
95-
Ok(Expr::AggregateFunction(AggregateFunction {
96-
func_def: AggregateFunctionDefinition::BuiltIn(
97-
// yes it is the same Avg, `BetterAvgUdaf` was just a
98-
// marketing pitch :)
99-
datafusion_expr::aggregate_function::AggregateFunction::Avg,
100-
),
101-
args: aggregate_function.args,
102-
distinct: aggregate_function.distinct,
103-
filter: aggregate_function.filter,
104-
order_by: aggregate_function.order_by,
105-
null_treatment: aggregate_function.null_treatment,
106-
}))
95+
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
96+
avg_udaf(),
97+
// yes it is the same Avg, `BetterAvgUdaf` was just a
98+
// marketing pitch :)
99+
aggregate_function.args,
100+
aggregate_function.distinct,
101+
aggregate_function.filter,
102+
aggregate_function.order_by,
103+
aggregate_function.null_treatment,
104+
)))
107105
};
108106

109107
Some(Box::new(simplify))

datafusion-examples/examples/simplify_udwf_expression.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
use std::any::Any;
1919

2020
use arrow_schema::DataType;
21+
2122
use datafusion::execution::context::SessionContext;
23+
use datafusion::functions_aggregate::average::avg_udaf;
2224
use datafusion::{error::Result, execution::options::CsvReadOptions};
2325
use datafusion_expr::function::WindowFunctionSimplification;
2426
use datafusion_expr::{
25-
expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
26-
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
27+
expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature,
28+
Volatility, WindowUDF, WindowUDFImpl,
2729
};
2830

2931
/// This UDWF will show how to use the WindowUDFImpl::simplify() API
@@ -71,9 +73,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
7173
let simplify = |window_function: datafusion_expr::expr::WindowFunction,
7274
_: &dyn SimplifyInfo| {
7375
Ok(Expr::WindowFunction(WindowFunction {
74-
fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
75-
AggregateFunction::Avg,
76-
),
76+
fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()),
7777
args: window_function.args,
7878
partition_by: window_function.partition_by,
7979
order_by: window_function.order_by,

datafusion/core/src/dataframe/mod.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,11 @@ use datafusion_common::config::{CsvOptions, FormatOptions, JsonOptions};
4848
use datafusion_common::{
4949
plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions,
5050
};
51-
use datafusion_expr::lit;
51+
use datafusion_expr::{case, is_null, lit};
5252
use datafusion_expr::{
53-
avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
54-
UNNAMED_TABLE,
53+
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
5554
};
56-
use datafusion_expr::{case, is_null};
57-
use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum};
55+
use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum};
5856

5957
use async_trait::async_trait;
6058

@@ -561,7 +559,7 @@ impl DataFrame {
561559
/// # async fn main() -> Result<()> {
562560
/// let ctx = SessionContext::new();
563561
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?
564-
/// // Return a single row (a, b) for each distinct value of a
562+
/// // Return a single row (a, b) for each distinct value of a
565563
/// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?;
566564
/// # Ok(())
567565
/// # }
@@ -2045,7 +2043,7 @@ mod tests {
20452043

20462044
assert_batches_sorted_eq!(
20472045
["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
2048-
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
2046+
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
20492047
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
20502048
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
20512049
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",

datafusion/core/tests/dataframe/mod.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
5252
use datafusion_expr::expr::{GroupingSet, Sort};
5353
use datafusion_expr::var_provider::{VarProvider, VarType};
5454
use datafusion_expr::{
55-
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
56-
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
57-
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
55+
array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
56+
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
57+
WindowFrameUnits, WindowFunctionDefinition,
5858
};
59-
use datafusion_functions_aggregate::expr_fn::{count, sum};
59+
use datafusion_functions_aggregate::expr_fn::{avg, count, sum};
6060

6161
#[tokio::test]
6262
async fn test_count_wildcard_on_sort() -> Result<()> {

datafusion/core/tests/user_defined/user_defined_aggregates.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ use datafusion_expr::{
4848
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
4949
SimpleAggregateUDF,
5050
};
51-
use datafusion_physical_expr::expressions::AvgAccumulator;
51+
use datafusion_functions_aggregate::average::AvgAccumulator;
52+
5253
/// Test to show the contents of the setup
5354
#[tokio::test]
5455
async fn test_setup() {

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> {
5151
let actual = plan_and_collect(&ctx, sql).await.unwrap();
5252
let expected = [
5353
"+------------------------------------------+",
54-
"| AVG(custom_sqrt(aggregate_test_100.c11)) |",
54+
"| avg(custom_sqrt(aggregate_test_100.c11)) |",
5555
"+------------------------------------------+",
5656
"| 0.6584408483418835 |",
5757
"+------------------------------------------+",
@@ -69,7 +69,7 @@ async fn csv_query_avg_sqrt() -> Result<()> {
6969
let actual = plan_and_collect(&ctx, sql).await.unwrap();
7070
let expected = [
7171
"+------------------------------------------+",
72-
"| AVG(custom_sqrt(aggregate_test_100.c12)) |",
72+
"| avg(custom_sqrt(aggregate_test_100.c12)) |",
7373
"+------------------------------------------+",
7474
"| 0.6706002946036459 |",
7575
"+------------------------------------------+",

datafusion/expr/src/aggregate_function.rs

-22
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ pub enum AggregateFunction {
3737
Min,
3838
/// Maximum
3939
Max,
40-
/// Average
41-
Avg,
4240
/// Aggregation into an array
4341
ArrayAgg,
4442
/// N'th value in a group according to some ordering
@@ -55,7 +53,6 @@ impl AggregateFunction {
5553
match self {
5654
Min => "MIN",
5755
Max => "MAX",
58-
Avg => "AVG",
5956
ArrayAgg => "ARRAY_AGG",
6057
NthValue => "NTH_VALUE",
6158
Correlation => "CORR",
@@ -75,9 +72,7 @@ impl FromStr for AggregateFunction {
7572
fn from_str(name: &str) -> Result<AggregateFunction> {
7673
Ok(match name {
7774
// general
78-
"avg" => AggregateFunction::Avg,
7975
"max" => AggregateFunction::Max,
80-
"mean" => AggregateFunction::Avg,
8176
"min" => AggregateFunction::Min,
8277
"array_agg" => AggregateFunction::ArrayAgg,
8378
"nth_value" => AggregateFunction::NthValue,
@@ -123,7 +118,6 @@ impl AggregateFunction {
123118
AggregateFunction::Correlation => {
124119
correlation_return_type(&coerced_data_types[0])
125120
}
126-
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
127121
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
128122
"item",
129123
coerced_data_types[0].clone(),
@@ -135,19 +129,6 @@ impl AggregateFunction {
135129
}
136130
}
137131

138-
/// Returns the internal sum datatype of the avg aggregate function.
139-
pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
140-
// Note that this function *must* return the same type that the respective physical expression returns
141-
// or the execution panics.
142-
let fun = AggregateFunction::Avg;
143-
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
144-
&fun,
145-
input_expr_types,
146-
&fun.signature(),
147-
)?;
148-
avg_sum_type(&coerced_data_types[0])
149-
}
150-
151132
impl AggregateFunction {
152133
/// the signatures supported by the function `fun`.
153134
pub fn signature(&self) -> Signature {
@@ -168,9 +149,6 @@ impl AggregateFunction {
168149
.collect::<Vec<_>>();
169150
Signature::uniform(1, valid, Volatility::Immutable)
170151
}
171-
AggregateFunction::Avg => {
172-
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
173-
}
174152
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
175153
AggregateFunction::Correlation => {
176154
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)

datafusion/expr/src/expr.rs

-7
Original file line numberDiff line numberDiff line change
@@ -2280,7 +2280,6 @@ mod test {
22802280
"nth_value",
22812281
"min",
22822282
"max",
2283-
"avg",
22842283
];
22852284
for name in names {
22862285
let fun = find_df_window_func(name).unwrap();
@@ -2309,12 +2308,6 @@ mod test {
23092308
aggregate_function::AggregateFunction::Min
23102309
))
23112310
);
2312-
assert_eq!(
2313-
find_df_window_func("avg"),
2314-
Some(WindowFunctionDefinition::AggregateFunction(
2315-
aggregate_function::AggregateFunction::Avg
2316-
))
2317-
);
23182311
assert_eq!(
23192312
find_df_window_func("cume_dist"),
23202313
Some(WindowFunctionDefinition::BuiltInWindowFunction(

datafusion/expr/src/expr_fn.rs

-12
Original file line numberDiff line numberDiff line change
@@ -183,18 +183,6 @@ pub fn array_agg(expr: Expr) -> Expr {
183183
))
184184
}
185185

186-
/// Create an expression to represent the avg() aggregate function
187-
pub fn avg(expr: Expr) -> Expr {
188-
Expr::AggregateFunction(AggregateFunction::new(
189-
aggregate_function::AggregateFunction::Avg,
190-
vec![expr],
191-
false,
192-
None,
193-
None,
194-
None,
195-
))
196-
}
197-
198186
/// Return a new expression with bitwise AND
199187
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
200188
Expr::BinaryExpr(BinaryExpr::new(

datafusion/expr/src/expr_rewriter/order_by.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ mod test {
156156
use arrow::datatypes::{DataType, Field, Schema};
157157

158158
use crate::{
159-
avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast,
160-
LogicalPlanBuilder,
159+
cast, col, lit, logical_plan::builder::LogicalTableSource, min,
160+
test::function_stub::avg, try_cast, LogicalPlanBuilder,
161161
};
162162

163163
use super::*;
@@ -246,9 +246,9 @@ mod test {
246246
expected: sort(col("c1") + col("MIN(t.c2)")),
247247
},
248248
TestCase {
249-
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
249+
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
250250
input: sort(avg(col("c3"))),
251-
expected: sort(col("AVG(t.c3)").alias("average")),
251+
expected: sort(col("avg(t.c3)").alias("average")),
252252
},
253253
];
254254

0 commit comments

Comments
 (0)