From 79c91891fe566312f5719c31d980eee4aa55a99b Mon Sep 17 00:00:00 2001 From: Ronan Lamy Date: Fri, 18 Oct 2024 20:05:20 +0100 Subject: [PATCH 1/2] Use threading in AsyncMapper.produce() --- src/datachain/asyn.py | 19 +++++++++++++++---- tests/unit/test_asyn.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/datachain/asyn.py b/src/datachain/asyn.py index 4e42f4a8c..7c94190fa 100644 --- a/src/datachain/asyn.py +++ b/src/datachain/asyn.py @@ -1,5 +1,12 @@ import asyncio -from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator +from collections.abc import ( + AsyncIterable, + Awaitable, + Coroutine, + Generator, + Iterable, + Iterator, +) from concurrent.futures import ThreadPoolExecutor from heapq import heappop, heappush from typing import Any, Callable, Generic, Optional, TypeVar @@ -54,9 +61,13 @@ def start_task(self, coro: Coroutine) -> asyncio.Task: task.add_done_callback(self._tasks.discard) return task - async def produce(self) -> None: + def _produce(self) -> None: for item in self.iterable: - await self.work_queue.put(item) + fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop) + fut.result() # wait until the item is in the queue + + async def produce(self) -> None: + await self.to_thread(self._produce) async def worker(self) -> None: while (item := await self.work_queue.get()) is not None: @@ -132,7 +143,7 @@ async def _break_iteration(self) -> None: self.result_queue.get_nowait() await self.result_queue.put(None) - def iterate(self, timeout=None) -> Iterable[ResultT]: + def iterate(self, timeout=None) -> Generator[ResultT, None, None]: init = asyncio.run_coroutine_threadsafe(self.init(), self.loop) init.result(timeout=1) async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop) diff --git a/tests/unit/test_asyn.py b/tests/unit/test_asyn.py index a97bb732a..e102dd3a2 100644 --- a/tests/unit/test_asyn.py +++ b/tests/unit/test_asyn.py @@ -2,6 +2,7 @@ import functools from collections import Counter from contextlib import contextmanager +from queue import Queue import pytest from fsspec.asyn import sync @@ -111,6 +112,37 @@ async def process(row): list(mapper.iterate(timeout=4)) +@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper]) +def test_mapper_deadlock(create_mapper): + queue = Queue() + inputs = range(50) + + def as_iter(queue): + while (item := queue.get()) is not None: + yield item + + async def process(x): + return x + + mapper = create_mapper(process, as_iter(queue), workers=10, loop=get_loop()) + it = mapper.iterate(timeout=4) + for i in inputs: + queue.put(i) + + # Check that we can get as many objects out as we put in, without deadlock + result = [] + for _ in range(len(inputs)): + result.append(next(it)) + if mapper.order_preserving: + assert result == list(inputs) + else: + assert set(result) == set(inputs) + + # Check that iteration terminates cleanly + queue.put(None) + assert list(it) == [] + + @pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper]) @settings(deadline=None) @given( From acdc96916a0ebe32b7d9cf90987094fda2623bab Mon Sep 17 00:00:00 2001 From: Ronan Lamy Date: Wed, 9 Oct 2024 15:39:21 +0100 Subject: [PATCH 2/2] Implement prefetching in .gen() and .map() --- src/datachain/data_storage/warehouse.py | 5 +- src/datachain/lib/dc.py | 7 ++- src/datachain/lib/file.py | 5 ++ src/datachain/lib/settings.py | 12 ++++- src/datachain/lib/udf.py | 63 ++++++++++++++++++------- src/datachain/query/dataset.py | 52 ++++++++++---------- tests/func/test_datachain.py | 7 ++- 7 files changed, 101 insertions(+), 50 deletions(-) diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 74b25045d..ecc50eba2 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -227,7 +227,10 @@ def dataset_select_paginated( if limit < page_size: paginated_query = paginated_query.limit(None).limit(limit) - results = self.dataset_rows_select(paginated_query.offset(offset)) + # Ensure we're using a thread-local connection + with self.clone() as wh: + # Cursor results are not thread-safe, so we convert them to a list + results = list(wh.dataset_rows_select(paginated_query.offset(offset))) processed = False for row in results: diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 138e2a131..9aaa227bf 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -325,6 +325,7 @@ def settings( parallel=None, workers=None, min_task_size=None, + prefetch: Optional[int] = None, sys: Optional[bool] = None, ) -> "Self": """Change settings for chain. @@ -351,7 +352,7 @@ def settings( if sys is None: sys = self._sys settings = copy.copy(self._settings) - settings.add(Settings(cache, parallel, workers, min_task_size)) + settings.add(Settings(cache, parallel, workers, min_task_size, prefetch)) return self._evolve(settings=settings, _sys=sys) def reset_settings(self, settings: Optional[Settings] = None) -> "Self": @@ -801,6 +802,8 @@ def map( ``` """ udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map) + if (prefetch := self._settings.prefetch) is not None: + udf_obj.prefetch = prefetch return self._evolve( query=self._query.add_signals( @@ -838,6 +841,8 @@ def gen( ``` """ udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map) + if (prefetch := self._settings.prefetch) is not None: + udf_obj.prefetch = prefetch return self._evolve( query=self._query.generate( udf_obj.to_udf_wrapper(), diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 41cd6369f..4373bd889 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -271,6 +271,11 @@ def ensure_cached(self) -> None: client = self._catalog.get_client(self.source) client.download(self, callback=self._download_cb) + async def _prefetch(self) -> None: + if self._caching_enabled: + client = self._catalog.get_client(self.source) + await client._download(self, callback=self._download_cb) + def get_local_path(self) -> Optional[str]: """Return path to a file in a local cache. diff --git a/src/datachain/lib/settings.py b/src/datachain/lib/settings.py index 1f3722a44..fe294d950 100644 --- a/src/datachain/lib/settings.py +++ b/src/datachain/lib/settings.py @@ -7,11 +7,19 @@ def __init__(self, msg): class Settings: - def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None): + def __init__( + self, + cache=None, + parallel=None, + workers=None, + min_task_size=None, + prefetch=None, + ): self._cache = cache self.parallel = parallel self._workers = workers self.min_task_size = min_task_size + self.prefetch = prefetch if not isinstance(cache, bool) and cache is not None: raise SettingsError( @@ -66,3 +74,5 @@ def add(self, settings: "Settings"): self.parallel = settings.parallel or self.parallel self._workers = settings._workers or self._workers self.min_task_size = settings.min_task_size or self.min_task_size + if settings.prefetch is not None: + self.prefetch = settings.prefetch diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index 2ce0257d5..ff9b25d64 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -1,3 +1,4 @@ +import contextlib import sys import traceback from collections.abc import Iterable, Iterator, Mapping, Sequence @@ -7,6 +8,7 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback from pydantic import BaseModel +from datachain.asyn import AsyncMapper from datachain.dataset import RowDict from datachain.lib.convert.flatten import flatten from datachain.lib.data_model import DataValue @@ -22,6 +24,8 @@ ) if TYPE_CHECKING: + from collections import abc + from typing_extensions import Self from datachain.catalog import Catalog @@ -276,9 +280,18 @@ def process_safe(self, obj_rows): return result_objs +async def _prefetch_input(row): + for obj in row: + if isinstance(obj, File): + await obj._prefetch() + return row + + class Mapper(UDFBase): """Inherit from this class to pass to `DataChain.map()`.""" + prefetch: int = 2 + def run( self, udf_fields: "Sequence[str]", @@ -290,16 +303,22 @@ def run( ) -> Iterator[Iterable[UDFResult]]: self.catalog = catalog self.setup() - - for row in udf_inputs: - id_, *udf_args = self._prepare_row_and_id( - row, udf_fields, cache, download_cb - ) - result_objs = self.process_safe(udf_args) - udf_output = self._flatten_row(result_objs) - output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))] - processed_cb.relative_update(1) - yield output + prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( + self._prepare_row_and_id(row, udf_fields, cache, download_cb) + for row in udf_inputs + ) + if self.prefetch > 0: + prepared_inputs = AsyncMapper( + _prefetch_input, prepared_inputs, workers=self.prefetch + ).iterate() + + with contextlib.closing(prepared_inputs): + for id_, *udf_args in prepared_inputs: + result_objs = self.process_safe(udf_args) + udf_output = self._flatten_row(result_objs) + output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))] + processed_cb.relative_update(1) + yield output self.teardown() @@ -349,6 +368,7 @@ class Generator(UDFBase): """Inherit from this class to pass to `DataChain.gen()`.""" is_output_batched = True + prefetch: int = 2 def run( self, @@ -361,14 +381,21 @@ def run( ) -> Iterator[Iterable[UDFResult]]: self.catalog = catalog self.setup() - - for row in udf_inputs: - udf_args = self._prepare_row(row, udf_fields, cache, download_cb) - result_objs = self.process_safe(udf_args) - udf_outputs = (self._flatten_row(row) for row in result_objs) - output = (dict(zip(self.signal_names, row)) for row in udf_outputs) - processed_cb.relative_update(1) - yield output + prepared_inputs: abc.Generator[Sequence[Any], None, None] = ( + self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs + ) + if self.prefetch > 0: + prepared_inputs = AsyncMapper( + _prefetch_input, prepared_inputs, workers=self.prefetch + ).iterate() + + with contextlib.closing(prepared_inputs): + for row in prepared_inputs: + result_objs = self.process_safe(row) + udf_outputs = (self._flatten_row(row) for row in result_objs) + output = (dict(zip(self.signal_names, row)) for row in udf_outputs) + processed_cb.relative_update(1) + yield output self.teardown() diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index a94d773fe..80f76fdd8 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -472,33 +472,31 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: # Otherwise process single-threaded (faster for smaller UDFs) warehouse = self.catalog.warehouse - with contextlib.closing( - batching(warehouse.dataset_select_paginated, query) - ) as udf_inputs: - download_cb = get_download_callback() - processed_cb = get_processed_callback() - generated_cb = get_generated_callback(self.is_generator) - try: - udf_results = self.udf.run( - udf_fields, - udf_inputs, - self.catalog, - self.is_generator, - self.cache, - download_cb, - processed_cb, - ) - process_udf_outputs( - warehouse, - udf_table, - udf_results, - self.udf, - cb=generated_cb, - ) - finally: - download_cb.close() - processed_cb.close() - generated_cb.close() + udf_inputs = batching(warehouse.dataset_select_paginated, query) + download_cb = get_download_callback() + processed_cb = get_processed_callback() + generated_cb = get_generated_callback(self.is_generator) + try: + udf_results = self.udf.run( + udf_fields, + udf_inputs, + self.catalog, + self.is_generator, + self.cache, + download_cb, + processed_cb, + ) + process_udf_outputs( + warehouse, + udf_table, + udf_results, + self.udf, + cb=generated_cb, + ) + finally: + download_cb.close() + processed_cb.close() + generated_cb.close() warehouse.insert_rows_done(udf_table) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index b256ba296..4779f1580 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -208,17 +208,20 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type): @pytest.mark.parametrize("use_cache", [True, False]) -def test_map_file(cloud_test_catalog, use_cache): +@pytest.mark.parametrize("prefetch", [0, 2]) +def test_map_file(cloud_test_catalog, use_cache, prefetch): ctc = cloud_test_catalog def new_signal(file: File) -> str: + assert bool(file.get_local_path()) is (use_cache and prefetch > 0) with file.open() as f: return file.name + " -> " + f.read().decode("utf-8") dc = ( DataChain.from_storage(ctc.src_uri, session=ctc.session) - .settings(cache=use_cache) + .settings(cache=use_cache, prefetch=prefetch) .map(signal=new_signal) + .save() ) expected = {