Skip to content

Commit

Permalink
test: add tests for some arithmetic scalar functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Sep 20, 2024
1 parent 284cf88 commit d02537e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 32 deletions.
59 changes: 57 additions & 2 deletions ibis_substrait/tests/compiler/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def run_query_duckdb(query, datasets):
}


@pytest.fixture
@pytest.fixture(scope="session")
def acero_consumer():
return AceroSubstraitConsumer().with_tables(datasets)


@pytest.fixture
@pytest.fixture(scope="session")
def datafusion_consumer():
return DatafusionSubstraitConsumer().with_tables(datasets)

Expand Down Expand Up @@ -288,3 +288,58 @@ def test_scalar_subquery(consumer: str, request):
expr = orders.filter(orders["order_total"] == orders["order_total"].max())

run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_projection_functions_arithmetic(consumer: str, request):
expr = orders.select(
orders["order_id"] + orders["order_total"],
orders["order_id"] - orders["order_total"],
orders["order_id"] * orders["order_total"],
orders["order_id"] / orders["order_total"],
orders["order_total"] ** ibis.literal(2),
orders["order_total"].sqrt(),
orders["order_total"].exp(),
(orders["order_id"] * 10 - orders["order_total"]).abs(),
)
run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.mark.parametrize(
"consumer",
[
"acero_consumer",
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="NotImplemented")],
),
],
)
def test_projection_functions_arithmetic_negation(consumer: str, request):
expr = orders.select(orders["order_total"].negate())
run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.mark.parametrize(
"consumer",
[
pytest.param(
"acero_consumer",
marks=[
pytest.mark.xfail(
pa.ArrowNotImplementedError,
reason="No conversion function exists to convert the Substrait function",
)
],
),
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="NotImplemented")],
),
],
)
def test_projection_functions_arithmetic_modulus(consumer: str, request):
expr = orders.select(
orders["order_id"] % orders["fk_store_id"],
)
run_parity_test(request.getfixturevalue(consumer), expr)
73 changes: 43 additions & 30 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d02537e

Please sign in to comment.