Skip to content

Commit

Permalink
Revert "Make PyTorch flwr run template use Flower Datasets (#3133
Browse files Browse the repository at this point in the history
…)"

This reverts commit 0f2df33.
  • Loading branch information
tanertopal authored Mar 14, 2024
1 parent 0f2df33 commit 2764379
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 37 deletions.
10 changes: 1 addition & 9 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""$project_name: A Flower / PyTorch app."""

from flwr.client import NumPyClient, ClientApp
from flwr.cli.flower_toml import load_and_validate_with_defaults

from $project_name.task import (
Net,
Expand Down Expand Up @@ -32,17 +31,10 @@ class FlowerClient(NumPyClient):
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


# Load config
cfg, *_ = load_and_validate_with_defaults()

def client_fn(cid: str):
# Load model and data
net = Net().to(DEVICE)
engine = cfg["flower"]["engine"]
num_partitions = 2
if "simulation" in engine:
num_partitions = engine["simulation"]["supernode"]["num"]
trainloader, valloader = load_data(int(cid), num_partitions)
trainloader, valloader = load_data()

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()
Expand Down
44 changes: 16 additions & 28 deletions src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from flwr_datasets import FederatedDataset


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Expand All @@ -34,39 +34,27 @@ class Net(nn.Module):
return self.fc3(x)


def load_data(partition_id: int, total_partitions: int):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": total_partitions})
partition = fds.load_partition(partition_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader
def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)


def train(net, trainloader, valloader, epochs, device):
"""Train the model on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net.train()
for _ in range(epochs):
for batch in trainloader:
images = batch["img"]
labels = batch["label"]
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()

train_loss, train_acc = test(net, trainloader)
Expand All @@ -83,13 +71,13 @@ def train(net, trainloader, valloader, epochs, device):

def test(net, testloader):
"""Validate the model on the test set."""
net.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in testloader:
images = batch["img"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images)
for images, labels in testloader:
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
Expand Down

0 comments on commit 2764379

Please sign in to comment.