Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression on register_udaf #878

Merged
merged 7 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions python/datafusion/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pyarrow.compute as pc
import pytest

from datafusion import Accumulator, SessionContext, column, udaf
from datafusion import Accumulator, column, udaf, udf


class Summarize(Accumulator):
Expand Down Expand Up @@ -60,18 +60,15 @@ def state(self) -> List[pa.Scalar]:


@pytest.fixture
def df():
ctx = SessionContext()

def df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]])
return ctx.create_dataframe([[batch]], name="test_table")


@pytest.mark.skip(reason="df.collect() will hang, need more investigations")
def test_errors(df):
with pytest.raises(TypeError):
udaf(
Expand All @@ -92,8 +89,8 @@ def test_errors(df):
df = df.aggregate([], [accum(column("a"))])

msg = (
"Can't instantiate abstract class MissingMethods with abstract "
"methods evaluate, merge, update"
"Execution error: TypeError: Can't instantiate abstract class MissingMethods "
"without an implementation for abstract methods 'evaluate', 'merge', 'update'"
)
with pytest.raises(Exception, match=msg):
df.collect()
Expand Down Expand Up @@ -132,3 +129,36 @@ def test_group_by(df):
arrays = [batch.column(1) for batch in batches]
joined = pa.concat_arrays(arrays)
assert joined == pa.array([1.0 + 2.0, 3.0])


def test_register_udaf(ctx, df) -> None:
summarize = udaf(
Summarize,
pa.float64(),
pa.float64(),
[pa.float64()],
volatility="immutable",
)

ctx.register_udaf(summarize)

df_result = ctx.sql("select summarize(b) from test_table")

assert df_result.collect()[0][0][0].as_py() == 14.0


def test_register_udf(ctx, df) -> None:
is_null = udf(
lambda x: x.is_null(),
[pa.float64()],
pa.bool_(),
volatility="immutable",
name="is_null",
)

ctx.register_udf(is_null)

df_result = ctx.sql("select is_null(a) from test_table")
result = df_result.collect()[0].column(0)

assert result == pa.array([False, False, False])
4 changes: 2 additions & 2 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(
See :py:func:`udaf` for a convenience function and argument
descriptions.
"""
self._udf = df_internal.AggregateUDF(
self._udaf = df_internal.AggregateUDF(
name, accumulator, input_types, return_type, state_type, str(volatility)
)

Expand All @@ -203,7 +203,7 @@ def __call__(self, *args: Expr) -> Expr:
occur during the evaluation of the dataframe.
"""
args = [arg.expr for arg in args]
return Expr(self._udf.__call__(*args))
return Expr(self._udaf.__call__(*args))

@staticmethod
def udaf(
Expand Down
Loading