Skip to content

Commit d50aacf

Browse files
committed
Add to turn any aggregate function into a window function
1 parent 6c8bf5f commit d50aacf

File tree

4 files changed

+112
-34
lines changed

4 files changed

+112
-34
lines changed

python/datafusion/expr.py

+37
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,43 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
542542
"""
543543
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))
544544

545+
def over(
546+
self,
547+
partition_by: Optional[list[Expr]] = None,
548+
window_frame: Optional[WindowFrame] = None,
549+
order_by: Optional[list[SortExpr | Expr]] = None,
550+
null_treatment: Optional[NullTreatment] = None,
551+
) -> Expr:
552+
"""Turn an aggregate function into a window function.
553+
554+
This function turns any aggregate function into a window function. With the
555+
exception of ``partition_by``, how each of the parameters is used is determined
556+
by the underlying aggregate function.
557+
558+
Args:
559+
partition_by: Expressions to partition the window frame on
560+
window_frame: Specify the window frame parameters
561+
order_by: Set ordering within the window frame
562+
null_treatment: Set how to handle null values
563+
"""
564+
partition_by_raw = expr_list_to_raw_expr_list(partition_by)
565+
order_by_raw = sort_list_to_raw_sort_list(order_by)
566+
window_frame_raw = (
567+
window_frame.window_frame if window_frame is not None else None
568+
)
569+
null_treatment_raw = (
570+
null_treatment.value if null_treatment is not None else None
571+
)
572+
573+
return Expr(
574+
self.expr.over(
575+
partition_by=partition_by_raw,
576+
order_by=order_by_raw,
577+
window_frame=window_frame_raw,
578+
null_treatment=null_treatment_raw,
579+
)
580+
)
581+
545582

546583
class ExprFuncBuilder:
547584
def __init__(self, builder: expr_internal.ExprFuncBuilder):

python/datafusion/tests/test_dataframe.py

+13-21
Original file line numberDiff line numberDiff line change
@@ -386,38 +386,30 @@ def test_distinct():
386386
),
387387
[-1, -1, None, 7, -1, -1, None],
388388
),
389-
# TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833
390-
pytest.param(
389+
(
391390
"first_value",
392-
f.window(
393-
"first_value",
394-
[column("a")],
395-
order_by=[f.order_by(column("b"))],
396-
partition_by=[column("c")],
391+
f.first_value(column("a")).over(
392+
partition_by=[column("c")], order_by=[column("b")]
397393
),
398394
[1, 1, 1, 1, 5, 5, 5],
399395
),
400-
pytest.param(
396+
(
401397
"last_value",
402-
f.window("last_value", [column("a")])
403-
.window_frame(WindowFrame("rows", 0, None))
404-
.order_by(column("b"))
405-
.partition_by(column("c"))
406-
.build(),
398+
f.last_value(column("a")).over(
399+
partition_by=[column("c")],
400+
order_by=[column("b")],
401+
window_frame=WindowFrame("rows", None, None),
402+
),
407403
[3, 3, 3, 3, 6, 6, 6],
408404
),
409-
pytest.param(
405+
(
410406
"3rd_value",
411-
f.window(
412-
"nth_value",
413-
[column("b"), literal(3)],
414-
order_by=[f.order_by(column("a"))],
415-
),
407+
f.nth_value(column("b"), 3).over(order_by=[column("a")]),
416408
[None, None, 7, 7, 7, 7, 7],
417409
),
418-
pytest.param(
410+
(
419411
"avg",
420-
f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)),
412+
f.round(f.avg(column("b")).over(order_by=[column("a")]), literal(3)),
421413
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
422414
),
423415
]

src/expr.rs

+43-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use datafusion::logical_expr::utils::exprlist_to_fields;
19-
use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
19+
use datafusion::logical_expr::{
20+
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
21+
};
2022
use pyo3::{basic::CompareOp, prelude::*};
2123
use std::convert::{From, Into};
2224
use std::sync::Arc;
@@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
3941
use crate::expr::binary_expr::PyBinaryExpr;
4042
use crate::expr::column::PyColumn;
4143
use crate::expr::literal::PyLiteral;
44+
use crate::functions::add_builder_fns_to_window;
4245
use crate::sql::logical::PyLogicalPlan;
4346

4447
use self::alias::PyAlias;
@@ -558,6 +561,45 @@ impl PyExpr {
558561
pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
559562
self.expr.clone().window_frame(window_frame.into()).into()
560563
}
564+
565+
#[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
566+
pub fn over(
567+
&self,
568+
partition_by: Option<Vec<PyExpr>>,
569+
window_frame: Option<PyWindowFrame>,
570+
order_by: Option<Vec<PySortExpr>>,
571+
null_treatment: Option<NullTreatment>,
572+
) -> PyResult<PyExpr> {
573+
match &self.expr {
574+
Expr::AggregateFunction(agg_fn) => {
575+
let window_fn = Expr::WindowFunction(WindowFunction::new(
576+
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
577+
agg_fn.args.clone(),
578+
));
579+
580+
add_builder_fns_to_window(
581+
window_fn,
582+
partition_by,
583+
window_frame,
584+
order_by,
585+
null_treatment,
586+
)
587+
}
588+
Expr::WindowFunction(_) => add_builder_fns_to_window(
589+
self.expr.clone(),
590+
partition_by,
591+
window_frame,
592+
order_by,
593+
null_treatment,
594+
),
595+
_ => Err(
596+
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
597+
"Using `over` requires an aggregate function.".to_string(),
598+
))
599+
.into(),
600+
),
601+
}
602+
}
561603
}
562604

563605
#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]

src/functions.rs

+19-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::ptr::null;
19+
1820
use datafusion::functions_aggregate::all_default_aggregate_functions;
1921
use datafusion::logical_expr::window_function;
2022
use datafusion::logical_expr::ExprFunctionExt;
@@ -711,14 +713,15 @@ pub fn string_agg(
711713
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
712714
}
713715

714-
fn add_builder_fns_to_window(
716+
pub(crate) fn add_builder_fns_to_window(
715717
window_fn: Expr,
716718
partition_by: Option<Vec<PyExpr>>,
719+
window_frame: Option<PyWindowFrame>,
717720
order_by: Option<Vec<PySortExpr>>,
721+
null_treatment: Option<NullTreatment>,
718722
) -> PyResult<PyExpr> {
719-
// Since ExprFuncBuilder::new() is private, set an empty partition and then
720-
// override later if appropriate.
721-
let mut builder = window_fn.partition_by(vec![]);
723+
let null_treatment = null_treatment.map(|n| n.into());
724+
let mut builder = window_fn.null_treatment(null_treatment);
722725

723726
if let Some(partition_cols) = partition_by {
724727
builder = builder.partition_by(
@@ -734,6 +737,10 @@ fn add_builder_fns_to_window(
734737
builder = builder.order_by(order_by_cols);
735738
}
736739

740+
if let Some(window_frame) = window_frame {
741+
builder = builder.window_frame(window_frame.into());
742+
}
743+
737744
builder.build().map(|e| e.into()).map_err(|err| err.into())
738745
}
739746

@@ -748,7 +755,7 @@ pub fn lead(
748755
) -> PyResult<PyExpr> {
749756
let window_fn = window_function::lead(arg.expr, Some(shift_offset), default_value);
750757

751-
add_builder_fns_to_window(window_fn, partition_by, order_by)
758+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
752759
}
753760

754761
#[pyfunction]
@@ -762,7 +769,7 @@ pub fn lag(
762769
) -> PyResult<PyExpr> {
763770
let window_fn = window_function::lag(arg.expr, Some(shift_offset), default_value);
764771

765-
add_builder_fns_to_window(window_fn, partition_by, order_by)
772+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
766773
}
767774

768775
#[pyfunction]
@@ -773,7 +780,7 @@ pub fn row_number(
773780
) -> PyResult<PyExpr> {
774781
let window_fn = datafusion::functions_window::expr_fn::row_number();
775782

776-
add_builder_fns_to_window(window_fn, partition_by, order_by)
783+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
777784
}
778785

779786
#[pyfunction]
@@ -784,7 +791,7 @@ pub fn rank(
784791
) -> PyResult<PyExpr> {
785792
let window_fn = window_function::rank();
786793

787-
add_builder_fns_to_window(window_fn, partition_by, order_by)
794+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
788795
}
789796

790797
#[pyfunction]
@@ -795,7 +802,7 @@ pub fn dense_rank(
795802
) -> PyResult<PyExpr> {
796803
let window_fn = window_function::dense_rank();
797804

798-
add_builder_fns_to_window(window_fn, partition_by, order_by)
805+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
799806
}
800807

801808
#[pyfunction]
@@ -806,7 +813,7 @@ pub fn percent_rank(
806813
) -> PyResult<PyExpr> {
807814
let window_fn = window_function::percent_rank();
808815

809-
add_builder_fns_to_window(window_fn, partition_by, order_by)
816+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
810817
}
811818

812819
#[pyfunction]
@@ -817,7 +824,7 @@ pub fn cume_dist(
817824
) -> PyResult<PyExpr> {
818825
let window_fn = window_function::cume_dist();
819826

820-
add_builder_fns_to_window(window_fn, partition_by, order_by)
827+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
821828
}
822829

823830
#[pyfunction]
@@ -829,7 +836,7 @@ pub fn ntile(
829836
) -> PyResult<PyExpr> {
830837
let window_fn = window_function::ntile(arg.into());
831838

832-
add_builder_fns_to_window(window_fn, partition_by, order_by)
839+
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
833840
}
834841

835842
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {

0 commit comments

Comments
 (0)