diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml
index 36b94b5c7e97..4a1289d9175a 100644
--- a/.github/workflows/_docker-build.yml
+++ b/.github/workflows/_docker-build.yml
@@ -98,7 +98,7 @@ jobs:
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
- uses: actions/upload-artifact@c7d193f32edcb7bfad88892161225aeda64e9392 # v4.0.0
+ uses: actions/upload-artifact@1eb3cb2b3e0f29609092a73eb033bb759a334595 # v4.1.0
with:
name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }}
path: /tmp/digests/*
@@ -114,7 +114,7 @@ jobs:
metadata: ${{ steps.meta.outputs.json }}
steps:
- name: Download digests
- uses: actions/download-artifact@f44cd7b40bfd40b6aa1cc1b9b5b7bf03d3c67110 # v4.1.0
+ uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1
with:
pattern: digests-${{ needs.build.outputs.build-id }}-*
path: /tmp/digests
diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml
index 8758d0e1c5c7..2ca596a59361 100644
--- a/.github/workflows/swift.yml
+++ b/.github/workflows/swift.yml
@@ -20,7 +20,7 @@ jobs:
name: Test
runs-on: macos-latest
steps:
- - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c
+ - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4
with:
swift-version: 5
- uses: actions/checkout@v4
@@ -31,7 +31,7 @@ jobs:
runs-on: macos-latest
name: Build docs
steps:
- - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c
+ - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4
with:
swift-version: 5
- uses: actions/checkout@v4
@@ -44,7 +44,7 @@ jobs:
runs-on: macos-latest
name: Deploy docs
steps:
- - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c
+ - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4
with:
swift-version: 5
- uses: actions/checkout@v4
diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst
index 1107f6798b23..ff3dbb605846 100644
--- a/doc/source/how-to-install-flower.rst
+++ b/doc/source/how-to-install-flower.rst
@@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth
Install stable release
----------------------
+Using pip
+~~~~~~~~~
+
Stable releases are available on `PyPI `_::
python -m pip install flwr
@@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed
python -m pip install flwr[simulation]
+Using conda (or mamba)
+~~~~~~~~~~~~~~~~~~~~~~
+
+Flower can also be installed from the ``conda-forge`` channel.
+
+If you have not added ``conda-forge`` to your channels, you will first need to run the following::
+
+ conda config --add channels conda-forge
+ conda config --set channel_priority strict
+
+Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``::
+
+ conda install flwr
+
+or with ``mamba``::
+
+ mamba install flwr
+
+
Verify installation
-------------------
diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md
index db0245e41453..2527e8e4a820 100644
--- a/examples/advanced-pytorch/README.md
+++ b/examples/advanced-pytorch/README.md
@@ -1,6 +1,6 @@
# Advanced Flower Example (PyTorch)
-This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways:
+This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways:
- 10 clients (instead of just 2)
- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set).
@@ -59,12 +59,13 @@ pip install -r requirements.txt
The included `run.sh` will start the Flower server (using `server.py`),
sleep for 2 seconds to ensure that the server is up, and then start 10 Flower clients (using `client.py`) with only a small subset of the data (in order to run on any machine),
-but this can be changed by removing the `--toy True` argument in the script. You can simply start everything in a terminal as follows:
+but this can be changed by removing the `--toy` argument in the script. You can simply start everything in a terminal as follows:
```shell
-poetry run ./run.sh
+# After activating your environment
+./run.sh
```
The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).
-You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network).
+You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network).
diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py
index f9ffb6181fd8..b22cbcd70465 100644
--- a/examples/advanced-pytorch/client.py
+++ b/examples/advanced-pytorch/client.py
@@ -6,6 +6,7 @@
import argparse
from collections import OrderedDict
import warnings
+import datasets
warnings.filterwarnings("ignore")
@@ -13,9 +14,9 @@
class CifarClient(fl.client.NumPyClient):
def __init__(
self,
- trainset: torchvision.datasets,
- testset: torchvision.datasets,
- device: str,
+ trainset: datasets.Dataset,
+ testset: datasets.Dataset,
+ device: torch.device,
validation_split: int = 0.1,
):
self.device = device
@@ -41,17 +42,14 @@ def fit(self, parameters, config):
batch_size: int = config["batch_size"]
epochs: int = config["local_epochs"]
- n_valset = int(len(self.trainset) * self.validation_split)
+ train_valid = self.trainset.train_test_split(self.validation_split)
+ trainset = train_valid["train"]
+ valset = train_valid["test"]
- valset = torch.utils.data.Subset(self.trainset, range(0, n_valset))
- trainset = torch.utils.data.Subset(
- self.trainset, range(n_valset, len(self.trainset))
- )
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
+ val_loader = DataLoader(valset, batch_size=batch_size)
- trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
- valLoader = DataLoader(valset, batch_size=batch_size)
-
- results = utils.train(model, trainLoader, valLoader, epochs, self.device)
+ results = utils.train(model, train_loader, val_loader, epochs, self.device)
parameters_prime = utils.get_model_params(model)
num_examples_train = len(trainset)
@@ -73,13 +71,13 @@ def evaluate(self, parameters, config):
return float(loss), len(self.testset), {"accuracy": float(accuracy)}
-def client_dry_run(device: str = "cpu"):
+def client_dry_run(device: torch.device = "cpu"):
"""Weak tests to check whether all client methods are working as expected."""
model = utils.load_efficientnet(classes=10)
trainset, testset = utils.load_partition(0)
- trainset = torch.utils.data.Subset(trainset, range(10))
- testset = torch.utils.data.Subset(testset, range(10))
+ trainset = trainset.select(range(10))
+ testset = testset.select(range(10))
client = CifarClient(trainset, testset, device)
client.fit(
utils.get_model_params(model),
@@ -102,7 +100,7 @@ def main() -> None:
help="Do a dry-run to check the client",
)
parser.add_argument(
- "--partition",
+ "--client-id",
type=int,
default=0,
choices=range(0, 10),
@@ -112,9 +110,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
- type=bool,
- default=False,
- required=False,
+ action='store_true',
help="Set to true to quicky run the client using only 10 datasamples. \
Useful for testing purposes. Default: False",
)
@@ -136,12 +132,11 @@ def main() -> None:
client_dry_run(device)
else:
# Load a subset of CIFAR-10 to simulate the local data partition
- trainset, testset = utils.load_partition(args.partition)
+ trainset, testset = utils.load_partition(args.client_id)
if args.toy:
- trainset = torch.utils.data.Subset(trainset, range(10))
- testset = torch.utils.data.Subset(testset, range(10))
-
+ trainset = trainset.select(range(10))
+ testset = testset.select(range(10))
# Start Flower client
client = CifarClient(trainset, testset, device)
diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml
index a12f3c47de70..89fd5a32a89e 100644
--- a/examples/advanced-pytorch/pyproject.toml
+++ b/examples/advanced-pytorch/pyproject.toml
@@ -14,6 +14,7 @@ authors = [
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
+flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
torch = "1.13.1"
torchvision = "0.14.1"
validators = "0.18.2"
diff --git a/examples/advanced-pytorch/requirements.txt b/examples/advanced-pytorch/requirements.txt
index ba7b284df90e..f4d6a0774162 100644
--- a/examples/advanced-pytorch/requirements.txt
+++ b/examples/advanced-pytorch/requirements.txt
@@ -1,4 +1,5 @@
flwr>=1.0, <2.0
+flwr-datasets[vision]>=0.0.2, <1.0.0
torch==1.13.1
torchvision==0.14.1
validators==0.18.2
diff --git a/examples/advanced-pytorch/run.sh b/examples/advanced-pytorch/run.sh
index 212285f504f9..3367e1680535 100755
--- a/examples/advanced-pytorch/run.sh
+++ b/examples/advanced-pytorch/run.sh
@@ -2,20 +2,17 @@
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/
-# Download the CIFAR-10 dataset
-python -c "from torchvision.datasets import CIFAR10; CIFAR10('./dataset', download=True)"
-
# Download the EfficientNetB0 model
python -c "import torch; torch.hub.load( \
'NVIDIA/DeepLearningExamples:torchhub', \
'nvidia_efficientnet_b0', pretrained=True)"
-python server.py &
-sleep 3 # Sleep for 3s to give the server enough time to start
+python server.py --toy &
+sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset
for i in `seq 0 9`; do
echo "Starting client $i"
- python client.py --partition=${i} --toy True &
+ python client.py --client-id=${i} --toy &
done
# Enable CTRL+C to stop all background processes
diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py
index 8343e62da69f..fda49b71a311 100644
--- a/examples/advanced-pytorch/server.py
+++ b/examples/advanced-pytorch/server.py
@@ -10,6 +10,8 @@
import warnings
+from flwr_datasets import FederatedDataset
+
warnings.filterwarnings("ignore")
@@ -39,18 +41,13 @@ def evaluate_config(server_round: int):
def get_evaluate_fn(model: torch.nn.Module, toy: bool):
"""Return an evaluation function for server-side evaluation."""
- # Load data and model here to avoid the overhead of doing it in `evaluate` itself
- trainset, _, _ = utils.load_data()
-
- n_train = len(trainset)
+ # Load data here to avoid the overhead of doing it in `evaluate` itself
+ centralized_data = utils.load_centralized_data()
if toy:
# use only 10 samples as validation set
- valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train))
- else:
- # Use the last 5k training examples as a validation set
- valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train))
+ centralized_data = centralized_data.select(range(10))
- valLoader = DataLoader(valset, batch_size=16)
+ val_loader = DataLoader(centralized_data, batch_size=16)
# The `evaluate` function will be called after every round
def evaluate(
@@ -63,7 +60,7 @@ def evaluate(
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
- loss, accuracy = utils.test(model, valLoader)
+ loss, accuracy = utils.test(model, val_loader)
return loss, {"accuracy": accuracy}
return evaluate
@@ -79,9 +76,7 @@ def main():
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--toy",
- type=bool,
- default=False,
- required=False,
+ action='store_true',
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)
diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py
index 8788ead90dee..6512010b1f23 100644
--- a/examples/advanced-pytorch/utils.py
+++ b/examples/advanced-pytorch/utils.py
@@ -1,49 +1,45 @@
import torch
-import torchvision.transforms as transforms
-from torchvision.datasets import CIFAR10
+from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
+from torch.utils.data import DataLoader
import warnings
-warnings.filterwarnings("ignore")
-
-# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+from flwr_datasets import FederatedDataset
+warnings.filterwarnings("ignore")
-def load_data():
- """Load CIFAR-10 (training and test set)."""
- transform = transforms.Compose(
- [
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- ]
- )
- trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
- testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
+def load_partition(node_id, toy: bool = False):
+ """Load partition CIFAR10 data."""
+ fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
+ partition = fds.load_partition(node_id)
+ # Divide data on each node: 80% train, 20% test
+ partition_train_test = partition.train_test_split(test_size=0.2)
+ partition_train_test = partition_train_test.with_transform(apply_transforms)
+ return partition_train_test["train"], partition_train_test["test"]
- num_examples = {"trainset": len(trainset), "testset": len(testset)}
- return trainset, testset, num_examples
+def load_centralized_data():
+ fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
+ centralized_data = fds.load_full("test")
+ centralized_data = centralized_data.with_transform(apply_transforms)
+ return centralized_data
-def load_partition(idx: int):
- """Load 1/10th of the training and test data to simulate a partition."""
- assert idx in range(10)
- trainset, testset, num_examples = load_data()
- n_train = int(num_examples["trainset"] / 10)
- n_test = int(num_examples["testset"] / 10)
- train_parition = torch.utils.data.Subset(
- trainset, range(idx * n_train, (idx + 1) * n_train)
- )
- test_parition = torch.utils.data.Subset(
- testset, range(idx * n_test, (idx + 1) * n_test)
- )
- return (train_parition, test_parition)
+def apply_transforms(batch):
+ """Apply transforms to the partition from FederatedDataset."""
+ pytorch_transforms = Compose([
+ Resize(256),
+ CenterCrop(224),
+ ToTensor(),
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ])
+ batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
+ return batch
-def train(net, trainloader, valloader, epochs, device: str = "cpu"):
+def train(net, trainloader, valloader, epochs,
+ device: torch.device = torch.device("cpu")):
"""Train the network on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
@@ -53,7 +49,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"):
)
net.train()
for _ in range(epochs):
- for images, labels in trainloader:
+ for batch in trainloader:
+ images, labels = batch["img"], batch["label"]
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
@@ -74,7 +71,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"):
return results
-def test(net, testloader, steps: int = None, device: str = "cpu"):
+def test(net, testloader, steps: int = None,
+ device: torch.device = torch.device("cpu")):
"""Validate the network on the entire test set."""
print("Starting evalutation...")
net.to(device) # move model to GPU if available
@@ -82,7 +80,8 @@ def test(net, testloader, steps: int = None, device: str = "cpu"):
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
- for batch_idx, (images, labels) in enumerate(testloader):
+ for batch_idx, batch in enumerate(testloader):
+ images, labels = batch["img"], batch["label"]
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
@@ -109,12 +108,14 @@ def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int =
entrypoint: EfficientNet model to download.
For supported entrypoints, please refer
https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/
- classes: Number of classes in final classifying layer. Leave as None to get the downloaded
+ classes: Number of classes in final classifying layer. Leave as None to get
+ the downloaded
model untouched.
Returns:
EfficientNet Model
- Note: One alternative implementation can be found at https://github.com/lukemelas/EfficientNet-PyTorch
+ Note: One alternative implementation can be found at
+ https://github.com/lukemelas/EfficientNet-PyTorch
"""
efficientnet = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True
diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md
index 31bf5edb64c6..b21c0d2545ca 100644
--- a/examples/advanced-tensorflow/README.md
+++ b/examples/advanced-tensorflow/README.md
@@ -1,9 +1,9 @@
# Advanced Flower Example (TensorFlow/Keras)
-This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. It differs from the quickstart example in the following ways:
+This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways:
- 10 clients (instead of just 2)
-- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that by default only a small subset of this data is used when running the `run.sh` script)
+- Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script)
- Server-side model evaluation after parameter aggregation
- Hyperparameter schedule using config functions
- Custom return values
@@ -57,10 +57,11 @@ pip install -r requirements.txt
## Run Federated Learning with TensorFlow/Keras and Flower
-The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 2 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows:
+The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 10 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows:
```shell
-poetry run ./run.sh
+# Once you have activated your environment
+./run.sh
```
The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).
diff --git a/examples/advanced-tensorflow/client.py b/examples/advanced-tensorflow/client.py
index 1c0b61575635..033f20b1b027 100644
--- a/examples/advanced-tensorflow/client.py
+++ b/examples/advanced-tensorflow/client.py
@@ -6,6 +6,8 @@
import flwr as fl
+from flwr_datasets import FederatedDataset
+
# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -74,7 +76,7 @@ def main() -> None:
# Parse command line argument `partition`
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
- "--partition",
+ "--client-id",
type=int,
default=0,
choices=range(0, 10),
@@ -84,9 +86,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
- type=bool,
- default=False,
- required=False,
+ action='store_true',
help="Set to true to quicky run the client using only 10 datasamples. "
"Useful for testing purposes. Default: False",
)
@@ -99,7 +99,7 @@ def main() -> None:
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
# Load a subset of CIFAR-10 to simulate the local data partition
- (x_train, y_train), (x_test, y_test) = load_partition(args.partition)
+ x_train, y_train, x_test, y_test = load_partition(args.client_id)
if args.toy:
x_train, y_train = x_train[:10], y_train[:10]
@@ -117,15 +117,16 @@ def main() -> None:
def load_partition(idx: int):
"""Load 1/10th of the training and test data to simulate a partition."""
- assert idx in range(10)
- (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
- return (
- x_train[idx * 5000 : (idx + 1) * 5000],
- y_train[idx * 5000 : (idx + 1) * 5000],
- ), (
- x_test[idx * 1000 : (idx + 1) * 1000],
- y_test[idx * 1000 : (idx + 1) * 1000],
- )
+ # Download and partition dataset
+ fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
+ partition = fds.load_partition(idx)
+ partition.set_format("numpy")
+
+ # Divide data on each node: 80% train, 20% test
+ partition = partition.train_test_split(test_size=0.2)
+ x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"]
+ x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"]
+ return x_train, y_train, x_test, y_test
if __name__ == "__main__":
diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml
index 293ba64b3f43..2f16d8a15584 100644
--- a/examples/advanced-tensorflow/pyproject.toml
+++ b/examples/advanced-tensorflow/pyproject.toml
@@ -11,5 +11,6 @@ authors = ["The Flower Authors "]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
+flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""}
tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""}
diff --git a/examples/advanced-tensorflow/requirements.txt b/examples/advanced-tensorflow/requirements.txt
index 7a70c46a8128..0cb5fe8c07af 100644
--- a/examples/advanced-tensorflow/requirements.txt
+++ b/examples/advanced-tensorflow/requirements.txt
@@ -1,3 +1,4 @@
flwr>=1.0, <2.0
+flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64"
tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64"
diff --git a/examples/advanced-tensorflow/run.sh b/examples/advanced-tensorflow/run.sh
index 8ddb6a252b52..4acef1371571 100755
--- a/examples/advanced-tensorflow/run.sh
+++ b/examples/advanced-tensorflow/run.sh
@@ -5,14 +5,11 @@
echo "Starting server"
python server.py &
-sleep 3 # Sleep for 3s to give the server enough time to start
+sleep 10 # Sleep for 10s to give the server enough time to start and download the dataset
-# Ensure that the Keras dataset used in client.py is already cached.
-python -c "import tensorflow as tf; tf.keras.datasets.cifar10.load_data()"
-
-for i in `seq 0 9`; do
+for i in $(seq 0 9); do
echo "Starting client $i"
- python client.py --partition=${i} --toy True &
+ python client.py --client-id=${i} --toy &
done
# This will allow you to use CTRL+C to stop all background processes
diff --git a/examples/advanced-tensorflow/server.py b/examples/advanced-tensorflow/server.py
index e1eb3d4fd8f7..26dde312bee5 100644
--- a/examples/advanced-tensorflow/server.py
+++ b/examples/advanced-tensorflow/server.py
@@ -4,6 +4,8 @@
import flwr as fl
import tensorflow as tf
+from flwr_datasets import FederatedDataset
+
def main() -> None:
# Load and compile model for
@@ -43,11 +45,11 @@ def main() -> None:
def get_evaluate_fn(model):
"""Return an evaluation function for server-side evaluation."""
- # Load data and model here to avoid the overhead of doing it in `evaluate` itself
- (x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
-
- # Use the last 5k training examples as a validation set
- x_val, y_val = x_train[45000:50000], y_train[45000:50000]
+ # Load data here to avoid the overhead of doing it in `evaluate` itself
+ fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
+ test = fds.load_full("test")
+ test.set_format("numpy")
+ x_test, y_test = test["img"] / 255.0, test["label"]
# The `evaluate` function will be called after every round
def evaluate(
@@ -56,7 +58,7 @@ def evaluate(
config: Dict[str, fl.common.Scalar],
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
model.set_weights(parameters) # Update model with the latest parameters
- loss, accuracy = model.evaluate(x_val, y_val)
+ loss, accuracy = model.evaluate(x_test, y_test)
return loss, {"accuracy": accuracy}
return evaluate
diff --git a/examples/android/README.md b/examples/android/README.md
index 7931aa96b0c5..f9f2bb93b8dc 100644
--- a/examples/android/README.md
+++ b/examples/android/README.md
@@ -54,4 +54,4 @@ poetry run ./run.sh
Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`.
-When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training.
+When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively.
diff --git a/pyproject.toml b/pyproject.toml
index 0616ffdbeffd..cab083b32325 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -84,11 +84,11 @@ types-protobuf = "==3.19.18"
types-requests = "==2.31.0.10"
types-setuptools = "==69.0.0.20240115"
clang-format = "==17.0.4"
-isort = "==5.12.0"
+isort = "==5.13.2"
black = { version = "==23.10.1", extras = ["jupyter"] }
docformatter = "==1.7.5"
mypy = "==1.6.1"
-pylint = "==2.13.9"
+pylint = "==3.0.3"
flake8 = "==5.0.4"
pytest = "==7.4.3"
pytest-cov = "==4.1.0"
@@ -137,7 +137,7 @@ line-length = 88
target-version = ["py38", "py39", "py310", "py311"]
[tool.pylint."MESSAGES CONTROL"]
-disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias"
+disable = "duplicate-code,too-few-public-methods,useless-import-alias"
[tool.pytest.ini_options]
minversion = "6.2"
@@ -184,7 +184,7 @@ target-version = "py38"
line-length = 88
select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"]
-ignore = ["B024", "B027"]
+ignore = ["B024", "B027", "D205", "D209"]
exclude = [
".bzr",
".direnv",
diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py
index a5b285fbb7fb..91fa5468ae75 100644
--- a/src/py/flwr/client/app.py
+++ b/src/py/flwr/client/app.py
@@ -138,10 +138,12 @@ def _check_actionable_client(
client: Optional[Client], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
- raise Exception("Both `client_fn` and `client` are `None`, but one is required")
+ raise ValueError(
+ "Both `client_fn` and `client` are `None`, but one is required"
+ )
if client_fn is not None and client is not None:
- raise Exception(
+ raise ValueError(
"Both `client_fn` and `client` are provided, but only one is allowed"
)
@@ -150,6 +152,7 @@ def _check_actionable_client(
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
+# pylint: disable=too-many-arguments
def start_client(
*,
server_address: str,
@@ -299,7 +302,7 @@ def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> Client:
if client is None: # Added this to keep mypy happy
- raise Exception(
+ raise ValueError(
"Both `client_fn` and `client` are `None`, but one is required"
)
return client # Always return the same instance
diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py
index 7ef6410debad..56d6308a0fe2 100644
--- a/src/py/flwr/client/app_test.py
+++ b/src/py/flwr/client/app_test.py
@@ -41,19 +41,19 @@ class PlainClient(Client):
def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def fit(self, ins: FitIns) -> FitRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
class NeedsWrappingClient(NumPyClient):
@@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient):
def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def get_parameters(self, config: Config) -> NDArrays:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def fit(
self, parameters: NDArrays, config: Config
) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def evaluate(
self, parameters: NDArrays, config: Config
) -> Tuple[float, int, Dict[str, Scalar]]:
"""Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ raise NotImplementedError()
def test_to_client_with_client() -> None:
diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py
index 41b4d676df43..c39b89b31da3 100644
--- a/src/py/flwr/client/dpfedavg_numpy_client.py
+++ b/src/py/flwr/client/dpfedavg_numpy_client.py
@@ -117,16 +117,16 @@ def fit(
update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]
if "dpfedavg_clip_norm" not in config:
- raise Exception("Clipping threshold not supplied by the server.")
+ raise KeyError("Clipping threshold not supplied by the server.")
if not isinstance(config["dpfedavg_clip_norm"], float):
- raise Exception("Clipping threshold should be a floating point value.")
+ raise TypeError("Clipping threshold should be a floating point value.")
# Clipping
update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])
if "dpfedavg_noise_stddev" in config:
if not isinstance(config["dpfedavg_noise_stddev"], float):
- raise Exception(
+ raise TypeError(
"Scale of noise to be added should be a floating point value."
)
# Noising
@@ -138,7 +138,7 @@ def fit(
# Calculating value of norm indicator bit, required for adaptive clipping
if "dpfedavg_adaptive_clip_enabled" in config:
if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
- raise Exception(
+ raise TypeError(
"dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
)
metrics["dpfedavg_norm_bit"] = not clipped
diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py
index 13b1948eec07..3599e1dfb254 100644
--- a/src/py/flwr/client/message_handler/task_handler.py
+++ b/src/py/flwr/client/message_handler/task_handler.py
@@ -80,8 +80,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}
# Check if certain fields are already initialized
- # pylint: disable-next=too-many-boolean-expressions
- if (
+ if ( # pylint: disable-next=too-many-boolean-expressions
"task_id" in initialized_fields_in_task_res
or "group_id" in initialized_fields_in_task_res
or "run_id" in initialized_fields_in_task_res
diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py
index 2312741f5af6..d67fb90512d4 100644
--- a/src/py/flwr/client/numpy_client.py
+++ b/src/py/flwr/client/numpy_client.py
@@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
# Return FitRes
parameters_prime, num_examples, metrics = results
@@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
- raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
+ raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
# Return EvaluateRes
loss, num_examples, metrics = results
diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py
index d22b246dbd61..87b06dd0be4e 100644
--- a/src/py/flwr/client/rest_client/connection.py
+++ b/src/py/flwr/client/rest_client/connection.py
@@ -143,6 +143,7 @@ def create_node() -> None:
},
data=create_node_req_bytes,
verify=verify,
+ timeout=None,
)
# Check status code and headers
@@ -185,6 +186,7 @@ def delete_node() -> None:
},
data=delete_node_req_req_bytes,
verify=verify,
+ timeout=None,
)
# Check status code and headers
@@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]:
},
data=pull_task_ins_req_bytes,
verify=verify,
+ timeout=None,
)
# Check status code and headers
@@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None:
},
data=push_task_res_request_bytes,
verify=verify,
+ timeout=None,
)
state[KEY_TASK_INS] = None
diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py
index efbb00a9d916..4b74c1ace3de 100644
--- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py
+++ b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py
@@ -333,7 +333,7 @@ def _share_keys(
# Check if the size is larger than threshold
if len(state.public_keys_dict) < state.threshold:
- raise Exception("Available neighbours number smaller than threshold")
+ raise ValueError("Available neighbours number smaller than threshold")
# Check if all public keys are unique
pk_list: List[bytes] = []
@@ -341,14 +341,14 @@ def _share_keys(
pk_list.append(pk1)
pk_list.append(pk2)
if len(set(pk_list)) != len(pk_list):
- raise Exception("Some public keys are identical")
+ raise ValueError("Some public keys are identical")
# Check if public keys of this client are correct in the dictionary
if (
state.public_keys_dict[state.sid][0] != state.pk1
or state.public_keys_dict[state.sid][1] != state.pk2
):
- raise Exception(
+ raise ValueError(
"Own public keys are displayed in dict incorrectly, should not happen!"
)
@@ -393,7 +393,7 @@ def _collect_masked_input(
ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
if len(ciphertexts) + 1 < state.threshold:
- raise Exception("Not enough available neighbour clients.")
+ raise ValueError("Not enough available neighbour clients.")
# Decrypt ciphertexts, verify their sources, and store shares.
for src, ciphertext in zip(srcs, ciphertexts):
@@ -409,7 +409,7 @@ def _collect_masked_input(
f"from {actual_src} instead of {src}."
)
if dst != state.sid:
- ValueError(
+ raise ValueError(
f"Client {state.sid}: received an encrypted message"
f"for Client {dst} from Client {src}."
)
@@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
# Send private mask seed share for every avaliable client (including itclient)
# Send first private key share for building pairwise mask for every dropped client
if len(active_sids) < state.threshold:
- raise Exception("Available neighbours number smaller than threshold")
+ raise ValueError("Available neighbours number smaller than threshold")
sids, shares = [], []
sids += active_sids
diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py
new file mode 100644
index 000000000000..3d40c0488baa
--- /dev/null
+++ b/src/py/flwr/common/parametersrecord.py
@@ -0,0 +1,110 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ParametersRecord and Array."""
+
+
+from dataclasses import dataclass, field
+from typing import List, Optional, OrderedDict
+
+
+@dataclass
+class Array:
+ """Array type.
+
+ A dataclass containing serialized data from an array-like or tensor-like object
+ along with some metadata about it.
+
+ Parameters
+ ----------
+ dtype : str
+ A string representing the data type of the serialised object (e.g. `np.float32`)
+
+ shape : List[int]
+ A list representing the shape of the unserialized array-like object. This is
+ used to deserialize the data (depending on the serialization method) or simply
+ as a metadata field.
+
+ stype : str
+ A string indicating the type of serialisation mechanism used to generate the
+ bytes in `data` from an array-like or tensor-like object.
+
+ data: bytes
+ A buffer of bytes containing the data.
+ """
+
+ dtype: str
+ shape: List[int]
+ stype: str
+ data: bytes
+
+
+@dataclass
+class ParametersRecord:
+ """Parameters record.
+
+ A dataclass storing named Arrays in order. This means that it holds entries as an
+ OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
+ PyTorch's state_dict, but holding serialised tensors instead.
+ """
+
+ keep_input: bool
+ data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])
+
+ def __init__(
+ self,
+ array_dict: Optional[OrderedDict[str, Array]] = None,
+ keep_input: bool = False,
+ ) -> None:
+ """Construct a ParametersRecord object.
+
+ Parameters
+ ----------
+ array_dict : Optional[OrderedDict[str, Array]]
+ A dictionary that stores serialized array-like or tensor-like objects.
+ keep_input : bool (default: False)
+ A boolean indicating whether parameters should be deleted from the input
+ dictionary immediately after adding them to the record. If False, the
+ dictionary passed to `set_parameters()` will be empty once exiting from that
+ function. This is the desired behaviour when working with very large
+ models/tensors/arrays. However, if you plan to continue working with your
+ parameters after adding it to the record, set this flag to True. When set
+ to True, the data is duplicated in memory.
+ """
+ self.keep_input = keep_input
+ self.data = OrderedDict()
+ if array_dict:
+ self.set_parameters(array_dict)
+
+ def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
+ """Add parameters to record.
+
+ Parameters
+ ----------
+ array_dict : OrderedDict[str, Array]
+ A dictionary that stores serialized array-like or tensor-like objects.
+ """
+ if any(not isinstance(k, str) for k in array_dict.keys()):
+ raise TypeError(f"Not all keys are of valid type. Expected {str}")
+ if any(not isinstance(v, Array) for v in array_dict.values()):
+ raise TypeError(f"Not all values are of valid type. Expected {Array}")
+
+ if self.keep_input:
+ # Copy
+ self.data = OrderedDict(array_dict)
+ else:
+ # Add entries to dataclass without duplicating memory
+ for key in list(array_dict.keys()):
+ self.data[key] = array_dict[key]
+ del array_dict[key]
diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py
index 0088b7397a6d..dc723a2cea86 100644
--- a/src/py/flwr/common/recordset.py
+++ b/src/py/flwr/common/recordset.py
@@ -14,13 +14,10 @@
# ==============================================================================
"""RecordSet."""
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import Dict
-
-@dataclass
-class ParametersRecord:
- """Parameters record."""
+from .parametersrecord import ParametersRecord
@dataclass
@@ -37,9 +34,9 @@ class ConfigsRecord:
class RecordSet:
"""Definition of RecordSet."""
- parameters: Dict[str, ParametersRecord] = {}
- metrics: Dict[str, MetricsRecord] = {}
- configs: Dict[str, ConfigsRecord] = {}
+ parameters: Dict[str, ParametersRecord] = field(default_factory=dict)
+ metrics: Dict[str, MetricsRecord] = field(default_factory=dict)
+ configs: Dict[str, ConfigsRecord] = field(default_factory=dict)
def set_parameters(self, name: str, record: ParametersRecord) -> None:
"""Add a ParametersRecord."""
diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py
new file mode 100644
index 000000000000..90c06dcdb109
--- /dev/null
+++ b/src/py/flwr/common/recordset_test.py
@@ -0,0 +1,147 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RecordSet tests."""
+
+
+from typing import Callable, List, OrderedDict, Type, Union
+
+import numpy as np
+import pytest
+
+from .parameter import ndarrays_to_parameters, parameters_to_ndarrays
+from .parametersrecord import Array, ParametersRecord
+from .recordset_utils import (
+ parameters_to_parametersrecord,
+ parametersrecord_to_parameters,
+)
+from .typing import NDArray, NDArrays, Parameters
+
+
+def get_ndarrays() -> NDArrays:
+ """Return list of NumPy arrays."""
+ arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]])
+ arr2 = np.eye(2, 7, 3)
+
+ return [arr1, arr2]
+
+
+def ndarray_to_array(ndarray: NDArray) -> Array:
+ """Represent NumPy ndarray as Array."""
+ return Array(
+ data=ndarray.tobytes(),
+ dtype=str(ndarray.dtype),
+ stype="numpy.ndarray.tobytes",
+ shape=list(ndarray.shape),
+ )
+
+
+def test_ndarray_to_array() -> None:
+ """Test creation of Array object from NumPy ndarray."""
+ shape = (2, 7, 9)
+ arr = np.eye(*shape)
+
+ array = ndarray_to_array(arr)
+
+ arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape)
+
+ assert np.array_equal(arr, arr_)
+
+
+def test_parameters_to_array_and_back() -> None:
+ """Test conversion between legacy Parameters and Array."""
+ ndarrays = get_ndarrays()
+
+ # Array represents a single array, unlike Paramters, which represent a
+ # list of arrays
+ ndarray = ndarrays[0]
+
+ parameters = ndarrays_to_parameters([ndarray])
+
+ array = Array(
+ data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[]
+ )
+
+ parameters = Parameters(tensors=[array.data], tensor_type=array.stype)
+
+ ndarray_ = parameters_to_ndarrays(parameters=parameters)[0]
+
+ assert np.array_equal(ndarray, ndarray_)
+
+
+def test_parameters_to_parametersrecord_and_back() -> None:
+ """Test conversion between legacy Parameters and ParametersRecords."""
+ ndarrays = get_ndarrays()
+
+ parameters = ndarrays_to_parameters(ndarrays)
+
+ params_record = parameters_to_parametersrecord(parameters=parameters)
+
+ parameters_ = parametersrecord_to_parameters(params_record)
+
+ ndarrays_ = parameters_to_ndarrays(parameters=parameters_)
+
+ for arr, arr_ in zip(ndarrays, ndarrays_):
+ assert np.array_equal(arr, arr_)
+
+
+def test_set_parameters_while_keeping_intputs() -> None:
+ """Tests keep_input functionality in ParametersRecord."""
+ # Adding parameters to a record that doesn't erase entries in the input `array_dict`
+ p_record = ParametersRecord(keep_input=True)
+ array_dict = OrderedDict(
+ {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
+ )
+ p_record.set_parameters(array_dict)
+
+ # Creating a second parametersrecord passing the same `array_dict` (not erased)
+ p_record_2 = ParametersRecord(array_dict)
+ assert p_record.data == p_record_2.data
+
+ # Now it should be empty (the second ParametersRecord wasn't flagged to keep it)
+ assert len(array_dict) == 0
+
+
+def test_set_parameters_with_correct_types() -> None:
+ """Test adding dictionary of Arrays to ParametersRecord."""
+ p_record = ParametersRecord()
+ array_dict = OrderedDict(
+ {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())}
+ )
+ p_record.set_parameters(array_dict)
+
+
+@pytest.mark.parametrize(
+ "key_type, value_fn",
+ [
+ (str, lambda x: x), # correct key, incorrect value
+ (str, lambda x: x.tolist()), # correct key, incorrect value
+ (int, ndarray_to_array), # incorrect key, correct value
+ (int, lambda x: x), # incorrect key, incorrect value
+ (int, lambda x: x.tolist()), # incorrect key, incorrect value
+ ],
+)
+def test_set_parameters_with_incorrect_types(
+ key_type: Type[Union[int, str]],
+ value_fn: Callable[[NDArray], Union[NDArray, List[float]]],
+) -> None:
+ """Test adding dictionary of unsupported types to ParametersRecord."""
+ p_record = ParametersRecord()
+
+ array_dict = {
+ key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays())
+ }
+
+ with pytest.raises(TypeError):
+ p_record.set_parameters(array_dict) # type: ignore
diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py
new file mode 100644
index 000000000000..c1e724fa2758
--- /dev/null
+++ b/src/py/flwr/common/recordset_utils.py
@@ -0,0 +1,87 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RecordSet utilities."""
+
+
+from typing import OrderedDict
+
+from .parametersrecord import Array, ParametersRecord
+from .typing import Parameters
+
+
+def parametersrecord_to_parameters(
+ record: ParametersRecord, keep_input: bool = False
+) -> Parameters:
+ """Convert ParameterRecord to legacy Parameters.
+
+ Warning: Because `Arrays` in `ParametersRecord` encode more information of the
+ array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it
+ might not be possible to reconstruct such data structures from `Parameters` objects
+ alone. Additional information or metadta must be provided from elsewhere.
+
+ Parameters
+ ----------
+ record : ParametersRecord
+ The record to be conveted into Parameters.
+ keep_input : bool (default: False)
+ A boolean indicating whether entries in the record should be deleted from the
+ input dictionary immediately after adding them to the record.
+ """
+ parameters = Parameters(tensors=[], tensor_type="")
+
+ for key in list(record.data.keys()):
+ parameters.tensors.append(record.data[key].data)
+
+ if not keep_input:
+ del record.data[key]
+
+ return parameters
+
+
+def parameters_to_parametersrecord(
+ parameters: Parameters, keep_input: bool = False
+) -> ParametersRecord:
+ """Convert legacy Parameters into a single ParametersRecord.
+
+ Because there is no concept of names in the legacy Parameters, arbitrary keys will
+ be used when constructing the ParametersRecord. Similarly, the shape and data type
+ won't be recorded in the Array objects.
+
+ Parameters
+ ----------
+ parameters : Parameters
+ Parameters object to be represented as a ParametersRecord.
+ keep_input : bool (default: False)
+ A boolean indicating whether parameters should be deleted from the input
+ Parameters object (i.e. a list of serialized NumPy arrays) immediately after
+ adding them to the record.
+ """
+ tensor_type = parameters.tensor_type
+
+ p_record = ParametersRecord()
+
+ num_arrays = len(parameters.tensors)
+ for idx in range(num_arrays):
+ if keep_input:
+ tensor = parameters.tensors[idx]
+ else:
+ tensor = parameters.tensors.pop(0)
+ p_record.set_parameters(
+ OrderedDict(
+ {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])}
+ )
+ )
+
+ return p_record
diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py
index a60fff57e7bf..5441e766983a 100644
--- a/src/py/flwr/common/retry_invoker.py
+++ b/src/py/flwr/common/retry_invoker.py
@@ -156,6 +156,7 @@ class RetryInvoker:
>>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
"""
+ # pylint: disable-next=too-many-arguments
def __init__(
self,
wait_factory: Callable[[], Generator[float, None, None]],
diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py
index c8c73e87e04a..59f5387b0a07 100644
--- a/src/py/flwr/common/serde.py
+++ b/src/py/flwr/common/serde.py
@@ -59,7 +59,9 @@ def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessa
server_message.evaluate_ins,
)
)
- raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf")
+ raise ValueError(
+ "No instruction set in ServerMessage, cannot serialize to ProtoBuf"
+ )
def server_message_from_proto(
@@ -91,7 +93,7 @@ def server_message_from_proto(
server_message_proto.evaluate_ins,
)
)
- raise Exception(
+ raise ValueError(
"Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf"
)
@@ -125,7 +127,9 @@ def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessa
client_message.evaluate_res,
)
)
- raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf")
+ raise ValueError(
+ "No instruction set in ClientMessage, cannot serialize to ProtoBuf"
+ )
def client_message_from_proto(
@@ -157,7 +161,7 @@ def client_message_from_proto(
client_message_proto.evaluate_res,
)
)
- raise Exception(
+ raise ValueError(
"Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf"
)
@@ -474,7 +478,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
if isinstance(scalar, str):
return Scalar(string=scalar)
- raise Exception(
+ raise ValueError(
f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})"
)
@@ -518,7 +522,7 @@ def _check_value(value: typing.Value) -> None:
for element in value:
if isinstance(element, data_type):
continue
- raise Exception(
+ raise TypeError(
f"Inconsistent type: the types of elements in the list must "
f"be the same (expected {data_type}, but got {type(element)})."
)
diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py
index 2c3a6d2ccddf..82747e5afb2c 100644
--- a/src/py/flwr/driver/app_test.py
+++ b/src/py/flwr/driver/app_test.py
@@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Flower Driver app tests."""
-# pylint: disable=no-self-use
import threading
diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py
index 92b4230a3932..8f75bbf78362 100644
--- a/src/py/flwr/driver/driver_test.py
+++ b/src/py/flwr/driver/driver_test.py
@@ -139,6 +139,7 @@ def test_del_with_initialized_driver(self) -> None:
self.driver._get_grpc_driver_and_run_id()
# Execute
+ # pylint: disable-next=unnecessary-dunder-call
self.driver.__del__()
# Assert
@@ -147,6 +148,7 @@ def test_del_with_initialized_driver(self) -> None:
def test_del_with_uninitialized_driver(self) -> None:
"""Test cleanup behavior when Driver is not initialized."""
# Execute
+ # pylint: disable-next=unnecessary-dunder-call
self.driver.__del__()
# Assert
diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py
index b6d42fe799d5..627b95cdb1b4 100644
--- a/src/py/flwr/driver/grpc_driver.py
+++ b/src/py/flwr/driver/grpc_driver.py
@@ -89,7 +89,7 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call Driver API
res: CreateRunResponse = self.stub.CreateRun(request=req)
@@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call gRPC Driver API
res: GetNodesResponse = self.stub.GetNodes(request=req)
@@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call gRPC Driver API
res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
@@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
- raise Exception("`GrpcDriver` instance not connected")
+ raise ConnectionError("`GrpcDriver` instance not connected")
# Call Driver API
res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py
index 6ae38ea3d805..4e68499f018d 100644
--- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py
+++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py
@@ -113,7 +113,7 @@ def _transition(self, next_status: Status) -> None:
):
self._status = next_status
else:
- raise Exception(f"Invalid transition: {self._status} to {next_status}")
+ raise ValueError(f"Invalid transition: {self._status} to {next_status}")
self._cv.notify_all()
@@ -129,7 +129,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper:
self._raise_if_closed()
if self._status != Status.AWAITING_INS_WRAPPER:
- raise Exception("This should not happen")
+ raise ValueError("This should not happen")
self._ins_wrapper = ins_wrapper # Write
self._transition(Status.INS_WRAPPER_AVAILABLE)
@@ -146,7 +146,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper:
self._transition(Status.AWAITING_INS_WRAPPER)
if res_wrapper is None:
- raise Exception("ResWrapper can not be None")
+ raise ValueError("ResWrapper can not be None")
return res_wrapper
@@ -170,7 +170,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]:
self._transition(Status.AWAITING_RES_WRAPPER)
if ins_wrapper is None:
- raise Exception("InsWrapper can not be None")
+ raise ValueError("InsWrapper can not be None")
yield ins_wrapper
@@ -180,7 +180,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None:
self._raise_if_closed()
if self._status != Status.AWAITING_RES_WRAPPER:
- raise Exception("This should not happen")
+ raise ValueError("This should not happen")
self._res_wrapper = res_wrapper # Write
self._transition(Status.RES_WRAPPER_AVAILABLE)
diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py
index 18a2144072ed..bcfbe6e6fac8 100644
--- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py
+++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py
@@ -70,6 +70,7 @@ def test_workflow_successful() -> None:
_ = next(ins_wrapper_iterator)
bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage()))
except Exception as exception:
+ # pylint: disable-next=broad-exception-raised
raise Exception from exception
# Wait until worker_thread is finished
diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
index 1c737d31c7fc..0fa6f82a89b5 100644
--- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
+++ b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py
@@ -166,6 +166,6 @@ def _call_client_proxy(
evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res)
return ClientMessage(evaluate_res=evaluate_res_proto)
- raise Exception(
+ raise ValueError(
"Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf"
)
diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py
index 63ec1021ff5c..9b5c03aeeaf9 100644
--- a/src/py/flwr/server/server_test.py
+++ b/src/py/flwr/server/server_test.py
@@ -47,14 +47,14 @@ class SuccessClient(ClientProxy):
def get_properties(
self, ins: GetPropertiesIns, timeout: Optional[float]
) -> GetPropertiesRes:
- """Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ """Raise an error because this method is not expected to be called."""
+ raise NotImplementedError()
def get_parameters(
self, ins: GetParametersIns, timeout: Optional[float]
) -> GetParametersRes:
- """Raise an Exception because this method is not expected to be called."""
- raise Exception()
+ """Raise a error because this method is not expected to be called."""
+ raise NotImplementedError()
def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes:
"""Simulate fit by returning a success FitRes with simple set of weights."""
@@ -87,26 +87,26 @@ class FailingClient(ClientProxy):
def get_properties(
self, ins: GetPropertiesIns, timeout: Optional[float]
) -> GetPropertiesRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def get_parameters(
self, ins: GetParametersIns, timeout: Optional[float]
) -> GetParametersRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes:
- """Raise an Exception to simulate failure in the client."""
- raise Exception()
+ """Raise a NotImplementedError to simulate failure in the client."""
+ raise NotImplementedError()
def test_fit_clients() -> None:
diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py
index 4f66be3ff262..26f326819971 100644
--- a/src/py/flwr/server/state/sqlite_state.py
+++ b/src/py/flwr/server/state/sqlite_state.py
@@ -134,7 +134,7 @@ def query(
) -> List[Dict[str, Any]]:
"""Execute a SQL query."""
if self.conn is None:
- raise Exception("State is not initialized.")
+ raise AttributeError("State is not initialized.")
if data is None:
data = []
@@ -459,7 +459,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:
"""
if self.conn is None:
- raise Exception("State not intitialized")
+ raise AttributeError("State not intitialized")
with self.conn:
self.conn.execute(query_1, data)
diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py
index efdd288fc308..a3f899386011 100644
--- a/src/py/flwr/server/state/sqlite_state_test.py
+++ b/src/py/flwr/server/state/sqlite_state_test.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Test for utility functions."""
-# pylint: disable=no-self-use, invalid-name, disable=R0904
+# pylint: disable=invalid-name, disable=R0904
import unittest
diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py
index 88b4b53aed4c..204b4ba97b5f 100644
--- a/src/py/flwr/server/state/state_test.py
+++ b/src/py/flwr/server/state/state_test.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests all state implemenations have to conform to."""
-# pylint: disable=no-self-use, invalid-name, disable=R0904
+# pylint: disable=invalid-name, disable=R0904
import tempfile
import unittest
diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py
index 4eb76111b266..c668b55eebe6 100644
--- a/src/py/flwr/server/strategy/aggregate.py
+++ b/src/py/flwr/server/strategy/aggregate.py
@@ -27,7 +27,7 @@
def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
"""Compute weighted average."""
# Calculate the total number of examples used during training
- num_examples_total = sum([num_examples for _, num_examples in results])
+ num_examples_total = sum(num_examples for (_, num_examples) in results)
# Create a list of weights, each multiplied by the related number of examples
weighted_weights = [
@@ -45,7 +45,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
"""Compute in-place weighted average."""
# Count total examples
- num_examples_total = sum([fit_res.num_examples for _, fit_res in results])
+ num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
# Compute scaling factors for each result
scaling_factors = [
@@ -95,9 +95,9 @@ def aggregate_krum(
# For each client, take the n-f-2 closest parameters vectors
num_closest = max(1, len(weights) - num_malicious - 2)
closest_indices = []
- for i, _ in enumerate(distance_matrix):
+ for distance in distance_matrix:
closest_indices.append(
- np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203
+ np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203
)
# Compute the score for each client, that is the sum of the distances
@@ -202,7 +202,7 @@ def aggregate_bulyan(
def weighted_loss_avg(results: List[Tuple[int, float]]) -> float:
"""Aggregate evaluation results obtained from multiple clients."""
- num_total_evaluation_examples = sum([num_examples for num_examples, _ in results])
+ num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results)
weighted_losses = [num_examples * loss for num_examples, loss in results]
return sum(weighted_losses) / num_total_evaluation_examples
@@ -233,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray:
"""
flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights])
distance_matrix = np.zeros((len(weights), len(weights)))
- for i, _ in enumerate(flat_w):
- for j, _ in enumerate(flat_w):
- delta = flat_w[i] - flat_w[j]
+ for i, flat_w_i in enumerate(flat_w):
+ for j, flat_w_j in enumerate(flat_w):
+ delta = flat_w_i - flat_w_j
norm = np.linalg.norm(delta)
distance_matrix[i, j] = norm**2
return distance_matrix
diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py
index 3269735e9d73..8b3278cc9ba0 100644
--- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py
+++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py
@@ -91,7 +91,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None:
norm_bit_set_count = 0
for client_proxy, fit_res in results:
if "dpfedavg_norm_bit" not in fit_res.metrics:
- raise Exception(
+ raise KeyError(
f"Indicator bit not returned by client with id {client_proxy.cid}."
)
if fit_res.metrics["dpfedavg_norm_bit"]:
diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py
index 0154cfd79fc5..f2f1c206f3de 100644
--- a/src/py/flwr/server/strategy/dpfedavg_fixed.py
+++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py
@@ -46,11 +46,11 @@ def __init__(
self.num_sampled_clients = num_sampled_clients
if clip_norm <= 0:
- raise Exception("The clipping threshold should be a positive value.")
+ raise ValueError("The clipping threshold should be a positive value.")
self.clip_norm = clip_norm
if noise_multiplier < 0:
- raise Exception("The noise multiplier should be a non-negative value.")
+ raise ValueError("The noise multiplier should be a non-negative value.")
self.noise_multiplier = noise_multiplier
self.server_side_noising = server_side_noising
diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py
index e890f7216020..6678b7ced114 100644
--- a/src/py/flwr/server/strategy/fedavg_android.py
+++ b/src/py/flwr/server/strategy/fedavg_android.py
@@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays:
"""Convert parameters object to NumPy weights."""
return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors]
- # pylint: disable=R0201
def ndarray_to_bytes(self, ndarray: NDArray) -> bytes:
"""Serialize NumPy array to bytes."""
return ndarray.tobytes()
- # pylint: disable=R0201
def bytes_to_ndarray(self, tensor: bytes) -> NDArray:
"""Deserialize NumPy array from bytes."""
ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32)
diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py
index 7a5bf1425b44..17e979d92beb 100644
--- a/src/py/flwr/server/strategy/fedmedian.py
+++ b/src/py/flwr/server/strategy/fedmedian.py
@@ -36,7 +36,7 @@
class FedMedian(FedAvg):
- """Configurable FedAvg with Momentum strategy implementation."""
+ """Configurable FedMedian strategy implementation."""
def __repr__(self) -> str:
"""Compute a string representation of the strategy."""
diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py
index 94a67fbcbfae..758e8e608e9f 100644
--- a/src/py/flwr/server/strategy/qfedavg.py
+++ b/src/py/flwr/server/strategy/qfedavg.py
@@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float:
hs_ffl = []
if self.pre_weights is None:
- raise Exception("QffedAvg pre_weights are None in aggregate_fit")
+ raise AttributeError("QffedAvg pre_weights are None in aggregate_fit")
weights_before = self.pre_weights
eval_result = self.evaluate(