Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pre-fetching in map() and gen() #521

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import asyncio
from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator
from collections.abc import (
AsyncIterable,
Awaitable,
Coroutine,
Generator,
Iterable,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from heapq import heappop, heappush
from typing import Any, Callable, Generic, Optional, TypeVar
Expand Down Expand Up @@ -54,9 +61,13 @@ def start_task(self, coro: Coroutine) -> asyncio.Task:
task.add_done_callback(self._tasks.discard)
return task

async def produce(self) -> None:
def _produce(self) -> None:
for item in self.iterable:
await self.work_queue.put(item)
fut = asyncio.run_coroutine_threadsafe(self.work_queue.put(item), self.loop)
fut.result() # wait until the item is in the queue

async def produce(self) -> None:
await self.to_thread(self._produce)

async def worker(self) -> None:
while (item := await self.work_queue.get()) is not None:
Expand Down Expand Up @@ -132,7 +143,7 @@ async def _break_iteration(self) -> None:
self.result_queue.get_nowait()
await self.result_queue.put(None)

def iterate(self, timeout=None) -> Iterable[ResultT]:
def iterate(self, timeout=None) -> Generator[ResultT, None, None]:
init = asyncio.run_coroutine_threadsafe(self.init(), self.loop)
init.result(timeout=1)
async_run = asyncio.run_coroutine_threadsafe(self.run(), self.loop)
Expand Down
5 changes: 4 additions & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def dataset_select_paginated(
if limit < page_size:
paginated_query = paginated_query.limit(None).limit(limit)

results = self.dataset_rows_select(paginated_query.offset(offset))
# Ensure we're using a thread-local connection
with self.clone() as wh:
# Cursor results are not thread-safe, so we convert them to a list
Copy link
Member

@shcheklein shcheklein Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: why do they have to be thread safe (I mean cursor results)? Since we run producer in a separate thread now in async mapper?
Q: are there any implications in terms of memory usage for this?

results = list(wh.dataset_rows_select(paginated_query.offset(offset)))

processed = False
for row in results:
Expand Down
7 changes: 6 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@
parallel=None,
workers=None,
min_task_size=None,
prefetch: Optional[int] = None,
Copy link
Member

@shcheklein shcheklein Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: why int? let's update the docs here (do we have some CI to detect these discrepancies btw (missing docs) cc @skshetry )

sys: Optional[bool] = None,
) -> "Self":
"""Change settings for chain.
Expand All @@ -351,7 +352,7 @@
if sys is None:
sys = self._sys
settings = copy.copy(self._settings)
settings.add(Settings(cache, parallel, workers, min_task_size))
settings.add(Settings(cache, parallel, workers, min_task_size, prefetch))
return self._evolve(settings=settings, _sys=sys)

def reset_settings(self, settings: Optional[Settings] = None) -> "Self":
Expand Down Expand Up @@ -801,6 +802,8 @@
```
"""
udf_obj = self._udf_to_obj(Mapper, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch

return self._evolve(
query=self._query.add_signals(
Expand Down Expand Up @@ -838,6 +841,8 @@
```
"""
udf_obj = self._udf_to_obj(Generator, func, params, output, signal_map)
if (prefetch := self._settings.prefetch) is not None:
udf_obj.prefetch = prefetch

Check warning on line 845 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L845

Added line #L845 was not covered by tests
return self._evolve(
query=self._query.generate(
udf_obj.to_udf_wrapper(),
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ def ensure_cached(self) -> None:
client = self._catalog.get_client(self.source)
client.download(self, callback=self._download_cb)

async def _prefetch(self) -> None:
if self._caching_enabled:
client = self._catalog.get_client(self.source)
await client._download(self, callback=self._download_cb)

def get_local_path(self) -> Optional[str]:
"""Return path to a file in a local cache.

Expand Down
12 changes: 11 additions & 1 deletion src/datachain/lib/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@ def __init__(self, msg):


class Settings:
def __init__(self, cache=None, parallel=None, workers=None, min_task_size=None):
def __init__(
self,
cache=None,
parallel=None,
workers=None,
min_task_size=None,
prefetch=None,
):
self._cache = cache
self.parallel = parallel
self._workers = workers
self.min_task_size = min_task_size
self.prefetch = prefetch

if not isinstance(cache, bool) and cache is not None:
raise SettingsError(
Expand Down Expand Up @@ -66,3 +74,5 @@ def add(self, settings: "Settings"):
self.parallel = settings.parallel or self.parallel
self._workers = settings._workers or self._workers
self.min_task_size = settings.min_task_size or self.min_task_size
if settings.prefetch is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is a reason to have a mix of styles here - some protected vars, some not, some self._cache = settings._cache or self._cache, some like if settings.prefetch is not None:?

self.prefetch = settings.prefetch
63 changes: 45 additions & 18 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import sys
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -7,6 +8,7 @@
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from pydantic import BaseModel

from datachain.asyn import AsyncMapper
from datachain.dataset import RowDict
from datachain.lib.convert.flatten import flatten
from datachain.lib.data_model import DataValue
Expand All @@ -22,6 +24,8 @@
)

if TYPE_CHECKING:
from collections import abc

from typing_extensions import Self

from datachain.catalog import Catalog
Expand Down Expand Up @@ -276,9 +280,18 @@ def process_safe(self, obj_rows):
return result_objs


async def _prefetch_input(row):
for obj in row:
if isinstance(obj, File):
await obj._prefetch()
return row


class Mapper(UDFBase):
"""Inherit from this class to pass to `DataChain.map()`."""

prefetch: int = 2

def run(
self,
udf_fields: "Sequence[str]",
Expand All @@ -290,16 +303,22 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
id_, *udf_args = self._prepare_row_and_id(
row, udf_fields, cache, download_cb
)
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row_and_id(row, udf_fields, cache, download_cb)
for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for id_, *udf_args in prepared_inputs:
result_objs = self.process_safe(udf_args)
udf_output = self._flatten_row(result_objs)
output = [{"sys__id": id_} | dict(zip(self.signal_names, udf_output))]
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down Expand Up @@ -349,6 +368,7 @@ class Generator(UDFBase):
"""Inherit from this class to pass to `DataChain.gen()`."""

is_output_batched = True
prefetch: int = 2

def run(
self,
Expand All @@ -361,14 +381,21 @@ def run(
) -> Iterator[Iterable[UDFResult]]:
self.catalog = catalog
self.setup()

for row in udf_inputs:
udf_args = self._prepare_row(row, udf_fields, cache, download_cb)
result_objs = self.process_safe(udf_args)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output
prepared_inputs: abc.Generator[Sequence[Any], None, None] = (
self._prepare_row(row, udf_fields, cache, download_cb) for row in udf_inputs
)
if self.prefetch > 0:
prepared_inputs = AsyncMapper(
_prefetch_input, prepared_inputs, workers=self.prefetch
).iterate()

with contextlib.closing(prepared_inputs):
for row in prepared_inputs:
result_objs = self.process_safe(row)
udf_outputs = (self._flatten_row(row) for row in result_objs)
output = (dict(zip(self.signal_names, row)) for row in udf_outputs)
processed_cb.relative_update(1)
yield output

self.teardown()

Expand Down
52 changes: 25 additions & 27 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,33 +472,31 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
# Otherwise process single-threaded (faster for smaller UDFs)
warehouse = self.catalog.warehouse

with contextlib.closing(
batching(warehouse.dataset_select_paginated, query)
) as udf_inputs:
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()
udf_inputs = batching(warehouse.dataset_select_paginated, query)
download_cb = get_download_callback()
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
)
process_udf_outputs(
warehouse,
udf_table,
udf_results,
self.udf,
cb=generated_cb,
)
finally:
download_cb.close()
processed_cb.close()
generated_cb.close()

warehouse.insert_rows_done(udf_table)

Expand Down
7 changes: 5 additions & 2 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,20 @@ def test_from_storage_dependencies(cloud_test_catalog, cloud_type):


@pytest.mark.parametrize("use_cache", [True, False])
def test_map_file(cloud_test_catalog, use_cache):
@pytest.mark.parametrize("prefetch", [0, 2])
def test_map_file(cloud_test_catalog, use_cache, prefetch):
ctc = cloud_test_catalog

def new_signal(file: File) -> str:
assert bool(file.get_local_path()) is (use_cache and prefetch > 0)
with file.open() as f:
return file.name + " -> " + f.read().decode("utf-8")

dc = (
DataChain.from_storage(ctc.src_uri, session=ctc.session)
.settings(cache=use_cache)
.settings(cache=use_cache, prefetch=prefetch)
.map(signal=new_signal)
.save()
)

expected = {
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_asyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
from collections import Counter
from contextlib import contextmanager
from queue import Queue

import pytest
from fsspec.asyn import sync
Expand Down Expand Up @@ -111,6 +112,37 @@ async def process(row):
list(mapper.iterate(timeout=4))


@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper])
def test_mapper_deadlock(create_mapper):
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it deadlock if we don't wrap producer into a thread? were you trying to test (make sure that producer is wrapped)?

queue = Queue()
inputs = range(50)

def as_iter(queue):
while (item := queue.get()) is not None:
yield item

async def process(x):
return x

mapper = create_mapper(process, as_iter(queue), workers=10, loop=get_loop())
it = mapper.iterate(timeout=4)
for i in inputs:
queue.put(i)

# Check that we can get as many objects out as we put in, without deadlock
result = []
for _ in range(len(inputs)):
result.append(next(it))
if mapper.order_preserving:
assert result == list(inputs)
else:
assert set(result) == set(inputs)

# Check that iteration terminates cleanly
queue.put(None)
assert list(it) == []


@pytest.mark.parametrize("create_mapper", [AsyncMapper, OrderedMapper])
@settings(deadline=None)
@given(
Expand Down
Loading