Skip to content

Commit f4a0828

Browse files
migrate approx_percentile_cont, approx_distinct, and approx_median to UDAF
Ref: approx_distinct apache/datafusion#10851 Ref: approx_median apache/datafusion#10840 Ref: approx_percentile_cont and _with_weight apache/datafusion#10917
1 parent 86d1ad0 commit f4a0828

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

python/datafusion/functions.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1233,9 +1233,9 @@ def flatten(array: Expr) -> Expr:
12331233

12341234

12351235
# aggregate functions
1236-
def approx_distinct(arg: Expr) -> Expr:
1236+
def approx_distinct(expression: Expr) -> Expr:
12371237
"""Returns the approximate number of distinct values."""
1238-
return Expr(f.approx_distinct(arg.expr, distinct=True))
1238+
return Expr(f.approx_distinct(expression.expr))
12391239

12401240

12411241
def approx_median(arg: Expr, distinct: bool = False) -> Expr:
@@ -1244,20 +1244,22 @@ def approx_median(arg: Expr, distinct: bool = False) -> Expr:
12441244

12451245

12461246
def approx_percentile_cont(
1247-
expr: Expr,
1247+
expression: Expr,
12481248
percentile: Expr,
1249-
num_centroids: int | None = None,
1249+
# num_centroids: int | None = None,
12501250
distinct: bool = False,
12511251
) -> Expr:
12521252
"""Returns the value that is approximately at a given percentile of ``expr``."""
1253+
# TODO: enable num_centroids
1254+
num_centroids = None
12531255
if num_centroids is None:
12541256
return Expr(
12551257
f.approx_percentile_cont(expr.expr, percentile.expr, distinct=distinct)
12561258
)
12571259

12581260
return Expr(
12591261
f.approx_percentile_cont(
1260-
expr.expr, percentile.expr, num_centroids, distinct=distinct
1262+
expr.expr, percentile.expr, distinct=distinct
12611263
)
12621264
)
12631265

src/functions.rs

+51-7
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,57 @@ use datafusion_expr::{
3737
lit, Expr, WindowFunctionDefinition,
3838
};
3939

40+
#[pyfunction]
41+
pub fn approx_distinct(expression: PyExpr) -> PyExpr {
42+
functions_aggregate::expr_fn::approx_distinct::approx_distinct(expression.expr).into()
43+
}
44+
45+
#[pyfunction]
46+
pub fn approx_median(expression: PyExpr, distinct: bool) -> PyResult<PyExpr> {
47+
// TODO: better builder pattern
48+
let expr = functions_aggregate::expr_fn::approx_median(expression.expr);
49+
if distinct {
50+
Ok(expr.distinct().build()?.into())
51+
} else {
52+
Ok(expr.into())
53+
}
54+
}
55+
56+
#[pyfunction]
57+
pub fn approx_percentile_cont(
58+
expression: PyExpr,
59+
percentile: PyExpr,
60+
distinct: bool,
61+
) -> PyResult<PyExpr> {
62+
// TODO: better builder pattern
63+
let expr =
64+
functions_aggregate::expr_fn::approx_percentile_cont(expression.expr, percentile.expr);
65+
if distinct {
66+
Ok(expr.distinct().build()?.into())
67+
} else {
68+
Ok(expr.into())
69+
}
70+
}
71+
72+
#[pyfunction]
73+
pub fn approx_percentile_cont_with_weight(
74+
expression: PyExpr,
75+
weight: PyExpr,
76+
percentile: PyExpr,
77+
distinct: bool,
78+
) -> PyResult<PyExpr> {
79+
let expr = functions_aggregate::expr_fn::approx_percentile_cont_with_weight(
80+
expression.expr,
81+
weight.expr,
82+
percentile.expr,
83+
);
84+
if distinct {
85+
Ok(expr.distinct().build()?.into())
86+
} else {
87+
Ok(expr.into())
88+
}
89+
}
90+
4091
#[pyfunction]
4192
pub fn sum(args: PyExpr) -> PyExpr {
4293
functions_aggregate::expr_fn::sum(args.expr).into()
@@ -727,13 +778,6 @@ array_fn!(list_resize, array_resize, array size value);
727778
array_fn!(flatten, array);
728779
array_fn!(range, start stop step);
729780

730-
aggregate_function!(approx_distinct, ApproxDistinct);
731-
aggregate_function!(approx_median, ApproxMedian);
732-
aggregate_function!(approx_percentile_cont, ApproxPercentileCont);
733-
aggregate_function!(
734-
approx_percentile_cont_with_weight,
735-
ApproxPercentileContWithWeight
736-
);
737781
aggregate_function!(array_agg, ArrayAgg);
738782
aggregate_function!(avg, Avg);
739783
aggregate_function!(corr, Correlation);

0 commit comments

Comments
 (0)