diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst
index 1107f6798b23..ff3dbb605846 100644
--- a/doc/source/how-to-install-flower.rst
+++ b/doc/source/how-to-install-flower.rst
@@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth
Install stable release
----------------------
+Using pip
+~~~~~~~~~
+
Stable releases are available on `PyPI `_::
python -m pip install flwr
@@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed
python -m pip install flwr[simulation]
+Using conda (or mamba)
+~~~~~~~~~~~~~~~~~~~~~~
+
+Flower can also be installed from the ``conda-forge`` channel.
+
+If you have not added ``conda-forge`` to your channels, you will first need to run the following::
+
+ conda config --add channels conda-forge
+ conda config --set channel_priority strict
+
+Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``::
+
+ conda install flwr
+
+or with ``mamba``::
+
+ mamba install flwr
+
+
Verify installation
-------------------
diff --git a/pyproject.toml b/pyproject.toml
index 0616ffdbeffd..24d20c7ced40 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -88,7 +88,7 @@ isort = "==5.12.0"
black = { version = "==23.10.1", extras = ["jupyter"] }
docformatter = "==1.7.5"
mypy = "==1.6.1"
-pylint = "==2.13.9"
+pylint = "==3.0.3"
flake8 = "==5.0.4"
pytest = "==7.4.3"
pytest-cov = "==4.1.0"
@@ -137,7 +137,7 @@ line-length = 88
target-version = ["py38", "py39", "py310", "py311"]
[tool.pylint."MESSAGES CONTROL"]
-disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias"
+disable = "duplicate-code,too-few-public-methods,useless-import-alias"
[tool.pytest.ini_options]
minversion = "6.2"
@@ -184,7 +184,7 @@ target-version = "py38"
line-length = 88
select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"]
-ignore = ["B024", "B027"]
+ignore = ["B024", "B027", "D205", "D209"]
exclude = [
".bzr",
".direnv",
diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py
index a5b285fbb7fb..91fa5468ae75 100644
--- a/src/py/flwr/client/app.py
+++ b/src/py/flwr/client/app.py
@@ -138,10 +138,12 @@ def _check_actionable_client(
client: Optional[Client], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
- raise Exception("Both `client_fn` and `client` are `None`, but one is required")
+ raise ValueError(
+ "Both `client_fn` and `client` are `None`, but one is required"
+ )
if client_fn is not None and client is not None:
- raise Exception(
+ raise ValueError(
"Both `client_fn` and `client` are provided, but only one is allowed"
)
@@ -150,6 +152,7 @@ def _check_actionable_client(
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
+# pylint: disable=too-many-arguments
def start_client(
*,
server_address: str,
@@ -299,7 +302,7 @@ def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> Client:
if client is None: # Added this to keep mypy happy
- raise Exception(
+ raise ValueError(
"Both `client_fn` and `client` are `None`, but one is required"
)
return client # Always return the same instance
diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py
index 7ef6410debad..56d6308a0fe2 100644
--- a/src/py/flwr/client/app_test.py
+++ b/src/py/flwr/client/app_test.py
@@ -41,19 +41,19 @@ class PlainClient(Client):
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def fit(self, ins: FitIns) -> FitRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
class NeedsWrappingClient(NumPyClient):
@@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient):
def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def get_parameters(self, config: Config) -> NDArrays:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def fit(
self, parameters: NDArrays, config: Config
) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def evaluate(
self, parameters: NDArrays, config: Config
) -> Tuple[float, int, Dict[str, Scalar]]:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def test_to_client_with_client() -> None:
diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py
index 41b4d676df43..c39b89b31da3 100644
--- a/src/py/flwr/client/dpfedavg_numpy_client.py
+++ b/src/py/flwr/client/dpfedavg_numpy_client.py
@@ -117,16 +117,16 @@ def fit(
update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]
if "dpfedavg_clip_norm" not in config:
- raise Exception("Clipping threshold not supplied by the server.")
+ raise KeyError("Clipping threshold not supplied by the server.")
if not isinstance(config["dpfedavg_clip_norm"], float):
- raise Exception("Clipping threshold should be a floating point value.")
+ raise TypeError("Clipping threshold should be a floating point value.")
# Clipping
update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])
if "dpfedavg_noise_stddev" in config:
if not isinstance(config["dpfedavg_noise_stddev"], float):
- raise Exception(
+ raise TypeError(
"Scale of noise to be added should be a floating point value."
)
# Noising
@@ -138,7 +138,7 @@ def fit(
# Calculating value of norm indicator bit, required for adaptive clipping
if "dpfedavg_adaptive_clip_enabled" in config:
if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
- raise Exception(
+ raise TypeError(
"dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
)
metrics["dpfedavg_norm_bit"] = not clipped
diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py
index 13b1948eec07..3599e1dfb254 100644
--- a/src/py/flwr/client/message_handler/task_handler.py
+++ b/src/py/flwr/client/message_handler/task_handler.py
@@ -80,8 +80,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}
# Check if certain fields are already initialized
- # pylint: disable-next=too-many-boolean-expressions
- if (
+ if ( # pylint: disable-next=too-many-boolean-expressions
"task_id" in initialized_fields_in_task_res
or "group_id" in initialized_fields_in_task_res
or "run_id" in initialized_fields_in_task_res
diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py
index 2312741f5af6..d67fb90512d4 100644
--- a/src/py/flwr/client/numpy_client.py
+++ b/src/py/flwr/client/numpy_client.py
@@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
# Return FitRes
parameters_prime, num_examples, metrics = results
@@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
# Return EvaluateRes
loss, num_examples, metrics = results
diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py
index d22b246dbd61..87b06dd0be4e 100644
--- a/src/py/flwr/client/rest_client/connection.py
+++ b/src/py/flwr/client/rest_client/connection.py
@@ -143,6 +143,7 @@ def create_node() -> None:
},
data=create_node_req_bytes,
verify=verify,
+ timeout=None,
)
# Check status code and headers
@@ -185,6 +186,7 @@ def delete_node() -> None:
},
data=delete_node_req_req_bytes,
verify=verify,
+ timeout=None,
)
# Check status code and headers
@@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]:
},
data=pull_task_ins_req_bytes,
verify=verify,
+ timeout=None,
)
# Check status code and headers
@@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None:
},
data=push_task_res_request_bytes,
verify=verify,
+ timeout=None,
)
state[KEY_TASK_INS] = None
diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py
index efbb00a9d916..4b74c1ace3de 100644
--- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py
+++ b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py
@@ -333,7 +333,7 @@ def _share_keys(
# Check if the size is larger than threshold
if len(state.public_keys_dict) < state.threshold:
- raise Exception("Available neighbours number smaller than threshold")
+ raise ValueError("Available neighbours number smaller than threshold")
# Check if all public keys are unique
pk_list: List[bytes] = []
@@ -341,14 +341,14 @@ def _share_keys(
pk_list.append(pk1)
pk_list.append(pk2)
if len(set(pk_list)) != len(pk_list):
- raise Exception("Some public keys are identical")
+ raise ValueError("Some public keys are identical")
# Check if public keys of this client are correct in the dictionary
if (
state.public_keys_dict[state.sid][0] != state.pk1
or state.public_keys_dict[state.sid][1] != state.pk2
):
- raise Exception(
+ raise ValueError(
"Own public keys are displayed in dict incorrectly, should not happen!"
)
@@ -393,7 +393,7 @@ def _collect_masked_input(
ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
if len(ciphertexts) + 1 < state.threshold:
- raise Exception("Not enough available neighbour clients.")
+ raise ValueError("Not enough available neighbour clients.")
# Decrypt ciphertexts, verify their sources, and store shares.
for src, ciphertext in zip(srcs, ciphertexts):
@@ -409,7 +409,7 @@ def _collect_masked_input(
f"from {actual_src} instead of {src}."
)
if dst != state.sid:
- ValueError(
+ raise ValueError(
f"Client {state.sid}: received an encrypted message"
f"for Client {dst} from Client {src}."
)
@@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
# Send private mask seed share for every avaliable client (including itclient)
# Send first private key share for building pairwise mask for every dropped client
if len(active_sids) < state.threshold:
- raise Exception("Available neighbours number smaller than threshold")
+ raise ValueError("Available neighbours number smaller than threshold")
sids, shares = [], []
sids += active_sids
diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py
new file mode 100644
index 000000000000..3d40c0488baa
--- /dev/null
+++ b/src/py/flwr/common/parametersrecord.py
@@ -0,0 +1,110 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ParametersRecord and Array."""
+
+
+from dataclasses import dataclass, field
+from typing import List, Optional, OrderedDict
+
+
+@dataclass
+class Array:
+ """Array type.
+
+ A dataclass containing serialized data from an array-like or tensor-like object
+ along with some metadata about it.
+
+ Parameters
+ ----------
+ dtype : str
+ A string representing the data type of the serialised object (e.g. `np.float32`)
+
+ shape : List[int]
+ A list representing the shape of the unserialized array-like object. This is
+ used to deserialize the data (depending on the serialization method) or simply
+ as a metadata field.
+
+ stype : str
+ A string indicating the type of serialisation mechanism used to generate the
+ bytes in `data` from an array-like or tensor-like object.
+
+ data: bytes
+ A buffer of bytes containing the data.
+ """
+
+ dtype: str
+ shape: List[int]
+ stype: str
+ data: bytes
+
+
+@dataclass
+class ParametersRecord:
+ """Parameters record.
+
+ A dataclass storing named Arrays in order. This means that it holds entries as an
+ OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
+ PyTorch's state_dict, but holding serialised tensors instead.
+ """
+
+ keep_input: bool
+ data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
+
+ def __init__(
+ self,
+ array_dict: Optional[OrderedDict[str, Array]] = None,
+ keep_input: bool = False,
+ ) -> None:
+ """Construct a ParametersRecord object.
+
+ Parameters
+ ----------
+ array_dict : Optional[OrderedDict[str, Array]]
+ A dictionary that stores serialized array-like or tensor-like objects.
+ keep_input : bool (default: False)
+ A boolean indicating whether parameters should be deleted from the input
+ dictionary immediately after adding them to the record. If False, the
+ dictionary passed to `set_parameters()` will be empty once exiting from that
+ function. This is the desired behaviour when working with very large
+ models/tensors/arrays. However, if you plan to continue working with your
+ parameters after adding it to the record, set this flag to True. When set
+ to True, the data is duplicated in memory.
+ """
+ self.keep_input = keep_input
+ self.data = OrderedDict()
+ if array_dict:
+ self.set_parameters(array_dict)
+
+ def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
+ """Add parameters to record.
+
+ Parameters
+ ----------
+ array_dict : OrderedDict[str, Array]
+ A dictionary that stores serialized array-like or tensor-like objects.
+ """
+ if any(not isinstance(k, str) for k in array_dict.keys()):
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
+ if any(not isinstance(v, Array) for v in array_dict.values()):
+ raise TypeError(f"Not all values are of valid type. Expected {Array}")
+
+ if self.keep_input:
+ # Copy
+ self.data = OrderedDict(array_dict)
+ else:
+ # Add entries to dataclass without duplicating memory
+ for key in list(array_dict.keys()):
+ self.data[key] = array_dict[key]
+ del array_dict[key]
diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py
index 0088b7397a6d..dc723a2cea86 100644
--- a/src/py/flwr/common/recordset.py
+++ b/src/py/flwr/common/recordset.py
@@ -14,13 +14,10 @@
# ==============================================================================
"""RecordSet."""
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import Dict
-
-@dataclass
-class ParametersRecord:
- """Parameters record."""
+from .parametersrecord import ParametersRecord
@dataclass
@@ -37,9 +34,9 @@ class ConfigsRecord:
class RecordSet:
"""Definition of RecordSet."""
- parameters: Dict[str, ParametersRecord] = {}
- metrics: Dict[str, MetricsRecord] = {}
- configs: Dict[str, ConfigsRecord] = {}
+ parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
+ metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
+ configs: Dict[str, ConfigsRecord] = field(default_factory=dict)
def set_parameters(self, name: str, record: ParametersRecord) -> None:
"""Add a ParametersRecord."""
diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py
new file mode 100644
index 000000000000..90c06dcdb109
--- /dev/null
+++ b/src/py/flwr/common/recordset_test.py
@@ -0,0 +1,147 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RecordSet tests."""
+
+
+from typing import Callable, List, OrderedDict, Type, Union
+
+import numpy as np
+import pytest
+
+from .parameter import ndarrays_to_parameters, parameters_to_ndarrays
+from .parametersrecord import Array, ParametersRecord
+from .recordset_utils import (
+ parameters_to_parametersrecord,
+ parametersrecord_to_parameters,
+)
+from .typing import NDArray, NDArrays, Parameters
+
+
+def get_ndarrays() -> NDArrays:
+ """Return list of NumPy arrays."""
+ arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]])
+ arr2 = np.eye(2, 7, 3)
+
+ return [arr1, arr2]
+
+
+def ndarray_to_array(ndarray: NDArray) -> Array:
+ """Represent NumPy ndarray as Array."""
+ return Array(
+ data=ndarray.tobytes(),
+ dtype=str(ndarray.dtype),
+ stype="numpy.ndarray.tobytes",
+ shape=list(ndarray.shape),
+ )
+
+
+def test_ndarray_to_array() -> None:
+ """Test creation of Array object from NumPy ndarray."""
+ shape = (2, 7, 9)
+ arr = np.eye(*shape)
+
+ array = ndarray_to_array(arr)
+
+ arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)
+
+ assert np.array_equal(arr, arr_)
+
+
+def test_parameters_to_array_and_back() -> None:
+ """Test conversion between legacy Parameters and Array."""
+ ndarrays = get_ndarrays()
+
+ # Array represents a single array, unlike Paramters, which represent a
+ # list of arrays
+ ndarray = ndarrays[0]
+
+ parameters = ndarrays_to_parameters([ndarray])
+
+ array = Array(
+ data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[]
+ )
+
+ parameters = Parameters(tensors=[array.data], tensor_type=array.stype)
+
+ ndarray_ = parameters_to_ndarrays(parameters=parameters)[0]
+
+ assert np.array_equal(ndarray, ndarray_)
+
+
+def test_parameters_to_parametersrecord_and_back() -> None:
+ """Test conversion between legacy Parameters and ParametersRecords."""
+ ndarrays = get_ndarrays()
+
+ parameters = ndarrays_to_parameters(ndarrays)
+
+ params_record = parameters_to_parametersrecord(parameters=parameters)
+
+ parameters_ = parametersrecord_to_parameters(params_record)
+
+ ndarrays_ = parameters_to_ndarrays(parameters=parameters_)
+
+ for arr, arr_ in zip(ndarrays, ndarrays_):
+ assert np.array_equal(arr, arr_)
+
+
+def test_set_parameters_while_keeping_intputs() -> None:
+ """Tests keep_input functionality in ParametersRecord."""
+ # Adding parameters to a record that doesn't erase entries in the input `array_dict`
+ p_record = ParametersRecord(keep_input=True)
+ array_dict = OrderedDict(
+ {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
+ )
+ p_record.set_parameters(array_dict)
+
+ # Creating a second parametersrecord passing the same `array_dict` (not erased)
+ p_record_2 = ParametersRecord(array_dict)
+ assert p_record.data == p_record_2.data
+
+ # Now it should be empty (the second ParametersRecord wasn't flagged to keep it)
+ assert len(array_dict) == 0
+
+
+def test_set_parameters_with_correct_types() -> None:
+ """Test adding dictionary of Arrays to ParametersRecord."""
+ p_record = ParametersRecord()
+ array_dict = OrderedDict(
+ {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
+ )
+ p_record.set_parameters(array_dict)
+
+
+@pytest.mark.parametrize(
+ "key_type, value_fn",
+ [
+ (str, lambda x: x), # correct key, incorrect value
+ (str, lambda x: x.tolist()), # correct key, incorrect value
+ (int, ndarray_to_array), # incorrect key, correct value
+ (int, lambda x: x), # incorrect key, incorrect value
+ (int, lambda x: x.tolist()), # incorrect key, incorrect value
+ ],
+)
+def test_set_parameters_with_incorrect_types(
+ key_type: Type[Union[int, str]],
+ value_fn: Callable[[NDArray], Union[NDArray, List[float]]],
+) -> None:
+ """Test adding dictionary of unsupported types to ParametersRecord."""
+ p_record = ParametersRecord()
+
+ array_dict = {
+ key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays())
+ }
+
+ with pytest.raises(TypeError):
+ p_record.set_parameters(array_dict) # type: ignore
diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py
new file mode 100644
index 000000000000..c1e724fa2758
--- /dev/null
+++ b/src/py/flwr/common/recordset_utils.py
@@ -0,0 +1,87 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RecordSet utilities."""
+
+
+from typing import OrderedDict
+
+from .parametersrecord import Array, ParametersRecord
+from .typing import Parameters
+
+
+def parametersrecord_to_parameters(
+ record: ParametersRecord, keep_input: bool = False
+) -> Parameters:
+ """Convert ParameterRecord to legacy Parameters.
+
+ Warning: Because `Arrays` in `ParametersRecord` encode more information of the
+ array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
+ might not be possible to reconstruct such data structures from `Parameters` objects
+ alone. Additional information or metadta must be provided from elsewhere.
+
+ Parameters
+ ----------
+ record : ParametersRecord
+ The record to be conveted into Parameters.
+ keep_input : bool (default: False)
+ A boolean indicating whether entries in the record should be deleted from the
+ input dictionary immediately after adding them to the record.
+ """
+ parameters = Parameters(tensors=[], tensor_type="")
+
+ for key in list(record.data.keys()):
+ parameters.tensors.append(record.data[key].data)
+
+ if not keep_input:
+ del record.data[key]
+
+ return parameters
+
+
+def parameters_to_parametersrecord(
+ parameters: Parameters, keep_input: bool = False
+) -> ParametersRecord:
+ """Convert legacy Parameters into a single ParametersRecord.
+
+ Because there is no concept of names in the legacy Parameters, arbitrary keys will
+ be used when constructing the ParametersRecord. Similarly, the shape and data type
+ won't be recorded in the Array objects.
+
+ Parameters
+ ----------
+ parameters : Parameters
+ Parameters object to be represented as a ParametersRecord.
+ keep_input : bool (default: False)
+ A boolean indicating whether parameters should be deleted from the input
+ Parameters object (i.e. a list of serialized NumPy arrays) immediately after
+ adding them to the record.
+ """
+ tensor_type = parameters.tensor_type
+
+ p_record = ParametersRecord()
+
+ num_arrays = len(parameters.tensors)
+ for idx in range(num_arrays):
+ if keep_input:
+ tensor = parameters.tensors[idx]
+ else:
+ tensor = parameters.tensors.pop(0)
+ p_record.set_parameters(
+ OrderedDict(
+ {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
+ )
+ )
+
+ return p_record
diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py
index a60fff57e7bf..5441e766983a 100644
--- a/src/py/flwr/common/retry_invoker.py
+++ b/src/py/flwr/common/retry_invoker.py
@@ -156,6 +156,7 @@ class RetryInvoker:
>>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
"""
+ # pylint: disable-next=too-many-arguments
def __init__(
self,
wait_factory: Callable[[], Generator[float, None, None]],
diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py
index c8c73e87e04a..59f5387b0a07 100644
--- a/src/py/flwr/common/serde.py
+++ b/src/py/flwr/common/serde.py
@@ -59,7 +59,9 @@ def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessa
server_message.evaluate_ins,
)
)
- raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf")
+ raise ValueError(
+ "No instruction set in ServerMessage, cannot serialize to ProtoBuf"
+ )
def server_message_from_proto(
@@ -91,7 +93,7 @@ def server_message_from_proto(
server_message_proto.evaluate_ins,
)
)
- raise Exception(
+ raise ValueError(
"Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf"
)
@@ -125,7 +127,9 @@ def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessa
client_message.evaluate_res,
)
)
- raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf")
+ raise ValueError(
+ "No instruction set in ClientMessage, cannot serialize to ProtoBuf"
+ )
def client_message_from_proto(
@@ -157,7 +161,7 @@ def client_message_from_proto(
client_message_proto.evaluate_res,
)
)
- raise Exception(
+ raise ValueError(
"Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf"
)
@@ -474,7 +478,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
if isinstance(scalar, str):
return Scalar(string=scalar)
- raise Exception(
+ raise ValueError(
f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})"
)
@@ -518,7 +522,7 @@ def _check_value(value: typing.Value) -> None:
for element in value:
if isinstance(element, data_type):
continue
- raise Exception(
+ raise TypeError(
f"Inconsistent type: the types of elements in the list must "
f"be the same (expected {data_type}, but got {type(element)})."
)
diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py
index 2c3a6d2ccddf..82747e5afb2c 100644
--- a/src/py/flwr/driver/app_test.py
+++ b/src/py/flwr/driver/app_test.py
@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Flower Driver app tests."""
-# pylint: disable=no-self-use
import threading
diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py
index 92b4230a3932..8f75bbf78362 100644
--- a/src/py/flwr/driver/driver_test.py
+++ b/src/py/flwr/driver/driver_test.py
@@ -139,6 +139,7 @@ def test_del_with_initialized_driver(self) -> None:
self.driver._get_grpc_driver_and_run_id()
# Execute
+ # pylint: disable-next=unnecessary-dunder-call
self.driver.__del__()
# Assert
@@ -147,6 +148,7 @@ def test_del_with_initialized_driver(self) -> None:
def test_del_with_uninitialized_driver(self) -> None:
"""Test cleanup behavior when Driver is not initialized."""
# Execute
+ # pylint: disable-next=unnecessary-dunder-call
self.driver.__del__()
# Assert
diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py
index b6d42fe799d5..627b95cdb1b4 100644
--- a/src/py/flwr/driver/grpc_driver.py
+++ b/src/py/flwr/driver/grpc_driver.py
@@ -89,7 +89,7 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call Driver API
res: CreateRunResponse = self.stub.CreateRun(request=req)
@@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call gRPC Driver API
res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call gRPC Driver API
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py
index 6ae38ea3d805..4e68499f018d 100644
--- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py
+++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py
@@ -113,7 +113,7 @@ def _transition(self, next_status: Status) -> None:
):
self._status = next_status
else:
- raise Exception(f"Invalid transition: {self._status} to {next_status}")
+ raise ValueError(f"Invalid transition: {self._status} to {next_status}")
self._cv.notify_all()
@@ -129,7 +129,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper:
self._raise_if_closed()
if self._status != Status.AWAITING_INS_WRAPPER:
- raise Exception("This should not happen")
+ raise ValueError("This should not happen")
self._ins_wrapper = ins_wrapper # Write
self._transition(Status.INS_WRAPPER_AVAILABLE)
@@ -146,7 +146,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper:
self._transition(Status.AWAITING_INS_WRAPPER)
if res_wrapper is None:
- raise Exception("ResWrapper can not be None")
+ raise ValueError("ResWrapper can not be None")
return res_wrapper
@@ -170,7 +170,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]:
self._transition(Status.AWAITING_RES_WRAPPER)
if ins_wrapper is None:
- raise Exception("InsWrapper can not be None")
+ raise ValueError("InsWrapper can not be None")
yield ins_wrapper
@@ -180,7 +180,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None:
self._raise_if_closed()
if self._status != Status.AWAITING_RES_WRAPPER:
- raise Exception("This should not happen")
+ raise ValueError("This should not happen")
self._res_wrapper = res_wrapper # Write
self._transition(Status.RES_WRAPPER_AVAILABLE)
diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py
index 18a2144072ed..bcfbe6e6fac8 100644
--- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py
+++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py
@@ -70,6 +70,7 @@ def test_workflow_successful() -> None:
_ = next(ins_wrapper_iterator)
bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage()))
except Exception as exception:
+ # pylint: disable-next=broad-exception-raised
raise Exception from exception
# Wait until worker_thread is finished
diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
index 1c737d31c7fc..0fa6f82a89b5 100644
--- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
+++ b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
@@ -166,6 +166,6 @@ def _call_client_proxy(
evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res)
return ClientMessage(evaluate_res=evaluate_res_proto)
- raise Exception(
+ raise ValueError(
"Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf"
)
diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py
index 63ec1021ff5c..9b5c03aeeaf9 100644
--- a/src/py/flwr/server/server_test.py
+++ b/src/py/flwr/server/server_test.py
@@ -47,14 +47,14 @@ class SuccessClient(ClientProxy):
def get_properties(
self, ins: GetPropertiesIns, timeout: Optional[float]
) -> GetPropertiesRes:
- """Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ """Raise an error because this method is not expected to be called."""
+ raise NotImplementedError()
def get_parameters(
self, ins: GetParametersIns, timeout: Optional[float]
) -> GetParametersRes:
- """Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ """Raise a error because this method is not expected to be called."""
+ raise NotImplementedError()
def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes:
"""Simulate fit by returning a success FitRes with simple set of weights."""
@@ -87,26 +87,26 @@ class FailingClient(ClientProxy):
def get_properties(
self, ins: GetPropertiesIns, timeout: Optional[float]
) -> GetPropertiesRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def get_parameters(
self, ins: GetParametersIns, timeout: Optional[float]
) -> GetParametersRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def test_fit_clients() -> None:
diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py
index 4f66be3ff262..26f326819971 100644
--- a/src/py/flwr/server/state/sqlite_state.py
+++ b/src/py/flwr/server/state/sqlite_state.py
@@ -134,7 +134,7 @@ def query(
) -> List[Dict[str, Any]]:
"""Execute a SQL query."""
if self.conn is None:
- raise Exception("State is not initialized.")
+ raise AttributeError("State is not initialized.")
if data is None:
data = []
@@ -459,7 +459,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:
"""
if self.conn is None:
- raise Exception("State not intitialized")
+ raise AttributeError("State not intitialized")
with self.conn:
self.conn.execute(query_1, data)
diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py
index efdd288fc308..a3f899386011 100644
--- a/src/py/flwr/server/state/sqlite_state_test.py
+++ b/src/py/flwr/server/state/sqlite_state_test.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Test for utility functions."""
-# pylint: disable=no-self-use, invalid-name, disable=R0904
+# pylint: disable=invalid-name, disable=R0904
import unittest
diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py
index 88b4b53aed4c..204b4ba97b5f 100644
--- a/src/py/flwr/server/state/state_test.py
+++ b/src/py/flwr/server/state/state_test.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests all state implemenations have to conform to."""
-# pylint: disable=no-self-use, invalid-name, disable=R0904
+# pylint: disable=invalid-name, disable=R0904
import tempfile
import unittest
diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py
index 4eb76111b266..c668b55eebe6 100644
--- a/src/py/flwr/server/strategy/aggregate.py
+++ b/src/py/flwr/server/strategy/aggregate.py
@@ -27,7 +27,7 @@
def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
"""Compute weighted average."""
# Calculate the total number of examples used during training
- num_examples_total = sum([num_examples for _, num_examples in results])
+ num_examples_total = sum(num_examples for (_, num_examples) in results)
# Create a list of weights, each multiplied by the related number of examples
weighted_weights = [
@@ -45,7 +45,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
"""Compute in-place weighted average."""
# Count total examples
- num_examples_total = sum([fit_res.num_examples for _, fit_res in results])
+ num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
# Compute scaling factors for each result
scaling_factors = [
@@ -95,9 +95,9 @@ def aggregate_krum(
# For each client, take the n-f-2 closest parameters vectors
num_closest = max(1, len(weights) - num_malicious - 2)
closest_indices = []
- for i, _ in enumerate(distance_matrix):
+ for distance in distance_matrix:
closest_indices.append(
- np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203
+ np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
)
# Compute the score for each client, that is the sum of the distances
@@ -202,7 +202,7 @@ def aggregate_bulyan(
def weighted_loss_avg(results: List[Tuple[int, float]]) -> float:
"""Aggregate evaluation results obtained from multiple clients."""
- num_total_evaluation_examples = sum([num_examples for num_examples, _ in results])
+ num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results)
weighted_losses = [num_examples * loss for num_examples, loss in results]
return sum(weighted_losses) / num_total_evaluation_examples
@@ -233,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray:
"""
flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights])
distance_matrix = np.zeros((len(weights), len(weights)))
- for i, _ in enumerate(flat_w):
- for j, _ in enumerate(flat_w):
- delta = flat_w[i] - flat_w[j]
+ for i, flat_w_i in enumerate(flat_w):
+ for j, flat_w_j in enumerate(flat_w):
+ delta = flat_w_i - flat_w_j
norm = np.linalg.norm(delta)
distance_matrix[i, j] = norm**2
return distance_matrix
diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py
index 3269735e9d73..8b3278cc9ba0 100644
--- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py
+++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py
@@ -91,7 +91,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None:
norm_bit_set_count = 0
for client_proxy, fit_res in results:
if "dpfedavg_norm_bit" not in fit_res.metrics:
- raise Exception(
+ raise KeyError(
f"Indicator bit not returned by client with id {client_proxy.cid}."
)
if fit_res.metrics["dpfedavg_norm_bit"]:
diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py
index 0154cfd79fc5..f2f1c206f3de 100644
--- a/src/py/flwr/server/strategy/dpfedavg_fixed.py
+++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py
@@ -46,11 +46,11 @@ def __init__(
self.num_sampled_clients = num_sampled_clients
if clip_norm <= 0:
- raise Exception("The clipping threshold should be a positive value.")
+ raise ValueError("The clipping threshold should be a positive value.")
self.clip_norm = clip_norm
if noise_multiplier < 0:
- raise Exception("The noise multiplier should be a non-negative value.")
+ raise ValueError("The noise multiplier should be a non-negative value.")
self.noise_multiplier = noise_multiplier
self.server_side_noising = server_side_noising
diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py
index e890f7216020..6678b7ced114 100644
--- a/src/py/flwr/server/strategy/fedavg_android.py
+++ b/src/py/flwr/server/strategy/fedavg_android.py
@@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays:
"""Convert parameters object to NumPy weights."""
return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors]
- # pylint: disable=R0201
def ndarray_to_bytes(self, ndarray: NDArray) -> bytes:
"""Serialize NumPy array to bytes."""
return ndarray.tobytes()
- # pylint: disable=R0201
def bytes_to_ndarray(self, tensor: bytes) -> NDArray:
"""Deserialize NumPy array from bytes."""
ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32)
diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py
index 7a5bf1425b44..17e979d92beb 100644
--- a/src/py/flwr/server/strategy/fedmedian.py
+++ b/src/py/flwr/server/strategy/fedmedian.py
@@ -36,7 +36,7 @@
class FedMedian(FedAvg):
- """Configurable FedAvg with Momentum strategy implementation."""
+ """Configurable FedMedian strategy implementation."""
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py
index 94a67fbcbfae..758e8e608e9f 100644
--- a/src/py/flwr/server/strategy/qfedavg.py
+++ b/src/py/flwr/server/strategy/qfedavg.py
@@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float:
hs_ffl = []
if self.pre_weights is None:
- raise Exception("QffedAvg pre_weights are None in aggregate_fit")
+ raise AttributeError("QffedAvg pre_weights are None in aggregate_fit")
weights_before = self.pre_weights
eval_result = self.evaluate(