diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py index 1a79721a4c24..0dd115b39cf0 100644 --- a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py @@ -60,7 +60,7 @@ def __init__( # pylint: disable=R0913 self, partition_sizes: Union[List[int], NDArrayInt], partition_by: str, - alpha: Union[float, List[float], NDArrayFloat], + alpha: Union[int, float, List[float], NDArrayFloat], shuffle: bool = True, seed: Optional[int] = 42, ) -> None: @@ -109,7 +109,7 @@ def load_partition(self, node_id: int) -> datasets.Dataset: return self.dataset.select(self._node_id_to_indices[node_id]) def _initialize_alpha_if_needed( - self, alpha: Union[float, List[float], NDArrayFloat] + self, alpha: Union[int, float, List[float], NDArrayFloat] ) -> NDArrayFloat: """Convert alpha to the used format in the code a NDArrayFloat. @@ -130,7 +130,10 @@ def _initialize_alpha_if_needed( if self._initialized_alpha: assert self._alpha is not None return self._alpha - if isinstance(alpha, float): + if isinstance(alpha, int): + assert self._num_unique_classes is not None + alpha = np.array([float(alpha)], dtype=float).repeat(self._num_unique_classes) + elif isinstance(alpha, float): assert self._num_unique_classes is not None alpha = np.array([alpha], dtype=float).repeat(self._num_unique_classes) elif isinstance(alpha, List):