Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate advanced-pytorch example to use FDS #2805

Merged
merged 6 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/advanced-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Advanced Flower Example (PyTorch)

This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways:
This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways:

- 10 clients (instead of just 2)
- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set).
Expand Down Expand Up @@ -59,12 +59,13 @@ pip install -r requirements.txt

The included `run.sh` will start the Flower server (using `server.py`),
sleep for 2 seconds to ensure that the server is up, and then start 10 Flower clients (using `client.py`) with only a small subset of the data (in order to run on any machine),
but this can be changed by removing the `--toy True` argument in the script. You can simply start everything in a terminal as follows:
but this can be changed by removing the `--toy` argument in the script. You can simply start everything in a terminal as follows:

```shell
poetry run ./run.sh
# After activating your environment
./run.sh
```

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).

You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network).
You can also manually run `python3 server.py` and `python3 client.py --client-id <ID>` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network).
41 changes: 18 additions & 23 deletions examples/advanced-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import argparse
from collections import OrderedDict
import warnings
import datasets

warnings.filterwarnings("ignore")


class CifarClient(fl.client.NumPyClient):
def __init__(
self,
trainset: torchvision.datasets,
testset: torchvision.datasets,
device: str,
trainset: datasets.Dataset,
testset: datasets.Dataset,
device: torch.device,
validation_split: int = 0.1,
):
self.device = device
Expand All @@ -41,17 +42,14 @@ def fit(self, parameters, config):
batch_size: int = config["batch_size"]
epochs: int = config["local_epochs"]

n_valset = int(len(self.trainset) * self.validation_split)
train_valid = self.trainset.train_test_split(self.validation_split)
trainset = train_valid["train"]
valset = train_valid["test"]

valset = torch.utils.data.Subset(self.trainset, range(0, n_valset))
trainset = torch.utils.data.Subset(
self.trainset, range(n_valset, len(self.trainset))
)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=batch_size)

trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valLoader = DataLoader(valset, batch_size=batch_size)

results = utils.train(model, trainLoader, valLoader, epochs, self.device)
results = utils.train(model, train_loader, val_loader, epochs, self.device)

parameters_prime = utils.get_model_params(model)
num_examples_train = len(trainset)
Expand All @@ -73,13 +71,13 @@ def evaluate(self, parameters, config):
return float(loss), len(self.testset), {"accuracy": float(accuracy)}


def client_dry_run(device: str = "cpu"):
def client_dry_run(device: torch.device = "cpu"):
"""Weak tests to check whether all client methods are working as expected."""

model = utils.load_efficientnet(classes=10)
trainset, testset = utils.load_partition(0)
trainset = torch.utils.data.Subset(trainset, range(10))
testset = torch.utils.data.Subset(testset, range(10))
trainset = trainset.select(range(10))
testset = testset.select(range(10))
client = CifarClient(trainset, testset, device)
client.fit(
utils.get_model_params(model),
Expand All @@ -102,7 +100,7 @@ def main() -> None:
help="Do a dry-run to check the client",
)
parser.add_argument(
"--partition",
"--client-id",
type=int,
default=0,
choices=range(0, 10),
Expand All @@ -112,9 +110,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
type=bool,
default=False,
required=False,
action='store_true',
help="Set to true to quicky run the client using only 10 datasamples. \
Useful for testing purposes. Default: False",
)
Expand All @@ -136,12 +132,11 @@ def main() -> None:
client_dry_run(device)
else:
# Load a subset of CIFAR-10 to simulate the local data partition
trainset, testset = utils.load_partition(args.partition)
trainset, testset = utils.load_partition(args.client_id)

if args.toy:
trainset = torch.utils.data.Subset(trainset, range(10))
testset = torch.utils.data.Subset(testset, range(10))

trainset = trainset.select(range(10))
testset = testset.select(range(10))
# Start Flower client
client = CifarClient(trainset, testset, device)

Expand Down
1 change: 1 addition & 0 deletions examples/advanced-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ authors = [
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
torch = "1.13.1"
torchvision = "0.14.1"
validators = "0.18.2"
1 change: 1 addition & 0 deletions examples/advanced-pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
flwr>=1.0, <2.0
flwr-datasets[vision]>=0.0.2, <1.0.0
torch==1.13.1
torchvision==0.14.1
validators==0.18.2
9 changes: 3 additions & 6 deletions examples/advanced-pytorch/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

# Download the CIFAR-10 dataset
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./dataset', download=True)"

# Download the EfficientNetB0 model
python -c "import torch; torch.hub.load( \
'NVIDIA/DeepLearningExamples:torchhub', \
'nvidia_efficientnet_b0', pretrained=True)"

python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start
python server.py --toy &
sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset

for i in `seq 0 9`; do
echo "Starting client $i"
python client.py --partition=${i} --toy True &
python client.py --client-id=${i} --toy &
done

# Enable CTRL+C to stop all background processes
Expand Down
21 changes: 8 additions & 13 deletions examples/advanced-pytorch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import warnings

from flwr_datasets import FederatedDataset

warnings.filterwarnings("ignore")


Expand Down Expand Up @@ -39,18 +41,13 @@ def evaluate_config(server_round: int):
def get_evaluate_fn(model: torch.nn.Module, toy: bool):
"""Return an evaluation function for server-side evaluation."""

# Load data and model here to avoid the overhead of doing it in `evaluate` itself
trainset, _, _ = utils.load_data()

n_train = len(trainset)
# Load data here to avoid the overhead of doing it in `evaluate` itself
centralized_data = utils.load_centralized_data()
if toy:
# use only 10 samples as validation set
valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train))
else:
# Use the last 5k training examples as a validation set
valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train))
centralized_data = centralized_data.select(range(10))

valLoader = DataLoader(valset, batch_size=16)
val_loader = DataLoader(centralized_data, batch_size=16)

# The `evaluate` function will be called after every round
def evaluate(
Expand All @@ -63,7 +60,7 @@ def evaluate(
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)

loss, accuracy = utils.test(model, valLoader)
loss, accuracy = utils.test(model, val_loader)
return loss, {"accuracy": accuracy}

return evaluate
Expand All @@ -79,9 +76,7 @@ def main():
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--toy",
type=bool,
default=False,
required=False,
action='store_true',
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)
Expand Down
77 changes: 39 additions & 38 deletions examples/advanced-pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,45 @@
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from torch.utils.data import DataLoader

import warnings

warnings.filterwarnings("ignore")

# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from flwr_datasets import FederatedDataset

warnings.filterwarnings("ignore")

def load_data():
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
def load_partition(node_id, toy: bool = False):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
partition_train_test = partition_train_test.with_transform(apply_transforms)
return partition_train_test["train"], partition_train_test["test"]

num_examples = {"trainset": len(trainset), "testset": len(testset)}
return trainset, testset, num_examples

def load_centralized_data():
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
centralized_data = fds.load_full("test")
centralized_data = centralized_data.with_transform(apply_transforms)
return centralized_data

def load_partition(idx: int):
"""Load 1/10th of the training and test data to simulate a partition."""
assert idx in range(10)
trainset, testset, num_examples = load_data()
n_train = int(num_examples["trainset"] / 10)
n_test = int(num_examples["testset"] / 10)

train_parition = torch.utils.data.Subset(
trainset, range(idx * n_train, (idx + 1) * n_train)
)
test_parition = torch.utils.data.Subset(
testset, range(idx * n_test, (idx + 1) * n_test)
)
return (train_parition, test_parition)
def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
pytorch_transforms = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch


def train(net, trainloader, valloader, epochs, device: str = "cpu"):
def train(net, trainloader, valloader, epochs,
device: torch.device = torch.device("cpu")):
"""Train the network on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
Expand All @@ -53,7 +49,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"):
)
net.train()
for _ in range(epochs):
for images, labels in trainloader:
for batch in trainloader:
images, labels = batch["img"], batch["label"]
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
Expand All @@ -74,15 +71,17 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"):
return results


def test(net, testloader, steps: int = None, device: str = "cpu"):
def test(net, testloader, steps: int = None,
device: torch.device = torch.device("cpu")):
"""Validate the network on the entire test set."""
print("Starting evalutation...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(testloader):
for batch_idx, batch in enumerate(testloader):
images, labels = batch["img"], batch["label"]
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
Expand All @@ -109,12 +108,14 @@ def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int =
entrypoint: EfficientNet model to download.
For supported entrypoints, please refer
https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/
classes: Number of classes in final classifying layer. Leave as None to get the downloaded
classes: Number of classes in final classifying layer. Leave as None to get
the downloaded
model untouched.
Returns:
EfficientNet Model

Note: One alternative implementation can be found at https://github.com/lukemelas/EfficientNet-PyTorch
Note: One alternative implementation can be found at
https://github.com/lukemelas/EfficientNet-PyTorch
"""
efficientnet = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True
Expand Down