Skip to content

Commit

Permalink
Catch OSError for arrow (#7348)
Browse files Browse the repository at this point in the history
* catch oserror in arrow

* Update arrow.py

* style
  • Loading branch information
lhoestq authored Jan 9, 2025
1 parent 9e88687 commit 7a1a84b
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 37 deletions.
10 changes: 5 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,7 +1990,7 @@ def flatten(self, new_fingerprint: Optional[str] = None, max_depth=16) -> "Datas
dataset.info.features = self._info.features.flatten(max_depth=max_depth)
dataset.info.features = Features({col: dataset.info.features[col] for col in dataset.data.column_names})
dataset._data = update_metadata_with_features(dataset._data, dataset.features)
logger.info(f'Flattened dataset from depth {depth} to depth {1 if depth + 1 < max_depth else "unknown"}.')
logger.info(f"Flattened dataset from depth {depth} to depth {1 if depth + 1 < max_depth else 'unknown'}.")
dataset._fingerprint = new_fingerprint
return dataset

Expand Down Expand Up @@ -3176,9 +3176,9 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
del kwargs["shard"]
else:
logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}")
assert (
None not in transformed_shards
), f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
assert None not in transformed_shards, (
f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results"
)
logger.info(f"Concatenating {num_proc} shards")
result = _concatenate_map_style_datasets(transformed_shards)
# update fingerprint if the dataset changed
Expand Down Expand Up @@ -5651,7 +5651,7 @@ def push_to_hub(
create_pr=create_pr,
)
logger.info(
f"Commit #{i+1} completed"
f"Commit #{i + 1} completed"
+ (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "")
+ "."
)
Expand Down
14 changes: 7 additions & 7 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ def as_dataset(
"datasets.load_dataset() before trying to access the Dataset object."
)

logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._output_dir}')
logger.debug(f"Constructing Dataset for split {split or ', '.join(self.info.splits)}, from {self._output_dir}")

# By default, return all splits
if split is None:
Expand Down Expand Up @@ -1528,9 +1528,9 @@ def _prepare_split(
# the content is the number of examples progress update
pbar.update(content)

assert (
None not in examples_per_job
), f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
assert None not in examples_per_job, (
f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
)

total_shards = sum(shards_per_job)
total_num_examples = sum(examples_per_job)
Expand Down Expand Up @@ -1783,9 +1783,9 @@ def _prepare_split(
# the content is the number of examples progress update
pbar.update(content)

assert (
None not in examples_per_job
), f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
assert None not in examples_per_job, (
f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results"
)

total_shards = sum(shards_per_job)
total_num_examples = sum(examples_per_job)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,7 @@ def push_to_hub(
create_pr=create_pr,
)
logger.info(
f"Commit #{i+1} completed"
f"Commit #{i + 1} completed"
+ (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "")
+ "."
)
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,8 +2168,8 @@ def recursive_reorder(source, target, stack=""):
if sorted(source) != sorted(target):
message = (
f"Keys mismatch: between {source} (source) and {target} (target).\n"
f"{source.keys()-target.keys()} are missing from target "
f"and {target.keys()-source.keys()} are missing from source" + stack_position
f"{source.keys() - target.keys()} are missing from target "
f"and {target.keys() - source.keys()} are missing from source" + stack_position
)
raise ValueError(message)
return {key: recursive_reorder(source[key], target[key], stack + f".{key}") for key in target}
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/features/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def encode_example(self, translation_dict):
return translation_dict
elif self.languages and set(translation_dict) - lang_set:
raise ValueError(
f'Some languages in example ({", ".join(sorted(set(translation_dict) - lang_set))}) are not in valid set ({", ".join(lang_set)}).'
f"Some languages in example ({', '.join(sorted(set(translation_dict) - lang_set))}) are not in valid set ({', '.join(lang_set)})."
)

# Convert dictionary into tuples, splitting out cases where there are
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def generate_fingerprint(dataset: "Dataset") -> str:


def generate_random_fingerprint(nbits: int = 64) -> str:
return f"{fingerprint_rng.getrandbits(nbits):0{nbits//4}x}"
return f"{fingerprint_rng.getrandbits(nbits):0{nbits // 4}x}"


def update_fingerprint(fingerprint, transform, transform_args):
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _split_generators(self, dl_manager):
with open(file, "rb") as f:
try:
reader = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
except (OSError, pa.lib.ArrowInvalid):
reader = pa.ipc.open_file(f)
self.info.features = datasets.Features.from_arrow_schema(reader.schema)
break
Expand All @@ -65,7 +65,7 @@ def _generate_tables(self, files):
try:
try:
batches = pa.ipc.open_stream(f)
except pa.lib.ArrowInvalid:
except (OSError, pa.lib.ArrowInvalid):
reader = pa.ipc.open_file(f)
batches = (reader.get_batch(i) for i in range(reader.num_record_batches))
for batch_idx, record_batch in enumerate(batches):
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def passage_generator():
successes += ok
if successes != len(documents):
logger.warning(
f"Some documents failed to be added to ElasticSearch. Failures: {len(documents)-successes}/{len(documents)}"
f"Some documents failed to be added to ElasticSearch. Failures: {len(documents) - successes}/{len(documents)}"
)
logger.info(f"Indexed {successes:d} documents")

Expand Down
2 changes: 1 addition & 1 deletion src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class SplitReadInstruction:
"""

def __init__(self, split_info=None):
self._splits = NonMutableDict(error_msg="Overlap between splits. Split {key} has been added with " "itself.")
self._splits = NonMutableDict(error_msg="Overlap between splits. Split {key} has been added with itself.")

if split_info:
self.add(SlicedSplitInfo(split_info=split_info, slice_value=None))
Expand Down
3 changes: 1 addition & 2 deletions src/datasets/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def _get_default_logging_level():
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option DATASETS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
f"Unknown option DATASETS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level

Expand Down
4 changes: 2 additions & 2 deletions src/datasets/utils/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def stratified_shuffle_split_generate_indices(y, n_train, n_test, rng, n_splits=
raise ValueError("Minimum class count error")
if n_train < n_classes:
raise ValueError(
"The train_size = %d should be greater or " "equal to the number of classes = %d" % (n_train, n_classes)
"The train_size = %d should be greater or equal to the number of classes = %d" % (n_train, n_classes)
)
if n_test < n_classes:
raise ValueError(
"The test_size = %d should be greater or " "equal to the number of classes = %d" % (n_test, n_classes)
"The test_size = %d should be greater or equal to the number of classes = %d" % (n_test, n_classes)
)
class_indices = np.split(np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1])
for _ in range(n_splits):
Expand Down
24 changes: 12 additions & 12 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,9 +1146,9 @@ def test_skip_examples_iterable():
skip_ex_iterable = SkipExamplesIterable(base_ex_iterable, n=count)
expected = list(generate_examples_fn(n=total))[count:]
assert list(skip_ex_iterable) == expected
assert (
skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable
), "skip examples makes the shards order fixed"
assert skip_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is skip_ex_iterable, (
"skip examples makes the shards order fixed"
)
assert_load_state_dict_resumes_iteration(skip_ex_iterable)


Expand All @@ -1158,9 +1158,9 @@ def test_take_examples_iterable():
take_ex_iterable = TakeExamplesIterable(base_ex_iterable, n=count)
expected = list(generate_examples_fn(n=total))[:count]
assert list(take_ex_iterable) == expected
assert (
take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable
), "skip examples makes the shards order fixed"
assert take_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is take_ex_iterable, (
"skip examples makes the shards order fixed"
)
assert_load_state_dict_resumes_iteration(take_ex_iterable)


Expand Down Expand Up @@ -1208,9 +1208,9 @@ def test_horizontally_concatenated_examples_iterable():
concatenated_ex_iterable = HorizontallyConcatenatedMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2])
expected = [{**x, **y} for (_, x), (_, y) in zip(ex_iterable1, ex_iterable2)]
assert [x for _, x in concatenated_ex_iterable] == expected
assert (
concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable
), "horizontally concatenated examples makes the shards order fixed"
assert concatenated_ex_iterable.shuffle_data_sources(np.random.default_rng(42)) is concatenated_ex_iterable, (
"horizontally concatenated examples makes the shards order fixed"
)
assert_load_state_dict_resumes_iteration(concatenated_ex_iterable)


Expand Down Expand Up @@ -2270,7 +2270,7 @@ def test_iterable_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Check last partial batch
assert len(batches[3]["id"]) == 1
Expand All @@ -2287,7 +2287,7 @@ def test_iterable_dataset_batch():
assert len(batch["id"]) == 3
assert len(batch["text"]) == 3
assert batch["id"] == [3 * i, 3 * i + 1, 3 * i + 2]
assert batch["text"] == [f"Text {3*i}", f"Text {3*i+1}", f"Text {3*i+2}"]
assert batch["text"] == [f"Text {3 * i}", f"Text {3 * i + 1}", f"Text {3 * i + 2}"]

# Test with batch_size=4 (doesn't evenly divide dataset size)
batched_ds = ds.batch(batch_size=4, drop_last_batch=False)
Expand All @@ -2298,7 +2298,7 @@ def test_iterable_dataset_batch():
assert len(batch["id"]) == 4
assert len(batch["text"]) == 4
assert batch["id"] == [4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3]
assert batch["text"] == [f"Text {4*i}", f"Text {4*i+1}", f"Text {4*i+2}", f"Text {4*i+3}"]
assert batch["text"] == [f"Text {4 * i}", f"Text {4 * i + 1}", f"Text {4 * i + 2}", f"Text {4 * i + 3}"]

# Check last partial batch
assert len(batches[2]["id"]) == 2
Expand Down

0 comments on commit 7a1a84b

Please sign in to comment.