Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply formatting after iter_arrow to speed up format -> map, filter for iterable datasets #7207

Merged
merged 50 commits into from
Jan 14, 2025
Merged
Changes from 5 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
a73bb02
apply formatting after iter_arrow
alex-hh Oct 8, 2024
4a761a9
add support for formatting to map iteration
alex-hh Oct 8, 2024
3b65d99
formatted iterator for filter
alex-hh Oct 8, 2024
d906b9f
fix filtered formatting
alex-hh Oct 8, 2024
421917d
option to disable formatting for outputs of map
alex-hh Oct 8, 2024
e7b67c3
remove format_outputs kwarg
alex-hh Oct 9, 2024
a4f9700
rename batched_examples_iterator -> inputs_iterator
alex-hh Oct 9, 2024
a465abd
support arbitrary input formatting in filtered examples iterable iter…
alex-hh Oct 9, 2024
1863f8c
preserve formatting on filtered shuffle
alex-hh Oct 9, 2024
205e0d6
pass token_per_repo_id to python_feature_decoder in formatters
alex-hh Oct 9, 2024
42dc44f
implement FormattedExamplesIterator
alex-hh Oct 9, 2024
4a8fed5
fix formatted examples iterable
alex-hh Oct 9, 2024
8cdf6a6
Merge branch 'main' into iterable-map-with-format
alex-hh Oct 9, 2024
2ddaa7d
restore is_typed property
alex-hh Oct 9, 2024
dcd5017
pass formatting config to formatted examples iterable
alex-hh Oct 9, 2024
1ae947e
fix formatter init
alex-hh Oct 9, 2024
20330e8
Merge branch 'main' into iterable-map-with-format
alex-hh Oct 9, 2024
8f6845f
map examples iterable expects to receive rebatchedarrowexamplesiterab…
alex-hh Oct 9, 2024
3a91aac
only apply features if they exist
alex-hh Oct 9, 2024
84fcf74
fix shuffle and shard
alex-hh Oct 9, 2024
4fac60a
remove formatting from FilteredExamplesIterable
alex-hh Oct 10, 2024
afa78aa
run pre commit
alex-hh Oct 10, 2024
5a8389b
filtered iter_arrow always allowed if available
alex-hh Oct 10, 2024
c97f02e
filtered examples iterable needs formatting when iter_arrow enabled
alex-hh Oct 10, 2024
76e09a1
only iter arrow on filter if formatting is set
alex-hh Oct 10, 2024
ee45f7f
add features property to support feature inference
alex-hh Oct 10, 2024
b828575
fix features property
alex-hh Oct 10, 2024
f76701b
dont re-encode featuers
alex-hh Oct 10, 2024
15a8cfe
avoid re-encoding outputs of map
alex-hh Oct 10, 2024
884bba1
map should not preserve formatting
alex-hh Oct 10, 2024
d979672
update comment
alex-hh Oct 10, 2024
190d062
update map features property
alex-hh Oct 10, 2024
85b7d4d
return bool for mapped ex iterable is typed
alex-hh Oct 11, 2024
3129274
pass return features to mapped exampels iterable constructor
alex-hh Oct 11, 2024
45f55b4
don't iter arrow with formatted filter to avoid re formatting
alex-hh Oct 11, 2024
5e31fe0
avoid re-formatting data
alex-hh Oct 12, 2024
49a84fe
rename return features -> features
alex-hh Oct 14, 2024
002f5b4
update refs to return_features
alex-hh Oct 14, 2024
2479264
decode features in batched map
alex-hh Oct 14, 2024
68bfa39
preserve formatting in with_format
alex-hh Oct 15, 2024
38f78d2
fix features (mapped ex iterable
alex-hh Oct 16, 2024
f59a8e6
Merge branch 'main' into iterable-map-with-format
alex-hh Oct 31, 2024
ca2deb4
update shard
alex-hh Oct 31, 2024
4efcf11
remove formatted examples iterable from with_format
alex-hh Nov 2, 2024
f997f8c
avoid reapplying features when chaining filter, map
alex-hh Nov 2, 2024
bd8bbd3
preserve formatting in map
alex-hh Nov 11, 2024
7d1c48d
Merge branch 'main' into iterable-map-with-format
lhoestq Jan 13, 2025
441a95b
fix tests
lhoestq Jan 13, 2025
9a0e112
style
lhoestq Jan 13, 2025
66d59c7
fix tests
lhoestq Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 89 additions & 49 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,41 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "RandomlyCycli
)


def formatted_arrow_examples_iterator(ex_iterable, formatter, batched: bool = False):
for key, pa_table in ex_iterable.iter_arrow():
if batched:
yield key, formatter.format_batch(pa_table)
else:
yield key, formatter.format_row(pa_table)


def formatted_python_examples_iterator(ex_iterable, batch_size, formatter, batched: bool = False):
iterator = iter(ex_iterable)
if formatter:
format_dict = (
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
)
else:
format_dict = None
if batched:
for key, example in iterator:
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
iterator_batch = (
iterator if batch_size is None or batch_size <= 0 else islice(iterator, batch_size - 1)
) # take the next batch_size - 1 examples from iterator
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
batch = _examples_to_batch(examples)
batch = format_dict(batch) if format_dict else batch
# the new key is the concatenation of the examples keys from the batch
new_key = "_".join(str(key) for key in keys)
yield new_key, batch
else:
for key, example in iterator:
example = format_dict(example) if format_dict else example
yield key, example


class MappedExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
Expand Down Expand Up @@ -953,44 +988,43 @@ def _iter(self):
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
else:
num_examples_to_skip = 0
iterator = iter(self.ex_iterable)

if self.formatting:
formatter = get_formatter(self.formatting.format_type)
format_dict = (
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
formatter = get_formatter(self.formatting.format_type) if self.formatting else None
if self.formatting and self.ex_iterable.iter_arrow:
# we still want to use an arrow iterator, yielding single batches of size self.batch_size
# to which the formatter can be applied
ex_iterable = RebatchedArrowExamplesIterable(
self.ex_iterable, batch_size=self.batch_size if self.batched else 1, drop_last_batch=False
)
batched_examples_iterator = formatted_arrow_examples_iterator(ex_iterable, formatter, batched=self.batched)

else:
format_dict = None
batched_examples_iterator = formatted_python_examples_iterator(
self.ex_iterable, batch_size=self.batch_size, formatter=formatter, batched=self.batched
)

if self.batched:
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
for key, example in iterator:
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
iterator_batch = (
iterator
if self.batch_size is None or self.batch_size <= 0
else islice(iterator, self.batch_size - 1)
)
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
for new_key, batch in batched_examples_iterator:
if batch:
batch_len = len(batch[next(iter(batch))])
else:
batch_len = 0
if (
self.drop_last_batch
and self.batch_size is not None
and self.batch_size > 0
and len(examples) < self.batch_size
and batch_len < self.batch_size
): # ignore last batch
return
batch = _examples_to_batch(examples)
batch = format_dict(batch) if format_dict else batch
# then apply the transform
inputs = batch
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append([current_idx + i for i in range(len(key_examples_list))])
function_args.append([current_idx + i for i in range(batch_len)])
transformed_batch = dict(batch) # this will be updated with the function output
transformed_batch.update(self.function(*function_args, **self.fn_kwargs))
# then remove the unwanted columns
Expand All @@ -1006,10 +1040,10 @@ def _iter(self):
]
if bad_cols:
raise ValueError(
f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]} while {first_col} has length {len(transformed_batch[first_col])}."
f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]}"
f" while {first_col} has length {len(transformed_batch[first_col])}."
)
# the new key is the concatenation of the examples keys from the batch
new_key = "_".join(str(key) for key in keys)

# yield one example at a time from the transformed batch
for example in _batch_to_examples(transformed_batch):
current_idx += 1
Expand All @@ -1024,11 +1058,10 @@ def _iter(self):
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
else:
for key, example in iterator:
for key, example in batched_examples_iterator:
# If not batched, we can apply the transform and yield the example directly
# first copy the example, since we might drop some keys
example = dict(example)
example = format_dict(example) if format_dict else example
# then apply the transform
inputs = example
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
Expand Down Expand Up @@ -1082,7 +1115,8 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key
output_table = self.function(*function_args, **self.fn_kwargs)
if not isinstance(output_table, pa.Table):
raise TypeError(
f"Provided `function` which is applied to pyarrow tables returns a variable of type {type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset."
f"Provided `function` which is applied to pyarrow tables returns a variable of type "
f"{type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset."
)
# we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts
# then remove the unwanted columns
Expand Down Expand Up @@ -1209,57 +1243,62 @@ def _iter(self):
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
else:
num_examples_to_skip = 0
iterator = iter(self.ex_iterable)

if self.formatting:
formatter = get_formatter(self.formatting.format_type)
format_dict = (
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
formatter = get_formatter(self.formatting.format_type) if self.formatting else None
if self.formatting and self.ex_iterable.iter_arrow:
# we still want to use an arrow iterator, yielding single batches of size self.batch_size
# to which the formatter can be applied
ex_iterable = RebatchedArrowExamplesIterable(
self.ex_iterable, batch_size=self.batch_size if self.batched else 1, drop_last_batch=False
)
batched_examples_iterator = formatted_arrow_examples_iterator(ex_iterable, formatter, batched=self.batched)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe name it input_iterator or something like that since it's not necessarily batched ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


else:
format_dict = None
batched_examples_iterator = formatted_python_examples_iterator(
self.ex_iterable, batch_size=self.batch_size, formatter=formatter, batched=self.batched
)

if self.batched:
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
for key, example in iterator:
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
iterator_batch = (
iterator
if self.batch_size is None or self.batch_size <= 0
else islice(iterator, self.batch_size - 1)
)
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
batch = _examples_to_batch(examples)
batch = format_dict(batch) if format_dict else batch
for combined_key, batch in batched_examples_iterator:
if batch:
batch_len = len(batch[next(iter(batch))])
else:
batch_len = 0
# then compute the mask for the batch
inputs = batch
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append([current_idx + i for i in range(len(key_examples_list))])
function_args.append([current_idx + i for i in range(batch_len)])
mask = self.function(*function_args, **self.fn_kwargs)
# yield one example at a time from the batch
for key_example, to_keep in zip(key_examples_list, mask):
examples = _batch_to_examples(batch)
# TODO: nicer way to handle keys?
if not self.formatting:
keys = combined_key.split("_")
else:
keys = [combined_key] * len(mask)
alex-hh marked this conversation as resolved.
Show resolved Hide resolved
for key, example, to_keep in zip(keys, examples, mask):
current_idx += 1
if self._state_dict:
self._state_dict["num_examples_since_previous_state"] += 1
if num_examples_to_skip > 0:
num_examples_to_skip -= 1
continue
if to_keep:
yield key_example
yield key, example
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
else:
for key, example in iterator:
for key, example in batched_examples_iterator:
# If not batched, we can apply the filtering function direcly
example = dict(example)
inputs = format_dict(example) if format_dict else example
inputs = example
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append(current_idx)
Expand Down Expand Up @@ -2211,6 +2250,7 @@ def map(
remove_columns: Optional[Union[str, List[str]]] = None,
features: Optional[Features] = None,
fn_kwargs: Optional[dict] = None,
format_outputs: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this useful ?

Copy link
Contributor Author

@alex-hh alex-hh Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was useful for optimising performance in my case -

idea is it allows the user to determine the formatting of the examples returned by map (within the mapped function) rather than via self.formatting, which determines the formatting of the inputs to the map function

if map returns lists / arbitrary types then re-applying the formatter to the outputs can be expensive

However, haven't thought about how to handle preserving a consistent self._formatting FormattingConfig for downstream dataset transformations...might require something better than this kwarg

Copy link
Member

@lhoestq lhoestq Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still do ds = ds.map(...).with_format(None) and no formatting will be applied to the output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect - will remove the kwarg

) -> "IterableDataset":
"""
Apply a function to all the examples in the iterable dataset (individually or in batches) and update them.
Expand Down Expand Up @@ -2307,15 +2347,15 @@ def map(
drop_last_batch=drop_last_batch,
remove_columns=remove_columns,
fn_kwargs=fn_kwargs,
formatting=self._formatting,
formatting=copy.deepcopy(self._formatting),
)
info = self.info.copy()
info.features = features
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
split=self._split,
formatting=self._formatting,
formatting=self._formatting if format_outputs else None,
shuffling=copy.deepcopy(self._shuffling),
distributed=copy.deepcopy(self._distributed),
token_per_repo_id=self._token_per_repo_id,
Expand Down
Loading