Skip to content

Commit

Permalink
Expose nth value
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Sep 7, 2024
1 parent 85df127 commit 3296e1a
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 6 deletions.
66 changes: 60 additions & 6 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
"named_struct",
"nanvl",
"now",
"nth_value",
"nullif",
"octet_length",
"order_by",
Expand Down Expand Up @@ -1739,9 +1740,18 @@ def covar(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr:
return covar_samp(value_y, value_x, filter)


def max(arg: Expr, distinct: bool = False) -> Expr:
"""Returns the maximum value of the argument."""
return Expr(f.max(arg.expr, distinct=distinct))
def max(expression: Expr, filter: Optional[Expr] = None) -> Expr:
"""Aggregate function that returns the maximum value of the argument.
If using the builder functions described in ref:`_aggregation` this function ignores
the options ``order_by``, ``null_treatment``, and ``distinct``.
Args:
expression: The value to find the maximum of
filter: If provided, only compute against rows for which the filter is true
"""
filter_raw = filter.expr if filter is not None else None
return Expr(f.max(expression.expr, filter=filter_raw))


def mean(expression: Expr, filter: Optional[Expr] = None) -> Expr:
Expand Down Expand Up @@ -1772,9 +1782,18 @@ def median(
return Expr(f.median(expression.expr, distinct=distinct, filter=filter_raw))


def min(arg: Expr, distinct: bool = False) -> Expr:
"""Returns the minimum value of the argument."""
return Expr(f.min(arg.expr, distinct=distinct))
def min(expression: Expr, filter: Optional[Expr] = None) -> Expr:
"""Returns the minimum value of the argument.
If using the builder functions described in ref:`_aggregation` this function ignores
the options ``order_by``, ``null_treatment``, and ``distinct``.
Args:
expression: The value to find the minimum of
filter: If provided, only compute against rows for which the filter is true
"""
filter_raw = filter.expr if filter is not None else None
return Expr(f.min(expression.expr, filter=filter_raw))


def sum(arg: Expr) -> Expr:
Expand Down Expand Up @@ -1933,6 +1952,41 @@ def last_value(
)


def nth_value(
expression: Expr,
n: int,
filter: Optional[Expr] = None,
order_by: Optional[list[Expr]] = None,
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
) -> Expr:
"""Returns the n-th value in a group of values.
This aggregate function will return the n-th value in the partition.
If using the builder functions described in ref:`_aggregation` this function ignores
the option ``distinct``.
Args:
expression: Argument to perform bitwise calculation on
n: Index of value to return. Starts at 1.
filter: If provided, only compute against rows for which the filter is true
order_by: Set the ordering of the expression to evaluate
null_treatment: Assign whether to respect or ignull null values.
"""
order_by_raw = expr_list_to_raw_expr_list(order_by)
filter_raw = filter.expr if filter is not None else None

return Expr(
f.nth_value(
expression.expr,
n,
filter=filter_raw,
order_by=order_by_raw,
null_treatment=null_treatment.value,
)
)


def bit_and(expression: Expr, filter: Optional[Expr] = None) -> Expr:
"""Computes the bitwise AND of the argument.
Expand Down
28 changes: 28 additions & 0 deletions python/datafusion/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
(f.count(), pa.array([3]), False),
(f.count(column("e")), pa.array([2]), False),
(f.count_star(filter=column("a") != 3), pa.array([2]), False),
(f.max(column("a"), filter=column("a") != lit(3)), pa.array([2]), False),
(f.min(column("a"), filter=column("a") != lit(1)), pa.array([2]), False),
],
)
def test_aggregation(df, agg_expr, expected, array_sort):
Expand Down Expand Up @@ -329,6 +331,32 @@ def test_bit_and_bool_fns(df, name, expr, result):
),
[8, 9],
),
("first_value", f.first_value(column("a")), [0, 4]),
(
"nth_value_ordered",
f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]),
[2, 5],
),
(
"nth_value_with_null",
f.nth_value(
column("b"),
3,
order_by=[column("b").sort(ascending=True, nulls_first=False)],
null_treatment=NullTreatment.RESPECT_NULLS,
),
[8, None],
),
(
"nth_value_ignore_null",
f.nth_value(
column("b"),
2,
order_by=[column("b").sort(ascending=True)],
null_treatment=NullTreatment.IGNORE_NULLS,
),
[7, 9],
),
],
)
def test_first_last_value(df_partitioned, name, expr, result) -> None:
Expand Down
15 changes: 15 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,20 @@ pub fn first_value(
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
}

// nth_value requires a non-expr argument
#[pyfunction]
pub fn nth_value(
expr: PyExpr,
n: i64,
distinct: Option<bool>,
filter: Option<PyExpr>,
order_by: Option<Vec<PyExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(vec![expr.expr, lit(n)]);
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
}

fn add_builder_fns_to_window(
window_fn: Expr,
partition_by: Option<Vec<PyExpr>>,
Expand Down Expand Up @@ -1058,6 +1072,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(regr_syy))?;
m.add_wrapped(wrap_pyfunction!(first_value))?;
m.add_wrapped(wrap_pyfunction!(last_value))?;
m.add_wrapped(wrap_pyfunction!(nth_value))?;
m.add_wrapped(wrap_pyfunction!(bit_and))?;
m.add_wrapped(wrap_pyfunction!(bit_or))?;
m.add_wrapped(wrap_pyfunction!(bit_xor))?;
Expand Down

0 comments on commit 3296e1a

Please sign in to comment.