Skip to content

Commit

Permalink
fix: datafusion aggregations in parity tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Sep 5, 2024
1 parent c6e10bd commit ac5c557
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions ibis_substrait/tests/compiler/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,37 @@ def test_left_join(consumer: str, request):

@pytest.mark.parametrize(
"consumer",
[
"acero_consumer",
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="")],
),
],
["acero_consumer", "datafusion_consumer"],
)
def test_filter_groupby(consumer: str, request):
filter_table = orders.join(
stores, orders["fk_store_id"] == stores["store_id"]
).filter(lambda t: t.order_total > 30)

expr = filter_table.group_by("city").aggregate(
sales=filter_table["order_id"].count()
expr = (
filter_table.group_by("city")
.aggregate(sales=filter_table["order_id"].count())
.filter(ibis.literal(True))
)

run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.mark.parametrize(
"consumer",
[
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="")],
),
],
)
def test_groupby_datafusion(consumer: str, request):
expr = orders.group_by("fk_store_id").aggregate(sales=orders["order_id"].count())

run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.mark.parametrize(
"consumer",
[
Expand All @@ -187,18 +198,19 @@ def test_filter_groupby(consumer: str, request):
pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented")
],
),
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="")],
),
"datafusion_consumer",
],
)
def test_filter_groupby_count_distinct(consumer: str, request):
filter_table = orders.join(
stores, orders["fk_store_id"] == stores["store_id"]
).filter(lambda t: t.order_total > 30)

expr = filter_table.group_by("city").aggregate(sales=filter_table["city"].nunique())
expr = (
filter_table.group_by("city")
.aggregate(sales=filter_table["city"].nunique())
.filter(ibis.literal(True))
)

run_parity_test(request.getfixturevalue(consumer), expr)

Expand Down

0 comments on commit ac5c557

Please sign in to comment.