From 5953afe9a82ffab111c4934e0ec598439a7e8440 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Sat, 14 Dec 2024 01:11:33 +0700 Subject: [PATCH 1/3] Optimize UDF with parallel execution --- src/datachain/lib/udf.py | 1 - src/datachain/query/batch.py | 39 +++++-- src/datachain/query/dataset.py | 14 ++- src/datachain/query/dispatch.py | 178 ++++++++++++++++++-------------- src/datachain/query/utils.py | 42 ++++++++ src/datachain/utils.py | 2 +- 6 files changed, 183 insertions(+), 93 deletions(-) create mode 100644 src/datachain/query/utils.py diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index d708c0330..c59442d6b 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -85,7 +85,6 @@ def run( udf_fields: "Sequence[str]", udf_inputs: "Iterable[RowsOutput]", catalog: "Catalog", - is_generator: bool, cache: bool, download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, diff --git a/src/datachain/query/batch.py b/src/datachain/query/batch.py index 8f24ec895..6be29b1fe 100644 --- a/src/datachain/query/batch.py +++ b/src/datachain/query/batch.py @@ -7,6 +7,7 @@ from datachain.data_storage.schema import PARTITION_COLUMN_ID from datachain.data_storage.warehouse import SELECT_BATCH_SIZE +from datachain.query.utils import get_query_column, get_query_id_column if TYPE_CHECKING: from sqlalchemy import Select @@ -23,11 +24,14 @@ class RowsOutputBatch: class BatchingStrategy(ABC): """BatchingStrategy provides means of batching UDF executions.""" + is_batching: bool + @abstractmethod def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutput, None, None]: """Apply the provided parameters to the UDF.""" @@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy): batch UDF calls. """ + is_batching = False + def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[Sequence, None, None]: + if ids_only: + query = query.with_only_columns(get_query_id_column(query)) return execute(query) @@ -52,14 +61,20 @@ class Batch(BatchingStrategy): is passed a sequence of multiple parameter sets. """ + is_batching = True + def __init__(self, count: int): self.count = count def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutputBatch, None, None]: + if ids_only: + query = query.with_only_columns(get_query_id_column(query)) + # choose page size that is a multiple of the batch size page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count @@ -84,19 +99,31 @@ class Partition(BatchingStrategy): Dataset rows need to be sorted by the grouping column. """ + is_batching = True + def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutputBatch, None, None]: + id_col = get_query_id_column(query) + if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None: + raise RuntimeError("partition column not found in query") + + if ids_only: + query = query.with_only_columns(id_col, partition_col) + current_partition: Optional[int] = None batch: list[Sequence] = [] query_fields = [str(c.name) for c in query.selected_columns] + # query_fields = [column_name(col) for col in query.inner_columns] + id_column_idx = query_fields.index("sys__id") partition_column_idx = query_fields.index(PARTITION_COLUMN_ID) ordered_query = query.order_by(None).order_by( - PARTITION_COLUMN_ID, + partition_col, *query._order_by_clauses, ) @@ -108,7 +135,7 @@ def __call__( if len(batch) > 0: yield RowsOutputBatch(batch) batch = [] - batch.append(row) + batch.append([row[id_column_idx]] if ids_only else row) if len(batch) > 0: yield RowsOutputBatch(batch) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 46058ba83..b4294ac7d 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -44,7 +44,6 @@ from datachain.dataset import DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function -from datachain.lib.udf import UDFAdapter from datachain.progress import CombinedDownloadCallback from datachain.sql.functions.random import rand from datachain.utils import ( @@ -66,7 +65,7 @@ from datachain.catalog import Catalog from datachain.data_storage import AbstractWarehouse from datachain.dataset import DatasetRecord - from datachain.lib.udf import UDFResult + from datachain.lib.udf import UDFAdapter, UDFResult P = ParamSpec("P") @@ -302,7 +301,7 @@ def adjust_outputs( return row -def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]: +def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]: """Optimization: Precompute UDF column types so these don't have to be computed in the convert_type function for each row in a loop.""" dialect = warehouse.db.dialect @@ -323,7 +322,7 @@ def process_udf_outputs( warehouse: "AbstractWarehouse", udf_table: "Table", udf_results: Iterator[Iterable["UDFResult"]], - udf: UDFAdapter, + udf: "UDFAdapter", batch_size: int = INSERT_BATCH_SIZE, cb: Callback = DEFAULT_CALLBACK, ) -> None: @@ -367,7 +366,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback: @frozen class UDFStep(Step, ABC): - udf: UDFAdapter + udf: "UDFAdapter" catalog: "Catalog" partition_by: Optional[PartitionByType] = None parallel: Optional[int] = None @@ -478,7 +477,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: udf_fields, udf_inputs, self.catalog, - self.is_generator, self.cache, download_cb, processed_cb, @@ -1487,7 +1485,7 @@ def chunk(self, index: int, total: int) -> "Self": @detach def add_signals( self, - udf: UDFAdapter, + udf: "UDFAdapter", parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, @@ -1531,7 +1529,7 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self": @detach def generate( self, - udf: UDFAdapter, + udf: "UDFAdapter", parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index 5392cf491..2d85fe551 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -1,9 +1,10 @@ import contextlib -from collections.abc import Iterator, Sequence +from collections.abc import Sequence from itertools import chain from multiprocessing import cpu_count from sys import stdin -from typing import Optional +from threading import Timer +from typing import TYPE_CHECKING, Optional import attrs import multiprocess @@ -13,22 +14,23 @@ from datachain.catalog import Catalog from datachain.catalog.loader import get_distributed_class -from datachain.lib.udf import UDFAdapter, UDFResult +from datachain.query.batch import RowsOutputBatch from datachain.query.dataset import ( get_download_callback, get_generated_callback, get_processed_callback, process_udf_outputs, ) -from datachain.query.queue import ( - get_from_queue, - marshal, - msgpack_pack, - msgpack_unpack, - put_into_queue, - unmarshal, -) -from datachain.utils import batched_it +from datachain.query.queue import get_from_queue, put_into_queue +from datachain.query.utils import get_query_id_column +from datachain.utils import batched, flatten + +if TYPE_CHECKING: + from sqlalchemy import Select, Table + + from datachain.data_storage import AbstractMetastore, AbstractWarehouse + from datachain.lib.udf import UDFAdapter + from datachain.query.batch import BatchingStrategy DEFAULT_BATCH_SIZE = 10000 STOP_SIGNAL = "STOP" @@ -54,12 +56,9 @@ def udf_entrypoint() -> int: # Load UDF info from stdin udf_info = load(stdin.buffer) - ( - warehouse_class, - warehouse_args, - warehouse_kwargs, - ) = udf_info["warehouse_clone_params"] - warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs) + query: Select = udf_info["query"] + table: Table = udf_info["table"] + batching: BatchingStrategy = udf_info["batching"] # Parallel processing (faster for more CPU-heavy UDFs) dispatch = UDFDispatcher( @@ -67,41 +66,39 @@ def udf_entrypoint() -> int: udf_info["catalog_init"], udf_info["metastore_clone_params"], udf_info["warehouse_clone_params"], + query=query, + table=table, udf_fields=udf_info["udf_fields"], cache=udf_info["cache"], is_generator=udf_info.get("is_generator", False), + is_batching=batching.is_batching, ) - query = udf_info["query"] - batching = udf_info["batching"] - table = udf_info["table"] n_workers = udf_info["processes"] - udf = loads(udf_info["udf_data"]) if n_workers is True: - # Use default number of CPUs (cores) - n_workers = None + n_workers = None # Use default number of CPUs (cores) + + wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"] + warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs) with contextlib.closing( - batching(warehouse.dataset_select_paginated, query) + batching(warehouse.db.execute, query, ids_only=True) ) as udf_inputs: download_cb = get_download_callback() processed_cb = get_processed_callback() generated_cb = get_generated_callback(dispatch.is_generator) try: - udf_results = dispatch.run_udf_parallel( - marshal(udf_inputs), + dispatch.run_udf_parallel( + udf_inputs, n_workers=n_workers, processed_cb=processed_cb, download_cb=download_cb, ) - process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb) finally: download_cb.close() processed_cb.close() generated_cb.close() - warehouse.insert_rows_done(table) - return 0 @@ -120,26 +117,24 @@ def __init__( catalog_init_params, metastore_clone_params, warehouse_clone_params, + query: "Select", + table: "Table", udf_fields: "Sequence[str]", cache: bool, is_generator: bool = False, + is_batching: bool = False, buffer_size: int = DEFAULT_BATCH_SIZE, ): self.udf_data = udf_data self.catalog_init_params = catalog_init_params - ( - self.metastore_class, - self.metastore_args, - self.metastore_kwargs, - ) = metastore_clone_params - ( - self.warehouse_class, - self.warehouse_args, - self.warehouse_kwargs, - ) = warehouse_clone_params + self.metastore_clone_params = metastore_clone_params + self.warehouse_clone_params = warehouse_clone_params + self.query = query + self.table = table self.udf_fields = udf_fields self.cache = cache self.is_generator = is_generator + self.is_batching = is_batching self.buffer_size = buffer_size self.catalog = None self.task_queue = None @@ -148,12 +143,10 @@ def __init__( def _create_worker(self) -> "UDFWorker": if not self.catalog: - metastore = self.metastore_class( - *self.metastore_args, **self.metastore_kwargs - ) - warehouse = self.warehouse_class( - *self.warehouse_args, **self.warehouse_kwargs - ) + ms_cls, ms_args, ms_kwargs = self.metastore_clone_params + metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs) + ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params + warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs) self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params) self.udf = loads(self.udf_data) return UDFWorker( @@ -161,7 +154,9 @@ def _create_worker(self) -> "UDFWorker": self.udf, self.task_queue, self.done_queue, - self.is_generator, + self.query, + self.table, + self.is_batching, self.cache, self.udf_fields, ) @@ -194,7 +189,7 @@ def run_udf_parallel( # noqa: C901, PLR0912 input_queue=None, processed_cb: Callback = DEFAULT_CALLBACK, download_cb: Callback = DEFAULT_CALLBACK, - ) -> Iterator[Sequence[UDFResult]]: + ) -> None: n_workers = get_n_workers_from_arg(n_workers) if self.buffer_size < n_workers: @@ -224,6 +219,9 @@ def run_udf_parallel( # noqa: C901, PLR0912 input_finished = False if not streaming_mode: + if not self.is_batching: + input_rows = batched(flatten(input_rows), DEFAULT_BATCH_SIZE) + # Stop all workers after the input rows have finished processing input_data = chain(input_rows, [STOP_SIGNAL] * n_workers) @@ -238,19 +236,17 @@ def run_udf_parallel( # noqa: C901, PLR0912 # Process all tasks while n_workers > 0: result = get_from_queue(self.done_queue) + + if downloaded := result.get("downloaded"): + download_cb.relative_update(downloaded) + if processed := result.get("processed"): + processed_cb.relative_update(processed) + status = result["status"] - if status == NOTIFY_STATUS: - if downloaded := result.get("downloaded"): - download_cb.relative_update(downloaded) - if processed := result.get("processed"): - processed_cb.relative_update(processed) + if status in (OK_STATUS, NOTIFY_STATUS): + pass # Do nothing here elif status == FINISHED_STATUS: - # Worker finished - n_workers -= 1 - elif status == OK_STATUS: - if processed := result.get("processed"): - processed_cb.relative_update(processed) - yield msgpack_unpack(result["result"]) + n_workers -= 1 # Worker finished else: # Failed / error n_workers -= 1 if exc := result.get("exception"): @@ -311,11 +307,13 @@ def relative_update(self, inc: int = 1) -> None: @attrs.define class UDFWorker: - catalog: Catalog - udf: UDFAdapter + catalog: "Catalog" + udf: "UDFAdapter" task_queue: "multiprocess.Queue" done_queue: "multiprocess.Queue" - is_generator: bool + query: "Select" + table: "Table" + is_batching: bool cache: bool udf_fields: Sequence[str] cb: Callback = attrs.field() @@ -325,31 +323,57 @@ def _default_callback(self) -> WorkerCallback: return WorkerCallback(self.done_queue) def run(self) -> None: + warehouse = self.catalog.warehouse.clone() processed_cb = ProcessedCallback() + udf_results = self.udf.run( self.udf_fields, - unmarshal(self.get_inputs()), + self.get_inputs(), self.catalog, - self.is_generator, self.cache, download_cb=self.cb, processed_cb=processed_cb, ) - for udf_output in udf_results: - for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE): - put_into_queue( - self.done_queue, - { - "status": OK_STATUS, - "result": msgpack_pack(list(batch)), - }, - ) + process_udf_outputs( + warehouse, + self.table, + self.notify_and_process(udf_results, processed_cb), + self.udf, + cb=processed_cb, + ) + warehouse.insert_rows_done(self.table) + + put_into_queue( + self.done_queue, + {"status": FINISHED_STATUS, "processed": processed_cb.processed_rows}, + ) + + def notify_and_process(self, udf_results, processed_cb): + for row in udf_results: put_into_queue( self.done_queue, - {"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows}, + {"status": OK_STATUS, "processed": processed_cb.processed_rows}, ) - put_into_queue(self.done_queue, {"status": FINISHED_STATUS}) + yield row def get_inputs(self): - while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: - yield batch + warehouse = self.catalog.warehouse.clone() + col_id = get_query_id_column(self.query) + + if self.is_batching: + while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: + ids = [row[0] for row in batch.rows] + rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids))) + yield RowsOutputBatch(list(rows)) + else: + while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: + rows = warehouse.dataset_rows_select( + self.query.where(col_id.in_(batch)) + ) + yield from rows + + +class RepeatTimer(Timer): + def run(self): + while not self.finished.wait(self.interval): + self.function(*self.args, **self.kwargs) diff --git a/src/datachain/query/utils.py b/src/datachain/query/utils.py new file mode 100644 index 000000000..0d92226b1 --- /dev/null +++ b/src/datachain/query/utils.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, Optional, Union + +from sqlalchemy import Column + +if TYPE_CHECKING: + from sqlalchemy import ColumnElement, Select, TextClause + + +ColT = Union[Column, "ColumnElement", "TextClause"] + + +def column_name(col: ColT) -> str: + """Returns column name from column element.""" + return col.name if isinstance(col, Column) else str(col) + + +def get_query_column(query: "Select", name: str) -> Optional[ColT]: + """Returns column element from query by name or None if column not found.""" + return next((col for col in query.inner_columns if column_name(col) == name), None) + + +def get_query_id_column(query: "Select") -> ColT: + """Returns ID column element from query or None if column not found.""" + col = get_query_column(query, "sys__id") + if col is None: + raise RuntimeError("sys__id column not found in query") + return col + + +def select_only_columns(query: "Select", *names: str) -> "Select": + """Returns query selecting defined columns only.""" + if not names: + return query + + cols: list[ColT] = [] + for name in names: + col = get_query_column(query, name) + if col is None: + raise ValueError(f"Column '{name}' not found in query") + cols.append(col) + + return query.with_only_columns(*cols) diff --git a/src/datachain/utils.py b/src/datachain/utils.py index 21fcd6e49..11018df08 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -263,7 +263,7 @@ def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]: def flatten(items): for item in items: - if isinstance(item, list): + if isinstance(item, (list, tuple)): yield from item else: yield item From 699fb96be9929c1df408f9b498b7a48e87fadd21 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Tue, 17 Dec 2024 18:22:47 +0700 Subject: [PATCH 2/3] Finish UDF optimizations: SaaS support --- src/datachain/data_storage/warehouse.py | 1 - src/datachain/query/dataset.py | 16 +-- src/datachain/query/dispatch.py | 130 ++++++++++-------------- src/datachain/query/udf.py | 20 ++++ 4 files changed, 81 insertions(+), 86 deletions(-) create mode 100644 src/datachain/query/udf.py diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index fdb8f3c17..cd0c4376e 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -216,7 +216,6 @@ def dataset_select_paginated( limit = query._limit paginated_query = query.limit(page_size) - results = None offset = 0 num_yielded = 0 diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 921329228..078ae0a86 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -44,6 +44,8 @@ from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function from datachain.progress import CombinedDownloadCallback +from datachain.query.schema import C, UDFParamSpec, normalize_param +from datachain.query.session import Session from datachain.sql.functions.random import rand from datachain.utils import ( batched, @@ -52,9 +54,6 @@ get_datachain_executable, ) -from .schema import C, UDFParamSpec, normalize_param -from .session import Session - if TYPE_CHECKING: from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.schema import Table @@ -65,6 +64,7 @@ from datachain.data_storage import AbstractWarehouse from datachain.dataset import DatasetRecord from datachain.lib.udf import UDFAdapter, UDFResult + from datachain.query.udf import UdfInfo P = ParamSpec("P") @@ -346,6 +346,8 @@ def process_udf_outputs( for row_chunk in batched(rows, batch_size): warehouse.insert_rows(udf_table, row_chunk) + warehouse.insert_rows_done(udf_table) + def get_download_callback() -> Callback: return CombinedDownloadCallback( @@ -439,7 +441,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: raise RuntimeError( "In-memory databases cannot be used with parallel processing." ) - udf_info = { + udf_info: UdfInfo = { "udf_data": filtered_cloudpickle_dumps(self.udf), "catalog_init": self.catalog.get_init_params(), "metastore_clone_params": self.catalog.metastore.clone_params(), @@ -463,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603 process.communicate(process_data) - if process.poll(): - raise RuntimeError("UDF Execution Failed!") + if ret := process.poll(): + raise RuntimeError(f"UDF Execution Failed! Exit code: {ret}") else: # Otherwise process single-threaded (faster for smaller UDFs) warehouse = self.catalog.warehouse @@ -494,8 +496,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: processed_cb.close() generated_cb.close() - warehouse.insert_rows_done(udf_table) - except QueryScriptCancelError: self.catalog.warehouse.close() sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE) diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index 2d85fe551..e1d62c9e2 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -1,5 +1,5 @@ import contextlib -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from itertools import chain from multiprocessing import cpu_count from sys import stdin @@ -11,10 +11,11 @@ from cloudpickle import load, loads from fsspec.callbacks import DEFAULT_CALLBACK, Callback from multiprocess import get_context +from sqlalchemy.sql import func from datachain.catalog import Catalog from datachain.catalog.loader import get_distributed_class -from datachain.query.batch import RowsOutputBatch +from datachain.query.batch import RowsOutput, RowsOutputBatch from datachain.query.dataset import ( get_download_callback, get_generated_callback, @@ -22,6 +23,7 @@ process_udf_outputs, ) from datachain.query.queue import get_from_queue, put_into_queue +from datachain.query.udf import UdfInfo from datachain.query.utils import get_query_id_column from datachain.utils import batched, flatten @@ -30,7 +32,6 @@ from datachain.data_storage import AbstractMetastore, AbstractWarehouse from datachain.lib.udf import UDFAdapter - from datachain.query.batch import BatchingStrategy DEFAULT_BATCH_SIZE = 10000 STOP_SIGNAL = "STOP" @@ -40,10 +41,6 @@ NOTIFY_STATUS = "NOTIFY" -def full_module_type_path(typ: type) -> str: - return f"{typ.__module__}.{typ.__qualname__}" - - def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int: if not n_workers: return cpu_count() @@ -54,26 +51,13 @@ def get_n_workers_from_arg(n_workers: Optional[int] = None) -> int: def udf_entrypoint() -> int: # Load UDF info from stdin - udf_info = load(stdin.buffer) - - query: Select = udf_info["query"] - table: Table = udf_info["table"] - batching: BatchingStrategy = udf_info["batching"] + udf_info: UdfInfo = load(stdin.buffer) # Parallel processing (faster for more CPU-heavy UDFs) - dispatch = UDFDispatcher( - udf_info["udf_data"], - udf_info["catalog_init"], - udf_info["metastore_clone_params"], - udf_info["warehouse_clone_params"], - query=query, - table=table, - udf_fields=udf_info["udf_fields"], - cache=udf_info["cache"], - is_generator=udf_info.get("is_generator", False), - is_batching=batching.is_batching, - ) + dispatch = UDFDispatcher(udf_info) + query = udf_info["query"] + batching = udf_info["batching"] n_workers = udf_info["processes"] if n_workers is True: n_workers = None # Use default number of CPUs (cores) @@ -81,15 +65,21 @@ def udf_entrypoint() -> int: wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"] warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs) + total_rows = next( + warehouse.db.execute( + query.with_only_columns(func.count(query.c.sys__id)).order_by(None) + ) + )[0] + with contextlib.closing( batching(warehouse.db.execute, query, ids_only=True) ) as udf_inputs: download_cb = get_download_callback() processed_cb = get_processed_callback() - generated_cb = get_generated_callback(dispatch.is_generator) try: dispatch.run_udf_parallel( udf_inputs, + total_rows=total_rows, n_workers=n_workers, processed_cb=processed_cb, download_cb=download_cb, @@ -97,7 +87,6 @@ def udf_entrypoint() -> int: finally: download_cb.close() processed_cb.close() - generated_cb.close() return 0 @@ -111,30 +100,17 @@ class UDFDispatcher: task_queue: Optional[multiprocess.Queue] = None done_queue: Optional[multiprocess.Queue] = None - def __init__( - self, - udf_data, - catalog_init_params, - metastore_clone_params, - warehouse_clone_params, - query: "Select", - table: "Table", - udf_fields: "Sequence[str]", - cache: bool, - is_generator: bool = False, - is_batching: bool = False, - buffer_size: int = DEFAULT_BATCH_SIZE, - ): - self.udf_data = udf_data - self.catalog_init_params = catalog_init_params - self.metastore_clone_params = metastore_clone_params - self.warehouse_clone_params = warehouse_clone_params - self.query = query - self.table = table - self.udf_fields = udf_fields - self.cache = cache - self.is_generator = is_generator - self.is_batching = is_batching + def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE): + self.udf_data = udf_info["udf_data"] + self.catalog_init_params = udf_info["catalog_init"] + self.metastore_clone_params = udf_info["metastore_clone_params"] + self.warehouse_clone_params = udf_info["warehouse_clone_params"] + self.query = udf_info["query"] + self.table = udf_info["table"] + self.udf_fields = udf_info["udf_fields"] + self.cache = udf_info["cache"] + self.is_generator = udf_info["is_generator"] + self.is_batching = udf_info["batching"].is_batching self.buffer_size = buffer_size self.catalog = None self.task_queue = None @@ -156,6 +132,7 @@ def _create_worker(self) -> "UDFWorker": self.done_queue, self.query, self.table, + self.is_generator, self.is_batching, self.cache, self.udf_fields, @@ -184,26 +161,27 @@ def create_input_queue(self): def run_udf_parallel( # noqa: C901, PLR0912 self, - input_rows, + input_rows: Iterable[RowsOutput], + total_rows: int, n_workers: Optional[int] = None, - input_queue=None, processed_cb: Callback = DEFAULT_CALLBACK, download_cb: Callback = DEFAULT_CALLBACK, ) -> None: n_workers = get_n_workers_from_arg(n_workers) + input_batch_size = total_rows // n_workers + if input_batch_size == 0: + input_batch_size = 1 + elif input_batch_size > DEFAULT_BATCH_SIZE: + input_batch_size = DEFAULT_BATCH_SIZE + if self.buffer_size < n_workers: raise RuntimeError( "Parallel run error: buffer size is smaller than " f"number of workers: {self.buffer_size} < {n_workers}" ) - if input_queue: - streaming_mode = True - self.task_queue = input_queue - else: - streaming_mode = False - self.task_queue = self.ctx.Queue() + self.task_queue = self.ctx.Queue() self.done_queue = self.ctx.Queue() pool = [ self.ctx.Process(name=f"Worker-UDF-{i}", target=self._run_worker) @@ -218,20 +196,19 @@ def run_udf_parallel( # noqa: C901, PLR0912 # Will be set to True when the input is exhausted input_finished = False - if not streaming_mode: - if not self.is_batching: - input_rows = batched(flatten(input_rows), DEFAULT_BATCH_SIZE) + if not self.is_batching: + input_rows = batched(flatten(input_rows), input_batch_size) - # Stop all workers after the input rows have finished processing - input_data = chain(input_rows, [STOP_SIGNAL] * n_workers) + # Stop all workers after the input rows have finished processing + input_data = chain(input_rows, [STOP_SIGNAL] * n_workers) - # Add initial buffer of tasks - for _ in range(self.buffer_size): - try: - put_into_queue(self.task_queue, next(input_data)) - except StopIteration: - input_finished = True - break + # Add initial buffer of tasks + for _ in range(self.buffer_size): + try: + put_into_queue(self.task_queue, next(input_data)) + except StopIteration: + input_finished = True + break # Process all tasks while n_workers > 0: @@ -253,7 +230,7 @@ def run_udf_parallel( # noqa: C901, PLR0912 raise exc raise RuntimeError("Internal error: Parallel UDF execution failed") - if status == OK_STATUS and not streaming_mode and not input_finished: + if status == OK_STATUS and not input_finished: try: put_into_queue(self.task_queue, next(input_data)) except StopIteration: @@ -313,6 +290,7 @@ class UDFWorker: done_queue: "multiprocess.Queue" query: "Select" table: "Table" + is_generator: bool is_batching: bool cache: bool udf_fields: Sequence[str] @@ -323,8 +301,8 @@ def _default_callback(self) -> WorkerCallback: return WorkerCallback(self.done_queue) def run(self) -> None: - warehouse = self.catalog.warehouse.clone() processed_cb = ProcessedCallback() + generated_cb = get_generated_callback(self.is_generator) udf_results = self.udf.run( self.udf_fields, @@ -335,13 +313,12 @@ def run(self) -> None: processed_cb=processed_cb, ) process_udf_outputs( - warehouse, + self.catalog.warehouse, self.table, self.notify_and_process(udf_results, processed_cb), self.udf, - cb=processed_cb, + cb=generated_cb, ) - warehouse.insert_rows_done(self.table) put_into_queue( self.done_queue, @@ -367,10 +344,9 @@ def get_inputs(self): yield RowsOutputBatch(list(rows)) else: while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: - rows = warehouse.dataset_rows_select( + yield from warehouse.dataset_rows_select( self.query.where(col_id.in_(batch)) ) - yield from rows class RepeatTimer(Timer): diff --git a/src/datachain/query/udf.py b/src/datachain/query/udf.py new file mode 100644 index 000000000..a6046deae --- /dev/null +++ b/src/datachain/query/udf.py @@ -0,0 +1,20 @@ +from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict + +if TYPE_CHECKING: + from sqlalchemy import Select, Table + + from datachain.query.batch import BatchingStrategy + + +class UdfInfo(TypedDict): + udf_data: bytes + catalog_init: dict[str, Any] + metastore_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]] + warehouse_clone_params: tuple[Callable[..., Any], list[Any], dict[str, Any]] + table: "Table" + query: "Select" + udf_fields: list[str] + batching: "BatchingStrategy" + processes: Optional[int] + is_generator: bool + cache: bool From dc2672015112d4f64b6b14aa4e49a5156c0d4fbd Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Mon, 23 Dec 2024 21:40:56 +0700 Subject: [PATCH 3/3] Code review update --- src/datachain/query/batch.py | 1 - src/datachain/query/dataset.py | 4 ++-- src/datachain/query/dispatch.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/datachain/query/batch.py b/src/datachain/query/batch.py index 6be29b1fe..65b2f5742 100644 --- a/src/datachain/query/batch.py +++ b/src/datachain/query/batch.py @@ -118,7 +118,6 @@ def __call__( batch: list[Sequence] = [] query_fields = [str(c.name) for c in query.selected_columns] - # query_fields = [column_name(col) for col in query.inner_columns] id_column_idx = query_fields.index("sys__id") partition_column_idx = query_fields.index(PARTITION_COLUMN_ID) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 9ecab714e..727a14219 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -465,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603 process.communicate(process_data) - if ret := process.poll(): - raise RuntimeError(f"UDF Execution Failed! Exit code: {ret}") + if retval := process.poll(): + raise RuntimeError(f"UDF Execution Failed! Exit code: {retval}") else: # Otherwise process single-threaded (faster for smaller UDFs) warehouse = self.catalog.warehouse diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index e1d62c9e2..722f68c10 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -72,7 +72,7 @@ def udf_entrypoint() -> int: )[0] with contextlib.closing( - batching(warehouse.db.execute, query, ids_only=True) + batching(warehouse.dataset_select_paginated, query, ids_only=True) ) as udf_inputs: download_cb = get_download_callback() processed_cb = get_processed_callback()