From b9273c711c052949ed49efda909555bf05d5e253 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 14 Mar 2024 09:11:59 +0000 Subject: [PATCH] use MNIST and Adam --- examples/dp-secagg-demo/server.py | 2 +- examples/dp-secagg-demo/task.py | 22 ++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/dp-secagg-demo/server.py b/examples/dp-secagg-demo/server.py index ff7d81b7ecd4..47411b7b62fe 100644 --- a/examples/dp-secagg-demo/server.py +++ b/examples/dp-secagg-demo/server.py @@ -61,7 +61,7 @@ def main(driver: Driver, context: Context) -> None: # Construct the LegacyContext context = LegacyContext( state=context.state, - config=fl.server.ServerConfig(num_rounds=10), + config=fl.server.ServerConfig(num_rounds=3), strategy=dp_strategy, ) diff --git a/examples/dp-secagg-demo/task.py b/examples/dp-secagg-demo/task.py index 2fe68260c3f5..fa17fe608a9c 100644 --- a/examples/dp-secagg-demo/task.py +++ b/examples/dp-secagg-demo/task.py @@ -7,7 +7,6 @@ import torch.nn.functional as F from flwr.common.logger import log from torch.utils.data import DataLoader -from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor @@ -15,11 +14,11 @@ class Net(nn.Module): - """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" + """Model.""" def __init__(self) -> None: super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) + self.conv1 = nn.Conv2d(1, 6, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) @@ -27,9 +26,10 @@ def __init__(self) -> None: self.fc3 = nn.Linear(84, 10) def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size = x.size(0) x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) + x = x.view(batch_size, -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x) @@ -37,17 +37,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def load_data(partition_id): """Load partition CIFAR10 data.""" - fds = FederatedDataset(dataset="cifar10", partitioners={"train": 100}) + fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) partition = fds.load_partition(partition_id) # Divide data on each node: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2) - pytorch_transforms = Compose( - [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) + pytorch_transforms = Compose([ToTensor(), Normalize((0.5,), (0.5,))]) def apply_transforms(batch): """Apply transforms to the partition from FederatedDataset.""" - batch["img"] = [pytorch_transforms(img) for img in batch["img"]] + batch["image"] = [pytorch_transforms(img) for img in batch["image"]] return batch partition_train_test = partition_train_test.with_transform(apply_transforms) @@ -61,11 +59,11 @@ def train(net, trainloader, valloader, epochs, device): log(INFO, "Starting training...") net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss().to(device) - optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + optimizer = torch.optim.Adam(net.parameters()) net.train() for _ in range(epochs): for batch in trainloader: - images = batch["img"].to(device) + images = batch["image"].to(device) labels = batch["label"].to(device) optimizer.zero_grad() loss = criterion(net(images), labels) @@ -91,7 +89,7 @@ def test(net, testloader): correct, loss = 0, 0.0 with torch.no_grad(): for batch in testloader: - images = batch["img"].to(DEVICE) + images = batch["image"].to(DEVICE) labels = batch["label"].to(DEVICE) outputs = net(images.to(DEVICE)) labels = labels.to(DEVICE)