From 5a3f45ce5981debe389a1e623e7d7d91fe9722ef Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:39:36 +0100 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Javier --- .../flwr_datasets/partitioner/dirichlet_partitioner.py | 10 +++++----- .../partitioner/dirichlet_partitioner_test.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index aad90697f7bc..5f1df71991bb 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -158,7 +158,7 @@ def _initialize_alpha( elif isinstance(alpha, List): if len(alpha) != self._num_partitions: raise ValueError( - "The alpha parameter needs to be of length of equal to the " + "If passing alpha as a List, it needs to be of length of equal to " "num_partitions." ) alpha = np.asarray(alpha) @@ -166,15 +166,15 @@ def _initialize_alpha( # pylint: disable=R1720 if alpha.ndim == 1 and alpha.shape[0] != self._num_partitions: raise ValueError( - "The alpha parameter needs to be of length of equal to" - "the num_partitions." + "If passing alpha as an NDArray, its length needs to be of length " + "equal to num_partitions." ) elif alpha.ndim == 2: alpha = alpha.flatten() if alpha.shape[0] != self._num_partitions: raise ValueError( - "The alpha parameter needs to be of length of equal to " - "the num_partitions." + "If passing alpha as an NDArray, its size needs to be of length" + " equal to num_partitions." ) else: raise ValueError("The given alpha format is not supported.") diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py index d8cfb3cb854e..c123f84effb7 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Test DirichletPartitioner.""" + + # pylint: disable=W0212 import unittest from typing import Tuple, Union