Skip to content

Commit 616a748

Browse files
committed
Make first_value and last_value identical in the interface
1 parent 0fc0895 commit 616a748

File tree

3 files changed

+151
-68
lines changed

3 files changed

+151
-68
lines changed

python/datafusion/functions.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -1699,29 +1699,47 @@ def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
16991699
def first_value(
17001700
arg: Expr,
17011701
distinct: bool = False,
1702-
filter: bool = None,
1703-
order_by: Expr | None = None,
1704-
null_treatment: common.NullTreatment | None = None,
1702+
filter: Optional[bool] = None,
1703+
order_by: Optional[list[Expr]] = None,
1704+
null_treatment: Optional[common.NullTreatment] = None,
17051705
) -> Expr:
17061706
"""Returns the first value in a group of values."""
1707+
order_by_cols = [e.expr for e in order_by] if order_by is not None else None
1708+
17071709
return Expr(
17081710
f.first_value(
17091711
arg.expr,
17101712
distinct=distinct,
17111713
filter=filter,
1712-
order_by=order_by,
1714+
order_by=order_by_cols,
17131715
null_treatment=null_treatment,
17141716
)
17151717
)
17161718

17171719

1718-
def last_value(arg: Expr) -> Expr:
1720+
def last_value(
1721+
arg: Expr,
1722+
distinct: bool = False,
1723+
filter: Optional[bool] = None,
1724+
order_by: Optional[list[Expr]] = None,
1725+
null_treatment: Optional[common.NullTreatment] = None,
1726+
) -> Expr:
17191727
"""Returns the last value in a group of values.
17201728
17211729
To set parameters on this expression, use ``.order_by()``, ``.distinct()``,
17221730
``.filter()``, or ``.null_treatment()``.
17231731
"""
1724-
return Expr(f.last_value(arg.expr))
1732+
order_by_cols = [e.expr for e in order_by] if order_by is not None else None
1733+
1734+
return Expr(
1735+
f.last_value(
1736+
arg.expr,
1737+
distinct=distinct,
1738+
filter=filter,
1739+
order_by=order_by_cols,
1740+
null_treatment=null_treatment,
1741+
)
1742+
)
17251743

17261744

17271745
def bit_and(arg: Expr, distinct: bool = False) -> Expr:

python/datafusion/tests/test_functions.py

+99-53
Original file line numberDiff line numberDiff line change
@@ -567,45 +567,86 @@ def test_array_function_obj_tests(stmt, py_expr):
567567
assert a == b
568568

569569

570-
@pytest.mark.parametrize("function, expected_result", [
571-
(f.ascii(column("a")), pa.array([72, 87, 33], type=pa.int32())), # H = 72; W = 87; ! = 33
572-
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
573-
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
574-
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
575-
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
576-
(f.concat_ws("-", column("a"), literal("test")), pa.array(["Hello-test", "World-test", "!-test"])),
577-
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
578-
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
579-
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
580-
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
581-
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
582-
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
583-
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
584-
(f.md5(column("a")), pa.array([
585-
"8b1a9953c4611296a827abf8c47804d7",
586-
"f5a7924e621e84c9280a9a27e1bcb7f6",
587-
"9033e0e305f247c0c3c80d0c7848c8b3",
588-
])),
589-
(f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
590-
(f.repeat(column("a"), literal(2)), pa.array(["HelloHello", "WorldWorld", "!!"])),
591-
(f.replace(column("a"), literal("l"), literal("?")), pa.array(["He??o", "Wor?d", "!"])),
592-
(f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])),
593-
(f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])),
594-
(f.rpad(column("a"), literal(8)), pa.array(["Hello ", "World ", "! "])),
595-
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
596-
(f.split_part(column("a"), literal("l"), literal(1)), pa.array(["He", "Wor", "!"])),
597-
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
598-
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
599-
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
600-
(f.translate(column("a"), literal("or"), literal("ld")), pa.array(["Helll", "Wldld", "!"])),
601-
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
602-
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
603-
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
604-
(f.overlay(column("a"), literal("--"), literal(2)), pa.array(["H--lo", "W--ld", "--"])),
605-
(f.regexp_like(column("a"), literal("(ell|orl)")), pa.array([True, True, False])),
606-
(f.regexp_match(column("a"), literal("(ell|orl)")), pa.array([["ell"], ["orl"], None])),
607-
(f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")), pa.array(["H-o", "W-d", "!"])),
608-
])
570+
@pytest.mark.parametrize(
571+
"function, expected_result",
572+
[
573+
(
574+
f.ascii(column("a")),
575+
pa.array([72, 87, 33], type=pa.int32()),
576+
), # H = 72; W = 87; ! = 33
577+
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
578+
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
579+
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
580+
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
581+
(
582+
f.concat_ws("-", column("a"), literal("test")),
583+
pa.array(["Hello-test", "World-test", "!-test"]),
584+
),
585+
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
586+
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
587+
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
588+
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
589+
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
590+
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
591+
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
592+
(
593+
f.md5(column("a")),
594+
pa.array(
595+
[
596+
"8b1a9953c4611296a827abf8c47804d7",
597+
"f5a7924e621e84c9280a9a27e1bcb7f6",
598+
"9033e0e305f247c0c3c80d0c7848c8b3",
599+
]
600+
),
601+
),
602+
(f.octet_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
603+
(
604+
f.repeat(column("a"), literal(2)),
605+
pa.array(["HelloHello", "WorldWorld", "!!"]),
606+
),
607+
(
608+
f.replace(column("a"), literal("l"), literal("?")),
609+
pa.array(["He??o", "Wor?d", "!"]),
610+
),
611+
(f.reverse(column("a")), pa.array(["olleH", "dlroW", "!"])),
612+
(f.right(column("a"), literal(4)), pa.array(["ello", "orld", "!"])),
613+
(
614+
f.rpad(column("a"), literal(8)),
615+
pa.array(["Hello ", "World ", "! "]),
616+
),
617+
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
618+
(
619+
f.split_part(column("a"), literal("l"), literal(1)),
620+
pa.array(["He", "Wor", "!"]),
621+
),
622+
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
623+
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
624+
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
625+
(
626+
f.translate(column("a"), literal("or"), literal("ld")),
627+
pa.array(["Helll", "Wldld", "!"]),
628+
),
629+
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
630+
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
631+
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
632+
(
633+
f.overlay(column("a"), literal("--"), literal(2)),
634+
pa.array(["H--lo", "W--ld", "--"]),
635+
),
636+
(
637+
f.regexp_like(column("a"), literal("(ell|orl)")),
638+
pa.array([True, True, False]),
639+
),
640+
(
641+
f.regexp_match(column("a"), literal("(ell|orl)")),
642+
pa.array([["ell"], ["orl"], None]),
643+
),
644+
(
645+
f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")),
646+
pa.array(["H-o", "W-d", "!"]),
647+
),
648+
],
649+
)
609650
def test_string_functions(df, function, expected_result):
610651
df = df.select(function)
611652
result = df.collect()
@@ -849,27 +890,30 @@ def test_regr_funcs_sql_2():
849890
assert result_sql[0].column(8) == pa.array([4], type=pa.float64())
850891

851892

852-
@pytest.mark.parametrize("func, expected", [
853-
pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"),
854-
pytest.param(f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"),
855-
pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"),
856-
pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"),
857-
pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"),
858-
pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"),
859-
pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"),
860-
pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"),
861-
pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy")
862-
])
893+
@pytest.mark.parametrize(
894+
"func, expected",
895+
[
896+
pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"),
897+
pytest.param(
898+
f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"
899+
),
900+
pytest.param(f.regr_count, pa.array([3], type=pa.uint64()), id="regr_count"),
901+
pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"),
902+
pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"),
903+
pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"),
904+
pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"),
905+
pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"),
906+
pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy"),
907+
],
908+
)
863909
def test_regr_funcs_df(func, expected):
864-
865910
# test case based on `regr_*() basic tests
866911
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1
867912

868-
869913
ctx = SessionContext()
870914

871915
# Create a DataFrame
872-
data = {'column1': [1, 2, 3], 'column2': [2, 4, 6]}
916+
data = {"column1": [1, 2, 3], "column2": [2, 4, 6]}
873917
df = ctx.from_pydict(data, name="test_table")
874918

875919
# Perform the regression function using DataFrame API
@@ -900,6 +944,8 @@ def test_first_last_value(df):
900944
assert result.column(3) == pa.array(["!"])
901945
assert result.column(4) == pa.array([6])
902946
assert result.column(5) == pa.array([datetime(2020, 7, 2)])
947+
df.show()
948+
assert False
903949

904950

905951
def test_binary_string_functions(df):

src/functions.rs

+28-9
Original file line numberDiff line numberDiff line change
@@ -319,18 +319,15 @@ pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyEx
319319
}
320320
}
321321

322-
#[pyfunction]
323-
pub fn first_value(
324-
expr: PyExpr,
322+
fn add_builder_fns_to_aggregate(
323+
agg_fn: Expr,
325324
distinct: bool,
326325
filter: Option<PyExpr>,
327326
order_by: Option<Vec<PyExpr>>,
328327
null_treatment: Option<NullTreatment>,
329328
) -> PyResult<PyExpr> {
330-
// If we initialize the UDAF with order_by directly, then it gets over-written by the builder
331-
let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
332-
333-
// luckily, I can guarantee initializing a builder with an `order_by` default of empty vec
329+
// Since ExprFuncBuilder::new() is private, we can guarantee initializing
330+
// a builder with an `order_by` default of empty vec
334331
let order_by = order_by
335332
.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>())
336333
.unwrap_or_default();
@@ -351,8 +348,30 @@ pub fn first_value(
351348
}
352349

353350
#[pyfunction]
354-
pub fn last_value(expr: PyExpr) -> PyExpr {
355-
functions_aggregate::expr_fn::last_value(vec![expr.expr]).into()
351+
pub fn first_value(
352+
expr: PyExpr,
353+
distinct: bool,
354+
filter: Option<PyExpr>,
355+
order_by: Option<Vec<PyExpr>>,
356+
null_treatment: Option<NullTreatment>,
357+
) -> PyResult<PyExpr> {
358+
// If we initialize the UDAF with order_by directly, then it gets over-written by the builder
359+
let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
360+
361+
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
362+
}
363+
364+
#[pyfunction]
365+
pub fn last_value(
366+
expr: PyExpr,
367+
distinct: bool,
368+
filter: Option<PyExpr>,
369+
order_by: Option<Vec<PyExpr>>,
370+
null_treatment: Option<NullTreatment>,
371+
) -> PyResult<PyExpr> {
372+
let agg_fn = functions_aggregate::expr_fn::last_value(vec![expr.expr]);
373+
374+
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
356375
}
357376

358377
#[pyfunction]

0 commit comments

Comments
 (0)