Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PyTorch flwr run template use Flower Datasets #3133

Merged
merged 7 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""$project_name: A Flower / PyTorch app."""

from flwr.client import NumPyClient, ClientApp
from flwr.cli.flower_toml import load_and_validate_with_defaults

from $project_name.task import (
Net,
Expand Down Expand Up @@ -31,10 +32,17 @@ class FlowerClient(NumPyClient):
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


# Load config
cfg, *_ = load_and_validate_with_defaults()

def client_fn(cid: str):
# Load model and data
net = Net().to(DEVICE)
trainloader, valloader = load_data()
engine = cfg["flower"]["engine"]
num_partitions = 2
if "simulation" in engine:
num_partitions = engine["simulation"]["supernode"]["num"]
trainloader, valloader = load_data(int(cid), num_partitions)

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()
Expand Down
44 changes: 28 additions & 16 deletions src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

from flwr_datasets import FederatedDataset

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand All @@ -34,27 +34,39 @@ class Net(nn.Module):
return self.fc3(x)


def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)
def load_data(partition_id: int, total_partitions: int):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": total_partitions})
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


def train(net, trainloader, valloader, epochs, device):
"""Train the model on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net.train()
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
for batch in trainloader:
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()

train_loss, train_acc = test(net, trainloader)
Expand All @@ -71,13 +83,13 @@ def train(net, trainloader, valloader, epochs, device):

def test(net, testloader):
"""Validate the model on the test set."""
net.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
for batch in testloader:
images = batch["img"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
Expand Down