-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
apply formatting after iter_arrow to speed up format -> map, filter for iterable datasets #7207
Changes from 5 commits
a73bb02
4a761a9
3b65d99
d906b9f
421917d
e7b67c3
a4f9700
a465abd
1863f8c
205e0d6
42dc44f
4a8fed5
8cdf6a6
2ddaa7d
dcd5017
1ae947e
20330e8
8f6845f
3a91aac
84fcf74
4fac60a
afa78aa
5a8389b
c97f02e
76e09a1
ee45f7f
b828575
f76701b
15a8cfe
884bba1
d979672
190d062
85b7d4d
3129274
45f55b4
5e31fe0
49a84fe
002f5b4
2479264
68bfa39
38f78d2
f59a8e6
ca2deb4
4efcf11
f997f8c
bd8bbd3
7d1c48d
441a95b
9a0e112
66d59c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -885,6 +885,41 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "RandomlyCycli | |
) | ||
|
||
|
||
def formatted_arrow_examples_iterator(ex_iterable, formatter, batched: bool = False): | ||
for key, pa_table in ex_iterable.iter_arrow(): | ||
if batched: | ||
yield key, formatter.format_batch(pa_table) | ||
else: | ||
yield key, formatter.format_row(pa_table) | ||
|
||
|
||
def formatted_python_examples_iterator(ex_iterable, batch_size, formatter, batched: bool = False): | ||
iterator = iter(ex_iterable) | ||
if formatter: | ||
format_dict = ( | ||
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects | ||
) | ||
else: | ||
format_dict = None | ||
if batched: | ||
for key, example in iterator: | ||
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset | ||
iterator_batch = ( | ||
iterator if batch_size is None or batch_size <= 0 else islice(iterator, batch_size - 1) | ||
) # take the next batch_size - 1 examples from iterator | ||
key_examples_list = [(key, example)] + list(iterator_batch) | ||
keys, examples = zip(*key_examples_list) | ||
batch = _examples_to_batch(examples) | ||
batch = format_dict(batch) if format_dict else batch | ||
# the new key is the concatenation of the examples keys from the batch | ||
new_key = "_".join(str(key) for key in keys) | ||
yield new_key, batch | ||
else: | ||
for key, example in iterator: | ||
example = format_dict(example) if format_dict else example | ||
yield key, example | ||
|
||
|
||
class MappedExamplesIterable(_BaseExamplesIterable): | ||
def __init__( | ||
self, | ||
|
@@ -953,44 +988,43 @@ def _iter(self): | |
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"] | ||
else: | ||
num_examples_to_skip = 0 | ||
iterator = iter(self.ex_iterable) | ||
|
||
if self.formatting: | ||
formatter = get_formatter(self.formatting.format_type) | ||
format_dict = ( | ||
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects | ||
formatter = get_formatter(self.formatting.format_type) if self.formatting else None | ||
if self.formatting and self.ex_iterable.iter_arrow: | ||
# we still want to use an arrow iterator, yielding single batches of size self.batch_size | ||
# to which the formatter can be applied | ||
ex_iterable = RebatchedArrowExamplesIterable( | ||
self.ex_iterable, batch_size=self.batch_size if self.batched else 1, drop_last_batch=False | ||
) | ||
batched_examples_iterator = formatted_arrow_examples_iterator(ex_iterable, formatter, batched=self.batched) | ||
|
||
else: | ||
format_dict = None | ||
batched_examples_iterator = formatted_python_examples_iterator( | ||
self.ex_iterable, batch_size=self.batch_size, formatter=formatter, batched=self.batched | ||
) | ||
|
||
if self.batched: | ||
if self._state_dict: | ||
self._state_dict["previous_state"] = self.ex_iterable.state_dict() | ||
self._state_dict["num_examples_since_previous_state"] = 0 | ||
self._state_dict["previous_state_example_idx"] = current_idx | ||
for key, example in iterator: | ||
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset | ||
iterator_batch = ( | ||
iterator | ||
if self.batch_size is None or self.batch_size <= 0 | ||
else islice(iterator, self.batch_size - 1) | ||
) | ||
key_examples_list = [(key, example)] + list(iterator_batch) | ||
keys, examples = zip(*key_examples_list) | ||
for new_key, batch in batched_examples_iterator: | ||
if batch: | ||
batch_len = len(batch[next(iter(batch))]) | ||
else: | ||
batch_len = 0 | ||
if ( | ||
self.drop_last_batch | ||
and self.batch_size is not None | ||
and self.batch_size > 0 | ||
and len(examples) < self.batch_size | ||
and batch_len < self.batch_size | ||
): # ignore last batch | ||
return | ||
batch = _examples_to_batch(examples) | ||
batch = format_dict(batch) if format_dict else batch | ||
# then apply the transform | ||
inputs = batch | ||
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] | ||
if self.with_indices: | ||
function_args.append([current_idx + i for i in range(len(key_examples_list))]) | ||
function_args.append([current_idx + i for i in range(batch_len)]) | ||
transformed_batch = dict(batch) # this will be updated with the function output | ||
transformed_batch.update(self.function(*function_args, **self.fn_kwargs)) | ||
# then remove the unwanted columns | ||
|
@@ -1006,10 +1040,10 @@ def _iter(self): | |
] | ||
if bad_cols: | ||
raise ValueError( | ||
f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]} while {first_col} has length {len(transformed_batch[first_col])}." | ||
f"Column lengths mismatch: columns {bad_cols} have length {[len(transformed_batch[col]) for col in bad_cols]}" | ||
f" while {first_col} has length {len(transformed_batch[first_col])}." | ||
) | ||
# the new key is the concatenation of the examples keys from the batch | ||
new_key = "_".join(str(key) for key in keys) | ||
|
||
# yield one example at a time from the transformed batch | ||
for example in _batch_to_examples(transformed_batch): | ||
current_idx += 1 | ||
|
@@ -1024,11 +1058,10 @@ def _iter(self): | |
self._state_dict["num_examples_since_previous_state"] = 0 | ||
self._state_dict["previous_state_example_idx"] = current_idx | ||
else: | ||
for key, example in iterator: | ||
for key, example in batched_examples_iterator: | ||
# If not batched, we can apply the transform and yield the example directly | ||
# first copy the example, since we might drop some keys | ||
example = dict(example) | ||
example = format_dict(example) if format_dict else example | ||
# then apply the transform | ||
inputs = example | ||
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] | ||
|
@@ -1082,7 +1115,8 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key | |
output_table = self.function(*function_args, **self.fn_kwargs) | ||
if not isinstance(output_table, pa.Table): | ||
raise TypeError( | ||
f"Provided `function` which is applied to pyarrow tables returns a variable of type {type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." | ||
f"Provided `function` which is applied to pyarrow tables returns a variable of type " | ||
f"{type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." | ||
) | ||
# we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts | ||
# then remove the unwanted columns | ||
|
@@ -1209,57 +1243,62 @@ def _iter(self): | |
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"] | ||
else: | ||
num_examples_to_skip = 0 | ||
iterator = iter(self.ex_iterable) | ||
|
||
if self.formatting: | ||
formatter = get_formatter(self.formatting.format_type) | ||
format_dict = ( | ||
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects | ||
formatter = get_formatter(self.formatting.format_type) if self.formatting else None | ||
if self.formatting and self.ex_iterable.iter_arrow: | ||
# we still want to use an arrow iterator, yielding single batches of size self.batch_size | ||
# to which the formatter can be applied | ||
ex_iterable = RebatchedArrowExamplesIterable( | ||
self.ex_iterable, batch_size=self.batch_size if self.batched else 1, drop_last_batch=False | ||
) | ||
batched_examples_iterator = formatted_arrow_examples_iterator(ex_iterable, formatter, batched=self.batched) | ||
|
||
else: | ||
format_dict = None | ||
batched_examples_iterator = formatted_python_examples_iterator( | ||
self.ex_iterable, batch_size=self.batch_size, formatter=formatter, batched=self.batched | ||
) | ||
|
||
if self.batched: | ||
if self._state_dict: | ||
self._state_dict["previous_state"] = self.ex_iterable.state_dict() | ||
self._state_dict["num_examples_since_previous_state"] = 0 | ||
self._state_dict["previous_state_example_idx"] = current_idx | ||
for key, example in iterator: | ||
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset | ||
iterator_batch = ( | ||
iterator | ||
if self.batch_size is None or self.batch_size <= 0 | ||
else islice(iterator, self.batch_size - 1) | ||
) | ||
key_examples_list = [(key, example)] + list(iterator_batch) | ||
keys, examples = zip(*key_examples_list) | ||
batch = _examples_to_batch(examples) | ||
batch = format_dict(batch) if format_dict else batch | ||
for combined_key, batch in batched_examples_iterator: | ||
if batch: | ||
batch_len = len(batch[next(iter(batch))]) | ||
else: | ||
batch_len = 0 | ||
# then compute the mask for the batch | ||
inputs = batch | ||
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] | ||
if self.with_indices: | ||
function_args.append([current_idx + i for i in range(len(key_examples_list))]) | ||
function_args.append([current_idx + i for i in range(batch_len)]) | ||
mask = self.function(*function_args, **self.fn_kwargs) | ||
# yield one example at a time from the batch | ||
for key_example, to_keep in zip(key_examples_list, mask): | ||
examples = _batch_to_examples(batch) | ||
# TODO: nicer way to handle keys? | ||
if not self.formatting: | ||
keys = combined_key.split("_") | ||
else: | ||
keys = [combined_key] * len(mask) | ||
alex-hh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for key, example, to_keep in zip(keys, examples, mask): | ||
current_idx += 1 | ||
if self._state_dict: | ||
self._state_dict["num_examples_since_previous_state"] += 1 | ||
if num_examples_to_skip > 0: | ||
num_examples_to_skip -= 1 | ||
continue | ||
if to_keep: | ||
yield key_example | ||
yield key, example | ||
if self._state_dict: | ||
self._state_dict["previous_state"] = self.ex_iterable.state_dict() | ||
self._state_dict["num_examples_since_previous_state"] = 0 | ||
self._state_dict["previous_state_example_idx"] = current_idx | ||
else: | ||
for key, example in iterator: | ||
for key, example in batched_examples_iterator: | ||
# If not batched, we can apply the filtering function direcly | ||
example = dict(example) | ||
inputs = format_dict(example) if format_dict else example | ||
inputs = example | ||
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns] | ||
if self.with_indices: | ||
function_args.append(current_idx) | ||
|
@@ -2211,6 +2250,7 @@ def map( | |
remove_columns: Optional[Union[str, List[str]]] = None, | ||
features: Optional[Features] = None, | ||
fn_kwargs: Optional[dict] = None, | ||
format_outputs: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this useful ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was useful for optimising performance in my case - idea is it allows the user to determine the formatting of the examples returned by map (within the mapped function) rather than via self.formatting, which determines the formatting of the inputs to the map function if map returns lists / arbitrary types then re-applying the formatter to the outputs can be expensive However, haven't thought about how to handle preserving a consistent self._formatting FormattingConfig for downstream dataset transformations...might require something better than this kwarg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can still do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perfect - will remove the kwarg |
||
) -> "IterableDataset": | ||
""" | ||
Apply a function to all the examples in the iterable dataset (individually or in batches) and update them. | ||
|
@@ -2307,15 +2347,15 @@ def map( | |
drop_last_batch=drop_last_batch, | ||
remove_columns=remove_columns, | ||
fn_kwargs=fn_kwargs, | ||
formatting=self._formatting, | ||
formatting=copy.deepcopy(self._formatting), | ||
) | ||
info = self.info.copy() | ||
info.features = features | ||
return IterableDataset( | ||
ex_iterable=ex_iterable, | ||
info=info, | ||
split=self._split, | ||
formatting=self._formatting, | ||
formatting=self._formatting if format_outputs else None, | ||
shuffling=copy.deepcopy(self._shuffling), | ||
distributed=copy.deepcopy(self._distributed), | ||
token_per_repo_id=self._token_per_repo_id, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe name it
input_iterator
or something like that since it's not necessarily batched ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated