diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 9393d0960d3..26b32c51b62 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -830,6 +830,7 @@ def map( fn_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, desc: Optional[str] = None, + try_original_type: Optional[bool] = True, ) -> "DatasetDict": """ Apply a function to all the examples in the table (individually or in batches) and update the table. @@ -908,6 +909,9 @@ def map( use multiprocessing. desc (`str`, *optional*, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples. + try_original_type (`Optional[bool]`, defaults to `True`): + Try to keep the types of the original columns (e.g. int32 -> int32). + Set to False if you want to always infer new types. Example: @@ -956,6 +960,7 @@ def map( fn_kwargs=fn_kwargs, num_proc=num_proc, desc=desc, + try_original_type=try_original_type, ) if with_split: diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index 83f0395718b..72f84071e1c 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -24,21 +24,27 @@ class DatasetDictTest(TestCase): - def _create_dummy_dataset(self, multiple_columns=False): + def _create_dummy_dataset(self, multiple_columns=False, int_to_float=False): if multiple_columns: data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} dset = Dataset.from_dict(data) + elif int_to_float: + data = { + "text": ["text1", "text2", "text3", "text4"], + "labels": [[1, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0]], + } + dset = Dataset.from_dict(data) else: dset = Dataset.from_dict( {"filename": ["my_name-train" + "_" + f"{x:03d}" for x in np.arange(30).tolist()]} ) return dset - def _create_dummy_dataset_dict(self, multiple_columns=False) -> DatasetDict: + def _create_dummy_dataset_dict(self, multiple_columns=False, int_to_float=False) -> DatasetDict: return DatasetDict( { - "train": self._create_dummy_dataset(multiple_columns=multiple_columns), - "test": self._create_dummy_dataset(multiple_columns=multiple_columns), + "train": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float), + "test": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float), } ) @@ -325,6 +331,28 @@ def test_map(self): self.assertListEqual(sorted(mapped_dsets_2["train"].column_names), sorted(["filename", "foo", "bar"])) del dsets, mapped_dsets_1, mapped_dsets_2 + # casting int labels to float labels + with tempfile.TemporaryDirectory() as tmp_dir: + dset_dict = self._create_dummy_dataset_dict(int_to_float=True) + + def _preprocess(examples): + result = {"labels": [list(map(float, labels)) for labels in examples["labels"]]} + return result + + with dset_dict.map( + _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True + ) as dset_test: + for labels in dset_test["test"]["labels"]: + for label in labels: + self.assertIsInstance(label, int) + + with dset_dict.map( + _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False + ) as dset_test: + for labels in dset_test["test"]["labels"]: + for label in labels: + self.assertIsInstance(label, float) + def test_iterable_map(self): dsets = self._create_dummy_iterable_dataset_dict() fn_kwargs = {"n": 3}