diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index 36b94b5c7e97..4a1289d9175a 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -98,7 +98,7 @@ jobs: touch "/tmp/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@c7d193f32edcb7bfad88892161225aeda64e9392 # v4.0.0 + uses: actions/upload-artifact@1eb3cb2b3e0f29609092a73eb033bb759a334595 # v4.1.0 with: name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }} path: /tmp/digests/* @@ -114,7 +114,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@f44cd7b40bfd40b6aa1cc1b9b5b7bf03d3c67110 # v4.1.0 + uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index 8758d0e1c5c7..2ca596a59361 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -20,7 +20,7 @@ jobs: name: Test runs-on: macos-latest steps: - - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c + - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4 with: swift-version: 5 - uses: actions/checkout@v4 @@ -31,7 +31,7 @@ jobs: runs-on: macos-latest name: Build docs steps: - - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c + - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4 with: swift-version: 5 - uses: actions/checkout@v4 @@ -44,7 +44,7 @@ jobs: runs-on: macos-latest name: Deploy docs steps: - - uses: fwal/setup-swift@f51889efb55dccf13be0ee727e3d6c89a096fb4c + - uses: fwal/setup-swift@cdbe0f7f4c77929b6580e71983e8606e55ffe7e4 with: swift-version: 5 - uses: actions/checkout@v4 diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index 1107f6798b23..ff3dbb605846 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth Install stable release ---------------------- +Using pip +~~~~~~~~~ + Stable releases are available on `PyPI `_:: python -m pip install flwr @@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed python -m pip install flwr[simulation] +Using conda (or mamba) +~~~~~~~~~~~~~~~~~~~~~~ + +Flower can also be installed from the ``conda-forge`` channel. + +If you have not added ``conda-forge`` to your channels, you will first need to run the following:: + + conda config --add channels conda-forge + conda config --set channel_priority strict + +Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``:: + + conda install flwr + +or with ``mamba``:: + + mamba install flwr + + Verify installation ------------------- diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index db0245e41453..2527e8e4a820 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -1,6 +1,6 @@ # Advanced Flower Example (PyTorch) -This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) - Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set). @@ -59,12 +59,13 @@ pip install -r requirements.txt The included `run.sh` will start the Flower server (using `server.py`), sleep for 2 seconds to ensure that the server is up, and then start 10 Flower clients (using `client.py`) with only a small subset of the data (in order to run on any machine), -but this can be changed by removing the `--toy True` argument in the script. You can simply start everything in a terminal as follows: +but this can be changed by removing the `--toy` argument in the script. You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +# After activating your environment +./run.sh ``` The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network). +You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index f9ffb6181fd8..b22cbcd70465 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -6,6 +6,7 @@ import argparse from collections import OrderedDict import warnings +import datasets warnings.filterwarnings("ignore") @@ -13,9 +14,9 @@ class CifarClient(fl.client.NumPyClient): def __init__( self, - trainset: torchvision.datasets, - testset: torchvision.datasets, - device: str, + trainset: datasets.Dataset, + testset: datasets.Dataset, + device: torch.device, validation_split: int = 0.1, ): self.device = device @@ -41,17 +42,14 @@ def fit(self, parameters, config): batch_size: int = config["batch_size"] epochs: int = config["local_epochs"] - n_valset = int(len(self.trainset) * self.validation_split) + train_valid = self.trainset.train_test_split(self.validation_split) + trainset = train_valid["train"] + valset = train_valid["test"] - valset = torch.utils.data.Subset(self.trainset, range(0, n_valset)) - trainset = torch.utils.data.Subset( - self.trainset, range(n_valset, len(self.trainset)) - ) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(valset, batch_size=batch_size) - trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True) - valLoader = DataLoader(valset, batch_size=batch_size) - - results = utils.train(model, trainLoader, valLoader, epochs, self.device) + results = utils.train(model, train_loader, val_loader, epochs, self.device) parameters_prime = utils.get_model_params(model) num_examples_train = len(trainset) @@ -73,13 +71,13 @@ def evaluate(self, parameters, config): return float(loss), len(self.testset), {"accuracy": float(accuracy)} -def client_dry_run(device: str = "cpu"): +def client_dry_run(device: torch.device = "cpu"): """Weak tests to check whether all client methods are working as expected.""" model = utils.load_efficientnet(classes=10) trainset, testset = utils.load_partition(0) - trainset = torch.utils.data.Subset(trainset, range(10)) - testset = torch.utils.data.Subset(testset, range(10)) + trainset = trainset.select(range(10)) + testset = testset.select(range(10)) client = CifarClient(trainset, testset, device) client.fit( utils.get_model_params(model), @@ -102,7 +100,7 @@ def main() -> None: help="Do a dry-run to check the client", ) parser.add_argument( - "--partition", + "--client-id", type=int, default=0, choices=range(0, 10), @@ -112,9 +110,7 @@ def main() -> None: ) parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action='store_true', help="Set to true to quicky run the client using only 10 datasamples. \ Useful for testing purposes. Default: False", ) @@ -136,12 +132,11 @@ def main() -> None: client_dry_run(device) else: # Load a subset of CIFAR-10 to simulate the local data partition - trainset, testset = utils.load_partition(args.partition) + trainset, testset = utils.load_partition(args.client_id) if args.toy: - trainset = torch.utils.data.Subset(trainset, range(10)) - testset = torch.utils.data.Subset(testset, range(10)) - + trainset = trainset.select(range(10)) + testset = testset.select(range(10)) # Start Flower client client = CifarClient(trainset, testset, device) diff --git a/examples/advanced-pytorch/pyproject.toml b/examples/advanced-pytorch/pyproject.toml index a12f3c47de70..89fd5a32a89e 100644 --- a/examples/advanced-pytorch/pyproject.toml +++ b/examples/advanced-pytorch/pyproject.toml @@ -14,6 +14,7 @@ authors = [ [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "1.13.1" torchvision = "0.14.1" validators = "0.18.2" diff --git a/examples/advanced-pytorch/requirements.txt b/examples/advanced-pytorch/requirements.txt index ba7b284df90e..f4d6a0774162 100644 --- a/examples/advanced-pytorch/requirements.txt +++ b/examples/advanced-pytorch/requirements.txt @@ -1,4 +1,5 @@ flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 torch==1.13.1 torchvision==0.14.1 validators==0.18.2 diff --git a/examples/advanced-pytorch/run.sh b/examples/advanced-pytorch/run.sh index 212285f504f9..3367e1680535 100755 --- a/examples/advanced-pytorch/run.sh +++ b/examples/advanced-pytorch/run.sh @@ -2,20 +2,17 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the CIFAR-10 dataset -python -c "from torchvision.datasets import CIFAR10; CIFAR10('./dataset', download=True)" - # Download the EfficientNetB0 model python -c "import torch; torch.hub.load( \ 'NVIDIA/DeepLearningExamples:torchhub', \ 'nvidia_efficientnet_b0', pretrained=True)" -python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start +python server.py --toy & +sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset for i in `seq 0 9`; do echo "Starting client $i" - python client.py --partition=${i} --toy True & + python client.py --client-id=${i} --toy & done # Enable CTRL+C to stop all background processes diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index 8343e62da69f..fda49b71a311 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -10,6 +10,8 @@ import warnings +from flwr_datasets import FederatedDataset + warnings.filterwarnings("ignore") @@ -39,18 +41,13 @@ def evaluate_config(server_round: int): def get_evaluate_fn(model: torch.nn.Module, toy: bool): """Return an evaluation function for server-side evaluation.""" - # Load data and model here to avoid the overhead of doing it in `evaluate` itself - trainset, _, _ = utils.load_data() - - n_train = len(trainset) + # Load data here to avoid the overhead of doing it in `evaluate` itself + centralized_data = utils.load_centralized_data() if toy: # use only 10 samples as validation set - valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train)) - else: - # Use the last 5k training examples as a validation set - valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train)) + centralized_data = centralized_data.select(range(10)) - valLoader = DataLoader(valset, batch_size=16) + val_loader = DataLoader(centralized_data, batch_size=16) # The `evaluate` function will be called after every round def evaluate( @@ -63,7 +60,7 @@ def evaluate( state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) model.load_state_dict(state_dict, strict=True) - loss, accuracy = utils.test(model, valLoader) + loss, accuracy = utils.test(model, val_loader) return loss, {"accuracy": accuracy} return evaluate @@ -79,9 +76,7 @@ def main(): parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action='store_true', help="Set to true to use only 10 datasamples for validation. \ Useful for testing purposes. Default: False", ) diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 8788ead90dee..6512010b1f23 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,49 +1,45 @@ import torch -import torchvision.transforms as transforms -from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop +from torch.utils.data import DataLoader import warnings -warnings.filterwarnings("ignore") - -# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +from flwr_datasets import FederatedDataset +warnings.filterwarnings("ignore") -def load_data(): - """Load CIFAR-10 (training and test set).""" - transform = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - trainset = CIFAR10("./dataset", train=True, download=True, transform=transform) - testset = CIFAR10("./dataset", train=False, download=True, transform=transform) +def load_partition(node_id, toy: bool = False): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(node_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + partition_train_test = partition_train_test.with_transform(apply_transforms) + return partition_train_test["train"], partition_train_test["test"] - num_examples = {"trainset": len(trainset), "testset": len(testset)} - return trainset, testset, num_examples +def load_centralized_data(): + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + centralized_data = fds.load_full("test") + centralized_data = centralized_data.with_transform(apply_transforms) + return centralized_data -def load_partition(idx: int): - """Load 1/10th of the training and test data to simulate a partition.""" - assert idx in range(10) - trainset, testset, num_examples = load_data() - n_train = int(num_examples["trainset"] / 10) - n_test = int(num_examples["testset"] / 10) - train_parition = torch.utils.data.Subset( - trainset, range(idx * n_train, (idx + 1) * n_train) - ) - test_parition = torch.utils.data.Subset( - testset, range(idx * n_test, (idx + 1) * n_test) - ) - return (train_parition, test_parition) +def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + pytorch_transforms = Compose([ + Resize(256), + CenterCrop(224), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + return batch -def train(net, trainloader, valloader, epochs, device: str = "cpu"): +def train(net, trainloader, valloader, epochs, + device: torch.device = torch.device("cpu")): """Train the network on the training set.""" print("Starting training...") net.to(device) # move model to GPU if available @@ -53,7 +49,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"): ) net.train() for _ in range(epochs): - for images, labels in trainloader: + for batch in trainloader: + images, labels = batch["img"], batch["label"] images, labels = images.to(device), labels.to(device) optimizer.zero_grad() loss = criterion(net(images), labels) @@ -74,7 +71,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"): return results -def test(net, testloader, steps: int = None, device: str = "cpu"): +def test(net, testloader, steps: int = None, + device: torch.device = torch.device("cpu")): """Validate the network on the entire test set.""" print("Starting evalutation...") net.to(device) # move model to GPU if available @@ -82,7 +80,8 @@ def test(net, testloader, steps: int = None, device: str = "cpu"): correct, loss = 0, 0.0 net.eval() with torch.no_grad(): - for batch_idx, (images, labels) in enumerate(testloader): + for batch_idx, batch in enumerate(testloader): + images, labels = batch["img"], batch["label"] images, labels = images.to(device), labels.to(device) outputs = net(images) loss += criterion(outputs, labels).item() @@ -109,12 +108,14 @@ def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = entrypoint: EfficientNet model to download. For supported entrypoints, please refer https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/ - classes: Number of classes in final classifying layer. Leave as None to get the downloaded + classes: Number of classes in final classifying layer. Leave as None to get + the downloaded model untouched. Returns: EfficientNet Model - Note: One alternative implementation can be found at https://github.com/lukemelas/EfficientNet-PyTorch + Note: One alternative implementation can be found at + https://github.com/lukemelas/EfficientNet-PyTorch """ efficientnet = torch.hub.load( "NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True diff --git a/examples/advanced-tensorflow/README.md b/examples/advanced-tensorflow/README.md index 31bf5edb64c6..b21c0d2545ca 100644 --- a/examples/advanced-tensorflow/README.md +++ b/examples/advanced-tensorflow/README.md @@ -1,9 +1,9 @@ # Advanced Flower Example (TensorFlow/Keras) -This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. It differs from the quickstart example in the following ways: +This example demonstrates an advanced federated learning setup using Flower with TensorFlow/Keras. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways: - 10 clients (instead of just 2) -- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that by default only a small subset of this data is used when running the `run.sh` script) +- Each client holds a local dataset of 1/10 of the train datasets and 80% is training examples and 20% as test examples (note that by default only a small subset of this data is used when running the `run.sh` script) - Server-side model evaluation after parameter aggregation - Hyperparameter schedule using config functions - Custom return values @@ -57,10 +57,11 @@ pip install -r requirements.txt ## Run Federated Learning with TensorFlow/Keras and Flower -The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 2 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: +The included `run.sh` will call a script to generate certificates (which will be used by server and clients), start the Flower server (using `server.py`), sleep for 10 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows: ```shell -poetry run ./run.sh +# Once you have activated your environment +./run.sh ``` The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). diff --git a/examples/advanced-tensorflow/client.py b/examples/advanced-tensorflow/client.py index 1c0b61575635..033f20b1b027 100644 --- a/examples/advanced-tensorflow/client.py +++ b/examples/advanced-tensorflow/client.py @@ -6,6 +6,8 @@ import flwr as fl +from flwr_datasets import FederatedDataset + # Make TensorFlow logs less verbose os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -74,7 +76,7 @@ def main() -> None: # Parse command line argument `partition` parser = argparse.ArgumentParser(description="Flower") parser.add_argument( - "--partition", + "--client-id", type=int, default=0, choices=range(0, 10), @@ -84,9 +86,7 @@ def main() -> None: ) parser.add_argument( "--toy", - type=bool, - default=False, - required=False, + action='store_true', help="Set to true to quicky run the client using only 10 datasamples. " "Useful for testing purposes. Default: False", ) @@ -99,7 +99,7 @@ def main() -> None: model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"]) # Load a subset of CIFAR-10 to simulate the local data partition - (x_train, y_train), (x_test, y_test) = load_partition(args.partition) + x_train, y_train, x_test, y_test = load_partition(args.client_id) if args.toy: x_train, y_train = x_train[:10], y_train[:10] @@ -117,15 +117,16 @@ def main() -> None: def load_partition(idx: int): """Load 1/10th of the training and test data to simulate a partition.""" - assert idx in range(10) - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() - return ( - x_train[idx * 5000 : (idx + 1) * 5000], - y_train[idx * 5000 : (idx + 1) * 5000], - ), ( - x_test[idx * 1000 : (idx + 1) * 1000], - y_test[idx * 1000 : (idx + 1) * 1000], - ) + # Download and partition dataset + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + partition = fds.load_partition(idx) + partition.set_format("numpy") + + # Divide data on each node: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2) + x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + return x_train, y_train, x_test, y_test if __name__ == "__main__": diff --git a/examples/advanced-tensorflow/pyproject.toml b/examples/advanced-tensorflow/pyproject.toml index 293ba64b3f43..2f16d8a15584 100644 --- a/examples/advanced-tensorflow/pyproject.toml +++ b/examples/advanced-tensorflow/pyproject.toml @@ -11,5 +11,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\""} tensorflow-macos = {version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\""} diff --git a/examples/advanced-tensorflow/requirements.txt b/examples/advanced-tensorflow/requirements.txt index 7a70c46a8128..0cb5fe8c07af 100644 --- a/examples/advanced-tensorflow/requirements.txt +++ b/examples/advanced-tensorflow/requirements.txt @@ -1,3 +1,4 @@ flwr>=1.0, <2.0 +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" diff --git a/examples/advanced-tensorflow/run.sh b/examples/advanced-tensorflow/run.sh index 8ddb6a252b52..4acef1371571 100755 --- a/examples/advanced-tensorflow/run.sh +++ b/examples/advanced-tensorflow/run.sh @@ -5,14 +5,11 @@ echo "Starting server" python server.py & -sleep 3 # Sleep for 3s to give the server enough time to start +sleep 10 # Sleep for 10s to give the server enough time to start and download the dataset -# Ensure that the Keras dataset used in client.py is already cached. -python -c "import tensorflow as tf; tf.keras.datasets.cifar10.load_data()" - -for i in `seq 0 9`; do +for i in $(seq 0 9); do echo "Starting client $i" - python client.py --partition=${i} --toy True & + python client.py --client-id=${i} --toy & done # This will allow you to use CTRL+C to stop all background processes diff --git a/examples/advanced-tensorflow/server.py b/examples/advanced-tensorflow/server.py index e1eb3d4fd8f7..26dde312bee5 100644 --- a/examples/advanced-tensorflow/server.py +++ b/examples/advanced-tensorflow/server.py @@ -4,6 +4,8 @@ import flwr as fl import tensorflow as tf +from flwr_datasets import FederatedDataset + def main() -> None: # Load and compile model for @@ -43,11 +45,11 @@ def main() -> None: def get_evaluate_fn(model): """Return an evaluation function for server-side evaluation.""" - # Load data and model here to avoid the overhead of doing it in `evaluate` itself - (x_train, y_train), _ = tf.keras.datasets.cifar10.load_data() - - # Use the last 5k training examples as a validation set - x_val, y_val = x_train[45000:50000], y_train[45000:50000] + # Load data here to avoid the overhead of doing it in `evaluate` itself + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10}) + test = fds.load_full("test") + test.set_format("numpy") + x_test, y_test = test["img"] / 255.0, test["label"] # The `evaluate` function will be called after every round def evaluate( @@ -56,7 +58,7 @@ def evaluate( config: Dict[str, fl.common.Scalar], ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: model.set_weights(parameters) # Update model with the latest parameters - loss, accuracy = model.evaluate(x_val, y_val) + loss, accuracy = model.evaluate(x_test, y_test) return loss, {"accuracy": accuracy} return evaluate diff --git a/examples/android/README.md b/examples/android/README.md index 7931aa96b0c5..f9f2bb93b8dc 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -54,4 +54,4 @@ poetry run ./run.sh Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`. -When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training. +When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively. diff --git a/pyproject.toml b/pyproject.toml index 0616ffdbeffd..cab083b32325 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,11 +84,11 @@ types-protobuf = "==3.19.18" types-requests = "==2.31.0.10" types-setuptools = "==69.0.0.20240115" clang-format = "==17.0.4" -isort = "==5.12.0" +isort = "==5.13.2" black = { version = "==23.10.1", extras = ["jupyter"] } docformatter = "==1.7.5" mypy = "==1.6.1" -pylint = "==2.13.9" +pylint = "==3.0.3" flake8 = "==5.0.4" pytest = "==7.4.3" pytest-cov = "==4.1.0" @@ -137,7 +137,7 @@ line-length = 88 target-version = ["py38", "py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] -disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +disable = "duplicate-code,too-few-public-methods,useless-import-alias" [tool.pytest.ini_options] minversion = "6.2" @@ -184,7 +184,7 @@ target-version = "py38" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] -ignore = ["B024", "B027"] +ignore = ["B024", "B027", "D205", "D209"] exclude = [ ".bzr", ".direnv", diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index a5b285fbb7fb..91fa5468ae75 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -138,10 +138,12 @@ def _check_actionable_client( client: Optional[Client], client_fn: Optional[ClientFn] ) -> None: if client_fn is None and client is None: - raise Exception("Both `client_fn` and `client` are `None`, but one is required") + raise ValueError( + "Both `client_fn` and `client` are `None`, but one is required" + ) if client_fn is not None and client is not None: - raise Exception( + raise ValueError( "Both `client_fn` and `client` are provided, but only one is allowed" ) @@ -150,6 +152,7 @@ def _check_actionable_client( # pylint: disable=too-many-branches # pylint: disable=too-many-locals # pylint: disable=too-many-statements +# pylint: disable=too-many-arguments def start_client( *, server_address: str, @@ -299,7 +302,7 @@ def single_client_factory( cid: str, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy - raise Exception( + raise ValueError( "Both `client_fn` and `client` are `None`, but one is required" ) return client # Always return the same instance diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 7ef6410debad..56d6308a0fe2 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -41,19 +41,19 @@ class PlainClient(Client): def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit(self, ins: FitIns) -> FitRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate(self, ins: EvaluateIns) -> EvaluateRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() class NeedsWrappingClient(NumPyClient): @@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient): def get_properties(self, config: Config) -> Dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, config: Config) -> NDArrays: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit( self, parameters: NDArrays, config: Config ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config ) -> Tuple[float, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def test_to_client_with_client() -> None: diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index 41b4d676df43..c39b89b31da3 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -117,16 +117,16 @@ def fit( update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)] if "dpfedavg_clip_norm" not in config: - raise Exception("Clipping threshold not supplied by the server.") + raise KeyError("Clipping threshold not supplied by the server.") if not isinstance(config["dpfedavg_clip_norm"], float): - raise Exception("Clipping threshold should be a floating point value.") + raise TypeError("Clipping threshold should be a floating point value.") # Clipping update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"]) if "dpfedavg_noise_stddev" in config: if not isinstance(config["dpfedavg_noise_stddev"], float): - raise Exception( + raise TypeError( "Scale of noise to be added should be a floating point value." ) # Noising @@ -138,7 +138,7 @@ def fit( # Calculating value of norm indicator bit, required for adaptive clipping if "dpfedavg_adaptive_clip_enabled" in config: if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool): - raise Exception( + raise TypeError( "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag." ) metrics["dpfedavg_norm_bit"] = not clipped diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index 13b1948eec07..3599e1dfb254 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -80,8 +80,7 @@ def validate_task_res(task_res: TaskRes) -> bool: initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} # Check if certain fields are already initialized - # pylint: disable-next=too-many-boolean-expressions - if ( + if ( # pylint: disable-next=too-many-boolean-expressions "task_id" in initialized_fields_in_task_res or "group_id" in initialized_fields_in_task_res or "run_id" in initialized_fields_in_task_res diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 2312741f5af6..d67fb90512d4 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) # Return FitRes parameters_prime, num_examples, metrics = results @@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) # Return EvaluateRes loss, num_examples, metrics = results diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d22b246dbd61..87b06dd0be4e 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -143,6 +143,7 @@ def create_node() -> None: }, data=create_node_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -185,6 +186,7 @@ def delete_node() -> None: }, data=delete_node_req_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]: }, data=pull_task_ins_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None: }, data=push_task_res_request_bytes, verify=verify, + timeout=None, ) state[KEY_TASK_INS] = None diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py index efbb00a9d916..4b74c1ace3de 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py @@ -333,7 +333,7 @@ def _share_keys( # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: List[bytes] = [] @@ -341,14 +341,14 @@ def _share_keys( pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): - raise Exception("Some public keys are identical") + raise ValueError("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.sid][0] != state.pk1 or state.public_keys_dict[state.sid][1] != state.pk2 ): - raise Exception( + raise ValueError( "Own public keys are displayed in dict incorrectly, should not happen!" ) @@ -393,7 +393,7 @@ def _collect_masked_input( ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: - raise Exception("Not enough available neighbour clients.") + raise ValueError("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): @@ -409,7 +409,7 @@ def _collect_masked_input( f"from {actual_src} instead of {src}." ) if dst != state.sid: - ValueError( + raise ValueError( f"Client {state.sid}: received an encrypted message" f"for Client {dst} from Client {src}." ) @@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") sids, shares = [], [] sids += active_sids diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py new file mode 100644 index 000000000000..3d40c0488baa --- /dev/null +++ b/src/py/flwr/common/parametersrecord.py @@ -0,0 +1,110 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParametersRecord and Array.""" + + +from dataclasses import dataclass, field +from typing import List, Optional, OrderedDict + + +@dataclass +class Array: + """Array type. + + A dataclass containing serialized data from an array-like or tensor-like object + along with some metadata about it. + + Parameters + ---------- + dtype : str + A string representing the data type of the serialised object (e.g. `np.float32`) + + shape : List[int] + A list representing the shape of the unserialized array-like object. This is + used to deserialize the data (depending on the serialization method) or simply + as a metadata field. + + stype : str + A string indicating the type of serialisation mechanism used to generate the + bytes in `data` from an array-like or tensor-like object. + + data: bytes + A buffer of bytes containing the data. + """ + + dtype: str + shape: List[int] + stype: str + data: bytes + + +@dataclass +class ParametersRecord: + """Parameters record. + + A dataclass storing named Arrays in order. This means that it holds entries as an + OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to + PyTorch's state_dict, but holding serialised tensors instead. + """ + + keep_input: bool + data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) + + def __init__( + self, + array_dict: Optional[OrderedDict[str, Array]] = None, + keep_input: bool = False, + ) -> None: + """Construct a ParametersRecord object. + + Parameters + ---------- + array_dict : Optional[OrderedDict[str, Array]] + A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. If False, the + dictionary passed to `set_parameters()` will be empty once exiting from that + function. This is the desired behaviour when working with very large + models/tensors/arrays. However, if you plan to continue working with your + parameters after adding it to the record, set this flag to True. When set + to True, the data is duplicated in memory. + """ + self.keep_input = keep_input + self.data = OrderedDict() + if array_dict: + self.set_parameters(array_dict) + + def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: + """Add parameters to record. + + Parameters + ---------- + array_dict : OrderedDict[str, Array] + A dictionary that stores serialized array-like or tensor-like objects. + """ + if any(not isinstance(k, str) for k in array_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}") + if any(not isinstance(v, Array) for v in array_dict.values()): + raise TypeError(f"Not all values are of valid type. Expected {Array}") + + if self.keep_input: + # Copy + self.data = OrderedDict(array_dict) + else: + # Add entries to dataclass without duplicating memory + for key in list(array_dict.keys()): + self.data[key] = array_dict[key] + del array_dict[key] diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py index 0088b7397a6d..dc723a2cea86 100644 --- a/src/py/flwr/common/recordset.py +++ b/src/py/flwr/common/recordset.py @@ -14,13 +14,10 @@ # ============================================================================== """RecordSet.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict - -@dataclass -class ParametersRecord: - """Parameters record.""" +from .parametersrecord import ParametersRecord @dataclass @@ -37,9 +34,9 @@ class ConfigsRecord: class RecordSet: """Definition of RecordSet.""" - parameters: Dict[str, ParametersRecord] = {} - metrics: Dict[str, MetricsRecord] = {} - configs: Dict[str, ConfigsRecord] = {} + parameters: Dict[str, ParametersRecord] = field(default_factory=dict) + metrics: Dict[str, MetricsRecord] = field(default_factory=dict) + configs: Dict[str, ConfigsRecord] = field(default_factory=dict) def set_parameters(self, name: str, record: ParametersRecord) -> None: """Add a ParametersRecord.""" diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py new file mode 100644 index 000000000000..90c06dcdb109 --- /dev/null +++ b/src/py/flwr/common/recordset_test.py @@ -0,0 +1,147 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet tests.""" + + +from typing import Callable, List, OrderedDict, Type, Union + +import numpy as np +import pytest + +from .parameter import ndarrays_to_parameters, parameters_to_ndarrays +from .parametersrecord import Array, ParametersRecord +from .recordset_utils import ( + parameters_to_parametersrecord, + parametersrecord_to_parameters, +) +from .typing import NDArray, NDArrays, Parameters + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +def ndarray_to_array(ndarray: NDArray) -> Array: + """Represent NumPy ndarray as Array.""" + return Array( + data=ndarray.tobytes(), + dtype=str(ndarray.dtype), + stype="numpy.ndarray.tobytes", + shape=list(ndarray.shape), + ) + + +def test_ndarray_to_array() -> None: + """Test creation of Array object from NumPy ndarray.""" + shape = (2, 7, 9) + arr = np.eye(*shape) + + array = ndarray_to_array(arr) + + arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) + + assert np.array_equal(arr, arr_) + + +def test_parameters_to_array_and_back() -> None: + """Test conversion between legacy Parameters and Array.""" + ndarrays = get_ndarrays() + + # Array represents a single array, unlike Paramters, which represent a + # list of arrays + ndarray = ndarrays[0] + + parameters = ndarrays_to_parameters([ndarray]) + + array = Array( + data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[] + ) + + parameters = Parameters(tensors=[array.data], tensor_type=array.stype) + + ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] + + assert np.array_equal(ndarray, ndarray_) + + +def test_parameters_to_parametersrecord_and_back() -> None: + """Test conversion between legacy Parameters and ParametersRecords.""" + ndarrays = get_ndarrays() + + parameters = ndarrays_to_parameters(ndarrays) + + params_record = parameters_to_parametersrecord(parameters=parameters) + + parameters_ = parametersrecord_to_parameters(params_record) + + ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + + for arr, arr_ in zip(ndarrays, ndarrays_): + assert np.array_equal(arr, arr_) + + +def test_set_parameters_while_keeping_intputs() -> None: + """Tests keep_input functionality in ParametersRecord.""" + # Adding parameters to a record that doesn't erase entries in the input `array_dict` + p_record = ParametersRecord(keep_input=True) + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + # Creating a second parametersrecord passing the same `array_dict` (not erased) + p_record_2 = ParametersRecord(array_dict) + assert p_record.data == p_record_2.data + + # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) + assert len(array_dict) == 0 + + +def test_set_parameters_with_correct_types() -> None: + """Test adding dictionary of Arrays to ParametersRecord.""" + p_record = ParametersRecord() + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # correct key, incorrect value + (str, lambda x: x.tolist()), # correct key, incorrect value + (int, ndarray_to_array), # incorrect key, correct value + (int, lambda x: x), # incorrect key, incorrect value + (int, lambda x: x.tolist()), # incorrect key, incorrect value + ], +) +def test_set_parameters_with_incorrect_types( + key_type: Type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, List[float]]], +) -> None: + """Test adding dictionary of unsupported types to ParametersRecord.""" + p_record = ParametersRecord() + + array_dict = { + key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays()) + } + + with pytest.raises(TypeError): + p_record.set_parameters(array_dict) # type: ignore diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py new file mode 100644 index 000000000000..c1e724fa2758 --- /dev/null +++ b/src/py/flwr/common/recordset_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import OrderedDict + +from .parametersrecord import Array, ParametersRecord +from .typing import Parameters + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool = False +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warning: Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadta must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool (default: False) + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record.data[key].data) + + if not keep_input: + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool = False +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + num_arrays = len(parameters.tensors) + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + p_record.set_parameters( + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) + ) + + return p_record diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index a60fff57e7bf..5441e766983a 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -156,6 +156,7 @@ class RetryInvoker: >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1) """ + # pylint: disable-next=too-many-arguments def __init__( self, wait_factory: Callable[[], Generator[float, None, None]], diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index c8c73e87e04a..59f5387b0a07 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -59,7 +59,9 @@ def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessa server_message.evaluate_ins, ) ) - raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ServerMessage, cannot serialize to ProtoBuf" + ) def server_message_from_proto( @@ -91,7 +93,7 @@ def server_message_from_proto( server_message_proto.evaluate_ins, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) @@ -125,7 +127,9 @@ def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessa client_message.evaluate_res, ) ) - raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ClientMessage, cannot serialize to ProtoBuf" + ) def client_message_from_proto( @@ -157,7 +161,7 @@ def client_message_from_proto( client_message_proto.evaluate_res, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf" ) @@ -474,7 +478,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar: if isinstance(scalar, str): return Scalar(string=scalar) - raise Exception( + raise ValueError( f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})" ) @@ -518,7 +522,7 @@ def _check_value(value: typing.Value) -> None: for element in value: if isinstance(element, data_type): continue - raise Exception( + raise TypeError( f"Inconsistent type: the types of elements in the list must " f"be the same (expected {data_type}, but got {type(element)})." ) diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index 2c3a6d2ccddf..82747e5afb2c 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Flower Driver app tests.""" -# pylint: disable=no-self-use import threading diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py index 92b4230a3932..8f75bbf78362 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/driver/driver_test.py @@ -139,6 +139,7 @@ def test_del_with_initialized_driver(self) -> None: self.driver._get_grpc_driver_and_run_id() # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert @@ -147,6 +148,7 @@ def test_del_with_initialized_driver(self) -> None: def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index b6d42fe799d5..627b95cdb1b4 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -89,7 +89,7 @@ def create_run(self, req: CreateRunRequest) -> CreateRunResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: CreateRunResponse = self.stub.CreateRun(request=req) @@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) @@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) @@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py index 6ae38ea3d805..4e68499f018d 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py @@ -113,7 +113,7 @@ def _transition(self, next_status: Status) -> None: ): self._status = next_status else: - raise Exception(f"Invalid transition: {self._status} to {next_status}") + raise ValueError(f"Invalid transition: {self._status} to {next_status}") self._cv.notify_all() @@ -129,7 +129,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._raise_if_closed() if self._status != Status.AWAITING_INS_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._ins_wrapper = ins_wrapper # Write self._transition(Status.INS_WRAPPER_AVAILABLE) @@ -146,7 +146,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._transition(Status.AWAITING_INS_WRAPPER) if res_wrapper is None: - raise Exception("ResWrapper can not be None") + raise ValueError("ResWrapper can not be None") return res_wrapper @@ -170,7 +170,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]: self._transition(Status.AWAITING_RES_WRAPPER) if ins_wrapper is None: - raise Exception("InsWrapper can not be None") + raise ValueError("InsWrapper can not be None") yield ins_wrapper @@ -180,7 +180,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None: self._raise_if_closed() if self._status != Status.AWAITING_RES_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._res_wrapper = res_wrapper # Write self._transition(Status.RES_WRAPPER_AVAILABLE) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py index 18a2144072ed..bcfbe6e6fac8 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py @@ -70,6 +70,7 @@ def test_workflow_successful() -> None: _ = next(ins_wrapper_iterator) bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage())) except Exception as exception: + # pylint: disable-next=broad-exception-raised raise Exception from exception # Wait until worker_thread is finished diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py index 1c737d31c7fc..0fa6f82a89b5 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py +++ b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py @@ -166,6 +166,6 @@ def _call_client_proxy( evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res) return ClientMessage(evaluate_res=evaluate_res_proto) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 63ec1021ff5c..9b5c03aeeaf9 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -47,14 +47,14 @@ class SuccessClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise an error because this method is not expected to be called.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise a error because this method is not expected to be called.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" @@ -87,26 +87,26 @@ class FailingClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def test_fit_clients() -> None: diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index 4f66be3ff262..26f326819971 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -134,7 +134,7 @@ def query( ) -> List[Dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: - raise Exception("State is not initialized.") + raise AttributeError("State is not initialized.") if data is None: data = [] @@ -459,7 +459,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """ if self.conn is None: - raise Exception("State not intitialized") + raise AttributeError("State not intitialized") with self.conn: self.conn.execute(query_1, data) diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py index efdd288fc308..a3f899386011 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/state/sqlite_state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Test for utility functions.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import unittest diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 88b4b53aed4c..204b4ba97b5f 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import tempfile import unittest diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 4eb76111b266..c668b55eebe6 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -27,7 +27,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training - num_examples_total = sum([num_examples for _, num_examples in results]) + num_examples_total = sum(num_examples for (_, num_examples) in results) # Create a list of weights, each multiplied by the related number of examples weighted_weights = [ @@ -45,7 +45,7 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: """Compute in-place weighted average.""" # Count total examples - num_examples_total = sum([fit_res.num_examples for _, fit_res in results]) + num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) # Compute scaling factors for each result scaling_factors = [ @@ -95,9 +95,9 @@ def aggregate_krum( # For each client, take the n-f-2 closest parameters vectors num_closest = max(1, len(weights) - num_malicious - 2) closest_indices = [] - for i, _ in enumerate(distance_matrix): + for distance in distance_matrix: closest_indices.append( - np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203 + np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203 ) # Compute the score for each client, that is the sum of the distances @@ -202,7 +202,7 @@ def aggregate_bulyan( def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" - num_total_evaluation_examples = sum([num_examples for num_examples, _ in results]) + num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] return sum(weighted_losses) / num_total_evaluation_examples @@ -233,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray: """ flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights]) distance_matrix = np.zeros((len(weights), len(weights))) - for i, _ in enumerate(flat_w): - for j, _ in enumerate(flat_w): - delta = flat_w[i] - flat_w[j] + for i, flat_w_i in enumerate(flat_w): + for j, flat_w_j in enumerate(flat_w): + delta = flat_w_i - flat_w_j norm = np.linalg.norm(delta) distance_matrix[i, j] = norm**2 return distance_matrix diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 3269735e9d73..8b3278cc9ba0 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -91,7 +91,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: norm_bit_set_count = 0 for client_proxy, fit_res in results: if "dpfedavg_norm_bit" not in fit_res.metrics: - raise Exception( + raise KeyError( f"Indicator bit not returned by client with id {client_proxy.cid}." ) if fit_res.metrics["dpfedavg_norm_bit"]: diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index 0154cfd79fc5..f2f1c206f3de 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -46,11 +46,11 @@ def __init__( self.num_sampled_clients = num_sampled_clients if clip_norm <= 0: - raise Exception("The clipping threshold should be a positive value.") + raise ValueError("The clipping threshold should be a positive value.") self.clip_norm = clip_norm if noise_multiplier < 0: - raise Exception("The noise multiplier should be a non-negative value.") + raise ValueError("The noise multiplier should be a non-negative value.") self.noise_multiplier = noise_multiplier self.server_side_noising = server_side_noising diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index e890f7216020..6678b7ced114 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays: """Convert parameters object to NumPy weights.""" return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors] - # pylint: disable=R0201 def ndarray_to_bytes(self, ndarray: NDArray) -> bytes: """Serialize NumPy array to bytes.""" return ndarray.tobytes() - # pylint: disable=R0201 def bytes_to_ndarray(self, tensor: bytes) -> NDArray: """Deserialize NumPy array from bytes.""" ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32) diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index 7a5bf1425b44..17e979d92beb 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -36,7 +36,7 @@ class FedMedian(FedAvg): - """Configurable FedAvg with Momentum strategy implementation.""" + """Configurable FedMedian strategy implementation.""" def __repr__(self) -> str: """Compute a string representation of the strategy.""" diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 94a67fbcbfae..758e8e608e9f 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float: hs_ffl = [] if self.pre_weights is None: - raise Exception("QffedAvg pre_weights are None in aggregate_fit") + raise AttributeError("QffedAvg pre_weights are None in aggregate_fit") weights_before = self.pre_weights eval_result = self.evaluate(