From 3759c0f3b3830c96c29c9dad2388f3e56ca65569 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Wed, 3 Apr 2024 23:09:53 +0200 Subject: [PATCH] Fix divide_dataset in Federated Datasets (#3192) --- datasets/flwr_datasets/utils.py | 6 ++++-- datasets/flwr_datasets/utils_test.py | 23 +++++++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index a6e4fa8d0f0b..346d897ccdd6 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -133,6 +133,7 @@ def divide_dataset( >>> train_test = divide_dataset(dataset=partition, division=division) >>> train, test = train_test["train"], train_test["test"] """ + _check_division_config_correctness(division) dataset_length = len(dataset) ranges = _create_division_indices_ranges(dataset_length, division) if isinstance(division, (list, tuple)): @@ -162,7 +163,7 @@ def _create_division_indices_ranges( for fraction in division: end_idx += int(dataset_length * fraction) ranges.append(range(start_idx, end_idx)) - start_idx += end_idx + start_idx = end_idx elif isinstance(division, dict): ranges = [] start_idx = 0 @@ -170,7 +171,7 @@ def _create_division_indices_ranges( for fraction in division.values(): end_idx += int(dataset_length * fraction) ranges.append(range(start_idx, end_idx)) - start_idx += end_idx + start_idx = end_idx else: TypeError( f"The type of the `division` should be dict, " @@ -274,6 +275,7 @@ def concatenate_divisions( concatenated_divisions : Dataset A dataset created as concatenation of the divisions from all partitions. """ + _check_division_config_correctness(partition_division) divisions = [] zero_len_divisions = 0 for partition_id in range(partitioner.num_partitions): diff --git a/datasets/flwr_datasets/utils_test.py b/datasets/flwr_datasets/utils_test.py index 3bf5afddf978..4add9f88eeb5 100644 --- a/datasets/flwr_datasets/utils_test.py +++ b/datasets/flwr_datasets/utils_test.py @@ -31,13 +31,32 @@ "expected_concatenation_size", ), [ + # Create 1 division + ((1.0,), [40], 0, 40), + ({"train": 1.0}, [40], "train", 40), + # Create 2 divisions ((0.8, 0.2), [32, 8], 1, 8), - ([0.8, 0.2], [32, 8], 1, 8), ({"train": 0.8, "test": 0.2}, [32, 8], "test", 8), + # Create 3 divisions + ([0.6, 0.2, 0.2], [24, 8, 8], 1, 8), + ({"train": 0.6, "valid": 0.2, "test": 0.2}, [24, 8, 8], "test", 8), + # Create 4 divisions + ([0.4, 0.2, 0.2, 0.2], [16, 8, 8, 8], 1, 8), + ({"0": 0.4, "1": 0.2, "2": 0.2, "3": 0.2}, [16, 8, 8, 8], "1", 8), # Not full dataset + # Create 1 division + ([0.8], [32], 0, 32), + ({"train": 0.8}, [32], "train", 32), + # Create 2 divisions ([0.2, 0.1], [8, 4], 1, 4), ((0.2, 0.1), [8, 4], 0, 8), ({"train": 0.2, "test": 0.1}, [8, 4], "test", 4), + # Create 3 divisions + ([0.6, 0.2, 0.1], [24, 8, 4], 2, 4), + ({"train": 0.6, "valid": 0.2, "test": 0.1}, [24, 8, 4], "test", 4), + # Create 4 divisions + ([0.4, 0.2, 0.1, 0.2], [16, 8, 4, 8], 2, 4), + ({"0": 0.4, "1": 0.2, "2": 0.1, "3": 0.2}, [16, 8, 4, 8], "2", 4), ], ) class UtilsTests(unittest.TestCase): @@ -60,7 +79,7 @@ def test_correct_sizes(self) -> None: else: lengths = [len(split) for split in divided_dataset.values()] - self.assertEqual(lengths, self.sizes) + self.assertEqual(self.sizes, lengths) def test_correct_return_types(self) -> None: """Test correct types of the divided dataset based on the config."""