Skip to content

Commit

Permalink
fix: reorganize parity tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Aug 28, 2024
1 parent abaa9d0 commit f2ac381
Showing 1 changed file with 126 additions and 99 deletions.
225 changes: 126 additions & 99 deletions ibis_substrait/tests/compiler/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,154 +79,181 @@ def datafusion_consumer():
return DatafusionSubstraitConsumer().with_tables(datasets)


all_consumers = ["acero_consumer", "datafusion_consumer"]
def run_parity_test(consumer: SubstraitConsumer, expr):
res_duckdb = sort_pyarrow_table(run_query_duckdb(expr, datasets))

compiler = SubstraitCompiler()

@pytest.fixture
def ibis_projection():
return orders["order_id", "order_total"]
res_compare = sort_pyarrow_table(consumer.execute(compiler.compile(expr)))

assert res_compare.equals(res_duckdb)

@pytest.fixture
def ibis_mutate():
return orders.mutate(order_total_plus_1=orders["order_total"] + 1)

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_projection(consumer: str, request):
expr = orders["order_id", "order_total"]
run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_sort():
return orders.order_by("order_total")

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_mutate(consumer: str, request):
expr = orders.mutate(order_total_plus_1=orders["order_total"] + 1)
run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_sort_limit():
return orders.order_by("order_total").limit(2)

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_sort(consumer: str, request):
expr = orders.order_by("order_total")
run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_filter():
return orders.filter(lambda t: t.order_total > 30)

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_sort_limit(consumer: str, request):
expr = orders.order_by("order_total").limit(2)
run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_inner_join():
return orders.join(stores, orders["fk_store_id"] == stores["store_id"])

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_filter(consumer: str, request):
expr = orders.filter(lambda t: t.order_total > 30)
run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_left_join():
return orders.join(stores, orders["fk_store_id"] == stores["store_id"], how="left")

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_inner_join(consumer: str, request):
expr = orders.join(stores, orders["fk_store_id"] == stores["store_id"])
run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_filter_groupby():

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_left_join(consumer: str, request):
expr = orders.join(stores, orders["fk_store_id"] == stores["store_id"], how="left")
run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.mark.parametrize(
"consumer",
[
"acero_consumer",
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="")],
),
],
)
def test_filter_groupby(consumer: str, request):
filter_table = orders.join(
stores, orders["fk_store_id"] == stores["store_id"]
).filter(lambda t: t.order_total > 30)

return filter_table.group_by("city").aggregate(
expr = filter_table.group_by("city").aggregate(
sales=filter_table["order_id"].count()
)

run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_filter_groupby_count_distinct():

@pytest.mark.parametrize(
"consumer",
[
pytest.param(
"acero_consumer",
marks=[
pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented")
],
),
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="")],
),
],
)
def test_filter_groupby_count_distinct(consumer: str, request):
filter_table = orders.join(
stores, orders["fk_store_id"] == stores["store_id"]
).filter(lambda t: t.order_total > 30)

return filter_table.group_by("city").aggregate(sales=filter_table["city"].nunique())
expr = filter_table.group_by("city").aggregate(sales=filter_table["city"].nunique())

run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_aggregate_having():
return orders.aggregate(

@pytest.mark.parametrize(
"consumer",
[
"acero_consumer",
pytest.param(
"datafusion_consumer",
marks=[pytest.mark.xfail(Exception, reason="")],
),
],
)
def test_aggregate_having(consumer: str, request):
expr = orders.aggregate(
[orders.order_id.max().name("amax"), orders.order_id.count().name("acount")],
by="fk_store_id",
having=(_.order_id.count() > 1),
)


@pytest.fixture
def ibis_inner_join_chain():
return orders.join(stores, orders["fk_store_id"] == stores["store_id"]).join(
customers, orders["fk_customer_id"] == customers["customer_id"]
)


@pytest.fixture
def ibis_union():
return orders.union(orders)
run_parity_test(request.getfixturevalue(consumer), expr)


@pytest.fixture
def ibis_window():
return orders.select(
orders["order_total"].mean().over(ibis.window(group_by="fk_store_id"))
@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_inner_join_chain(consumer: str, request):
expr = orders.join(stores, orders["fk_store_id"] == stores["store_id"]).join(
customers, orders["fk_customer_id"] == customers["customer_id"]
)

run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.fixture
def ibis_is_in():
return stores.filter(stores.city.isin(["NY", "LA"]))

@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_union(consumer: str, request):
expr = orders.union(orders)

@pytest.fixture
def ibis_scalar_subquery():
return orders.filter(orders["order_total"] == orders["order_total"].max())


all_exprs = [
"ibis_projection",
"ibis_mutate",
"ibis_sort",
"ibis_sort_limit",
"ibis_filter",
"ibis_inner_join",
"ibis_left_join",
("ibis_filter_groupby", {"datafusion_consumer": (Exception, "")}),
(
"ibis_filter_groupby_count_distinct",
{
"acero_consumer": (pa.ArrowNotImplementedError, "Unimplemented"),
"datafusion_consumer": (Exception, ""),
},
),
("ibis_aggregate_having", {"datafusion_consumer": (Exception, "")}),
"ibis_inner_join_chain",
"ibis_union",
("ibis_window", {"acero_consumer": (pa.ArrowNotImplementedError, "Unimplemented")}),
"ibis_is_in",
(
"ibis_scalar_subquery",
{"acero_consumer": (pa.ArrowNotImplementedError, "Unimplemented")},
),
]
run_parity_test(request.getfixturevalue(consumer), expr)


all_fixtures = [
pytest.param(
c,
e[0] if isinstance(e, tuple) else e,
marks=(
[pytest.mark.xfail(raises=e[1][c][0], reason=e[1][c][1])]
if isinstance(e, tuple) and c in e[1]
else []
@pytest.mark.parametrize(
"consumer",
[
pytest.param(
"acero_consumer",
marks=[
pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented")
],
),
"datafusion_consumer",
],
)
def test_window(consumer: str, request):
expr = orders.select(
orders["order_total"].mean().over(ibis.window(group_by="fk_store_id"))
)
for e in all_exprs
for c in all_consumers
]

run_parity_test(request.getfixturevalue(consumer), expr)

@pytest.mark.parametrize(("consumer", "expr"), all_fixtures)
def test_parity(consumer: str, expr, request):
consumer: SubstraitConsumer = request.getfixturevalue(consumer)
expr = request.getfixturevalue(expr)

res_duckdb = sort_pyarrow_table(run_query_duckdb(expr, datasets))
@pytest.mark.parametrize("consumer", ["acero_consumer", "datafusion_consumer"])
def test_is_in(consumer: str, request):
expr = stores.filter(stores.city.isin(["NY", "LA"]))

compiler = SubstraitCompiler()
run_parity_test(request.getfixturevalue(consumer), expr)

res_compare = sort_pyarrow_table(consumer.execute(compiler.compile(expr)))

assert res_compare.equals(res_duckdb)
@pytest.mark.parametrize(
"consumer",
[
pytest.param(
"acero_consumer",
marks=[
pytest.mark.xfail(pa.ArrowNotImplementedError, reason="Unimplemented")
],
),
"datafusion_consumer",
],
)
def test_scalar_subquery(consumer: str, request):
expr = orders.filter(orders["order_total"] == orders["order_total"].max())

run_parity_test(request.getfixturevalue(consumer), expr)

0 comments on commit f2ac381

Please sign in to comment.