Skip to content

Commit

Permalink
Minor refactor of data loading (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Apr 12, 2024
1 parent 0c2a359 commit 7b40c5a
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 67 deletions.
26 changes: 13 additions & 13 deletions queries/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path

from dask.dataframe.core import DataFrame

Expand All @@ -23,7 +22,7 @@
dask.config.set(scheduler="threads")


def read_ds(path: Path) -> DataFrame:
def read_ds(table_name: str) -> DataFrame:
# TODO: Load into memory before returning the Dask DataFrame.
# Code below is tripped up by date types
# df = pd.read_parquet(path, dtype_backend="pyarrow")
Expand All @@ -32,11 +31,12 @@ def read_ds(path: Path) -> DataFrame:
msg = "cannot run Dask starting from an in-memory representation"
raise RuntimeError(msg)

path_str = f"{path}.{settings.run.file_type}"
path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"

if settings.run.file_type == "parquet":
return dd.read_parquet(path_str, dtype_backend="pyarrow") # type: ignore[attr-defined,no-any-return]
return dd.read_parquet(path, dtype_backend="pyarrow") # type: ignore[attr-defined,no-any-return]
elif settings.run.file_type == "csv":
df = dd.read_csv(path_str, dtype_backend="pyarrow") # type: ignore[attr-defined]
df = dd.read_csv(path, dtype_backend="pyarrow") # type: ignore[attr-defined]
for c in df.columns:
if c.endswith("date"):
df[c] = df[c].astype("date32[day][pyarrow]")
Expand All @@ -48,42 +48,42 @@ def read_ds(path: Path) -> DataFrame:

@on_second_call
def get_line_item_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "lineitem")
return read_ds("lineitem")


@on_second_call
def get_orders_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "orders")
return read_ds("orders")


@on_second_call
def get_customer_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "customer")
return read_ds("customer")


@on_second_call
def get_region_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "region")
return read_ds("region")


@on_second_call
def get_nation_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "nation")
return read_ds("nation")


@on_second_call
def get_supplier_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "supplier")
return read_ds("supplier")


@on_second_call
def get_part_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "part")
return read_ds("part")


@on_second_call
def get_part_supp_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "partsupp")
return read_ds("partsupp")


def run_query(query_number: int, query: Callable[..., Any]) -> None:
Expand Down
24 changes: 12 additions & 12 deletions queries/duckdb/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

import duckdb
from duckdb import DuckDBPyRelation

Expand All @@ -9,8 +7,10 @@
settings = Settings()


def _scan_ds(path: Path) -> str:
path_str = f"{path}.{settings.run.file_type}"
def _scan_ds(table_name: str) -> str:
path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"
path_str = str(path)

if settings.run.file_type == "parquet":
if settings.run.include_io:
duckdb.read_parquet(path_str)
Expand Down Expand Up @@ -41,35 +41,35 @@ def _scan_ds(path: Path) -> str:


def get_line_item_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "lineitem")
return _scan_ds("lineitem")


def get_orders_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "orders")
return _scan_ds("orders")


def get_customer_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "customer")
return _scan_ds("customer")


def get_region_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "region")
return _scan_ds("region")


def get_nation_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "nation")
return _scan_ds("nation")


def get_supplier_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "supplier")
return _scan_ds("supplier")


def get_part_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "part")
return _scan_ds("part")


def get_part_supp_ds() -> str:
return _scan_ds(settings.dataset_base_dir / "partsupp")
return _scan_ds("partsupp")


def run_query(query_number: int, context: DuckDBPyRelation) -> None:
Expand Down
32 changes: 19 additions & 13 deletions queries/modin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path

settings = Settings()

Expand All @@ -23,55 +22,62 @@
os.environ["MODIN_MEMORY"] = str(settings.run.modin_memory)


def _read_ds(path: Path) -> pd.DataFrame:
path_str = f"{path}.{settings.run.file_type}"
def _read_ds(table_name: str) -> pd.DataFrame:
path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"

if settings.run.file_type == "parquet":
return pd.read_parquet(path_str, dtype_backend="pyarrow")
return pd.read_parquet(path, dtype_backend="pyarrow")
elif settings.run.file_type == "csv":
df = pd.read_csv(path, dtype_backend="pyarrow")
for c in df.columns:
if c.endswith("date"):
df[c] = df[c].astype("date32[day][pyarrow]")
return df
elif settings.run.file_type == "feather":
return pd.read_feather(path_str, dtype_backend="pyarrow")
return pd.read_feather(path, dtype_backend="pyarrow")
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)


@on_second_call
def get_line_item_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "lineitem")
return _read_ds("lineitem")


@on_second_call
def get_orders_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "orders")
return _read_ds("orders")


@on_second_call
def get_customer_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "customer")
return _read_ds("customer")


@on_second_call
def get_region_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "region")
return _read_ds("region")


@on_second_call
def get_nation_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "nation")
return _read_ds("nation")


@on_second_call
def get_supplier_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "supplier")
return _read_ds("supplier")


@on_second_call
def get_part_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "part")
return _read_ds("part")


@on_second_call
def get_part_supp_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "partsupp")
return _read_ds("partsupp")


def run_query(query_number: int, query: Callable[..., Any]) -> None:
Expand Down
28 changes: 14 additions & 14 deletions queries/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,68 +13,68 @@

if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path

settings = Settings()

pd.options.mode.copy_on_write = True


def _read_ds(path: Path) -> pd.DataFrame:
path_str = f"{path}.{settings.run.file_type}"
def _read_ds(table_name: str) -> pd.DataFrame:
path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"

if settings.run.file_type == "parquet":
return pd.read_parquet(path_str, dtype_backend="pyarrow")
return pd.read_parquet(path, dtype_backend="pyarrow")
elif settings.run.file_type == "csv":
df = pd.read_csv(path_str, dtype_backend="pyarrow")
df = pd.read_csv(path, dtype_backend="pyarrow")
for c in df.columns:
if c.endswith("date"):
df[c] = df[c].astype("date32[day][pyarrow]") # type: ignore[call-overload]
return df
elif settings.run.file_type == "feather":
return pd.read_feather(path_str, dtype_backend="pyarrow")
return pd.read_feather(path, dtype_backend="pyarrow")
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)


@on_second_call
def get_line_item_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "lineitem")
return _read_ds("lineitem")


@on_second_call
def get_orders_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "orders")
return _read_ds("orders")


@on_second_call
def get_customer_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "customer")
return _read_ds("customer")


@on_second_call
def get_region_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "region")
return _read_ds("region")


@on_second_call
def get_nation_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "nation")
return _read_ds("nation")


@on_second_call
def get_supplier_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "supplier")
return _read_ds("supplier")


@on_second_call
def get_part_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "part")
return _read_ds("part")


@on_second_call
def get_part_supp_ds() -> pd.DataFrame:
return _read_ds(settings.dataset_base_dir / "partsupp")
return _read_ds("partsupp")


def run_query(query_number: int, query: Callable[..., Any]) -> None:
Expand Down
27 changes: 13 additions & 14 deletions queries/polars/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from functools import partial
from pathlib import Path

import polars as pl

Expand All @@ -14,15 +13,15 @@
)


def _scan_ds(path: Path) -> pl.LazyFrame:
path_str = f"{path}.{settings.run.file_type}"
def _scan_ds(table_name: str) -> pl.LazyFrame:
path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"

if settings.run.file_type == "parquet":
scan = pl.scan_parquet(path_str)
scan = pl.scan_parquet(path)
elif settings.run.file_type == "feather":
scan = pl.scan_ipc(path_str)
scan = pl.scan_ipc(path)
elif settings.run.file_type == "csv":
scan = pl.scan_csv(path_str, try_parse_dates=True)
scan = pl.scan_csv(path, try_parse_dates=True)
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)
Expand All @@ -34,35 +33,35 @@ def _scan_ds(path: Path) -> pl.LazyFrame:


def get_line_item_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "lineitem")
return _scan_ds("lineitem")


def get_orders_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "orders")
return _scan_ds("orders")


def get_customer_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "customer")
return _scan_ds("customer")


def get_region_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "region")
return _scan_ds("region")


def get_nation_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "nation")
return _scan_ds("nation")


def get_supplier_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "supplier")
return _scan_ds("supplier")


def get_part_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "part")
return _scan_ds("part")


def get_part_supp_ds() -> pl.LazyFrame:
return _scan_ds(settings.dataset_base_dir / "partsupp")
return _scan_ds("partsupp")


def run_query(query_number: int, lf: pl.LazyFrame) -> None:
Expand Down
2 changes: 1 addition & 1 deletion settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Run(BaseSettings):
polars_streaming: bool = False
polars_streaming_groupby: bool = False

modin_memory: int = 16_000_000_000 # Tune as needed for optimal performance
modin_memory: int = 8_000_000_000 # Tune as needed for optimal performance

spark_driver_memory: str = "2g" # Tune as needed for optimal performance
spark_executor_memory: str = "1g" # Tune as needed for optimal performance
Expand Down

0 comments on commit 7b40c5a

Please sign in to comment.