diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index c4e46079f64..4ef9fd7cf67 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -15,6 +15,8 @@ # Lint as: python3 """Simple Dataset wrapping an Arrow Table.""" +from __future__ import annotations + import asyncio import contextlib import copy @@ -44,10 +46,15 @@ from pathlib import Path from random import sample from typing import ( + IO, TYPE_CHECKING, Any, BinaryIO, Callable, + Generator, + Iterable, + Iterator, + MutableMapping, Optional, Union, overload, @@ -100,7 +107,7 @@ ) from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table from .formatting.formatting import LazyDict, _is_range_contiguous -from .info import DatasetInfo, DatasetInfosDict +from .info import DatasetInfo, DatasetInfosDict, SupervisedKeysData, Version from .naming import _split_re from .search import IndexableMixin from .splits import NamedSplit, Split, SplitDict, SplitInfo @@ -138,13 +145,19 @@ if TYPE_CHECKING: import sqlite3 + import elasticsearch + import faiss import polars as pl import pyspark import sqlalchemy from .dataset_dict import DatasetDict + from .info import SupervisedKeysData, Version from .iterable_dataset import IterableDataset + if config.TF_AVAILABLE: + import tensorflow as tf + logger = logging.get_logger(__name__) PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED = ( @@ -157,22 +170,22 @@ class DatasetInfoMixin: at the base level of the Dataset for easy access. """ - def __init__(self, info: DatasetInfo, split: Optional[NamedSplit]): + def __init__(self, info: DatasetInfo, split: NamedSplit) -> None: self._info = info self._split = split @property - def info(self): + def info(self) -> DatasetInfo: """[`~datasets.DatasetInfo`] object containing all the metadata in the dataset.""" return self._info @property - def split(self): + def split(self) -> Optional[NamedSplit]: """[`~datasets.NamedSplit`] object corresponding to a named dataset split.""" return self._split @property - def builder_name(self) -> str: + def builder_name(self) -> Optional[str]: return self._info.builder_name @property @@ -180,7 +193,7 @@ def citation(self) -> str: return self._info.citation @property - def config_name(self) -> str: + def config_name(self) -> Optional[str]: return self._info.config_name @property @@ -216,11 +229,11 @@ def size_in_bytes(self) -> Optional[int]: return self._info.size_in_bytes @property - def supervised_keys(self): + def supervised_keys(self) -> Optional[SupervisedKeysData]: return self._info.supervised_keys @property - def version(self): + def version(self) -> Optional[Union[str, Version]]: return self._info.version @@ -229,13 +242,13 @@ class TensorflowDatasetMixin: @staticmethod def _get_output_signature( - dataset: "Dataset", + dataset: Dataset, collate_fn: Callable, collate_fn_args: dict, cols_to_retain: Optional[list[str]] = None, batch_size: Optional[int] = None, num_test_batches: int = 20, - ): + ) -> tuple[dict[str, tf.TensorSpec], dict[str, Union[np.int64, np.float32, np.str_]]]: """Private method used by `to_tf_dataset()` to find the shapes and dtypes of samples from this dataset after being passed through the collate_fn. Tensorflow needs an exact signature for tf.numpy_function, so the only way to do this is to run test batches - the collator may add or rename columns, so we can't figure @@ -247,7 +260,7 @@ def _get_output_signature( validation/evaluation. collate_fn(`Callable`): A function or callable object (such as a `DataCollator`) that will collate lists of samples into a batch. - collate_fn_args (`Dict`): A `dict` of keyword arguments to be passed to the + collate_fn_args (`dict`): A `dict` of keyword arguments to be passed to the `collate_fn`. batch_size (`int`, optional): The size of batches loaded from the dataset. Used for shape inference. Can be None, which indicates that batch sizes can be variable. @@ -337,7 +350,7 @@ def to_tf_dataset( prefetch: bool = True, num_workers: int = 0, num_test_batches: int = 20, - ): + ) -> tf.data.Dataset: """Create a `tf.data.Dataset` from the underlying Dataset. This `tf.data.Dataset` will load and collate batches from the Dataset, and is suitable for passing to methods like `model.fit()` or `model.predict()`. The dataset will yield `dicts` for both inputs and labels unless the `dict` would contain only a single key, in which case a raw @@ -347,7 +360,7 @@ def to_tf_dataset( batch_size (`int`, *optional*): Size of batches to load from the dataset. Defaults to `None`, which implies that the dataset won't be batched, but the returned dataset can be batched later with `tf_dataset.batch(batch_size)`. - columns (`List[str]` or `str`, *optional*): + columns (`list[str]` or `str`, *optional*): Dataset column(s) to load in the `tf.data.Dataset`. Column names that are created by the `collate_fn` and that do not exist in the original dataset can be used. shuffle(`bool`, defaults to `False`): @@ -359,10 +372,10 @@ def to_tf_dataset( collate_fn(`Callable`, *optional*): A function or callable object (such as a `DataCollator`) that will collate lists of samples into a batch. - collate_fn_args (`Dict`, *optional*): + collate_fn_args (`dict`, *optional*): An optional `dict` of keyword arguments to be passed to the `collate_fn`. - label_cols (`List[str]` or `str`, defaults to `None`): + label_cols (`list[str]` or `str`, defaults to `None`): Dataset column(s) to load as labels. Note that many models compute loss internally rather than letting Keras do it, in which case passing the labels here is optional, as long as they're in the input `columns`. @@ -504,7 +517,7 @@ def to_tf_dataset( else: raise ValueError("num_workers must be >= 0") - def split_features_and_labels(input_batch): + def split_features_and_labels(input_batch: dict) -> Union[dict, tuple[dict, dict]]: # TODO(Matt, QL): deprecate returning the dict content when there's only one key features = {key: tensor for key, tensor in input_batch.items() if key in columns} labels = {key: tensor for key, tensor in input_batch.items() if key in label_cols} @@ -524,7 +537,7 @@ def split_features_and_labels(input_batch): tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE) # Remove a reference to the open Arrow file on delete - def cleanup_callback(ref): + def cleanup_callback(ref) -> None: dataset.__del__() self._TF_DATASET_REFS.remove(ref) @@ -537,16 +550,16 @@ class DatasetTransformationNotAllowedError(Exception): pass -def transmit_format(func): +def transmit_format(func: Callable) -> Callable: """Wrapper for dataset transforms that recreate a new Dataset to transmit the format of the original dataset to the new dataset""" @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Union[dict, Dataset, DatasetDict]: if args: - self: "Dataset" = args[0] + self: Dataset = args[0] args = args[1:] else: - self: "Dataset" = kwargs.pop("self") + self: Dataset = kwargs.pop("self") # don't use self.format since it returns a list of columns for 'columns' even if self_format_columns is None unformatted_columns = set(self.column_names) - set(self._format_columns or []) self_format = { @@ -556,8 +569,8 @@ def wrapper(*args, **kwargs): "output_all_columns": self._output_all_columns, } # apply actual function - out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs) - datasets: list["Dataset"] = list(out.values()) if isinstance(out, dict) else [out] + out: Union[Dataset, DatasetDict] = func(self, *args, **kwargs) + datasets: list[Dataset] = list(out.values()) if isinstance(out, dict) else [out] # re-apply format to the output for dataset in datasets: new_format = self_format.copy() @@ -580,7 +593,7 @@ def wrapper(*args, **kwargs): return wrapper -def update_metadata_with_features(table: Table, features: Features): +def update_metadata_with_features(table: Table, features: Features) -> Table: """To be used in dataset transforms that modify the features of the dataset, in order to update the features stored in the metadata of its schema.""" features = Features({col_name: features[col_name] for col_name in table.column_names}) if table.schema.metadata is None or b"huggingface" not in table.schema.metadata: @@ -596,7 +609,7 @@ def update_metadata_with_features(table: Table, features: Features): return table -def _check_table(table) -> Table: +def _check_table(table: Union[pa.Table, Table]) -> Union[InMemoryTable, Table]: """We check the table type to make sure it's an instance of :class:`datasets.table.Table`""" if isinstance(table, pa.Table): # for a pyarrow table, we can just consider it as a in-memory table @@ -608,7 +621,7 @@ def _check_table(table) -> Table: raise TypeError(f"Expected a pyarrow.Table or a datasets.table.Table object, but got {table}.") -def _check_column_names(column_names: list[str]): +def _check_column_names(column_names: list[str]) -> None: """Check the column names to make sure they don't contain duplicates.""" counter = Counter(column_names) if not all(count == 1 for count in counter.values()): @@ -616,7 +629,7 @@ def _check_column_names(column_names: list[str]): raise ValueError(f"The table can't have duplicated columns but columns {duplicated_columns} are duplicated.") -def _check_valid_indices_value(index, size): +def _check_valid_indices_value(index: int, size: int) -> None: if (index < 0 and index + size < 0) or (index >= size): raise IndexError(f"Index {index} out of range for dataset of size {size}.") @@ -637,7 +650,7 @@ def __init__( split: Optional[NamedSplit] = None, indices_table: Optional[Table] = None, fingerprint: Optional[str] = None, - ): + ) -> None: info = info.copy() if info is not None else DatasetInfo() DatasetInfoMixin.__init__(self, info=info, split=split) IndexableMixin.__init__(self) @@ -650,7 +663,7 @@ def __init__( self._format_kwargs: dict = {} self._format_columns: Optional[list] = None self._output_all_columns: bool = False - self._fingerprint: str = fingerprint + self._fingerprint: Optional[str] = fingerprint # Read metadata @@ -718,7 +731,7 @@ def from_file( split: Optional[NamedSplit] = None, indices_filename: Optional[str] = None, in_memory: bool = False, - ) -> "Dataset": + ) -> Dataset: """Instantiate a Dataset backed by an Arrow table at filename. Args: @@ -757,7 +770,7 @@ def from_buffer( info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, indices_buffer: Optional[pa.Buffer] = None, - ) -> "Dataset": + ) -> Dataset: """Instantiate a Dataset backed by an Arrow buffer. Args: @@ -790,7 +803,7 @@ def from_pandas( info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, preserve_index: Optional[bool] = None, - ) -> "Dataset": + ) -> Dataset: """ Convert `pandas.DataFrame` to a `pyarrow.Table` to create a [`Dataset`]. @@ -857,7 +870,7 @@ def from_polars( features: Optional[Features] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, - ) -> "Dataset": + ) -> Dataset: """ Collect the underlying arrow arrays in an Arrow Table. @@ -895,11 +908,11 @@ def from_polars( @classmethod def from_dict( cls, - mapping: dict, + mapping: Mapping, features: Optional[Features] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, - ) -> "Dataset": + ) -> Dataset: """ Convert `dict` to a `pyarrow.Table` to create a [`Dataset`]. @@ -961,7 +974,7 @@ def from_list( features: Optional[Features] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, - ) -> "Dataset": + ) -> Dataset: """ Convert a list of dicts to a `pyarrow.Table` to create a [`Dataset`]`. @@ -975,7 +988,7 @@ def from_list( and reload using e.g. save_to_disk / load_from_disk. Args: - mapping (`List[dict]`): A list of mappings of strings to row values. + mapping (`list[dict]`): A list of mappings of strings to row values. features (`Features`, optional): Dataset features. info (`DatasetInfo`, optional): Dataset information, like description, citation, etc. split (`NamedSplit`, optional): Name of the dataset split. @@ -983,20 +996,24 @@ def from_list( Returns: [`Dataset`] """ - # for simplicity and consistency wrt OptimizedTypedSequence we do not use InMemoryTable.from_pylist here - mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {} - return cls.from_dict(mapping, features, info, split) + # for simplicity and consistency wrt OptimizedTypedSequence we do not use InMemoryTable.from_pylist here' + transformed_mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {} + return cls.from_dict(transformed_mapping, features, info, split) + + from typing import ParamSpec + + P = ParamSpec("P") @staticmethod def from_csv( path_or_paths: Union[PathLike, list[PathLike]], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, num_proc: Optional[int] = None, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], IterableDataset, Dataset, DatasetDict]: """Create Dataset from CSV file(s). Args: @@ -1044,13 +1061,13 @@ def from_csv( def from_generator( generator: Callable, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, split: NamedSplit = Split.TRAIN, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], IterableDataset, Dataset, DatasetDict]: """Create a Dataset from a generator. Args: @@ -1120,12 +1137,12 @@ def from_json( path_or_paths: Union[PathLike, list[PathLike]], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, field: Optional[str] = None, num_proc: Optional[int] = None, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], IterableDataset, Dataset, DatasetDict]: """Create Dataset from JSON or JSON Lines file(s). Args: @@ -1177,12 +1194,12 @@ def from_parquet( path_or_paths: Union[PathLike, list[PathLike]], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, columns: Optional[list[str]] = None, num_proc: Optional[int] = None, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], IterableDataset, Dataset, DatasetDict]: """Create Dataset from Parquet file(s). Args: @@ -1196,7 +1213,7 @@ def from_parquet( Directory to cache data. keep_in_memory (`bool`, defaults to `False`): Whether to copy the data in-memory. - columns (`List[str]`, *optional*): + columns (`list[str]`, *optional*): If not `None`, only these columns will be read from the file. A column name may be a prefix of a nested field, e.g. 'a' will select 'a.b', 'a.c', and 'a.d.e'. @@ -1236,11 +1253,11 @@ def from_text( path_or_paths: Union[PathLike, list[PathLike]], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, num_proc: Optional[int] = None, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], IterableDataset, Dataset, DatasetDict]: """Create Dataset from text file(s). Args: @@ -1286,15 +1303,15 @@ def from_text( @staticmethod def from_spark( - df: "pyspark.sql.DataFrame", + df: pyspark.sql.DataFrame, split: Optional[NamedSplit] = None, features: Optional[Features] = None, keep_in_memory: bool = False, - cache_dir: str = None, - working_dir: str = None, + cache_dir: Optional[str] = None, + working_dir: Optional[str] = None, load_from_cache_file: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], "IterableDataset", Dataset, DatasetDict]: """Create a Dataset from Spark DataFrame. Dataset downloading is distributed over Spark workers. Args: @@ -1348,13 +1365,13 @@ def from_spark( @staticmethod def from_sql( - sql: Union[str, "sqlalchemy.sql.Selectable"], - con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], + sql: Union[str, sqlalchemy.sql.Selectable], + con: Union[str, sqlalchemy.engine.Connection, sqlalchemy.engine.Engine, sqlite3.Connection], features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> Union[dict[str, IterableDataset], IterableDataset, Dataset, DatasetDict]: """Create Dataset from SQL query or database table. Args: @@ -1404,32 +1421,32 @@ def from_sql( **kwargs, ).read() - def __setstate__(self, state): + def __setstate__(self, state) -> Dataset: self.__dict__.update(state) maybe_register_dataset_for_temp_dir_deletion(self) return self - def __del__(self): + def __del__(self) -> None: if hasattr(self, "_data"): del self._data if hasattr(self, "_indices"): del self._indices - def __enter__(self): + def __enter__(self) -> object: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # Here `del` is used to del the pyarrow tables. This properly closes the files used for memory mapped tables self.__del__() def save_to_disk( self, - dataset_path: PathLike, + dataset_path: Union[str, os.PathLike], max_shard_size: Optional[Union[str, int]] = None, num_shards: Optional[int] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, - ): + ) -> None: """ Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. @@ -1576,7 +1593,7 @@ def save_to_disk( json.dump(sorted_keys_dataset_info, dataset_info_file, indent=2) @staticmethod - def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_options: Optional[dict]): + def _save_to_disk_single(job_id: int, shard: Dataset, fpath: str, storage_options: Optional[dict]) -> Iterator: batch_size = config.DEFAULT_MAX_BATCH_SIZE num_examples_progress_update = 0 @@ -1603,7 +1620,7 @@ def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_opti yield job_id, True, (num_examples, num_bytes) @staticmethod - def _build_local_temp_path(uri_or_path: str) -> Path: + def _build_local_temp_path(uri_or_path: os.PathLike) -> Path: """ Builds and returns a Path concatenating a local temporary dir with the dir path (or absolute/relative path extracted from the uri) passed. @@ -1621,10 +1638,10 @@ def _build_local_temp_path(uri_or_path: str) -> Path: @staticmethod def load_from_disk( - dataset_path: PathLike, + dataset_path: os.PathLike, keep_in_memory: Optional[bool] = None, storage_options: Optional[dict] = None, - ) -> "Dataset": + ) -> Dataset: """ Loads a dataset that was previously saved using [`save_to_disk`] from a dataset directory, or from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. @@ -1873,7 +1890,7 @@ def unique(self, column: str) -> list: return dataset._data.column(column).unique().to_pylist() - def class_encode_column(self, column: str, include_nulls: bool = False) -> "Dataset": + def class_encode_column(self, column: str, include_nulls: bool = False) -> Dataset: """Casts the given column as [`~datasets.features.ClassLabel`] and updates the table. Args: @@ -1903,6 +1920,7 @@ def class_encode_column(self, column: str, include_nulls: bool = False) -> "Data # Sanity checks if column not in self._data.column_names: raise ValueError(f"Column ({column}) not in table columns ({self._data.column_names}).") + src_feat = self._info.features[column] if not isinstance(src_feat, Value): raise ValueError( @@ -1911,7 +1929,7 @@ def class_encode_column(self, column: str, include_nulls: bool = False) -> "Data if src_feat.dtype != "string" or (include_nulls and None in self.unique(column)): - def stringify_column(batch): + def stringify_column(batch: dict) -> dict: batch[column] = [ str(sample) if include_nulls or sample is not None else None for sample in batch[column] ] @@ -1929,7 +1947,7 @@ def stringify_column(batch): class_names = sorted(str(sample) for sample in dset.unique(column) if include_nulls or sample is not None) dst_feat = ClassLabel(names=class_names) - def cast_to_class_labels(batch): + def cast_to_class_labels(batch: dict) -> dict: batch[column] = [ dst_feat.str2int(str(sample)) if include_nulls or sample is not None else None for sample in batch[column] @@ -1949,7 +1967,7 @@ def cast_to_class_labels(batch): return dset @fingerprint_transform(inplace=False) - def flatten(self, new_fingerprint: Optional[str] = None, max_depth=16) -> "Dataset": + def flatten(self, new_fingerprint: Optional[str] = None, max_depth: int = 16) -> Dataset: """Flatten the table. Each column with a struct type is flattened into one column per struct field. Other columns are left unchanged. @@ -1996,13 +2014,13 @@ def flatten(self, new_fingerprint: Optional[str] = None, max_depth=16) -> "Datas def cast( self, features: Features, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, keep_in_memory: bool = False, load_from_cache_file: Optional[bool] = None, cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, num_proc: Optional[int] = None, - ) -> "Dataset": + ) -> Dataset: """ Cast the dataset to a new set of features. @@ -2077,7 +2095,7 @@ def cast( return dataset @fingerprint_transform(inplace=False) - def cast_column(self, column: str, feature: FeatureType, new_fingerprint: Optional[str] = None) -> "Dataset": + def cast_column(self, column: str, feature: FeatureType, new_fingerprint: Optional[str] = None) -> Dataset: """Cast column to feature for decoding. Args: @@ -2120,7 +2138,7 @@ def cast_column(self, column: str, feature: FeatureType, new_fingerprint: Option @transmit_format @fingerprint_transform(inplace=False) - def remove_columns(self, column_names: Union[str, list[str]], new_fingerprint: Optional[str] = None) -> "Dataset": + def remove_columns(self, column_names: Union[str, list[str]], new_fingerprint: Optional[str] = None) -> Dataset: """ Remove one or several column(s) in the dataset and the features associated to them. @@ -2128,7 +2146,7 @@ def remove_columns(self, column_names: Union[str, list[str]], new_fingerprint: O doesn't copy the data of the remaining columns and is thus faster. Args: - column_names (`Union[str, List[str]]`): + column_names (`Union[str, list[str]]`): Name of the column(s) to remove. new_fingerprint (`str`, *optional*): The new fingerprint of the dataset after transform. @@ -2176,7 +2194,7 @@ def remove_columns(self, column_names: Union[str, list[str]], new_fingerprint: O @fingerprint_transform(inplace=False) def rename_column( self, original_column_name: str, new_column_name: str, new_fingerprint: Optional[str] = None - ) -> "Dataset": + ) -> Dataset: """ Rename a column in the dataset, and move the features associated to the original column under the new column name. @@ -2220,7 +2238,7 @@ def rename_column( if not new_column_name: raise ValueError("New column name is empty.") - def rename(columns): + def rename(columns: Iterable) -> list: return [new_column_name if col == original_column_name else col for col in columns] new_column_names = rename(self._data.column_names) @@ -2240,13 +2258,13 @@ def rename(columns): return dataset @fingerprint_transform(inplace=False) - def rename_columns(self, column_mapping: dict[str, str], new_fingerprint: Optional[str] = None) -> "Dataset": + def rename_columns(self, column_mapping: dict[str, str], new_fingerprint: Optional[str] = None) -> Dataset: """ Rename several columns in the dataset, and move the features associated to the original columns under the new column names. Args: - column_mapping (`Dict[str, str]`): + column_mapping (`dict[str, str]`): A mapping of columns to rename to their new names new_fingerprint (`str`, *optional*): The new fingerprint of the dataset after transform. @@ -2287,7 +2305,7 @@ def rename_columns(self, column_mapping: dict[str, str], new_fingerprint: Option if empty_new_columns: raise ValueError(f"New column names {empty_new_columns} are empty.") - def rename(columns): + def rename(columns: Iterable[str]) -> list: return [column_mapping[col] if col in column_mapping else col for col in columns] new_column_names = rename(self._data.column_names) @@ -2308,12 +2326,12 @@ def rename(columns): @transmit_format @fingerprint_transform(inplace=False) - def select_columns(self, column_names: Union[str, list[str]], new_fingerprint: Optional[str] = None) -> "Dataset": + def select_columns(self, column_names: Union[str, list[str]], new_fingerprint: Optional[str] = None) -> Dataset: """Select one or several column(s) in the dataset and the features associated to them. Args: - column_names (`Union[str, List[str]]`): + column_names (`Union[str, list[str]]`): Name of the column(s) to keep. new_fingerprint (`str`, *optional*): The new fingerprint of the dataset after transform. If `None`, @@ -2354,7 +2372,7 @@ def select_columns(self, column_names: Union[str, list[str]], new_fingerprint: O dataset._fingerprint = new_fingerprint return dataset - def __len__(self): + def __len__(self) -> int: """Number of rows in the dataset. Example: @@ -2371,7 +2389,7 @@ def __len__(self): """ return self.num_rows - def __iter__(self): + def __iter__(self) -> Generator[Optional[Union[Mapping[str, Any], list]]]: """Iterate through the examples. If a formatting is set with [`Dataset.set_format`] rows will be returned with the @@ -2400,7 +2418,7 @@ def __iter__(self): i, ) - def iter(self, batch_size: int, drop_last_batch: bool = False): + def iter(self, batch_size: int, drop_last_batch: bool = False) -> Generator[Union[Mapping[str, Any], list]]: """Iterate through the batches of size `batch_size`. If a formatting is set with [`~datasets.Dataset.set_format`] rows will be returned with the @@ -2432,11 +2450,11 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): slice(i, i + batch_size), ) - def __repr__(self): + def __repr__(self) -> str: return f"Dataset({{\n features: {list(self._info.features.keys())},\n num_rows: {self.num_rows}\n}})" @property - def format(self): + def format(self) -> dict[str, Any]: return { "type": self._format_type, "format_kwargs": self._format_kwargs, @@ -2450,15 +2468,15 @@ def formatted_as( type: Optional[str] = None, columns: Optional[list] = None, output_all_columns: bool = False, - **format_kwargs, - ): + **format_kwargs: dict[str, Any], + ) -> Iterator: """To be used in a `with` statement. Set `__getitem__` return format (type and columns). Args: type (`str`, *optional*): Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__`` returns python objects (default). - columns (`List[str]`, *optional*): + columns (`list[str]`, *optional*): Columns to format in the output. `None` means `__getitem__` returns all columns (default). output_all_columns (`bool`, defaults to `False`): @@ -2482,8 +2500,8 @@ def set_format( type: Optional[str] = None, columns: Optional[list] = None, output_all_columns: bool = False, - **format_kwargs, - ): + **format_kwargs: dict[str, Any], + ) -> None: """Set `__getitem__` return format (type and columns). The data formatting is applied on-the-fly. The format `type` (for example "numpy") is used to format batches when using `__getitem__`. It's also possible to use custom transforms for formatting using [`~datasets.Dataset.set_transform`]. @@ -2492,7 +2510,7 @@ def set_format( type (`str`, *optional*): Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). - columns (`List[str]`, *optional*): + columns (`list[str]`, *optional*): Columns to format in the output. `None` means `__getitem__` returns all columns (default). output_all_columns (`bool`, defaults to `False`): @@ -2555,7 +2573,7 @@ def set_format( "do" if output_all_columns else "don't", ) - def reset_format(self): + def reset_format(self) -> None: """Reset `__getitem__` return format to python objects and all columns. Same as `self.set_format()` @@ -2589,7 +2607,7 @@ def set_transform( transform: Optional[Callable], columns: Optional[list] = None, output_all_columns: bool = False, - ): + ) -> None: """Set `__getitem__` return format using this transform. The transform is applied on-the-fly on batches when `__getitem__` is called. As [`~datasets.Dataset.set_format`], this can be reset using [`~datasets.Dataset.reset_format`]. @@ -2598,7 +2616,7 @@ def set_transform( User-defined formatting transform, replaces the format defined by [`~datasets.Dataset.set_format`]. A formatting function is a callable that takes a batch (as a `dict`) as input and returns a batch. This function is applied right before returning the objects in `__getitem__`. - columns (`List[str]`, *optional*): + columns (`list[str]`, *optional*): Columns to format in the output. If specified, then the input batch of the transform only contains those columns. output_all_columns (`bool`, defaults to `False`): @@ -2633,7 +2651,7 @@ def with_format( columns: Optional[list] = None, output_all_columns: bool = False, **format_kwargs, - ): + ) -> Dataset: """Set `__getitem__` return format (type and columns). The data formatting is applied on-the-fly. The format `type` (for example "numpy") is used to format batches when using `__getitem__`. @@ -2645,7 +2663,7 @@ def with_format( type (`str`, *optional*): Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). - columns (`List[str]`, *optional*): + columns (`list[str]`, *optional*): Columns to format in the output. `None` means `__getitem__` returns all columns (default). output_all_columns (`bool`, defaults to `False`): @@ -2703,7 +2721,7 @@ def with_transform( transform: Optional[Callable], columns: Optional[list] = None, output_all_columns: bool = False, - ): + ) -> Dataset: """Set `__getitem__` return format using this transform. The transform is applied on-the-fly on batches when `__getitem__` is called. As [`~datasets.Dataset.set_format`], this can be reset using [`~datasets.Dataset.reset_format`]. @@ -2715,7 +2733,7 @@ def with_transform( User-defined formatting transform, replaces the format defined by [`~datasets.Dataset.set_format`]. A formatting function is a callable that takes a batch (as a `dict`) as input and returns a batch. This function is applied right before returning the objects in `__getitem__`. - columns (`List[str]`, `optional`): + columns (`list[str]`, `optional`): Columns to format in the output. If specified, then the input batch of the transform only contains those columns. output_all_columns (`bool`, defaults to `False`): @@ -2821,8 +2839,8 @@ def cleanup_cache_files(self) -> int: os.remove(file_path) return len(files_to_remove) - def _get_cache_file_path(self, fingerprint): - if is_caching_enabled() and self.cache_files: + def _get_cache_file_path(self, fingerprint: str) -> str: + if is_caching_enabled() and self.cache_files and fingerprint is not None: cache_file_name = "cache-" + fingerprint + ".arrow" cache_directory = os.path.dirname(self.cache_files[0]["filename"]) else: @@ -2839,13 +2857,13 @@ def map( with_rank: bool = False, input_columns: Optional[Union[str, list[str]]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, drop_last_batch: bool = False, remove_columns: Optional[Union[str, list[str]]] = None, keep_in_memory: bool = False, load_from_cache_file: Optional[bool] = None, cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, features: Optional[Features] = None, disable_nullable: bool = False, fn_kwargs: Optional[dict] = None, @@ -2854,7 +2872,7 @@ def map( new_fingerprint: Optional[str] = None, desc: Optional[str] = None, try_original_type: Optional[bool] = True, - ) -> "Dataset": + ) -> Dataset: """ Apply a function to all the examples in the table (individually or in batches) and update the table. If your function returns a column that already exists, then it overwrites it. @@ -2875,10 +2893,10 @@ def map( Args: function (`Callable`): Function with one of the following signatures: - - `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False` - - `function(example: Dict[str, Any], *extra_args) -> Dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) - - `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` and `with_rank=False` - - `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(example: dict[str, Any]) -> dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False` + - `function(example: dict[str, Any], *extra_args) -> dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(batch: dict[str, list]) -> dict[str, list]` if `batched=True` and `with_indices=False` and `with_rank=False` + - `function(batch: dict[str, list], *extra_args) -> dict[str, list]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) For advanced usage, the function can also return a `pyarrow.Table`. If the function is asynchronous, then `map` will run your function in parallel. @@ -2890,7 +2908,7 @@ def map( with_rank (`bool`, defaults to `False`): Provide process rank to `function`. Note that in this case the signature of `function` should be `def function(example[, idx], rank): ...`. - input_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): + input_columns (`Optional[Union[str, list[str]]]`, defaults to `None`): The columns to be passed into `function` as positional arguments. If `None`, a `dict` mapping to all formatted columns is passed as one argument. batched (`bool`, defaults to `False`): @@ -2901,7 +2919,7 @@ def map( drop_last_batch (`bool`, defaults to `False`): Whether a last batch smaller than the batch_size should be dropped instead of being processed by the function. - remove_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): + remove_columns (`Optional[Union[str, list[str]]]`, defaults to `None`): Remove a selection of columns while doing the mapping. Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding columns with names in `remove_columns`, these columns will be kept. @@ -2922,7 +2940,7 @@ def map( instead of the automatically generated one. disable_nullable (`bool`, defaults to `False`): Disallow null values in the table. - fn_kwargs (`Dict`, *optional*, defaults to `None`): + fn_kwargs (`dict`, *optional*, defaults to `None`): Keyword arguments to be passed to `function`. num_proc (`int`, *optional*, defaults to `None`): Max number of processes when generating cache. Already cached shards are loaded sequentially. @@ -3161,7 +3179,7 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory) for rank in range(num_shards) ] - kwargs_per_job = [ + kwargs_per_job: list[dict[str, Any]] = [ { **dataset_kwargs, "shard": shards[rank], @@ -3267,50 +3285,49 @@ def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, @staticmethod def _map_single( - shard: "Dataset", + shard: Dataset, function: Optional[Callable] = None, with_indices: bool = False, with_rank: bool = False, input_columns: Optional[list[str]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, drop_last_batch: bool = False, remove_columns: Optional[list[str]] = None, keep_in_memory: bool = False, - cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + cache_file_name: Optional[Union[str, os.PathLike]] = None, + writer_batch_size: int = 1000, features: Optional[Features] = None, disable_nullable: bool = False, fn_kwargs: Optional[dict] = None, new_fingerprint: Optional[str] = None, rank: Optional[int] = None, offset: int = 0, - try_original_type: Optional[bool] = True, - ) -> Iterable[tuple[Optional[int], bool, Union[int, "Dataset"]]]: + ) -> Iterable[tuple[int, bool, Union[int, Dataset]]]: """Apply a function to all the elements in the table (individually or in batches) and update the table (if function does update examples). Args: shard (`datasets.Dataset`): Dataset to map the transform on. function (`Callable`): with one of the following signature: - - `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False` - - `function(example: Dict[str, Any], *extra_args) -> Dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) - - `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` and `with_rank=False` - - `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(example: dict[str, Any]) -> dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False` + - `function(example: dict[str, Any], *extra_args) -> dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(batch: dict[str, list]) -> dict[str, list]` if `batched=True` and `with_indices=False` and `with_rank=False` + - `function(batch: dict[str, list], *extra_args) -> dict[str, list]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) For advanced usage, the function can also return a `pyarrow.Table`. Moreover if your function returns nothing (`None`), then `map` will run your function and return the dataset unchanged. If no function is provided, default to identity function: lambda x: x with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx[, rank]): ...`. with_rank (`bool`, default `False`): Provide process rank to `function`. Note that in this case the signature of `function` should be `def function(example[, idx], rank): ...`. - input_columns (`Optional[List[str]]`, defaults to `None`): The columns to be passed into `function` as + input_columns (`Optional[list[str]]`, defaults to `None`): The columns to be passed into `function` as positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument. batched (`bool`, defaults to `False`): Provide batch of examples to `function` batch_size (`int`, optional, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True` `batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function` drop_last_batch (`bool`, default: `False`): Whether a last batch smaller than the batch_size should be dropped instead of being processed by the function. - remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping. + remove_columns (`Optional[list[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping. Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding columns with names in `remove_columns`, these columns will be kept. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. @@ -3322,7 +3339,7 @@ def _map_single( features (`Optional[datasets.Features]`, defaults to `None`): Use a specific Features to store the cache file instead of the automatically generated one. disable_nullable (`bool`, defaults to `False`): Disallow null values in the table. - fn_kwargs (`Dict`, optional, defaults to `None`): Keyword arguments to be passed to `function` + fn_kwargs (`dict`, optional, defaults to `None`): Keyword arguments to be passed to `function` new_fingerprint (`str`, optional, defaults to `None`): the new fingerprint of the dataset after transform. If `None`, the new fingerprint is computed using a hash of the previous fingerprint, and the transform arguments rank: (`int`, optional, defaults to `None`): If specified, this is the process rank when doing multiprocessing @@ -3356,7 +3373,7 @@ def _map_single( check_same_num_examples = batched and len(shard.list_indexes()) > 0 - def validate_function_output(processed_inputs): + def validate_function_output(processed_inputs: Union[Mapping, pa.Table, pd.DataFrame]) -> None: """Validate output of the map function.""" allowed_processed_inputs_types = (Mapping, pa.Table, pd.DataFrame) if config.POLARS_AVAILABLE and "polars" in sys.modules: @@ -3393,7 +3410,10 @@ def validate_function_output(processed_inputs): f"Provided `function` which is applied to all elements of table returns a `dict` of types {[type(x) for x in processed_inputs.values()]}. When using `batched=True`, make sure provided `function` returns a `dict` of types like `{allowed_batch_return_types}`." ) - def prepare_inputs(pa_inputs, indices, offset=0): + def prepare_inputs( + pa_inputs: Table, indices: Union[list, int], offset: int = 0 + ) -> tuple[Any, Union[list, list[list]], tuple, dict[str, Any]]: + # TODO type correctly `inputs` """Utility to apply the function on a selection of columns.""" inputs = format_table( pa_inputs, @@ -3413,7 +3433,11 @@ def prepare_inputs(pa_inputs, indices, offset=0): additional_args += (rank,) return inputs, fn_args, additional_args, fn_kwargs - def prepare_outputs(pa_inputs, inputs, processed_inputs): + def prepare_outputs( + pa_inputs: Table, + inputs: Union[LazyDict, MutableMapping[str, Any]], + processed_inputs: Union[MutableMapping, pa.Table, pd.DataFrame, LazyDict], + ) -> Union[MutableMapping, pa.Table, pd.DataFrame, LazyDict]: nonlocal update_data if not (update_data := (processed_inputs is not None)): return None @@ -3456,19 +3480,23 @@ def prepare_outputs(pa_inputs, inputs, processed_inputs): else: return processed_inputs - def apply_function(pa_inputs, indices, offset=0): + def apply_function( + pa_inputs: Table, indices: list, offset: int = 0 + ) -> Union[MutableMapping, pa.Table, pd.DataFrame, LazyDict]: """Utility to apply the function on a selection of columns.""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset) processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) return prepare_outputs(pa_inputs, inputs, processed_inputs) - async def async_apply_function(pa_inputs, indices, offset=0): + async def async_apply_function( + pa_inputs: Table, indices: list, offset: int = 0 + ) -> Union[MutableMapping, pa.Table, pd.DataFrame, LazyDict]: """Utility to apply the function on a selection of columns. Same code but async""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset) processed_inputs = await function(*fn_args, *additional_args, **fn_kwargs) return prepare_outputs(pa_inputs, inputs, processed_inputs) - def init_buffer_and_writer(): + def init_buffer_and_writer() -> tuple[Optional[pa.BufferOutputStream], ArrowWriter, Optional[IO]]: # Prepare output buffer and batched writer in memory or on file if we update the table writer_features = features if writer_features is None: @@ -3655,7 +3683,7 @@ def batch( drop_last_batch: bool = False, num_proc: Optional[int] = None, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """ Group samples from the dataset into batches. @@ -3709,17 +3737,17 @@ def filter( with_rank: bool = False, input_columns: Optional[Union[str, list[str]]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, keep_in_memory: bool = False, load_from_cache_file: Optional[bool] = None, cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, fn_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, suffix_template: str = "_{rank:05d}_of_{num_proc:05d}", new_fingerprint: Optional[str] = None, desc: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Apply a filter function to all the elements in the table in batches and update the table so that the dataset only includes examples according to the filter function. @@ -3729,10 +3757,10 @@ def filter( Args: function (`Callable`): Callable with one of the following signatures: - - `function(example: Dict[str, Any]) -> bool` if `batched=False` and `with_indices=False` and `with_rank=False` - - `function(example: Dict[str, Any], *extra_args) -> bool` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) - - `function(batch: Dict[str, List]) -> List[bool]` if `batched=True` and `with_indices=False` and `with_rank=False` - - `function(batch: Dict[str, List], *extra_args) -> List[bool]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(example: dict[str, Any]) -> bool` if `batched=False` and `with_indices=False` and `with_rank=False` + - `function(example: dict[str, Any], *extra_args) -> bool` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(batch: dict[str, list]) -> list[bool]` if `batched=True` and `with_indices=False` and `with_rank=False` + - `function(batch: dict[str, list], *extra_args) -> list[bool]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) If the function is asynchronous, then `filter` will run your function in parallel. If no function is provided, defaults to an always `True` function: `lambda x: True`. @@ -3742,7 +3770,7 @@ def filter( with_rank (`bool`, defaults to `False`): Provide process rank to `function`. Note that in this case the signature of `function` should be `def function(example[, idx], rank): ...`. - input_columns (`str` or `List[str]`, *optional*): + input_columns (`str` or `list[str]`, *optional*): The columns to be passed into `function` as positional arguments. If `None`, a `dict` mapping to all formatted columns is passed as one argument. batched (`bool`, defaults to `False`): @@ -3847,12 +3875,12 @@ def flatten_indices( self, keep_in_memory: bool = False, cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, features: Optional[Features] = None, disable_nullable: bool = False, num_proc: Optional[int] = None, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Create and cache a new Dataset by flattening the indices mapping. Args: @@ -3894,7 +3922,7 @@ def _new_dataset_with_indices( indices_cache_file_name: Optional[str] = None, indices_buffer: Optional[pa.Buffer] = None, fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Return a new Dataset obtained by adding indices (provided in indices_cache_file_name or in a buffer) to the current Dataset. """ @@ -3924,12 +3952,12 @@ def _new_dataset_with_indices( @fingerprint_transform(inplace=False, ignore_kwargs=["indices_cache_file_name"]) def select( self, - indices: Iterable, + indices: Union[pa.Array, pa.ChunkedArray, range, list[int], pd.Series, Iterator[int]], keep_in_memory: bool = False, indices_cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Create a new dataset with rows selected following the list/array of indices. Args: @@ -4016,7 +4044,7 @@ def _select_contiguous( start: int, length: int, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Create a new dataset with rows from a contiguous slice of data. The slice is defined by that start index and its length. @@ -4072,9 +4100,9 @@ def _select_with_indices_mapping( indices: Iterable, keep_in_memory: bool = False, indices_cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Create a new dataset with rows selected following the list/array of indices. The new dataset is made by creating a new indices mapping on top of the main arrow table. @@ -4172,7 +4200,7 @@ def _select_with_indices_mapping( else: return self._new_dataset_with_indices(indices_buffer=buf_writer.getvalue(), fingerprint=new_fingerprint) - def skip(self, n: int) -> "Dataset": + def skip(self, n: int) -> Dataset: """ Create a new [`Dataset`] that skips the first `n` elements. @@ -4202,7 +4230,7 @@ def skip(self, n: int) -> "Dataset": """ return self.select(range(n, len(self))) - def repeat(self, num_times: int) -> "Dataset": + def repeat(self, num_times: int) -> Dataset: """ Create a new [`Dataset`] that repeats the underlying dataset `num_times` times. @@ -4234,7 +4262,7 @@ def repeat(self, num_times: int) -> "Dataset": raise ValueError("Map style datasets do not support indefinite repetition.") return _concatenate_map_style_datasets([self] * num_times) if num_times > 0 else self.select([]) - def take(self, n: int) -> "Dataset": + def take(self, n: int) -> Dataset: """ Create a new [`Dataset`] with only the first `n` elements. @@ -4267,9 +4295,9 @@ def sort( keep_in_memory: bool = False, load_from_cache_file: Optional[bool] = None, indices_cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Create a new dataset sorted according to a single or multiple columns. Args: @@ -4396,9 +4424,9 @@ def shuffle( keep_in_memory: bool = False, load_from_cache_file: Optional[bool] = None, indices_cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, new_fingerprint: Optional[str] = None, - ) -> "Dataset": + ) -> Dataset: """Create a new Dataset where the rows are shuffled. Currently shuffling uses numpy random generators. @@ -4536,10 +4564,10 @@ def train_test_split( load_from_cache_file: Optional[bool] = None, train_indices_cache_file_name: Optional[str] = None, test_indices_cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, + writer_batch_size: int = 1000, train_new_fingerprint: Optional[str] = None, test_new_fingerprint: Optional[str] = None, - ) -> "DatasetDict": + ) -> DatasetDict: """Return a dictionary ([`datasets.DatasetDict`]) with two random train and test subsets (`train` and `test` `Dataset` splits). Splits are created from the dataset according to `test_size`, `train_size` and `shuffle`. @@ -4807,8 +4835,8 @@ def shard( contiguous: bool = True, keep_in_memory: bool = False, indices_cache_file_name: Optional[str] = None, - writer_batch_size: Optional[int] = 1000, - ) -> "Dataset": + writer_batch_size: int = 1000, + ) -> Dataset: """Return the `index`-nth shard from dataset split into `num_shards` pieces. This shards deterministically. `dataset.shard(n, i)` splits the dataset into contiguous chunks, @@ -5082,7 +5110,7 @@ def to_polars( batched: bool = False, schema_overrides: Optional[dict] = None, rechunk: bool = True, - ) -> Union["pl.DataFrame", Iterator["pl.DataFrame"]]: + ) -> Union[pl.DataFrame, pl.Series, Iterator[Union[pl.DataFrame, pl.Series]]]: """Returns the dataset as a `polars.DataFrame`. Can also return a generator for large datasets. Args: @@ -5237,7 +5265,7 @@ def _estimate_nbytes(self) -> int: # Approximate the space needed to store the bytes from the external files by analyzing the first 1000 examples extra_nbytes = 0 - def extra_nbytes_visitor(array, feature): + def extra_nbytes_visitor(array, feature: Union[Audio, Image, Video]): nonlocal extra_nbytes if isinstance(feature, (Audio, Image, Video)): for x in array.to_pylist(): @@ -5257,17 +5285,17 @@ def extra_nbytes_visitor(array, feature): return dataset_nbytes @staticmethod - def _generate_tables_from_shards(shards: list["Dataset"], batch_size: int): + def _generate_tables_from_shards(shards: list[Dataset], batch_size: int): for shard_idx, shard in enumerate(shards): for pa_table in shard.with_format("arrow").iter(batch_size): yield shard_idx, pa_table @staticmethod - def _generate_tables_from_cache_file(filename: str): + def _generate_tables_from_cache_file(filename: str) -> tuple[int, pa.Table]: for batch_idx, batch in enumerate(_memory_mapped_record_batch_reader_from_file(filename)): yield batch_idx, pa.Table.from_batches([batch]) - def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset": + def to_iterable_dataset(self, num_shards: int = 1) -> IterableDataset: """Get an [`datasets.IterableDataset`] from a map-style [`datasets.Dataset`]. This is equivalent to loading a dataset in streaming mode with [`datasets.load_dataset`], but much faster since the data is streamed from local files. @@ -5403,11 +5431,11 @@ def _push_parquet_shards_to_hub( max_shard_size: Optional[Union[int, str]] = None, num_shards: Optional[int] = None, embed_external_files: bool = True, - ) -> tuple[list[CommitOperationAdd], int, int]: + ) -> tuple[list, int, int]: """Pushes the dataset shards as Parquet files to the hub. Returns: - additions (`List[CommitOperation]`): list of the `CommitOperationAdd` of the uploaded shards + additions (`list[CommitOperation]`): list of the `CommitOperationAdd` of the uploaded shards uploaded_size (`int`): number of uploaded bytes to the repository dataset_nbytes (`int`): approximate size in bytes of the uploaded dataset after uncompression """ @@ -5880,12 +5908,12 @@ def add_faiss_index( device: Optional[int] = None, string_factory: Optional[str] = None, metric_type: Optional[int] = None, - custom_index: Optional["faiss.Index"] = None, # noqa: F821 + custom_index: Optional[faiss.Index] = None, # noqa: F821 batch_size: int = 1000, train_size: Optional[int] = None, faiss_verbose: bool = False, dtype=np.float32, - ): + ) -> Dataset: """Add a dense index using Faiss for fast retrieval. By default the index is done over the vectors of the specified column. You can specify `device` if you want to run it on GPU (`device` must be the GPU index). @@ -5900,7 +5928,7 @@ def add_faiss_index( The `index_name`/identifier of the index. This is the `index_name` that is used to call [`~datasets.Dataset.get_nearest_examples`] or [`~datasets.Dataset.search`]. By default it corresponds to `column`. - device (`Union[int, List[int]]`, *optional*): + device (`Union[int, list[int]]`, *optional*): If positive integer, this is the index of the GPU to use. If negative integer, use all GPUs. If a list of positive integers is passed in, run only on those GPUs. By default it uses the CPU. string_factory (`str`, *optional*): @@ -5960,7 +5988,7 @@ def add_faiss_index_from_external_arrays( device: Optional[int] = None, string_factory: Optional[str] = None, metric_type: Optional[int] = None, - custom_index: Optional["faiss.Index"] = None, # noqa: F821 + custom_index: Optional[faiss.Index] = None, # noqa: F821 batch_size: int = 1000, train_size: Optional[int] = None, faiss_verbose: bool = False, @@ -5980,7 +6008,7 @@ def add_faiss_index_from_external_arrays( index_name (`str`): The `index_name`/identifier of the index. This is the `index_name` that is used to call [`~datasets.Dataset.get_nearest_examples`] or [`~datasets.Dataset.search`]. - device (Optional `Union[int, List[int]]`, *optional*): + device (Optional `Union[int, list[int]]`, *optional*): If positive integer, this is the index of the GPU to use. If negative integer, use all GPUs. If a list of positive integers is passed in, run only on those GPUs. By default it uses the CPU. string_factory (`str`, *optional*): @@ -6018,7 +6046,7 @@ def add_elasticsearch_index( index_name: Optional[str] = None, host: Optional[str] = None, port: Optional[int] = None, - es_client: Optional["elasticsearch.Elasticsearch"] = None, # noqa: F821 + es_client: Optional[elasticsearch.Elasticsearch] = None, # noqa: F821 es_index_name: Optional[str] = None, es_index_config: Optional[dict] = None, ): @@ -6082,7 +6110,7 @@ def add_elasticsearch_index( @transmit_format @fingerprint_transform(inplace=False) - def add_item(self, item: dict, new_fingerprint: str): + def add_item(self, item: dict, new_fingerprint: str) -> Dataset: """Add item to Dataset. @@ -6134,7 +6162,7 @@ def add_item(self, item: dict, new_fingerprint: str): fingerprint=new_fingerprint, ) - def align_labels_with_mapping(self, label2id: dict, label_column: str) -> "Dataset": + def align_labels_with_mapping(self, label2id: dict, label_column: str) -> Dataset: """Align the dataset's label ID and label name mapping to match an input `label2id` mapping. This is useful when you want to ensure that a model's predicted labels are aligned with the dataset. The alignment in done using the lowercase label names. @@ -6217,13 +6245,13 @@ def _concatenate_map_style_datasets( info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, axis: int = 0, -): +) -> Dataset: """ Converts a list of :class:`Dataset` with the same schema into a single :class:`Dataset`. When you concatenate on axis 0, missing data are filled with None values. Args: - dsets (`List[datasets.Dataset]`): List of Datasets to concatenate. + dsets (`list[datasets.Dataset]`): List of Datasets to concatenate. info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. split (:class:`NamedSplit`, optional): Name of the dataset split. axis (``{0, 1}``, default ``0``, meaning over rows): @@ -6259,7 +6287,7 @@ def _concatenate_map_style_datasets( format = {} logger.info("Some of the datasets have disparate format. Resetting the format of the concatenated dataset.") - def apply_offset_to_indices_table(table, offset): + def apply_offset_to_indices_table(table, offset: int): if offset == 0: return table else: @@ -6327,14 +6355,14 @@ def apply_offset_to_indices_table(table, offset): def _interleave_map_style_datasets( - datasets: list["Dataset"], + datasets: list[Dataset], probabilities: Optional[list[float]] = None, seed: Optional[int] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", **kwargs, -) -> "Dataset": +) -> Dataset: """ Interleave several map-style datasets (sources) into a single map-style dataset. The new dataset is constructed by alternating between the sources to get the examples. @@ -6342,8 +6370,8 @@ def _interleave_map_style_datasets( If `probabilities` is not `None, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. Args: - datasets (`List[Dataset]`): list of datasets to interleave - probabilities (`List[float]`, optional, default None): If specified, the new dataset is constructed by sampling + datasets (`list[Dataset]`): list of datasets to interleave + probabilities (`list[float]`, optional, default None): If specified, the new dataset is constructed by sampling examples from one source at a time according to these probabilities. seed (`int`, optional, default None): The random seed used to choose a source for each example. info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. @@ -6405,7 +6433,7 @@ def _interleave_map_style_datasets( # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once bool_strategy_func = np.all if oversampling else np.any - def iter_random_indices(): + def iter_random_indices() -> Generator: """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" rng = np.random.default_rng(seed) while True: @@ -6464,7 +6492,7 @@ def get_indices_from_mask_function( indices_mapping: Optional[Table] = None, *args, **fn_kwargs, -): +) -> dict[str, Union[list, pa.Array]]: if batched: # we extract indices and rank from args *inputs, indices, rank = args @@ -6520,9 +6548,9 @@ async def async_get_indices_from_mask_function( with_rank: bool, input_columns: Optional[Union[str, list[str]]], indices_mapping: Optional[Table] = None, - *args, - **fn_kwargs, -): + *args: Any, + **fn_kwargs: Any, +) -> dict[str, list]: """same function but async""" if batched: # we extract indices and rank from args @@ -6546,7 +6574,7 @@ async def async_get_indices_from_mask_function( num_examples = len(batch[next(iter(batch.keys()))]) for i in range(num_examples): example = {key: batch[key][i] for key in batch} - additional_args = () + additional_args: tuple = () if with_indices: additional_args += (indices[i],) if with_rank: diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 64097b773f1..5987a06d212 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -139,7 +139,7 @@ def __init__( type: Optional[FeatureType] = None, try_type: Optional[FeatureType] = None, optimized_int_type: Optional[FeatureType] = None, - ): + ) -> None: # assert type is None or try_type is None, if type is not None and try_type is not None: raise ValueError("You cannot specify both type and try_type") @@ -359,7 +359,7 @@ def __init__( unit: str = "examples", embed_local_files: bool = False, storage_options: Optional[dict] = None, - ): + ) -> None: if path is None and stream is None: raise ValueError("At least one of path and stream must be provided.") if features is not None: @@ -410,17 +410,17 @@ def __init__( self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None self.hkey_record = [] - def __len__(self): + def __len__(self)-> int: """Return the number of writed and staged examples""" return self._num_examples + len(self.current_examples) + len(self.current_rows) - def __enter__(self): + def __enter__(self) -> object: return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def close(self): + def close(self) -> None: # Try closing if opened; if closed: pyarrow.lib.ArrowInvalid: Invalid operation on closed file if self.pa_writer: # it might be None try: @@ -430,7 +430,7 @@ def close(self): if self._closable_stream and not self.stream.closed: self.stream.close() # This also closes self.pa_writer if it is opened - def _build_writer(self, inferred_schema: pa.Schema): + def _build_writer(self, inferred_schema: pa.Schema) -> None: schema = self.schema inferred_features = Features.from_arrow_schema(inferred_schema) if self._features is not None: @@ -476,7 +476,7 @@ def _build_metadata(info: DatasetInfo, fingerprint: Optional[str] = None) -> dic metadata["fingerprint"] = fingerprint return {"huggingface": json.dumps(metadata)} - def write_examples_on_file(self): + def write_examples_on_file(self) -> None: """Write stored examples from the write-pool of examples. It makes a table out of the examples and write it.""" if not self.current_examples: return @@ -510,7 +510,7 @@ def write_examples_on_file(self): self.write_batch(batch_examples=batch_examples) self.current_examples = [] - def write_rows_on_file(self): + def write_rows_on_file(self) -> None: """Write stored rows from the write-pool of rows. It concatenates the single-row tables and it writes the resulting table.""" if not self.current_rows: return @@ -523,7 +523,7 @@ def write( example: dict[str, Any], key: Optional[Union[str, int, bytes]] = None, writer_batch_size: Optional[int] = None, - ): + ) -> None: """Add a given (Example,Key) pair to the write-pool of examples which is written to file. Args: @@ -551,7 +551,7 @@ def write( self.write_examples_on_file() - def check_duplicate_keys(self): + def check_duplicate_keys(self) -> None: """Raises error if duplicates found in a batch""" tmp_record = set() for hash, key in self.hkey_record: @@ -566,7 +566,7 @@ def check_duplicate_keys(self): else: tmp_record.add(hash) - def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None): + def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None) -> None: """Add a given single-row Table to the write-pool of rows which is written to file. Args: @@ -629,7 +629,7 @@ def write_batch( pa_table = pa.Table.from_arrays(arrays, schema=schema) self.write_table(pa_table, writer_batch_size) - def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None): + def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = None) -> None: """Write a Table to file. Args: @@ -647,7 +647,7 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non self._num_examples += pa_table.num_rows self.pa_writer.write_table(pa_table, writer_batch_size) - def finalize(self, close_stream=True): + def finalize(self, close_stream: bool = True) -> tuple[int, int]: self.write_rows_on_file() # In case current_examples < writer_batch_size, but user uses finalize() if self._check_duplicates: diff --git a/src/datasets/builder.py b/src/datasets/builder.py index d6992b9e19d..c344508c17a 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1066,10 +1066,10 @@ def _make_split_generators_kwargs(self, prepare_split_kwargs): def as_dataset( self, - split: Optional[Split] = None, - run_post_process=True, + split: Optional[Union[str, Split]] = None, + run_post_process: bool = True, verification_mode: Optional[Union[VerificationMode, str]] = None, - in_memory=False, + in_memory: bool = False, ) -> Union[Dataset, DatasetDict]: """Return a Dataset for the specified split. @@ -1249,7 +1249,7 @@ def as_streaming_dataset( self, split: Optional[str] = None, base_path: Optional[str] = None, - ) -> Union[dict[str, IterableDataset], IterableDataset]: + ) -> Union[IterableDatasetDict, IterableDataset]: if is_remote_filesystem(self._fs): raise NotImplementedError( f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet." @@ -1437,7 +1437,7 @@ def _prepare_split( self, split_generator: SplitGenerator, check_duplicate_keys: bool, - file_format="arrow", + file_format: str = "arrow", num_proc: Optional[int] = None, max_shard_size: Optional[Union[int, str]] = None, ): diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 9d9b4974aa7..f75e57f19d0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -15,6 +15,8 @@ # Lint as: python3 """This class handle features definition in datasets and some utilities to display table type.""" +from __future__ import annotations + import copy import json import re @@ -25,7 +27,8 @@ from dataclasses import InitVar, dataclass, field, fields from functools import reduce, wraps from operator import mul -from typing import Any, Callable, ClassVar, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, TypeVar, Union +from typing import Sequence as Sequence_ import numpy as np import pandas as pd @@ -47,6 +50,13 @@ from .video import Video +if TYPE_CHECKING: + import jax.numpy as jnp + import PIL.Image + import tensorflow as tf + import torch + + logger = logging.get_logger(__name__) @@ -131,7 +141,9 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType: purpose of this function. """ - def _dtype_error_msg(dtype, pa_dtype, examples=None, urls=None): + def _dtype_error_msg( + dtype: str, pa_dtype: pa.DataType, examples: Optional[list] = None, urls: Optional[list] = None + ) -> str: msg = f"{dtype} is not a validly formatted string representation of the pyarrow {pa_dtype} type." if examples: examples = ", ".join(examples[:-1]) + " or " + examples[-1] if len(examples) > 1 else examples[0] @@ -265,7 +277,28 @@ def _dtype_error_msg(dtype, pa_dtype, examples=None, urls=None): ) -def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_casting: bool) -> tuple[Any, bool]: +CastableOjbect = Union[ + np.ndarray, + torch.Tensor, + tf.Tensor, + jnp.ndarray, + PIL.Image.Image, + pd.Series, + pd.DataFrame, + pd.Timestamp, + pd.Timedelta, + Mapping, + dict, + list, + tuple, +] + + +def _cast_to_python_objects( + obj: CastableOjbect, + only_1d_for_numpy: bool, + optimize_list_casting: bool, +) -> tuple[Any, bool]: """ Cast pytorch/tensorflow/pandas objects to python numpy array/lists. It works recursively. @@ -442,7 +475,9 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas return obj, False -def cast_to_python_objects(obj: Any, only_1d_for_numpy=False, optimize_list_casting=True) -> Any: +def cast_to_python_objects( + obj: CastableOjbect, only_1d_for_numpy: bool = False, optimize_list_casting: bool = True +) -> Any: """ Cast numpy/pytorch/tensorflow/pandas objects to python lists. It works recursively. @@ -517,20 +552,20 @@ class Value: dtype: str id: Optional[str] = None # Automatically constructed - pa_type: ClassVar[Any] = None + pa_type: Any = None _type: str = field(default="Value", init=False, repr=False) - def __post_init__(self): + def __post_init__(self) -> None: if self.dtype == "double": # fix inferred type self.dtype = "float64" if self.dtype == "float": # fix inferred type self.dtype = "float32" self.pa_type = string_to_arrow(self.dtype) - def __call__(self): + def __call__(self) -> str: return self.pa_type - def encode_example(self, value): + def encode_example(self, value: Any) -> Any: if pa.types.is_boolean(self.pa_type): return bool(value) elif pa.types.is_integer(self.pa_type): @@ -544,14 +579,17 @@ def encode_example(self, value): class _ArrayXD: - def __post_init__(self): + shape: tuple + dtype: str + + def __post_init__(self) -> None: self.shape = tuple(self.shape) - def __call__(self): + def __call__(self) -> dict: pa_type = globals()[self.__class__.__name__ + "ExtensionType"](self.shape, self.dtype) return pa_type - def encode_example(self, value): + def encode_example(self, value: object) -> object: return value @@ -655,10 +693,13 @@ class Array5D(_ArrayXD): _type: str = field(default="Array5D", init=False, repr=False) +T_ArrayXDExtensionType = TypeVar("T_ArrayXDExtensionType", bound="_ArrayXDExtensionType") + + class _ArrayXDExtensionType(pa.ExtensionType): ndims: Optional[int] = None - def __init__(self, shape: tuple, dtype: str): + def __init__(self, shape: tuple, dtype: str) -> None: if self.ndims is None or self.ndims <= 1: raise ValueError("You must instantiate an array type with a value for dim that is > 1") if len(shape) != self.ndims: @@ -671,11 +712,13 @@ def __init__(self, shape: tuple, dtype: str): self.storage_dtype = self._generate_dtype(self.value_type) pa.ExtensionType.__init__(self, self.storage_dtype, f"{self.__class__.__module__}.{self.__class__.__name__}") - def __arrow_ext_serialize__(self): + def __arrow_ext_serialize__(self) -> bytes: return json.dumps((self.shape, self.value_type)).encode() @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized): + def __arrow_ext_deserialize__( + cls: type[T_ArrayXDExtensionType], storage_type, serialized: Union[str, bytes, bytearray] + ) -> T_ArrayXDExtensionType: args = json.loads(serialized) return cls(*args) @@ -683,13 +726,13 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): def __reduce__(self): return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) - def __hash__(self): + def __hash__(self) -> int: return hash((self.__class__, self.shape, self.value_type)) - def __arrow_ext_class__(self): + def __arrow_ext_class__(self) -> type[ArrayExtensionArray]: return ArrayExtensionArray - def _generate_dtype(self, dtype): + def _generate_dtype(self, dtype: str) -> pa.DataType: dtype = string_to_arrow(dtype) for d in reversed(self.shape): dtype = pa.list_(dtype) @@ -747,14 +790,14 @@ def _unnest_pa_type(pa_type: pa.DataType) -> pa.DataType: class ArrayExtensionArray(pa.ExtensionArray): - def __array__(self): + def __array__(self) -> np.ndarray: zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True) return self.to_numpy(zero_copy_only=zero_copy_only) def __getitem__(self, i): return self.storage[i] - def to_numpy(self, zero_copy_only=True): + def to_numpy(self, zero_copy_only: bool = True) -> np.ndarray: storage: pa.ListArray = self.storage null_mask = storage.is_null().to_numpy(zero_copy_only=False) @@ -798,7 +841,7 @@ def to_numpy(self, zero_copy_only=True): return numpy_arr - def to_pylist(self, maps_as_pydicts: Optional[Literal["lossy", "strict"]] = None): + def to_pylist(self, maps_as_pydicts: Optional[Literal["lossy", "strict"]] = None) -> list: zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True) numpy_arr = self.to_numpy(zero_copy_only=zero_copy_only) if self.type.shape[0] is None and numpy_arr.dtype == object: @@ -810,10 +853,10 @@ def to_pylist(self, maps_as_pydicts: Optional[Literal["lossy", "strict"]] = None class PandasArrayExtensionDtype(PandasExtensionDtype): _metadata = "value_type" - def __init__(self, value_type: Union["PandasArrayExtensionDtype", np.dtype]): + def __init__(self, value_type: Union[PandasArrayExtensionDtype, np.dtype]) -> None: self._value_type = value_type - def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]): + def __from_arrow__(self, array: Union[pa.Array, pa.ChunkedArray]) -> PandasArrayExtensionArray: if isinstance(array, pa.ChunkedArray): array = array.type.wrap_array(pa.concat_arrays([chunk.storage for chunk in array.chunks])) zero_copy_only = _is_zero_copy_only(array.storage.type, unnest=True) @@ -825,7 +868,7 @@ def construct_array_type(cls): return PandasArrayExtensionArray @property - def type(self) -> type: + def type(self) -> type[np.ndarray]: return np.ndarray @property @@ -1277,7 +1320,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType: return schema() -def encode_nested_example(schema, obj, level=0): +def encode_nested_example(schema, obj, level: int = 0): """Encode a nested example. This is used since some features (in particular ClassLabel) have some logic during encoding. @@ -1449,7 +1492,7 @@ def register_feature( _FEATURE_TYPES[feature_type] = feature_cls -def generate_from_dict(obj: Any): +def generate_from_dict(obj: dict): """Regenerate the nested feature object from a deserialized dict. We use the '_type' fields to get the dataclass name to load. @@ -1783,7 +1826,7 @@ class Features(dict): - [`Translation`] or [`TranslationVariableLanguages`] feature specific to Machine Translation. """ - def __init__(*args, **kwargs): + def __init__(*args, **kwargs) -> None: # self not in the signature to allow passing self as a kwarg if not args: raise TypeError("descriptor '__init__' of 'Features' object needs an argument") @@ -1801,11 +1844,11 @@ def __init__(*args, **kwargs): popitem = keep_features_dicts_synced(dict.popitem) clear = keep_features_dicts_synced(dict.clear) - def __reduce__(self): + def __reduce__(self) -> tuple[type[Features], tuple[dict]]: return Features, (dict(self),) @property - def type(self): + def type(self) -> pa.DataType: """ Features field types. @@ -1815,7 +1858,7 @@ def type(self): return get_nested_type(self) @property - def arrow_schema(self): + def arrow_schema(self) -> pa.Schema: """ Features schema. @@ -1826,7 +1869,7 @@ def arrow_schema(self): return pa.schema(self.type).with_metadata({"huggingface": json.dumps(hf_metadata)}) @classmethod - def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features": + def from_arrow_schema(cls, pa_schema: pa.Schema) -> Features: """ Construct [`Features`] from Arrow Schema. It also checks the schema metadata for Hugging Face Datasets features. @@ -1860,7 +1903,7 @@ def from_arrow_schema(cls, pa_schema: pa.Schema) -> "Features": return cls(**obj) @classmethod - def from_dict(cls, dic) -> "Features": + def from_dict(cls, dic: dict[str, Any]) -> Features: """ Construct [`Features`] from dict. @@ -1966,7 +2009,7 @@ def to_yaml_types(obj: dict) -> dict: return to_yaml_types(to_yaml_inner(yaml_data)["struct"]) @classmethod - def _from_yaml_list(cls, yaml_data: list) -> "Features": + def _from_yaml_list(cls, yaml_data: list) -> Features: yaml_data = copy.deepcopy(yaml_data) # we convert the list obtained from YAML data into the dict representation that is used for JSON dump @@ -2034,7 +2077,7 @@ def from_yaml_inner(obj: Union[dict, list]) -> Union[dict, list]: return cls.from_dict(from_yaml_inner(yaml_data)) - def encode_example(self, example): + def encode_example(self, example: dict[str, Any]): """ Encode example into a format for Arrow. @@ -2048,7 +2091,7 @@ def encode_example(self, example): example = cast_to_python_objects(example) return encode_nested_example(self, example) - def encode_column(self, column, column_name: str): + def encode_column(self, column: list, column_name: str) -> list: """ Encode column into a format for Arrow. @@ -2243,7 +2286,7 @@ def recursive_reorder(source, target, stack=""): return Features(recursive_reorder(self, other)) - def flatten(self, max_depth=16) -> "Features": + def flatten(self, max_depth: int = 16) -> Features: """Flatten the features. Every dictionary column is removed and is replaced by all the subfields it contains. The new fields are named by concatenating the name of the original column and the subfield name like this: `.`. @@ -2298,7 +2341,7 @@ def flatten(self, max_depth=16) -> "Features": def _align_features(features_list: list[Features]) -> list[Features]: """Align dictionaries of features so that the keys that are found in multiple dictionaries share the same feature.""" - name2feature = {} + name2feature: dict = {} for features in features_list: for k, v in features.items(): if k in name2feature and isinstance(v, dict): @@ -2310,12 +2353,12 @@ def _align_features(features_list: list[Features]) -> list[Features]: return [Features({k: name2feature[k] for k in features.keys()}) for features in features_list] -def _check_if_features_can_be_aligned(features_list: list[Features]): +def _check_if_features_can_be_aligned(features_list: list[Features]) -> None: """Check if the dictionaries of features can be aligned. Two dictonaries of features can be aligned if the keys they share have the same type or some of them is of type `Value("null")`. """ - name2feature = {} + name2feature: dict = {} for features in features_list: for k, v in features.items(): if k not in name2feature or (isinstance(name2feature[k], Value) and name2feature[k].dtype == "null"): diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index f8fc92c9d1f..13e1c30d192 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import operator from collections.abc import Iterable, Mapping, MutableMapping from functools import partial # Lint as: python3 -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Optional, TypeVar, Union import numpy as np import pandas as pd @@ -30,6 +32,8 @@ T = TypeVar("T") +KT = TypeVar("KT") +VT = TypeVar("VT") RowFormat = TypeVar("RowFormat") ColumnFormat = TypeVar("ColumnFormat") @@ -40,7 +44,7 @@ def _is_range_contiguous(key: range) -> bool: return key.step == 1 and key.stop >= key.start -def _raise_bad_key_type(key: Any): +def _raise_bad_key_type(key: object) -> None: raise TypeError( f"Wrong key type: '{key}' of type '{type(key)}'. Expected one of int, slice, range, str or Iterable." ) @@ -150,7 +154,7 @@ def extract_batch(self, pa_table: pa.Table) -> dict: class NumpyArrowExtractor(BaseArrowExtractor[dict, np.ndarray, dict]): - def __init__(self, **np_array_kwargs): + def __init__(self, **np_array_kwargs: Any) -> None: self.np_array_kwargs = np_array_kwargs def extract_row(self, pa_table: pa.Table) -> dict: @@ -215,7 +219,7 @@ def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame: class PythonFeaturesDecoder: def __init__( self, features: Optional[Features], token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None - ): + ) -> None: self.features = features self.token_per_repo_id = token_per_repo_id @@ -230,7 +234,7 @@ def decode_batch(self, batch: dict) -> dict: class PandasFeaturesDecoder: - def __init__(self, features: Optional[Features]): + def __init__(self, features: Optional[Features]) -> None: self.features = features def decode_row(self, row: pd.DataFrame) -> pd.DataFrame: @@ -260,18 +264,17 @@ def decode_column(self, column: pd.Series, column_name: str) -> pd.Series: def decode_batch(self, batch: pd.DataFrame) -> pd.DataFrame: return self.decode_row(batch) - class LazyDict(MutableMapping): """A dictionary backed by Arrow data. The values are formatted on-the-fly when accessing the dictionary.""" - def __init__(self, pa_table: pa.Table, formatter: "Formatter"): + def __init__(self, pa_table: pa.Table, formatter: Formatter) -> None: self.pa_table = pa_table self.formatter = formatter self.data = dict.fromkeys(pa_table.column_names) self.keys_to_format = set(self.data.keys()) - def __len__(self): + def __len__(self) -> int: return len(self.data) def __getitem__(self, key): @@ -282,27 +285,27 @@ def __getitem__(self, key): self.keys_to_format.remove(key) return value - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: if key in self.keys_to_format: self.keys_to_format.remove(key) self.data[key] = value - def __delitem__(self, key) -> None: + def __delitem__(self, key: str) -> None: if key in self.keys_to_format: self.keys_to_format.remove(key) del self.data[key] - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.data) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: return key in self.data - def __repr__(self): + def __repr__(self) -> str: self._format_all() return repr(self.data) - def __or__(self, other): + def __or__(self, other: Union[dict, LazyDict]) -> LazyDict: if isinstance(other, LazyDict): inst = self.copy() other = other.copy() @@ -317,7 +320,7 @@ def __or__(self, other): return inst return NotImplemented - def __ror__(self, other): + def __ror__(self, other: Union[dict, LazyDict]) -> LazyDict: if isinstance(other, LazyDict): inst = self.copy() other = other.copy() @@ -332,7 +335,7 @@ def __ror__(self, other): return inst return NotImplemented - def __ior__(self, other): + def __ior__(self, other: Union[dict, LazyDict]) -> LazyDict: if isinstance(other, LazyDict): other = other.copy() other._format_all() @@ -343,7 +346,7 @@ def __ior__(self, other): self.data |= other return self - def __copy__(self): + def __copy__(self) -> LazyDict: # Identical to `UserDict.__copy__` inst = self.__class__.__new__(self.__class__) inst.__dict__.update(self.__dict__) @@ -352,13 +355,13 @@ def __copy__(self): inst.__dict__["keys_to_format"] = self.__dict__["keys_to_format"].copy() return inst - def copy(self): + def copy(self) -> object: import copy return copy.copy(self) @classmethod - def fromkeys(cls, iterable, value=None): + def fromkeys(cls: type[object], iterable: Iterable, value=None): raise NotImplementedError def format(self, key): @@ -395,7 +398,7 @@ def __init__( self, features: Optional[Features] = None, token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None, - ): + ) -> None: self.features = features self.token_per_repo_id = token_per_repo_id self.python_features_decoder = PythonFeaturesDecoder(self.features, self.token_per_repo_id) @@ -420,7 +423,7 @@ def format_batch(self, pa_table: pa.Table) -> BatchFormat: class TensorFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]): - def recursive_tensorize(self, data_struct: dict): + def recursive_tensorize(self, data_struct: dict) -> None: raise NotImplementedError @@ -444,7 +447,12 @@ def format_batch(self, pa_table: pa.Table) -> pa.Table: class PythonFormatter(Formatter[Mapping, list, Mapping]): - def __init__(self, features=None, lazy=False, token_per_repo_id=None): + def __init__( + self, + features: Optional[Features] = None, + lazy: bool = False, + token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None, + ) -> None: super().__init__(features, token_per_repo_id) self.lazy = lazy @@ -498,7 +506,13 @@ class CustomFormatter(Formatter[dict, ColumnFormat, dict]): to return. """ - def __init__(self, transform: Callable[[dict], dict], features=None, token_per_repo_id=None, **kwargs): + def __init__( + self, + transform: Callable[[dict], dict], + features: Optional[Features] = None, + token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None, + **kwargs: Any, + ) -> None: super().__init__(features=features, token_per_repo_id=token_per_repo_id) self.transform = transform @@ -618,8 +632,8 @@ def format_table( key: Union[int, slice, range, str, Iterable], formatter: Formatter, format_columns: Optional[list] = None, - output_all_columns=False, -): + output_all_columns: bool = False, +) -> Union[Mapping, list]: """ Format a Table depending on the key that was used and a Formatter object. diff --git a/src/datasets/info.py b/src/datasets/info.py index 1217c58bb39..2022178297b 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -28,6 +28,8 @@ - etc. """ +from __future__ import annotations + import copy import dataclasses import json @@ -35,7 +37,7 @@ import posixpath from dataclasses import dataclass from pathlib import Path -from typing import ClassVar, Optional, Union +from typing import ClassVar, List, Optional, TypeVar, Union import fsspec from fsspec.core import url_to_fs @@ -77,7 +79,7 @@ class PostProcessedInfo: features: Optional[Features] = None resources_checksums: Optional[dict] = None - def __post_init__(self): + def __post_init__(self) -> None: # Convert back to the correct classes when we reload from dict if self.features is not None and not isinstance(self.features, Features): self.features = Features.from_dict(self.features) @@ -164,7 +166,7 @@ class DatasetInfo: "splits", ] - def __post_init__(self): + def __post_init__(self) -> None: # Convert back to the correct classes when we reload from dict if self.features is not None and not isinstance(self.features, Features): self.features = Features.from_dict(self.features) @@ -183,7 +185,9 @@ def __post_init__(self): else: self.supervised_keys = SupervisedKeysData(**self.supervised_keys) - def write_to_directory(self, dataset_info_dir, pretty_print=False, storage_options: Optional[dict] = None): + def write_to_directory( + self, dataset_info_dir: str, pretty_print: bool = False, storage_options: Optional[dict] = None + ) -> None: """Write `DatasetInfo` and license (if present) as JSON files to `dataset_info_dir`. Args: @@ -212,16 +216,16 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False, storage_optio with fs.open(posixpath.join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f: self._dump_license(f) - def _dump_info(self, file, pretty_print=False): + def _dump_info(self, file, pretty_print: bool = False) -> None: """Dump info in `file` file-like object open in bytes mode (to support remote files)""" file.write(json.dumps(asdict(self), indent=4 if pretty_print else None).encode("utf-8")) - def _dump_license(self, file): + def _dump_license(self, file) -> None: """Dump license in `file` file-like object open in bytes mode (to support remote files)""" file.write(self.license.encode("utf-8")) @classmethod - def from_merge(cls, dataset_infos: list["DatasetInfo"]): + def from_merge(cls, dataset_infos: list[DatasetInfo]) -> DatasetInfo: dataset_infos = [dset_info.copy() for dset_info in dataset_infos if dset_info is not None] if len(dataset_infos) > 0 and all(dataset_infos[0] == dset_info for dset_info in dataset_infos): @@ -283,7 +287,7 @@ def from_dict(cls, dataset_info_dict: dict) -> "DatasetInfo": field_names = {f.name for f in dataclasses.fields(cls)} return cls(**{k: v for k, v in dataset_info_dict.items() if k in field_names}) - def update(self, other_dataset_info: "DatasetInfo", ignore_none=True): + def update(self, other_dataset_info: "DatasetInfo", ignore_none: bool = True) -> None: self_dict = self.__dict__ self_dict.update( **{ @@ -321,8 +325,11 @@ def _from_yaml_dict(cls, yaml_data: dict) -> "DatasetInfo": return cls(**{k: v for k, v in yaml_data.items() if k in field_names}) +T_DatasetInfoDict = TypeVar("T_DatasetInfoDict", bound=DatasetInfosDict) + + class DatasetInfosDict(dict[str, DatasetInfo]): - def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=False) -> None: + def write_to_directory(self, dataset_infos_dir: str, overwrite: bool = False, pretty_print: bool = False) -> None: total_dataset_infos = {} dataset_infos_path = os.path.join(dataset_infos_dir, config.DATASETDICT_INFOS_FILENAME) dataset_readme_path = os.path.join(dataset_infos_dir, config.REPOCARD_FILENAME) @@ -351,7 +358,7 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa dataset_card.save(Path(dataset_readme_path)) @classmethod - def from_directory(cls, dataset_infos_dir) -> "DatasetInfosDict": + def from_directory(cls: type[T_DatasetInfoDict], dataset_infos_dir: str) -> T_DatasetInfoDict: logger.info(f"Loading Dataset Infos from {dataset_infos_dir}") # Load the info from the YAML part of README.md if os.path.exists(os.path.join(dataset_infos_dir, config.REPOCARD_FILENAME)): @@ -371,7 +378,7 @@ def from_directory(cls, dataset_infos_dir) -> "DatasetInfosDict": return cls() @classmethod - def from_dataset_card_data(cls, dataset_card_data: DatasetCardData) -> "DatasetInfosDict": + def from_dataset_card_data(cls: type[T_DatasetInfoDict], dataset_card_data: DatasetCardData) -> T_DatasetInfoDict: if isinstance(dataset_card_data.get("dataset_info"), (list, dict)): if isinstance(dataset_card_data["dataset_info"], list): return cls( diff --git a/src/datasets/io/abc.py b/src/datasets/io/abc.py index a1913cc20e3..ea07a2672d6 100644 --- a/src/datasets/io/abc.py +++ b/src/datasets/io/abc.py @@ -11,12 +11,12 @@ def __init__( path_or_paths: Optional[NestedDataStructureLike[PathLike]] = None, split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, **kwargs, - ): + ) -> None: self.path_or_paths = path_or_paths self.split = split if split or isinstance(path_or_paths, dict) else "train" self.features = features @@ -35,12 +35,12 @@ class AbstractDatasetInputStream(ABC): def __init__( self, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, **kwargs, - ): + ) -> None: self.features = features self.cache_dir = cache_dir self.keep_in_memory = keep_in_memory @@ -49,5 +49,5 @@ def __init__( self.kwargs = kwargs @abstractmethod - def read(self) -> Union[Dataset, IterableDataset]: + def read(self) -> Union[Dataset, IterableDataset, dict[str, IterableDataset]]: pass diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index 4ac2ea1135b..0d3cda40c89 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -4,7 +4,7 @@ import fsspec -from .. import Dataset, Features, NamedSplit, config +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit, config from ..formatting import query_table from ..packaged_modules.csv.csv import Csv from ..utils import tqdm as hf_tqdm @@ -18,7 +18,7 @@ def __init__( path_or_paths: NestedDataStructureLike[PathLike], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, @@ -42,7 +42,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[Dataset, DatasetDict, IterableDatasetDict, IterableDataset]: # Build iterable dataset if self.streaming: dataset = self.builder.as_streaming_dataset(split=self.split) @@ -75,7 +75,7 @@ def __init__( num_proc: Optional[int] = None, storage_options: Optional[dict] = None, **to_csv_kwargs, - ): + ) -> None: if num_proc is not None and num_proc <= 0: raise ValueError(f"num_proc {num_proc} must be an integer > 0.") diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index b10609cac23..8dd3373ad43 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -1,23 +1,33 @@ -from typing import Callable, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional, Union + +from datasets.dataset_dict import DatasetDict, IterableDatasetDict from .. import Features, NamedSplit, Split from ..packaged_modules.generator.generator import Generator from .abc import AbstractDatasetInputStream +if TYPE_CHECKING: + from ..arrow_dataset import Dataset, DatasetDict + from ..dataset_dict import IterableDatasetDict + from ..iterable_dataset import IterableDataset + + class GeneratorDatasetInputStream(AbstractDatasetInputStream): def __init__( self, generator: Callable, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, split: NamedSplit = Split.TRAIN, **kwargs, - ): + ) -> None: super().__init__( features=features, cache_dir=cache_dir, @@ -35,7 +45,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[Dataset, DatasetDict, IterableDatasetDict, IterableDataset]: # Build iterable dataset if self.streaming: dataset = self.builder.as_streaming_dataset(split=self.builder.config.split) diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index 41abfa518cc..e346b71c14a 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -4,7 +4,7 @@ import fsspec -from .. import Dataset, Features, NamedSplit, config +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit, config from ..formatting import query_table from ..packaged_modules.json.json import Json from ..utils import tqdm as hf_tqdm @@ -18,13 +18,13 @@ def __init__( path_or_paths: NestedDataStructureLike[PathLike], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, field: Optional[str] = None, num_proc: Optional[int] = None, **kwargs, - ): + ) -> None: super().__init__( path_or_paths, split=split, @@ -45,7 +45,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[Dataset, DatasetDict, IterableDatasetDict, IterableDataset]: # Build iterable dataset if self.streaming: dataset = self.builder.as_streaming_dataset(split=self.split) @@ -78,7 +78,7 @@ def __init__( num_proc: Optional[int] = None, storage_options: Optional[dict] = None, **to_json_kwargs, - ): + ) -> None: if num_proc is not None and num_proc <= 0: raise ValueError(f"num_proc {num_proc} must be an integer > 0.") diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index d34f5110204..13949f27c53 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -4,7 +4,7 @@ import fsspec import pyarrow.parquet as pq -from .. import Dataset, Features, NamedSplit, config +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit, config from ..arrow_writer import get_writer_batch_size from ..formatting import query_table from ..packaged_modules import _PACKAGED_DATASETS_MODULES @@ -20,7 +20,7 @@ def __init__( path_or_paths: NestedDataStructureLike[PathLike], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, @@ -46,7 +46,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: # Build iterable dataset if self.streaming: dataset = self.builder.as_streaming_dataset(split=self.split) @@ -78,7 +78,7 @@ def __init__( batch_size: Optional[int] = None, storage_options: Optional[dict] = None, **parquet_writer_kwargs, - ): + ) -> None: self.dataset = dataset self.path_or_buf = path_or_buf self.batch_size = batch_size or get_writer_batch_size(dataset.features) diff --git a/src/datasets/io/spark.py b/src/datasets/io/spark.py index 7562ba1fb5f..7ce49b3ae88 100644 --- a/src/datasets/io/spark.py +++ b/src/datasets/io/spark.py @@ -1,8 +1,10 @@ -from typing import Optional +from __future__ import annotations + +from typing import Optional, Union import pyspark -from .. import Features, NamedSplit +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit from ..download import DownloadMode from ..packaged_modules.spark.spark import Spark from .abc import AbstractDatasetReader @@ -21,13 +23,13 @@ def __init__( split: Optional[NamedSplit] = None, features: Optional[Features] = None, streaming: bool = True, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, - working_dir: str = None, + working_dir: Optional[str] = None, load_from_cache_file: bool = True, file_format: str = "arrow", **kwargs, - ): + ) -> None: super().__init__( split=split, features=features, @@ -46,7 +48,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[Dataset, DatasetDict, IterableDatasetDict, IterableDataset]: if self.streaming: return self.builder.as_streaming_dataset(split=self.split) download_mode = None if self._load_from_cache_file else DownloadMode.FORCE_REDOWNLOAD diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 2331e3e6407..248f6fd76eb 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import multiprocessing from typing import TYPE_CHECKING, Optional, Union @@ -13,6 +15,8 @@ import sqlalchemy + from .. import DatasetDict + class SqlDatasetReader(AbstractDatasetInputStream): def __init__( @@ -20,10 +24,10 @@ def __init__( sql: Union[str, "sqlalchemy.sql.Selectable"], con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, **kwargs, - ): + ) -> None: super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs) self.builder = Sql( cache_dir=cache_dir, @@ -33,7 +37,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[Dataset, DatasetDict]: download_config = None download_mode = None verification_mode = None @@ -58,11 +62,11 @@ def __init__( self, dataset: Dataset, name: str, - con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], + con: Union[str, sqlalchemy.engine.Connection, sqlalchemy.engine.Engine, sqlite3.Connection], batch_size: Optional[int] = None, num_proc: Optional[int] = None, **to_sql_kwargs, - ): + ) -> None: if num_proc is not None and num_proc <= 0: raise ValueError(f"num_proc {num_proc} must be an integer > 0.") @@ -81,7 +85,7 @@ def write(self) -> int: written = self._write(index=index, **self.to_sql_kwargs) return written - def _batch_sql(self, args): + def _batch_sql(self, args) -> int: offset, index, to_sql_kwargs = args to_sql_kwargs = {**to_sql_kwargs, "if_exists": "append"} if offset > 0 else to_sql_kwargs batch = query_table( @@ -93,7 +97,7 @@ def _batch_sql(self, args): num_rows = df.to_sql(self.name, self.con, index=index, **to_sql_kwargs) return num_rows or len(df) - def _write(self, index, **to_sql_kwargs) -> int: + def _write(self, index: bool, **to_sql_kwargs) -> int: """Writes the pyarrow table as SQL to a database. Caller is responsible for opening and closing the SQL connection. diff --git a/src/datasets/io/text.py b/src/datasets/io/text.py index 58963f3c7ab..289380b4db1 100644 --- a/src/datasets/io/text.py +++ b/src/datasets/io/text.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import Optional, Union -from .. import Features, NamedSplit +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit from ..packaged_modules.text.text import Text from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetReader @@ -12,12 +12,12 @@ def __init__( path_or_paths: NestedDataStructureLike[PathLike], split: Optional[NamedSplit] = None, features: Optional[Features] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, **kwargs, - ): + ) -> None: super().__init__( path_or_paths, split=split, @@ -36,7 +36,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self) -> Union[IterableDatasetDict, IterableDataset, Dataset, DatasetDict]: # Build iterable dataset if self.streaming: dataset = self.builder.as_streaming_dataset(split=self.split) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 80e4b6b7292..01a18dd0054 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import copy import inspect @@ -10,7 +12,19 @@ from dataclasses import dataclass from functools import partial from itertools import cycle, islice -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Iterator, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, +) import fsspec.asyn import numpy as np @@ -18,7 +32,7 @@ import pyarrow as pa from . import config -from .arrow_dataset import Dataset, DatasetInfoMixin +from .arrow_dataset import DatasetInfoMixin from .features import Features from .features.features import ( FeatureType, @@ -45,18 +59,20 @@ if TYPE_CHECKING: + import pyspark import torch + from .arrow_dataset import Dataset logger = get_logger(__name__) Key = Union[int, str] -def identity_func(x): +def identity_func(x: object) -> object: return x -def _rename_columns_fn(example: dict, column_mapping: dict[str, str]): +def _rename_columns_fn(example: Sequence, column_mapping: Mapping) -> dict[str, Any]: if any(col not in example for col in column_mapping): raise ValueError( f"Error when renaming {list(column_mapping)} to {list(column_mapping.values())}: columns {set(column_mapping) - set(example)} are not in the dataset." @@ -71,13 +87,13 @@ def _rename_columns_fn(example: dict, column_mapping: dict[str, str]): } -def add_column_fn(example: dict, idx: int, name: str, column: list[dict]): +def add_column_fn(example: Sequence, idx: int, name: str, column: Sequence[Mapping]) -> dict[str, Mapping]: if name in example: raise ValueError(f"Error when adding {name}: column {name} is already in the dataset.") return {name: column[idx]} -def _infer_features_from_batch(batch: dict[str, list], try_features: Optional[Features] = None) -> Features: +def _infer_features_from_batch(batch: Mapping, try_features: Optional[Features] = None) -> Features: pa_table = pa.Table.from_pydict(batch) if try_features is not None: try: @@ -87,7 +103,7 @@ def _infer_features_from_batch(batch: dict[str, list], try_features: Optional[Fe return Features.from_arrow_schema(pa_table.schema) -def _examples_to_batch(examples: list[dict[str, Any]]) -> dict[str, list]: +def _examples_to_batch(examples: Iterable[Mapping[str, Sequence]]) -> dict: # we order the columns by order of appearance # to do so, we use a dict as an ordered set cols = {col: None for example in examples for col in example} @@ -96,7 +112,7 @@ def _examples_to_batch(examples: list[dict[str, Any]]) -> dict[str, list]: return dict(zip(cols, arrays)) -def _batch_to_examples(batch: dict[str, list]) -> Iterator[dict[str, Any]]: +def _batch_to_examples(batch: Mapping[str, Sequence]) -> Iterator[dict[str, Sequence]]: """Convert a batch (dict of examples) to examples list""" n_examples = 0 if len(batch) == 0 else len(batch[next(iter(batch))]) for i in range(n_examples): @@ -111,7 +127,7 @@ def _convert_to_arrow( """Convert and group examples in Arrow tables of size `batch_size`. Args: - iterable (`Iterable[Tuple[Key, dict]]`): + iterable (`Iterable[tuple[Key, dict]]`): An examples iterable containing tuples (example_key, example) of type (int/str, dict) batch_size (`Optional[int]`): Size of each sub-table to yield. If None or <= 0, yields the full table. @@ -139,9 +155,9 @@ class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" def __init__(self) -> None: - self._state_dict: Optional[Union[list, dict]] = None + self._state_dict: Optional[MutableMapping[str, Any]] = None - def __iter__(self) -> Iterator[tuple[Key, dict]]: + def __iter__(self) -> Generator[tuple[Key, pa.Table]]: """An examples iterable should yield tuples (example_key, example) of type (int/str, dict)""" raise NotImplementedError(f"{type(self)} doesn't implement __iter__ yet") @@ -164,11 +180,11 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples """ raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "_BaseExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") - def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous=True) -> list[int]: + def split_shard_indices_by_worker(self, num_shards: int, index: int, contiguous: bool = True) -> list[int]: if contiguous: div = self.num_shards // num_shards mod = self.num_shards % num_shards @@ -185,8 +201,8 @@ def num_shards(self) -> int: def _init_state_dict(self) -> dict: raise NotImplementedError(f"{type(self)} doesn't implement _init_state_dict yet") - def load_state_dict(self, state_dict: dict) -> dict: - def _inner_load_state_dict(state, new_state): + def load_state_dict(self, state_dict: Optional[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]: + def _inner_load_state_dict(state: Optional[Mapping], new_state: Optional[Mapping]) -> Optional[Mapping]: if new_state is not None and isinstance(state, dict): for key in new_state: state[key] = _inner_load_state_dict(state[key], new_state[key]) @@ -199,14 +215,14 @@ def _inner_load_state_dict(state, new_state): return _inner_load_state_dict(self._state_dict, state_dict) - def state_dict(self) -> dict: + def state_dict(self) -> Mapping[str, Any]: if self._state_dict: return copy.deepcopy(self._state_dict) raise RuntimeError("State dict is not initialized, please call ex_iterable._init_state_dict() first.") class ExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_examples_fn: Callable[..., tuple[Key, dict]], kwargs: dict): + def __init__(self, generate_examples_fn: Callable[..., tuple[Key, dict]], kwargs: dict) -> None: super().__init__() self.generate_examples_fn = generate_examples_fn self.kwargs = kwargs @@ -215,7 +231,7 @@ def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 for gen_kwags in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 @@ -230,7 +246,7 @@ def __iter__(self): def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": return ShuffledDataSourcesExamplesIterable(self.generate_examples_fn, self.kwargs, generator) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> "ExamplesIterable": """Keep only the requested shard.""" gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) @@ -245,7 +261,7 @@ def num_shards(self) -> int: class ShuffledDataSourcesExamplesIterable(ExamplesIterable): def __init__( self, generate_examples_fn: Callable[..., tuple[Key, dict]], kwargs: dict, generator: np.random.Generator - ): + ) -> None: super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) @@ -253,7 +269,7 @@ def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: """Shuffle the kwargs order to shuffle shards""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) @@ -270,7 +286,7 @@ def __iter__(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> "ExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) @@ -280,7 +296,7 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "E class ArrowExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_tables_fn: Callable[..., tuple[Key, pa.Table]], kwargs: dict): + def __init__(self, generate_tables_fn: Callable[..., tuple[Key, pa.Table]], kwargs: dict) -> None: super().__init__() self.generate_tables_fn = generate_tables_fn self.kwargs = kwargs @@ -331,10 +347,10 @@ def _iter_arrow(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shuffle_data_sources(self, generator: np.random.Generator) -> "ArrowExamplesIterable": + def shuffle_data_sources(self, generator: np.random.Generator) -> ArrowExamplesIterable: return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> ArrowExamplesIterable: """Keep only the requested shard.""" gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) @@ -352,7 +368,7 @@ def __init__( generate_tables_fn: Callable[..., tuple[Key, pa.Table]], kwargs: dict, generator: np.random.Generator, - ): + ) -> None: super().__init__(generate_tables_fn, kwargs) self.generator = deepcopy(generator) @@ -360,7 +376,7 @@ def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict - def __iter__(self): + def __iter__(self) -> Generator: """Shuffle the kwargs order to shuffle shards""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) @@ -407,7 +423,7 @@ def _iter_arrow(self): self._state_dict["shard_idx"] += 1 self._state_dict["shard_example_idx"] = 0 - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ArrowExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> ArrowExamplesIterable: """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) @@ -417,7 +433,9 @@ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "A class RebatchedArrowExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int], drop_last_batch: bool = False): + def __init__( + self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int], drop_last_batch: bool = False + ) -> None: super().__init__() self.ex_iterable = ex_iterable self.batch_size = batch_size @@ -428,11 +446,11 @@ def iter_arrow(self): return self._iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterable.features def _init_state_dict(self) -> dict: @@ -540,7 +558,9 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RebatchedArro self.ex_iterable.shuffle_data_sources(generator), self.batch_size, self.drop_last_batch ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RebatchedArrowExamplesIterable": + def shard_data_sources( + self, num_shards: int, index: int, contiguous: bool = True + ) -> "RebatchedArrowExamplesIterable": return RebatchedArrowExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.batch_size, @@ -553,7 +573,7 @@ def num_shards(self) -> int: class SelectColumnsIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: list[str]): + def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: list[str]) -> None: super().__init__() self.ex_iterable = ex_iterable self.column_names = column_names @@ -564,11 +584,11 @@ def iter_arrow(self): return self._iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterable.features def _init_state_dict(self) -> dict: @@ -587,7 +607,7 @@ def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]: def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable": return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SelectColumnsIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> "SelectColumnsIterable": return SelectColumnsIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), self.column_names ) @@ -598,7 +618,7 @@ def num_shards(self) -> int: class StepExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int): + def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int) -> None: super().__init__() self.ex_iterable = ex_iterable self.step = step @@ -606,18 +626,18 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int): # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterable.features def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: ex_iterator = iter(self.ex_iterable) while True: batch = list(islice(ex_iterator, self.step)) @@ -631,7 +651,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesI self.ex_iterable.shuffle_data_sources(generator), step=self.step, offset=self.offset ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "StepExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> "StepExamplesIterable": return StepExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), step=self.step, @@ -648,7 +668,7 @@ def __init__( self, ex_iterables: list[_BaseExamplesIterable], stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", - ): + ) -> None: super().__init__() self.ex_iterables = ex_iterables self.stopping_strategy = stopping_strategy @@ -659,14 +679,14 @@ def __init__( # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterables[0].is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterables[0].features - def _get_indices_iterator(self): + def _get_indices_iterator(self) -> Iterator: # this is an infinite iterator to keep track of which iterator we want to pick examples from ex_iterable_idx = self._state_dict["ex_iterable_idx"] if self._state_dict else 0 for next_ex_iterable_idx in islice(cycle(range(len(self.ex_iterables))), ex_iterable_idx + 1, None): @@ -685,7 +705,7 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: # we use this to buffer one example of each iterator to know if an iterator is exhausted nexts = [None] * len(self.ex_iterables) # because of that, we need to rewind 1 example when reloading the state dict @@ -737,7 +757,7 @@ def num_shards(self) -> int: return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) def shard_data_sources( - self, num_shards: int, index: int, contiguous=True + self, num_shards: int, index: int, contiguous: bool = True ) -> "CyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return CyclingMultiSourcesExamplesIterable( @@ -759,16 +779,16 @@ class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): This is done with `_apply_feature_types_on_example`. """ - def __init__(self, ex_iterables: list[_BaseExamplesIterable]): + def __init__(self, ex_iterables: list[_BaseExamplesIterable]) -> None: super().__init__() self.ex_iterables = ex_iterables @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterables[0].is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterables[0].features @property @@ -784,14 +804,14 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: ex_iterable_idx_start = self._state_dict["ex_iterable_idx"] if self._state_dict else 0 for ex_iterable in islice(self.ex_iterables, ex_iterable_idx_start, None): yield from ex_iterable if self._state_dict: self._state_dict["ex_iterable_idx"] += 1 - def _iter_arrow(self): + def _iter_arrow(self) -> Iterator: ex_iterable_idx_start = self._state_dict["ex_iterable_idx"] if self._state_dict else 0 for ex_iterable in islice(self.ex_iterables, ex_iterable_idx_start, None): yield from ex_iterable.iter_arrow() @@ -813,7 +833,7 @@ def num_shards(self) -> int: return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables) def shard_data_sources( - self, num_shards: int, index: int, contiguous=True + self, num_shards: int, index: int, contiguous: bool = True ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return VerticallyConcatenatedMultiSourcesExamplesIterable( @@ -847,17 +867,17 @@ class HorizontallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable This is done with `_apply_feature_types_on_example`. """ - def __init__(self, ex_iterables: list[_BaseExamplesIterable]): + def __init__(self, ex_iterables: list[_BaseExamplesIterable]) -> None: super().__init__() self.ex_iterables = ex_iterables # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterables[0].is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterables[0].features def _init_state_dict(self) -> dict: @@ -867,7 +887,7 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: ex_iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables] for i in itertools.count(): keys = [] @@ -901,7 +921,7 @@ def num_shards(self) -> int: return 1 def shard_data_sources( - self, num_shards: int, index: int, contiguous=True + self, num_shards: int, index: int, contiguous: bool = True ) -> "HorizontallyConcatenatedMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return HorizontallyConcatenatedMultiSourcesExamplesIterable( @@ -916,21 +936,21 @@ def __init__( generator: np.random.Generator, probabilities: Optional[list[float]] = None, stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", - ): + ) -> None: super().__init__(ex_iterables, stopping_strategy) self.generator = deepcopy(generator) self.probabilities = probabilities # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterables[0].is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterables[0].features - def _get_indices_iterator(self): + def _get_indices_iterator(self) -> Generator[int]: rng = deepcopy(self.generator) num_sources = len(self.ex_iterables) random_batch_size = 1000 @@ -981,7 +1001,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RandomlyCycli ) def shard_data_sources( - self, num_shards: int, index: int, contiguous=True + self, num_shards: int, index: int, contiguous: bool = True ) -> "RandomlyCyclingMultiSourcesExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" return RandomlyCyclingMultiSourcesExamplesIterable( @@ -992,7 +1012,9 @@ def shard_data_sources( ) -def _table_output_to_arrow(output) -> pa.Table: +def _table_output_to_arrow( + output, +) -> Union[pa.Array, pa.Table]: if isinstance(output, pa.Table): return output if isinstance(output, (pd.DataFrame, pd.Series)): @@ -1013,14 +1035,14 @@ def __init__( with_indices: bool = False, input_columns: Optional[list[str]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, drop_last_batch: bool = False, remove_columns: Optional[list[str]] = None, fn_kwargs: Optional[dict] = None, formatting: Optional["FormattingConfig"] = None, features: Optional[Features] = None, max_num_running_async_map_functions_in_parallel: Optional[int] = None, - ): + ) -> None: super().__init__() self.ex_iterable = ex_iterable self.function = function @@ -1058,11 +1080,11 @@ def iter_arrow(self): return self._iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.features is not None # user has extracted features @property - def features(self): + def features(self) -> Optional[Features]: return self._features def _init_state_dict(self) -> dict: @@ -1075,7 +1097,7 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Generator[tuple]: if self.formatting and self.formatting.is_table: formatter = PythonFormatter() for key, pa_table in self._iter_arrow(max_chunksize=1): @@ -1083,7 +1105,7 @@ def __iter__(self): else: yield from self._iter() - def _iter(self): + def _iter(self) -> Generator[tuple[int, tuple[str, dict]]]: current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0 if self._state_dict and self._state_dict["previous_state"]: self.ex_iterable.load_state_dict(self._state_dict["previous_state"]) @@ -1101,7 +1123,7 @@ def _iter(self): else: format_dict = None - def iter_batched_inputs(): + def iter_batched_inputs() -> Generator[tuple[list[int], tuple[str, Optional[dict]]]]: nonlocal current_idx for key, example in iterator: # If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset @@ -1128,7 +1150,7 @@ def iter_batched_inputs(): current_idx += len(indices) yield indices, (key, batch) - def iter_inputs(): + def iter_inputs() -> Generator[tuple[int, tuple[Key, Optional[dict]]]]: nonlocal current_idx for key, example in iterator: # If not batched, we can apply the transform and yield the example directly @@ -1138,7 +1160,7 @@ def iter_inputs(): current_idx += 1 yield current_idx - 1, (key, example) - def validate_function_output(processed_inputs): + def validate_function_output(processed_inputs: MutableMapping) -> None: if self.batched and processed_inputs: first_col = next(iter(processed_inputs)) bad_cols = [ @@ -1150,7 +1172,7 @@ def validate_function_output(processed_inputs): f"while {first_col} has length {len(processed_inputs[first_col])}." ) - def prepare_inputs(key_example, indices): + def prepare_inputs(key_example: Iterable, indices: Any) -> tuple[dict, list, tuple[()], Optional[dict]]: key, example = key_example fn_args = [example] if self.input_columns is None else [example[col] for col in self.input_columns] additional_args = () @@ -1159,7 +1181,7 @@ def prepare_inputs(key_example, indices): inputs = dict(example) return inputs, fn_args, additional_args, self.fn_kwargs - def prepare_outputs(key_example, inputs, processed_inputs): + def prepare_outputs(key_example: Mapping, inputs: MutableMapping, processed_inputs: MutableMapping) -> dict: validate_function_output(processed_inputs) # this logic mimics the one in Dataset.map if self.remove_columns: @@ -1172,13 +1194,13 @@ def prepare_outputs(key_example, inputs, processed_inputs): # no need to do features decoding here return transformed_inputs - def apply_function(key_example, indices): + def apply_function(key_example: Mapping, indices: Iterable[int]) -> dict: """Utility to apply the function on a selection of columns.""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices) processed_inputs = self.function(*fn_args, *additional_args, **fn_kwargs) return prepare_outputs(key_example, inputs, processed_inputs) - async def async_apply_function(key_example, indices): + async def async_apply_function(key_example: Mapping, indices: Iterable[int]) -> dict: """Utility to apply the function on a selection of columns. Same code but async""" inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices) processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs) @@ -1350,7 +1372,7 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[tuple[Key self._state_dict["num_examples_since_previous_state"] = 0 self._state_dict["previous_state_example_idx"] += len(pa_table) - def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExamplesIterable": + def shuffle_data_sources(self, generator: np.random.Generator) -> MappedExamplesIterable: """Shuffle the wrapped examples iterable.""" return MappedExamplesIterable( self.ex_iterable.shuffle_data_sources(generator), @@ -1367,7 +1389,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExample max_num_running_async_map_functions_in_parallel=self.max_num_running_async_map_functions_in_parallel, ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "MappedExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> MappedExamplesIterable: """Keep only the requested shard.""" return MappedExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1393,7 +1415,7 @@ def _add_mask( input: Union[dict, pa.Table], mask: Union[bool, list, pa.Array, pa.ChunkedArray, pa.BooleanScalar], mask_column_name: str, -): +) -> Union[pa.Table, dict[str, Union[bool, list, pa.Array, pa.ChunkedArray, pa.BooleanScalar]]]: if isinstance(input, pa.Table): if not isinstance(mask, (list, pa.Array, pa.ChunkedArray)): mask = pa.array([mask], type=pa.bool_()) @@ -1402,14 +1424,16 @@ def _add_mask( return {mask_column_name: mask} -def add_mask(mask_function: Callable, input: Union[dict, pa.Table], *args, mask_column_name: str, **kwargs): +def add_mask( + mask_function: Callable, input: Union[dict, pa.Table], *args: Any, mask_column_name: str, **kwargs: Any +) -> Union[pa.Table, dict[str, Union[bool, list, pa.Array, pa.ChunkedArray, pa.BooleanScalar]]]: mask = mask_function(input, *args, **kwargs) return _add_mask(input, mask, mask_column_name) async def async_add_mask( - mask_function: Callable, input: Union[dict, pa.Table], *args, mask_column_name: str, **kwargs -): + mask_function: Callable, input: Union[dict, pa.Table], *args: Any, mask_column_name: str, **kwargs: Any +)-> Union[pa.Table, dict[str, Union[bool, list, pa.Array, pa.ChunkedArray, pa.BooleanScalar]]]: mask = await mask_function(input, *args, **kwargs) return _add_mask(input, mask, mask_column_name) @@ -1424,10 +1448,10 @@ def __init__( with_indices: bool = False, input_columns: Optional[list[str]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, fn_kwargs: Optional[dict] = None, - formatting: Optional["FormattingConfig"] = None, - ): + formatting: Optional[FormattingConfig] = None, + ) -> None: self.mask_function = function if ex_iterable.is_typed: features = Features({**ex_iterable.features, self.mask_column_name: Value("bool")}) @@ -1449,7 +1473,7 @@ def __init__( features=features, ) - def _iter(self): + def _iter(self) -> Generator[tuple[Any, dict]]: for key, example in super()._iter(): example = dict(example) if example.pop(self.mask_column_name): @@ -1460,7 +1484,7 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None): mask = pa_table[self.mask_column_name] yield key, pa_table.drop(self.mask_column_name).filter(mask) - def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable": + def shuffle_data_sources(self, seed: Optional[int]) -> FilteredExamplesIterable: """Shuffle the wrapped examples iterable.""" return FilteredExamplesIterable( self.ex_iterable.shuffle_data_sources(seed), @@ -1473,7 +1497,7 @@ def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable formatting=self.formatting, ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FilteredExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> FilteredExamplesIterable: """Keep only the requested shard.""" return FilteredExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1492,7 +1516,7 @@ def num_shards(self) -> int: class BufferShuffledExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): + def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator) -> None: super().__init__() self.ex_iterable = ex_iterable self.buffer_size = buffer_size @@ -1500,11 +1524,11 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterable.features def _init_state_dict(self) -> dict: @@ -1512,7 +1536,7 @@ def _init_state_dict(self) -> dict: self._original_state_dict = self.state_dict() return self._state_dict - def load_state_dict(self, state_dict: dict) -> dict: + def load_state_dict(self, state_dict: Optional[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]: if self._state_dict: if state_dict != self._original_state_dict: logger.warning( @@ -1522,11 +1546,13 @@ def load_state_dict(self, state_dict: dict) -> dict: return super().load_state_dict(state_dict) @staticmethod - def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000) -> Iterator[int]: + def _iter_random_indices( + rng: np.random.Generator, buffer_size: int, random_batch_size: int = 1000 + ) -> Generator[int, None, None]: while True: yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size)) - def __iter__(self): + def __iter__(self) -> Generator[list]: buffer_size = self.buffer_size rng = deepcopy(self.generator) indices_iterator = self._iter_random_indices(rng, buffer_size) @@ -1549,7 +1575,9 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffle self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "BufferShuffledExamplesIterable": + def shard_data_sources( + self, num_shards: int, index: int, contiguous: bool = True + ) -> BufferShuffledExamplesIterable: """Keep only the requested shard.""" return BufferShuffledExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1569,7 +1597,7 @@ def __init__( n: int, block_sources_order_when_shuffling: bool = True, split_when_sharding: bool = True, - ): + ) -> None: super().__init__() self.ex_iterable = ex_iterable self.n = n @@ -1578,14 +1606,14 @@ def __init__( # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterable.features - def _init_state_dict(self) -> dict: + def _init_state_dict(self) -> dict[str, Union[bool, dict]]: self._state_dict = { "skipped": False, "examples_iterable": self.ex_iterable._init_state_dict(), @@ -1593,14 +1621,14 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Generator[tuple[Key, dict]]: ex_iterable_idx_start = 0 if self._state_dict and self._state_dict["skipped"] else self.n if self._state_dict: self._state_dict["skipped"] = True yield from islice(self.ex_iterable, ex_iterable_idx_start, None) @staticmethod - def split_number(num, n): + def split_number(num: int, n: int) -> list[int]: quotient = num // n remainder = num % n result = [quotient] * n @@ -1608,7 +1636,7 @@ def split_number(num, n): result[i] += 1 return result - def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesIterable": + def shuffle_data_sources(self, generator: np.random.Generator) -> SkipExamplesIterable: """May not shuffle the wrapped examples iterable since it would skip examples from other shards instead.""" if self.block_sources_order_when_shuffling: return self @@ -1620,7 +1648,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "SkipExamplesI split_when_sharding=self.split_when_sharding, ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SkipExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> SkipExamplesIterable: """Keep only the requested shard.""" if self.split_when_sharding: return SkipExamplesIterable( @@ -1646,7 +1674,7 @@ def __init__( self, ex_iterable: _BaseExamplesIterable, num_times: Optional[int], - ): + ) -> None: super().__init__() self.ex_iterable = ex_iterable self.num_times = num_times @@ -1659,7 +1687,7 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Generator: repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0 while True: if self.num_times is not None and repeat_index >= max(self.num_times, 0): @@ -1674,7 +1702,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExample """Shuffle the underlying iterable, then repeat.""" return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable": + def shard_data_sources(self, worker_id: int, num_workers: int) -> RepeatExamplesIterable: """Shard, then repeat shards.""" return RepeatExamplesIterable( self.ex_iterable.shard_data_sources(worker_id, num_workers), @@ -1693,7 +1721,7 @@ def __init__( n: int, block_sources_order_when_shuffling: bool = True, split_when_sharding: bool = True, - ): + ) -> None: super().__init__() self.ex_iterable = ex_iterable self.n = n @@ -1702,11 +1730,11 @@ def __init__( # TODO(QL): implement iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed @property - def features(self): + def features(self) -> Optional[Features]: return self.ex_iterable.features def _init_state_dict(self) -> dict: @@ -1717,7 +1745,7 @@ def _init_state_dict(self) -> dict: } return self._state_dict - def __iter__(self): + def __iter__(self) -> Iterator: ex_iterable_num_taken = self._state_dict["num_taken"] if self._state_dict else 0 for key_example in islice(self.ex_iterable, self.n - ex_iterable_num_taken): if self._state_dict: @@ -1725,7 +1753,7 @@ def __iter__(self): yield key_example @staticmethod - def split_number(num, n): + def split_number(num: int, n: int) -> list[int]: quotient = num // n remainder = num % n result = [quotient] * n @@ -1745,7 +1773,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "TakeExamplesI split_when_sharding=self.split_when_sharding, ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "TakeExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> "TakeExamplesIterable": """Keep only the requested shard.""" if self.split_when_sharding: return TakeExamplesIterable( @@ -1783,7 +1811,7 @@ def _apply_feature_types_on_example( def _apply_feature_types_on_batch( - batch: dict, features: Features, token_per_repo_id: dict[str, Union[str, bool, None]] + batch: Mapping, features: Features, token_per_repo_id: dict[str, Union[str, bool, None]] ) -> dict: batch = dict(batch) # add missing columns @@ -1818,7 +1846,7 @@ def __init__( formatting: Optional[FormattingConfig], features: Optional[Features], token_per_repo_id: dict[str, Union[str, bool, None]], - ): + ) -> None: super().__init__() self.ex_iterable = ex_iterable self._features = features @@ -1831,18 +1859,18 @@ def iter_arrow(self): return self._iter_arrow @property - def is_typed(self): + def is_typed(self) -> bool: return self.ex_iterable.is_typed or self._features is not None @property - def features(self): + def features(self) -> Optional[Features]: return self._features def _init_state_dict(self) -> dict: self._state_dict = self.ex_iterable._init_state_dict() return self._state_dict - def __iter__(self): + def __iter__(self) -> Generator[tuple[Key, Optional[dict]]]: if not self.formatting or self.formatting.is_table: formatter = PythonFormatter(features=self._features if not self.ex_iterable.is_typed else None) else: @@ -1873,7 +1901,7 @@ def __iter__(self): example = format_dict(example) yield key, example - def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]: + def _iter_arrow(self) -> Generator[tuple[Key, pa.Table]]: if not self.features: yield from self.ex_iterable._iter_arrow() for key, pa_table in self.ex_iterable._iter_arrow(): @@ -1897,7 +1925,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "FormattedExam formatting=self.formatting, ) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "FormattedExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> FormattedExamplesIterable: """Keep only the requested shard.""" return FormattedExamplesIterable( self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), @@ -1923,7 +1951,7 @@ class DistributedConfig: world_size: int -def _maybe_add_torch_iterable_dataset_parent_class(cls): +def _maybe_add_torch_iterable_dataset_parent_class(cls: type) -> None: """Add torch.utils.data.IterableDataset as a parent class if 'torch' is available""" if config.TORCH_AVAILABLE: import torch.utils.data @@ -1956,7 +1984,7 @@ def __init__( shuffling: Optional[ShufflingConfig] = None, distributed: Optional[DistributedConfig] = None, token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None, - ): + ) -> None: if distributed and distributed.world_size > 1 and shuffling and shuffling._original_seed is None: raise RuntimeError( "The dataset doesn't have a fixed random seed across nodes to shuffle and split the list of dataset shards by node. " @@ -2079,27 +2107,27 @@ def load_state_dict(self, state_dict: dict) -> None: """ self._starting_state_dict = state_dict - def __repr__(self): + def __repr__(self) -> str: return f"IterableDataset({{\n features: {list(self._info.features.keys()) if self._info.features is not None else 'Unknown'},\n num_shards: {self.num_shards}\n}})" - def __getstate__(self): + def __getstate__(self) -> dict: return self.__dict__ - def __setstate__(self, d): + def __setstate__(self, d: dict) -> None: self.__dict__ = d # Re-add torch shared memory, since shared memory is not always kept when pickling self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch) # Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling _maybe_add_torch_iterable_dataset_parent_class(self.__class__) - def _head(self, n=5): + def _head(self, n: int = 5) -> Mapping: return next(iter(self.iter(batch_size=n))) @property def epoch(self) -> int: return int(self._epoch) - def _effective_generator(self): + def _effective_generator(self) -> np.random.Generator: if self._shuffling and self.epoch == 0: return self._shuffling.generator elif self._shuffling: @@ -2120,7 +2148,7 @@ def num_shards(self) -> int: def n_shards(self) -> int: # backward compatibility return self.num_shards - def _iter_pytorch(self): + def _iter_pytorch(self) -> Generator[Optional[dict[str, Any]]]: ex_iterable = self._prepare_ex_iterable_for_iteration() # Fix for fsspec when using multiprocess to avoid hanging in the ML training loop. (only required for fsspec >= 0.9.0) # See https://github.com/fsspec/gcsfs/issues/379 @@ -2179,7 +2207,7 @@ def _iter_pytorch(self): f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.num_shards}<{worker_info.num_workers})." ) - def _is_main_process(self): + def _is_main_process(self) -> bool: if self._distributed and self._distributed.rank > 0: return False if "torch" in sys.modules: @@ -2312,7 +2340,7 @@ def from_generator( features: Optional[Features] = None, gen_kwargs: Optional[dict] = None, split: NamedSplit = Split.TRAIN, - ) -> "IterableDataset": + ) -> IterableDataset: """Create an Iterable Dataset from a generator. Args: @@ -2363,11 +2391,11 @@ def from_generator( @staticmethod def from_spark( - df: "pyspark.sql.DataFrame", + df: pyspark.sql.DataFrame, split: Optional[NamedSplit] = None, features: Optional[Features] = None, **kwargs, - ) -> "IterableDataset": + ) -> IterableDataset: """Create an IterableDataset from Spark DataFrame. The dataset is streamed to the driver in batches. Args: @@ -2484,7 +2512,7 @@ def map( with_indices: bool = False, input_columns: Optional[Union[str, list[str]]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, drop_last_batch: bool = False, remove_columns: Optional[Union[str, list[str]]] = None, features: Optional[Features] = None, @@ -2513,10 +2541,10 @@ def map( Function applied on-the-fly on the examples when you iterate on the dataset. It must have one of the following signatures: - - `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` - - `function(example: Dict[str, Any], idx: int) -> Dict[str, Any]` if `batched=False` and `with_indices=True` - - `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` - - `function(batch: Dict[str, List], indices: List[int]) -> Dict[str, List]` if `batched=True` and `with_indices=True` + - `function(example: dict[str, Any]) -> dict[str, Any]` if `batched=False` and `with_indices=False` + - `function(example: dict[str, Any], idx: int) -> dict[str, Any]` if `batched=False` and `with_indices=True` + - `function(batch: dict[str, list]) -> dict[str, list]` if `batched=True` and `with_indices=False` + - `function(batch: dict[str, list], indices: list[int]) -> dict[str, list]` if `batched=True` and `with_indices=True` For advanced usage, the function can also return a `pyarrow.Table`. If the function is asynchronous, then `map` will run your function in parallel. @@ -2524,7 +2552,7 @@ def map( If no function is provided, default to identity function: `lambda x: x`. with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx[, rank]): ...`. - input_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): + input_columns (`Optional[Union[str, list[str]]]`, defaults to `None`): The columns to be passed into `function` as positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument. batched (`bool`, defaults to `False`): @@ -2535,13 +2563,13 @@ def map( drop_last_batch (`bool`, defaults to `False`): Whether a last batch smaller than the batch_size should be dropped instead of being processed by the function. - remove_columns (`[List[str]]`, *optional*, defaults to `None`): + remove_columns (`[list[str]]`, *optional*, defaults to `None`): Remove a selection of columns while doing the mapping. Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding columns with names in `remove_columns`, these columns will be kept. features (`[Features]`, *optional*, defaults to `None`): Feature types of the resulting dataset. - fn_kwargs (`Dict`, *optional*, default `None`): + fn_kwargs (`dict`, *optional*, default `None`): Keyword arguments to be passed to `function`. Example: @@ -2631,10 +2659,10 @@ def map( def filter( self, function: Optional[Callable] = None, - with_indices=False, + with_indices: bool = False, input_columns: Optional[Union[str, list[str]]] = None, batched: bool = False, - batch_size: Optional[int] = 1000, + batch_size: int = 1000, fn_kwargs: Optional[dict] = None, ) -> "IterableDataset": """Apply a filter function to all the elements so that the dataset only includes examples according to the filter function. @@ -2647,23 +2675,23 @@ def filter( function (`Callable`): Callable with one of the following signatures: - - `function(example: Dict[str, Any]) -> bool` if `with_indices=False, batched=False` - - `function(example: Dict[str, Any], indices: int) -> bool` if `with_indices=True, batched=False` - - `function(example: Dict[str, List]) -> List[bool]` if `with_indices=False, batched=True` - - `function(example: Dict[str, List], indices: List[int]) -> List[bool]` if `with_indices=True, batched=True` + - `function(example: dict[str, Any]) -> bool` if `with_indices=False, batched=False` + - `function(example: dict[str, Any], indices: int) -> bool` if `with_indices=True, batched=False` + - `function(example: dict[str, list]) -> list[bool]` if `with_indices=False, batched=True` + - `function(example: dict[str, list], indices: list[int]) -> list[bool]` if `with_indices=True, batched=True` If the function is asynchronous, then `filter` will run your function in parallel. If no function is provided, defaults to an always True function: `lambda x: True`. with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`. - input_columns (`str` or `List[str]`, *optional*): + input_columns (`str` or `list[str]`, *optional*): The columns to be passed into `function` as positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument. batched (`bool`, defaults to `False`): Provide batch of examples to `function`. batch_size (`int`, *optional*, default `1000`): Number of examples per batch provided to `function` if `batched=True`. - fn_kwargs (`Dict`, *optional*, default `None`): + fn_kwargs (`dict`, *optional*, default `None`): Keyword arguments to be passed to `function`. Example: @@ -2715,8 +2743,8 @@ def filter( ) def shuffle( - self, seed=None, generator: Optional[np.random.Generator] = None, buffer_size: int = 1000 - ) -> "IterableDataset": + self, seed: Optional[int] = None, generator: Optional[np.random.Generator] = None, buffer_size: int = 1000 + ) -> IterableDataset: """ Randomly shuffles the elements of this dataset. @@ -2781,7 +2809,7 @@ def shuffle( token_per_repo_id=self._token_per_repo_id, ) - def set_epoch(self, epoch: int): + def set_epoch(self, epoch: int) -> None: self._epoch += epoch - self._epoch # update torch value in shared memory in-place def skip(self, n: int) -> "IterableDataset": @@ -2871,7 +2899,7 @@ def repeat(self, num_times: Optional[int]) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) - def take(self, n: int) -> "IterableDataset": + def take(self, n: int) -> IterableDataset: """ Create a new [`IterableDataset`] with only the first `n` elements. @@ -2980,7 +3008,7 @@ def column_names(self) -> Optional[list[str]]: """ return list(self._info.features.keys()) if self._info.features is not None else None - def add_column(self, name: str, column: Union[list, np.array]) -> "IterableDataset": + def add_column(self, name: str, column: Union[list, np.ndarray]) -> IterableDataset: """Add column to Dataset. Args: @@ -3022,13 +3050,13 @@ def rename_column(self, original_column_name: str, new_column_name: str) -> "Ite """ return self.rename_columns({original_column_name: new_column_name}) - def rename_columns(self, column_mapping: dict[str, str]) -> "IterableDataset": + def rename_columns(self, column_mapping: dict[str, str]) -> IterableDataset: """ Rename several columns in the dataset, and move the features associated to the original columns under the new column names. Args: - column_mapping (`Dict[str, str]`): A mapping of columns to rename to their new names + column_mapping (`dict[str, str]`): A mapping of columns to rename to their new names Returns: `IterableDataset`: A copy of the dataset with renamed columns @@ -3047,14 +3075,14 @@ def rename_columns(self, column_mapping: dict[str, str]) -> "IterableDataset": ) return ds_iterable - def remove_columns(self, column_names: Union[str, list[str]]) -> "IterableDataset": + def remove_columns(self, column_names: Union[str, list[str]]) -> IterableDataset: """ Remove one or several column(s) in the dataset and the features associated to them. The removal is done on-the-fly on the examples when iterating over the dataset. Args: - column_names (`Union[str, List[str]]`): + column_names (`Union[str, list[str]]`): Name of the column(s) to remove. Returns: @@ -3082,14 +3110,14 @@ def remove_columns(self, column_names: Union[str, list[str]]) -> "IterableDatase return ds_iterable - def select_columns(self, column_names: Union[str, list[str]]) -> "IterableDataset": + def select_columns(self, column_names: Union[str, list[str]]) -> IterableDataset: """Select one or several column(s) in the dataset and the features associated to them. The selection is done on-the-fly on the examples when iterating over the dataset. Args: - column_names (`Union[str, List[str]]`): + column_names (`Union[str, list[str]]`): Name of the column(s) to select. Returns: @@ -3330,7 +3358,7 @@ def _step(self, step: int, offset: int) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) - def _resolve_features(self): + def _resolve_features(self) -> IterableDataset: if self.features is not None: return self elif self._ex_iterable.is_typed: @@ -3349,7 +3377,7 @@ def _resolve_features(self): token_per_repo_id=self._token_per_repo_id, ) - def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableDataset": + def batch(self, batch_size: int, drop_last_batch: bool = False) -> IterableDataset: """ Group samples from the dataset into batches. @@ -3364,7 +3392,7 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData ``` """ - def batch_fn(unbatched): + def batch_fn(unbatched: Mapping) -> dict: return {k: [v] for k, v in unbatched.items()} if self.features: @@ -3389,7 +3417,7 @@ def _concatenate_iterable_datasets( Args: - dsets (`List[datasets.IterableDataset]`): List of Datasets to concatenate. + dsets (`list[datasets.IterableDataset]`): List of Datasets to concatenate. info (`DatasetInfo`, optional): Dataset information, like description, citation, etc. split (`NamedSplit`, optional): Name of the dataset split. axis (``{0, 1}``, default ``0``, meaning over rows): @@ -3453,8 +3481,8 @@ def _interleave_iterable_datasets( Args: - datasets (`List[IterableDataset]`): list of datasets to interleave - probabilities (`List[float]`, optional, default None): If specified, the new iterable dataset samples + datasets (`list[IterableDataset]`): list of datasets to interleave + probabilities (`list[float]`, optional, default None): If specified, the new iterable dataset samples examples from one source at a time according to these probabilities. seed (`int`, optional, default None): The random seed used to choose a source for each example. stopping_strategy (`str`, defaults to `first_exhausted`): diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index d730bfea502..d3e2e8d21c6 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -1,13 +1,28 @@ +from __future__ import annotations + import os import posixpath +import shutil import uuid from collections.abc import Iterable from dataclasses import dataclass from itertools import islice -from typing import TYPE_CHECKING, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Generator, + Iterable, + Iterator, + Mapping, + MutableMapping, + Optional, + Sequence, + Union, +) import numpy as np import pyarrow as pa +import pyspark.sql import datasets from datasets.arrow_writer import ArrowWriter, ParquetWriter @@ -25,7 +40,6 @@ if TYPE_CHECKING: import pyspark - import pyspark.sql @dataclass @@ -34,11 +48,13 @@ class SparkConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() -def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: list[int]): +def _reorder_dataframe_by_partition( + df: pyspark.sql.DataFrame, new_partition_order: Sequence[int] +) -> pyspark.sql.DataFrame: df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}") for partition_id in new_partition_order[1:]: partition_df = df.select("*").where(f"part_id = {partition_id}") @@ -47,10 +63,10 @@ def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_o def _generate_iterable_examples( - df: "pyspark.sql.DataFrame", - partition_order: list[int], - state_dict: Optional[dict] = None, -): + df: pyspark.sql.DataFrame, + partition_order: Sequence[int], + state_dict: Optional[MutableMapping] = None, +) -> Generator[tuple[str, dict]]: import pyspark df_with_partition_id = df.select("*", pyspark.sql.functions.spark_partition_id().alias("part_id")) @@ -78,30 +94,30 @@ def _generate_iterable_examples( class SparkExamplesIterable(_BaseExamplesIterable): def __init__( self, - df: "pyspark.sql.DataFrame", - partition_order=None, - ): + df: pyspark.sql.DataFrame, + partition_order: Optional[Sequence[int]] = None, + ) -> None: super().__init__() self.df = df self.partition_order = partition_order or range(self.df.rdd.getNumPartitions()) - def _init_state_dict(self) -> dict: + def _init_state_dict(self) -> dict[str, int]: self._state_dict = {"partition_idx": 0, "partition_example_idx": 0} return self._state_dict @experimental - def load_state_dict(self, state_dict: dict) -> dict: + def load_state_dict(self, state_dict: Optional[Mapping]) -> Optional[Mapping]: return super().load_state_dict(state_dict) - def __iter__(self): + def __iter__(self) -> Generator[tuple[str, dict]]: yield from _generate_iterable_examples(self.df, self.partition_order, self._state_dict) - def shuffle_data_sources(self, generator: np.random.Generator) -> "SparkExamplesIterable": + def shuffle_data_sources(self, generator: np.random.Generator) -> SparkExamplesIterable: partition_order = list(range(self.df.rdd.getNumPartitions())) generator.shuffle(partition_order) return SparkExamplesIterable(self.df, partition_order=partition_order) - def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "SparkExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous: bool = True) -> SparkExamplesIterable: partition_order = self.split_shard_indices_by_worker(num_shards=num_shards, index=index, contiguous=contiguous) return SparkExamplesIterable(self.df, partition_order=partition_order) @@ -115,11 +131,11 @@ class Spark(datasets.DatasetBuilder): def __init__( self, - df: "pyspark.sql.DataFrame", - cache_dir: str = None, - working_dir: str = None, - **config_kwargs, - ): + df: pyspark.sql.DataFrame, + cache_dir: Optional[str] = None, + working_dir: Optional[str] = None, + **config_kwargs: Any, + ) -> None: import pyspark self._spark = pyspark.sql.SparkSession.builder.getOrCreate() @@ -132,13 +148,13 @@ def __init__( **config_kwargs, ) - def _validate_cache_dir(self): + def _validate_cache_dir(self) -> None: # Define this so that we don't reference self in create_cache_and_write_probe, which will result in a pickling # error due to pickling the SparkContext. cache_dir = self._cache_dir # Returns the path of the created file. - def create_cache_and_write_probe(context): + def create_cache_and_write_probe(context) -> list[str]: # makedirs with exist_ok will recursively create the directory. It will not throw an error if directories # already exist. os.makedirs(cache_dir, exist_ok=True) @@ -165,16 +181,18 @@ def create_cache_and_write_probe(context): "When using Dataset.from_spark on a multi-node cluster, the driver and all workers should be able to access cache_dir" ) - def _info(self): + def _info(self) -> datasets.DatasetInfo: return datasets.DatasetInfo(features=self.config.features) - def _split_generators(self, dl_manager: datasets.download.download_manager.DownloadManager): + def _split_generators( + self, dl_manager: datasets.download.download_manager.DownloadManager + ) -> list[datasets.SplitGenerator]: return [datasets.SplitGenerator(name=datasets.Split.TRAIN)] - def _repartition_df_if_needed(self, max_shard_size): + def _repartition_df_if_needed(self, max_shard_size: int) -> None: import pyspark - def get_arrow_batch_size(it): + def get_arrow_batch_size(it: Iterable[pa.RecordBatch]) -> Generator[pa.RecordBatch]: for batch in it: yield pa.RecordBatch.from_pydict({"batch_bytes": [batch.nbytes]}) @@ -201,7 +219,7 @@ def _prepare_split_single( fpath: str, file_format: str, max_shard_size: int, - ) -> Iterable[tuple[int, bool, Union[int, tuple]]]: + ) -> Generator[tuple[int, tuple[int, int, int, int]]]: import pyspark writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -214,7 +232,7 @@ def _prepare_split_single( writer_batch_size = self._writer_batch_size storage_options = self._fs.storage_options - def write_arrow(it): + def write_arrow(it: Iterator[pa.RecordBatch]) -> Union[pa.RecordBatch, Generator[pa.RecordBatch]]: # Within the same SparkContext, no two task attempts will share the same attempt ID. task_id = pyspark.TaskContext().taskAttemptId() first_batch = next(it, None) @@ -282,12 +300,12 @@ def write_arrow(it): def _prepare_split( self, - split_generator: "datasets.SplitGenerator", + split_generator: datasets.SplitGenerator, file_format: str = "arrow", max_shard_size: Optional[Union[str, int]] = None, num_proc: Optional[int] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: self._validate_cache_dir() max_shard_size = convert_file_size_to_int(max_shard_size or MAX_SHARD_SIZE) @@ -336,7 +354,7 @@ def _rename_shard( task_id: int, shard_id: int, global_shard_id: int, - ): + ) -> None: rename( fs, fpath.replace("SSSSS", f"{shard_id:05d}").replace("TTTTT", f"{task_id:05d}"), @@ -362,6 +380,6 @@ def _rename_shard( def _get_examples_iterable_for_split( self, - split_generator: "datasets.SplitGenerator", + split_generator: datasets.SplitGenerator, ) -> SparkExamplesIterable: return SparkExamplesIterable(self.df) diff --git a/src/datasets/table.py b/src/datasets/table.py index e57a0c54927..46298b29ef4 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -1,9 +1,27 @@ +from __future__ import annotations + import copy import os from collections.abc import Iterator from functools import partial from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Iterator, + List, + Mapping, + Optional, + OrderedDict, + ParamSpec, + Sequence, + Tuple, + TypeVar, + Union, +) import numpy as np import pyarrow as pa @@ -14,6 +32,8 @@ if TYPE_CHECKING: + import pandas as pd + from .features.features import Features, FeatureType @@ -21,7 +41,7 @@ def inject_arrow_table_documentation(arrow_table_method): - def wrapper(fn): + def wrapper(fn: Callable): fn.__doc__ = arrow_table_method.__doc__ + (fn.__doc__ if fn.__doc__ is not None else "") fn.__doc__ = fn.__doc__.replace("pyarrow.Table", "Table") if hasattr(arrow_table_method, "__annotations__"): @@ -66,7 +86,10 @@ def _memory_mapped_arrow_table_from_file(filename: str) -> pa.Table: return pa_table -def _deepcopy(x, memo: dict): +T = TypeVar("T") + + +def _deepcopy(x: T, memo: MutableMapping) -> T: """deepcopy a regular class instance""" cls = x.__class__ result = cls.__new__(cls) @@ -76,7 +99,7 @@ def _deepcopy(x, memo: dict): return result -def _interpolation_search(arr: list[int], x: int) -> int: +def _interpolation_search(arr: Sequence[int], x: int) -> int: """ Return the position i of a sorted array so that arr[i] <= x < arr[i+1] @@ -103,14 +126,14 @@ def _interpolation_search(arr: list[int], x: int) -> int: class IndexedTableMixin: - def __init__(self, table: pa.Table): + def __init__(self, table: pa.Table) -> None: self._schema: pa.Schema = table.schema self._batches: list[pa.RecordBatch] = [ recordbatch for recordbatch in table.to_batches() if len(recordbatch) > 0 ] self._offsets: np.ndarray = np.cumsum([0] + [len(b) for b in self._batches], dtype=np.int64) - def fast_gather(self, indices: Union[list[int], np.ndarray]) -> pa.Table: + def fast_gather(self, indices: Union[Sequence[int], np.ndarray]) -> pa.Table: """ Create a pa.Table by gathering the records at the records at the specified indices. Should be faster than pa.concat_tables(table.fast_slice(int(i) % table.num_rows, 1) for i in indices) since NumPy can compute @@ -127,7 +150,7 @@ def fast_gather(self, indices: Union[list[int], np.ndarray]) -> pa.Table: schema=self._schema, ) - def fast_slice(self, offset=0, length=None) -> pa.Table: + def fast_slice(self, offset: int = 0, length: Optional[int] = None) -> pa.Table: """ Slice the Table using interpolation search. The behavior is the same as `pyarrow.Table.slice` but it's significantly faster. @@ -163,11 +186,11 @@ class Table(IndexedTableMixin): The implementation of these methods differs for the subclasses. """ - def __init__(self, table: pa.Table): + def __init__(self, table: pa.Table) -> None: super().__init__(table) self.table = table - def __deepcopy__(self, memo: dict): + def __deepcopy__(self, memo: dict) -> object: # arrow tables are immutable, so there's no need to copy self.table # moreover calling deepcopy on a pyarrow table seems to make pa.total_allocated_bytes() decrease for some reason # by adding it to the memo, self.table won't be copied @@ -176,7 +199,7 @@ def __deepcopy__(self, memo: dict): memo[id(self._batches)] = list(self._batches) return _deepcopy(self, memo) - def validate(self, *args, **kwargs): + def validate(self, *args: Any, **kwargs: Any): """ Perform validation checks. An exception is raised if validation fails. @@ -192,7 +215,7 @@ def validate(self, *args, **kwargs): """ return self.table.validate(*args, **kwargs) - def equals(self, *args, **kwargs): + def equals(self, *args: Any, **kwargs: Any) -> bool: """ Check if contents of two tables are equal. @@ -209,7 +232,7 @@ def equals(self, *args, **kwargs): kwargs = {k: v.table if isinstance(v, Table) else v for k, v in kwargs} return self.table.equals(*args, **kwargs) - def to_batches(self, *args, **kwargs): + def to_batches(self, *args: Any, **kwargs: Any) -> list[pyarrow.RecordBatch]: """ Convert Table to list of (contiguous) `RecordBatch` objects. @@ -223,7 +246,7 @@ def to_batches(self, *args, **kwargs): """ return self.table.to_batches(*args, **kwargs) - def to_pydict(self, *args, **kwargs): + def to_pydict(self, *args: Any, **kwargs: Any) -> Union[dict, OrderedDict]: """ Convert the Table to a `dict` or `OrderedDict`. @@ -232,7 +255,7 @@ def to_pydict(self, *args, **kwargs): """ return self.table.to_pydict(*args, **kwargs) - def to_pylist(self, *args, **kwargs): + def to_pylist(self, *args: Any, **kwargs: Any) -> list: """ Convert the Table to a list @@ -241,7 +264,7 @@ def to_pylist(self, *args, **kwargs): """ return self.table.to_pylist(*args, **kwargs) - def to_pandas(self, *args, **kwargs): + def to_pandas(self, *args: Any, **kwargs: Any) -> Union[pd.Series, pd.DataFrame]: """ Convert to a pandas-compatible NumPy array or DataFrame, as appropriate. @@ -303,10 +326,10 @@ def to_pandas(self, *args, **kwargs): """ return self.table.to_pandas(*args, **kwargs) - def to_string(self, *args, **kwargs): + def to_string(self, *args: Any, **kwargs: Any) -> str: return self.table.to_string(*args, **kwargs) - def to_reader(self, max_chunksize: Optional[int] = None): + def to_reader(self, max_chunksize: Optional[int] = None) -> pa.RecordBatchReader: """ Convert the Table to a RecordBatchReader. @@ -322,7 +345,7 @@ def to_reader(self, max_chunksize: Optional[int] = None): """ return self.table.to_reader(max_chunksize=max_chunksize) - def field(self, *args, **kwargs): + def field(self, *args: Any, **kwargs: Any) -> pa.Field: """ Select a schema field by its column name or numeric index. @@ -335,7 +358,7 @@ def field(self, *args, **kwargs): """ return self.table.field(*args, **kwargs) - def column(self, *args, **kwargs): + def column(self, *args: Any, **kwargs: Any) -> Union[pa.Array, pa.ChunkedArray]: """ Select a column by its column name, or numeric index. @@ -348,7 +371,7 @@ def column(self, *args, **kwargs): """ return self.table.column(*args, **kwargs) - def itercolumns(self, *args, **kwargs): + def itercolumns(self, *args: Any, **kwargs: Any) -> Iterator[pa.ChunkedArray]: """ Iterator over all columns in their numerical order. @@ -358,7 +381,7 @@ def itercolumns(self, *args, **kwargs): return self.table.itercolumns(*args, **kwargs) @property - def schema(self): + def schema(self) -> pa.Schema: """ Schema of the table and its columns. @@ -368,7 +391,7 @@ def schema(self): return self.table.schema @property - def columns(self): + def columns(self) -> list[pa.ChunkedArray]: """ List of all columns in numerical order. @@ -378,7 +401,7 @@ def columns(self): return self.table.columns @property - def num_columns(self): + def num_columns(self) -> int: """ Number of columns in this table. @@ -388,7 +411,7 @@ def num_columns(self): return self.table.num_columns @property - def num_rows(self): + def num_rows(self) -> int: """ Number of rows in this table. @@ -401,7 +424,7 @@ def num_rows(self): return self.table.num_rows @property - def shape(self): + def shape(self) -> tuple[int, int]: """ Dimensions of the table: (#rows, #columns). @@ -411,35 +434,35 @@ def shape(self): return self.table.shape @property - def nbytes(self): + def nbytes(self) -> int: """ Total number of bytes consumed by the elements of the table. """ return self.table.nbytes @property - def column_names(self): + def column_names(self) -> list[str]: """ Names of the table's columns. """ return self.table.column_names - def __eq__(self, other): + def __eq__(self, other: Table) -> bool: return self.equals(other) - def __getitem__(self, i): + def __getitem__(self, i: int) -> pa.ChunkedArray: return self.table[i] - def __len__(self): + def __len__(self) -> int: return len(self.table) - def __repr__(self): + def __repr__(self) -> str: return self.table.__repr__().replace("pyarrow.Table", self.__class__.__name__) - def __str__(self): + def __str__(self) -> str: return self.table.__str__().replace("pyarrow.Table", self.__class__.__name__) - def slice(self, *args, **kwargs): + def slice(self, *args: Any, **kwargs: Any) -> Table: """ Compute zero-copy slice of this Table. @@ -455,13 +478,13 @@ def slice(self, *args, **kwargs): """ raise NotImplementedError() - def filter(self, *args, **kwargs): + def filter(self, *args: Any, **kwargs: Any) -> None: """ Select records from a Table. See `pyarrow.compute.filter` for full usage. """ raise NotImplementedError() - def flatten(self, *args, **kwargs): + def flatten(self, *args: Any, **kwargs: Any) -> Table: """ Flatten this Table. Each column with a struct type is flattened into one column per struct field. Other columns are left unchanged. @@ -475,7 +498,7 @@ def flatten(self, *args, **kwargs): """ raise NotImplementedError() - def combine_chunks(self, *args, **kwargs): + def combine_chunks(self, *args: Any, **kwargs: Any): """ Make a new table by combining the chunks this table has. @@ -491,7 +514,7 @@ def combine_chunks(self, *args, **kwargs): """ raise NotImplementedError() - def cast(self, *args, **kwargs): + def cast(self, *args: Any, **kwargs: Any): """ Cast table values to another schema. @@ -506,7 +529,7 @@ def cast(self, *args, **kwargs): """ raise NotImplementedError() - def replace_schema_metadata(self, *args, **kwargs): + def replace_schema_metadata(self, *args: Any, **kwargs: Any): """ EXPERIMENTAL: Create shallow copy of table by replacing schema key-value metadata with the indicated new metadata (which may be None, @@ -520,7 +543,7 @@ def replace_schema_metadata(self, *args, **kwargs): """ raise NotImplementedError() - def add_column(self, *args, **kwargs): + def add_column(self, *args: Any, **kwargs: Any): """ Add column to Table at position. @@ -541,7 +564,7 @@ def add_column(self, *args, **kwargs): """ raise NotImplementedError() - def append_column(self, *args, **kwargs): + def append_column(self, *args: Any, **kwargs: Any): """ Append column at end of columns. @@ -557,7 +580,7 @@ def append_column(self, *args, **kwargs): """ raise NotImplementedError() - def remove_column(self, *args, **kwargs): + def remove_column(self, *args: Any, **kwargs: Any): """ Create new Table with the indicated column removed. @@ -570,7 +593,7 @@ def remove_column(self, *args, **kwargs): """ raise NotImplementedError() - def set_column(self, *args, **kwargs): + def set_column(self, *args: Any, **kwargs: Any): """ Replace column in Table at position. @@ -588,13 +611,13 @@ def set_column(self, *args, **kwargs): """ raise NotImplementedError() - def rename_columns(self, *args, **kwargs): + def rename_columns(self, *args: Any, **kwargs: Any): """ Create new table with columns renamed to provided names. """ raise NotImplementedError() - def drop(self, *args, **kwargs): + def drop(self, *args: Any, **kwargs: Any): """ Drop one or more columns and return a new table. @@ -610,7 +633,7 @@ def drop(self, *args, **kwargs): """ raise NotImplementedError() - def select(self, *args, **kwargs): + def select(self, *args: Any, **kwargs: Any): """ Select columns of the table. @@ -636,6 +659,9 @@ class TableBlock(Table): pass +T_InMemoryTable = TypeVar("T_InMemoryTable", bound="InMemoryTable") + + class InMemoryTable(TableBlock): """ The table is said in-memory when it is loaded into the user's RAM. @@ -652,17 +678,17 @@ class InMemoryTable(TableBlock): """ @classmethod - def from_file(cls, filename: str): + def from_file(cls: type[T_InMemoryTable], filename: str) -> T_InMemoryTable: table = _in_memory_arrow_table_from_file(filename) return cls(table) @classmethod - def from_buffer(cls, buffer: pa.Buffer): + def from_buffer(cls: type[T_InMemoryTable], buffer: pa.Buffer) -> T_InMemoryTable: table = _in_memory_arrow_table_from_buffer(buffer) return cls(table) @classmethod - def from_pandas(cls, *args, **kwargs): + def from_pandas(cls: type[T_InMemoryTable], *args: Any, **kwargs: Any) -> T_InMemoryTable: """ Convert pandas.DataFrame to an Arrow Table. @@ -720,7 +746,7 @@ def from_pandas(cls, *args, **kwargs): return cls(pa.Table.from_pandas(*args, **kwargs)) @classmethod - def from_arrays(cls, *args, **kwargs): + def from_arrays(cls: type[T_InMemoryTable], *args: Any, **kwargs: Any) -> T_InMemoryTable: """ Construct a Table from Arrow arrays. @@ -740,7 +766,7 @@ def from_arrays(cls, *args, **kwargs): return cls(pa.Table.from_arrays(*args, **kwargs)) @classmethod - def from_pydict(cls, *args, **kwargs): + def from_pydict(cls: type[T_InMemoryTable], *args: Any, **kwargs: Any) -> T_InMemoryTable: """ Construct a Table from Arrow arrays or columns. @@ -758,12 +784,14 @@ def from_pydict(cls, *args, **kwargs): return cls(pa.Table.from_pydict(*args, **kwargs)) @classmethod - def from_pylist(cls, mapping, *args, **kwargs): + def from_pylist( + cls: type[T_InMemoryTable], mapping: Sequence[Mapping[str, Any]], *args: Any, **kwargs: Any + ) -> T_InMemoryTable: """ Construct a Table from list of rows / dictionaries. Args: - mapping (`List[dict]`): + mapping (`Sequence[Mapping[str, Any]]`): A mapping of strings to row values. schema (`Schema`, defaults to `None`): If not passed, will be inferred from the Mapping values @@ -776,7 +804,7 @@ def from_pylist(cls, mapping, *args, **kwargs): return cls(pa.Table.from_pylist(mapping, *args, **kwargs)) @classmethod - def from_batches(cls, *args, **kwargs): + def from_batches(cls: type[T_InMemoryTable], *args: Any, **kwargs: Any) -> T_InMemoryTable: """ Construct a Table from a sequence or iterator of Arrow `RecordBatches`. @@ -791,7 +819,7 @@ def from_batches(cls, *args, **kwargs): """ return cls(pa.Table.from_batches(*args, **kwargs)) - def slice(self, offset=0, length=None): + def slice(self, offset: int = 0, length: Optional[int] = None) -> InMemoryTable: """ Compute zero-copy slice of this Table. @@ -808,13 +836,13 @@ def slice(self, offset=0, length=None): # Use fast slicing here return InMemoryTable(self.fast_slice(offset=offset, length=length)) - def filter(self, *args, **kwargs): + def filter(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Select records from a Table. See `pyarrow.compute.filter` for full usage. """ return InMemoryTable(self.table.filter(*args, **kwargs)) - def flatten(self, *args, **kwargs): + def flatten(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Flatten this Table. Each column with a struct type is flattened into one column per struct field. Other columns are left unchanged. @@ -828,7 +856,7 @@ def flatten(self, *args, **kwargs): """ return InMemoryTable(table_flatten(self.table, *args, **kwargs)) - def combine_chunks(self, *args, **kwargs): + def combine_chunks(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Make a new table by combining the chunks this table has. @@ -844,7 +872,7 @@ def combine_chunks(self, *args, **kwargs): """ return InMemoryTable(self.table.combine_chunks(*args, **kwargs)) - def cast(self, *args, **kwargs): + def cast(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Cast table values to another schema. @@ -859,7 +887,7 @@ def cast(self, *args, **kwargs): """ return InMemoryTable(table_cast(self.table, *args, **kwargs)) - def replace_schema_metadata(self, *args, **kwargs): + def replace_schema_metadata(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ EXPERIMENTAL: Create shallow copy of table by replacing schema key-value metadata with the indicated new metadata (which may be `None`, @@ -873,7 +901,7 @@ def replace_schema_metadata(self, *args, **kwargs): """ return InMemoryTable(self.table.replace_schema_metadata(*args, **kwargs)) - def add_column(self, *args, **kwargs): + def add_column(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Add column to Table at position. @@ -894,7 +922,7 @@ def add_column(self, *args, **kwargs): """ return InMemoryTable(self.table.add_column(*args, **kwargs)) - def append_column(self, *args, **kwargs): + def append_column(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Append column at end of columns. @@ -911,7 +939,7 @@ def append_column(self, *args, **kwargs): """ return InMemoryTable(self.table.append_column(*args, **kwargs)) - def remove_column(self, *args, **kwargs): + def remove_column(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Create new Table with the indicated column removed. @@ -925,7 +953,7 @@ def remove_column(self, *args, **kwargs): """ return InMemoryTable(self.table.remove_column(*args, **kwargs)) - def set_column(self, *args, **kwargs): + def set_column(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Replace column in Table at position. @@ -944,13 +972,13 @@ def set_column(self, *args, **kwargs): """ return InMemoryTable(self.table.set_column(*args, **kwargs)) - def rename_columns(self, *args, **kwargs): + def rename_columns(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Create new table with columns renamed to provided names. """ return InMemoryTable(self.table.rename_columns(*args, **kwargs)) - def drop(self, *args, **kwargs): + def drop(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Drop one or more columns and return a new table. @@ -967,7 +995,7 @@ def drop(self, *args, **kwargs): """ return InMemoryTable(self.table.drop(*args, **kwargs)) - def select(self, *args, **kwargs): + def select(self, *args: Any, **kwargs: Any) -> InMemoryTable: """ Select columns of the table. @@ -986,6 +1014,8 @@ def select(self, *args, **kwargs): # The MemoryMappedTable needs replays to properly reload tables from the disk Replay = tuple[str, tuple, dict] +T_MemoryMappedTable = TypeVar("T_MemoryMappedTable", bound="MemoryMappedTable") + class MemoryMappedTable(TableBlock): """ @@ -1008,21 +1038,23 @@ class MemoryMappedTable(TableBlock): stay low. """ - def __init__(self, table: pa.Table, path: str, replays: Optional[list[Replay]] = None): + def __init__(self, table: pa.Table, path: str, replays: Optional[list[Replay]] = None) -> None: super().__init__(table) self.path = os.path.abspath(path) self.replays: list[Replay] = replays if replays is not None else [] @classmethod - def from_file(cls, filename: str, replays=None): + def from_file( + cls: type[T_MemoryMappedTable], filename: str, replays: Optional[list[Replay]] = None + ) -> T_MemoryMappedTable: table = _memory_mapped_arrow_table_from_file(filename) table = cls._apply_replays(table, replays) return cls(table, filename, replays) - def __getstate__(self): + def __getstate__(self) -> dict[str, Union[list[Replay], str]]: return {"path": self.path, "replays": self.replays} - def __setstate__(self, state): + def __setstate__(self, state: Mapping) -> None: path = state["path"] replays = state["replays"] table = _memory_mapped_arrow_table_from_file(path) @@ -1046,7 +1078,7 @@ def _append_replay(self, replay: Replay) -> list[Replay]: replays.append(replay) return replays - def slice(self, offset=0, length=None): + def slice(self, offset: int = 0, length: Optional[int] = None) -> MemoryMappedTable: """ Compute zero-copy slice of this Table. @@ -1060,12 +1092,12 @@ def slice(self, offset=0, length=None): Returns: `datasets.table.Table` """ - replay = ("slice", (offset, length), {}) + replay: tuple[str, tuple[int, Optional[int]], dict] = ("slice", (offset, length), {}) replays = self._append_replay(replay) # Use fast slicing here return MemoryMappedTable(self.fast_slice(offset=offset, length=length), self.path, replays) - def filter(self, *args, **kwargs): + def filter(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Select records from a Table. See `pyarrow.compute.filter` for full usage. """ @@ -1073,7 +1105,7 @@ def filter(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.filter(*args, **kwargs), self.path, replays) - def flatten(self, *args, **kwargs): + def flatten(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Flatten this Table. Each column with a struct type is flattened into one column per struct field. Other columns are left unchanged. @@ -1089,7 +1121,7 @@ def flatten(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(table_flatten(self.table, *args, **kwargs), self.path, replays) - def combine_chunks(self, *args, **kwargs): + def combine_chunks(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Make a new table by combining the chunks this table has. @@ -1107,7 +1139,7 @@ def combine_chunks(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.combine_chunks(*args, **kwargs), self.path, replays) - def cast(self, *args, **kwargs): + def cast(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Cast table values to another schema @@ -1124,7 +1156,7 @@ def cast(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(table_cast(self.table, *args, **kwargs), self.path, replays) - def replace_schema_metadata(self, *args, **kwargs): + def replace_schema_metadata(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ EXPERIMENTAL: Create shallow copy of table by replacing schema key-value metadata with the indicated new metadata (which may be None, @@ -1140,7 +1172,7 @@ def replace_schema_metadata(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.replace_schema_metadata(*args, **kwargs), self.path, replays) - def add_column(self, *args, **kwargs): + def add_column(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Add column to Table at position. @@ -1163,7 +1195,7 @@ def add_column(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.add_column(*args, **kwargs), self.path, replays) - def append_column(self, *args, **kwargs): + def append_column(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Append column at end of columns. @@ -1182,7 +1214,7 @@ def append_column(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.append_column(*args, **kwargs), self.path, replays) - def remove_column(self, *args, **kwargs): + def remove_column(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Create new Table with the indicated column removed. @@ -1198,7 +1230,7 @@ def remove_column(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.remove_column(*args, **kwargs), self.path, replays) - def set_column(self, *args, **kwargs): + def set_column(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Replace column in Table at position. @@ -1219,7 +1251,7 @@ def set_column(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.set_column(*args, **kwargs), self.path, replays) - def rename_columns(self, *args, **kwargs): + def rename_columns(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Create new table with columns renamed to provided names. """ @@ -1227,7 +1259,7 @@ def rename_columns(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.rename_columns(*args, **kwargs), self.path, replays) - def drop(self, *args, **kwargs): + def drop(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Drop one or more columns and return a new table. @@ -1246,7 +1278,7 @@ def drop(self, *args, **kwargs): replays = self._append_replay(replay) return MemoryMappedTable(self.table.drop(*args, **kwargs), self.path, replays) - def select(self, *args, **kwargs): + def select(self, *args: Any, **kwargs: Any) -> MemoryMappedTable: """ Select columns of the table. @@ -1268,7 +1300,11 @@ def select(self, *args, **kwargs): # The ``blocks`` attributes stores a list of list of blocks. # The first axis concatenates the tables along the axis 0 (it appends rows), # while the second axis concatenates tables along the axis 1 (it appends columns). -TableBlockContainer = TypeVar("TableBlockContainer", TableBlock, list[TableBlock], list[list[TableBlock]]) +TableBlockContainer = TypeVar("TableBlockContainer", TableBlock, Iterable[TableBlock], Iterable[Iterable[TableBlock]]) + + +Blocks = list[list[TableBlock]] +T_ConcatenationTable = TypeVar("T_ConcatenationTable", bound="ConcatenationTable") class ConcatenationTable(Table): @@ -1297,7 +1333,7 @@ class ConcatenationTable(Table): and the blocks by accessing the `ConcatenationTable.blocks` attribute. """ - def __init__(self, table: pa.Table, blocks: list[list[TableBlock]]): + def __init__(self, table: pa.Table, blocks: Iterable[Iterable[TableBlock]]) -> None: super().__init__(table) self.blocks = blocks # Check that all the blocks have the right type. @@ -1310,10 +1346,10 @@ def __init__(self, table: pa.Table, blocks: list[list[TableBlock]]): f", but got {_short_str(subtable)}." ) - def __getstate__(self): + def __getstate__(self) -> dict[str, Union[Iterable[Iterable[TableBlock]], pa.Schema]]: return {"blocks": self.blocks, "schema": self.table.schema} - def __setstate__(self, state): + def __setstate__(self, state: Mapping) -> None: blocks = state["blocks"] schema = state["schema"] table = self._concat_blocks_horizontally_and_vertically(blocks) @@ -1325,7 +1361,7 @@ def __setstate__(self, state): ConcatenationTable.__init__(self, table, blocks=blocks) @staticmethod - def _concat_blocks(blocks: list[Union[TableBlock, pa.Table]], axis: int = 0) -> pa.Table: + def _concat_blocks(blocks: Iterable[TableBlock], axis: int = 0) -> pa.Table: pa_tables = [table.table if hasattr(table, "table") else table for table in blocks] if axis == 0: # We set promote_options="default" to fill missing columns with null values @@ -1342,7 +1378,9 @@ def _concat_blocks(blocks: list[Union[TableBlock, pa.Table]], axis: int = 0) -> raise ValueError("'axis' must be either 0 or 1") @classmethod - def _concat_blocks_horizontally_and_vertically(cls, blocks: list[list[TableBlock]]) -> pa.Table: + def _concat_blocks_horizontally_and_vertically( + cls: type[T_ConcatenationTable], blocks: Iterable[Iterable[TableBlock]] + ) -> pa.Table: pa_tables_to_concat_vertically = [] for i, tables in enumerate(blocks): if not tables: @@ -1352,7 +1390,9 @@ def _concat_blocks_horizontally_and_vertically(cls, blocks: list[list[TableBlock return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0) @classmethod - def _merge_blocks(cls, blocks: TableBlockContainer, axis: Optional[int] = None) -> TableBlockContainer: + def _merge_blocks( + cls: type[T_ConcatenationTable], blocks: Iterable, axis: Optional[int] = None + ) -> Union[list, list[list]]: if axis is not None: merged_blocks = [] for is_in_memory, block_group in groupby(blocks, key=lambda x: isinstance(x, InMemoryTable)): @@ -1368,7 +1408,9 @@ def _merge_blocks(cls, blocks: TableBlockContainer, axis: Optional[int] = None) return merged_blocks @classmethod - def _consolidate_blocks(cls, blocks: TableBlockContainer) -> TableBlockContainer: + def _consolidate_blocks( + cls: type[T_ConcatenationTable], blocks: Union[TableBlock, Sequence] + ) -> Union[TableBlock, Sequence[TableBlock], Sequence[Sequence[TableBlock]]]: if isinstance(blocks, TableBlock): return blocks elif isinstance(blocks[0], TableBlock): @@ -1377,7 +1419,10 @@ def _consolidate_blocks(cls, blocks: TableBlockContainer) -> TableBlockContainer return cls._merge_blocks(blocks) @classmethod - def from_blocks(cls, blocks: TableBlockContainer) -> "ConcatenationTable": + def from_blocks( + cls: type[T_ConcatenationTable], + blocks: Union[TableBlock, Sequence[TableBlock], Sequence[Sequence[TableBlock]]], + ) -> T_ConcatenationTable: blocks = cls._consolidate_blocks(blocks) if isinstance(blocks, TableBlock): table = blocks @@ -1404,7 +1449,7 @@ def from_tables(cls, tables: list[Union[pa.Table, Table]], axis: int = 0) -> "Co """ - def to_blocks(table: Union[pa.Table, Table]) -> list[list[TableBlock]]: + def to_blocks(table: Union[pa.Table, Table]) -> Blocks: if isinstance(table, pa.Table): return [[InMemoryTable(table)]] elif isinstance(table, ConcatenationTable): @@ -1480,7 +1525,7 @@ def _slices(self): yield (offset, length) offset += length - def slice(self, offset=0, length=None): + def slice(self, offset: int = 0, length: Optional[int] = None) -> ConcatenationTable: """ Compute zero-copy slice of this Table. @@ -1511,7 +1556,7 @@ def slice(self, offset=0, length=None): length, offset = 0, 0 return ConcatenationTable(table, blocks) - def filter(self, mask, *args, **kwargs): + def filter(self, mask, *args, **kwargs) -> ConcatenationTable: """ Select records from a Table. See `pyarrow.compute.filter` for full usage. """ @@ -1522,7 +1567,7 @@ def filter(self, mask, *args, **kwargs): blocks.append([t.filter(submask, *args, **kwargs) for t in tables]) return ConcatenationTable(table, blocks) - def flatten(self, *args, **kwargs): + def flatten(self, *args, **kwargs) -> ConcatenationTable: """ Flatten this Table. Each column with a struct type is flattened into one column per struct field. Other columns are left unchanged. @@ -1540,7 +1585,7 @@ def flatten(self, *args, **kwargs): blocks.append([t.flatten(*args, **kwargs) for t in tables]) return ConcatenationTable(table, blocks) - def combine_chunks(self, *args, **kwargs): + def combine_chunks(self, *args, **kwargs) -> ConcatenationTable: """ Make a new table by combining the chunks this table has. @@ -1560,7 +1605,7 @@ def combine_chunks(self, *args, **kwargs): blocks.append([t.combine_chunks(*args, **kwargs) for t in tables]) return ConcatenationTable(table, blocks) - def cast(self, target_schema, *args, **kwargs): + def cast(self, target_schema, *args, **kwargs) -> ConcatenationTable: """ Cast table values to another schema. @@ -1591,7 +1636,7 @@ def cast(self, target_schema, *args, **kwargs): blocks.append(new_tables) return ConcatenationTable(table, blocks) - def replace_schema_metadata(self, *args, **kwargs): + def replace_schema_metadata(self, *args, **kwargs) -> ConcatenationTable: """ EXPERIMENTAL: Create shallow copy of table by replacing schema key-value metadata with the indicated new metadata (which may be `None`, @@ -1647,7 +1692,7 @@ def append_column(self, *args, **kwargs): """ raise NotImplementedError() - def remove_column(self, i, *args, **kwargs): + def remove_column(self, i, *args, **kwargs) -> ConcatenationTable: """ Create new Table with the indicated column removed. @@ -1690,7 +1735,7 @@ def set_column(self, *args, **kwargs): """ raise NotImplementedError() - def rename_columns(self, names, *args, **kwargs): + def rename_columns(self, names, *args, **kwargs) -> ConcatenationTable: """ Create new table with columns renamed to provided names. """ @@ -1703,7 +1748,7 @@ def rename_columns(self, names, *args, **kwargs): ) return ConcatenationTable(table, blocks) - def drop(self, columns, *args, **kwargs): + def drop(self, columns, *args, **kwargs) -> ConcatenationTable: """ Drop one or more columns and return a new table. @@ -1724,7 +1769,7 @@ def drop(self, columns, *args, **kwargs): blocks.append([t.drop([c for c in columns if c in t.column_names], *args, **kwargs) for t in tables]) return ConcatenationTable(table, blocks) - def select(self, columns, *args, **kwargs): + def select(self, columns, *args, **kwargs) -> ConcatenationTable: """ Select columns of the table. @@ -1744,7 +1789,7 @@ def select(self, columns, *args, **kwargs): return ConcatenationTable(table, blocks) -def concat_tables(tables: list[Table], axis: int = 0) -> Table: +def concat_tables(tables: list[Table], axis: int = 0) -> ConcatenationTable: """ Concatenate tables. @@ -1788,12 +1833,23 @@ def list_table_cache_files(table: Table) -> list[str]: return [] -def _wrap_for_chunked_arrays(func): +from typing import Concatenate + + +P = ParamSpec("P") + + +def _wrap_for_chunked_arrays( + func: Callable[Concatenate[Union[pa.Array, pa.ChunkedArray], P], pa.Array], +) -> Callable[Concatenate[Union[pa.Array, pa.ChunkedArray], P], Union[pa.Array, pa.ChunkedArray]]: """Apply the function on each chunk of a `pyarrow.ChunkedArray`, or on the array directly""" - def wrapper(array, *args, **kwargs): + def wrapper( + array: Union[pa.Array, pa.ChunkedArray], *args: P.args, **kwargs: P.kwargs + ) -> Union[pa.Array, pa.ChunkedArray]: if isinstance(array, pa.ChunkedArray): return pa.chunked_array([func(chunk, *args, **kwargs) for chunk in array.chunks]) + else: return func(array, *args, **kwargs) @@ -1831,7 +1887,7 @@ def _storage_type(type: pa.DataType) -> pa.DataType: return type -def _short_str(value: Any) -> str: +def _short_str(value: object) -> str: out = str(value) if len(out) > 3000: out = out[:1500] + "\n...\n" + out[-1500:] @@ -1953,7 +2009,7 @@ def array_cast( @_wrap_for_chunked_arrays def cast_array_to_feature( - array: pa.Array, feature: "FeatureType", allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True + array: pa.Array, feature: FeatureType, allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True ) -> pa.Array: """Cast an array to the arrow type that corresponds to the requested feature type. For custom features like [`Audio`] or [`Image`], it takes into account the "cast_storage" methods @@ -2110,7 +2166,7 @@ def cast_array_to_feature( @_wrap_for_chunked_arrays -def embed_array_storage(array: pa.Array, feature: "FeatureType"): +def embed_array_storage(array: pa.Array, feature: FeatureType) -> pa.Array: """Embed data into an arrays's storage. For custom features like Audio or Image, it takes into account the "embed_storage" methods they define to embed external data (e.g. an image file) into an array. @@ -2177,18 +2233,18 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"): class CastError(ValueError): """When it's not possible to cast an Arrow table to a specific schema or set of features""" - def __init__(self, *args, table_column_names: list[str], requested_column_names: list[str]) -> None: + def __init__(self, *args: Any, table_column_names: list[str], requested_column_names: list[str]) -> None: super().__init__(*args) self.table_column_names = table_column_names self.requested_column_names = requested_column_names - def __reduce__(self): + def __reduce__(self) -> tuple[partial[CastError], tuple[()]]: # Fix unpickling: TypeError: __init__() missing 2 required keyword-only arguments: 'table_column_names' and 'requested_column_names' return partial( CastError, table_column_names=self.table_column_names, requested_column_names=self.requested_column_names ), () - def details(self): + def details(self) -> str: new_columns = set(self.table_column_names) - set(self.requested_column_names) missing_columns = set(self.requested_column_names) - set(self.table_column_names) if new_columns and missing_columns: @@ -2199,7 +2255,7 @@ def details(self): return f"there are {len(missing_columns)} missing columns ({_short_str(missing_columns)})" -def cast_table_to_features(table: pa.Table, features: "Features"): +def cast_table_to_features(table: pa.Table, features: Features) -> pa.Table: """Cast a table to the arrow schema that corresponds to the requested features. Args: @@ -2221,7 +2277,7 @@ def cast_table_to_features(table: pa.Table, features: "Features"): return pa.Table.from_arrays(arrays, schema=features.arrow_schema) -def cast_table_to_schema(table: pa.Table, schema: pa.Schema): +def cast_table_to_schema(table: pa.Table, schema: pa.Schema) -> pa.Table: """Cast a table to the arrow schema. Different from `cast_table_to_features`, this method can preserve nullability. Args: @@ -2253,7 +2309,7 @@ def cast_table_to_schema(table: pa.Table, schema: pa.Schema): return pa.Table.from_arrays(arrays, schema=schema) -def embed_table_storage(table: pa.Table): +def embed_table_storage(table: pa.Table) -> pa.Table: """Embed external data into a table's storage. @@ -2275,7 +2331,7 @@ def embed_table_storage(table: pa.Table): return pa.Table.from_arrays(arrays, schema=features.arrow_schema) -def table_cast(table: pa.Table, schema: pa.Schema): +def table_cast(table: pa.Table, schema: pa.Schema) -> pa.Table: """Improved version of `pa.Table.cast`. It supports casting to feature types stored in the schema metadata. @@ -2297,7 +2353,7 @@ def table_cast(table: pa.Table, schema: pa.Schema): return table -def table_flatten(table: pa.Table): +def table_flatten(table: pa.Table) -> pa.Table: """Improved version of `pa.Table.flatten`. It behaves as `pa.Table.flatten` in a sense it does 1-step flatten of the columns with a struct type into one column per struct field, @@ -2339,7 +2395,9 @@ def table_flatten(table: pa.Table): return flat_table.replace_schema_metadata(flat_features.arrow_schema.metadata) -def table_visitor(table: pa.Table, function: Callable[[pa.Array], None]): +def table_visitor( + table: pa.Table, function: Callable[[pa.Array, FeatureType], None] +) -> None: """Visit all arrays in a table and apply a function to them. Args: @@ -2352,7 +2410,7 @@ def table_visitor(table: pa.Table, function: Callable[[pa.Array], None]): features = Features.from_arrow_schema(table.schema) - def _visit(array, feature): + def _visit(array: pa.Array, feature: FeatureType) -> None: if isinstance(array, pa.ChunkedArray): for chunk in array.chunks: _visit(chunk, feature) @@ -2378,7 +2436,7 @@ def _visit(array, feature): _visit(table[name], feature) -def table_iter(table: Table, batch_size: int, drop_last_batch=False) -> Iterator[pa.Table]: +def table_iter(table: Table, batch_size: int, drop_last_batch: bool = False) -> Generator[pa.Table]: """Iterate over sub-tables of size `batch_size`. Args: diff --git a/src/datasets/utils/_filelock.py b/src/datasets/utils/_filelock.py index 803ce77f5a4..67601b118f1 100644 --- a/src/datasets/utils/_filelock.py +++ b/src/datasets/utils/_filelock.py @@ -15,6 +15,7 @@ """Utilities to handle file locking in `datasets`.""" import os +from typing import Any, TypeVar from filelock import FileLock as FileLock_ from filelock import UnixFileLock @@ -22,6 +23,9 @@ from packaging import version +T = TypeVar("T", bound="FileLock") + + class FileLock(FileLock_): """ A `filelock.FileLock` initializer that handles long paths. @@ -30,7 +34,7 @@ class FileLock(FileLock_): MAX_FILENAME_LENGTH = 255 - def __init__(self, lock_file, *args, **kwargs): + def __init__(self, lock_file: str, *args: Any, **kwargs: Any) -> None: # The "mode" argument is required if we want to use the current umask in filelock >= 3.10 # In previous previous it was already using the current umask. if "mode" not in kwargs and version.parse(_filelock_version) >= version.parse("3.10.0"): @@ -41,7 +45,7 @@ def __init__(self, lock_file, *args, **kwargs): super().__init__(lock_file, *args, **kwargs) @classmethod - def hash_filename_if_too_long(cls, path: str) -> str: + def hash_filename_if_too_long(cls: type[T], path: str) -> str: path = os.path.abspath(os.path.expanduser(path)) filename = os.path.basename(path) max_filename_length = cls.MAX_FILENAME_LENGTH diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index d94c65fa18c..4ef84769fd6 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -15,6 +15,8 @@ # Lint as: python3 """Some python utils function and classes.""" +from __future__ import annotations + import copy import functools import itertools @@ -31,12 +33,27 @@ from pathlib import Path from queue import Empty from shutil import disk_usage -from typing import Any, Callable, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Generic, + Iterable, + Iterator, + Literal, + Optional, + Set, + TypeGuard, + TypeVar, + Union, +) from urllib.parse import urlparse import multiprocess import multiprocess.pool import numpy as np +from filelock import BaseFileLock from tqdm.auto import tqdm from .. import config @@ -53,13 +70,14 @@ try: # pragma: no branch - from typing import Final - import typing_extensions as _typing_extensions - from typing_extensions import Literal + from typing_extensions import Final, Literal except ImportError: _typing_extensions = Literal = Final = None +if TYPE_CHECKING: + from _typeshed import ConvertibleToFloat, DataclassInstance + logger = logging.get_logger(__name__) @@ -72,7 +90,7 @@ memoize = functools.lru_cache -def size_str(size_in_bytes): +def size_str(size_in_bytes: ConvertibleToFloat) -> str: """Returns a human readable size string. If size_in_bytes is None, then returns "Unknown size". @@ -140,7 +158,7 @@ def convert_file_size_to_int(size: Union[int, str]) -> int: raise ValueError(f"`size={size}` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") -def glob_pattern_to_regex(pattern): +def glob_pattern_to_regex(pattern: str) -> str: # partially taken from fsspec: # https://github.com/fsspec/filesystem_spec/blob/697d0f8133d8a5fbc3926e4761d7ecd51337ce50/fsspec/asyn.py#L735 return ( @@ -192,15 +210,14 @@ def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]: return _dict -def asdict(obj): +def asdict(obj: object) -> dict[str, Any]: """Convert an object to its dictionary representation recursively. """ # Implementation based on https://docs.python.org/3/library/dataclasses.html#dataclasses.asdict - - def _is_dataclass_instance(obj): + def _is_dataclass_instance(obj: object) -> TypeGuard[DataclassInstance]: # https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass return is_dataclass(obj) and not isinstance(obj, type) @@ -232,7 +249,7 @@ def _asdict_inner(obj): @contextmanager -def temporary_assignment(obj, attr, value): +def temporary_assignment(obj: object, attr: str, value: object) -> Generator[Any]: """Temporarily assign obj.attr to value.""" original = getattr(obj, attr, None) setattr(obj, attr, value) @@ -243,7 +260,7 @@ def temporary_assignment(obj, attr, value): @contextmanager -def temp_seed(seed: int, set_pytorch=False, set_tensorflow=False): +def temp_seed(seed: int, set_pytorch: bool = False, set_tensorflow: bool = False) -> Generator[Any]: """Temporarily set the random seed. This works for python numpy, pytorch and tensorflow.""" np_state = np.random.get_state() np.random.seed(seed) @@ -296,7 +313,7 @@ def temp_seed(seed: int, set_pytorch=False, set_tensorflow=False): delattr(tf_context, "_rng") -def unique_values(values): +def unique_values(values: Iterable) -> Generator[Any]: """Iterate over iterable and return only unique values in order.""" seen = set() for value in values: @@ -305,16 +322,20 @@ def unique_values(values): yield value -def no_op_if_value_is_null(func): +T = TypeVar("T") +R = TypeVar("R") + + +def no_op_if_value_is_null(func: Callable[[T], R]) -> Callable[[Optional[T]], Optional[R]]: """If the value is None, return None, else call `func`.""" - def wrapper(value): + def wrapper(value: Optional[T]) -> Optional[R]: return func(value) if value is not None else None return wrapper -def first_non_null_value(iterable): +def first_non_null_value(iterable: Iterable) -> tuple[int, Any]: """Return the index and the value of the first non-null value in the iterable. If all values are None, return -1 as index.""" for i, value in enumerate(iterable): if value is not None: @@ -322,7 +343,15 @@ def first_non_null_value(iterable): return -1, None -def first_non_null_non_empty_value(iterable): +def first_non_null_non_empty_value(iterable: Iterable) -> tuple[int, Any]: + """Return the index and the value of the first non-null non-empty value in the iterable. If all values are None or empty, return -1 as index.""" + for i, value in enumerate(iterable): + if value is not None and not (isinstance(value, (dict, list)) and len(value) == 0): + return i, value + return -1, None + + +def first_non_null_non_empty_value(iterable) -> tuple[int, Optional[Any]]: """Return the index and the value of the first non-null non-empty value in the iterable. If all values are None or empty, return -1 as index.""" for i, value in enumerate(iterable): if value is not None and not (isinstance(value, (dict, list)) and len(value) == 0): @@ -330,14 +359,18 @@ def first_non_null_non_empty_value(iterable): return -1, None -def zip_dict(*dicts): +def zip_dict(*dicts: dict[Any, Any]) -> Iterator: """Iterate over items of dictionaries grouped by their keys.""" for key in unique_values(itertools.chain(*dicts)): # set merge all keys # Will raise KeyError if the dict don't have the same keys yield key, tuple(d[key] for d in dicts) -class NonMutableDict(dict): +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class NonMutableDict(dict[KT, VT], Generic[KT, VT]): """Dict where keys can only be added but not modified. Will raise an error if the user try to overwrite one key. The error message @@ -345,7 +378,7 @@ class NonMutableDict(dict): the overwritten key. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._error_msg = kwargs.pop( "error_msg", "Try to overwrite existing key: {key}", @@ -354,25 +387,26 @@ def __init__(self, *args, **kwargs): raise ValueError("NonMutableDict cannot be initialized with kwargs.") super().__init__(*args, **kwargs) - def __setitem__(self, key, value): + def __setitem__(self, key: KT, value: VT) -> None: if key in self: raise ValueError(self._error_msg.format(key=key)) return super().__setitem__(key, value) - def update(self, other): + def update(self, other) -> None: if any(k in self for k in other): raise ValueError(self._error_msg.format(key=set(self) & set(other))) return super().update(other) -class classproperty(property): # pylint: disable=invalid-name +class classproperty(property, Generic[T]): # pylint: disable=invalid-name """Descriptor to be used as decorator for @classmethods.""" - def __get__(self, obj, objtype=None): - return self.fget.__get__(None, objtype)() + def __get__(self, obj: Optional[T], objtype: Optional[type[T]] = None) -> Optional[T]: + if self.fget is not None: + return self.fget.__get__(None, objtype)() -def _single_map_nested(args): +def _single_map_nested(args: Iterable) -> Union[dict, list, tuple, np.ndarray]: """Apply a function recursively to each element of a nested data struct.""" function, data_struct, batched, batch_size, types, rank, disable_tqdm, desc = args @@ -414,6 +448,7 @@ def _single_map_nested(args): return tuple(mapped) else: return np.array(mapped) +from typing import Sequence def map_nested( @@ -426,8 +461,8 @@ def map_nested( num_proc: Optional[int] = None, parallel_min_length: int = 2, batched: bool = False, - batch_size: Optional[int] = 1000, - types: Optional[tuple] = None, + batch_size: int = 1000, + types: Optional[Sequence[type]] = None, disable_tqdm: bool = True, desc: Optional[str] = None, ) -> Any: @@ -478,7 +513,7 @@ def map_nested( `Any` """ if types is None: - types = [] + types: list[None] = [] if not dict_only: if map_list: types.append(list) @@ -486,7 +521,7 @@ def map_nested( types.append(tuple) if map_numpy: types.append(np.ndarray) - types = tuple(types) + types: tuple = tuple(types) # Singleton if not isinstance(data_struct, dict) and not isinstance(data_struct, types): @@ -554,10 +589,10 @@ def map_nested( class NestedDataStructure: - def __init__(self, data=None): + def __init__(self, data: Optional[Iterable] = None) -> None: self.data = data if data is not None else [] - def flatten(self, data=None): + def flatten(self, data: Optional[Iterable] = None) -> list: data = data if data is not None else self.data if isinstance(data, dict): return self.flatten(list(data.values())) @@ -567,7 +602,7 @@ def flatten(self, data=None): return [data] -def has_sufficient_disk_space(needed_bytes, directory="."): +def has_sufficient_disk_space(needed_bytes: int, directory: str = ".") -> bool: try: free_bytes = disk_usage(os.path.abspath(directory)).free except OSError: @@ -594,7 +629,7 @@ def _convert_github_url(url_path: str) -> tuple[str, Optional[str]]: return url_path, sub_directory -def lock_importable_file(importable_local_file: str) -> FileLock: +def lock_importable_file(importable_local_file: Union[str, os.PathLike]) -> BaseFileLock: # Check the directory with a unique name in our dataset folder # path is: ./datasets/dataset_name/hash_from_code/script.py # we use a hash as subdirectory_name to be able to have multiple versions of a dataset processing file together @@ -603,7 +638,9 @@ def lock_importable_file(importable_local_file: str) -> FileLock: return FileLock(lock_path) -def get_imports(file_path: str) -> tuple[str, str, str, str]: + + +def get_imports(file_path: str) -> list[tuple[Literal["external", "internal", "library"], str, str ,Optional[str]]]: """Find whether we should import or clone additional files for a given processing script. And list the import. @@ -628,7 +665,7 @@ def get_imports(file_path: str) -> tuple[str, str, str, str]: lines.extend(f.readlines()) logger.debug(f"Checking {file_path} for additional imports.") - imports: list[tuple[str, str, str, Optional[str]]] = [] + imports: list[tuple[Literal["external", "internal", "library"], str, str ,Optional[str]]] = [] is_in_docstring = False for line in lines: docstr_start_match = re.findall(r'[\s\S]*?"""[\s\S]*?', line) @@ -676,7 +713,7 @@ def get_imports(file_path: str) -> tuple[str, str, str, str]: return imports -def copyfunc(func): +def copyfunc(func: Callable) -> Callable: result = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) result.__kwdefaults__ = func.__kwdefaults__ return result @@ -729,9 +766,6 @@ def iflatmap_unordered( [async_result.get(timeout=0.05) for async_result in async_results] -T = TypeVar("T") - - def iter_batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: if n < 1: raise ValueError(f"Invalid batch size {n}")