Skip to content

Commit d163c63

Browse files
migrate grouping to UDAF
Ref: apache/datafusion#10906
1 parent 70069d7 commit d163c63

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

python/datafusion/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ def grouping(arg: Expr, distinct: bool = False) -> Expr:
13261326
13271327
Returns 1 if the value of the argument is aggregated, 0 if not.
13281328
"""
1329-
return Expr(f.grouping([arg.expr], distinct=distinct))
1329+
return Expr(f.grouping(arg.expr, distinct=distinct))
13301330

13311331

13321332
def max(arg: Expr, distinct: bool = False) -> Expr:

src/functions.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ pub fn corr(y: PyExpr, x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
108108
}
109109
}
110110

111+
#[pyfunction]
112+
pub fn grouping(expression: PyExpr, distinct: bool) -> PyResult<PyExpr> {
113+
let expr = functions_aggregate::expr_fn::grouping(expression.expr);
114+
if distinct {
115+
Ok(expr.distinct().build()?.into())
116+
} else {
117+
Ok(expr.into())
118+
}
119+
}
120+
111121
#[pyfunction]
112122
pub fn sum(args: PyExpr) -> PyExpr {
113123
functions_aggregate::expr_fn::sum(args.expr).into()
@@ -799,7 +809,6 @@ array_fn!(flatten, array);
799809
array_fn!(range, start stop step);
800810

801811
aggregate_function!(array_agg, ArrayAgg);
802-
aggregate_function!(grouping, Grouping);
803812
aggregate_function!(max, Max);
804813
aggregate_function!(mean, Avg);
805814
aggregate_function!(min, Min);

0 commit comments

Comments
 (0)