Skip to content

Commit

Permalink
Add python wrappers for UDWF
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Sep 21, 2024
1 parent 0c48f4f commit 7232b4e
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 12 deletions.
5 changes: 4 additions & 1 deletion python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from .record_batch import RecordBatchStream, RecordBatch

from .udf import ScalarUDF, AggregateUDF, Accumulator
from .udf import ScalarUDF, AggregateUDF, Accumulator, WindowUDF

from .common import (
DFSchema,
Expand Down Expand Up @@ -78,6 +78,7 @@
"Database",
"Table",
"AggregateUDF",
"WindowUDF",
"LogicalPlan",
"ExecutionPlan",
"RecordBatch",
Expand Down Expand Up @@ -113,3 +114,5 @@ def lit(value):
udf = ScalarUDF.udf

udaf = AggregateUDF.udaf

udwf = WindowUDF.udwf
7 changes: 5 additions & 2 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
from ._internal import SessionContext as SessionContextInternal
from ._internal import LogicalPlan, ExecutionPlan

from datafusion._internal import AggregateUDF
from datafusion.catalog import Catalog, Table
from datafusion.dataframe import DataFrame
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
from datafusion.record_batch import RecordBatchStream
from datafusion.udf import ScalarUDF
from datafusion.udf import ScalarUDF, AggregateUDF, WindowUDF

from typing import Any, TYPE_CHECKING
from typing_extensions import deprecated
Expand Down Expand Up @@ -833,6 +832,10 @@ def register_udaf(self, udaf: AggregateUDF) -> None:
"""Register a user-defined aggregation function (UDAF) with the context."""
self.ctx.register_udaf(udaf._udaf)

def register_udwf(self, udwf: WindowUDF) -> None:
"""Register a user-defined window function (UDWF) with the context."""
self.ctx.register_udwf(udwf._udwf)

def catalog(self, name: str = "datafusion") -> Catalog:
"""Retrieve a catalog by name."""
return self.ctx.catalog(name)
Expand Down
215 changes: 215 additions & 0 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,218 @@ def udaf(
state_type=state_type,
volatility=volatility,
)


class WindowEvaluator(metaclass=ABCMeta):
"""Evaluator class for user defined window functions (UDWF).
Users should inherit from this class and implement ``evaluate``, ``evaluate_all``,
and/or ``evaluate_all_with_rank``. If using `evaluate` only you will need to
override ``supports_bounded_execution``.
"""

def memoize(self) -> None:
"""Perform a memoize operation to improve performance.
When the window frame has a fixed beginning (e.g UNBOUNDED
PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
NTH_VALUE do not need the (unbounded) input once they have
seen a certain amount of input.
`memoize` is called after each input batch is processed, and
such functions can save whatever they need
"""
pass

def get_range(self, idx: int, n_rows: int) -> tuple[int, int]:
"""Return the range for the window fuction.
If `uses_window_frame` flag is `false`. This method is used to
calculate required range for the window function during
stateful execution.
Generally there is no required range, hence by default this
returns smallest range(current row). e.g seeing current row is
enough to calculate window result (such as row_number, rank,
etc)
Args:
idx:: Current index
n_rows: Number of rows.
"""
return (idx, idx + 1)

def is_causal(self) -> bool:
"""Get whether evaluator needs future data for its result."""
return False

def evaluate_all(self, values: pyarrow.Array, num_rows: int) -> pyarrow.Array:
"""Evaluate a window function on an entire input partition.
This function is called once per input *partition* for window
functions that *do not use* values from the window frame,
such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, `PERCENT_RANK`,
`CUME_DIST`, `LEAD`, `LAG`).
It produces the result of all rows in a single pass. It
expects to receive the entire partition as the `value` and
must produce an output column with one output row for every
input row.
`num_rows` is required to correctly compute the output in case
`values.len() == 0`
Implementing this function is an optimization: certain window
functions are not affected by the window frame definition or
the query doesn't have a frame, and `evaluate` skips the
(costly) window frame boundary calculation and the overhead of
calling `evaluate` for each output row.
For example, the `LAG` built in window function does not use
the values of its window frame (it can be computed in one shot
on the entire partition with `Self::evaluate_all` regardless of the
window defined in the `OVER` clause)
```sql
lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
```
However, `avg()` computes the average in the window and thus
does use its window frame
```sql
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
```
"""
if self.supports_bounded_execution() and not self.uses_window_frame():
res = []
for idx in range(0, num_rows):
res.append(self.evaluate(values, self.get_range(idx, num_rows)))
return pyarrow.array(res)
else:
raise

@abstractmethod
def evaluate(self, values: pyarrow.Array, range: tuple[int, int]) -> pyarrow.Scalar:
"""Evaluate window function on a range of rows in an input partition.
This is the simplest and most general function to implement
but also the least performant as it creates output one row at
a time. It is typically much faster to implement stateful
evaluation using one of the other specialized methods on this
trait.
Returns a [`ScalarValue`] that is the value of the window
function within `range` for the entire partition. Argument
`values` contains the evaluation result of function arguments
and evaluation results of ORDER BY expressions. If function has a
single argument, `values[1..]` will contain ORDER BY expression results.
"""
pass

@abstractmethod
def evaluate_all_with_rank(
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
) -> pyarrow.Array:
"""Called for window functions that only need the rank of a row.
Evaluate the partition evaluator against the partition using
the row ranks. For example, `RANK(col)` produces
```text
col | rank
--- + ----
A | 1
A | 1
C | 3
D | 4
D | 5
```
For this case, `num_rows` would be `5` and the
`ranks_in_partition` would be called with
```text
[
(0,1),
(2,2),
(3,4),
]
"""
pass

def supports_bounded_execution(self) -> bool:
"""Can the window function be incrementally computed using bounded memory?"""
return False

def uses_window_frame(self) -> bool:
"""Does the window function use the values from the window frame?"""
return False

def include_rank(self) -> bool:
"""Can this function be evaluated with (only) rank?"""
return False


class WindowUDF:
"""Class for performing window user defined functions (UDF).
Window UDFs operate on a partition of rows. See
also :py:class:`ScalarUDF` for operating on a row by row basis.
"""

def __init__(
self,
name: str | None,
func: WindowEvaluator,
input_type: pyarrow.DataType,
return_type: _R,
volatility: Volatility | str,
) -> None:
"""Instantiate a user defined window function (UDWF).
See :py:func:`udwf` for a convenience function and argument
descriptions.
"""
self._udwf = df_internal.WindowUDF(
name, func, input_type, return_type, str(volatility)
)

def __call__(self, *args: Expr) -> Expr:
"""Execute the UDWF.
This function is not typically called by an end user. These calls will
occur during the evaluation of the dataframe.
"""
args_raw = [arg.expr for arg in args]
return Expr(self._udwf.__call__(*args_raw))

@staticmethod
def udwf(
func: Callable[..., _R],
input_type: pyarrow.DataType,
return_type: _R,
volatility: Volatility | str,
name: str | None = None,
) -> WindowUDF:
"""Create a new User Defined Window Function.
Args:
func: The python function.
input_type: The data type of the arguments to ``func``.
return_type: The data type of the return value.
volatility: See :py:class:`Volatility` for allowed values.
name: A descriptive name for the function.
Returns:
A user defined window function.
"""
if name is None:
name = func.__qualname__.lower()
return WindowUDF(
name=name,
func=func,
input_type=input_type,
return_type=return_type,
volatility=volatility,
)
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<dataframe::PyDataFrame>()?;
m.add_class::<udf::PyScalarUDF>()?;
m.add_class::<udaf::PyAggregateUDF>()?;
m.add_class::<udwf::PyWindowUDF>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
m.add_class::<physical_plan::PyExecutionPlan>()?;
Expand Down
17 changes: 8 additions & 9 deletions src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,27 @@ impl PartitionEvaluator for RustPartitionEvaluator {
.bind(py)
.call_method0("is_causal")
.and_then(|v| v.extract())
.unwrap_or(false)
})
.unwrap_or(false)
}

fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow array
let mut py_args = values
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
py_args.push(num_rows.to_object(py));
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
self.evaluator
.bind(py)
.call_method1("evaluate_all", py_args)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
.map(|v| {
let array_data = ArrayData::from_pyarrow_bound(&v).unwrap();
make_array(array_data)
})
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})
}

Expand All @@ -116,8 +114,9 @@ impl PartitionEvaluator for RustPartitionEvaluator {
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
py_args.push(range.start.to_object(py));
py_args.push(range.end.to_object(py));
let range_tuple =
PyTuple::new_bound(py, vec![range.start.to_object(py), range.end.to_object(py)]);
py_args.push(range_tuple.into());
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
Expand Down Expand Up @@ -162,8 +161,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
.bind(py)
.call_method0("supports_bounded_execution")
.and_then(|v| v.extract())
.unwrap_or(false)
})
.unwrap_or(false)
}

fn uses_window_frame(&self) -> bool {
Expand All @@ -172,8 +171,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
.bind(py)
.call_method0("uses_window_frame")
.and_then(|v| v.extract())
.unwrap_or(false)
})
.unwrap_or(false)
}

fn include_rank(&self) -> bool {
Expand All @@ -182,8 +181,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
.bind(py)
.call_method0("include_rank")
.and_then(|v| v.extract())
.unwrap_or(false)
})
.unwrap_or(false)
}
}

Expand Down

0 comments on commit 7232b4e

Please sign in to comment.