From d48bf53f4af0496f5b7be5087cf6ba22d6897234 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 14 Jan 2025 19:10:09 +0100 Subject: [PATCH] add pandas and polars formatting in iterabledataset --- src/datasets/formatting/__init__.py | 1 + src/datasets/formatting/formatting.py | 15 +++- src/datasets/formatting/polars_formatter.py | 8 +- src/datasets/iterable_dataset.py | 94 +++++++++++++++------ 4 files changed, 86 insertions(+), 32 deletions(-) diff --git a/src/datasets/formatting/__init__.py b/src/datasets/formatting/__init__.py index 8aa21d37bd2..9771618c7b9 100644 --- a/src/datasets/formatting/__init__.py +++ b/src/datasets/formatting/__init__.py @@ -22,6 +22,7 @@ Formatter, PandasFormatter, PythonFormatter, + TableFormatter, TensorFormatter, format_table, query_table, diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index ddd77353519..c0b31cc16c5 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -429,7 +429,15 @@ def recursive_tensorize(self, data_struct: dict): raise NotImplementedError -class ArrowFormatter(Formatter[pa.Table, pa.Array, pa.Table]): +class TableFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]): + table_type: str + column_type: str + + +class ArrowFormatter(TableFormatter[pa.Table, pa.Array, pa.Table]): + table_type = "arrow table" + column_type = "arrow array" + def format_row(self, pa_table: pa.Table) -> pa.Table: return self.simple_arrow_extractor().extract_row(pa_table) @@ -465,7 +473,10 @@ def format_batch(self, pa_table: pa.Table) -> Mapping: return batch -class PandasFormatter(Formatter[pd.DataFrame, pd.Series, pd.DataFrame]): +class PandasFormatter(TableFormatter[pd.DataFrame, pd.Series, pd.DataFrame]): + table_type = "pandas dataframe" + column_type = "pandas series" + def format_row(self, pa_table: pa.Table) -> pd.DataFrame: row = self.pandas_arrow_extractor().extract_row(pa_table) row = self.pandas_features_decoder.decode_row(row) diff --git a/src/datasets/formatting/polars_formatter.py b/src/datasets/formatting/polars_formatter.py index 543bde52dd0..7ea2f783aec 100644 --- a/src/datasets/formatting/polars_formatter.py +++ b/src/datasets/formatting/polars_formatter.py @@ -13,7 +13,6 @@ # limitations under the License. import sys -from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Optional @@ -23,7 +22,7 @@ from ..features import Features from ..features.features import decode_nested_example from ..utils.py_utils import no_op_if_value_is_null -from .formatting import BaseArrowExtractor, TensorFormatter +from .formatting import BaseArrowExtractor, TableFormatter if TYPE_CHECKING: @@ -98,7 +97,10 @@ def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame": return self.decode_row(batch) -class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]): +class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]): + table_type = "polars dataframe" + column_type = "polars series" + def __init__(self, features=None, **np_array_kwargs): super().__init__(features=features) self.np_array_kwargs = np_array_kwargs diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index f2d47ff64b9..c21ffc02d7c 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -10,13 +10,21 @@ import fsspec.asyn import numpy as np +import pandas as pd import pyarrow as pa from . import config from .arrow_dataset import Dataset, DatasetInfoMixin from .features import Features from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects -from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter +from .formatting import ( + ArrowFormatter, + PythonFormatter, + TableFormatter, + TensorFormatter, + get_format_type_from_alias, + get_formatter, +) from .info import DatasetInfo from .splits import NamedSplit, Split from .table import cast_table_to_features, read_schema_from_file, table_cast @@ -966,6 +974,19 @@ def shard_data_sources( ) +def _table_output_to_arrow(output) -> pa.Table: + if isinstance(output, pa.Table): + return output + if isinstance(output, (pd.DataFrame, pd.Series)): + return pa.Table.from_pandas(output) + if config.POLARS_AVAILABLE and "polars" in sys.modules: + import polars as pl + + if isinstance(output, (pl.DataFrame, pl.Series)): + return output.to_arrow() + return output + + class MappedExamplesIterable(_BaseExamplesIterable): def __init__( self, @@ -994,22 +1015,22 @@ def __init__( self.formatting = formatting # required for iter_arrow self._features = features # sanity checks - if formatting and formatting.format_type == "arrow": + if formatting and formatting.is_table: # batch_size should match for iter_arrow if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): raise ValueError( - "The Arrow-formatted MappedExamplesIterable has underlying iterable" + f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has underlying iterable" f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." ) elif ex_iterable.batch_size != (batch_size if batched else 1): raise ValueError( - f"The Arrow-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is" f"different from {ex_iterable.batch_size=} from its underlying iterable." ) @property def iter_arrow(self): - if self.formatting and self.formatting.format_type == "arrow": + if self.formatting and self.formatting.is_table: return self._iter_arrow @property @@ -1030,7 +1051,7 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): - if self.formatting and self.formatting.format_type == "arrow": + if self.formatting and self.formatting.is_table: formatter = PythonFormatter() for key, pa_table in self._iter_arrow(max_chunksize=1): yield key, formatter.format_row(pa_table) @@ -1156,6 +1177,7 @@ def _iter(self): yield key, transformed_example def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key, pa.Table]]: + formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter() if self.ex_iterable.iter_arrow: iterator = self.ex_iterable.iter_arrow() else: @@ -1182,18 +1204,23 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key ): return # first build the batch - function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + function_args = ( + [formatter.format_batch(pa_table)] + if self.input_columns is None + else [pa_table[col] for col in self.input_columns] + ) if self.with_indices: if self.batched: function_args.append([current_idx + i for i in range(len(pa_table))]) else: function_args.append(current_idx) # then apply the transform - output_table = self.function(*function_args, **self.fn_kwargs) + output = self.function(*function_args, **self.fn_kwargs) + output_table = _table_output_to_arrow(output) if not isinstance(output_table, pa.Table): raise TypeError( - f"Provided `function` which is applied to pyarrow tables returns a variable of type " - f"{type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." + f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " + f"{type(output_table)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset." ) # we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts # then remove the unwanted columns @@ -1280,16 +1307,16 @@ def __init__( self.fn_kwargs = fn_kwargs or {} self.formatting = formatting # required for iter_arrow # sanity checks - if formatting and formatting.format_type == "arrow": + if formatting and formatting.is_table: # batch_size should match for iter_arrow if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): raise ValueError( - "The Arrow-formatted FilteredExamplesIterable has underlying iterable" + f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has underlying iterable" f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." ) elif ex_iterable.batch_size != (batch_size if batched else 1): raise ValueError( - f"The Arrow-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is" f"different from {ex_iterable.batch_size=} from its underlying iterable." ) @@ -1392,6 +1419,7 @@ def _iter(self): yield key, example def _iter_arrow(self, max_chunksize: Optional[int] = None): + formatter = get_formatter(self.formatting) if self.formatting else ArrowFormatter() if self.ex_iterable.iter_arrow: iterator = self.ex_iterable.iter_arrow() else: @@ -1415,14 +1443,24 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None): ): return - function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + function_args = ( + [formatter.format_batch(pa_table)] + if self.input_columns is None + else [pa_table[col] for col in self.input_columns] + ) if self.with_indices: if self.batched: function_args.append([current_idx + i for i in range(len(pa_table))]) else: function_args.append(current_idx) # then apply the transform - mask = self.function(*function_args, **self.fn_kwargs) + output = self.function(*function_args, **self.fn_kwargs) + mask = _table_output_to_arrow(output) + if not isinstance(mask, (pa.Array, pa.BooleanScalar)): + raise TypeError( + f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " + f"{type(output_table)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset." + ) # return output if self.batched: output_table = pa_table.filter(mask) @@ -1734,11 +1772,13 @@ def _apply_feature_types_on_batch( class FormattingConfig: format_type: Optional[str] - def __post_init__(self): - if self.format_type == "pandas": - raise NotImplementedError( - "The 'pandas' formatting is not implemented for iterable datasets. You can use 'numpy' or 'arrow' instead." - ) + @property + def is_table(self) -> bool: + return isinstance(get_formatter(self.format_type), TableFormatter) + + @property + def is_tensor(self) -> bool: + return isinstance(get_formatter(self.format_type), TensorFormatter) class FormattedExamplesIterable(_BaseExamplesIterable): @@ -1757,7 +1797,7 @@ def __init__( @property def iter_arrow(self): - if self.ex_iterable.iter_arrow and (not self.formatting or self.formatting.format_type == "arrow"): + if self.ex_iterable.iter_arrow and (not self.formatting or self.formatting.is_table): return self._iter_arrow @property @@ -1773,7 +1813,7 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): - if not self.formatting or self.formatting.format_type == "arrow": + if not self.formatting or self.formatting.is_table: formatter = PythonFormatter() else: formatter = get_formatter( @@ -2093,7 +2133,7 @@ def _iter_pytorch(self): else: format_dict = None - if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): if ex_iterable.iter_arrow: iterator = ex_iterable.iter_arrow() else: @@ -2133,7 +2173,7 @@ def _prepare_ex_iterable_for_iteration( self, batch_size: int = 1, drop_last_batch: bool = False ) -> _BaseExamplesIterable: ex_iterable = self._ex_iterable - if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): ex_iterable = RebatchedArrowExamplesIterable( ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch ) @@ -2189,7 +2229,7 @@ def __iter__(self): else: format_dict = None - if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): if ex_iterable.iter_arrow: iterator = ex_iterable.iter_arrow() else: @@ -2225,7 +2265,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): format_dict = None ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=batch_size, drop_last_batch=drop_last_batch) - if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): if ex_iterable.iter_arrow: iterator = ex_iterable.iter_arrow() else: @@ -2516,7 +2556,7 @@ def map( else self._info.features ) - if self._formatting and self._formatting.format_type == "arrow": + if self._formatting and self._formatting.is_table: # apply formatting before iter_arrow to keep map examples iterable happy ex_iterable = FormattedExamplesIterable( ex_iterable,