From 91e5f7d2fd72a1e8348504b01ab7bd7e056aaf15 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 7 Sep 2024 10:45:40 -0400 Subject: [PATCH] Update first_value and last_value with the builder parameters that are relevant --- python/datafusion/functions.py | 46 ++++++++---- python/datafusion/tests/test_aggregation.py | 83 +++++++++++++++++++++ python/datafusion/tests/test_functions.py | 24 ------ 3 files changed, 113 insertions(+), 40 deletions(-) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 59d3efb2..8b917d00 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1864,51 +1864,65 @@ def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr: def first_value( - arg: Expr, - distinct: bool = False, + expression: Expr, filter: Optional[Expr] = None, order_by: Optional[list[Expr]] = None, - null_treatment: Optional[NullTreatment] = None, + null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: - """Returns the first value in a group of values.""" + """Returns the first value in a group of values. + + This aggregate function will return the first value in the partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the option ``distinct``. + + Args: + expression: Argument to perform bitwise calculation on + filter: If provided, only compute against rows for which the filter is true + order_by: Set the ordering of the expression to evaluate + null_treatment: Assign whether to respect or ignull null values. + """ order_by_raw = expr_list_to_raw_expr_list(order_by) filter_raw = filter.expr if filter is not None else None - null_treatment_raw = null_treatment.value if null_treatment is not None else None return Expr( f.first_value( - arg.expr, - distinct=distinct, + expression.expr, filter=filter_raw, order_by=order_by_raw, - null_treatment=null_treatment_raw, + null_treatment=null_treatment.value, ) ) def last_value( - arg: Expr, - distinct: bool = False, + expression: Expr, filter: Optional[Expr] = None, order_by: Optional[list[Expr]] = None, null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS, ) -> Expr: """Returns the last value in a group of values. - To set parameters on this expression, use ``.order_by()``, ``.distinct()``, - ``.filter()``, or ``.null_treatment()``. + This aggregate function will return the last value in the partition. + + If using the builder functions described in ref:`_aggregation` this function ignores + the option ``distinct``. + + Args: + expression: Argument to perform bitwise calculation on + filter: If provided, only compute against rows for which the filter is true + order_by: Set the ordering of the expression to evaluate + null_treatment: Assign whether to respect or ignull null values. """ order_by_raw = expr_list_to_raw_expr_list(order_by) filter_raw = filter.expr if filter is not None else None - null_treatment_raw = null_treatment.value if null_treatment is not None else None return Expr( f.last_value( - arg.expr, - distinct=distinct, + expression.expr, filter=filter_raw, order_by=order_by_raw, - null_treatment=null_treatment_raw, + null_treatment=null_treatment.value, ) ) diff --git a/python/datafusion/tests/test_aggregation.py b/python/datafusion/tests/test_aggregation.py index 00ff7ca0..e859e17f 100644 --- a/python/datafusion/tests/test_aggregation.py +++ b/python/datafusion/tests/test_aggregation.py @@ -21,6 +21,7 @@ from datafusion import SessionContext, column, lit from datafusion import functions as f +from datafusion.common import NullTreatment @pytest.fixture @@ -41,6 +42,23 @@ def df(): return ctx.create_dataframe([[batch]]) +@pytest.fixture +def df_partitioned(): + ctx = SessionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [ + pa.array([0, 1, 2, 3, 4, 5, 6]), + pa.array([7, None, 7, 8, 9, None, 9]), + pa.array(["A", "A", "A", "A", "B", "B", "B"]), + ], + names=["a", "b", "c"], + ) + + return ctx.create_dataframe([[batch]]) + + @pytest.fixture def df_aggregate_100(): ctx = SessionContext() @@ -256,3 +274,68 @@ def test_bit_and_bool_fns(df, name, expr, result): } assert df.collect()[0].to_pydict() == expected + + +@pytest.mark.parametrize( + "name,expr,result", + [ + ("first_value", f.first_value(column("a")), [0, 4]), + ( + "first_value_ordered", + f.first_value(column("a"), order_by=[column("a").sort(ascending=False)]), + [3, 6], + ), + ( + "first_value_with_null", + f.first_value( + column("b"), + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.RESPECT_NULLS, + ), + [None, None], + ), + ( + "first_value_ignore_null", + f.first_value( + column("b"), + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.IGNORE_NULLS, + ), + [7, 9], + ), + ("last_value", f.last_value(column("a")), [3, 6]), + ( + "last_value_ordered", + f.last_value(column("a"), order_by=[column("a").sort(ascending=False)]), + [0, 4], + ), + ( + "last_value_with_null", + f.last_value( + column("b"), + order_by=[column("b").sort(ascending=True, nulls_first=False)], + null_treatment=NullTreatment.RESPECT_NULLS, + ), + [None, None], + ), + ( + "last_value_ignore_null", + f.last_value( + column("b"), + order_by=[column("b").sort(ascending=True)], + null_treatment=NullTreatment.IGNORE_NULLS, + ), + [8, 9], + ), + ], +) +def test_first_last_value(df_partitioned, name, expr, result) -> None: + df = df_partitioned.aggregate([column("c")], [expr.alias(name)]).sort(column("c")) + df.show() + + expected = { + "c": ["A", "B"], + name: result, + } + + assert df.collect()[0].to_pydict() == expected diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index e7e6d79e..bc5d50cc 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -942,30 +942,6 @@ def test_regr_funcs_df(func, expected): assert result_df[0].column(0) == expected -def test_first_last_value(df): - df = df.aggregate( - [], - [ - f.first_value(column("a")), - f.first_value(column("b")), - f.first_value(column("d")), - f.last_value(column("a")), - f.last_value(column("b")), - f.last_value(column("d")), - ], - ) - - result = df.collect() - result = result[0] - assert result.column(0) == pa.array(["Hello"]) - assert result.column(1) == pa.array([4]) - assert result.column(2) == pa.array([datetime(2022, 12, 31)]) - assert result.column(3) == pa.array(["!"]) - assert result.column(4) == pa.array([6]) - assert result.column(5) == pa.array([datetime(2020, 7, 2)]) - df.show() - - def test_binary_string_functions(df): df = df.select( f.encode(column("a"), literal("base64")),