Skip to content

Commit 3296e1a

Browse files
committed
Expose nth value
1 parent 85df127 commit 3296e1a

File tree

3 files changed

+103
-6
lines changed

3 files changed

+103
-6
lines changed

python/datafusion/functions.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@
180180
"named_struct",
181181
"nanvl",
182182
"now",
183+
"nth_value",
183184
"nullif",
184185
"octet_length",
185186
"order_by",
@@ -1739,9 +1740,18 @@ def covar(value_y: Expr, value_x: Expr, filter: Optional[Expr] = None) -> Expr:
17391740
return covar_samp(value_y, value_x, filter)
17401741

17411742

1742-
def max(arg: Expr, distinct: bool = False) -> Expr:
1743-
"""Returns the maximum value of the argument."""
1744-
return Expr(f.max(arg.expr, distinct=distinct))
1743+
def max(expression: Expr, filter: Optional[Expr] = None) -> Expr:
1744+
"""Aggregate function that returns the maximum value of the argument.
1745+
1746+
If using the builder functions described in ref:`_aggregation` this function ignores
1747+
the options ``order_by``, ``null_treatment``, and ``distinct``.
1748+
1749+
Args:
1750+
expression: The value to find the maximum of
1751+
filter: If provided, only compute against rows for which the filter is true
1752+
"""
1753+
filter_raw = filter.expr if filter is not None else None
1754+
return Expr(f.max(expression.expr, filter=filter_raw))
17451755

17461756

17471757
def mean(expression: Expr, filter: Optional[Expr] = None) -> Expr:
@@ -1772,9 +1782,18 @@ def median(
17721782
return Expr(f.median(expression.expr, distinct=distinct, filter=filter_raw))
17731783

17741784

1775-
def min(arg: Expr, distinct: bool = False) -> Expr:
1776-
"""Returns the minimum value of the argument."""
1777-
return Expr(f.min(arg.expr, distinct=distinct))
1785+
def min(expression: Expr, filter: Optional[Expr] = None) -> Expr:
1786+
"""Returns the minimum value of the argument.
1787+
1788+
If using the builder functions described in ref:`_aggregation` this function ignores
1789+
the options ``order_by``, ``null_treatment``, and ``distinct``.
1790+
1791+
Args:
1792+
expression: The value to find the minimum of
1793+
filter: If provided, only compute against rows for which the filter is true
1794+
"""
1795+
filter_raw = filter.expr if filter is not None else None
1796+
return Expr(f.min(expression.expr, filter=filter_raw))
17781797

17791798

17801799
def sum(arg: Expr) -> Expr:
@@ -1933,6 +1952,41 @@ def last_value(
19331952
)
19341953

19351954

1955+
def nth_value(
1956+
expression: Expr,
1957+
n: int,
1958+
filter: Optional[Expr] = None,
1959+
order_by: Optional[list[Expr]] = None,
1960+
null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
1961+
) -> Expr:
1962+
"""Returns the n-th value in a group of values.
1963+
1964+
This aggregate function will return the n-th value in the partition.
1965+
1966+
If using the builder functions described in ref:`_aggregation` this function ignores
1967+
the option ``distinct``.
1968+
1969+
Args:
1970+
expression: Argument to perform bitwise calculation on
1971+
n: Index of value to return. Starts at 1.
1972+
filter: If provided, only compute against rows for which the filter is true
1973+
order_by: Set the ordering of the expression to evaluate
1974+
null_treatment: Assign whether to respect or ignull null values.
1975+
"""
1976+
order_by_raw = expr_list_to_raw_expr_list(order_by)
1977+
filter_raw = filter.expr if filter is not None else None
1978+
1979+
return Expr(
1980+
f.nth_value(
1981+
expression.expr,
1982+
n,
1983+
filter=filter_raw,
1984+
order_by=order_by_raw,
1985+
null_treatment=null_treatment.value,
1986+
)
1987+
)
1988+
1989+
19361990
def bit_and(expression: Expr, filter: Optional[Expr] = None) -> Expr:
19371991
"""Computes the bitwise AND of the argument.
19381992

python/datafusion/tests/test_aggregation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
161161
(f.count(), pa.array([3]), False),
162162
(f.count(column("e")), pa.array([2]), False),
163163
(f.count_star(filter=column("a") != 3), pa.array([2]), False),
164+
(f.max(column("a"), filter=column("a") != lit(3)), pa.array([2]), False),
165+
(f.min(column("a"), filter=column("a") != lit(1)), pa.array([2]), False),
164166
],
165167
)
166168
def test_aggregation(df, agg_expr, expected, array_sort):
@@ -329,6 +331,32 @@ def test_bit_and_bool_fns(df, name, expr, result):
329331
),
330332
[8, 9],
331333
),
334+
("first_value", f.first_value(column("a")), [0, 4]),
335+
(
336+
"nth_value_ordered",
337+
f.nth_value(column("a"), 2, order_by=[column("a").sort(ascending=False)]),
338+
[2, 5],
339+
),
340+
(
341+
"nth_value_with_null",
342+
f.nth_value(
343+
column("b"),
344+
3,
345+
order_by=[column("b").sort(ascending=True, nulls_first=False)],
346+
null_treatment=NullTreatment.RESPECT_NULLS,
347+
),
348+
[8, None],
349+
),
350+
(
351+
"nth_value_ignore_null",
352+
f.nth_value(
353+
column("b"),
354+
2,
355+
order_by=[column("b").sort(ascending=True)],
356+
null_treatment=NullTreatment.IGNORE_NULLS,
357+
),
358+
[7, 9],
359+
),
332360
],
333361
)
334362
def test_first_last_value(df_partitioned, name, expr, result) -> None:

src/functions.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,20 @@ pub fn first_value(
807807
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
808808
}
809809

810+
// nth_value requires a non-expr argument
811+
#[pyfunction]
812+
pub fn nth_value(
813+
expr: PyExpr,
814+
n: i64,
815+
distinct: Option<bool>,
816+
filter: Option<PyExpr>,
817+
order_by: Option<Vec<PyExpr>>,
818+
null_treatment: Option<NullTreatment>,
819+
) -> PyResult<PyExpr> {
820+
let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(vec![expr.expr, lit(n)]);
821+
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
822+
}
823+
810824
fn add_builder_fns_to_window(
811825
window_fn: Expr,
812826
partition_by: Option<Vec<PyExpr>>,
@@ -1058,6 +1072,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
10581072
m.add_wrapped(wrap_pyfunction!(regr_syy))?;
10591073
m.add_wrapped(wrap_pyfunction!(first_value))?;
10601074
m.add_wrapped(wrap_pyfunction!(last_value))?;
1075+
m.add_wrapped(wrap_pyfunction!(nth_value))?;
10611076
m.add_wrapped(wrap_pyfunction!(bit_and))?;
10621077
m.add_wrapped(wrap_pyfunction!(bit_or))?;
10631078
m.add_wrapped(wrap_pyfunction!(bit_xor))?;

0 commit comments

Comments
 (0)