diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 100e9943c530..bb6d46f266e9 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -144,10 +144,10 @@ def test_multiple_partitioners(self) -> None: dataset_test_partition0 = dataset_fds.load_partition(0, self.test_split) dataset = datasets.load_dataset(self.dataset_name) - self.assertEqual( - len(dataset_test_partition0), - len(dataset[self.test_split]) // num_test_partitions, - ) + expected_len = len(dataset[self.test_split]) // num_test_partitions + mod = len(dataset[self.test_split]) % num_test_partitions + expected_len += 1 if 0 < mod else 0 + self.assertEqual(len(dataset_test_partition0), expected_len) def test_no_need_for_split_keyword_if_one_partitioner(self) -> None: """Test if partitions got with and without split args are the same."""