Skip to content

Commit

Permalink
adding unit tests for user defined window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Sep 21, 2024
1 parent 7232b4e commit 8663e77
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 18 deletions.
105 changes: 105 additions & 0 deletions python/datafusion/tests/test_udwf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pyarrow as pa
import pytest

from datafusion import SessionContext, column, udwf, lit, functions as f
from datafusion.udf import WindowEvaluator


class ExponentialSmooth(WindowEvaluator):
"""Interface of a user-defined accumulation."""

def __init__(self) -> None:
self.alpha = 0.9

def evaluate_all(self, values: pa.Array, num_rows: int) -> pa.Array:
results = []
curr_value = 0.0
for idx in range(num_rows):
if idx == 0:
curr_value = values[idx].as_py()
else:
curr_value = values[idx].as_py() * self.alpha + curr_value * (
1.0 - self.alpha
)
results.append(curr_value)

return pa.array(results)


class NotSubclassOfWindowEvaluator:
pass


@pytest.fixture
def df():
ctx = SessionContext()

# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array([0, 1, 2, 3, 4, 5, 6]),
pa.array([7, 4, 3, 8, 9, 1, 6]),
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
],
names=["a", "b", "c"],
)
return ctx.create_dataframe([[batch]])


def test_udwf_errors(df):
with pytest.raises(TypeError):
udwf(
NotSubclassOfWindowEvaluator,
pa.float64(),
pa.float64(),
volatility="immutable",
)


smooth = udwf(
ExponentialSmooth,
pa.float64(),
pa.float64(),
volatility="immutable",
)

data_test_udwf_functions = [
("smooth_udwf", smooth(column("a")), [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889]),
(
"partitioned_udwf",
smooth(column("a")).partition_by(column("c")).build(),
[0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89],
),
(
"ordered_udwf",
smooth(column("a")).order_by(column("b")).build(),
[0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513],
),
]


@pytest.mark.parametrize("name,expr,expected", data_test_udwf_functions)
def test_udwf_functions(df, name, expr, expected):
df = df.select("a", f.round(expr, lit(3)).alias(name))

# execute and collect the first (and only) batch
result = df.sort(column("a")).select(column(name)).collect()[0]

assert result.column(0) == pa.array(expected)
43 changes: 25 additions & 18 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import datafusion._internal as df_internal
from datafusion.expr import Expr
from typing import Callable, TYPE_CHECKING, TypeVar
from typing import Callable, TYPE_CHECKING, TypeVar, Type
from abc import ABCMeta, abstractmethod
from typing import List
from enum import Enum
Expand Down Expand Up @@ -251,10 +251,15 @@ def udaf(
class WindowEvaluator(metaclass=ABCMeta):
"""Evaluator class for user defined window functions (UDWF).
Users should inherit from this class and implement ``evaluate``, ``evaluate_all``,
and/or ``evaluate_all_with_rank``. If using `evaluate` only you will need to
override ``supports_bounded_execution``.
"""
It is up to the user to decide which evaluate function is appropriate.
|``uses_window_frame``|``supports_bounded_execution``|``include_rank``|function_to_implement|
|---|---|----|----|
|False (default) |False (default) |False (default) | ``evaluate_all`` |
|False |True |False | ``evaluate`` |
|False |True/False |True | ``evaluate_all_with_rank`` |
|True |True/False |True/False | ``evaluate`` |
""" # noqa: W505

def memoize(self) -> None:
"""Perform a memoize operation to improve performance.
Expand Down Expand Up @@ -329,15 +334,8 @@ def evaluate_all(self, values: pyarrow.Array, num_rows: int) -> pyarrow.Array:
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
```
"""
if self.supports_bounded_execution() and not self.uses_window_frame():
res = []
for idx in range(0, num_rows):
res.append(self.evaluate(values, self.get_range(idx, num_rows)))
return pyarrow.array(res)
else:
raise
pass

@abstractmethod
def evaluate(self, values: pyarrow.Array, range: tuple[int, int]) -> pyarrow.Scalar:
"""Evaluate window function on a range of rows in an input partition.
Expand All @@ -355,7 +353,6 @@ def evaluate(self, values: pyarrow.Array, range: tuple[int, int]) -> pyarrow.Sca
"""
pass

@abstractmethod
def evaluate_all_with_rank(
self, num_rows: int, ranks_in_partition: list[tuple[int, int]]
) -> pyarrow.Array:
Expand Down Expand Up @@ -383,6 +380,8 @@ def evaluate_all_with_rank(
(2,2),
(3,4),
]
The user must implement this method if ``include_rank`` returns True.
"""
pass

Expand All @@ -399,6 +398,10 @@ def include_rank(self) -> bool:
return False


if TYPE_CHECKING:
_W = TypeVar("_W", bound=WindowEvaluator)


class WindowUDF:
"""Class for performing window user defined functions (UDF).
Expand All @@ -409,9 +412,9 @@ class WindowUDF:
def __init__(
self,
name: str | None,
func: WindowEvaluator,
func: Type[WindowEvaluator],
input_type: pyarrow.DataType,
return_type: _R,
return_type: pyarrow.DataType,
volatility: Volatility | str,
) -> None:
"""Instantiate a user defined window function (UDWF).
Expand All @@ -434,9 +437,9 @@ def __call__(self, *args: Expr) -> Expr:

@staticmethod
def udwf(
func: Callable[..., _R],
func: Type[WindowEvaluator],
input_type: pyarrow.DataType,
return_type: _R,
return_type: pyarrow.DataType,
volatility: Volatility | str,
name: str | None = None,
) -> WindowUDF:
Expand All @@ -452,6 +455,10 @@ def udwf(
Returns:
A user defined window function.
"""
if not issubclass(func, WindowEvaluator):
raise TypeError(
"`func` must implement the abstract base class WindowEvaluator"
)
if name is None:
name = func.__qualname__.lower()
return WindowUDF(
Expand Down

0 comments on commit 8663e77

Please sign in to comment.