Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Javier <[email protected]>
  • Loading branch information
adam-narozniak and jafermarq authored Feb 20, 2024
1 parent ae4031d commit 5a3f45c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 5 additions & 5 deletions datasets/flwr_datasets/partitioner/dirichlet_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,23 @@ def _initialize_alpha(
elif isinstance(alpha, List):
if len(alpha) != self._num_partitions:
raise ValueError(
"The alpha parameter needs to be of length of equal to the "
"If passing alpha as a List, it needs to be of length of equal to "
"num_partitions."
)
alpha = np.asarray(alpha)
elif isinstance(alpha, np.ndarray):
# pylint: disable=R1720
if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions:
raise ValueError(
"The alpha parameter needs to be of length of equal to"
"the num_partitions."
"If passing alpha as an NDArray, its length needs to be of length "
"equal to num_partitions."
)
elif alpha.ndim == 2:
alpha = alpha.flatten()
if alpha.shape[0] != self._num_partitions:
raise ValueError(
"The alpha parameter needs to be of length of equal to "
"the num_partitions."
"If passing alpha as an NDArray, its size needs to be of length"
" equal to num_partitions."
)
else:
raise ValueError("The given alpha format is not supported.")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Test DirichletPartitioner."""


# pylint: disable=W0212
import unittest
from typing import Tuple, Union
Expand Down

0 comments on commit 5a3f45c

Please sign in to comment.