Skip to content

Commit

Permalink
Merge branch 'main' into fds-bump-up-datasets-version
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Oct 18, 2024
2 parents 1c61eaa + e827ded commit c952d8c
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 168 deletions.
2 changes: 2 additions & 0 deletions examples/flower-authentication/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
keys/
certificates/
116 changes: 75 additions & 41 deletions examples/flower-authentication/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,75 @@ dataset: [CIFAR-10]
framework: [torch, torchvision]
---

# Flower Authentication with PyTorch 🧪
# Flower Federations with Authentication 🧪

> 🧪 = This example covers experimental features that might change in future versions of Flower
> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch.
> \[!NOTE\]
> 🧪 = This example covers experimental features that might change in future versions of Flower.
> Please consult the regular PyTorch examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch.
The following steps describe how to start a long-running Flower server (SuperLink) and a long-running Flower client (SuperNode) with authentication enabled.
The following steps describe how to start a long-running Flower server (SuperLink+SuperExec) and a long-running Flower clients (SuperNode) with authentication enabled. The task is to train a simple CNN for image classification using PyTorch.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git _tmp && mv _tmp/examples/flower-authentication . && rm -rf _tmp && cd flower-authentication
git clone --depth=1 https://github.com/adap/flower.git _tmp \
&& mv _tmp/examples/flower-authentication . \
&& rm -rf _tmp && cd flower-authentication
```

This will create a new directory called `flower-authentication` with the following project structure:

```bash
$ tree .
.
├── certificate.conf # <-- configuration for OpenSSL
├── generate.sh # <-- generate certificates and keys
├── pyproject.toml # <-- project dependencies
├── client.py # <-- contains `ClientApp`
├── server.py # <-- contains `ServerApp`
└── task.py # <-- task-specific code (model, data)
```shell
flower-authentication
├── authexample
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ └── task.py # Defines your model, training and data loading
├── pyproject.toml # Project metadata like dependencies and configs
├── certificate.conf # Configuration for OpenSSL
├── generate.sh # Generate certificates and keys
├── prepare_dataset.py # Generate datasets for each SuperNode to use
└── README.md
```

## Install dependencies
### Install dependencies and project

Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml`. You can install the dependencies by invoking `pip`:
Install the dependencies defined in `pyproject.toml` as well as the `authexample` package.

```shell
# From a new python environment, run:
pip install .
```bash
pip install -e .
```

Then, to verify that everything works correctly you can run the following command:

```shell
python3 -c "import flwr"
```
## Generate public and private keys

If you don't see any errors you're good to go!
The `generate.sh` script by default generates certificates for creating a secure TLS connection
and three private and public key pairs for one server and two clients.

## Generate public and private keys
> \[!NOTE\]
> Note that this script should only be used for development purposes and not for creating production key pairs.
```bash
./generate.sh
```

`generate.sh` is a script that (by default) generates certificates for creating a secure TLS connection
and three private and public key pairs for one server and two clients.
You can generate more keys by specifying the number of client credentials that you wish to generate.
The script also generates a CSV file that includes each of the generated (client) public keys.

⚠️ Note that this script should only be used for development purposes and not for creating production key pairs.

```bash
./generate.sh {your_number_of_clients}
```

## Start the long-running Flower server (SuperLink)
## Start the long-running Flower server-side (SuperLink+SuperExec)

To start a long-running Flower server (SuperLink) and enable authentication is very easy; all you need to do is type
Starting long-running Flower server-side components (SuperLink+SuperExec) and enable authentication is very easy; all you need to do is type
`--auth-list-public-keys` containing file path to the known `client_public_keys.csv`, `--auth-superlink-private-key`
containing file path to the SuperLink's private key `server_credentials`, and `--auth-superlink-public-key` containing file path to the SuperLink's public key `server_credentials.pub`. Notice that you can only enable authentication with a secure TLS connection.

Let's first launche the `SuperLink`:

```bash
flower-superlink \
--ssl-ca-certfile certificates/ca.crt \
Expand All @@ -82,35 +83,68 @@ flower-superlink \
--auth-superlink-public-key keys/server_credentials.pub
```

## Start the long-running Flower client (SuperNode)
Then launch the `SuperExec`:

```bash
flower-superexec \
--ssl-ca-certfile certificates/ca.crt \
--ssl-certfile certificates/server.pem \
--ssl-keyfile certificates/server.key \
--executor-config '--executor-config 'root-certificates=\"certificates/ca.crt\"'' \
--executor flwr.superexec.deployment:executor

```

At this point your server-side is idling. First, let's connect two `SuperNodes`, and then we'll start a run.

## Start the long-running Flower client-side (SuperNode)

> \[!NOTE\]
> Typically each `SuperNode` runs in a different entity/organization which has access to a dataset. In this example we are going to artificially create N dataset splits and saved them into a new directory called `datasets/`. Then, each `SuperNode` will be pointed to the dataset it should load via the `--node-config` argument. We provide a script that does the download, partition and saving of CIFAR-10.
```bash
python prepare_dataset.py
```

In a new terminal window, start the first long-running Flower client (SuperNode):

```bash
flower-client-app client:app \
flower-supernode \
--root-certificates certificates/ca.crt \
--server 127.0.0.1:9092 \
--superlink 127.0.0.1:9092 \
--auth-supernode-private-key keys/client_credentials_1 \
--auth-supernode-public-key keys/client_credentials_1.pub
--auth-supernode-public-key keys/client_credentials_1.pub \
--node-config 'dataset-path="datasets/cifar10_part_1"'
```

In yet another new terminal window, start the second long-running Flower client:

```bash
flower-client-app client:app \
flower-supernode \
--root-certificates certificates/ca.crt \
--server 127.0.0.1:9092 \
--superlink 127.0.0.1:9092 \
--auth-supernode-private-key keys/client_credentials_2 \
--auth-supernode-public-key keys/client_credentials_2.pub
--auth-supernode-public-key keys/client_credentials_2.pub \
--node-config 'dataset-path="datasets/cifar10_part_2"'
```

If you generated more than 2 client credentials, you can add more clients by opening new terminal windows and running the command
above. Don't forget to specify the correct client private and public keys for each client instance you created.

> \[!TIP\]
> Note the `--node-config` passed when spawning the `SuperNode` is accessible to the `ClientApp` via the context. In this example, the `client_fn()` uses it to load the dataset and then proceed with the training of the model.
>
> ```python
> def client_fn(context: Context):
> # retrieve the passed `--node-config`
> dataset_path = context.node_config["dataset-path"]
> # then load the dataset
> ```
## Run the Flower App
With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower ServerApp:
With both the long-running server-side (SuperLink+SuperExec) and two SuperNodes up and running, we can now start run. Note that the command below points to a federation named `my-federation`. Its entry point is defined in the `pyproject.toml`.
```bash
flower-server-app server:app --root-certificates certificates/ca.crt --dir ./ --server 127.0.0.1:9091
flwr run . my-federation
```
1 change: 1 addition & 0 deletions examples/flower-authentication/authexample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""authexample."""
65 changes: 65 additions & 0 deletions examples/flower-authentication/authexample/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""authexample: An authenticated Flower / PyTorch app."""

import torch
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

from authexample.task import (
Net,
get_weights,
load_data_from_disk,
set_weights,
test,
train,
)


# Define Flower Client
class FlowerClient(NumPyClient):
def __init__(self, trainloader, valloader, local_epochs, learning_rate):
self.net = Net()
self.trainloader = trainloader
self.valloader = valloader
self.local_epochs = local_epochs
self.lr = learning_rate
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def fit(self, parameters, config):
"""Train the model with data of this client."""
set_weights(self.net, parameters)
results = train(
self.net,
self.trainloader,
self.valloader,
self.local_epochs,
self.lr,
self.device,
)
return get_weights(self.net), len(self.trainloader.dataset), results

def evaluate(self, parameters, config):
"""Evaluate the model on the data this client has."""
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.valloader, self.device)
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
"""Construct a Client that will be run in a ClientApp."""

# Read the node_config to get the path to the dataset the SuperNode running
# this ClientApp has access to
dataset_path = context.node_config["dataset-path"]

# Read run_config to fetch hyperparameters relevant to this run
batch_size = context.run_config["batch-size"]
trainloader, valloader = load_data_from_disk(dataset_path, batch_size)
local_epochs = context.run_config["local-epochs"]
learning_rate = context.run_config["learning-rate"]

# Return Client instance
return FlowerClient(trainloader, valloader, local_epochs, learning_rate).to_client()


# Flower ClientApp
app = ClientApp(client_fn)
46 changes: 46 additions & 0 deletions examples/flower-authentication/authexample/server_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""authexample: An authenticated Flower / PyTorch app."""

from typing import List, Tuple

from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg

from authexample.task import Net, get_weights


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


def server_fn(context: Context):
"""Construct components that set the ServerApp behaviour."""

# Read from config
num_rounds = context.run_config["num-server-rounds"]

# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)

# Define the strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=context.run_config["fraction-evaluate"],
min_available_clients=2,
evaluate_metrics_aggregation_fn=weighted_average,
initial_parameters=parameters,
)
config = ServerConfig(num_rounds=num_rounds)

return ServerAppComponents(strategy=strategy, config=config)


# Create ServerApp
app = ServerApp(server_fn=server_fn)
Loading

0 comments on commit c952d8c

Please sign in to comment.