From 9dd52988ce5d7204a96d74fefad812382c5f18d2 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:03:14 +0100 Subject: [PATCH 01/10] Narrow down Python version in FDS TF e2e test (#2797) Co-authored-by: Javier --- datasets/e2e/tensorflow/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/e2e/tensorflow/pyproject.toml b/datasets/e2e/tensorflow/pyproject.toml index 9c5c72c46400..4d7b5f60e856 100644 --- a/datasets/e2e/tensorflow/pyproject.toml +++ b/datasets/e2e/tensorflow/pyproject.toml @@ -9,7 +9,7 @@ description = "Flower Datasets with TensorFlow" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = "^3.8" +python = ">=3.8,<3.11" flwr-datasets = { path = "./../../", extras = ["vision"] } tensorflow-cpu = "^2.9.1, !=2.11.1" parameterized = "==0.9.0" From 7f48ea2ae7b8ff8653eb7ce013727028a01e0f9a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 08:38:19 +0000 Subject: [PATCH 02/10] Update types-setuptools requirement from ==68.2.0.0 to ==69.0.0.20240115 (#2790) Updates the requirements on [types-setuptools](https://github.com/python/typeshed) to permit the latest version. - [Commits](https://github.com/python/typeshed/commits) --- updated-dependencies: - dependency-name: types-setuptools dependency-type: direct:development ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Taner Topal --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8a300afa8c84..0616ffdbeffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ rest = ["requests", "starlette", "uvicorn"] types-dataclasses = "==0.6.6" types-protobuf = "==3.19.18" types-requests = "==2.31.0.10" -types-setuptools = "==68.2.0.0" +types-setuptools = "==69.0.0.20240115" clang-format = "==17.0.4" isort = "==5.12.0" black = { version = "==23.10.1", extras = ["jupyter"] } From 0c13d3b5b62351951485048211e27cf301ae2523 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:58:31 +0000 Subject: [PATCH 03/10] Bump actions/download-artifact from 4.1.0 to 4.1.1 (#2788) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4.1.0 to 4.1.1. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/f44cd7b40bfd40b6aa1cc1b9b5b7bf03d3c67110...6b208ae046db98c579e8a3aa621ab581ff575935) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Taner Topal --- .github/workflows/_docker-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index 36b94b5c7e97..07c9d0cba0ad 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -114,7 +114,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@f44cd7b40bfd40b6aa1cc1b9b5b7bf03d3c67110 # v4.1.0 + uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests From f1f0299791da5d5bafc07210cd458ad0bbf2f2f9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 17 Jan 2024 12:34:16 +0000 Subject: [PATCH 04/10] Bump actions/upload-artifact from 4.0.0 to 4.1.0 (#2789) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.0.0 to 4.1.0. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/c7d193f32edcb7bfad88892161225aeda64e9392...1eb3cb2b3e0f29609092a73eb033bb759a334595) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/_docker-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index 07c9d0cba0ad..4a1289d9175a 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -98,7 +98,7 @@ jobs: touch "/tmp/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@c7d193f32edcb7bfad88892161225aeda64e9392 # v4.0.0 + uses: actions/upload-artifact@1eb3cb2b3e0f29609092a73eb033bb759a334595 # v4.1.0 with: name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }} path: /tmp/digests/* From 097631c079165fb3d72d89fb1bad66a294746033 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 17 Jan 2024 14:31:27 +0100 Subject: [PATCH 05/10] Fix outdated Android README (#2804) --- examples/android/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/android/README.md b/examples/android/README.md index 7931aa96b0c5..f9f2bb93b8dc 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -54,4 +54,4 @@ poetry run ./run.sh Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`. -When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training. +When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively. From 0daa3d79d102d62e626e9d951249da12811d87dd Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 17 Jan 2024 16:29:13 +0100 Subject: [PATCH 06/10] Add conda install instructions (#2800) --- doc/source/how-to-install-flower.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index 1107f6798b23..ff3dbb605846 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth Install stable release ---------------------- +Using pip +~~~~~~~~~ + Stable releases are available on `PyPI `_:: python -m pip install flwr @@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed python -m pip install flwr[simulation] +Using conda (or mamba) +~~~~~~~~~~~~~~~~~~~~~~ + +Flower can also be installed from the ``conda-forge`` channel. + +If you have not added ``conda-forge`` to your channels, you will first need to run the following:: + + conda config --add channels conda-forge + conda config --set channel_priority strict + +Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``:: + + conda install flwr + +or with ``mamba``:: + + mamba install flwr + + Verify installation ------------------- From 9e6df75bec9092b89f06939d533d2d8e66a5e0cc Mon Sep 17 00:00:00 2001 From: Edoardo Gabrielli Date: Wed, 17 Jan 2024 18:36:37 +0100 Subject: [PATCH 07/10] Update FedMedian docstring (#2761) --- src/py/flwr/server/strategy/fedmedian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index 7a5bf1425b44..17e979d92beb 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -36,7 +36,7 @@ class FedMedian(FedAvg): - """Configurable FedAvg with Momentum strategy implementation.""" + """Configurable FedMedian strategy implementation.""" def __repr__(self) -> str: """Compute a string representation of the strategy.""" From 815f66277cf42b97b6bcb481a9baf32632119752 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 17 Jan 2024 18:50:42 +0100 Subject: [PATCH 08/10] Favor docformatter for multi-line docstrings (#2807) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0616ffdbeffd..7a8b0d1ad45f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,7 +184,7 @@ target-version = "py38" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] -ignore = ["B024", "B027"] +ignore = ["B024", "B027", "D205", "D209"] exclude = [ ".bzr", ".direnv", From 1fcb147c8360d90cae741ba802e3720ad145010b Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 17 Jan 2024 18:23:57 +0000 Subject: [PATCH 09/10] Add `ParametersRecord` (#2799) Co-authored-by: Daniel J. Beutel Co-authored-by: Heng Pan --- src/py/flwr/common/parametersrecord.py | 110 ++++++++++++++++++ src/py/flwr/common/recordset.py | 13 +-- src/py/flwr/common/recordset_test.py | 147 +++++++++++++++++++++++++ src/py/flwr/common/recordset_utils.py | 87 +++++++++++++++ 4 files changed, 349 insertions(+), 8 deletions(-) create mode 100644 src/py/flwr/common/parametersrecord.py create mode 100644 src/py/flwr/common/recordset_test.py create mode 100644 src/py/flwr/common/recordset_utils.py diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py new file mode 100644 index 000000000000..3d40c0488baa --- /dev/null +++ b/src/py/flwr/common/parametersrecord.py @@ -0,0 +1,110 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParametersRecord and Array.""" + + +from dataclasses import dataclass, field +from typing import List, Optional, OrderedDict + + +@dataclass +class Array: + """Array type. + + A dataclass containing serialized data from an array-like or tensor-like object + along with some metadata about it. + + Parameters + ---------- + dtype : str + A string representing the data type of the serialised object (e.g. `np.float32`) + + shape : List[int] + A list representing the shape of the unserialized array-like object. This is + used to deserialize the data (depending on the serialization method) or simply + as a metadata field. + + stype : str + A string indicating the type of serialisation mechanism used to generate the + bytes in `data` from an array-like or tensor-like object. + + data: bytes + A buffer of bytes containing the data. + """ + + dtype: str + shape: List[int] + stype: str + data: bytes + + +@dataclass +class ParametersRecord: + """Parameters record. + + A dataclass storing named Arrays in order. This means that it holds entries as an + OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to + PyTorch's state_dict, but holding serialised tensors instead. + """ + + keep_input: bool + data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) + + def __init__( + self, + array_dict: Optional[OrderedDict[str, Array]] = None, + keep_input: bool = False, + ) -> None: + """Construct a ParametersRecord object. + + Parameters + ---------- + array_dict : Optional[OrderedDict[str, Array]] + A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. If False, the + dictionary passed to `set_parameters()` will be empty once exiting from that + function. This is the desired behaviour when working with very large + models/tensors/arrays. However, if you plan to continue working with your + parameters after adding it to the record, set this flag to True. When set + to True, the data is duplicated in memory. + """ + self.keep_input = keep_input + self.data = OrderedDict() + if array_dict: + self.set_parameters(array_dict) + + def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: + """Add parameters to record. + + Parameters + ---------- + array_dict : OrderedDict[str, Array] + A dictionary that stores serialized array-like or tensor-like objects. + """ + if any(not isinstance(k, str) for k in array_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}") + if any(not isinstance(v, Array) for v in array_dict.values()): + raise TypeError(f"Not all values are of valid type. Expected {Array}") + + if self.keep_input: + # Copy + self.data = OrderedDict(array_dict) + else: + # Add entries to dataclass without duplicating memory + for key in list(array_dict.keys()): + self.data[key] = array_dict[key] + del array_dict[key] diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index 0088b7397a6d..dc723a2cea86 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -14,13 +14,10 @@ # ============================================================================== """RecordSet.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict - -@dataclass -class ParametersRecord: - """Parameters record.""" +from .parametersrecord import ParametersRecord @dataclass @@ -37,9 +34,9 @@ class ConfigsRecord: class RecordSet: """Definition of RecordSet.""" - parameters: Dict[str, ParametersRecord] = {} - metrics: Dict[str, MetricsRecord] = {} - configs: Dict[str, ConfigsRecord] = {} + parameters: Dict[str, ParametersRecord] = field(default_factory=dict) + metrics: Dict[str, MetricsRecord] = field(default_factory=dict) + configs: Dict[str, ConfigsRecord] = field(default_factory=dict) def set_parameters(self, name: str, record: ParametersRecord) -> None: """Add a ParametersRecord.""" diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py new file mode 100644 index 000000000000..90c06dcdb109 --- /dev/null +++ b/src/py/flwr/common/recordset_test.py @@ -0,0 +1,147 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet tests.""" + + +from typing import Callable, List, OrderedDict, Type, Union + +import numpy as np +import pytest + +from .parameter import ndarrays_to_parameters, parameters_to_ndarrays +from .parametersrecord import Array, ParametersRecord +from .recordset_utils import ( + parameters_to_parametersrecord, + parametersrecord_to_parameters, +) +from .typing import NDArray, NDArrays, Parameters + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +def ndarray_to_array(ndarray: NDArray) -> Array: + """Represent NumPy ndarray as Array.""" + return Array( + data=ndarray.tobytes(), + dtype=str(ndarray.dtype), + stype="numpy.ndarray.tobytes", + shape=list(ndarray.shape), + ) + + +def test_ndarray_to_array() -> None: + """Test creation of Array object from NumPy ndarray.""" + shape = (2, 7, 9) + arr = np.eye(*shape) + + array = ndarray_to_array(arr) + + arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) + + assert np.array_equal(arr, arr_) + + +def test_parameters_to_array_and_back() -> None: + """Test conversion between legacy Parameters and Array.""" + ndarrays = get_ndarrays() + + # Array represents a single array, unlike Paramters, which represent a + # list of arrays + ndarray = ndarrays[0] + + parameters = ndarrays_to_parameters([ndarray]) + + array = Array( + data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[] + ) + + parameters = Parameters(tensors=[array.data], tensor_type=array.stype) + + ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] + + assert np.array_equal(ndarray, ndarray_) + + +def test_parameters_to_parametersrecord_and_back() -> None: + """Test conversion between legacy Parameters and ParametersRecords.""" + ndarrays = get_ndarrays() + + parameters = ndarrays_to_parameters(ndarrays) + + params_record = parameters_to_parametersrecord(parameters=parameters) + + parameters_ = parametersrecord_to_parameters(params_record) + + ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + + for arr, arr_ in zip(ndarrays, ndarrays_): + assert np.array_equal(arr, arr_) + + +def test_set_parameters_while_keeping_intputs() -> None: + """Tests keep_input functionality in ParametersRecord.""" + # Adding parameters to a record that doesn't erase entries in the input `array_dict` + p_record = ParametersRecord(keep_input=True) + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + # Creating a second parametersrecord passing the same `array_dict` (not erased) + p_record_2 = ParametersRecord(array_dict) + assert p_record.data == p_record_2.data + + # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) + assert len(array_dict) == 0 + + +def test_set_parameters_with_correct_types() -> None: + """Test adding dictionary of Arrays to ParametersRecord.""" + p_record = ParametersRecord() + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # correct key, incorrect value + (str, lambda x: x.tolist()), # correct key, incorrect value + (int, ndarray_to_array), # incorrect key, correct value + (int, lambda x: x), # incorrect key, incorrect value + (int, lambda x: x.tolist()), # incorrect key, incorrect value + ], +) +def test_set_parameters_with_incorrect_types( + key_type: Type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, List[float]]], +) -> None: + """Test adding dictionary of unsupported types to ParametersRecord.""" + p_record = ParametersRecord() + + array_dict = { + key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays()) + } + + with pytest.raises(TypeError): + p_record.set_parameters(array_dict) # type: ignore diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py new file mode 100644 index 000000000000..c1e724fa2758 --- /dev/null +++ b/src/py/flwr/common/recordset_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import OrderedDict + +from .parametersrecord import Array, ParametersRecord +from .typing import Parameters + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool = False +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warning: Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadta must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool (default: False) + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record.data[key].data) + + if not keep_input: + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool = False +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + num_arrays = len(parameters.tensors) + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + p_record.set_parameters( + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) + ) + + return p_record From 66b3bbe81484c11be579551175991189a4888476 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 17 Jan 2024 19:34:06 +0100 Subject: [PATCH 10/10] Upgrade Pylint to 3.0.3 (#2488) Co-authored-by: Charles Beauville --- pyproject.toml | 4 +-- src/py/flwr/client/app.py | 9 ++++-- src/py/flwr/client/app_test.py | 16 +++++------ src/py/flwr/client/dpfedavg_numpy_client.py | 8 +++--- .../client/message_handler/task_handler.py | 3 +- src/py/flwr/client/numpy_client.py | 4 +-- src/py/flwr/client/rest_client/connection.py | 4 +++ .../secure_aggregation/secaggplus_handler.py | 12 ++++---- src/py/flwr/common/retry_invoker.py | 1 + src/py/flwr/common/serde.py | 16 +++++++---- src/py/flwr/driver/app_test.py | 1 - src/py/flwr/driver/driver_test.py | 2 ++ src/py/flwr/driver/grpc_driver.py | 8 +++--- .../server/fleet/grpc_bidi/grpc_bridge.py | 10 +++---- .../fleet/grpc_bidi/grpc_bridge_test.py | 1 + .../server/fleet/grpc_bidi/ins_scheduler.py | 2 +- src/py/flwr/server/server_test.py | 28 +++++++++---------- src/py/flwr/server/state/sqlite_state.py | 4 +-- src/py/flwr/server/state/sqlite_state_test.py | 2 +- src/py/flwr/server/state/state_test.py | 2 +- src/py/flwr/server/strategy/aggregate.py | 16 +++++------ .../flwr/server/strategy/dpfedavg_adaptive.py | 2 +- src/py/flwr/server/strategy/dpfedavg_fixed.py | 4 +-- src/py/flwr/server/strategy/fedavg_android.py | 2 -- src/py/flwr/server/strategy/qfedavg.py | 2 +- 25 files changed, 87 insertions(+), 76 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a8b0d1ad45f..24d20c7ced40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ isort = "==5.12.0" black = { version = "==23.10.1", extras = ["jupyter"] } docformatter = "==1.7.5" mypy = "==1.6.1" -pylint = "==2.13.9" +pylint = "==3.0.3" flake8 = "==5.0.4" pytest = "==7.4.3" pytest-cov = "==4.1.0" @@ -137,7 +137,7 @@ line-length = 88 target-version = ["py38", "py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] -disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +disable = "duplicate-code,too-few-public-methods,useless-import-alias" [tool.pytest.ini_options] minversion = "6.2" diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index a5b285fbb7fb..91fa5468ae75 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -138,10 +138,12 @@ def _check_actionable_client( client: Optional[Client], client_fn: Optional[ClientFn] ) -> None: if client_fn is None and client is None: - raise Exception("Both `client_fn` and `client` are `None`, but one is required") + raise ValueError( + "Both `client_fn` and `client` are `None`, but one is required" + ) if client_fn is not None and client is not None: - raise Exception( + raise ValueError( "Both `client_fn` and `client` are provided, but only one is allowed" ) @@ -150,6 +152,7 @@ def _check_actionable_client( # pylint: disable=too-many-branches # pylint: disable=too-many-locals # pylint: disable=too-many-statements +# pylint: disable=too-many-arguments def start_client( *, server_address: str, @@ -299,7 +302,7 @@ def single_client_factory( cid: str, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy - raise Exception( + raise ValueError( "Both `client_fn` and `client` are `None`, but one is required" ) return client # Always return the same instance diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 7ef6410debad..56d6308a0fe2 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -41,19 +41,19 @@ class PlainClient(Client): def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit(self, ins: FitIns) -> FitRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate(self, ins: EvaluateIns) -> EvaluateRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() class NeedsWrappingClient(NumPyClient): @@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient): def get_properties(self, config: Config) -> Dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, config: Config) -> NDArrays: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit( self, parameters: NDArrays, config: Config ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config ) -> Tuple[float, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def test_to_client_with_client() -> None: diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index 41b4d676df43..c39b89b31da3 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -117,16 +117,16 @@ def fit( update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)] if "dpfedavg_clip_norm" not in config: - raise Exception("Clipping threshold not supplied by the server.") + raise KeyError("Clipping threshold not supplied by the server.") if not isinstance(config["dpfedavg_clip_norm"], float): - raise Exception("Clipping threshold should be a floating point value.") + raise TypeError("Clipping threshold should be a floating point value.") # Clipping update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"]) if "dpfedavg_noise_stddev" in config: if not isinstance(config["dpfedavg_noise_stddev"], float): - raise Exception( + raise TypeError( "Scale of noise to be added should be a floating point value." ) # Noising @@ -138,7 +138,7 @@ def fit( # Calculating value of norm indicator bit, required for adaptive clipping if "dpfedavg_adaptive_clip_enabled" in config: if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool): - raise Exception( + raise TypeError( "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag." ) metrics["dpfedavg_norm_bit"] = not clipped diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index 13b1948eec07..3599e1dfb254 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -80,8 +80,7 @@ def validate_task_res(task_res: TaskRes) -> bool: initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} # Check if certain fields are already initialized - # pylint: disable-next=too-many-boolean-expressions - if ( + if ( # pylint: disable-next=too-many-boolean-expressions "task_id" in initialized_fields_in_task_res or "group_id" in initialized_fields_in_task_res or "run_id" in initialized_fields_in_task_res diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 2312741f5af6..d67fb90512d4 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) # Return FitRes parameters_prime, num_examples, metrics = results @@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) # Return EvaluateRes loss, num_examples, metrics = results diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d22b246dbd61..87b06dd0be4e 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -143,6 +143,7 @@ def create_node() -> None: }, data=create_node_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -185,6 +186,7 @@ def delete_node() -> None: }, data=delete_node_req_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]: }, data=pull_task_ins_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None: }, data=push_task_res_request_bytes, verify=verify, + timeout=None, ) state[KEY_TASK_INS] = None diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py index efbb00a9d916..4b74c1ace3de 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py @@ -333,7 +333,7 @@ def _share_keys( # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: List[bytes] = [] @@ -341,14 +341,14 @@ def _share_keys( pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): - raise Exception("Some public keys are identical") + raise ValueError("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.sid][0] != state.pk1 or state.public_keys_dict[state.sid][1] != state.pk2 ): - raise Exception( + raise ValueError( "Own public keys are displayed in dict incorrectly, should not happen!" ) @@ -393,7 +393,7 @@ def _collect_masked_input( ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: - raise Exception("Not enough available neighbour clients.") + raise ValueError("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): @@ -409,7 +409,7 @@ def _collect_masked_input( f"from {actual_src} instead of {src}." ) if dst != state.sid: - ValueError( + raise ValueError( f"Client {state.sid}: received an encrypted message" f"for Client {dst} from Client {src}." ) @@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") sids, shares = [], [] sids += active_sids diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index a60fff57e7bf..5441e766983a 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -156,6 +156,7 @@ class RetryInvoker: >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1) """ + # pylint: disable-next=too-many-arguments def __init__( self, wait_factory: Callable[[], Generator[float, None, None]], diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index c8c73e87e04a..59f5387b0a07 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -59,7 +59,9 @@ def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessa server_message.evaluate_ins, ) ) - raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ServerMessage, cannot serialize to ProtoBuf" + ) def server_message_from_proto( @@ -91,7 +93,7 @@ def server_message_from_proto( server_message_proto.evaluate_ins, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) @@ -125,7 +127,9 @@ def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessa client_message.evaluate_res, ) ) - raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ClientMessage, cannot serialize to ProtoBuf" + ) def client_message_from_proto( @@ -157,7 +161,7 @@ def client_message_from_proto( client_message_proto.evaluate_res, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf" ) @@ -474,7 +478,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar: if isinstance(scalar, str): return Scalar(string=scalar) - raise Exception( + raise ValueError( f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})" ) @@ -518,7 +522,7 @@ def _check_value(value: typing.Value) -> None: for element in value: if isinstance(element, data_type): continue - raise Exception( + raise TypeError( f"Inconsistent type: the types of elements in the list must " f"be the same (expected {data_type}, but got {type(element)})." ) diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index 2c3a6d2ccddf..82747e5afb2c 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Flower Driver app tests.""" -# pylint: disable=no-self-use import threading diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py index 92b4230a3932..8f75bbf78362 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/driver/driver_test.py @@ -139,6 +139,7 @@ def test_del_with_initialized_driver(self) -> None: self.driver._get_grpc_driver_and_run_id() # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert @@ -147,6 +148,7 @@ def test_del_with_initialized_driver(self) -> None: def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index b6d42fe799d5..627b95cdb1b4 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -89,7 +89,7 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: CreateRunResponse = self.stub.CreateRun(request=req) @@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) @@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) @@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py index 6ae38ea3d805..4e68499f018d 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py @@ -113,7 +113,7 @@ def _transition(self, next_status: Status) -> None: ): self._status = next_status else: - raise Exception(f"Invalid transition: {self._status} to {next_status}") + raise ValueError(f"Invalid transition: {self._status} to {next_status}") self._cv.notify_all() @@ -129,7 +129,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._raise_if_closed() if self._status != Status.AWAITING_INS_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._ins_wrapper = ins_wrapper # Write self._transition(Status.INS_WRAPPER_AVAILABLE) @@ -146,7 +146,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._transition(Status.AWAITING_INS_WRAPPER) if res_wrapper is None: - raise Exception("ResWrapper can not be None") + raise ValueError("ResWrapper can not be None") return res_wrapper @@ -170,7 +170,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]: self._transition(Status.AWAITING_RES_WRAPPER) if ins_wrapper is None: - raise Exception("InsWrapper can not be None") + raise ValueError("InsWrapper can not be None") yield ins_wrapper @@ -180,7 +180,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None: self._raise_if_closed() if self._status != Status.AWAITING_RES_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._res_wrapper = res_wrapper # Write self._transition(Status.RES_WRAPPER_AVAILABLE) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py index 18a2144072ed..bcfbe6e6fac8 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py @@ -70,6 +70,7 @@ def test_workflow_successful() -> None: _ = next(ins_wrapper_iterator) bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage())) except Exception as exception: + # pylint: disable-next=broad-exception-raised raise Exception from exception # Wait until worker_thread is finished diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py index 1c737d31c7fc..0fa6f82a89b5 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py +++ b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py @@ -166,6 +166,6 @@ def _call_client_proxy( evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res) return ClientMessage(evaluate_res=evaluate_res_proto) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 63ec1021ff5c..9b5c03aeeaf9 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -47,14 +47,14 @@ class SuccessClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise an error because this method is not expected to be called.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise a error because this method is not expected to be called.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" @@ -87,26 +87,26 @@ class FailingClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def test_fit_clients() -> None: diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index 4f66be3ff262..26f326819971 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -134,7 +134,7 @@ def query( ) -> List[Dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: - raise Exception("State is not initialized.") + raise AttributeError("State is not initialized.") if data is None: data = [] @@ -459,7 +459,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """ if self.conn is None: - raise Exception("State not intitialized") + raise AttributeError("State not intitialized") with self.conn: self.conn.execute(query_1, data) diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py index efdd288fc308..a3f899386011 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/state/sqlite_state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Test for utility functions.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import unittest diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 88b4b53aed4c..204b4ba97b5f 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import tempfile import unittest diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 4eb76111b266..c668b55eebe6 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -27,7 +27,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training - num_examples_total = sum([num_examples for _, num_examples in results]) + num_examples_total = sum(num_examples for (_, num_examples) in results) # Create a list of weights, each multiplied by the related number of examples weighted_weights = [ @@ -45,7 +45,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: """Compute in-place weighted average.""" # Count total examples - num_examples_total = sum([fit_res.num_examples for _, fit_res in results]) + num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) # Compute scaling factors for each result scaling_factors = [ @@ -95,9 +95,9 @@ def aggregate_krum( # For each client, take the n-f-2 closest parameters vectors num_closest = max(1, len(weights) - num_malicious - 2) closest_indices = [] - for i, _ in enumerate(distance_matrix): + for distance in distance_matrix: closest_indices.append( - np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203 + np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203 ) # Compute the score for each client, that is the sum of the distances @@ -202,7 +202,7 @@ def aggregate_bulyan( def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" - num_total_evaluation_examples = sum([num_examples for num_examples, _ in results]) + num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] return sum(weighted_losses) / num_total_evaluation_examples @@ -233,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray: """ flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights]) distance_matrix = np.zeros((len(weights), len(weights))) - for i, _ in enumerate(flat_w): - for j, _ in enumerate(flat_w): - delta = flat_w[i] - flat_w[j] + for i, flat_w_i in enumerate(flat_w): + for j, flat_w_j in enumerate(flat_w): + delta = flat_w_i - flat_w_j norm = np.linalg.norm(delta) distance_matrix[i, j] = norm**2 return distance_matrix diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 3269735e9d73..8b3278cc9ba0 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -91,7 +91,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: norm_bit_set_count = 0 for client_proxy, fit_res in results: if "dpfedavg_norm_bit" not in fit_res.metrics: - raise Exception( + raise KeyError( f"Indicator bit not returned by client with id {client_proxy.cid}." ) if fit_res.metrics["dpfedavg_norm_bit"]: diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index 0154cfd79fc5..f2f1c206f3de 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -46,11 +46,11 @@ def __init__( self.num_sampled_clients = num_sampled_clients if clip_norm <= 0: - raise Exception("The clipping threshold should be a positive value.") + raise ValueError("The clipping threshold should be a positive value.") self.clip_norm = clip_norm if noise_multiplier < 0: - raise Exception("The noise multiplier should be a non-negative value.") + raise ValueError("The noise multiplier should be a non-negative value.") self.noise_multiplier = noise_multiplier self.server_side_noising = server_side_noising diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index e890f7216020..6678b7ced114 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays: """Convert parameters object to NumPy weights.""" return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors] - # pylint: disable=R0201 def ndarray_to_bytes(self, ndarray: NDArray) -> bytes: """Serialize NumPy array to bytes.""" return ndarray.tobytes() - # pylint: disable=R0201 def bytes_to_ndarray(self, tensor: bytes) -> NDArray: """Deserialize NumPy array from bytes.""" ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32) diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 94a67fbcbfae..758e8e608e9f 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float: hs_ffl = [] if self.pre_weights is None: - raise Exception("QffedAvg pre_weights are None in aggregate_fit") + raise AttributeError("QffedAvg pre_weights are None in aggregate_fit") weights_before = self.pre_weights eval_result = self.evaluate(