Skip to content

Commit

Permalink
Update first_value and last_value with the builder parameters that ar…
Browse files Browse the repository at this point in the history
…e relevant
  • Loading branch information
timsaucer committed Sep 7, 2024
1 parent b7262ba commit 91e5f7d
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 40 deletions.
46 changes: 30 additions & 16 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,51 +1864,65 @@ def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr:


def first_value(
arg: Expr,
distinct: bool = False,
expression: Expr,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the first value in a group of values."""
"""Returns the first value in a group of values.
This aggregate function will return the first value in the partition.
If using the builder functions described in ref:`_aggregation` this function ignores
the option ``distinct``.
Args:
expression: Argument to perform bitwise calculation on
filter: If provided, only compute against rows for which the filter is true
order_by: Set the ordering of the expression to evaluate
null_treatment: Assign whether to respect or ignull null values.
"""
order_by_raw = expr_list_to_raw_expr_list(order_by)
filter_raw = filter.expr if filter is not None else None
null_treatment_raw = null_treatment.value if null_treatment is not None else None

return Expr(
f.first_value(
arg.expr,
distinct=distinct,
expression.expr,
filter=filter_raw,
order_by=order_by_raw,
null_treatment=null_treatment_raw,
null_treatment=null_treatment.value,
)
)


def last_value(
arg: Expr,
distinct: bool = False,
expression: Expr,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr]] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the last value in a group of values.
To set parameters on this expression, use ``.order_by()``, ``.distinct()``,
``.filter()``, or ``.null_treatment()``.
This aggregate function will return the last value in the partition.
If using the builder functions described in ref:`_aggregation` this function ignores
the option ``distinct``.
Args:
expression: Argument to perform bitwise calculation on
filter: If provided, only compute against rows for which the filter is true
order_by: Set the ordering of the expression to evaluate
null_treatment: Assign whether to respect or ignull null values.
"""
order_by_raw = expr_list_to_raw_expr_list(order_by)
filter_raw = filter.expr if filter is not None else None
null_treatment_raw = null_treatment.value if null_treatment is not None else None

return Expr(
f.last_value(
arg.expr,
distinct=distinct,
expression.expr,
filter=filter_raw,
order_by=order_by_raw,
null_treatment=null_treatment_raw,
null_treatment=null_treatment.value,
)
)

Expand Down
83 changes: 83 additions & 0 deletions python/datafusion/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from datafusion import SessionContext, column, lit
from datafusion import functions as f
from datafusion.common import NullTreatment


@pytest.fixture
Expand All @@ -41,6 +42,23 @@ def df():
return ctx.create_dataframe([[batch]])


@pytest.fixture
def df_partitioned():
ctx = SessionContext()

# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array([0, 1, 2, 3, 4, 5, 6]),
pa.array([7, None, 7, 8, 9, None, 9]),
pa.array(["A", "A", "A", "A", "B", "B", "B"]),
],
names=["a", "b", "c"],
)

return ctx.create_dataframe([[batch]])


@pytest.fixture
def df_aggregate_100():
ctx = SessionContext()
Expand Down Expand Up @@ -256,3 +274,68 @@ def test_bit_and_bool_fns(df, name, expr, result):
}

assert df.collect()[0].to_pydict() == expected


@pytest.mark.parametrize(
"name,expr,result",
[
("first_value", f.first_value(column("a")), [0, 4]),
(
"first_value_ordered",
f.first_value(column("a"), order_by=[column("a").sort(ascending=False)]),
[3, 6],
),
(
"first_value_with_null",
f.first_value(
column("b"),
order_by=[column("b").sort(ascending=True)],
null_treatment=NullTreatment.RESPECT_NULLS,
),
[None, None],
),
(
"first_value_ignore_null",
f.first_value(
column("b"),
order_by=[column("b").sort(ascending=True)],
null_treatment=NullTreatment.IGNORE_NULLS,
),
[7, 9],
),
("last_value", f.last_value(column("a")), [3, 6]),
(
"last_value_ordered",
f.last_value(column("a"), order_by=[column("a").sort(ascending=False)]),
[0, 4],
),
(
"last_value_with_null",
f.last_value(
column("b"),
order_by=[column("b").sort(ascending=True, nulls_first=False)],
null_treatment=NullTreatment.RESPECT_NULLS,
),
[None, None],
),
(
"last_value_ignore_null",
f.last_value(
column("b"),
order_by=[column("b").sort(ascending=True)],
null_treatment=NullTreatment.IGNORE_NULLS,
),
[8, 9],
),
],
)
def test_first_last_value(df_partitioned, name, expr, result) -> None:
df = df_partitioned.aggregate([column("c")], [expr.alias(name)]).sort(column("c"))
df.show()

expected = {
"c": ["A", "B"],
name: result,
}

assert df.collect()[0].to_pydict() == expected
24 changes: 0 additions & 24 deletions python/datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,30 +942,6 @@ def test_regr_funcs_df(func, expected):
assert result_df[0].column(0) == expected


def test_first_last_value(df):
df = df.aggregate(
[],
[
f.first_value(column("a")),
f.first_value(column("b")),
f.first_value(column("d")),
f.last_value(column("a")),
f.last_value(column("b")),
f.last_value(column("d")),
],
)

result = df.collect()
result = result[0]
assert result.column(0) == pa.array(["Hello"])
assert result.column(1) == pa.array([4])
assert result.column(2) == pa.array([datetime(2022, 12, 31)])
assert result.column(3) == pa.array(["!"])
assert result.column(4) == pa.array([6])
assert result.column(5) == pa.array([datetime(2020, 7, 2)])
df.show()


def test_binary_string_functions(df):
df = df.select(
f.encode(column("a"), literal("base64")),
Expand Down

0 comments on commit 91e5f7d

Please sign in to comment.