Skip to content

Commit

Permalink
More missing array funcs (#605)
Browse files Browse the repository at this point in the history
* Add `array_distinct` function

* Add `range` function

* Add `list_distinct` alias

* Add `array_intersect` scalar function

* Add `array_union` scalar function

* Add `array_except` scalar function

* Add `array_resize` scalar function
  • Loading branch information
judahrand authored Mar 21, 2024
1 parent 18ac182 commit 6a895c6
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
53 changes: 53 additions & 0 deletions datafusion/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def py_arr_replace(arr, from_, to, n=None):

return new_arr

def py_arr_resize(arr, size, value):
arr = np.asarray(arr)
return np.pad(
arr,
[(0, size - arr.shape[0])],
"constant",
constant_values=value,
)

def py_flatten(arr):
result = []
for elem in arr:
Expand Down Expand Up @@ -259,6 +268,14 @@ def py_flatten(arr):
f.array_dims(col),
lambda: [[len(r)] for r in data],
],
[
f.array_distinct(col),
lambda: [list(set(r)) for r in data],
],
[
f.list_distinct(col),
lambda: [list(set(r)) for r in data],
],
[
f.list_dims(col),
lambda: [[len(r)] for r in data],
Expand Down Expand Up @@ -415,7 +432,43 @@ def py_flatten(arr):
f.list_slice(col, literal(-1), literal(2)),
lambda: [arr[-1:2] for arr in data],
],
[
f.array_intersect(col, literal([3.0, 4.0])),
lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
],
[
f.list_intersect(col, literal([3.0, 4.0])),
lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
],
[
f.array_union(col, literal([12.0, 999.0])),
lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
],
[
f.list_union(col, literal([12.0, 999.0])),
lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
],
[
f.array_except(col, literal([3.0])),
lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
],
[
f.list_except(col, literal([3.0])),
lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
],
[
f.array_resize(col, literal(10), literal(0.0)),
lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
],
[
f.list_resize(col, literal(10), literal(0.0)),
lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
],
[f.flatten(literal(data)), lambda: [py_flatten(data)]],
[
f.range(literal(1), literal(5), literal(2)),
lambda: [np.arange(1, 5, 2)],
],
]

for stmt, py_expr in test_items:
Expand Down
22 changes: 22 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ scalar_function!(trunc, Trunc);
scalar_function!(upper, Upper, "Converts the string to all upper case.");
scalar_function!(make_array, MakeArray);
scalar_function!(array, MakeArray);
scalar_function!(range, Range);
scalar_function!(uuid, Uuid);
scalar_function!(r#struct, Struct); // Use raw identifier since struct is a keyword
scalar_function!(from_unixtime, FromUnixtime);
Expand All @@ -405,6 +406,8 @@ scalar_function!(list_push_back, ArrayAppend);
scalar_function!(array_concat, ArrayConcat);
scalar_function!(array_cat, ArrayConcat);
scalar_function!(array_dims, ArrayDims);
scalar_function!(array_distinct, ArrayDistinct);
scalar_function!(list_distinct, ArrayDistinct);
scalar_function!(list_dims, ArrayDims);
scalar_function!(array_element, ArrayElement);
scalar_function!(array_extract, ArrayElement);
Expand Down Expand Up @@ -444,6 +447,14 @@ scalar_function!(array_replace_all, ArrayReplaceAll);
scalar_function!(list_replace_all, ArrayReplaceAll);
scalar_function!(array_slice, ArraySlice);
scalar_function!(list_slice, ArraySlice);
scalar_function!(array_intersect, ArrayIntersect);
scalar_function!(list_intersect, ArrayIntersect);
scalar_function!(array_union, ArrayUnion);
scalar_function!(list_union, ArrayUnion);
scalar_function!(array_except, ArrayExcept);
scalar_function!(list_except, ArrayExcept);
scalar_function!(array_resize, ArrayResize);
scalar_function!(list_resize, ArrayResize);
scalar_function!(flatten, Flatten);

aggregate_function!(approx_distinct, ApproxDistinct);
Expand Down Expand Up @@ -499,6 +510,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
m.add_wrapped(wrap_pyfunction!(array))?;
m.add_wrapped(wrap_pyfunction!(range))?;
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
Expand Down Expand Up @@ -644,6 +656,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_concat))?;
m.add_wrapped(wrap_pyfunction!(array_cat))?;
m.add_wrapped(wrap_pyfunction!(array_dims))?;
m.add_wrapped(wrap_pyfunction!(array_distinct))?;
m.add_wrapped(wrap_pyfunction!(list_distinct))?;
m.add_wrapped(wrap_pyfunction!(list_dims))?;
m.add_wrapped(wrap_pyfunction!(array_element))?;
m.add_wrapped(wrap_pyfunction!(array_extract))?;
Expand All @@ -661,6 +675,14 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_positions))?;
m.add_wrapped(wrap_pyfunction!(list_positions))?;
m.add_wrapped(wrap_pyfunction!(array_to_string))?;
m.add_wrapped(wrap_pyfunction!(array_intersect))?;
m.add_wrapped(wrap_pyfunction!(list_intersect))?;
m.add_wrapped(wrap_pyfunction!(array_union))?;
m.add_wrapped(wrap_pyfunction!(list_union))?;
m.add_wrapped(wrap_pyfunction!(array_except))?;
m.add_wrapped(wrap_pyfunction!(list_except))?;
m.add_wrapped(wrap_pyfunction!(array_resize))?;
m.add_wrapped(wrap_pyfunction!(list_resize))?;
m.add_wrapped(wrap_pyfunction!(array_join))?;
m.add_wrapped(wrap_pyfunction!(list_to_string))?;
m.add_wrapped(wrap_pyfunction!(list_join))?;
Expand Down

0 comments on commit 6a895c6

Please sign in to comment.