From 16a121d7821a7691815a966270f577e2c503473f Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 9 Oct 2024 17:04:07 +0100 Subject: [PATCH] Preserve features in iterable dataset.filter (#7209) * add is_typed property to example iterables to prevent applying decode_examples multiple times * Update src/datasets/iterable_dataset.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/iterable_dataset.py | 68 ++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index ca8b5fe1de0..d1b54131b61 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -130,6 +130,10 @@ def __iter__(self) -> Iterator[Tuple[Key, dict]]: def iter_arrow(self) -> Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]]: return None + @property + def is_typed(self) -> bool: + return False + def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamplesIterable": """ Either shuffle the shards/sources of the dataset, or propagate the shuffling to the underlying iterable. @@ -393,6 +397,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int] def iter_arrow(self): return self._iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = { "ex_iterable": self.ex_iterable._init_state_dict(), @@ -518,6 +526,10 @@ def iter_arrow(self): if self.ex_iterable.iter_arrow: return self._iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() return self._state_dict @@ -550,6 +562,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int): self.offset = offset # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() return self._state_dict @@ -593,6 +609,10 @@ def __init__( self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterables[0].is_typed + def _get_indices_iterator(self): # this is an infinite iterator to keep track of which iterator we want to pick examples from ex_iterable_idx = self._state_dict["ex_iterable_idx"] if self._state_dict else 0 @@ -687,6 +707,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]): super().__init__() self.ex_iterables = ex_iterables + @property + def is_typed(self): + return self.ex_iterables[0].is_typed + @property def iter_arrow(self): if all(ex_iterable.iter_arrow is not None for ex_iterable in self.ex_iterables): @@ -767,6 +791,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]): self.ex_iterables = ex_iterables # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterables[0].is_typed + def _init_state_dict(self) -> dict: self._state_dict = {"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables]} return self._state_dict @@ -826,6 +854,10 @@ def __init__( self.probabilities = probabilities # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterables[0].is_typed + def _get_indices_iterator(self): rng = deepcopy(self.generator) num_sources = len(self.ex_iterables) @@ -929,6 +961,10 @@ def iter_arrow(self): if self.formatting and self.formatting.format_type == "arrow": return self._iter_arrow + @property + def is_typed(self): + return False + def _init_state_dict(self) -> dict: self._state_dict = { "ex_iterable": self.ex_iterable._init_state_dict(), @@ -1185,6 +1221,10 @@ def iter_arrow(self): if self.formatting and self.formatting.format_type == "arrow": return self._iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = { "ex_iterable": self.ex_iterable._init_state_dict(), @@ -1365,6 +1405,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat self.generator = generator # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() self._original_state_dict = self.state_dict() @@ -1435,6 +1479,10 @@ def __init__( self.split_when_sharding = split_when_sharding # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = {"skipped": False, "ex_iterable": self.ex_iterable._init_state_dict()} return self._state_dict @@ -1498,6 +1546,10 @@ def __init__( self.split_when_sharding = split_when_sharding # TODO(QL): implement iter_arrow + @property + def is_typed(self): + return self.ex_iterable.is_typed + def _init_state_dict(self) -> dict: self._state_dict = {"num_taken": 0, "ex_iterable": self.ex_iterable._init_state_dict()} return self._state_dict @@ -1600,6 +1652,10 @@ def iter_arrow(self): if self.ex_iterable.iter_arrow is not None: return self._iter_arrow + @property + def is_typed(self): + return True + def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() return self._state_dict @@ -1914,7 +1970,7 @@ def _iter_pytorch(self): return else: for key, example in ex_iterable: - if self.features: + if self.features and not ex_iterable.is_typed: # `IterableDataset` automatically fills missing columns with None. # This is done with `_apply_feature_types_on_example`. example = _apply_feature_types_on_example( @@ -2010,7 +2066,7 @@ def __iter__(self): return for key, example in ex_iterable: - if self.features: + if self.features and not ex_iterable.is_typed: # `IterableDataset` automatically fills missing columns with None. # This is done with `_apply_feature_types_on_example`. example = _apply_feature_types_on_example( @@ -2052,7 +2108,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): if drop_last_batch and len(examples) < batch_size: # ignore last batch return batch = _examples_to_batch(examples) - if self.features: + if self.features and not ex_iterable.is_typed: # `IterableDataset` automatically fills missing columns with None. # This is done with `_apply_feature_types_on_batch`. batch = _apply_feature_types_on_batch(batch, self.features, token_per_repo_id=self._token_per_repo_id) @@ -2405,10 +2461,6 @@ def filter( if isinstance(input_columns, str): input_columns = [input_columns] - # TODO(QL): keep the features (right now if we keep it it would call decode_example again on an already decoded example) - info = copy.deepcopy(self._info) - info.features = None - # We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here ex_iterable = FilteredExamplesIterable( TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id) @@ -2424,7 +2476,7 @@ def filter( ) return IterableDataset( ex_iterable=ex_iterable, - info=info, + info=self._info, split=self._split, formatting=self._formatting, shuffling=copy.deepcopy(self._shuffling),