diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl index 187b2301f72b..bdb5b8fcadf9 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -1,7 +1,6 @@ """$project_name: A Flower / PyTorch app.""" from flwr.client import NumPyClient, ClientApp -from flwr.cli.flower_toml import load_and_validate_with_defaults from $project_name.task import ( Net, @@ -32,17 +31,10 @@ class FlowerClient(NumPyClient): return loss, len(self.valloader.dataset), {"accuracy": accuracy} -# Load config -cfg, *_ = load_and_validate_with_defaults() - def client_fn(cid: str): # Load model and data net = Net().to(DEVICE) - engine = cfg["flower"]["engine"] - num_partitions = 2 - if "simulation" in engine: - num_partitions = engine["simulation"]["supernode"]["num"] - trainloader, valloader = load_data(int(cid), num_partitions) + trainloader, valloader = load_data() # Return Client instance return FlowerClient(net, trainloader, valloader).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl index b7f69bf7dce7..1d727599a1e4 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor -from flwr_datasets import FederatedDataset + DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -34,39 +34,27 @@ class Net(nn.Module): return self.fc3(x) -def load_data(partition_id: int, total_partitions: int): - """Load partition CIFAR10 data.""" - fds = FederatedDataset(dataset="cifar10", partitioners={"train": total_partitions}) - partition = fds.load_partition(partition_id) - # Divide data on each node: 80% train, 20% test - partition_train_test = partition.train_test_split(test_size=0.2) - pytorch_transforms = Compose( - [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - ) - - def apply_transforms(batch): - """Apply transforms to the partition from FederatedDataset.""" - batch["img"] = [pytorch_transforms(img) for img in batch["img"]] - return batch - - partition_train_test = partition_train_test.with_transform(apply_transforms) - trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) - testloader = DataLoader(partition_train_test["test"], batch_size=32) - return trainloader, testloader +def load_data(): + """Load CIFAR-10 (training and test set).""" + trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = CIFAR10("./data", train=True, download=True, transform=trf) + testset = CIFAR10("./data", train=False, download=True, transform=trf) + return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) def train(net, trainloader, valloader, epochs, device): """Train the model on the training set.""" + print("Starting training...") net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) net.train() for _ in range(epochs): - for batch in trainloader: - images = batch["img"] - labels = batch["label"] + for images, labels in trainloader: + images, labels = images.to(device), labels.to(device) optimizer.zero_grad() - criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward() + loss = criterion(net(images), labels) + loss.backward() optimizer.step() train_loss, train_acc = test(net, trainloader) @@ -83,13 +71,13 @@ def train(net, trainloader, valloader, epochs, device): def test(net, testloader): """Validate the model on the test set.""" + net.to(DEVICE) criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad(): - for batch in testloader: - images = batch["img"].to(DEVICE) - labels = batch["label"].to(DEVICE) - outputs = net(images) + for images, labels in testloader: + outputs = net(images.to(DEVICE)) + labels = labels.to(DEVICE) loss += criterion(outputs, labels).item() correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() accuracy = correct / len(testloader.dataset)