diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index cc39ffb566c..0bb4821b83e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -19,6 +19,7 @@ import contextlib import copy import fnmatch +import glob import inspect import itertools import json @@ -27,12 +28,13 @@ import posixpath import re import shutil +import string import sys import tempfile import time import warnings import weakref -from collections import Counter +from collections import Counter, defaultdict from collections.abc import Iterable, Iterator, Mapping from collections.abc import Sequence as Sequence_ from copy import deepcopy @@ -2959,6 +2961,11 @@ def map( if num_proc is not None and num_proc <= 0: raise ValueError("num_proc must be an integer > 0.") + string_formatter = string.Formatter() + fields = {field_name for _, field_name, _, _ in string_formatter.parse(suffix_template) if field_name} + if fields != {"rank", "num_proc"}: + raise ValueError(f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: {fields}") + # If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway) if len(self) == 0: if self._indices is not None: # empty indices mapping @@ -3040,7 +3047,14 @@ def map( cache_file_name = self._get_cache_file_path(new_fingerprint) dataset_kwargs["cache_file_name"] = cache_file_name - def load_processed_shard_from_cache(shard_kwargs): + if cache_file_name is not None: + cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name) + if not cache_file_ext: + raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}") + else: + cache_file_prefix = cache_file_ext = None + + def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset: """Load a processed shard from cache if it exists, otherwise throw an error.""" shard = shard_kwargs["shard"] # Check if we've already cached this computation (indexed by a hash) @@ -3051,64 +3065,71 @@ def load_processed_shard_from_cache(shard_kwargs): return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split) raise NonExistentDatasetError - num_shards = num_proc if num_proc is not None else 1 - if batched and drop_last_batch: - pbar_total = len(self) // num_shards // batch_size * num_shards * batch_size - else: - pbar_total = len(self) + existing_cache_file_map: dict[int, list[str]] = defaultdict(list) + if cache_file_name is not None: + if os.path.exists(cache_file_name): + existing_cache_file_map[1] = [cache_file_name] - shards_done = 0 - if num_proc is None or num_proc == 1: - transformed_dataset = None - try: - transformed_dataset = load_processed_shard_from_cache(dataset_kwargs) - logger.info(f"Loading cached processed dataset at {dataset_kwargs['cache_file_name']}") - except NonExistentDatasetError: - pass - if transformed_dataset is None: - with hf_tqdm( - unit=" examples", - total=pbar_total, - desc=desc or "Map", - ) as pbar: - for rank, done, content in Dataset._map_single(**dataset_kwargs): - if done: - shards_done += 1 - logger.debug(f"Finished processing shard number {rank} of {num_shards}.") - transformed_dataset = content - else: - pbar.update(content) - assert transformed_dataset is not None, "Failed to retrieve the result from map" - # update fingerprint if the dataset changed - if transformed_dataset._fingerprint != self._fingerprint: - transformed_dataset._fingerprint = new_fingerprint - return transformed_dataset - else: + assert cache_file_prefix is not None and cache_file_ext is not None + cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext - def format_cache_file_name( - cache_file_name: Optional[str], - rank: Union[int, Literal["*"]], # noqa: F722 - ) -> Optional[str]: - if not cache_file_name: - return cache_file_name - sep = cache_file_name.rindex(".") - base_name, extension = cache_file_name[:sep], cache_file_name[sep:] - if isinstance(rank, int): - cache_file_name = base_name + suffix_template.format(rank=rank, num_proc=num_proc) + extension - logger.info(f"Process #{rank} will write at {cache_file_name}") - else: - cache_file_name = ( - base_name - + suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_proc) - + extension - ) + for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"): + suffix_variable_map = string_to_dict(cache_file, cache_file_with_suffix_pattern) + if suffix_variable_map is not None: + file_num_proc = int(suffix_variable_map["num_proc"]) + existing_cache_file_map[file_num_proc].append(cache_file) + + num_shards = num_proc or 1 + if existing_cache_file_map: + # to avoid remapping when a different num_proc is given than when originally cached, update num_shards to + # what was used originally + + def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]: + percent_missing = (mapped_num_proc - len(existing_cache_file_map[mapped_num_proc])) / mapped_num_proc + num_shards_diff = abs(mapped_num_proc - num_shards) + return ( + percent_missing, # choose the most complete set of existing cache files + num_shards_diff, # then choose the mapped_num_proc closest to the current num_proc + mapped_num_proc, # finally, choose whichever mapped_num_proc is lower + ) + + num_shards = min(existing_cache_file_map, key=select_existing_cache_files) + + existing_cache_files = existing_cache_file_map[num_shards] + + def format_cache_file_name( + cache_file_name: Optional[str], + rank: Union[int, Literal["*"]], # noqa: F722 + ) -> Optional[str]: + if not cache_file_name: return cache_file_name - def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: - new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_proc) - validate_fingerprint(new_fingerprint) - return new_fingerprint + assert cache_file_prefix is not None and cache_file_ext is not None + + if isinstance(rank, int): + cache_file_name = ( + cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext + ) + if not os.path.exists(cache_file_name): + process_name = ( + "Main process" if num_proc is None or num_proc == 1 else f"Process #{rank % num_shards + 1}" + ) + logger.info(f"{process_name} will write at {cache_file_name}") + else: + # TODO: this assumes the format_spec of rank in suffix_template + cache_file_name = ( + cache_file_prefix + + suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_shards) + + cache_file_ext + ) + return cache_file_name + + def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: + new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_shards) + validate_fingerprint(new_fingerprint) + return new_fingerprint + if num_proc is not None and num_proc > 1: prev_env = deepcopy(os.environ) # check if parallelism if off # from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22 @@ -3123,9 +3144,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: ): logger.warning("Setting TOKENIZERS_PARALLELISM=false for forked processes.") os.environ["TOKENIZERS_PARALLELISM"] = "false" + else: + prev_env = os.environ + + kwargs_per_job: list[Optional[dict[str, Any]]] + if num_shards == 1: + shards = [self] + kwargs_per_job = [dataset_kwargs] + else: shards = [ - self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory) - for rank in range(num_proc) + self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory) + for rank in range(num_shards) ] kwargs_per_job = [ { @@ -3139,62 +3168,97 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: for rank in range(num_shards) ] - transformed_shards = [None] * num_shards - for rank in range(num_shards): - try: - transformed_shards[rank] = load_processed_shard_from_cache(kwargs_per_job[rank]) - kwargs_per_job[rank] = None - except NonExistentDatasetError: - pass - - kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None] - - # We try to create a pool with as many workers as dataset not yet cached. - if kwargs_per_job: - if len(kwargs_per_job) < num_shards: - logger.info( - f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache." - ) - with Pool(len(kwargs_per_job)) as pool: - os.environ = prev_env - logger.info(f"Spawning {num_proc} processes") - with hf_tqdm( - unit=" examples", - total=pbar_total, - desc=(desc or "Map") + f" (num_proc={num_proc})", - ) as pbar: + transformed_shards: list[Optional[Dataset]] = [None] * num_shards + for rank in range(num_shards): + try: + job_kwargs = kwargs_per_job[rank] + assert job_kwargs is not None + transformed_shards[rank] = load_processed_shard_from_cache(job_kwargs) + kwargs_per_job[rank] = None + except NonExistentDatasetError: + pass + + if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]: + if len(unprocessed_kwargs_per_job) != num_shards: + logger.info( + f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were " + "missing from the cache." + ) + + pbar_total = len(self) + pbar_initial = len(existing_cache_files) * pbar_total // num_shards + if batched and drop_last_batch: + batch_size = batch_size or 1 + pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size + pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size + + with hf_tqdm( + unit=" examples", + initial=pbar_initial, + total=pbar_total, + desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""), + ) as pbar: + shards_done = 0 + + def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, int]) -> None: + nonlocal shards_done + if done: + shards_done += 1 + logger.debug(f"Finished processing shard number {rank} of {num_shards}.") + assert isinstance(content, Dataset) + transformed_shards[rank or 0] = content + else: + assert isinstance(content, int) + pbar.update(content) + + if num_proc is not None and num_proc > 1: + with Pool(num_proc) as pool: + os.environ = prev_env + logger.info(f"Spawning {num_proc} processes") + for rank, done, content in iflatmap_unordered( - pool, Dataset._map_single, kwargs_iterable=kwargs_per_job + pool, Dataset._map_single, kwargs_iterable=unprocessed_kwargs_per_job ): - if done: - shards_done += 1 - logger.debug(f"Finished processing shard number {rank} of {num_shards}.") - transformed_shards[rank] = content - else: - pbar.update(content) - pool.close() - pool.join() - # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805) - for kwargs in kwargs_per_job: - del kwargs["shard"] - else: - logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}") - if None in transformed_shards: - raise ValueError( - f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at " - "least one worker failed to return its results" - ) - logger.info(f"Concatenating {num_proc} shards") - result = _concatenate_map_style_datasets(transformed_shards) - # update fingerprint if the dataset changed + check_if_shard_done(rank, done, content) + + pool.close() + pool.join() + else: + for unprocessed_kwargs in unprocessed_kwargs_per_job: + for rank, done, content in Dataset._map_single(**unprocessed_kwargs): + check_if_shard_done(rank, done, content) + + # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805) + for job_kwargs in unprocessed_kwargs_per_job: + if "shard" in job_kwargs: + del job_kwargs["shard"] + else: + logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}") + + all_transformed_shards = [shard for shard in transformed_shards if shard is not None] + if len(transformed_shards) != len(all_transformed_shards): + raise ValueError( + f"Failed to retrieve results from map: result list {transformed_shards} still contains None - " + "at least one worker failed to return its results" + ) + + if num_shards == 1: + result = all_transformed_shards[0] + else: + logger.info(f"Concatenating {num_shards} shards") + result = _concatenate_map_style_datasets(all_transformed_shards) + + # update fingerprint if the dataset changed + result._fingerprint = ( + new_fingerprint if any( transformed_shard._fingerprint != shard._fingerprint - for transformed_shard, shard in zip(transformed_shards, shards) - ): - result._fingerprint = new_fingerprint - else: - result._fingerprint = self._fingerprint - return result + for transformed_shard, shard in zip(all_transformed_shards, shards) + ) + else self._fingerprint + ) + + return result @staticmethod def _map_single( @@ -3216,7 +3280,7 @@ def _map_single( new_fingerprint: Optional[str] = None, rank: Optional[int] = None, offset: int = 0, - ) -> Iterable[tuple[int, bool, Union[int, "Dataset"]]]: + ) -> Iterable[tuple[Optional[int], bool, Union[int, "Dataset"]]]: """Apply a function to all the elements in the table (individually or in batches) and update the table (if function does update examples). @@ -5750,7 +5814,7 @@ def push_to_hub( @transmit_format @fingerprint_transform(inplace=False) def add_column( - self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None + self, name: str, column: Union[list, np.ndarray], new_fingerprint: str, feature: Optional[FeatureType] = None ): """Add column to Dataset. diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 2a1caa81bf7..bedbba7a795 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -19,7 +19,7 @@ from .utils import logging from .utils import tqdm as hf_tqdm from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin -from .utils.py_utils import glob_pattern_to_regex, string_to_dict +from .utils.py_utils import string_to_dict SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]] @@ -266,7 +266,7 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> di if len(data_files) > 0: splits: set[str] = set() for p in data_files: - p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern))) + p_parts = string_to_dict(xbasename(p), xbasename(split_pattern)) assert p_parts is not None splits.add(p_parts["split"]) diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index d954f548f22..79606f71cd2 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -180,6 +180,7 @@ def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]: Optional[dict[str, str]]: dictionary of variable -> value, retrieved from the input using the pattern, or `None` if the string does not match the pattern. """ + pattern = pattern.encode("unicode_escape").decode("utf-8") # C:\\Users -> C:\\\\Users for Windows paths pattern = re.sub(r"{([^:}]+)(?::[^}]+)?}", r"{\1}", pattern) # remove format specifiers, e.g. {rank:05d} -> {rank} regex = re.sub(r"{(.+?)}", r"(?P<_\1>.+)", pattern) result = re.search(regex, string) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 2e54aadf7b6..d8083b79ca3 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1459,6 +1459,91 @@ def test_map_caching(self, in_memory): finally: datasets.enable_caching() + def test_suffix_template_format(self, in_memory): + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + self.assertRaises(ValueError) as e, + dset.map(lambda x: {"foo": "bar"}, suffix_template="_{}_of_{}"), + ): + self.assertIn( + "suffix_template must contain exactly the fields 'rank' and 'num_proc', got: ", + e.exception.args[0], + ) + + def test_cache_file_name_no_ext_raises_error(self, in_memory): + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + self.assertRaises(ValueError) as e, + dset.map(lambda x: {"foo": "bar"}, cache_file_name=os.path.join(tmp_dir, "train")), + ): + self.assertIn("Expected cache_file_name to have an extension, but got: ", e.exception.args[0]) + + def test_map_caching_reuses_cache_with_different_num_proc(self, in_memory): + for dset_test1_num_proc, dset_test2_num_proc in [(1, 2), (2, 1)]: + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + ): + # cannot mock _map_single here because mock objects aren't picklable + # see: https://github.com/python/cpython/issues/100090 + self._caplog.clear() + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test1_num_proc) as dset_test1: + dset_test1_data_files = list(dset_test1.cache_files) + self.assertFalse("Loading cached processed dataset" in self._caplog.text) + + self._caplog.clear() + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test2_num_proc) as dset_test2: + self.assertEqual(dset_test1_data_files, dset_test2.cache_files) + self.assertEqual(len(dset_test2.cache_files), 0 if in_memory else dset_test1_num_proc) + self.assertTrue(("Loading cached processed dataset" in self._caplog.text) ^ in_memory) + + def test_map_caching_partial_remap(self, in_memory): + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + ): + # cannot mock _map_single here because mock objects aren't picklable + # see: https://github.com/python/cpython/issues/100090 + self._caplog.clear() + dset_test1_num_proc = 4 + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test1_num_proc) as dset_test1: + dset_test1_data_files = list(dset_test1.cache_files) + self.assertFalse("Loading cached processed dataset" in self._caplog.text) + + num_files_to_delete = 2 + expected_msg = ( + f"Reprocessing {num_files_to_delete}/{dset_test1_num_proc} shards because some of them " + "were missing from the cache." + ) + for cache_file in dset_test1_data_files[num_files_to_delete:]: + os.remove(cache_file["filename"]) + + self._caplog.clear() + dset_test2_num_proc = 1 + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test2_num_proc) as dset_test2: + self.assertEqual(dset_test1_data_files, dset_test2.cache_files) + self.assertEqual(len(dset_test2.cache_files), 0 if in_memory else dset_test1_num_proc) + self.assertTrue((expected_msg in self._caplog.text) ^ in_memory) + self.assertFalse(f"Spawning {dset_test1_num_proc} processes" in self._caplog.text) + self.assertFalse(f"Spawning {dset_test2_num_proc} processes" in self._caplog.text) + + for cache_file in dset_test1_data_files[num_files_to_delete:]: + os.remove(cache_file["filename"]) + + self._caplog.clear() + dset_test3_num_proc = 3 + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test3_num_proc) as dset_test3: + self.assertEqual(dset_test1_data_files, dset_test3.cache_files) + self.assertEqual(len(dset_test3.cache_files), 0 if in_memory else dset_test1_num_proc) + self.assertTrue((expected_msg in self._caplog.text) ^ in_memory) + self.assertTrue(f"Spawning {dset_test3_num_proc} processes" in self._caplog.text) + def test_map_return_pa_table(self, in_memory): def func_return_single_row_pa_table(x): return pa.table({"id": [0], "text": ["a"]})