Skip to content

Commit a00cfbf

Browse files
authored
feat: aggregates as windows (#871)
* Add to turn any aggregate function into a window function * Rename Window to WindowExpr so we can define Window to mean a window definition to be reused * Add unit test to cover default frames * Improve error report
1 parent 6c8bf5f commit a00cfbf

File tree

6 files changed

+183
-48
lines changed

6 files changed

+183
-48
lines changed

python/datafusion/expr.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
Union = expr_internal.Union
9393
Unnest = expr_internal.Unnest
9494
UnnestExpr = expr_internal.UnnestExpr
95-
Window = expr_internal.Window
95+
WindowExpr = expr_internal.WindowExpr
9696

9797
__all__ = [
9898
"Expr",
@@ -154,6 +154,7 @@
154154
"Partitioning",
155155
"Repartition",
156156
"Window",
157+
"WindowExpr",
157158
"WindowFrame",
158159
"WindowFrameBound",
159160
]
@@ -542,6 +543,36 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
542543
"""
543544
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))
544545

546+
def over(self, window: Window) -> Expr:
547+
"""Turn an aggregate function into a window function.
548+
549+
This function turns any aggregate function into a window function. With the
550+
exception of ``partition_by``, how each of the parameters is used is determined
551+
by the underlying aggregate function.
552+
553+
Args:
554+
window: Window definition
555+
"""
556+
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
557+
order_by_raw = sort_list_to_raw_sort_list(window._order_by)
558+
window_frame_raw = (
559+
window._window_frame.window_frame
560+
if window._window_frame is not None
561+
else None
562+
)
563+
null_treatment_raw = (
564+
window._null_treatment.value if window._null_treatment is not None else None
565+
)
566+
567+
return Expr(
568+
self.expr.over(
569+
partition_by=partition_by_raw,
570+
order_by=order_by_raw,
571+
window_frame=window_frame_raw,
572+
null_treatment=null_treatment_raw,
573+
)
574+
)
575+
545576

546577
class ExprFuncBuilder:
547578
def __init__(self, builder: expr_internal.ExprFuncBuilder):
@@ -584,6 +615,30 @@ def build(self) -> Expr:
584615
return Expr(self.builder.build())
585616

586617

618+
class Window:
619+
"""Define reusable window parameters."""
620+
621+
def __init__(
622+
self,
623+
partition_by: Optional[list[Expr]] = None,
624+
window_frame: Optional[WindowFrame] = None,
625+
order_by: Optional[list[SortExpr | Expr]] = None,
626+
null_treatment: Optional[NullTreatment] = None,
627+
) -> None:
628+
"""Construct a window definition.
629+
630+
Args:
631+
partition_by: Partitions for window operation
632+
window_frame: Define the start and end bounds of the window frame
633+
order_by: Set ordering
634+
null_treatment: Indicate how nulls are to be treated
635+
"""
636+
self._partition_by = partition_by
637+
self._window_frame = window_frame
638+
self._order_by = order_by
639+
self._null_treatment = null_treatment
640+
641+
587642
class WindowFrame:
588643
"""Defines a window frame for performing window operations."""
589644

python/datafusion/tests/test_dataframe.py

+54-21
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
literal,
3232
udf,
3333
)
34+
from datafusion.expr import Window
3435

3536

3637
@pytest.fixture
@@ -386,38 +387,32 @@ def test_distinct():
386387
),
387388
[-1, -1, None, 7, -1, -1, None],
388389
),
389-
# TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833
390-
pytest.param(
390+
(
391391
"first_value",
392-
f.window(
393-
"first_value",
394-
[column("a")],
395-
order_by=[f.order_by(column("b"))],
396-
partition_by=[column("c")],
392+
f.first_value(column("a")).over(
393+
Window(partition_by=[column("c")], order_by=[column("b")])
397394
),
398395
[1, 1, 1, 1, 5, 5, 5],
399396
),
400-
pytest.param(
397+
(
401398
"last_value",
402-
f.window("last_value", [column("a")])
403-
.window_frame(WindowFrame("rows", 0, None))
404-
.order_by(column("b"))
405-
.partition_by(column("c"))
406-
.build(),
399+
f.last_value(column("a")).over(
400+
Window(
401+
partition_by=[column("c")],
402+
order_by=[column("b")],
403+
window_frame=WindowFrame("rows", None, None),
404+
)
405+
),
407406
[3, 3, 3, 3, 6, 6, 6],
408407
),
409-
pytest.param(
408+
(
410409
"3rd_value",
411-
f.window(
412-
"nth_value",
413-
[column("b"), literal(3)],
414-
order_by=[f.order_by(column("a"))],
415-
),
410+
f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])),
416411
[None, None, 7, 7, 7, 7, 7],
417412
),
418-
pytest.param(
413+
(
419414
"avg",
420-
f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)),
415+
f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)),
421416
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
422417
),
423418
]
@@ -473,6 +468,44 @@ def test_invalid_window_frame(units, start_bound, end_bound):
473468
WindowFrame(units, start_bound, end_bound)
474469

475470

471+
def test_window_frame_defaults_match_postgres(partitioned_df):
472+
# ref: https://github.com/apache/datafusion-python/issues/688
473+
474+
window_frame = WindowFrame("rows", None, None)
475+
476+
col_a = column("a")
477+
478+
# Using `f.window` with or without an unbounded window_frame produces the same
479+
# results. These tests are included as a regression check but can be removed when
480+
# f.window() is deprecated in favor of using the .over() approach.
481+
no_frame = f.window("avg", [col_a]).alias("no_frame")
482+
with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame")
483+
df_1 = partitioned_df.select(col_a, no_frame, with_frame)
484+
485+
expected = {
486+
"a": [0, 1, 2, 3, 4, 5, 6],
487+
"no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
488+
"with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
489+
}
490+
491+
assert df_1.sort(col_a).to_pydict() == expected
492+
493+
# When order is not set, the default frame should be unounded preceeding to
494+
# unbounded following. When order is set, the default frame is unbounded preceeding
495+
# to current row.
496+
no_order = f.avg(col_a).over(Window()).alias("over_no_order")
497+
with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order")
498+
df_2 = partitioned_df.select(col_a, no_order, with_order)
499+
500+
expected = {
501+
"a": [0, 1, 2, 3, 4, 5, 6],
502+
"over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
503+
"over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
504+
}
505+
506+
assert df_2.sort(col_a).to_pydict() == expected
507+
508+
476509
def test_get_dataframe(tmp_path):
477510
ctx = SessionContext()
478511

src/expr.rs

+44-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use datafusion::logical_expr::utils::exprlist_to_fields;
19-
use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
19+
use datafusion::logical_expr::{
20+
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
21+
};
2022
use pyo3::{basic::CompareOp, prelude::*};
2123
use std::convert::{From, Into};
2224
use std::sync::Arc;
@@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
3941
use crate::expr::binary_expr::PyBinaryExpr;
4042
use crate::expr::column::PyColumn;
4143
use crate::expr::literal::PyLiteral;
44+
use crate::functions::add_builder_fns_to_window;
4245
use crate::sql::logical::PyLogicalPlan;
4346

4447
use self::alias::PyAlias;
@@ -558,6 +561,45 @@ impl PyExpr {
558561
pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
559562
self.expr.clone().window_frame(window_frame.into()).into()
560563
}
564+
565+
#[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
566+
pub fn over(
567+
&self,
568+
partition_by: Option<Vec<PyExpr>>,
569+
window_frame: Option<PyWindowFrame>,
570+
order_by: Option<Vec<PySortExpr>>,
571+
null_treatment: Option<NullTreatment>,
572+
) -> PyResult<PyExpr> {
573+
match &self.expr {
574+
Expr::AggregateFunction(agg_fn) => {
575+
let window_fn = Expr::WindowFunction(WindowFunction::new(
576+
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
577+
agg_fn.args.clone(),
578+
));
579+
580+
add_builder_fns_to_window(
581+
window_fn,
582+
partition_by,
583+
window_frame,
584+
order_by,
585+
null_treatment,
586+
)
587+
}
588+
Expr::WindowFunction(_) => add_builder_fns_to_window(
589+
self.expr.clone(),
590+
partition_by,
591+
window_frame,
592+
order_by,
593+
null_treatment,
594+
),
595+
_ => Err(
596+
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
597+
format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name()),
598+
))
599+
.into(),
600+
),
601+
}
602+
}
561603
}
562604

563605
#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
@@ -749,7 +791,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
749791
m.add_class::<drop_table::PyDropTable>()?;
750792
m.add_class::<repartition::PyPartitioning>()?;
751793
m.add_class::<repartition::PyRepartition>()?;
752-
m.add_class::<window::PyWindow>()?;
794+
m.add_class::<window::PyWindowExpr>()?;
753795
m.add_class::<window::PyWindowFrame>()?;
754796
m.add_class::<window::PyWindowFrameBound>()?;
755797
Ok(())

src/expr/window.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ use super::py_expr_list;
3232

3333
use crate::errors::py_datafusion_err;
3434

35-
#[pyclass(name = "Window", module = "datafusion.expr", subclass)]
35+
#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
3636
#[derive(Clone)]
37-
pub struct PyWindow {
37+
pub struct PyWindowExpr {
3838
window: Window,
3939
}
4040

@@ -62,15 +62,15 @@ pub struct PyWindowFrameBound {
6262
frame_bound: WindowFrameBound,
6363
}
6464

65-
impl From<PyWindow> for Window {
66-
fn from(window: PyWindow) -> Window {
65+
impl From<PyWindowExpr> for Window {
66+
fn from(window: PyWindowExpr) -> Window {
6767
window.window
6868
}
6969
}
7070

71-
impl From<Window> for PyWindow {
72-
fn from(window: Window) -> PyWindow {
73-
PyWindow { window }
71+
impl From<Window> for PyWindowExpr {
72+
fn from(window: Window) -> PyWindowExpr {
73+
PyWindowExpr { window }
7474
}
7575
}
7676

@@ -80,7 +80,7 @@ impl From<WindowFrameBound> for PyWindowFrameBound {
8080
}
8181
}
8282

83-
impl Display for PyWindow {
83+
impl Display for PyWindowExpr {
8484
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
8585
write!(
8686
f,
@@ -103,7 +103,7 @@ impl Display for PyWindowFrame {
103103
}
104104

105105
#[pymethods]
106-
impl PyWindow {
106+
impl PyWindowExpr {
107107
/// Returns the schema of the Window
108108
pub fn schema(&self) -> PyResult<PyDFSchema> {
109109
Ok(self.window.schema.as_ref().clone().into())
@@ -283,7 +283,7 @@ impl PyWindowFrameBound {
283283
}
284284
}
285285

286-
impl LogicalNode for PyWindow {
286+
impl LogicalNode for PyWindowExpr {
287287
fn inputs(&self) -> Vec<PyLogicalPlan> {
288288
vec![self.window.input.as_ref().clone().into()]
289289
}

0 commit comments

Comments
 (0)