Skip to content

Commit

Permalink
Make run query generic
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Apr 5, 2024
1 parent ca7703e commit 8188896
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 110 deletions.
34 changes: 34 additions & 0 deletions queries/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import sys
from importlib.metadata import version
from pathlib import Path
from subprocess import run
from typing import TYPE_CHECKING, Any
Expand All @@ -11,6 +12,8 @@
from settings import Settings

if TYPE_CHECKING:
from collections.abc import Callable

import pandas as pd
import polars as pl

Expand Down Expand Up @@ -93,6 +96,37 @@ def _get_query_numbers(library_name: str) -> list[int]:
return sorted(query_numbers)


def run_query_generic(
query: Callable[..., Any],
query_number: int,
library_name: str,
query_checker: Callable[..., None] | None = None,
) -> None:
"""Execute a query."""
with CodeTimer(name=f"Run {library_name} query {query_number}", unit="s") as timer:
result = query()

if settings.run.log_timings:
log_query_timing(
solution=library_name,
version=version(library_name),
query_number=query_number,
time=timer.took,
)

if settings.run.check_results:
if query_checker is None:
msg = "cannot check results if no query checking function is provided"
raise ValueError(msg)
if settings.scale_factor != 1:
msg = f"cannot check results when scale factor is not 1, got {settings.scale_factor}"
raise RuntimeError(msg)
query_checker(result, query_number)

if settings.run.show_results:
print(result)


def check_query_result_pl(result: pl.DataFrame, query_number: int) -> None:
"""Assert that the Polars result of the query is correct."""
from polars.testing import assert_frame_equal
Expand Down
27 changes: 6 additions & 21 deletions queries/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

import dask
import dask.dataframe as dd
from linetimer import CodeTimer

from queries.common_utils import check_query_result_pd, log_query_timing, on_second_call
from queries.common_utils import (
check_query_result_pd,
on_second_call,
run_query_generic,
)
from settings import Settings

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,22 +80,4 @@ def get_part_supp_ds() -> DataFrame:


def run_query(query_number: int, query: Callable[..., Any]) -> None:
with CodeTimer(name=f"Run Dask query {query_number}", unit="s") as timer:
result = query()

if settings.run.log_timings:
log_query_timing(
solution="dask",
version=dask.__version__,
query_number=query_number,
time=timer.took,
)

if settings.run.check_results:
if settings.scale_factor != 1:
msg = f"cannot check results when scale factor is not 1, got {settings.scale_factor}"
raise RuntimeError(msg)
check_query_result_pd(result, query_number)

if settings.run.show_results:
print(result)
run_query_generic(query, query_number, "dask", query_checker=check_query_result_pd)
27 changes: 5 additions & 22 deletions queries/duckdb/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from importlib.metadata import version
from pathlib import Path

import duckdb
from duckdb import DuckDBPyRelation
from linetimer import CodeTimer

from queries.common_utils import check_query_result_pl, log_query_timing
from queries.common_utils import check_query_result_pl, run_query_generic
from settings import Settings

settings = Settings()
Expand Down Expand Up @@ -65,22 +63,7 @@ def get_part_supp_ds() -> str:


def run_query(query_number: int, context: DuckDBPyRelation) -> None:
with CodeTimer(name=f"Run DuckDB query {query_number}", unit="s") as timer:
result = context.pl() # Force DuckDB to materialize

if settings.run.log_timings:
log_query_timing(
solution="duckdb",
version=version("duckdb"),
query_number=query_number,
time=timer.took,
)

if settings.run.check_results:
if settings.scale_factor != 1:
msg = f"cannot check results when scale factor is not 1, got {settings.scale_factor}"
raise RuntimeError(msg)
check_query_result_pl(result, query_number)

if settings.run.show_results:
print(result)
query = context.pl
run_query_generic(
query, query_number, "duckdb", query_checker=check_query_result_pl
)
31 changes: 9 additions & 22 deletions queries/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from typing import TYPE_CHECKING, Any

import pandas as pd
from linetimer import CodeTimer

from queries.common_utils import check_query_result_pd, log_query_timing, on_second_call
from queries.common_utils import (
check_query_result_pd,
on_second_call,
run_query_generic,
)
from settings import Settings

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,23 +71,7 @@ def get_part_supp_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "partsupp")


def run_query(q_num: int, query: Callable[..., Any]) -> None:
with CodeTimer(name=f"Run pandas query {q_num}", unit="s") as timer:
result = query()

if settings.run.log_timings:
log_query_timing(
solution="pandas",
version=pd.__version__,
query_number=q_num,
time=timer.took,
)

if settings.run.check_results:
if settings.scale_factor != 1:
msg = f"cannot check results when scale factor is not 1, got {settings.scale_factor}"
raise RuntimeError(msg)
check_query_result_pd(result, q_num)

if settings.run.show_results:
print(result)
def run_query(query_number: int, query: Callable[..., Any]) -> None:
run_query_generic(
query, query_number, "pandas", query_checker=check_query_result_pd
)
33 changes: 10 additions & 23 deletions queries/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import partial
from pathlib import Path

import polars as pl
from linetimer import CodeTimer

from queries.common_utils import check_query_result_pl, log_query_timing
from queries.common_utils import check_query_result_pl, run_query_generic
from settings import Settings

settings = Settings()
Expand Down Expand Up @@ -56,25 +56,12 @@ def get_part_supp_ds() -> pl.LazyFrame:


def run_query(query_number: int, lf: pl.LazyFrame) -> None:
streaming = settings.run.polars_streaming

if settings.run.polars_show_plan:
print(lf.explain(streaming=settings.run.polars_streaming))

with CodeTimer(name=f"Run Polars query {query_number}", unit="s") as timer:
result = lf.collect(streaming=settings.run.polars_streaming)

if settings.run.log_timings:
log_query_timing(
solution="polars",
version=pl.__version__,
query_number=query_number,
time=timer.took,
)

if settings.run.check_results:
if settings.scale_factor != 1:
msg = f"cannot check results when scale factor is not 1, got {settings.scale_factor}"
raise RuntimeError(msg)
check_query_result_pl(result, query_number)

if settings.run.show_results:
print(result)
print(lf.explain(streaming=streaming))

query = partial(lf.collect, streaming=streaming)
run_query_generic(
query, query_number, "polars", query_checker=check_query_result_pl
)
32 changes: 10 additions & 22 deletions queries/pyspark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from typing import TYPE_CHECKING

from linetimer import CodeTimer
from pyspark.sql import SparkSession

from queries.common_utils import check_query_result_pd, log_query_timing, on_second_call
from queries.common_utils import (
check_query_result_pd,
on_second_call,
run_query_generic,
)
from settings import Settings

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,23 +86,8 @@ def drop_temp_view() -> None:
]


def run_query(query_number: int, query: SparkDF) -> None:
with CodeTimer(name=f"Run PySpark query {query_number}", unit="s") as timer:
result = query.toPandas()

if settings.run.log_timings:
log_query_timing(
solution="pyspark",
version=get_or_create_spark().version,
query_number=query_number,
time=timer.took,
)

if settings.run.check_results:
if settings.scale_factor != 1:
msg = f"cannot check results when scale factor is not 1, got {settings.scale_factor}"
raise RuntimeError(msg)
check_query_result_pd(result, query_number)

if settings.run.show_results:
print(result)
def run_query(query_number: int, df: SparkDF) -> None:
query = df.toPandas
run_query_generic(
query, query_number, "pyspark", query_checker=check_query_result_pd
)

0 comments on commit 8188896

Please sign in to comment.