Skip to content

Commit

Permalink
add pandas and polars formatting in iterabledataset
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Jan 14, 2025
1 parent 75e61d1 commit d48bf53
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/datasets/formatting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Formatter,
PandasFormatter,
PythonFormatter,
TableFormatter,
TensorFormatter,
format_table,
query_table,
Expand Down
15 changes: 13 additions & 2 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,15 @@ def recursive_tensorize(self, data_struct: dict):
raise NotImplementedError


class ArrowFormatter(Formatter[pa.Table, pa.Array, pa.Table]):
class TableFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]):
table_type: str
column_type: str


class ArrowFormatter(TableFormatter[pa.Table, pa.Array, pa.Table]):
table_type = "arrow table"
column_type = "arrow array"

def format_row(self, pa_table: pa.Table) -> pa.Table:
return self.simple_arrow_extractor().extract_row(pa_table)

Expand Down Expand Up @@ -465,7 +473,10 @@ def format_batch(self, pa_table: pa.Table) -> Mapping:
return batch


class PandasFormatter(Formatter[pd.DataFrame, pd.Series, pd.DataFrame]):
class PandasFormatter(TableFormatter[pd.DataFrame, pd.Series, pd.DataFrame]):
table_type = "pandas dataframe"
column_type = "pandas series"

def format_row(self, pa_table: pa.Table) -> pd.DataFrame:
row = self.pandas_arrow_extractor().extract_row(pa_table)
row = self.pandas_features_decoder.decode_row(row)
Expand Down
8 changes: 5 additions & 3 deletions src/datasets/formatting/polars_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import sys
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Optional

Expand All @@ -23,7 +22,7 @@
from ..features import Features
from ..features.features import decode_nested_example
from ..utils.py_utils import no_op_if_value_is_null
from .formatting import BaseArrowExtractor, TensorFormatter
from .formatting import BaseArrowExtractor, TableFormatter


if TYPE_CHECKING:
Expand Down Expand Up @@ -98,7 +97,10 @@ def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame":
return self.decode_row(batch)


class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]):
class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]):
table_type = "polars dataframe"
column_type = "polars series"

def __init__(self, features=None, **np_array_kwargs):
super().__init__(features=features)
self.np_array_kwargs = np_array_kwargs
Expand Down
94 changes: 67 additions & 27 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,21 @@

import fsspec.asyn
import numpy as np
import pandas as pd
import pyarrow as pa

from . import config
from .arrow_dataset import Dataset, DatasetInfoMixin
from .features import Features
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects
from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter
from .formatting import (
ArrowFormatter,
PythonFormatter,
TableFormatter,
TensorFormatter,
get_format_type_from_alias,
get_formatter,
)
from .info import DatasetInfo
from .splits import NamedSplit, Split
from .table import cast_table_to_features, read_schema_from_file, table_cast
Expand Down Expand Up @@ -966,6 +974,19 @@ def shard_data_sources(
)


def _table_output_to_arrow(output) -> pa.Table:
if isinstance(output, pa.Table):
return output
if isinstance(output, (pd.DataFrame, pd.Series)):
return pa.Table.from_pandas(output)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

if isinstance(output, (pl.DataFrame, pl.Series)):
return output.to_arrow()
return output


class MappedExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
Expand Down Expand Up @@ -994,22 +1015,22 @@ def __init__(
self.formatting = formatting # required for iter_arrow
self._features = features
# sanity checks
if formatting and formatting.format_type == "arrow":
if formatting and formatting.is_table:
# batch_size should match for iter_arrow
if not isinstance(ex_iterable, RebatchedArrowExamplesIterable):
raise ValueError(
"The Arrow-formatted MappedExamplesIterable has underlying iterable"
f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has underlying iterable"
f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable."
)
elif ex_iterable.batch_size != (batch_size if batched else 1):
raise ValueError(
f"The Arrow-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is"
f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is"
f"different from {ex_iterable.batch_size=} from its underlying iterable."
)

@property
def iter_arrow(self):
if self.formatting and self.formatting.format_type == "arrow":
if self.formatting and self.formatting.is_table:
return self._iter_arrow

@property
Expand All @@ -1030,7 +1051,7 @@ def _init_state_dict(self) -> dict:
return self._state_dict

def __iter__(self):
if self.formatting and self.formatting.format_type == "arrow":
if self.formatting and self.formatting.is_table:
formatter = PythonFormatter()
for key, pa_table in self._iter_arrow(max_chunksize=1):
yield key, formatter.format_row(pa_table)
Expand Down Expand Up @@ -1156,6 +1177,7 @@ def _iter(self):
yield key, transformed_example

def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key, pa.Table]]:
formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter()
if self.ex_iterable.iter_arrow:
iterator = self.ex_iterable.iter_arrow()
else:
Expand All @@ -1182,18 +1204,23 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key
):
return
# first build the batch
function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns]
function_args = (
[formatter.format_batch(pa_table)]
if self.input_columns is None
else [pa_table[col] for col in self.input_columns]
)
if self.with_indices:
if self.batched:
function_args.append([current_idx + i for i in range(len(pa_table))])
else:
function_args.append(current_idx)
# then apply the transform
output_table = self.function(*function_args, **self.fn_kwargs)
output = self.function(*function_args, **self.fn_kwargs)
output_table = _table_output_to_arrow(output)
if not isinstance(output_table, pa.Table):
raise TypeError(
f"Provided `function` which is applied to pyarrow tables returns a variable of type "
f"{type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset."
f"Provided `function` which is applied to {formatter.table_type} returns a variable of type "
f"{type(output_table)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset."
)
# we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts
# then remove the unwanted columns
Expand Down Expand Up @@ -1280,16 +1307,16 @@ def __init__(
self.fn_kwargs = fn_kwargs or {}
self.formatting = formatting # required for iter_arrow
# sanity checks
if formatting and formatting.format_type == "arrow":
if formatting and formatting.is_table:
# batch_size should match for iter_arrow
if not isinstance(ex_iterable, RebatchedArrowExamplesIterable):
raise ValueError(
"The Arrow-formatted FilteredExamplesIterable has underlying iterable"
f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has underlying iterable"
f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable."
)
elif ex_iterable.batch_size != (batch_size if batched else 1):
raise ValueError(
f"The Arrow-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is"
f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is"
f"different from {ex_iterable.batch_size=} from its underlying iterable."
)

Expand Down Expand Up @@ -1392,6 +1419,7 @@ def _iter(self):
yield key, example

def _iter_arrow(self, max_chunksize: Optional[int] = None):
formatter = get_formatter(self.formatting) if self.formatting else ArrowFormatter()
if self.ex_iterable.iter_arrow:
iterator = self.ex_iterable.iter_arrow()
else:
Expand All @@ -1415,14 +1443,24 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None):
):
return

function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns]
function_args = (
[formatter.format_batch(pa_table)]
if self.input_columns is None
else [pa_table[col] for col in self.input_columns]
)
if self.with_indices:
if self.batched:
function_args.append([current_idx + i for i in range(len(pa_table))])
else:
function_args.append(current_idx)
# then apply the transform
mask = self.function(*function_args, **self.fn_kwargs)
output = self.function(*function_args, **self.fn_kwargs)
mask = _table_output_to_arrow(output)
if not isinstance(mask, (pa.Array, pa.BooleanScalar)):
raise TypeError(
f"Provided `function` which is applied to {formatter.table_type} returns a variable of type "
f"{type(output_table)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset."
)
# return output
if self.batched:
output_table = pa_table.filter(mask)
Expand Down Expand Up @@ -1734,11 +1772,13 @@ def _apply_feature_types_on_batch(
class FormattingConfig:
format_type: Optional[str]

def __post_init__(self):
if self.format_type == "pandas":
raise NotImplementedError(
"The 'pandas' formatting is not implemented for iterable datasets. You can use 'numpy' or 'arrow' instead."
)
@property
def is_table(self) -> bool:
return isinstance(get_formatter(self.format_type), TableFormatter)

@property
def is_tensor(self) -> bool:
return isinstance(get_formatter(self.format_type), TensorFormatter)


class FormattedExamplesIterable(_BaseExamplesIterable):
Expand All @@ -1757,7 +1797,7 @@ def __init__(

@property
def iter_arrow(self):
if self.ex_iterable.iter_arrow and (not self.formatting or self.formatting.format_type == "arrow"):
if self.ex_iterable.iter_arrow and (not self.formatting or self.formatting.is_table):
return self._iter_arrow

@property
Expand All @@ -1773,7 +1813,7 @@ def _init_state_dict(self) -> dict:
return self._state_dict

def __iter__(self):
if not self.formatting or self.formatting.format_type == "arrow":
if not self.formatting or self.formatting.is_table:
formatter = PythonFormatter()
else:
formatter = get_formatter(
Expand Down Expand Up @@ -2093,7 +2133,7 @@ def _iter_pytorch(self):
else:
format_dict = None

if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"):
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
if ex_iterable.iter_arrow:
iterator = ex_iterable.iter_arrow()
else:
Expand Down Expand Up @@ -2133,7 +2173,7 @@ def _prepare_ex_iterable_for_iteration(
self, batch_size: int = 1, drop_last_batch: bool = False
) -> _BaseExamplesIterable:
ex_iterable = self._ex_iterable
if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"):
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
ex_iterable = RebatchedArrowExamplesIterable(
ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch
)
Expand Down Expand Up @@ -2189,7 +2229,7 @@ def __iter__(self):
else:
format_dict = None

if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"):
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
if ex_iterable.iter_arrow:
iterator = ex_iterable.iter_arrow()
else:
Expand Down Expand Up @@ -2225,7 +2265,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
format_dict = None

ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=batch_size, drop_last_batch=drop_last_batch)
if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"):
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
if ex_iterable.iter_arrow:
iterator = ex_iterable.iter_arrow()
else:
Expand Down Expand Up @@ -2516,7 +2556,7 @@ def map(
else self._info.features
)

if self._formatting and self._formatting.format_type == "arrow":
if self._formatting and self._formatting.is_table:
# apply formatting before iter_arrow to keep map examples iterable happy
ex_iterable = FormattedExamplesIterable(
ex_iterable,
Expand Down

0 comments on commit d48bf53

Please sign in to comment.