Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(examples) Add low-level API advanced-pytorch example #4186

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ad68718
init
jafermarq Sep 3, 2024
6d58c2c
stateful client + personalization
jafermarq Sep 3, 2024
89940d2
+ reamde
jafermarq Sep 3, 2024
9fad231
stateful clients + plots
jafermarq Sep 3, 2024
1a46f1a
w/ prev
jafermarq Sep 3, 2024
4240deb
fixes
jafermarq Sep 3, 2024
0b95192
init low-level dir
jafermarq Sep 3, 2024
451467e
basic lr decay
jafermarq Sep 3, 2024
43fdd75
format
jafermarq Sep 3, 2024
449b03d
low-level init
jafermarq Sep 3, 2024
f9065bd
enhancements here and there
jafermarq Sep 3, 2024
3cea1e6
miniaml readme for low-level version
jafermarq Sep 4, 2024
a9fb6c8
refactor(examples) Suggest edits to `advanced-pytorch` example (#4134)
chongshenng Sep 4, 2024
1cc1540
format
jafermarq Sep 4, 2024
e306a24
Merge branch 'main' into new-advanced-pytorch-example
jafermarq Sep 4, 2024
34966ed
w/ prev
jafermarq Sep 4, 2024
ec4d1b1
opt-in query; eval loop; create run dir; save as json
jafermarq Sep 5, 2024
265ba2b
w/ previous (active timeout for query stage)
jafermarq Sep 5, 2024
520c82e
w&b support for serverapp
jafermarq Sep 5, 2024
f200630
saving to results json; updated readme low-level
jafermarq Sep 5, 2024
bf52be7
w/ previous
jafermarq Sep 5, 2024
814bb43
wandb plot; gitignore
jafermarq Sep 5, 2024
44c0fab
instructions how to run on GPU
jafermarq Sep 5, 2024
d872b6b
additional federation; dataset partition plot; other tweaks
jafermarq Sep 5, 2024
b54753b
tweak
jafermarq Sep 5, 2024
2a5a963
Merge branch 'main' into new-advanced-pytorch-example
jafermarq Sep 5, 2024
e0a49c1
Merge branch 'main' into new-advanced-pytorch-example
jafermarq Sep 12, 2024
f98a354
merge w/ main
jafermarq Sep 17, 2024
075efe8
merge w/ main
jafermarq Sep 20, 2024
f481444
moved things around
jafermarq Sep 20, 2024
6940f43
discard prev
jafermarq Sep 20, 2024
1204920
updated readme and pyproject.toml
jafermarq Sep 20, 2024
09f161a
loop until nodes are connected
jafermarq Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/advanced-pytorch-low-level/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__/
outputs/
wandb/
111 changes: 111 additions & 0 deletions examples/advanced-pytorch-low-level/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
---
tags: [advanced, vision, fds, wandb, low-level]
dataset: [Fashion-MNIST]
framework: [torch, torchvision]
---

# Federated Learning with PyTorch and Flower (Advanced Example with Low level API)

> \[!CAUTION\]
> This example uses Flower's low-level API which is a preview feature and subject to change. If you are not ready for the low-level API, the [advanced-pytorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch) example demonstrates near identical functionality but using higher level components such as Flower's _Strategies_ and _NumPyClient_.

This example demonstrates how to use Flower's low-level API to write a `ServerApp` a _"for loop"_, enabling you to define what a "round" means and construct [Message](https://flower.ai/docs/framework/ref-api/flwr.common.Message.html) objects to communicate arbitrary data structures as [RecordSet](https://flower.ai/docs/framework/ref-api/flwr.common.RecordSet.html) objects. Just like the the counterpart to this example using the strategies API (find it in the parent directory), it:

1. Save model checkpoints
2. Save the metrics available at the strategy (e.g. accuracies, losses)
3. Log training artefacts to [Weights & Biases](https://wandb.ai/site)
4. Implement a simple decaying learning rate schedule across rounds

> \[!NOTE\]
> The code in this example is particularly rich in comments, but the code itself is intended to be easy to follow. Note that in `task.py` you'll make use of many of the same components (model, train/evaluate functions, data loaders) as were first presented in the [advanced-pytorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch) example that uses strategies.

This examples uses [Flower Datasets](https://flower.ai/docs/datasets/) with the [Dirichlet Partitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.DirichletPartitioner.html#flwr_datasets.partitioner.DirichletPartitioner) to partition the [Fashion-MNIST](https://huggingface.co/datasets/zalando-datasets/fashion_mnist) dataset in a non-IID fashion into 50 partitions.

![](_static/fmnist_50_lda.png)

> \[!TIP\]
> You can use Flower Datasets [built-in visualization tools](https://flower.ai/docs/datasets/tutorial-visualize-label-distribution.html) to easily generate plots like the one above.

```shell
advanced-pytorch-low-level
├── pytorch_example_low_level
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ ├── task.py # Defines your model, training and data loading
│ └── utils.py # Defines utility functions
├── pyproject.toml # Project metadata like dependencies and configs
└── README.md
```

### Install dependencies and project

Install the dependencies defined in `pyproject.toml` as well as the `pytorch_example_low_level` package.

```bash
pip install -e .
```

## Run the project

The low-level `ServerApp` implemented in this example will go through these steps on each round:

1. Uniformly sample a % of the connected nodes
2. Involve the selected nodes in a round of training, where they'll train the global model on their local data.
3. Aggregate the received models
4. Query all nodes and those that return `True` will be consider in the next step
5. Share the global model with selected nodes so they evaluate it on their local validation sets
6. Compute the average accuracy and loss from the received results.

The low-level API also gives you full control on what gets logged when running you Flower apps. Running this example as shown below will generate a log like this:

```bash
...
INFO :
INFO : 🔄 Starting round 2/10
INFO : Sampled 10 out of 50 nodes.
INFO : 📥 Received 10/10 results (TRAIN)
INFO : 💡 Centrally evaluated model -> loss: 1.6017 / accuracy: 0.4556
INFO : 🎉 New best global model found: 0.455600
INFO : 📨 Received 12/50 results (QUERY)
INFO : ✅ 6/50 nodes opted-in for evaluation (QUERY)
INFO : 📥 Received 6/6 results (EVALUATE)
INFO : 📊 Federated evaluation -> loss: 1.605±0.116 / accuracy: 0.522±0.105
INFO :
...
```

By default, the metrics: {`centralized_accuracy`, `centralized_loss`, `federated_evaluate_accuracy`, `federated_evaluate_loss`} will be logged to Weights & Biases (they are also stored to the `results.json` previously mentioned). Upon executing `flwr run` you'll see a URL linking to your Weight&Biases dashboard wher you can see the metrics.

![](_static/wandb_plots.png)

### Run with the Simulation Engine

With default parameters, 20% of the total 50 nodes (see `num-supernodes` in `pyproject.toml`) will be sampled in each round. By default `ClientApp` objects will run on CPU.

> \[!TIP\]
> To run your `ClientApps` on GPU or to adjust the degree or parallelism of your simulation, edit the `[tool.flwr.federations.local-simulation]` section in the `pyproject.tom`.

```bash
flwr run .

# To disable W&B
flwr run . --run-config use-wandb=false
```

You can run the app using another federation (see `pyproject.toml`). For example, if you have a GPU available, select the `local-sim-gpu` federation:

```bash
flwr run . local-sim-gpu
```

You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example:

```bash
flwr run . --run-config "num-server-rounds=5 fraction-clients-train=0.5"
```

### Run with the Deployment Engine

> \[!NOTE\]
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions examples/advanced-pytorch-low-level/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "pytorch-example-low-level"
version = "1.0.0"
description = "Federated Learning with PyTorch and Flower (Advanced Example with Low level API)"
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.11.1",
"flwr-datasets[vision]>=0.3.0",
"torch==2.2.1",
"torchvision==0.17.1",
"wandb==0.17.8",
]

[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.flwr.app]
publisher = "flwrlabs"

[tool.flwr.app.components]
serverapp = "pytorch_example_low_level.server_app:app"
clientapp = "pytorch_example_low_level.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 10
fraction-clients-train = 0.2
batch-size = 32
local-epochs = 1
server-device = "cpu"
use-wandb = true

[tool.flwr.federations]
default = "local-sim"

[tool.flwr.federations.local-sim]
options.num-supernodes = 50
options.backend.client-resources.num-cpus = 2 # each ClientApp assumes to use 2CPUs
options.backend.client-resources.num-gpus = 0.0 # ratio of VRAM a ClientApp has access to


[tool.flwr.federations.local-sim-gpu]
options.num-supernodes = 50
options.backend.client-resources.num-cpus = 2
options.backend.client-resources.num-gpus = 0.25
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""pytorch-example-low-level: A low-level Flower / PyTorch app."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""pytorch-example-low-level: A low-level Flower / PyTorch app."""

import random
import time

import torch
from pytorch_example_low_level.task import Net, load_data, test, train
from pytorch_example_low_level.utils import (
parameters_record_to_state_dict,
state_dict_to_parameters_record,
)

from flwr.client import ClientApp
from flwr.common import ConfigsRecord, Context, Message, MetricsRecord, RecordSet

# Flower ClientApp
app = ClientApp()


@app.train()
def train_fn(msg: Message, context: Context):
"""A method that trains the received model on the local train set."""

# Initialize model
model = Net()
# Dynamically determine device (best for simulations)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load this `ClientApp`'s dataset
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, _ = load_data(partition_id, num_partitions)

# Extract model received from `ServerApp`
p_record = msg.content.parameters_records["global_model_record"]
state_dict = parameters_record_to_state_dict(p_record)

# apply to local PyTorch model
model.load_state_dict(state_dict)

# Get learning rate value sent from `ServerApp`
lr = msg.content.configs_records["config"]["lr"]
# Train with local dataset
_ = train(
model,
trainloader,
context.run_config["local-epochs"],
lr=lr,
device=device,
)

# Put resulting model into a ParametersRecord
p_record = state_dict_to_parameters_record(model.state_dict())

# Send reply back to `ServerApp`
reply_content = RecordSet()
reply_content.parameters_records["updated_model_dict"] = p_record
# Return message
return msg.create_reply(reply_content)


@app.evaluate()
def eval_fn(msg: Message, context: Context):
"""A method that evaluates the received model on the local validation set."""

# Initialize model
model = Net()
# Dynamically determine device (best for simulations)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load this `ClientApp`'s dataset
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
_, evalloader = load_data(partition_id, num_partitions)

# Extract model received from `ServerApp`
p_record = msg.content.parameters_records["global_model_record"]
state_dict = parameters_record_to_state_dict(p_record)

# apply to local PyTorch model
model.load_state_dict(state_dict)

# Evaluate with local dataset
loss, accuracy = test(
model,
evalloader,
device=device,
)

# Put resulting metrics into a MetricsRecord
m_record = MetricsRecord({"loss": loss, "accuracy": accuracy})

# Send reply back to `ServerApp`
reply_content = RecordSet()
reply_content.metrics_records["clientapp-evaluate"] = m_record
# Return message
return msg.create_reply(reply_content)


@app.query()
def query(msg: Message, context: Context):
"""A basic query method that aims to exemplify some opt-in functionality.

The node running this `ClientApp` reacts to an incomming message by returning
a `True` or a `False`. If `True`, this node will be sampled by the `ServerApp`
to receive the global model and do evaluation in its `@app.eval()` method.
"""

# Inspect message
c_record = msg.content.configs_records["query-config"]
# print(f"Received: {c_record = }")

# Sleep for a random amount of time, will result in some nodes not
# repling back to the `ServerApp` in time
time.sleep(random.randint(0, 2))

# Randomly set True or False as opt-in in the evaluation stage
# Note the keys used for the records below are arbitrary, but both `ServerApp`
# and `ClientApp` need to be aware of them.
c_record_response = ConfigsRecord(
{"opt-in": random.random() > 0.5, "ts": time.time()}
)
reply_content = RecordSet(configs_records={"query-response": c_record_response})

return msg.create_reply(content=reply_content)
Loading
Loading