From e015482750e9e08bd426bfcf649445d53705c51a Mon Sep 17 00:00:00 2001 From: kosiew Date: Tue, 29 Oct 2024 18:16:50 +0800 Subject: [PATCH] feat: add `cardinality` function to calculate total elements in an array (#937) --- .../common-operations/expressions.rst | 14 ++++++++++++++ python/datafusion/functions.py | 6 ++++++ python/tests/test_functions.py | 18 ++++++++++++++++++ src/functions.rs | 2 ++ 4 files changed, 40 insertions(+) diff --git a/docs/source/user-guide/common-operations/expressions.rst b/docs/source/user-guide/common-operations/expressions.rst index 77f3359f..23430d35 100644 --- a/docs/source/user-guide/common-operations/expressions.rst +++ b/docs/source/user-guide/common-operations/expressions.rst @@ -96,6 +96,20 @@ This function returns a boolean indicating whether the array is empty. In this example, the `is_empty` column will contain `True` for the first row and `False` for the second row. +To get the total number of elements in an array, you can use the function :py:func:`datafusion.functions.cardinality`. +This function returns an integer indicating the total number of elements in the array. + +.. ipython:: python + + from datafusion import SessionContext, col + from datafusion.functions import cardinality + + ctx = SessionContext() + df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]}) + df.select(cardinality(col("a")).alias("num_elements")) + +In this example, the `num_elements` column will contain `3` for both rows. + Structs ------- diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 570a6ce5..e67ba4ae 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -132,6 +132,7 @@ "find_in_set", "first_value", "flatten", + "cardinality", "floor", "from_unixtime", "gcd", @@ -1516,6 +1517,11 @@ def flatten(array: Expr) -> Expr: return Expr(f.flatten(array.expr)) +def cardinality(array: Expr) -> Expr: + """Returns the total number of elements in the array.""" + return Expr(f.cardinality(array.expr)) + + # aggregate functions def approx_distinct( expression: Expr, diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index e6fd41d8..37943e57 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -540,6 +540,24 @@ def test_array_function_flatten(): ) +def test_array_function_cardinality(): + data = [[1, 2, 3], [4, 4, 5, 6]] + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) + df = ctx.create_dataframe([[batch]]) + + stmt = f.cardinality(column("arr")) + py_expr = [len(arr) for arr in data] # Expected lengths: [3, 3] + # assert py_expr lengths + + query_result = df.select(stmt).collect()[0].column(0) + + for a, b in zip(query_result, py_expr): + np.testing.assert_array_equal( + np.array([a.as_py()], dtype=int), np.array([b], dtype=int) + ) + + @pytest.mark.parametrize( ("stmt", "py_expr"), [ diff --git a/src/functions.rs b/src/functions.rs index 4facb6cf..fe3531ba 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -594,6 +594,7 @@ array_fn!(array_intersect, first_array second_array); array_fn!(array_union, array1 array2); array_fn!(array_except, first_array second_array); array_fn!(array_resize, array size value); +array_fn!(cardinality, array); array_fn!(flatten, array); array_fn!(range, start stop step); @@ -1030,6 +1031,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(array_sort))?; m.add_wrapped(wrap_pyfunction!(array_slice))?; m.add_wrapped(wrap_pyfunction!(flatten))?; + m.add_wrapped(wrap_pyfunction!(cardinality))?; // Window Functions m.add_wrapped(wrap_pyfunction!(lead))?;