diff --git a/examples/fl-dp-sa/README.md b/examples/fl-dp-sa/README.md new file mode 100644 index 000000000000..99a0a7e50980 --- /dev/null +++ b/examples/fl-dp-sa/README.md @@ -0,0 +1,22 @@ +# fl_dp_sa + +This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation. +Note: This example is designed for a small number of rounds and is intended for demonstration purposes. + +## Install dependencies + +```bash +# Using pip +pip install . + +# Or using Poetry +poetry install +``` + +## Run + +The example uses the CIFAR-10 dataset with a total of 100 clients, with 20 clients sampled in each round. The hyperparameters for DP and SecAgg are specified in `server.py`. + +```shell +flower-simulation --server-app fl_dp_sa.server:app --client-app fl_dp_sa.client:app --num-supernodes 100 +``` diff --git a/examples/fl-dp-sa/fl_dp_sa/__init__.py b/examples/fl-dp-sa/fl_dp_sa/__init__.py new file mode 100644 index 000000000000..741260348ab8 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/__init__.py @@ -0,0 +1 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" diff --git a/examples/fl-dp-sa/fl_dp_sa/client.py b/examples/fl-dp-sa/fl_dp_sa/client.py new file mode 100644 index 000000000000..104264158833 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/client.py @@ -0,0 +1,43 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" + +from flwr.client import ClientApp, NumPyClient +from flwr.client.mod import fixedclipping_mod, secaggplus_mod + +from fl_dp_sa.task import DEVICE, Net, get_weights, load_data, set_weights, test, train + + +# Load model and data (simple CNN, CIFAR-10) +net = Net().to(DEVICE) + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def __init__(self, trainloader, testloader) -> None: + self.trainloader = trainloader + self.testloader = testloader + + def fit(self, parameters, config): + set_weights(net, parameters) + results = train(net, self.trainloader, self.testloader, epochs=1, device=DEVICE) + return get_weights(net), len(self.trainloader.dataset), results + + def evaluate(self, parameters, config): + set_weights(net, parameters) + loss, accuracy = test(net, self.testloader) + return loss, len(self.testloader.dataset), {"accuracy": accuracy} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + trainloader, testloader = load_data(partition_id=int(cid)) + return FlowerClient(trainloader, testloader).to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + fixedclipping_mod, + ], +) diff --git a/examples/fl-dp-sa/fl_dp_sa/server.py b/examples/fl-dp-sa/fl_dp_sa/server.py new file mode 100644 index 000000000000..f7da75997e98 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/server.py @@ -0,0 +1,77 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" + +from typing import List, Tuple + +from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server.strategy import ( + DifferentialPrivacyClientSideFixedClipping, + FedAvg, +) +from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow + +from fl_dp_sa.task import Net, get_weights + + +# Define metric aggregation function +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + examples = [num_examples for num_examples, _ in metrics] + + # Multiply accuracy of each client by number of examples used + train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] + train_accuracies = [ + num_examples * m["train_accuracy"] for num_examples, m in metrics + ] + val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] + val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_loss": sum(train_losses) / sum(examples), + "train_accuracy": sum(train_accuracies) / sum(examples), + "val_loss": sum(val_losses) / sum(examples), + "val_accuracy": sum(val_accuracies) / sum(examples), + } + + +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + +# Define strategy +strategy = FedAvg( + fraction_fit=0.2, + fraction_evaluate=0.0, # Disable evaluation for demo purpose + min_fit_clients=20, + min_available_clients=20, + fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, +) +strategy = DifferentialPrivacyClientSideFixedClipping( + strategy, noise_multiplier=0.2, clipping_norm=10, num_sampled_clients=20 +) + + +app = ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the train/evaluate workflow + workflow = DefaultWorkflow( + fit_workflow=SecAggPlusWorkflow( + num_shares=7, + reconstruction_threshold=4, + ) + ) + + # Execute + workflow(driver, context) diff --git a/examples/fl-dp-sa/fl_dp_sa/task.py b/examples/fl-dp-sa/fl_dp_sa/task.py new file mode 100644 index 000000000000..3d506263d5a3 --- /dev/null +++ b/examples/fl-dp-sa/fl_dp_sa/task.py @@ -0,0 +1,110 @@ +"""fl_dp_sa: A Flower / PyTorch app.""" + +from collections import OrderedDict +from logging import INFO +from flwr_datasets import FederatedDataset + +import torch +import torch.nn as nn +import torch.nn.functional as F +from flwr.common.logger import log +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize, ToTensor + + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + """Model.""" + + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 6, 3, padding=1) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.size(0) + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(batch_size, -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + +def load_data(partition_id): + """Load partition CIFAR10 data.""" + fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + partition = fds.load_partition(partition_id) + # Divide data on each node: 80% train, 20% test + partition_train_test = partition.train_test_split(test_size=0.2) + pytorch_transforms = Compose([ToTensor(), Normalize((0.5,), (0.5,))]) + + def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["image"] = [pytorch_transforms(img) for img in batch["image"]] + return batch + + partition_train_test = partition_train_test.with_transform(apply_transforms) + trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) + testloader = DataLoader(partition_train_test["test"], batch_size=32) + return trainloader, testloader + + +def train(net, trainloader, valloader, epochs, device): + """Train the model on the training set.""" + net.to(device) # move model to GPU if available + criterion = torch.nn.CrossEntropyLoss().to(device) + optimizer = torch.optim.Adam(net.parameters()) + net.train() + for _ in range(epochs): + for batch in trainloader: + images = batch["image"].to(device) + labels = batch["label"].to(device) + optimizer.zero_grad() + loss = criterion(net(images), labels) + loss.backward() + optimizer.step() + + train_loss, train_acc = test(net, trainloader) + val_loss, val_acc = test(net, valloader) + + results = { + "train_loss": train_loss, + "train_accuracy": train_acc, + "val_loss": val_loss, + "val_accuracy": val_acc, + } + return results + + +def test(net, testloader): + """Validate the model on the test set.""" + net.to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + correct, loss = 0, 0.0 + with torch.no_grad(): + for batch in testloader: + images = batch["image"].to(DEVICE) + labels = batch["label"].to(DEVICE) + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) + loss += criterion(outputs, labels).item() + correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() + accuracy = correct / len(testloader.dataset) + return loss, accuracy + + +def get_weights(net): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_weights(net, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) diff --git a/examples/fl-dp-sa/flower.toml b/examples/fl-dp-sa/flower.toml new file mode 100644 index 000000000000..ea2e98206791 --- /dev/null +++ b/examples/fl-dp-sa/flower.toml @@ -0,0 +1,13 @@ +[project] +name = "fl_dp_sa" +version = "1.0.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[flower.components] +serverapp = "fl_dp_sa.server:app" +clientapp = "fl_dp_sa.client:app" diff --git a/examples/fl-dp-sa/pyproject.toml b/examples/fl-dp-sa/pyproject.toml new file mode 100644 index 000000000000..d30fa4675e34 --- /dev/null +++ b/examples/fl-dp-sa/pyproject.toml @@ -0,0 +1,21 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "fl-dp-sa" +version = "0.1.0" +description = "" +license = "Apache-2.0" +authors = [ + "The Flower Authors ", +] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.9" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240313", extras = ["simulation"] } +flwr-datasets = { version = "0.0.2", extras = ["vision"] } +torch = "2.2.1" +torchvision = "0.17.1" diff --git a/examples/fl-dp-sa/requirements.txt b/examples/fl-dp-sa/requirements.txt new file mode 100644 index 000000000000..ddb8a814447b --- /dev/null +++ b/examples/fl-dp-sa/requirements.txt @@ -0,0 +1,4 @@ +flwr-nightly[simulation]==1.8.0.dev20240313 +flwr-datasets[vision]==0.0.2 +torch==2.2.1 +torchvision==0.17.1