From f1a63ec6693c06a593c36971f04bb720477b1b89 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Mon, 28 Apr 2025 21:36:31 -0700 Subject: [PATCH 1/3] Add try_original_type to DatasetDict.map --- src/datasets/dataset_dict.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 9393d0960d3..bf94177b5ff 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -830,6 +830,7 @@ def map( fn_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, desc: Optional[str] = None, + try_original_type: Optional[bool] = True, ) -> "DatasetDict": """ Apply a function to all the examples in the table (individually or in batches) and update the table. @@ -908,6 +909,9 @@ def map( use multiprocessing. desc (`str`, *optional*, defaults to `None`): Meaningful description to be displayed alongside with the progress bar while mapping examples. + try_original_type (`Optional[bool]`, defaults to `True`): + Try to keep the types of the original columns (e.g. int32 -> int32). + Set to False if you want to always infer new types. Example: @@ -956,6 +960,7 @@ def map( fn_kwargs=fn_kwargs, num_proc=num_proc, desc=desc, + try_original_type=try_original_type ) if with_split: From 5e6cb8f5fe631d9be98f37b2e060855dfe0ab2b5 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Mon, 28 Apr 2025 21:37:08 -0700 Subject: [PATCH 2/3] Add test cases --- tests/test_dataset_dict.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index 83f0395718b..d2d43bf296f 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -24,21 +24,27 @@ class DatasetDictTest(TestCase): - def _create_dummy_dataset(self, multiple_columns=False): + def _create_dummy_dataset(self, multiple_columns=False, int_to_float=False): if multiple_columns: data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} dset = Dataset.from_dict(data) + elif int_to_float: + data = { + "text": ["text1", "text2", "text3", "text4"], + "labels": [[1, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0]], + } + dset = Dataset.from_dict(data) else: dset = Dataset.from_dict( {"filename": ["my_name-train" + "_" + f"{x:03d}" for x in np.arange(30).tolist()]} ) return dset - def _create_dummy_dataset_dict(self, multiple_columns=False) -> DatasetDict: + def _create_dummy_dataset_dict(self, multiple_columns=False, int_to_float=False) -> DatasetDict: return DatasetDict( { - "train": self._create_dummy_dataset(multiple_columns=multiple_columns), - "test": self._create_dummy_dataset(multiple_columns=multiple_columns), + "train": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float), + "test": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float), } ) @@ -325,6 +331,28 @@ def test_map(self): self.assertListEqual(sorted(mapped_dsets_2["train"].column_names), sorted(["filename", "foo", "bar"])) del dsets, mapped_dsets_1, mapped_dsets_2 + # casting int labels to float labels + with tempfile.TemporaryDirectory() as tmp_dir: + dset_dict = self._create_dummy_dataset_dict(int_to_float=True) + + def _preprocess(examples): + result = {"labels": [list(map(float, labels)) for labels in examples["labels"]]} + return result + + with dset_dict.map( + _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True + ) as dset_test: + for labels in dset_test['test']["labels"]: + for label in labels: + self.assertIsInstance(label, int) + + with dset_dict.map( + _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False + ) as dset_test: + for labels in dset_test['test']["labels"]: + for label in labels: + self.assertIsInstance(label, float) + def test_iterable_map(self): dsets = self._create_dummy_iterable_dataset_dict() fn_kwargs = {"n": 3} From 8e53f1c777cc04a26869ed3cabea03f51ea5d6ae Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Tue, 29 Apr 2025 08:41:03 -0700 Subject: [PATCH 3/3] Apply make style --- src/datasets/dataset_dict.py | 2 +- tests/test_dataset_dict.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index bf94177b5ff..26b32c51b62 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -960,7 +960,7 @@ def map( fn_kwargs=fn_kwargs, num_proc=num_proc, desc=desc, - try_original_type=try_original_type + try_original_type=try_original_type, ) if with_split: diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index d2d43bf296f..72f84071e1c 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -342,14 +342,14 @@ def _preprocess(examples): with dset_dict.map( _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True ) as dset_test: - for labels in dset_test['test']["labels"]: + for labels in dset_test["test"]["labels"]: for label in labels: self.assertIsInstance(label, int) with dset_dict.map( _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False ) as dset_test: - for labels in dset_test['test']["labels"]: + for labels in dset_test["test"]["labels"]: for label in labels: self.assertIsInstance(label, float)