Skip to content

Commit

Permalink
feat(datafusion): add some array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Oct 30, 2023
1 parent b4e3f15 commit 0b96b68
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 16 deletions.
3 changes: 2 additions & 1 deletion ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import pandas as pd

_exclude_exp = (exp.Pow,)
_exclude_exp = (exp.Pow, exp.ArrayContains)


# the DataFusion dialect was created to skip the power function to operator transformation
Expand All @@ -66,6 +66,7 @@ class Backend(BaseBackend, CanCreateDatabase, CanCreateSchema):
dialect = "datafusion"
builder = None
supports_in_memory_tables = True
supports_arrays = True

@property
def version(self):
Expand Down
23 changes: 23 additions & 0 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def translate_val(op, **_):
ops.Degrees: "degrees",
ops.NullIf: "nullif",
ops.Pi: "pi",
ops.ArrayContains: "array_contains",
ops.ArrayLength: "array_length",
ops.ArrayRemove: "array_remove_all",
}

for _op, _name in _simple_ops.items():
Expand Down Expand Up @@ -710,3 +713,23 @@ def _if_else(op, *, bool_expr, true_expr, false_null_expr, **_):
@translate_val.register(ops.NotNull)
def _not_null(op, *, arg, **_):
return sg.not_(arg.is_(NULL))


@translate_val.register(ops.ArrayColumn)
def array_column(op, *, cols, **_):
return F.make_array(*cols)


@translate_val.register(ops.ArrayRepeat)
def array_repeat(op, *, arg, times, **_):
return F.flatten(F.array_repeat(arg, times))


@translate_val.register(ops.ArrayConcat)
def array_concat(op, *, arg, **_):
return F.array_concat(*arg)


@translate_val.register(ops.ArrayPosition)
def array_position(op, *, arg, other, **_):
return F.coalesce(F.array_position(arg, other), 0)
4 changes: 3 additions & 1 deletion ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.data import array_types


class TestConf(BackendTest, RoundAwayFromZero):
Expand All @@ -15,7 +16,7 @@ class TestConf(BackendTest, RoundAwayFromZero):
# returned_timestamp_unit = 'ns'
supports_structs = False
supports_json = False
supports_arrays = False
supports_arrays = True
stateful = False
deps = ("datafusion",)

Expand All @@ -24,6 +25,7 @@ def _load_data(self, **_: Any) -> None:
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.register(path, table_name=table_name)
con.register(array_types, table_name="array_types")

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
Expand Down
46 changes: 32 additions & 14 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
# list.


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_column(backend, alltypes, df):
expr = ibis.array([alltypes["double_col"], alltypes["double_col"]])
assert isinstance(expr, ir.ArrayColumn)
Expand Down Expand Up @@ -91,7 +90,7 @@ def test_array_scalar(con, backend):
assert con.execute(expr.typeof()) == ARRAY_BACKEND_TYPES[backend_name]


@pytest.mark.notimpl(["polars", "datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
def test_array_repeat(con):
expr = ibis.array([1.0, 2.0]) * 2

Expand All @@ -102,7 +101,6 @@ def test_array_repeat(con):


# Issues #2370
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_concat(con):
left = ibis.literal([1, 2, 3])
right = ibis.literal([2, 1])
Expand All @@ -113,7 +111,6 @@ def test_array_concat(con):


# Issues #2370
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_concat_variadic(con):
left = ibis.literal([1, 2, 3])
right = ibis.literal([2, 1])
Expand All @@ -124,7 +121,7 @@ def test_array_concat_variadic(con):


# Issues #2370
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=BaseException)
@pytest.mark.notyet(
["postgres", "trino"],
raises=sa.exc.ProgrammingError,
Expand All @@ -139,7 +136,6 @@ def test_array_concat_some_empty(con):
assert np.array_equal(result, expected)


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_radd_concat(con):
left = [1]
right = ibis.literal([2])
Expand All @@ -150,7 +146,6 @@ def test_array_radd_concat(con):
assert np.array_equal(result, expected)


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_length(con):
expr = ibis.literal([1, 2, 3]).length()
assert con.execute(expr.name("tmp")) == 3
Expand Down Expand Up @@ -193,13 +188,21 @@ def test_array_index(con, idx):
["sqlite"], reason="array types are unsupported", raises=NotImplementedError
),
# someone just needs to implement these
pytest.mark.notimpl(["datafusion"], raises=Exception),
)


@builtin_array
@pytest.mark.never(
["clickhouse", "duckdb", "pandas", "pyspark", "snowflake", "polars", "trino"],
[
"clickhouse",
"duckdb",
"pandas",
"pyspark",
"snowflake",
"polars",
"trino",
"datafusion",
],
reason="backend does not flatten array types",
raises=AssertionError,
)
Expand Down Expand Up @@ -234,7 +237,16 @@ def test_array_discovery_postgres(backend):
raises=AssertionError,
)
@pytest.mark.never(
["duckdb", "pandas", "postgres", "pyspark", "snowflake", "polars", "trino"],
[
"duckdb",
"pandas",
"postgres",
"pyspark",
"snowflake",
"polars",
"trino",
"datafusion",
],
reason="backend supports nullable nested types",
raises=AssertionError,
)
Expand Down Expand Up @@ -334,6 +346,7 @@ def test_array_discovery_snowflake(backend):
raises=BadRequest,
)
@pytest.mark.notimpl(["dask"], raises=ValueError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_simple(backend):
array_types = backend.array_types
expected = (
Expand All @@ -350,6 +363,7 @@ def test_unnest_simple(backend):

@builtin_array
@pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_complex(backend):
array_types = backend.array_types
df = array_types.execute()
Expand Down Expand Up @@ -389,6 +403,7 @@ def test_unnest_complex(backend):
raises=AssertionError,
)
@pytest.mark.notimpl(["dask"], raises=ValueError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_idempotent(backend):
array_types = backend.array_types
df = array_types.execute()
Expand All @@ -409,6 +424,7 @@ def test_unnest_idempotent(backend):

@builtin_array
@pytest.mark.notimpl("dask", raises=ValueError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_no_nulls(backend):
array_types = backend.array_types
df = array_types.execute()
Expand All @@ -435,6 +451,7 @@ def test_unnest_no_nulls(backend):

@builtin_array
@pytest.mark.notimpl("dask", raises=ValueError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_default_name(backend):
array_types = backend.array_types
df = array_types.execute()
Expand Down Expand Up @@ -561,10 +578,9 @@ def test_array_filter(backend, con, input, output):

@builtin_array
@pytest.mark.notimpl(
["datafusion", "mssql", "pandas", "polars", "postgres"],
["mssql", "pandas", "polars", "postgres"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["datafusion"], raises=Exception)
@pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError)
@pytest.mark.never(["impala"], reason="array_types table isn't defined")
def test_array_contains(backend, con):
Expand All @@ -577,7 +593,7 @@ def test_array_contains(backend, con):

@builtin_array
@pytest.mark.notimpl(
["dask", "datafusion", "impala", "mssql", "pandas", "polars"],
["dask", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
def test_array_position(backend, con):
Expand All @@ -590,7 +606,7 @@ def test_array_position(backend, con):

@builtin_array
@pytest.mark.notimpl(
["dask", "datafusion", "impala", "mssql", "pandas", "polars"],
["dask", "impala", "mssql", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
def test_array_remove(backend, con):
Expand Down Expand Up @@ -708,6 +724,7 @@ def test_array_intersect(con):
reason="ClickHouse won't accept dicts for struct type values",
)
@pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_unnest_struct(con):
data = {"value": [[{"a": 1}, {"a": 2}], [{"a": 3}, {"a": 4}]]}
t = ibis.memtable(data, schema=ibis.schema({"value": "!array<!struct<a: !int>>"}))
Expand Down Expand Up @@ -754,6 +771,7 @@ def test_zip(backend):
reason="https://github.com/ClickHouse/ClickHouse/issues/41112",
)
@pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
["polars"],
raises=com.OperationNotDefinedError,
Expand Down

0 comments on commit 0b96b68

Please sign in to comment.