Skip to content

Commit 7232b4e

Browse files
committed
Add python wrappers for UDWF
1 parent 0c48f4f commit 7232b4e

File tree

5 files changed

+233
-12
lines changed

5 files changed

+233
-12
lines changed

python/datafusion/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
from .record_batch import RecordBatchStream, RecordBatch
4242

43-
from .udf import ScalarUDF, AggregateUDF, Accumulator
43+
from .udf import ScalarUDF, AggregateUDF, Accumulator, WindowUDF
4444

4545
from .common import (
4646
DFSchema,
@@ -78,6 +78,7 @@
7878
"Database",
7979
"Table",
8080
"AggregateUDF",
81+
"WindowUDF",
8182
"LogicalPlan",
8283
"ExecutionPlan",
8384
"RecordBatch",
@@ -113,3 +114,5 @@ def lit(value):
113114
udf = ScalarUDF.udf
114115

115116
udaf = AggregateUDF.udaf
117+
118+
udwf = WindowUDF.udwf

python/datafusion/context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525
from ._internal import SessionContext as SessionContextInternal
2626
from ._internal import LogicalPlan, ExecutionPlan
2727

28-
from datafusion._internal import AggregateUDF
2928
from datafusion.catalog import Catalog, Table
3029
from datafusion.dataframe import DataFrame
3130
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3231
from datafusion.record_batch import RecordBatchStream
33-
from datafusion.udf import ScalarUDF
32+
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF
3433

3534
from typing import Any, TYPE_CHECKING
3635
from typing_extensions import deprecated
@@ -833,6 +832,10 @@ def register_udaf(self, udaf: AggregateUDF) -> None:
833832
"""Register a user-defined aggregation function (UDAF) with the context."""
834833
self.ctx.register_udaf(udaf._udaf)
835834

835+
def register_udwf(self, udwf: WindowUDF) -> None:
836+
"""Register a user-defined window function (UDWF) with the context."""
837+
self.ctx.register_udwf(udwf._udwf)
838+
836839
def catalog(self, name: str = "datafusion") -> Catalog:
837840
"""Retrieve a catalog by name."""
838841
return self.ctx.catalog(name)

python/datafusion/udf.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,218 @@ def udaf(
246246
state_type=state_type,
247247
volatility=volatility,
248248
)
249+
250+
251+
class WindowEvaluator(metaclass=ABCMeta):
252+
"""Evaluator class for user defined window functions (UDWF).
253+
254+
Users should inherit from this class and implement ``evaluate``, ``evaluate_all``,
255+
and/or ``evaluate_all_with_rank``. If using `evaluate` only you will need to
256+
override ``supports_bounded_execution``.
257+
"""
258+
259+
def memoize(self) -> None:
260+
"""Perform a memoize operation to improve performance.
261+
262+
When the window frame has a fixed beginning (e.g UNBOUNDED
263+
PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
264+
NTH_VALUE do not need the (unbounded) input once they have
265+
seen a certain amount of input.
266+
267+
`memoize` is called after each input batch is processed, and
268+
such functions can save whatever they need
269+
"""
270+
pass
271+
272+
def get_range(self, idx: int, n_rows: int) -> tuple[int, int]:
273+
"""Return the range for the window fuction.
274+
275+
If `uses_window_frame` flag is `false`. This method is used to
276+
calculate required range for the window function during
277+
stateful execution.
278+
279+
Generally there is no required range, hence by default this
280+
returns smallest range(current row). e.g seeing current row is
281+
enough to calculate window result (such as row_number, rank,
282+
etc)
283+
284+
Args:
285+
idx:: Current index
286+
n_rows: Number of rows.
287+
"""
288+
return (idx, idx + 1)
289+
290+
def is_causal(self) -> bool:
291+
"""Get whether evaluator needs future data for its result."""
292+
return False
293+
294+
def evaluate_all(self, values: pyarrow.Array, num_rows: int) -> pyarrow.Array:
295+
"""Evaluate a window function on an entire input partition.
296+
297+
This function is called once per input *partition* for window
298+
functions that *do not use* values from the window frame,
299+
such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, `PERCENT_RANK`,
300+
`CUME_DIST`, `LEAD`, `LAG`).
301+
302+
It produces the result of all rows in a single pass. It
303+
expects to receive the entire partition as the `value` and
304+
must produce an output column with one output row for every
305+
input row.
306+
307+
`num_rows` is required to correctly compute the output in case
308+
`values.len() == 0`
309+
310+
Implementing this function is an optimization: certain window
311+
functions are not affected by the window frame definition or
312+
the query doesn't have a frame, and `evaluate` skips the
313+
(costly) window frame boundary calculation and the overhead of
314+
calling `evaluate` for each output row.
315+
316+
For example, the `LAG` built in window function does not use
317+
the values of its window frame (it can be computed in one shot
318+
on the entire partition with `Self::evaluate_all` regardless of the
319+
window defined in the `OVER` clause)
320+
321+
```sql
322+
lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
323+
```
324+
325+
However, `avg()` computes the average in the window and thus
326+
does use its window frame
327+
328+
```sql
329+
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
330+
```
331+
"""
332+
if self.supports_bounded_execution() and not self.uses_window_frame():
333+
res = []
334+
for idx in range(0, num_rows):
335+
res.append(self.evaluate(values, self.get_range(idx, num_rows)))
336+
return pyarrow.array(res)
337+
else:
338+
raise
339+
340+
@abstractmethod
341+
def evaluate(self, values: pyarrow.Array, range: tuple[int, int]) -> pyarrow.Scalar:
342+
"""Evaluate window function on a range of rows in an input partition.
343+
344+
This is the simplest and most general function to implement
345+
but also the least performant as it creates output one row at
346+
a time. It is typically much faster to implement stateful
347+
evaluation using one of the other specialized methods on this
348+
trait.
349+
350+
Returns a [`ScalarValue`] that is the value of the window
351+
function within `range` for the entire partition. Argument
352+
`values` contains the evaluation result of function arguments
353+
and evaluation results of ORDER BY expressions. If function has a
354+
single argument, `values[1..]` will contain ORDER BY expression results.
355+
"""
356+
pass
357+
358+
@abstractmethod
359+
def evaluate_all_with_rank(
360+
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
361+
) -> pyarrow.Array:
362+
"""Called for window functions that only need the rank of a row.
363+
364+
Evaluate the partition evaluator against the partition using
365+
the row ranks. For example, `RANK(col)` produces
366+
367+
```text
368+
col | rank
369+
--- + ----
370+
A | 1
371+
A | 1
372+
C | 3
373+
D | 4
374+
D | 5
375+
```
376+
377+
For this case, `num_rows` would be `5` and the
378+
`ranks_in_partition` would be called with
379+
380+
```text
381+
[
382+
(0,1),
383+
(2,2),
384+
(3,4),
385+
]
386+
"""
387+
pass
388+
389+
def supports_bounded_execution(self) -> bool:
390+
"""Can the window function be incrementally computed using bounded memory?"""
391+
return False
392+
393+
def uses_window_frame(self) -> bool:
394+
"""Does the window function use the values from the window frame?"""
395+
return False
396+
397+
def include_rank(self) -> bool:
398+
"""Can this function be evaluated with (only) rank?"""
399+
return False
400+
401+
402+
class WindowUDF:
403+
"""Class for performing window user defined functions (UDF).
404+
405+
Window UDFs operate on a partition of rows. See
406+
also :py:class:`ScalarUDF` for operating on a row by row basis.
407+
"""
408+
409+
def __init__(
410+
self,
411+
name: str | None,
412+
func: WindowEvaluator,
413+
input_type: pyarrow.DataType,
414+
return_type: _R,
415+
volatility: Volatility | str,
416+
) -> None:
417+
"""Instantiate a user defined window function (UDWF).
418+
419+
See :py:func:`udwf` for a convenience function and argument
420+
descriptions.
421+
"""
422+
self._udwf = df_internal.WindowUDF(
423+
name, func, input_type, return_type, str(volatility)
424+
)
425+
426+
def __call__(self, *args: Expr) -> Expr:
427+
"""Execute the UDWF.
428+
429+
This function is not typically called by an end user. These calls will
430+
occur during the evaluation of the dataframe.
431+
"""
432+
args_raw = [arg.expr for arg in args]
433+
return Expr(self._udwf.__call__(*args_raw))
434+
435+
@staticmethod
436+
def udwf(
437+
func: Callable[..., _R],
438+
input_type: pyarrow.DataType,
439+
return_type: _R,
440+
volatility: Volatility | str,
441+
name: str | None = None,
442+
) -> WindowUDF:
443+
"""Create a new User Defined Window Function.
444+
445+
Args:
446+
func: The python function.
447+
input_type: The data type of the arguments to ``func``.
448+
return_type: The data type of the return value.
449+
volatility: See :py:class:`Volatility` for allowed values.
450+
name: A descriptive name for the function.
451+
452+
Returns:
453+
A user defined window function.
454+
"""
455+
if name is None:
456+
name = func.__qualname__.lower()
457+
return WindowUDF(
458+
name=name,
459+
func=func,
460+
input_type=input_type,
461+
return_type=return_type,
462+
volatility=volatility,
463+
)

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
9191
m.add_class::<dataframe::PyDataFrame>()?;
9292
m.add_class::<udf::PyScalarUDF>()?;
9393
m.add_class::<udaf::PyAggregateUDF>()?;
94+
m.add_class::<udwf::PyWindowUDF>()?;
9495
m.add_class::<config::PyConfig>()?;
9596
m.add_class::<sql::logical::PyLogicalPlan>()?;
9697
m.add_class::<physical_plan::PyExecutionPlan>()?;

src/udwf.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,27 @@ impl PartitionEvaluator for RustPartitionEvaluator {
8383
.bind(py)
8484
.call_method0("is_causal")
8585
.and_then(|v| v.extract())
86+
.unwrap_or(false)
8687
})
87-
.unwrap_or(false)
8888
}
8989

9090
fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
9191
Python::with_gil(|py| {
92-
// 1. cast args to Pyarrow array
9392
let mut py_args = values
9493
.iter()
9594
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
9695
.collect::<Vec<_>>();
9796
py_args.push(num_rows.to_object(py));
9897
let py_args = PyTuple::new_bound(py, py_args);
9998

100-
// 2. call function
10199
self.evaluator
102100
.bind(py)
103101
.call_method1("evaluate_all", py_args)
104-
.map_err(|e| DataFusionError::Execution(format!("{e}")))
105102
.map(|v| {
106103
let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
107104
make_array(array_data)
108105
})
106+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
109107
})
110108
}
111109

@@ -116,8 +114,9 @@ impl PartitionEvaluator for RustPartitionEvaluator {
116114
.iter()
117115
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
118116
.collect::<Vec<_>>();
119-
py_args.push(range.start.to_object(py));
120-
py_args.push(range.end.to_object(py));
117+
let range_tuple =
118+
PyTuple::new_bound(py, vec![range.start.to_object(py), range.end.to_object(py)]);
119+
py_args.push(range_tuple.into());
121120
let py_args = PyTuple::new_bound(py, py_args);
122121

123122
// 2. call function
@@ -162,8 +161,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
162161
.bind(py)
163162
.call_method0("supports_bounded_execution")
164163
.and_then(|v| v.extract())
164+
.unwrap_or(false)
165165
})
166-
.unwrap_or(false)
167166
}
168167

169168
fn uses_window_frame(&self) -> bool {
@@ -172,8 +171,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
172171
.bind(py)
173172
.call_method0("uses_window_frame")
174173
.and_then(|v| v.extract())
174+
.unwrap_or(false)
175175
})
176-
.unwrap_or(false)
177176
}
178177

179178
fn include_rank(&self) -> bool {
@@ -182,8 +181,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
182181
.bind(py)
183182
.call_method0("include_rank")
184183
.and_then(|v| v.extract())
184+
.unwrap_or(false)
185185
})
186-
.unwrap_or(false)
187186
}
188187
}
189188

0 commit comments

Comments
 (0)