Skip to content

Refactor Dataset.map to reuse cache files mapped with different num_proc #7434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
5341540
Refactor Dataset.map to reuse cache files mapped with different num_proc
ringohoffman Mar 4, 2025
bdc17c9
Only give reprocessing message doing a partial remap
ringohoffman Mar 4, 2025
d7c63fd
Update logging message to account for if a cache file will be written…
ringohoffman Mar 4, 2025
0df4132
Refactor string_to_dict to return None if there is no match instead o…
ringohoffman Mar 4, 2025
7f50b98
Merge branch 'return-none-if-string_to_dict-no-match' into reuse-cach…
ringohoffman Mar 4, 2025
79dc83b
Simplify existing existing_cache_file_map with string_to_dict
ringohoffman Mar 4, 2025
bb7f9b5
Set initial value if there are already existing cache files
ringohoffman Mar 4, 2025
dafe4f2
Merge branch 'main' into return-none-if-string_to_dict-no-match
ringohoffman Mar 5, 2025
e2c1a5c
Merge branch 'return-none-if-string_to_dict-no-match' into reuse-cach…
ringohoffman Mar 5, 2025
c82cab4
Allow for source_url_fields to be None
ringohoffman Mar 7, 2025
28d82dc
Merge branch 'main' into return-none-if-string_to_dict-no-match
ringohoffman Mar 7, 2025
71b6d16
Merge branch 'return-none-if-string_to_dict-no-match' into reuse-cach…
ringohoffman Mar 9, 2025
8cc0186
Merge branch 'main' into reuse-cache-on-different-num_proc
ringohoffman Mar 12, 2025
637c160
Add unicode escape to handle parsing string_to_dict in Windows paths
ringohoffman Mar 12, 2025
25c0015
Merge branch 'main' into reuse-cache-on-different-num_proc
lhoestq Mar 14, 2025
583c28e
Remove glob_pattern_to_regex
ringohoffman Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 175 additions & 111 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import copy
import fnmatch
import glob
import inspect
import itertools
import json
Expand All @@ -27,12 +28,13 @@
import posixpath
import re
import shutil
import string
import sys
import tempfile
import time
import warnings
import weakref
from collections import Counter
from collections import Counter, defaultdict
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Sequence as Sequence_
from copy import deepcopy
Expand Down Expand Up @@ -2959,6 +2961,11 @@ def map(
if num_proc is not None and num_proc <= 0:
raise ValueError("num_proc must be an integer > 0.")

string_formatter = string.Formatter()
fields = {field_name for _, field_name, _, _ in string_formatter.parse(suffix_template) if field_name}
if fields != {"rank", "num_proc"}:
raise ValueError(f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: {fields}")

# If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway)
if len(self) == 0:
if self._indices is not None: # empty indices mapping
Expand Down Expand Up @@ -3040,7 +3047,14 @@ def map(
cache_file_name = self._get_cache_file_path(new_fingerprint)
dataset_kwargs["cache_file_name"] = cache_file_name

def load_processed_shard_from_cache(shard_kwargs):
if cache_file_name is not None:
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
if not cache_file_ext:
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
else:
cache_file_prefix = cache_file_ext = None

def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
"""Load a processed shard from cache if it exists, otherwise throw an error."""
shard = shard_kwargs["shard"]
# Check if we've already cached this computation (indexed by a hash)
Expand All @@ -3051,64 +3065,71 @@ def load_processed_shard_from_cache(shard_kwargs):
return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split)
raise NonExistentDatasetError

num_shards = num_proc if num_proc is not None else 1
if batched and drop_last_batch:
pbar_total = len(self) // num_shards // batch_size * num_shards * batch_size
else:
pbar_total = len(self)
existing_cache_file_map: dict[int, list[str]] = defaultdict(list)
if cache_file_name is not None:
if os.path.exists(cache_file_name):
existing_cache_file_map[1] = [cache_file_name]

shards_done = 0
if num_proc is None or num_proc == 1:
transformed_dataset = None
try:
transformed_dataset = load_processed_shard_from_cache(dataset_kwargs)
logger.info(f"Loading cached processed dataset at {dataset_kwargs['cache_file_name']}")
except NonExistentDatasetError:
pass
if transformed_dataset is None:
with hf_tqdm(
unit=" examples",
total=pbar_total,
desc=desc or "Map",
) as pbar:
for rank, done, content in Dataset._map_single(**dataset_kwargs):
if done:
shards_done += 1
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
transformed_dataset = content
else:
pbar.update(content)
assert transformed_dataset is not None, "Failed to retrieve the result from map"
# update fingerprint if the dataset changed
if transformed_dataset._fingerprint != self._fingerprint:
transformed_dataset._fingerprint = new_fingerprint
return transformed_dataset
else:
assert cache_file_prefix is not None and cache_file_ext is not None
cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext

def format_cache_file_name(
cache_file_name: Optional[str],
rank: Union[int, Literal["*"]], # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name
sep = cache_file_name.rindex(".")
base_name, extension = cache_file_name[:sep], cache_file_name[sep:]
if isinstance(rank, int):
cache_file_name = base_name + suffix_template.format(rank=rank, num_proc=num_proc) + extension
logger.info(f"Process #{rank} will write at {cache_file_name}")
else:
cache_file_name = (
base_name
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_proc)
+ extension
)
for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"):
suffix_variable_map = string_to_dict(cache_file, cache_file_with_suffix_pattern)
if suffix_variable_map is not None:
file_num_proc = int(suffix_variable_map["num_proc"])
existing_cache_file_map[file_num_proc].append(cache_file)

num_shards = num_proc or 1
if existing_cache_file_map:
# to avoid remapping when a different num_proc is given than when originally cached, update num_shards to
# what was used originally

def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]:
percent_missing = (mapped_num_proc - len(existing_cache_file_map[mapped_num_proc])) / mapped_num_proc
num_shards_diff = abs(mapped_num_proc - num_shards)
return (
percent_missing, # choose the most complete set of existing cache files
num_shards_diff, # then choose the mapped_num_proc closest to the current num_proc
mapped_num_proc, # finally, choose whichever mapped_num_proc is lower
)

num_shards = min(existing_cache_file_map, key=select_existing_cache_files)

existing_cache_files = existing_cache_file_map[num_shards]

def format_cache_file_name(
cache_file_name: Optional[str],
rank: Union[int, Literal["*"]], # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name

def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_proc)
validate_fingerprint(new_fingerprint)
return new_fingerprint
assert cache_file_prefix is not None and cache_file_ext is not None

if isinstance(rank, int):
cache_file_name = (
cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext
)
if not os.path.exists(cache_file_name):
process_name = (
"Main process" if num_proc is None or num_proc == 1 else f"Process #{rank % num_shards + 1}"
)
logger.info(f"{process_name} will write at {cache_file_name}")
else:
# TODO: this assumes the format_spec of rank in suffix_template
cache_file_name = (
cache_file_prefix
+ suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_shards)
+ cache_file_ext
)
return cache_file_name

def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_shards)
validate_fingerprint(new_fingerprint)
return new_fingerprint

if num_proc is not None and num_proc > 1:
prev_env = deepcopy(os.environ)
# check if parallelism if off
# from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22
Expand All @@ -3123,9 +3144,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
):
logger.warning("Setting TOKENIZERS_PARALLELISM=false for forked processes.")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
else:
prev_env = os.environ

kwargs_per_job: list[Optional[dict[str, Any]]]
if num_shards == 1:
shards = [self]
kwargs_per_job = [dataset_kwargs]
else:
shards = [
self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
for rank in range(num_proc)
self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory)
for rank in range(num_shards)
]
kwargs_per_job = [
{
Expand All @@ -3139,62 +3168,97 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
for rank in range(num_shards)
]

transformed_shards = [None] * num_shards
for rank in range(num_shards):
try:
transformed_shards[rank] = load_processed_shard_from_cache(kwargs_per_job[rank])
kwargs_per_job[rank] = None
except NonExistentDatasetError:
pass

kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None]

# We try to create a pool with as many workers as dataset not yet cached.
if kwargs_per_job:
if len(kwargs_per_job) < num_shards:
logger.info(
f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
)
with Pool(len(kwargs_per_job)) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")
with hf_tqdm(
unit=" examples",
total=pbar_total,
desc=(desc or "Map") + f" (num_proc={num_proc})",
) as pbar:
transformed_shards: list[Optional[Dataset]] = [None] * num_shards
for rank in range(num_shards):
try:
job_kwargs = kwargs_per_job[rank]
assert job_kwargs is not None
transformed_shards[rank] = load_processed_shard_from_cache(job_kwargs)
kwargs_per_job[rank] = None
except NonExistentDatasetError:
pass

if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]:
if len(unprocessed_kwargs_per_job) != num_shards:
logger.info(
f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were "
"missing from the cache."
)

pbar_total = len(self)
pbar_initial = len(existing_cache_files) * pbar_total // num_shards
if batched and drop_last_batch:
batch_size = batch_size or 1
pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size
pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size

with hf_tqdm(
unit=" examples",
initial=pbar_initial,
total=pbar_total,
desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""),
) as pbar:
shards_done = 0

def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, int]) -> None:
nonlocal shards_done
if done:
shards_done += 1
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
assert isinstance(content, Dataset)
transformed_shards[rank or 0] = content
else:
assert isinstance(content, int)
pbar.update(content)

if num_proc is not None and num_proc > 1:
with Pool(num_proc) as pool:
os.environ = prev_env
logger.info(f"Spawning {num_proc} processes")

for rank, done, content in iflatmap_unordered(
pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
pool, Dataset._map_single, kwargs_iterable=unprocessed_kwargs_per_job
):
if done:
shards_done += 1
logger.debug(f"Finished processing shard number {rank} of {num_shards}.")
transformed_shards[rank] = content
else:
pbar.update(content)
pool.close()
pool.join()
# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
for kwargs in kwargs_per_job:
del kwargs["shard"]
else:
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
if None in transformed_shards:
raise ValueError(
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at "
"least one worker failed to return its results"
)
logger.info(f"Concatenating {num_proc} shards")
result = _concatenate_map_style_datasets(transformed_shards)
# update fingerprint if the dataset changed
check_if_shard_done(rank, done, content)

pool.close()
pool.join()
else:
for unprocessed_kwargs in unprocessed_kwargs_per_job:
for rank, done, content in Dataset._map_single(**unprocessed_kwargs):
check_if_shard_done(rank, done, content)

# Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805)
for job_kwargs in unprocessed_kwargs_per_job:
if "shard" in job_kwargs:
del job_kwargs["shard"]
else:
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")

all_transformed_shards = [shard for shard in transformed_shards if shard is not None]
if len(transformed_shards) != len(all_transformed_shards):
raise ValueError(
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - "
"at least one worker failed to return its results"
)

if num_shards == 1:
result = all_transformed_shards[0]
else:
logger.info(f"Concatenating {num_shards} shards")
result = _concatenate_map_style_datasets(all_transformed_shards)

# update fingerprint if the dataset changed
result._fingerprint = (
new_fingerprint
if any(
transformed_shard._fingerprint != shard._fingerprint
for transformed_shard, shard in zip(transformed_shards, shards)
):
result._fingerprint = new_fingerprint
else:
result._fingerprint = self._fingerprint
return result
for transformed_shard, shard in zip(all_transformed_shards, shards)
)
else self._fingerprint
)

return result

@staticmethod
def _map_single(
Expand All @@ -3216,7 +3280,7 @@ def _map_single(
new_fingerprint: Optional[str] = None,
rank: Optional[int] = None,
offset: int = 0,
) -> Iterable[tuple[int, bool, Union[int, "Dataset"]]]:
) -> Iterable[tuple[Optional[int], bool, Union[int, "Dataset"]]]:
"""Apply a function to all the elements in the table (individually or in batches)
and update the table (if function does update examples).

Expand Down Expand Up @@ -5750,7 +5814,7 @@ def push_to_hub(
@transmit_format
@fingerprint_transform(inplace=False)
def add_column(
self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None
self, name: str, column: Union[list, np.ndarray], new_fingerprint: str, feature: Optional[FeatureType] = None
):
"""Add column to Dataset.

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin
from .utils.py_utils import glob_pattern_to_regex, string_to_dict
from .utils.py_utils import string_to_dict


SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]]
Expand Down Expand Up @@ -266,7 +266,7 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> di
if len(data_files) > 0:
splits: set[str] = set()
for p in data_files:
p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))
p_parts = string_to_dict(xbasename(p), xbasename(split_pattern))
assert p_parts is not None
splits.add(p_parts["split"])

Expand Down
1 change: 1 addition & 0 deletions src/datasets/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]:
Optional[dict[str, str]]: dictionary of variable -> value, retrieved from the input using the pattern, or
`None` if the string does not match the pattern.
"""
pattern = pattern.encode("unicode_escape").decode("utf-8") # C:\\Users -> C:\\\\Users for Windows paths
pattern = re.sub(r"{([^:}]+)(?::[^}]+)?}", r"{\1}", pattern) # remove format specifiers, e.g. {rank:05d} -> {rank}
regex = re.sub(r"{(.+?)}", r"(?P<_\1>.+)", pattern)
result = re.search(regex, string)
Expand Down
Loading
Loading