Skip to content

Commit

Permalink
use MNIST and Adam
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Mar 14, 2024
1 parent ef1e99e commit b9273c7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/dp-secagg-demo/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(driver: Driver, context: Context) -> None:
# Construct the LegacyContext
context = LegacyContext(
state=context.state,
config=fl.server.ServerConfig(num_rounds=10),
config=fl.server.ServerConfig(num_rounds=3),
strategy=dp_strategy,
)

Expand Down
22 changes: 10 additions & 12 deletions examples/dp-secagg-demo/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,45 @@
import torch.nn.functional as F
from flwr.common.logger import log
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
"""Model."""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv1 = nn.Conv2d(1, 6, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def load_data(partition_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 100})
fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
pytorch_transforms = Compose([ToTensor(), Normalize((0.5,), (0.5,))])

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
Expand All @@ -61,11 +59,11 @@ def train(net, trainloader, valloader, epochs, device):
log(INFO, "Starting training...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.Adam(net.parameters())
net.train()
for _ in range(epochs):
for batch in trainloader:
images = batch["img"].to(device)
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
Expand All @@ -91,7 +89,7 @@ def test(net, testloader):
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(DEVICE)
images = batch["image"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
Expand Down

0 comments on commit b9273c7

Please sign in to comment.