From 03c9f797023258148739adbc58852e458b2ecd5e Mon Sep 17 00:00:00 2001 From: mohammadnaseri Date: Tue, 28 May 2024 22:56:42 +0100 Subject: [PATCH 1/7] fix(examples) Fix TensorFlow-Privacy example (#3488) Co-authored-by: Taner Topal Co-authored-by: jafermarq --- README.md | 1 + examples/dp-sgd-mnist/README.md | 105 ------------- examples/dp-sgd-mnist/client.py | 162 --------------------- examples/dp-sgd-mnist/common.py | 103 ------------- examples/dp-sgd-mnist/pyproject.toml | 20 --- examples/dp-sgd-mnist/requirements.txt | 4 - examples/dp-sgd-mnist/server.py | 56 ------- examples/tensorflow-privacy/README.md | 60 ++++++++ examples/tensorflow-privacy/client.py | 139 ++++++++++++++++++ examples/tensorflow-privacy/pyproject.toml | 22 +++ examples/tensorflow-privacy/server.py | 22 +++ 11 files changed, 244 insertions(+), 450 deletions(-) delete mode 100644 examples/dp-sgd-mnist/README.md delete mode 100644 examples/dp-sgd-mnist/client.py delete mode 100644 examples/dp-sgd-mnist/common.py delete mode 100644 examples/dp-sgd-mnist/pyproject.toml delete mode 100644 examples/dp-sgd-mnist/requirements.txt delete mode 100644 examples/dp-sgd-mnist/server.py create mode 100644 examples/tensorflow-privacy/README.md create mode 100644 examples/tensorflow-privacy/client.py create mode 100644 examples/tensorflow-privacy/pyproject.toml create mode 100644 examples/tensorflow-privacy/server.py diff --git a/README.md b/README.md index dabeb0a9eba9..a010abfcb2f5 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples): - [Flower through Docker Compose and with Grafana dashboard](https://github.com/adap/flower/tree/main/examples/flower-via-docker-compose) - [Flower with KaplanMeierFitter from the lifelines library](https://github.com/adap/flower/tree/main/examples/federated-kaplan-meier-fitter) - [Sample Level Privacy with Opacus](https://github.com/adap/flower/tree/main/examples/opacus) +- [Sample Level Privacy with TensorFlow-Privacy](https://github.com/adap/flower/tree/main/examples/tensorflow-privacy) ## Community diff --git a/examples/dp-sgd-mnist/README.md b/examples/dp-sgd-mnist/README.md deleted file mode 100644 index fcf602306c90..000000000000 --- a/examples/dp-sgd-mnist/README.md +++ /dev/null @@ -1,105 +0,0 @@ -# Flower Example Using Tensorflow/Keras and Tensorflow Privacy - -This example of Flower trains a federeated learning system where clients are free to choose -between non-private and private optimizers. Specifically, clients can choose to train Keras models using the standard SGD optimizer or __Differentially Private__ SGD (DPSGD) from [Tensorflow Privacy](https://github.com/tensorflow/privacy). For this task we use the MNIST dataset which is split artificially among clients. This causes the dataset to be i.i.d. The clients using DPSGD track the amount of privacy spent and display it at the end of the training. - -This example is adapted from https://github.com/tensorflow/privacy/blob/master/tutorials/mnist_dpsgd_tutorial_keras.py - -## Project Setup - -Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: - -```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/dp-sgd-mnist . && rm -rf flower && cd dp-sgd-mnist -``` - -This will create a new directory called `dp-sgd-mnist` containing the following files: - -```shell --- pyproject.toml --- requirements.txt --- client.py --- server.py --- common.py --- README.md -``` - -### Installing Dependencies - -Project dependencies (such as `tensorflow` and `tensorflow-privacy`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. - -#### Poetry - -```shell -poetry install -poetry shell -``` - -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: - -```shell -poetry run python3 -c "import flwr" -``` - -If you don't see any errors you're good to go! - -#### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt -``` - -## Run Federated Learning with TensorFlow/Keras/Tensorflow-Privacy and Flower - -Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: - -```shell -# terminal 1 -poetry run python3 server.py -``` - -Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: - -```shell -# terminal 2 -poetry run python3 client.py --partition 0 -``` - -```shell -# terminal 3 -# We will set the second client to use `dpsgd` -poetry run python3 client.py --partition 1 --dpsgd True -``` - -Alternatively you can run all of it in one shell as follows: - -```shell -poetry run python3 server.py & -poetry run python3 client.py --partition 0 & -poetry run python3 client.py --partition 1 --dpsgd True -``` - -It should be noted that when starting more than 2 clients, the total number of clients you intend to run and the data partition the client is expected to use must be specified. This is because the `num_clients` is used to split the dataset. - -For example, in case of 3 clients - -```shell -poetry run python3 server.py --num-clients 3 & -poetry run python3 client.py --num-clients 3 --partition 0 --dpsgd True & -poetry run python3 client.py --num-clients 3 --partition 1 & -poetry run python3 client.py --num-clients 3 --partition 2 --dpsgd True -``` - -Additional training parameters for the client and server can be referenced by passing `--help` to either script. - -Other things to note is that when all clients are running `dpsgd`, either train for more rounds or increase the local epochs to achieve optimal performance. You shall need to carefully tune the hyperparameters to your specific setup. - -```shell -poetry run python3 server.py --num-clients 3 --num-rounds 20 -``` - -```shell -poetry run python3 client.py --num-clients 3 --partition 1 --local-epochs 4 --dpsgd True -``` diff --git a/examples/dp-sgd-mnist/client.py b/examples/dp-sgd-mnist/client.py deleted file mode 100644 index cffe0e241645..000000000000 --- a/examples/dp-sgd-mnist/client.py +++ /dev/null @@ -1,162 +0,0 @@ -import argparse -import os - -import tensorflow as tf -from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import ( - VectorizedDPKerasSGDOptimizer, -) - -import flwr as fl - -import common - - -# Make TensorFlow logs less verbose -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - -# global for tracking privacy -PRIVACY_LOSS = 0 - - -# Define Flower client -class MnistClient(fl.client.NumPyClient): - def __init__(self, model, x_train, y_train, x_test, y_test, args): - self.model = model - self.x_train, self.y_train = x_train, y_train - self.x_test, self.y_test = x_test, y_test - self.batch_size = args.batch_size - self.local_epochs = args.local_epochs - self.dpsgd = args.dpsgd - - if args.dpsgd: - self.noise_multiplier = args.noise_multiplier - if args.batch_size % args.microbatches != 0: - raise ValueError( - "Number of microbatches should divide evenly batch_size" - ) - optimizer = VectorizedDPKerasSGDOptimizer( - l2_norm_clip=args.l2_norm_clip, - noise_multiplier=args.noise_multiplier, - num_microbatches=args.microbatches, - learning_rate=args.learning_rate, - ) - # Compute vector of per-example loss rather than its mean over a minibatch. - loss = tf.keras.losses.CategoricalCrossentropy( - from_logits=True, reduction=tf.losses.Reduction.NONE - ) - else: - optimizer = tf.keras.optimizers.SGD(learning_rate=args.learning_rate) - loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) - - # Compile model with Keras - model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"]) - - def get_parameters(self, config): - """Get parameters of the local model.""" - raise Exception("Not implemented (server-side parameter initialization)") - - def fit(self, parameters, config): - """Train parameters on the locally held training set.""" - # Update local model parameters - global PRIVACY_LOSS - if self.dpsgd: - privacy_spent = common.compute_epsilon( - self.local_epochs, - len(self.x_train), - self.batch_size, - self.noise_multiplier, - ) - PRIVACY_LOSS += privacy_spent - - self.model.set_weights(parameters) - # Train the model - self.model.fit( - self.x_train, - self.y_train, - epochs=self.local_epochs, - batch_size=self.batch_size, - ) - - return self.model.get_weights(), len(self.x_train), {} - - def evaluate(self, parameters, config): - """Evaluate parameters on the locally held test set.""" - - # Update local model with global parameters - self.model.set_weights(parameters) - - # Evaluate global model parameters on the local test data and return results - loss, accuracy = self.model.evaluate(self.x_test, self.y_test) - num_examples_test = len(self.x_test) - return loss, num_examples_test, {"accuracy": accuracy} - - -def main(args) -> None: - # Load Keras model - model = common.create_cnn_model() - - # Load a subset of MNIST to simulate the local data partition - (x_train, y_train), (x_test, y_test) = common.load(args.num_clients)[args.partition] - - # drop samples to form exact batches for dpsgd - # this is necessary since dpsgd is sensitive to uneven batches - # due to microbatching - if args.dpsgd and x_train.shape[0] % args.batch_size != 0: - drop_num = x_train.shape[0] % args.batch_size - x_train = x_train[:-drop_num] - y_train = y_train[:-drop_num] - - # Start Flower client - client = MnistClient(model, x_train, y_train, x_test, y_test, args) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) - if args.dpsgd: - print("Privacy Loss: ", PRIVACY_LOSS) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Flower Client") - parser.add_argument( - "--num-clients", - default=2, - type=int, - help="Total number of fl participants, requied to get correct partition", - ) - parser.add_argument( - "--partition", - type=int, - required=True, - help="Data Partion to train on. Must be less than number of clients", - ) - parser.add_argument( - "--local-epochs", - default=1, - type=int, - help="Total number of local epochs to train", - ) - parser.add_argument("--batch-size", default=32, type=int, help="Batch size") - parser.add_argument( - "--learning-rate", default=0.15, type=float, help="Learning rate for training" - ) - # DPSGD specific arguments - parser.add_argument( - "--dpsgd", - default=False, - type=bool, - help="If True, train with DP-SGD. If False, " "train with vanilla SGD.", - ) - parser.add_argument("--l2-norm-clip", default=1.0, type=float, help="Clipping norm") - parser.add_argument( - "--noise-multiplier", - default=1.1, - type=float, - help="Ratio of the standard deviation to the clipping norm", - ) - parser.add_argument( - "--microbatches", - default=32, - type=int, - help="Number of microbatches " "(must evenly divide batch_size)", - ) - args = parser.parse_args() - - main(args) diff --git a/examples/dp-sgd-mnist/common.py b/examples/dp-sgd-mnist/common.py deleted file mode 100644 index fbb2f6374203..000000000000 --- a/examples/dp-sgd-mnist/common.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import List, Tuple - -import numpy as np -import tensorflow as tf - -from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp -from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent - - -XY = Tuple[np.ndarray, np.ndarray] -XYList = List[XY] -PartitionedDataset = List[Tuple[XY, XY]] - - -def compute_epsilon( - epochs: int, num_train_examples: int, batch_size: int, noise_multiplier: float -) -> float: - """Computes epsilon value for given hyperparameters. - - Based on - github.com/tensorflow/privacy/blob/master/tutorials/mnist_dpsgd_tutorial_keras.py - """ - if noise_multiplier == 0.0: - return float("inf") - steps = epochs * num_train_examples // batch_size - orders = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)) - sampling_probability = batch_size / num_train_examples - rdp = compute_rdp( - q=sampling_probability, - noise_multiplier=noise_multiplier, - steps=steps, - orders=orders, - ) - # Delta is set to approximate 1 / (number of training points). - return get_privacy_spent(orders, rdp, target_delta=1 / num_train_examples)[0] - - -def create_cnn_model() -> tf.keras.Model: - """Returns a sequential keras CNN Model.""" - return tf.keras.Sequential( - [ - tf.keras.layers.Conv2D( - 16, - 8, - strides=2, - padding="same", - activation="relu", - input_shape=(28, 28, 1), - ), - tf.keras.layers.MaxPool2D(2, 1), - tf.keras.layers.Conv2D( - 32, 4, strides=2, padding="valid", activation="relu" - ), - tf.keras.layers.MaxPool2D(2, 1), - tf.keras.layers.Flatten(), - tf.keras.layers.Dense(32, activation="relu"), - tf.keras.layers.Dense(10), - ] - ) - - -def shuffle(X: np.ndarray, y: np.ndarray) -> XY: - """Shuffle X and y.""" - rng = np.random.default_rng() - idx = rng.permutation(len(X)) - return X[idx], y[idx] - - -def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList: - """Split X and y into a number of partitions.""" - return list( - zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions)) - ) - - -def preprocess(X: np.ndarray, y: np.ndarray) -> XY: - """Basic preprocessing for MNIST dataset.""" - X = np.array(X, dtype=np.float32) / 255 - X = X.reshape((X.shape[0], 28, 28, 1)) - - y = np.array(y, dtype=np.int32) - y = tf.keras.utils.to_categorical(y, num_classes=10) - - return X, y - - -def create_partitions(source_dataset: XY, num_partitions: int) -> XYList: - """Create partitioned version of a source dataset.""" - X, y = source_dataset - X, y = shuffle(X, y) - X, y = preprocess(X, y) - xy_partitions = partition(X, y, num_partitions) - return xy_partitions - - -def load( - num_partitions: int, -) -> PartitionedDataset: - """Create partitioned version of MNIST.""" - xy_train, xy_test = tf.keras.datasets.mnist.load_data() - xy_train_partitions = create_partitions(xy_train, num_partitions) - xy_test_partitions = create_partitions(xy_test, num_partitions) - return list(zip(xy_train_partitions, xy_test_partitions)) diff --git a/examples/dp-sgd-mnist/pyproject.toml b/examples/dp-sgd-mnist/pyproject.toml deleted file mode 100644 index 161952fd2aa4..000000000000 --- a/examples/dp-sgd-mnist/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "dp-sgd-mnist" -version = "0.1.0" -description = "Federated training with Tensorflow Privacy" -authors = [ - "The Flower Authors ", - "Kaushik Amar Das ", -] - -[tool.poetry.dependencies] -python = ">=3.8,<3.11" -# flwr = { path = "../../", develop = true } # Development -flwr = ">=1.0,<2.0" -tensorflow-cpu = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "platform_machine == \"x86_64\"" } -tensorflow-macos = { version = ">=2.9.1,<2.11.1 || >2.11.1", markers = "sys_platform == \"darwin\" and platform_machine == \"arm64\"" } -tensorflow-privacy = "0.8.10" diff --git a/examples/dp-sgd-mnist/requirements.txt b/examples/dp-sgd-mnist/requirements.txt deleted file mode 100644 index bd5478de342b..000000000000 --- a/examples/dp-sgd-mnist/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -flwr>=1.0, <2.0 -tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == "darwin" and platform_machine == "arm64" -tensorflow-cpu>=2.9.1, != 2.11.1 ; platform_machine == "x86_64" -tensorflow-privacy==0.8.10 diff --git a/examples/dp-sgd-mnist/server.py b/examples/dp-sgd-mnist/server.py deleted file mode 100644 index 5f6f7163834e..000000000000 --- a/examples/dp-sgd-mnist/server.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import os - -import tensorflow as tf - -import flwr as fl - -import common - -# Make TensorFlow logs less verbose -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - - -def get_evaluate_fn(model): - """Return an evaluation function for server-side evaluation.""" - - # Load test data here to avoid the overhead of doing it in `evaluate` itself - _, test = tf.keras.datasets.mnist.load_data() - test_data, test_labels = test - - # preprocessing - test_data, test_labels = common.preprocess(test_data, test_labels) - - # The `evaluate` function will be called after every round - def evaluate(weights: fl.common.NDArrays): - model.set_weights(weights) # Update model with the latest parameters - loss, accuracy = model.evaluate(test_data, test_labels) - return loss, {"accuracy": accuracy} - - return evaluate - - -def main(args) -> None: - model = common.create_cnn_model() - loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True) - model.compile("sgd", loss=loss, metrics=["accuracy"]) - strategy = fl.server.strategy.FedAvg( - fraction_fit=args.fraction_fit, - min_available_clients=args.num_clients, - evaluate_fn=get_evaluate_fn(model), - initial_parameters=fl.common.ndarrays_to_parameters(model.get_weights()), - ) - fl.server.start_server( - server_address="0.0.0.0:8080", - strategy=strategy, - config=fl.server.ServerConfig(num_rounds=args.num_rounds), - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Server Script") - parser.add_argument("--num-clients", default=2, type=int) - parser.add_argument("--num-rounds", default=1, type=int) - parser.add_argument("--fraction-fit", default=1.0, type=float) - args = parser.parse_args() - main(args) diff --git a/examples/tensorflow-privacy/README.md b/examples/tensorflow-privacy/README.md new file mode 100644 index 000000000000..a1f1be00f6b0 --- /dev/null +++ b/examples/tensorflow-privacy/README.md @@ -0,0 +1,60 @@ +# Training with Sample-Level Differential Privacy using TensorFlow-Privacy Engine + +In this example, we demonstrate how to train a model with sample-level differential privacy (DP) using Flower. We employ TensorFlow and integrate the tensorflow-privacy Engine to achieve sample-level differential privacy. This setup ensures robust privacy guarantees during the client training phase. + +For more information about DP in Flower please refer to the [tutorial](https://flower.ai/docs/framework/how-to-use-differential-privacy.html). For additional information about tensorflow-privacy, visit the official [website](https://www.tensorflow.org/responsible_ai/privacy/guide). + +## Environments Setup + +Start by cloning the example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/tensorflow-privacy . && rm -rf flower && cd tensorflow-privacy +``` + +This will create a new directory called `tensorflow-privacy` containing the following files: + +```shell +-- pyproject.toml +-- client.py +-- server.py +-- README.md +``` + +### Installing dependencies + +Project dependencies are defined in `pyproject.toml`. Install them with: + +```shell +pip install . +``` + +## Run Flower with tensorflow-privacy and TensorFlow + +### 1. Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### 2. Start the long-running Flower clients (SuperNodes) + +Start 2 Flower `SuperNodes` in 2 separate terminal windows, using: + +```bash +flower-client-app client:appA --insecure +``` + +```bash +flower-client-app client:appB --insecure +``` + +tensorflow-privacy hyperparameters can be passed for each client in `ClientApp` instantiation (in `client.py`). In this example, `noise_multiplier=1.5` and `noise_multiplier=1` are used for the first and second client respectively. + +### 3. Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` diff --git a/examples/tensorflow-privacy/client.py b/examples/tensorflow-privacy/client.py new file mode 100644 index 000000000000..1032d12e69a2 --- /dev/null +++ b/examples/tensorflow-privacy/client.py @@ -0,0 +1,139 @@ +import argparse +import os +from flwr.client import ClientApp, NumPyClient +import tensorflow as tf +from flwr_datasets import FederatedDataset +import tensorflow_privacy + +from tensorflow_privacy.privacy.analysis.compute_dp_sgd_privacy_lib import ( + compute_dp_sgd_privacy_statement, +) + +# Make TensorFlow log less verbose +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + +def load_data(partition_id, batch_size): + fds = FederatedDataset(dataset="cifar10", partitioners={"train": 2}) + partition = fds.load_partition(partition_id, "train") + partition.set_format("numpy") + + # Divide data on each node: 80% train, 20% test + partition = partition.train_test_split(test_size=0.2, seed=42) + x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + + # Adjust the size of the training dataset to make it evenly divisible by the batch size + remainder = len(x_train) % batch_size + if remainder != 0: + x_train = x_train[:-remainder] + y_train = y_train[:-remainder] + + return (x_train, y_train), (x_test, y_test) + + +class FlowerClient(NumPyClient): + def __init__( + self, + model, + train_data, + test_data, + l2_norm_clip, + noise_multiplier, + num_microbatches, + learning_rate, + batch_size, + ) -> None: + super().__init__() + self.model = model + self.x_train, self.y_train = train_data + self.x_test, self.y_test = test_data + self.noise_multiplier = noise_multiplier + self.l2_norm_clip = l2_norm_clip + self.num_microbatches = num_microbatches + self.learning_rate = learning_rate + self.batch_size = batch_size + if self.batch_size % self.num_microbatches != 0: + raise ValueError( + f"Batch size {self.batch_size} is not divisible by the number of microbatches {self.num_microbatches}" + ) + + self.optimizer = tensorflow_privacy.DPKerasSGDOptimizer( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + num_microbatches=num_microbatches, + learning_rate=learning_rate, + ) + loss = tf.keras.losses.SparseCategoricalCrossentropy( + reduction=tf.losses.Reduction.NONE + ) + self.model.compile(optimizer=self.optimizer, loss=loss, metrics=["accuracy"]) + + def get_parameters(self, config): + return self.model.get_weights() + + def fit(self, parameters, config): + self.model.set_weights(parameters) + + self.model.fit( + self.x_train, + self.y_train, + epochs=1, + batch_size=self.batch_size, + ) + + compute_dp_sgd_privacy_statement( + number_of_examples=self.x_train.shape[0], + batch_size=self.batch_size, + num_epochs=1, + noise_multiplier=self.noise_multiplier, + delta=1e-5, + ) + + return self.model.get_weights(), len(self.x_train), {} + + def evaluate(self, parameters, config): + self.model.set_weights(parameters) + self.model.compile( + optimizer=self.optimizer, + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + loss, accuracy = self.model.evaluate(self.x_test, self.y_test) + return loss, len(self.x_test), {"accuracy": accuracy} + + +def client_fn_parameterized( + partition_id, + noise_multiplier, + l2_norm_clip=1.0, + num_microbatches=64, + learning_rate=0.01, + batch_size=64, +): + def client_fn(cid: str): + model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) + train_data, test_data = load_data( + partition_id=partition_id, batch_size=batch_size + ) + return FlowerClient( + model, + train_data, + test_data, + noise_multiplier, + l2_norm_clip, + num_microbatches, + learning_rate, + batch_size, + ).to_client() + + return client_fn + + +appA = ClientApp( + client_fn=client_fn_parameterized(partition_id=0, noise_multiplier=1.0), +) + +appB = ClientApp( + client_fn=client_fn_parameterized(partition_id=1, noise_multiplier=1.5), +) diff --git a/examples/tensorflow-privacy/pyproject.toml b/examples/tensorflow-privacy/pyproject.toml new file mode 100644 index 000000000000..884ba3b5f07b --- /dev/null +++ b/examples/tensorflow-privacy/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "tensorflow-privacy-fl" +version = "0.1.0" +description = "Sample-level Differential Privacy with Tensorflow-Privacy in Flower" +authors = [ + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +dependencies = [ + "flwr>=1.8.0,<2.0", + "flwr-datasets[vision]>=0.1.0,<1.0.0", + "tensorflow-estimator~=2.4", + "tensorflow-probability~=0.22.0", + "tensorflow>=2.4.0,<=2.15.0", + "tensorflow-privacy == 0.9.0" +] + +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/examples/tensorflow-privacy/server.py b/examples/tensorflow-privacy/server.py new file mode 100644 index 000000000000..1e399fa7e833 --- /dev/null +++ b/examples/tensorflow-privacy/server.py @@ -0,0 +1,22 @@ +from typing import List, Tuple + +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.common import Metrics + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] + examples = [num_examples for num_examples, _ in metrics] + + return {"accuracy": sum(accuracies) / sum(examples)} + + +strategy = FedAvg(evaluate_metrics_aggregation_fn=weighted_average) + +config = ServerConfig(num_rounds=3) + +app = ServerApp( + config=config, + strategy=strategy, +) From a443f86ee7a59e4dbdb9c76200cc924d19edbc29 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Wed, 29 May 2024 12:47:07 +0200 Subject: [PATCH 2/7] break(datasets) Rename resplitter parameter and type to preprocessor (#3476) --- datasets/flwr_datasets/__init__.py | 4 +- datasets/flwr_datasets/federated_dataset.py | 22 +++--- .../flwr_datasets/federated_dataset_test.py | 14 ++-- .../{resplitter => preprocessor}/__init__.py | 14 ++-- .../divider.py} | 12 +-- .../divider_test.py} | 74 +++++++------------ .../merger.py} | 8 +- .../merger_test.py} | 46 ++++++------ .../preprocessor.py} | 4 +- datasets/flwr_datasets/utils.py | 18 ++--- 10 files changed, 100 insertions(+), 116 deletions(-) rename datasets/flwr_datasets/{resplitter => preprocessor}/__init__.py (76%) rename datasets/flwr_datasets/{resplitter/divide_resplitter.py => preprocessor/divider.py} (98%) rename datasets/flwr_datasets/{resplitter/divide_resplitter_test.py => preprocessor/divider_test.py} (79%) rename datasets/flwr_datasets/{resplitter/merge_resplitter.py => preprocessor/merger.py} (96%) rename datasets/flwr_datasets/{resplitter/merge_resplitter_test.py => preprocessor/merger_test.py} (81%) rename datasets/flwr_datasets/{resplitter/resplitter.py => preprocessor/preprocessor.py} (91%) diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index 0b9a6685427b..2d6ecb414498 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -15,7 +15,7 @@ """Flower Datasets main package.""" -from flwr_datasets import partitioner, resplitter +from flwr_datasets import partitioner, preprocessor from flwr_datasets import utils as utils from flwr_datasets.common.version import package_version as _package_version from flwr_datasets.federated_dataset import FederatedDataset @@ -23,7 +23,7 @@ __all__ = [ "FederatedDataset", "partitioner", - "resplitter", + "preprocessor", "utils", ] diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 6c41eaa3562f..5d98d01d4941 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -20,11 +20,11 @@ import datasets from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import Partitioner -from flwr_datasets.resplitter import Resplitter +from flwr_datasets.preprocessor import Preprocessor from flwr_datasets.utils import ( _check_if_dataset_tested, + _instantiate_merger_if_needed, _instantiate_partitioners, - _instantiate_resplitter_if_needed, ) @@ -45,9 +45,11 @@ class FederatedDataset: subset : str Secondary information regarding the dataset, most often subset or version (that is passed to the name in datasets.load_dataset). - resplitter : Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] - `Callable` that transforms `DatasetDict` splits, or configuration dict for - `MergeResplitter`. + preprocessor : Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] + `Callable` that transforms `DatasetDict` by resplitting, removing + features, creating new features, performing any other preprocessing operation, + or configuration dict for `Merger`. Applied after shuffling. If None, + no operation is applied. partitioners : Dict[str, Union[Partitioner, int]] A dictionary mapping the Dataset split (a `str`) to a `Partitioner` or an `int` (representing the number of IID partitions that this split should be partitioned @@ -79,7 +81,7 @@ def __init__( *, dataset: str, subset: Optional[str] = None, - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, + preprocessor: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], shuffle: bool = True, seed: Optional[int] = 42, @@ -87,8 +89,8 @@ def __init__( _check_if_dataset_tested(dataset) self._dataset_name: str = dataset self._subset: Optional[str] = subset - self._resplitter: Optional[Resplitter] = _instantiate_resplitter_if_needed( - resplitter + self._preprocessor: Optional[Preprocessor] = _instantiate_merger_if_needed( + preprocessor ) self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners @@ -242,8 +244,8 @@ def _prepare_dataset(self) -> None: # Note it shuffles all the splits. The self._dataset is DatasetDict # so e.g. {"train": train_data, "test": test_data}. All splits get shuffled. self._dataset = self._dataset.shuffle(seed=self._seed) - if self._resplitter: - self._dataset = self._resplitter(self._dataset) + if self._preprocessor: + self._dataset = self._preprocessor(self._dataset) self._dataset_prepared = True def _check_if_no_split_keyword_possible(self) -> None: diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 5d5179122e3b..f65aa6346f3a 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -170,20 +170,20 @@ def test_resplit_dataset_into_one(self) -> None: fds = FederatedDataset( dataset=self.dataset_name, partitioners={"train": 100}, - resplitter={"full": ("train", self.test_split)}, + preprocessor={"full": ("train", self.test_split)}, ) full = fds.load_split("full") self.assertEqual(dataset_length, len(full)) # pylint: disable=protected-access def test_resplit_dataset_to_change_names(self) -> None: - """Test resplitter to change the names of the partitions.""" + """Test preprocessor to change the names of the partitions.""" if self.test_split is None: return fds = FederatedDataset( dataset=self.dataset_name, partitioners={"new_train": 100}, - resplitter={ + preprocessor={ "new_train": ("train",), "new_" + self.test_split: (self.test_split,), }, @@ -195,7 +195,7 @@ def test_resplit_dataset_to_change_names(self) -> None: ) def test_resplit_dataset_by_callable(self) -> None: - """Test resplitter to change the names of the partitions.""" + """Test preprocessor to change the names of the partitions.""" if self.test_split is None: return @@ -209,7 +209,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict: ) fds = FederatedDataset( - dataset=self.dataset_name, partitioners={"train": 100}, resplitter=resplit + dataset=self.dataset_name, partitioners={"train": 100}, preprocessor=resplit ) full = fds.load_split("full") dataset = datasets.load_dataset(self.dataset_name) @@ -298,7 +298,7 @@ def resplit(dataset: DatasetDict) -> DatasetDict: fds = FederatedDataset( dataset="does-not-matter", partitioners={"train": 10}, - resplitter=resplit, + preprocessor=resplit, shuffle=True, ) train = fds.load_split("train") @@ -411,7 +411,7 @@ def test_cannot_use_the_old_split_names(self) -> None: fds = FederatedDataset( dataset="mnist", partitioners={"train": 100}, - resplitter={"full": ("train", "test")}, + preprocessor={"full": ("train", "test")}, ) with self.assertRaises(ValueError): fds.load_partition(0, "train") diff --git a/datasets/flwr_datasets/resplitter/__init__.py b/datasets/flwr_datasets/preprocessor/__init__.py similarity index 76% rename from datasets/flwr_datasets/resplitter/__init__.py rename to datasets/flwr_datasets/preprocessor/__init__.py index bf39786e0593..bab5d82a2035 100644 --- a/datasets/flwr_datasets/resplitter/__init__.py +++ b/datasets/flwr_datasets/preprocessor/__init__.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter package.""" +"""Preprocessor package.""" -from .divide_resplitter import DivideResplitter -from .merge_resplitter import MergeResplitter -from .resplitter import Resplitter +from .divider import Divider +from .merger import Merger +from .preprocessor import Preprocessor __all__ = [ - "DivideResplitter", - "MergeResplitter", - "Resplitter", + "Merger", + "Preprocessor", + "Divider", ] diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter.py b/datasets/flwr_datasets/preprocessor/divider.py similarity index 98% rename from datasets/flwr_datasets/resplitter/divide_resplitter.py rename to datasets/flwr_datasets/preprocessor/divider.py index 56150b51af85..9d7570de4cea 100644 --- a/datasets/flwr_datasets/resplitter/divide_resplitter.py +++ b/datasets/flwr_datasets/preprocessor/divider.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DivideResplitter class for Flower Datasets.""" +"""Divider class for Flower Datasets.""" import collections @@ -25,7 +25,7 @@ # flake8: noqa: E501 # pylint: disable=line-too-long -class DivideResplitter: +class Divider: """Dive existing split(s) of the dataset and assign them custom names. Create new `DatasetDict` with new split names with corresponding percentages of data @@ -66,14 +66,14 @@ class DivideResplitter: >>> # Assuming there is a dataset_dict of type `DatasetDict` >>> # dataset_dict is {"train": train-data, "test": test-data} - >>> resplitter = DivideResplitter( + >>> divider = Divider( >>> divide_config={ >>> "train": 0.8, >>> "valid": 0.2, >>> } >>> divide_split="train", >>> ) - >>> new_dataset_dict = resplitter(dataset_dict) + >>> new_dataset_dict = divider(dataset_dict) >>> # new_dataset_dict is >>> # {"train": 80% of train, "valid": 20% of train, "test": test-data} @@ -83,7 +83,7 @@ class DivideResplitter: >>> # Assuming there is a dataset_dict of type `DatasetDict` >>> # dataset_dict is {"train": train-data, "test": test-data} - >>> resplitter = DivideResplitter( + >>> divider = Divider( >>> divide_config={ >>> "train": { >>> "train": 0.8, @@ -92,7 +92,7 @@ class DivideResplitter: >>> "test": {"test-a": 0.4, "test-b": 0.6 } >>> } >>> ) - >>> new_dataset_dict = resplitter(dataset_dict) + >>> new_dataset_dict = divider(dataset_dict) >>> # new_dataset_dict is >>> # {"train": 80% of train, "valid": 20% of train, >>> # "test-a": 40% of test, "test-b": 60% of test} diff --git a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py b/datasets/flwr_datasets/preprocessor/divider_test.py similarity index 79% rename from datasets/flwr_datasets/resplitter/divide_resplitter_test.py rename to datasets/flwr_datasets/preprocessor/divider_test.py index 143297fcc1a7..ed282fbc18be 100644 --- a/datasets/flwr_datasets/resplitter/divide_resplitter_test.py +++ b/datasets/flwr_datasets/preprocessor/divider_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DivideResplitter tests.""" +"""Divider tests.""" import unittest from typing import Dict, Union @@ -20,7 +20,7 @@ from parameterized import parameterized_class from datasets import Dataset, DatasetDict -from flwr_datasets.resplitter import DivideResplitter +from flwr_datasets.preprocessor import Divider @parameterized_class( @@ -80,8 +80,8 @@ ), ], ) -class TestDivideResplitter(unittest.TestCase): - """DivideResplitter tests.""" +class TestDivider(unittest.TestCase): + """Divider tests.""" divide_config: Union[ Dict[str, float], @@ -105,27 +105,27 @@ def setUp(self) -> None: def test_resplitting_correct_new_split_names(self) -> None: """Test if resplitting produces requested new splits.""" - resplitter = DivideResplitter( + divider = Divider( self.divide_config, self.divide_split, self.drop_remaining_splits ) - resplit_dataset = resplitter(self.dataset_dict) + resplit_dataset = divider(self.dataset_dict) new_keys = set(resplit_dataset.keys()) self.assertEqual(set(self.split_name_to_size.keys()), new_keys) def test_resplitting_correct_new_split_sizes(self) -> None: """Test if resplitting produces correct sizes of splits.""" - resplitter = DivideResplitter( + divider = Divider( self.divide_config, self.divide_split, self.drop_remaining_splits ) - resplit_dataset = resplitter(self.dataset_dict) + resplit_dataset = divider(self.dataset_dict) split_to_size = { split_name: len(split) for split_name, split in resplit_dataset.items() } self.assertEqual(self.split_name_to_size, split_to_size) -class TestDivideResplitterIncorrectUseCases(unittest.TestCase): - """Resplitter tests.""" +class TestDividerIncorrectUseCases(unittest.TestCase): + """Divider tests.""" def setUp(self) -> None: """Set up the dataset with 3 splits for tests.""" @@ -144,21 +144,17 @@ def test_doubling_names_in_config(self) -> None: drop_remaining_splits = False with self.assertRaises(ValueError): - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) - _ = resplitter(self.dataset_dict) + divider = Divider(divide_config, divide_split, drop_remaining_splits) + _ = divider(self.dataset_dict) def test_duplicate_names_in_config_and_dataset_split_names_multisplit(self) -> None: """Test if resplitting raises when the name collides with the old name.""" divide_config = {"train": {"valid": 0.5}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_duplicate_names_in_config_and_dataset_split_names_single_split( self, @@ -167,77 +163,63 @@ def test_duplicate_names_in_config_and_dataset_split_names_single_split( divide_config = {"valid": 0.5} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_fraction_sum_up_to_more_than_one_multisplit(self) -> None: """Test if resplitting raises when fractions sum up to > 1.0 .""" divide_config = {"train": {"train_1": 0.5, "train_2": 0.7}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_fraction_sum_up_to_more_than_one_single_split(self) -> None: """Test if resplitting raises when fractions sum up to > 1.0 .""" divide_config = {"train_1": 0.5, "train_2": 0.7} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_sample_sizes_sum_up_to_more_than_dataset_size_single_split(self) -> None: """Test if resplitting raises when samples size sum up to > len(datset) .""" divide_config = {"train": {"train_1": 20, "train_2": 25}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_sample_sizes_sum_up_to_more_than_dataset_size_multisplit(self) -> None: """Test if resplitting raises when samples size sum up to > len(datset) .""" divide_config = {"train_1": 20, "train_2": 25} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_too_small_size_values_create_empty_dataset_single_split(self) -> None: """Test if resplitting raises when fraction creates empty dataset.""" divide_config = {"train": {"train_1": 0.2, "train_2": 0.0001}} divide_split = None drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) def test_too_small_size_values_create_empty_dataset_multisplit(self) -> None: """Test if resplitting raises when fraction creates empty dataset.""" divide_config = {"train_1": 0.2, "train_2": 0.0001} divide_split = "train" drop_remaining_splits = False - resplitter = DivideResplitter( - divide_config, divide_split, drop_remaining_splits - ) + divider = Divider(divide_config, divide_split, drop_remaining_splits) with self.assertRaises(ValueError): - _ = resplitter(self.dataset_dict) + _ = divider(self.dataset_dict) if __name__ == "__main__": diff --git a/datasets/flwr_datasets/resplitter/merge_resplitter.py b/datasets/flwr_datasets/preprocessor/merger.py similarity index 96% rename from datasets/flwr_datasets/resplitter/merge_resplitter.py rename to datasets/flwr_datasets/preprocessor/merger.py index 6bb8f23e60dc..2b76dbbafe4b 100644 --- a/datasets/flwr_datasets/resplitter/merge_resplitter.py +++ b/datasets/flwr_datasets/preprocessor/merger.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""MergeResplitter class for Flower Datasets.""" +"""Merger class for Flower Datasets.""" import collections @@ -24,7 +24,7 @@ from datasets import Dataset, DatasetDict -class MergeResplitter: +class Merger: """Merge existing splits of the dataset and assign them custom names. Create new `DatasetDict` with new split names corresponding to the merged existing @@ -43,13 +43,13 @@ class MergeResplitter: >>> # Assuming there is a dataset_dict of type `DatasetDict` >>> # dataset_dict is {"train": train-data, "valid": valid-data, "test": test-data} - >>> merge_resplitter = MergeResplitter( + >>> merger = Merger( >>> merge_config={ >>> "new_train": ("train", "valid"), >>> "test": ("test", ) >>> } >>> ) - >>> new_dataset_dict = merge_resplitter(dataset_dict) + >>> new_dataset_dict = merger(dataset_dict) >>> # new_dataset_dict is >>> # {"new_train": concatenation of train-data and valid-data, "test": test-data} """ diff --git a/datasets/flwr_datasets/resplitter/merge_resplitter_test.py b/datasets/flwr_datasets/preprocessor/merger_test.py similarity index 81% rename from datasets/flwr_datasets/resplitter/merge_resplitter_test.py rename to datasets/flwr_datasets/preprocessor/merger_test.py index ebbdfb4022b0..d5c69387e53d 100644 --- a/datasets/flwr_datasets/resplitter/merge_resplitter_test.py +++ b/datasets/flwr_datasets/preprocessor/merger_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter tests.""" +"""Preprocessor tests.""" import unittest @@ -21,11 +21,11 @@ import pytest from datasets import Dataset, DatasetDict -from flwr_datasets.resplitter.merge_resplitter import MergeResplitter +from flwr_datasets.preprocessor.merger import Merger -class TestResplitter(unittest.TestCase): - """Resplitter tests.""" +class TestMerger(unittest.TestCase): + """Preprocessor tests.""" def setUp(self) -> None: """Set up the dataset with 3 splits for tests.""" @@ -40,29 +40,29 @@ def setUp(self) -> None: def test_resplitting_train_size(self) -> None: """Test if resplitting for just renaming keeps the lengths correct.""" strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["new_train"]), 3) def test_resplitting_valid_size(self) -> None: """Test if resplitting for just renaming keeps the lengths correct.""" strategy: Dict[str, Tuple[str, ...]] = {"new_valid": ("valid",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["new_valid"]), 2) def test_resplitting_test_size(self) -> None: """Test if resplitting for just renaming keeps the lengths correct.""" strategy: Dict[str, Tuple[str, ...]] = {"new_test": ("test",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["new_test"]), 1) def test_resplitting_train_the_same(self) -> None: """Test if resplitting for just renaming keeps the dataset the same.""" strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertTrue( datasets_are_equal(self.dataset_dict["train"], new_dataset["new_train"]) ) @@ -72,8 +72,8 @@ def test_combined_train_valid_size(self) -> None: strategy: Dict[str, Tuple[str, ...]] = { "train_valid_combined": ("train", "valid") } - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["train_valid_combined"]), 5) def test_resplitting_test_with_combined_strategy_size(self) -> None: @@ -82,8 +82,8 @@ def test_resplitting_test_with_combined_strategy_size(self) -> None: "train_valid_combined": ("train", "valid"), "test": ("test",), } - resplitter = MergeResplitter(strategy) - new_dataset = resplitter(self.dataset_dict) + merger = Merger(strategy) + new_dataset = merger(self.dataset_dict) self.assertEqual(len(new_dataset["test"]), 1) def test_invalid_resplit_strategy_exception_message(self) -> None: @@ -92,20 +92,20 @@ def test_invalid_resplit_strategy_exception_message(self) -> None: "new_train": ("invalid_split",), "new_test": ("test",), } - resplitter = MergeResplitter(strategy) + merger = Merger(strategy) with self.assertRaisesRegex( ValueError, "The given dataset key 'invalid_split' is not present" ): - resplitter(self.dataset_dict) + merger(self.dataset_dict) def test_nonexistent_split_in_strategy(self) -> None: """Test if the exception is raised when the nonexistent split name is given.""" strategy: Dict[str, Tuple[str, ...]] = {"new_split": ("nonexistent_split",)} - resplitter = MergeResplitter(strategy) + merger = Merger(strategy) with self.assertRaisesRegex( ValueError, "The given dataset key 'nonexistent_split' is not present" ): - resplitter(self.dataset_dict) + merger(self.dataset_dict) def test_duplicate_merge_split_name(self) -> None: # pylint: disable=R0201 """Test that the new split names are not the same.""" @@ -114,17 +114,17 @@ def test_duplicate_merge_split_name(self) -> None: # pylint: disable=R0201 "test": ("train",), } with pytest.warns(UserWarning): - _ = MergeResplitter(strategy) + _ = Merger(strategy) def test_empty_dataset_dict(self) -> None: """Test that the error is raised when the empty DatasetDict is given.""" empty_dataset = DatasetDict({}) strategy: Dict[str, Tuple[str, ...]] = {"new_train": ("train",)} - resplitter = MergeResplitter(strategy) + merger = Merger(strategy) with self.assertRaisesRegex( ValueError, "The given dataset key 'train' is not present" ): - resplitter(empty_dataset) + merger(empty_dataset) def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: diff --git a/datasets/flwr_datasets/resplitter/resplitter.py b/datasets/flwr_datasets/preprocessor/preprocessor.py similarity index 91% rename from datasets/flwr_datasets/resplitter/resplitter.py rename to datasets/flwr_datasets/preprocessor/preprocessor.py index 206e2e85730c..c137b98eeeee 100644 --- a/datasets/flwr_datasets/resplitter/resplitter.py +++ b/datasets/flwr_datasets/preprocessor/preprocessor.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Resplitter.""" +"""Preprocessor.""" from typing import Callable from datasets import DatasetDict -Resplitter = Callable[[DatasetDict], DatasetDict] +Preprocessor = Callable[[DatasetDict], DatasetDict] diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index c6f6900a99cd..0ecb96ac9456 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -20,8 +20,8 @@ from datasets import Dataset, DatasetDict, concatenate_datasets from flwr_datasets.partitioner import IidPartitioner, Partitioner -from flwr_datasets.resplitter import Resplitter -from flwr_datasets.resplitter.merge_resplitter import MergeResplitter +from flwr_datasets.preprocessor import Preprocessor +from flwr_datasets.preprocessor.merger import Merger tested_datasets = [ "mnist", @@ -75,13 +75,13 @@ def _instantiate_partitioners( return instantiated_partitioners -def _instantiate_resplitter_if_needed( - resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] -) -> Optional[Resplitter]: - """Instantiate `MergeResplitter` if resplitter is merge_config.""" - if resplitter and isinstance(resplitter, Dict): - resplitter = MergeResplitter(merge_config=resplitter) - return cast(Optional[Resplitter], resplitter) +def _instantiate_merger_if_needed( + merger: Optional[Union[Preprocessor, Dict[str, Tuple[str, ...]]]] +) -> Optional[Preprocessor]: + """Instantiate `Merger` if preprocessor is merge_config.""" + if merger and isinstance(merger, Dict): + merger = Merger(merge_config=merger) + return cast(Optional[Preprocessor], merger) def _check_if_dataset_tested(dataset: str) -> None: From ee00d70a4f73abc339c53fcb402d9a164f3b9bf1 Mon Sep 17 00:00:00 2001 From: mohammadnaseri Date: Wed, 29 May 2024 11:52:36 +0100 Subject: [PATCH 3/7] fix(examples) Update TensorFlow-Privacy example dataset and model (#3526) --- examples/tensorflow-privacy/client.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/examples/tensorflow-privacy/client.py b/examples/tensorflow-privacy/client.py index 1032d12e69a2..4aec85da014a 100644 --- a/examples/tensorflow-privacy/client.py +++ b/examples/tensorflow-privacy/client.py @@ -14,14 +14,14 @@ def load_data(partition_id, batch_size): - fds = FederatedDataset(dataset="cifar10", partitioners={"train": 2}) + fds = FederatedDataset(dataset="mnist", partitioners={"train": 2}) partition = fds.load_partition(partition_id, "train") partition.set_format("numpy") # Divide data on each node: 80% train, 20% test partition = partition.train_test_split(test_size=0.2, seed=42) - x_train, y_train = partition["train"]["img"] / 255.0, partition["train"]["label"] - x_test, y_test = partition["test"]["img"] / 255.0, partition["test"]["label"] + x_train, y_train = partition["train"]["image"] / 255.0, partition["train"]["label"] + x_test, y_test = partition["test"]["image"] / 255.0, partition["test"]["label"] # Adjust the size of the training dataset to make it evenly divisible by the batch size remainder = len(x_train) % batch_size @@ -112,7 +112,18 @@ def client_fn_parameterized( batch_size=64, ): def client_fn(cid: str): - model = tf.keras.applications.MobileNetV2((32, 32, 3), classes=10, weights=None) + model = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), + tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10, activation="softmax"), + ] + ) train_data, test_data = load_data( partition_id=partition_id, batch_size=batch_size ) From a1216e2f8a7f7bd4ee2c37974cfbbb07893dfeaf Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Thu, 30 May 2024 18:33:27 +0800 Subject: [PATCH 4/7] fix(examples) Update XGBoost quick-start example for independent object instantiation (#3426) Co-authored-by: jafermarq --- examples/xgboost-quickstart/README.md | 23 ++---- examples/xgboost-quickstart/client.py | 86 ++++++++++++++------ examples/xgboost-quickstart/pyproject.toml | 26 +++--- examples/xgboost-quickstart/requirements.txt | 3 - examples/xgboost-quickstart/server.py | 11 +++ 5 files changed, 91 insertions(+), 58 deletions(-) delete mode 100644 examples/xgboost-quickstart/requirements.txt diff --git a/examples/xgboost-quickstart/README.md b/examples/xgboost-quickstart/README.md index b196520d37e6..713b6eab8bac 100644 --- a/examples/xgboost-quickstart/README.md +++ b/examples/xgboost-quickstart/README.md @@ -21,37 +21,26 @@ This will create a new directory called `xgboost-quickstart` containing the foll -- server.py <- Defines the server-side logic -- client.py <- Defines the client-side logic -- run.sh <- Commands to run experiments --- pyproject.toml <- Example dependencies (if you use Poetry) --- requirements.txt <- Example dependencies +-- pyproject.toml <- Example dependencies ``` ### Installing Dependencies -Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. - -#### Poetry +Project dependencies (such as `xgboost` and `flwr`) are defined in `pyproject.toml`. You can install the dependencies by invoking `pip`: ```shell -poetry install -poetry shell +# From a new python environment, run: +pip install . ``` -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: +Then, to verify that everything works correctly you can run the following command: ```shell -poetry run python3 -c "import flwr" +python3 -c "import flwr" ``` If you don't see any errors you're good to go! -#### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt -``` - ## Run Federated Learning with XGBoost and Flower Afterwards you are ready to start the Flower server as well as the clients. diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py index 6ac23ae15148..5a4d88bb7e43 100644 --- a/examples/xgboost-quickstart/client.py +++ b/examples/xgboost-quickstart/client.py @@ -92,9 +92,21 @@ def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core. # Define Flower client class XgbClient(fl.client.Client): - def __init__(self): - self.bst = None - self.config = None + def __init__( + self, + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ): + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + self.num_train = num_train + self.num_val = num_val + self.num_local_round = num_local_round + self.params = params def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: _ = (self, ins) @@ -106,41 +118,41 @@ def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: parameters=Parameters(tensor_type="", tensors=[]), ) - def _local_boost(self): + def _local_boost(self, bst_input): # Update trees based on local training data. - for i in range(num_local_round): - self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + for i in range(self.num_local_round): + bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds()) - # Extract the last N=num_local_round trees for sever aggregation - bst = self.bst[ - self.bst.num_boosted_rounds() - - num_local_round : self.bst.num_boosted_rounds() + # Bagging: extract the last N=num_local_round trees for sever aggregation + bst = bst_input[ + bst_input.num_boosted_rounds() + - self.num_local_round : bst_input.num_boosted_rounds() ] return bst def fit(self, ins: FitIns) -> FitRes: - if not self.bst: + global_round = int(ins.config["global_round"]) + if global_round == 1: # First round local training - log(INFO, "Start training at round 1") bst = xgb.train( - params, - train_dmatrix, - num_boost_round=num_local_round, - evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + self.params, + self.train_dmatrix, + num_boost_round=self.num_local_round, + evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")], ) - self.config = bst.save_config() - self.bst = bst else: + bst = xgb.Booster(params=self.params) for item in ins.parameters.tensors: global_model = bytearray(item) # Load global model into booster - self.bst.load_model(global_model) - self.bst.load_config(self.config) + bst.load_model(global_model) - bst = self._local_boost() + # Local training + bst = self._local_boost(bst) + # Save model local_model = bst.save_raw("json") local_model_bytes = bytes(local_model) @@ -150,27 +162,47 @@ def fit(self, ins: FitIns) -> FitRes: message="OK", ), parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), - num_examples=num_train, + num_examples=self.num_train, metrics={}, ) def evaluate(self, ins: EvaluateIns) -> EvaluateRes: - eval_results = self.bst.eval_set( - evals=[(valid_dmatrix, "valid")], - iteration=self.bst.num_boosted_rounds() - 1, + # Load global model + bst = xgb.Booster(params=self.params) + for para in ins.parameters.tensors: + para_b = bytearray(para) + bst.load_model(para_b) + + # Run evaluation + eval_results = bst.eval_set( + evals=[(self.valid_dmatrix, "valid")], + iteration=bst.num_boosted_rounds() - 1, ) auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + global_round = ins.config["global_round"] + log(INFO, f"AUC = {auc} at round {global_round}") + return EvaluateRes( status=Status( code=Code.OK, message="OK", ), loss=0.0, - num_examples=num_val, + num_examples=self.num_val, metrics={"AUC": auc}, ) # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client()) +fl.client.start_client( + server_address="127.0.0.1:8080", + client=XgbClient( + train_dmatrix, + valid_dmatrix, + num_train, + num_val, + num_local_round, + params, + ).to_client(), +) diff --git a/examples/xgboost-quickstart/pyproject.toml b/examples/xgboost-quickstart/pyproject.toml index c16542ea7ffe..f1e451fe779a 100644 --- a/examples/xgboost-quickstart/pyproject.toml +++ b/examples/xgboost-quickstart/pyproject.toml @@ -1,15 +1,19 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] -name = "xgboost-quickstart" +[project] +name = "quickstart-xgboost" version = "0.1.0" -description = "Federated XGBoost with Flower (quickstart)" -authors = ["The Flower Authors "] +description = "XGBoost Federated Learning Quickstart with Flower" +authors = [ + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +dependencies = [ + "flwr>=1.8.0,<2.0", + "flwr-datasets>=0.1.0,<1.0.0", + "xgboost>=2.0.0,<3.0.0", +] -[tool.poetry.dependencies] -python = ">=3.8,<3.11" -flwr = ">=1.7.0,<2.0" -flwr-datasets = ">=0.0.1,<1.0.0" -xgboost = ">=2.0.0,<3.0.0" +[tool.hatch.build.targets.wheel] +packages = ["."] diff --git a/examples/xgboost-quickstart/requirements.txt b/examples/xgboost-quickstart/requirements.txt deleted file mode 100644 index c6949e0651c5..000000000000 --- a/examples/xgboost-quickstart/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -flwr>=1.7.0, <2.0 -flwr-datasets>=0.0.1, <1.0.0 -xgboost>=2.0.0, <3.0.0 diff --git a/examples/xgboost-quickstart/server.py b/examples/xgboost-quickstart/server.py index b45a375ce94f..e9239fde696c 100644 --- a/examples/xgboost-quickstart/server.py +++ b/examples/xgboost-quickstart/server.py @@ -1,3 +1,4 @@ +from typing import Dict import flwr as fl from flwr.server.strategy import FedXgbBagging @@ -19,6 +20,14 @@ def evaluate_metrics_aggregation(eval_metrics): return metrics_aggregated +def config_func(rnd: int) -> Dict[str, str]: + """Return a configuration with global epochs.""" + config = { + "global_round": str(rnd), + } + return config + + # Define strategy strategy = FedXgbBagging( fraction_fit=(float(num_clients_per_round) / pool_size), @@ -27,6 +36,8 @@ def evaluate_metrics_aggregation(eval_metrics): min_evaluate_clients=num_evaluate_clients, fraction_evaluate=1.0, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + on_evaluate_config_fn=config_func, + on_fit_config_fn=config_func, ) # Start Flower server From 568ab1e8329e1b80c635abb218fd8dcbdc60d21c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 16:51:53 +0100 Subject: [PATCH 5/7] refactor(framework) Add `adapter` argument to `grpc_request_response` (#3534) --- src/py/flwr/client/grpc_rere_client/connection.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 3778fd4061f9..8ef8e7ebf62a 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -21,7 +21,7 @@ from copy import copy from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, Union, cast import grpc from cryptography.hazmat.primitives.asymmetric import ec @@ -73,6 +73,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 authentication_keys: Optional[ Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, + adapter_cls: Optional[Type[FleetStub]] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], @@ -133,7 +134,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 channel.subscribe(on_channel_state_change) # Shared variables for inner functions - stub = FleetStub(channel) + if adapter_cls is None: + adapter_cls = FleetStub + stub = adapter_cls(channel) metadata: Optional[Metadata] = None node: Optional[Node] = None ping_thread: Optional[threading.Thread] = None From 77ed30b69a15723c60e004ffd4ee2c0462a35cf6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 16:57:33 +0100 Subject: [PATCH 6/7] feat(framework) Add GrpcAdapter proto (#3532) --- src/proto/flwr/proto/grpcadapter.proto | 28 +++++++++ src/py/flwr/proto/grpcadapter_pb2.py | 32 +++++++++++ src/py/flwr/proto/grpcadapter_pb2.pyi | 43 ++++++++++++++ src/py/flwr/proto/grpcadapter_pb2_grpc.py | 66 ++++++++++++++++++++++ src/py/flwr/proto/grpcadapter_pb2_grpc.pyi | 24 ++++++++ src/py/flwr_tool/protoc_test.py | 2 +- 6 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 src/proto/flwr/proto/grpcadapter.proto create mode 100644 src/py/flwr/proto/grpcadapter_pb2.py create mode 100644 src/py/flwr/proto/grpcadapter_pb2.pyi create mode 100644 src/py/flwr/proto/grpcadapter_pb2_grpc.py create mode 100644 src/py/flwr/proto/grpcadapter_pb2_grpc.pyi diff --git a/src/proto/flwr/proto/grpcadapter.proto b/src/proto/flwr/proto/grpcadapter.proto new file mode 100644 index 000000000000..826efec4315e --- /dev/null +++ b/src/proto/flwr/proto/grpcadapter.proto @@ -0,0 +1,28 @@ +// Copyright 2024 Flower Labs GmbH. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== + +syntax = "proto3"; + +package flwr.proto; + +service GrpcAdapter { + rpc SendReceive(MessageContainer) returns (MessageContainer) {} +} + +message MessageContainer { + map metadata = 1; + string grpc_message_name = 2; + bytes grpc_message_content = 3; +} diff --git a/src/py/flwr/proto/grpcadapter_pb2.py b/src/py/flwr/proto/grpcadapter_pb2.py new file mode 100644 index 000000000000..7c0374736850 --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: flwr/proto/grpcadapter.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.grpcadapter_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_MESSAGECONTAINER_METADATAENTRY']._options = None + _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_options = b'8\001' + _globals['_MESSAGECONTAINER']._serialized_start=45 + _globals['_MESSAGECONTAINER']._serialized_end=231 + _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_start=184 + _globals['_MESSAGECONTAINER_METADATAENTRY']._serialized_end=231 + _globals['_GRPCADAPTER']._serialized_start=233 + _globals['_GRPCADAPTER']._serialized_end=323 +# @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/grpcadapter_pb2.pyi b/src/py/flwr/proto/grpcadapter_pb2.pyi new file mode 100644 index 000000000000..d5f89ac27c4a --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2.pyi @@ -0,0 +1,43 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import typing +import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class MessageContainer(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class MetadataEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + value: builtins.bytes + def __init__(self, + *, + key: typing.Text = ..., + value: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + METADATA_FIELD_NUMBER: builtins.int + GRPC_MESSAGE_NAME_FIELD_NUMBER: builtins.int + GRPC_MESSAGE_CONTENT_FIELD_NUMBER: builtins.int + @property + def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, builtins.bytes]: ... + grpc_message_name: typing.Text + grpc_message_content: builtins.bytes + def __init__(self, + *, + metadata: typing.Optional[typing.Mapping[typing.Text, builtins.bytes]] = ..., + grpc_message_name: typing.Text = ..., + grpc_message_content: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["grpc_message_content",b"grpc_message_content","grpc_message_name",b"grpc_message_name","metadata",b"metadata"]) -> None: ... +global___MessageContainer = MessageContainer diff --git a/src/py/flwr/proto/grpcadapter_pb2_grpc.py b/src/py/flwr/proto/grpcadapter_pb2_grpc.py new file mode 100644 index 000000000000..831f99d7b237 --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from flwr.proto import grpcadapter_pb2 as flwr_dot_proto_dot_grpcadapter__pb2 + + +class GrpcAdapterStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendReceive = channel.unary_unary( + '/flwr.proto.GrpcAdapter/SendReceive', + request_serializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.SerializeToString, + response_deserializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.FromString, + ) + + +class GrpcAdapterServicer(object): + """Missing associated documentation comment in .proto file.""" + + def SendReceive(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GrpcAdapterServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendReceive': grpc.unary_unary_rpc_method_handler( + servicer.SendReceive, + request_deserializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.FromString, + response_serializer=flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'flwr.proto.GrpcAdapter', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class GrpcAdapter(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def SendReceive(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/flwr.proto.GrpcAdapter/SendReceive', + flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.SerializeToString, + flwr_dot_proto_dot_grpcadapter__pb2.MessageContainer.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/grpcadapter_pb2_grpc.pyi b/src/py/flwr/proto/grpcadapter_pb2_grpc.pyi new file mode 100644 index 000000000000..640f983e6e04 --- /dev/null +++ b/src/py/flwr/proto/grpcadapter_pb2_grpc.pyi @@ -0,0 +1,24 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import abc +import flwr.proto.grpcadapter_pb2 +import grpc + +class GrpcAdapterStub: + def __init__(self, channel: grpc.Channel) -> None: ... + SendReceive: grpc.UnaryUnaryMultiCallable[ + flwr.proto.grpcadapter_pb2.MessageContainer, + flwr.proto.grpcadapter_pb2.MessageContainer] + + +class GrpcAdapterServicer(metaclass=abc.ABCMeta): + @abc.abstractmethod + def SendReceive(self, + request: flwr.proto.grpcadapter_pb2.MessageContainer, + context: grpc.ServicerContext, + ) -> flwr.proto.grpcadapter_pb2.MessageContainer: ... + + +def add_GrpcAdapterServicer_to_server(servicer: GrpcAdapterServicer, server: grpc.Server) -> None: ... diff --git a/src/py/flwr_tool/protoc_test.py b/src/py/flwr_tool/protoc_test.py index 2d48582eb441..8dcf4c6474d6 100644 --- a/src/py/flwr_tool/protoc_test.py +++ b/src/py/flwr_tool/protoc_test.py @@ -28,4 +28,4 @@ def test_directories() -> None: def test_proto_file_count() -> None: """Test if the correct number of proto files were captured by the glob.""" - assert len(PROTO_FILES) == 7 + assert len(PROTO_FILES) == 8 From 3d17a5ebffe969765581ab3417a1c15b77f8c138 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 30 May 2024 17:15:30 +0100 Subject: [PATCH 7/7] fix(framework) Fix the metadata type in `MessageContainer` (#3535) --- src/proto/flwr/proto/grpcadapter.proto | 2 +- src/py/flwr/proto/grpcadapter_pb2.py | 2 +- src/py/flwr/proto/grpcadapter_pb2.pyi | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/proto/flwr/proto/grpcadapter.proto b/src/proto/flwr/proto/grpcadapter.proto index 826efec4315e..acf9a9d3d94f 100644 --- a/src/proto/flwr/proto/grpcadapter.proto +++ b/src/proto/flwr/proto/grpcadapter.proto @@ -22,7 +22,7 @@ service GrpcAdapter { } message MessageContainer { - map metadata = 1; + map metadata = 1; string grpc_message_name = 2; bytes grpc_message_content = 3; } diff --git a/src/py/flwr/proto/grpcadapter_pb2.py b/src/py/flwr/proto/grpcadapter_pb2.py index 7c0374736850..2eff4bb78e47 100644 --- a/src/py/flwr/proto/grpcadapter_pb2.py +++ b/src/py/flwr/proto/grpcadapter_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lwr/proto/grpcadapter.proto\x12\nflwr.proto\"\xba\x01\n\x10MessageContainer\x12<\n\x08metadata\x18\x01 \x03(\x0b\x32*.flwr.proto.MessageContainer.MetadataEntry\x12\x19\n\x11grpc_message_name\x18\x02 \x01(\t\x12\x1c\n\x14grpc_message_content\x18\x03 \x01(\x0c\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x32Z\n\x0bGrpcAdapter\x12K\n\x0bSendReceive\x12\x1c.flwr.proto.MessageContainer\x1a\x1c.flwr.proto.MessageContainer\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) diff --git a/src/py/flwr/proto/grpcadapter_pb2.pyi b/src/py/flwr/proto/grpcadapter_pb2.pyi index d5f89ac27c4a..35889b30d2b6 100644 --- a/src/py/flwr/proto/grpcadapter_pb2.pyi +++ b/src/py/flwr/proto/grpcadapter_pb2.pyi @@ -18,11 +18,11 @@ class MessageContainer(google.protobuf.message.Message): KEY_FIELD_NUMBER: builtins.int VALUE_FIELD_NUMBER: builtins.int key: typing.Text - value: builtins.bytes + value: typing.Text def __init__(self, *, key: typing.Text = ..., - value: builtins.bytes = ..., + value: typing.Text = ..., ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... @@ -30,12 +30,12 @@ class MessageContainer(google.protobuf.message.Message): GRPC_MESSAGE_NAME_FIELD_NUMBER: builtins.int GRPC_MESSAGE_CONTENT_FIELD_NUMBER: builtins.int @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, builtins.bytes]: ... + def metadata(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ... grpc_message_name: typing.Text grpc_message_content: builtins.bytes def __init__(self, *, - metadata: typing.Optional[typing.Mapping[typing.Text, builtins.bytes]] = ..., + metadata: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ..., grpc_message_name: typing.Text = ..., grpc_message_content: builtins.bytes = ..., ) -> None: ...