Skip to content

Commit

Permalink
Add missing array functions (#551)
Browse files Browse the repository at this point in the history
* Add array_append, array_concat and array_cat

* Add tests for array functions array_append, array_concat and array_cat

* Add array_dims and list_dims

* Add tests for array_dims and list_dims

* Add array_element, array_extract, list_element and list_extract

* Add tests for array_element, array_extract, list_element and list_extract

* Add array_length and list_length
  • Loading branch information
ongchi authored Dec 28, 2023
1 parent 76d7fcf commit b22f82f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
64 changes: 64 additions & 0 deletions datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from datafusion import functions as f
from datafusion import literal

np.seterr(invalid="ignore")


@pytest.fixture
def df():
Expand Down Expand Up @@ -197,6 +199,68 @@ def test_math_functions():
)


def test_array_functions():
data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays(
[np.array(data, dtype=object)], names=["arr"]
)
df = ctx.create_dataframe([[batch]])

col = column("arr")
test_items = [
[
f.array_append(col, literal(99.0)),
lambda: [np.append(arr, 99.0) for arr in data],
],
[
f.array_concat(col, col),
lambda: [np.concatenate([arr, arr]) for arr in data],
],
[
f.array_cat(col, col),
lambda: [np.concatenate([arr, arr]) for arr in data],
],
[
f.array_dims(col),
lambda: [[len(r)] for r in data],
],
[
f.list_dims(col),
lambda: [[len(r)] for r in data],
],
[
f.array_element(col, literal(1)),
lambda: [r[0] for r in data],
],
[
f.array_extract(col, literal(1)),
lambda: [r[0] for r in data],
],
[
f.list_element(col, literal(1)),
lambda: [r[0] for r in data],
],
[
f.list_extract(col, literal(1)),
lambda: [r[0] for r in data],
],
[
f.array_length(col),
lambda: [len(r) for r in data],
],
[
f.list_length(col),
lambda: [len(r) for r in data],
],
]

for stmt, py_expr in test_items:
query_result = df.select(stmt).collect()[0].column(0).tolist()
for a, b in zip(query_result, py_expr()):
np.testing.assert_array_almost_equal(a, b)


def test_string_functions(df):
df = df.select(
f.ascii(column("a")),
Expand Down
27 changes: 27 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ scalar_function!(random, Random);
scalar_function!(encode, Encode);
scalar_function!(decode, Decode);

// Array Functions
scalar_function!(array_append, ArrayAppend);
scalar_function!(array_concat, ArrayConcat);
scalar_function!(array_cat, ArrayConcat);
scalar_function!(array_dims, ArrayDims);
scalar_function!(list_dims, ArrayDims);
scalar_function!(array_element, ArrayElement);
scalar_function!(array_extract, ArrayElement);
scalar_function!(list_element, ArrayElement);
scalar_function!(list_extract, ArrayElement);
scalar_function!(array_length, ArrayLength);
scalar_function!(list_length, ArrayLength);

aggregate_function!(approx_distinct, ApproxDistinct);
aggregate_function!(approx_median, ApproxMedian);
aggregate_function!(approx_percentile_cont, ApproxPercentileCont);
Expand Down Expand Up @@ -546,5 +559,19 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
//Binary String Functions
m.add_wrapped(wrap_pyfunction!(encode))?;
m.add_wrapped(wrap_pyfunction!(decode))?;

// Array Functions
m.add_wrapped(wrap_pyfunction!(array_append))?;
m.add_wrapped(wrap_pyfunction!(array_concat))?;
m.add_wrapped(wrap_pyfunction!(array_cat))?;
m.add_wrapped(wrap_pyfunction!(array_dims))?;
m.add_wrapped(wrap_pyfunction!(list_dims))?;
m.add_wrapped(wrap_pyfunction!(array_element))?;
m.add_wrapped(wrap_pyfunction!(array_extract))?;
m.add_wrapped(wrap_pyfunction!(list_element))?;
m.add_wrapped(wrap_pyfunction!(list_extract))?;
m.add_wrapped(wrap_pyfunction!(array_length))?;
m.add_wrapped(wrap_pyfunction!(list_length))?;

Ok(())
}

0 comments on commit b22f82f

Please sign in to comment.