diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 603e6c602274..a25d8c3914bc 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -2,6 +2,8 @@ ## Unreleased +- **Add scikit-learn tabular data example** ([#2719](https://github.com/adap/flower/pull/2719)) + - **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381)) - **Update Flower Baselines** diff --git a/examples/quickstart-sklearn-tabular/README.md b/examples/quickstart-sklearn-tabular/README.md new file mode 100644 index 000000000000..d62525c96c18 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/README.md @@ -0,0 +1,77 @@ +# Flower Example using scikit-learn + +This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on +"iris" (tabular) dataset. +It will help you understand how to adapt Flower for use with `scikit-learn`. +Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to +download, partition and preprocess the dataset. + +## Project Setup + +Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: + +```shell +git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-sklearn-tabular . && rm -rf flower && cd quickstart-sklearn-tabular +``` + +This will create a new directory called `quickstart-sklearn-tabular` containing the following files: + +```shell +-- pyproject.toml +-- requirements.txt +-- client.py +-- server.py +-- utils.py +-- README.md +``` + +### Installing Dependencies + +Project dependencies (such as `scikit-learn` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +#### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +If you don't see any errors you're good to go! + +#### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +## Run Federated Learning with scikit-learn and Flower + +Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: + +```shell +poetry run python3 server.py +``` + +Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: + +```shell +poetry run python3 client.py --node-id 0 # node-id should be any of {0,1,2} +``` + +Alternatively you can run all of it in one shell as follows: + +```shell +poetry run python3 server.py & +poetry run python3 client.py --node-id 0 & +poetry run python3 client.py --node-id 1 +``` + +You will see that Flower is starting a federated training. diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py new file mode 100644 index 000000000000..88f654d4398e --- /dev/null +++ b/examples/quickstart-sklearn-tabular/client.py @@ -0,0 +1,71 @@ +import argparse +import warnings + +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import log_loss + +import flwr as fl +import utils +from flwr_datasets import FederatedDataset + +if __name__ == "__main__": + N_CLIENTS = 3 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--node-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.node_id + + # Load the partition data + fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS}) + + dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:] + X = dataset[["petal_length", "petal_width", "sepal_length", "sepal_width"]] + y = dataset["species"] + unique_labels = fds.load_full("train").unique("species") + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] + + # Create LogisticRegression Model + model = LogisticRegression( + penalty="l2", + max_iter=1, # local epoch + warm_start=True, # prevent refreshing weights when fitting + ) + + # Setting initial parameters, akin to model.compile for keras models + utils.set_initial_params(model, n_features=X_train.shape[1], n_classes=3) + + # Define Flower client + class IrisClient(fl.client.NumPyClient): + def get_parameters(self, config): # type: ignore + return utils.get_model_parameters(model) + + def fit(self, parameters, config): # type: ignore + utils.set_model_params(model, parameters) + # Ignore convergence failure due to low local epochs + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model.fit(X_train, y_train) + accuracy = model.score(X_train, y_train) + return ( + utils.get_model_parameters(model), + len(X_train), + {"train_accuracy": accuracy}, + ) + + def evaluate(self, parameters, config): # type: ignore + utils.set_model_params(model, parameters) + loss = log_loss(y_test, model.predict_proba(X_test), labels=unique_labels) + accuracy = model.score(X_test, y_test) + return loss, len(X_test), {"test_accuracy": accuracy} + + # Start Flower client + fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client()) diff --git a/examples/quickstart-sklearn-tabular/pyproject.toml b/examples/quickstart-sklearn-tabular/pyproject.toml new file mode 100644 index 000000000000..34a78048d3b0 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "sklearn-mnist" +version = "0.1.0" +description = "Federated learning with scikit-learn and Flower" +authors = [ + "The Flower Authors ", + "Kaushik Amar Das " +] + +[tool.poetry.dependencies] +python = "^3.8" +flwr = ">=1.0,<2.0" +flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } +scikit-learn = "^1.3.0" diff --git a/examples/quickstart-sklearn-tabular/requirements.txt b/examples/quickstart-sklearn-tabular/requirements.txt new file mode 100644 index 000000000000..e0f15b31f3f7 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/requirements.txt @@ -0,0 +1,3 @@ +flwr>=1.0, <2.0 +flwr-datasets[vision]>=0.0.2, <1.0.0 +scikit-learn>=1.3.0 diff --git a/examples/quickstart-sklearn-tabular/run.sh b/examples/quickstart-sklearn-tabular/run.sh new file mode 100755 index 000000000000..48cee1b41b74 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash +set -e +cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ + +echo "Starting server" +python server.py & +sleep 3 # Sleep for 3s to give the server enough time to start + +for i in $(seq 0 1); do + echo "Starting client $i" + python client.py --node-id "${i}" & +done + +# This will allow you to use CTRL+C to stop all background processes +trap 'trap - SIGTERM && kill -- -$$' SIGINT SIGTERM +# Wait for all background processes to complete +wait diff --git a/examples/quickstart-sklearn-tabular/server.py b/examples/quickstart-sklearn-tabular/server.py new file mode 100644 index 000000000000..0c779c52a8d6 --- /dev/null +++ b/examples/quickstart-sklearn-tabular/server.py @@ -0,0 +1,19 @@ +import flwr as fl +import utils +from sklearn.linear_model import LogisticRegression + + +# Start Flower server for five rounds of federated learning +if __name__ == "__main__": + model = LogisticRegression() + utils.set_initial_params(model, n_classes=3, n_features=4) + strategy = fl.server.strategy.FedAvg( + min_available_clients=2, + fit_metrics_aggregation_fn=utils.weighted_average, + evaluate_metrics_aggregation_fn=utils.weighted_average, + ) + fl.server.start_server( + server_address="0.0.0.0:8080", + strategy=strategy, + config=fl.server.ServerConfig(num_rounds=25), + ) diff --git a/examples/quickstart-sklearn-tabular/utils.py b/examples/quickstart-sklearn-tabular/utils.py new file mode 100644 index 000000000000..e154f44ef8bf --- /dev/null +++ b/examples/quickstart-sklearn-tabular/utils.py @@ -0,0 +1,75 @@ +from typing import List, Tuple, Dict + +import numpy as np +from sklearn.linear_model import LogisticRegression + +from flwr.common import NDArrays, Metrics, Scalar + + +def get_model_parameters(model: LogisticRegression) -> NDArrays: + """Return the parameters of a sklearn LogisticRegression model.""" + if model.fit_intercept: + params = [ + model.coef_, + model.intercept_, + ] + else: + params = [ + model.coef_, + ] + return params + + +def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression: + """Set the parameters of a sklean LogisticRegression model.""" + model.coef_ = params[0] + if model.fit_intercept: + model.intercept_ = params[1] + return model + + +def set_initial_params(model: LogisticRegression, n_classes: int, n_features: int): + """Set initial parameters as zeros. + + Required since model params are uninitialized until model.fit is called but server + asks for initial parameters from clients at launch. Refer to + sklearn.linear_model.LogisticRegression documentation for more information. + """ + model.classes_ = np.array([i for i in range(n_classes)]) + + model.coef_ = np.zeros((n_classes, n_features)) + if model.fit_intercept: + model.intercept_ = np.zeros((n_classes,)) + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]: + """Compute weighted average. + + It is generic implementation that averages only over floats and ints and drops the + other data types of the Metrics. + """ + print(metrics) + # num_samples_list can represent number of sample or batches depending on the client + num_samples_list = [n_batches for n_batches, _ in metrics] + num_samples_sum = sum(num_samples_list) + metrics_lists: Dict[str, List[float]] = {} + for num_samples, all_metrics_dict in metrics: + # Calculate each metric one by one + for single_metric, value in all_metrics_dict.items(): + if isinstance(value, (float, int)): + metrics_lists[single_metric] = [] + # Just one iteration needed to initialize the keywords + break + + for num_samples, all_metrics_dict in metrics: + # Calculate each metric one by one + for single_metric, value in all_metrics_dict.items(): + # Add weighted metric + if isinstance(value, (float, int)): + metrics_lists[single_metric].append(float(num_samples * value)) + + weighted_metrics: Dict[str, Scalar] = {} + for metric_name, metric_values in metrics_lists.items(): + weighted_metrics[metric_name] = sum(metric_values) / num_samples_sum + + return weighted_metrics