From ca1cf3ad989bdfff62ad5f573f6c631907349e80 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 14 Nov 2024 14:05:31 +0800 Subject: [PATCH] Add make_list and tests for make_list, make_array --- python/datafusion/functions.py | 9 +++++++++ python/tests/test_functions.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 88ea7280..6ad4c50c 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -184,6 +184,7 @@ "lpad", "ltrim", "make_array", + "make_list", "make_date", "max", "md5", @@ -1044,6 +1045,14 @@ def make_array(*args: Expr) -> Expr: return Expr(f.make_array(args)) +def make_list(*args: Expr) -> Expr: + """Returns an array using the specified input expressions. + + This is an alias for :py:func:`make_array`. + """ + return make_array(*args) + + def array(*args: Expr) -> Expr: """Returns an array using the specified input expressions. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index c14cfc2d..0d40032b 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -576,6 +576,37 @@ def test_array_function_cardinality(): ) +@pytest.mark.parametrize("make_func", [f.make_array, f.make_list]) +def test_make_array_functions(make_func): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["Hello", "World", "!"], type=pa.string()), + pa.array([4, 5, 6]), + pa.array(["hello ", " world ", " !"], type=pa.string()), + ], + names=["a", "b", "c"], + ) + df = ctx.create_dataframe([[batch]]) + + stmt = make_func( + column("a").cast(pa.string()), + column("b").cast(pa.string()), + column("c").cast(pa.string()), + ) + py_expr = [ + ["Hello", "4", "hello "], + ["World", "5", " world "], + ["!", "6", " !"], + ] + + query_result = df.select(stmt).collect()[0].column(0) + for a, b in zip(query_result, py_expr): + np.testing.assert_array_equal( + np.array(a.as_py(), dtype=str), np.array(b, dtype=str) + ) + + @pytest.mark.parametrize( ("stmt", "py_expr"), [