Skip to content

Commit

Permalink
Optimize UDF with parallel execution (#713)
Browse files Browse the repository at this point in the history
---------
Co-authored-by: skshetry <[email protected]>
  • Loading branch information
dreadatour authored Dec 26, 2024
1 parent 46aa4ad commit b1ce093
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 150 deletions.
1 change: 0 additions & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def dataset_select_paginated(
limit = query._limit
paginated_query = query.limit(page_size)

results = None
offset = 0
num_yielded = 0

Expand Down
1 change: 0 additions & 1 deletion src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def run(
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[RowsOutput]",
catalog: "Catalog",
is_generator: bool,
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
Expand Down
38 changes: 32 additions & 6 deletions src/datachain/query/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from datachain.data_storage.schema import PARTITION_COLUMN_ID
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
from datachain.query.utils import get_query_column, get_query_id_column

if TYPE_CHECKING:
from sqlalchemy import Select
Expand All @@ -23,11 +24,14 @@ class RowsOutputBatch:
class BatchingStrategy(ABC):
"""BatchingStrategy provides means of batching UDF executions."""

is_batching: bool

@abstractmethod
def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[RowsOutput, None, None]:
"""Apply the provided parameters to the UDF."""

Expand All @@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy):
batch UDF calls.
"""

is_batching = False

def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[Sequence, None, None]:
if ids_only:
query = query.with_only_columns(get_query_id_column(query))
return execute(query)


Expand All @@ -52,14 +61,20 @@ class Batch(BatchingStrategy):
is passed a sequence of multiple parameter sets.
"""

is_batching = True

def __init__(self, count: int):
self.count = count

def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[RowsOutputBatch, None, None]:
if ids_only:
query = query.with_only_columns(get_query_id_column(query))

# choose page size that is a multiple of the batch size
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count

Expand All @@ -84,19 +99,30 @@ class Partition(BatchingStrategy):
Dataset rows need to be sorted by the grouping column.
"""

is_batching = True

def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[RowsOutputBatch, None, None]:
id_col = get_query_id_column(query)
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
raise RuntimeError("partition column not found in query")

if ids_only:
query = query.with_only_columns(id_col, partition_col)

current_partition: Optional[int] = None
batch: list[Sequence] = []

query_fields = [str(c.name) for c in query.selected_columns]
id_column_idx = query_fields.index("sys__id")
partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)

ordered_query = query.order_by(None).order_by(
PARTITION_COLUMN_ID,
partition_col,
*query._order_by_clauses,
)

Expand All @@ -108,7 +134,7 @@ def __call__(
if len(batch) > 0:
yield RowsOutputBatch(batch)
batch = []
batch.append(row)
batch.append([row[id_column_idx]] if ids_only else row)

if len(batch) > 0:
yield RowsOutputBatch(batch)
30 changes: 14 additions & 16 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter
from datachain.progress import CombinedDownloadCallback
from datachain.query.schema import C, UDFParamSpec, normalize_param
from datachain.query.session import Session
from datachain.sql.functions.random import rand
from datachain.utils import (
batched,
Expand All @@ -53,9 +54,6 @@
get_datachain_executable,
)

from .schema import C, UDFParamSpec, normalize_param
from .session import Session

if TYPE_CHECKING:
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Table
Expand All @@ -65,7 +63,8 @@
from datachain.catalog import Catalog
from datachain.data_storage import AbstractWarehouse
from datachain.dataset import DatasetRecord
from datachain.lib.udf import UDFResult
from datachain.lib.udf import UDFAdapter, UDFResult
from datachain.query.udf import UdfInfo

P = ParamSpec("P")

Expand Down Expand Up @@ -301,7 +300,7 @@ def adjust_outputs(
return row


def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]:
"""Optimization: Precompute UDF column types so these don't have to be computed
in the convert_type function for each row in a loop."""
dialect = warehouse.db.dialect
Expand All @@ -322,7 +321,7 @@ def process_udf_outputs(
warehouse: "AbstractWarehouse",
udf_table: "Table",
udf_results: Iterator[Iterable["UDFResult"]],
udf: UDFAdapter,
udf: "UDFAdapter",
batch_size: int = INSERT_BATCH_SIZE,
cb: Callback = DEFAULT_CALLBACK,
) -> None:
Expand All @@ -347,6 +346,8 @@ def process_udf_outputs(
for row_chunk in batched(rows, batch_size):
warehouse.insert_rows(udf_table, row_chunk)

warehouse.insert_rows_done(udf_table)


def get_download_callback() -> Callback:
return CombinedDownloadCallback(
Expand All @@ -366,7 +367,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:

@frozen
class UDFStep(Step, ABC):
udf: UDFAdapter
udf: "UDFAdapter"
catalog: "Catalog"
partition_by: Optional[PartitionByType] = None
parallel: Optional[int] = None
Expand Down Expand Up @@ -440,7 +441,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
raise RuntimeError(
"In-memory databases cannot be used with parallel processing."
)
udf_info = {
udf_info: UdfInfo = {
"udf_data": filtered_cloudpickle_dumps(self.udf),
"catalog_init": self.catalog.get_init_params(),
"metastore_clone_params": self.catalog.metastore.clone_params(),
Expand All @@ -464,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
process.communicate(process_data)
if process.poll():
raise RuntimeError("UDF Execution Failed!")
if retval := process.poll():
raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}")
else:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = self.catalog.warehouse
Expand All @@ -479,7 +480,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
Expand All @@ -496,8 +496,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
processed_cb.close()
generated_cb.close()

warehouse.insert_rows_done(udf_table)

except QueryScriptCancelError:
self.catalog.warehouse.close()
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE)
Expand Down Expand Up @@ -1491,7 +1489,7 @@ def chunk(self, index: int, total: int) -> "Self":
@detach
def add_signals(
self,
udf: UDFAdapter,
udf: "UDFAdapter",
parallel: Optional[int] = None,
workers: Union[bool, int] = False,
min_task_size: Optional[int] = None,
Expand Down Expand Up @@ -1535,7 +1533,7 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self":
@detach
def generate(
self,
udf: UDFAdapter,
udf: "UDFAdapter",
parallel: Optional[int] = None,
workers: Union[bool, int] = False,
min_task_size: Optional[int] = None,
Expand Down
Loading

0 comments on commit b1ce093

Please sign in to comment.