Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dp secagg demo example #3134

Merged
merged 39 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7ea2a4f
Add dp secagg demo example
mohammadnaseri Mar 12, 2024
3f32370
add sa
panh99 Mar 12, 2024
31165c2
Merge branch 'main' into add-dp-secagg-demo-example
mohammadnaseri Mar 13, 2024
1f17627
Add clientside fixed clipping
mohammadnaseri Mar 13, 2024
8674db0
Fix
mohammadnaseri Mar 13, 2024
c38fc90
update
panh99 Mar 13, 2024
55739f8
fix a bug causing data loading failure
panh99 Mar 13, 2024
17be0e0
Merge branch 'main' into add-dp-secagg-demo-example
panh99 Mar 13, 2024
0d3d023
update configs
panh99 Mar 13, 2024
503212f
Create README.md
mohammadnaseri Mar 13, 2024
13acf88
Update README.md
mohammadnaseri Mar 13, 2024
6441259
Merge branch 'main' into add-dp-secagg-demo-example
danieljanes Mar 13, 2024
9d1eadf
Update
mohammadnaseri Mar 13, 2024
62a2348
Update README.md
mohammadnaseri Mar 13, 2024
6986158
Add logging to central dp wrappers
mohammadnaseri Mar 13, 2024
ddd3f47
Update
mohammadnaseri Mar 13, 2024
f3b122b
Update
mohammadnaseri Mar 13, 2024
5af8720
Update
mohammadnaseri Mar 13, 2024
6369c8e
Update
mohammadnaseri Mar 13, 2024
b0fa420
Update
mohammadnaseri Mar 13, 2024
9200754
Update
mohammadnaseri Mar 13, 2024
4dfdbd1
Update
mohammadnaseri Mar 13, 2024
50b785d
Update
mohammadnaseri Mar 13, 2024
578d6b5
Update README.md
mohammadnaseri Mar 13, 2024
ef1e99e
Merge branch 'main' into add-dp-secagg-demo-example
mohammadnaseri Mar 13, 2024
b9273c7
use MNIST and Adam
panh99 Mar 14, 2024
c012814
Merge branch 'main' into add-dp-secagg-demo-example
mohammadnaseri Mar 14, 2024
ce3a774
Merge branch 'main' into add-dp-secagg-demo-example
danieljanes Mar 14, 2024
abe665b
Change noise multiplier
mohammadnaseri Mar 14, 2024
53f52f3
Organize imports, update comments
danieljanes Mar 14, 2024
e843038
Merge branch 'add-dp-secagg-demo-example' of github.com:adap/flower i…
danieljanes Mar 14, 2024
23ec20e
Move code into subdir
danieljanes Mar 14, 2024
77f21ed
Update imports
danieljanes Mar 14, 2024
2c52ac7
Add flower.toml/pyproject.toml/requirements.txt
danieljanes Mar 14, 2024
7c3d1b3
Update README
danieljanes Mar 14, 2024
ac31041
Rename example
danieljanes Mar 14, 2024
3e1a121
Update pyproject.toml
danieljanes Mar 14, 2024
3ffe3f5
Remove log
danieljanes Mar 15, 2024
4e2da5e
Merge branch 'main' into add-dp-secagg-demo-example
danieljanes Mar 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/fl-dp-sa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# fl_dp_sa

This is a simple example that utilizes central differential privacy with client-side fixed clipping and secure aggregation.
Note: This example is designed for a small number of rounds and is intended for demonstration purposes.

## Install dependencies

```bash
# Using pip
pip install .

# Or using Poetry
poetry install
```

## Run

The example uses the CIFAR-10 dataset with a total of 100 clients, with 20 clients sampled in each round. The hyperparameters for DP and SecAgg are specified in `server.py`.

```shell
flower-simulation --server-app fl_dp_sa.server:app --client-app fl_dp_sa.client:app --num-supernodes 100
```
1 change: 1 addition & 0 deletions examples/fl-dp-sa/fl_dp_sa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""fl_dp_sa: A Flower / PyTorch app."""
43 changes: 43 additions & 0 deletions examples/fl-dp-sa/fl_dp_sa/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""fl_dp_sa: A Flower / PyTorch app."""

from flwr.client import ClientApp, NumPyClient
from flwr.client.mod import fixedclipping_mod, secaggplus_mod

from fl_dp_sa.task import DEVICE, Net, get_weights, load_data, set_weights, test, train


# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)


# Define FlowerClient and client_fn
class FlowerClient(NumPyClient):
def __init__(self, trainloader, testloader) -> None:
self.trainloader = trainloader
self.testloader = testloader

def fit(self, parameters, config):
set_weights(net, parameters)
results = train(net, self.trainloader, self.testloader, epochs=1, device=DEVICE)
return get_weights(net), len(self.trainloader.dataset), results

def evaluate(self, parameters, config):
set_weights(net, parameters)
loss, accuracy = test(net, self.testloader)
return loss, len(self.testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
"""Create and return an instance of Flower `Client`."""
trainloader, testloader = load_data(partition_id=int(cid))
return FlowerClient(trainloader, testloader).to_client()


# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
mods=[
secaggplus_mod,
fixedclipping_mod,
],
)
77 changes: 77 additions & 0 deletions examples/fl-dp-sa/fl_dp_sa/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""fl_dp_sa: A Flower / PyTorch app."""

from typing import List, Tuple

from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig
from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.server.strategy import (
DifferentialPrivacyClientSideFixedClipping,
FedAvg,
)
from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow

from fl_dp_sa.task import Net, get_weights


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [
num_examples * m["train_accuracy"] for num_examples, m in metrics
]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
"val_loss": sum(val_losses) / sum(examples),
"val_accuracy": sum(val_accuracies) / sum(examples),
}


# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)


# Define strategy
strategy = FedAvg(
fraction_fit=0.2,
fraction_evaluate=0.0, # Disable evaluation for demo purpose
min_fit_clients=20,
min_available_clients=20,
fit_metrics_aggregation_fn=weighted_average,
initial_parameters=parameters,
)
strategy = DifferentialPrivacyClientSideFixedClipping(
strategy, noise_multiplier=0.2, clipping_norm=10, num_sampled_clients=20
)


app = ServerApp()


@app.main()
def main(driver: Driver, context: Context) -> None:
# Construct the LegacyContext
context = LegacyContext(
state=context.state,
config=ServerConfig(num_rounds=3),
strategy=strategy,
)

# Create the train/evaluate workflow
workflow = DefaultWorkflow(
fit_workflow=SecAggPlusWorkflow(
num_shares=7,
reconstruction_threshold=4,
)
)

# Execute
workflow(driver, context)
110 changes: 110 additions & 0 deletions examples/fl-dp-sa/fl_dp_sa/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""fl_dp_sa: A Flower / PyTorch app."""

from collections import OrderedDict
from logging import INFO
from flwr_datasets import FederatedDataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.common.logger import log
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model."""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def load_data(partition_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose([ToTensor(), Normalize((0.5,), (0.5,))])

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


def train(net, trainloader, valloader, epochs, device):
"""Train the model on the training set."""
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters())
net.train()
for _ in range(epochs):
for batch in trainloader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()

train_loss, train_acc = test(net, trainloader)
val_loss, val_acc = test(net, valloader)

results = {
"train_loss": train_loss,
"train_accuracy": train_acc,
"val_loss": val_loss,
"val_accuracy": val_acc,
}
return results


def test(net, testloader):
"""Validate the model on the test set."""
net.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["image"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def get_weights(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
13 changes: 13 additions & 0 deletions examples/fl-dp-sa/flower.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[project]
name = "fl_dp_sa"
version = "1.0.0"
description = ""
license = "Apache-2.0"
authors = [
"The Flower Authors <[email protected]>",
]
readme = "README.md"

[flower.components]
serverapp = "fl_dp_sa.server:app"
clientapp = "fl_dp_sa.client:app"
21 changes: 21 additions & 0 deletions examples/fl-dp-sa/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "fl-dp-sa"
version = "0.1.0"
description = ""
license = "Apache-2.0"
authors = [
"The Flower Authors <[email protected]>",
]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.9"
# Mandatory dependencies
flwr-nightly = { version = "1.8.0.dev20240313", extras = ["simulation"] }
flwr-datasets = { version = "0.0.2", extras = ["vision"] }
torch = "2.2.1"
torchvision = "0.17.1"
4 changes: 4 additions & 0 deletions examples/fl-dp-sa/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
flwr-nightly[simulation]==1.8.0.dev20240313
flwr-datasets[vision]==0.0.2
torch==2.2.1
torchvision==0.17.1