diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index d6f6b6261833..7c86570fe487 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -229,6 +229,14 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 math.ceil(len(self.dataset) / self._shard_size) ) num_usable_shards_in_dataset = self._num_shards_used + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "Based on the given arguments the creation of the partitions " + "is impossible. The implied number of partitions that can be " + "used is lower than the number of requested partitions " + "resulting in empty partitions. Please decrease the size of " + "shards: `shard_size`." + ) else: raise ValueError( "The keep_incomplete_shards need to be specified " @@ -251,6 +259,13 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 "keep_incomplete_shards is not correct." ) + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "The specified configuration results in empty partitions because the " + "number of usable shards is smaller that the number partitions. " + "Try decreasing the shard size or the number of partitions. " + ) + indices_on_which_to_split_shards = np.cumsum( num_shards_per_node_array, dtype=int )