Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quickstart sklearn tabular example #2719

Merged
merged 14 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/ref-changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- **Add scikit-learn tabular data example** ([#2719](https://github.com/adap/flower/pull/2719))

- **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381))

- **Update Flower Baselines**
Expand Down
77 changes: 77 additions & 0 deletions examples/quickstart-sklearn-tabular/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Flower Example using scikit-learn

This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system on
"iris" (tabular) dataset.
It will help you understand how to adapt Flower for use with `scikit-learn`.
Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) to
download, partition and preprocess the dataset.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/quickstart-sklearn-tabular . && rm -rf flower && cd quickstart-sklearn-tabular
```

This will create a new directory called `quickstart-sklearn-tabular` containing the following files:

```shell
-- pyproject.toml
-- requirements.txt
-- client.py
-- server.py
-- utils.py
-- README.md
```

### Installing Dependencies

Project dependencies (such as `scikit-learn` and `flwr`) are defined in `pyproject.toml` and `requirements.txt`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)) or [pip](https://pip.pypa.io/en/latest/development/), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

#### Poetry

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

#### pip

Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt.

```shell
pip install -r requirements.txt
```

## Run Federated Learning with scikit-learn and Flower

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
poetry run python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each:

```shell
poetry run python3 client.py --node-id 0 # node-id should be any of {0,1,2}
```

Alternatively you can run all of it in one shell as follows:

```shell
poetry run python3 server.py &
poetry run python3 client.py --node-id 0 &
poetry run python3 client.py --node-id 1
```

You will see that Flower is starting a federated training.
71 changes: 71 additions & 0 deletions examples/quickstart-sklearn-tabular/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse
import warnings

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import flwr as fl
import utils
from flwr_datasets import FederatedDataset

if __name__ == "__main__":
N_CLIENTS = 3

parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
type=int,
choices=range(0, N_CLIENTS),
required=True,
help="Specifies the artificial data partition",
)
args = parser.parse_args()
partition_id = args.node_id

# Load the partition data
fds = FederatedDataset(dataset="hitorilabs/iris", partitioners={"train": N_CLIENTS})

dataset = fds.load_partition(partition_id, "train").with_format("pandas")[:]
X = dataset[["petal_length", "petal_width", "sepal_length", "sepal_width"]]
y = dataset["species"]
unique_labels = fds.load_full("train").unique("species")
# Split the on edge data: 80% train, 20% test
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]

# Create LogisticRegression Model
model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)

# Setting initial parameters, akin to model.compile for keras models
utils.set_initial_params(model, n_features=X_train.shape[1], n_classes=3)

# Define Flower client
class IrisClient(fl.client.NumPyClient):
def get_parameters(self, config): # type: ignore
return utils.get_model_parameters(model)

def fit(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
# Ignore convergence failure due to low local epochs
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.fit(X_train, y_train)
accuracy = model.score(X_train, y_train)
return (
utils.get_model_parameters(model),
len(X_train),
{"train_accuracy": accuracy},
)

def evaluate(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test), labels=unique_labels)
accuracy = model.score(X_test, y_test)
return loss, len(X_test), {"test_accuracy": accuracy}

# Start Flower client
fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client())
18 changes: 18 additions & 0 deletions examples/quickstart-sklearn-tabular/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "sklearn-mnist"
version = "0.1.0"
description = "Federated learning with scikit-learn and Flower"
authors = [
"The Flower Authors <[email protected]>",
"Kaushik Amar Das <[email protected]>"
]

[tool.poetry.dependencies]
python = "^3.8"
flwr = ">=1.0,<2.0"
flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
scikit-learn = "^1.3.0"
3 changes: 3 additions & 0 deletions examples/quickstart-sklearn-tabular/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
flwr>=1.0, <2.0
flwr-datasets[vision]>=0.0.2, <1.0.0
scikit-learn>=1.3.0
17 changes: 17 additions & 0 deletions examples/quickstart-sklearn-tabular/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

echo "Starting server"
python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start

for i in $(seq 0 1); do
echo "Starting client $i"
python client.py --node-id "${i}" &
done

# This will allow you to use CTRL+C to stop all background processes
trap 'trap - SIGTERM && kill -- -$$' SIGINT SIGTERM
# Wait for all background processes to complete
wait
19 changes: 19 additions & 0 deletions examples/quickstart-sklearn-tabular/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import flwr as fl
import utils
from sklearn.linear_model import LogisticRegression


# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
model = LogisticRegression()
utils.set_initial_params(model, n_classes=3, n_features=4)
strategy = fl.server.strategy.FedAvg(
min_available_clients=2,
fit_metrics_aggregation_fn=utils.weighted_average,
evaluate_metrics_aggregation_fn=utils.weighted_average,
)
fl.server.start_server(
server_address="0.0.0.0:8080",
strategy=strategy,
config=fl.server.ServerConfig(num_rounds=25),
)
75 changes: 75 additions & 0 deletions examples/quickstart-sklearn-tabular/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import List, Tuple, Dict

import numpy as np
from sklearn.linear_model import LogisticRegression

from flwr.common import NDArrays, Metrics, Scalar


def get_model_parameters(model: LogisticRegression) -> NDArrays:
"""Return the parameters of a sklearn LogisticRegression model."""
if model.fit_intercept:
params = [
model.coef_,
model.intercept_,
]
else:
params = [
model.coef_,
]
return params


def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression:
"""Set the parameters of a sklean LogisticRegression model."""
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model


def set_initial_params(model: LogisticRegression, n_classes: int, n_features: int):
"""Set initial parameters as zeros.

Required since model params are uninitialized until model.fit is called but server
asks for initial parameters from clients at launch. Refer to
sklearn.linear_model.LogisticRegression documentation for more information.
"""
model.classes_ = np.array([i for i in range(n_classes)])

model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Dict[str, Scalar]:
"""Compute weighted average.

It is generic implementation that averages only over floats and ints and drops the
other data types of the Metrics.
"""
print(metrics)
# num_samples_list can represent number of sample or batches depending on the client
num_samples_list = [n_batches for n_batches, _ in metrics]
num_samples_sum = sum(num_samples_list)
metrics_lists: Dict[str, List[float]] = {}
for num_samples, all_metrics_dict in metrics:
# Calculate each metric one by one
for single_metric, value in all_metrics_dict.items():
if isinstance(value, (float, int)):
metrics_lists[single_metric] = []
# Just one iteration needed to initialize the keywords
break

for num_samples, all_metrics_dict in metrics:
# Calculate each metric one by one
for single_metric, value in all_metrics_dict.items():
# Add weighted metric
if isinstance(value, (float, int)):
metrics_lists[single_metric].append(float(num_samples * value))

weighted_metrics: Dict[str, Scalar] = {}
for metric_name, metric_values in metrics_lists.items():
weighted_metrics[metric_name] = sum(metric_values) / num_samples_sum

return weighted_metrics