diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index ca5663a03fa5..bbfac289fa57 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -150,9 +150,6 @@ jobs: - directory: e2e-pytorch-lightning e2e: e2e_pytorch_lightning - dataset: | - from torchvision.datasets import MNIST - MNIST('./data', download=True) - directory: e2e-scikit-learn e2e: e2e_scikit_learn diff --git a/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/client_app.py b/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/client_app.py index 3d2903037e85..ae8f4b3f4a74 100644 --- a/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/client_app.py +++ b/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/client_app.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from random import randint import pytorch_lightning as pl import torch @@ -54,7 +55,11 @@ def _set_parameters(model, parameters): def client_fn(context: Context): model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data() + partition_id = int(context.node_config.get("partition-id", 0)) + num_partitions = int(context.node_config.get("num-partitions", 10)) + train_loader, val_loader, test_loader = mnist.load_data( + partition_id, num_partitions + ) # Flower client return FlowerClient(model, train_loader, val_loader, test_loader).to_client() @@ -68,7 +73,9 @@ def client_fn(context: Context): def main() -> None: # Model and data model = mnist.LitAutoEncoder() - train_loader, val_loader, test_loader = mnist.load_data() + num_partitions = 10 + p_id = randint(0, num_partitions - 1) + train_loader, val_loader, test_loader = mnist.load_data(p_id, num_partitions) # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() diff --git a/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/mnist.py b/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/mnist.py index 977a9ea524e8..c5add816dfbd 100644 --- a/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/mnist.py +++ b/e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/mnist.py @@ -3,13 +3,16 @@ Source: pytorchlightning.ai (2021/02/04) """ +from random import randint + import pytorch_lightning as pl import torch +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, Subset, random_split +from torch.utils.data import DataLoader from torchvision import transforms -from torchvision.datasets import MNIST class LitAutoEncoder(pl.LightningModule): @@ -35,7 +38,7 @@ def configure_optimizers(self): return optimizer def training_step(self, train_batch, batch_idx): - x, y = train_batch + x = train_batch["image"] x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) @@ -50,7 +53,7 @@ def test_step(self, batch, batch_idx): self._evaluate(batch, "test") def _evaluate(self, batch, stage=None): - x, y = batch + x = batch["image"] x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) @@ -59,31 +62,54 @@ def _evaluate(self, batch, stage=None): self.log(f"{stage}_loss", loss, prog_bar=True) -def load_data(): - # Training / validation set - trainset = MNIST( - "./../data", train=True, download=True, transform=transforms.ToTensor() +def apply_transforms(batch): + """Apply transforms to the partition from FederatedDataset.""" + batch["image"] = [transforms.functional.to_tensor(img) for img in batch["image"]] + return batch + + +fds = None # Cache FederatedDataset + + +def load_data(partition_id, num_partitions): + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="ylecun/mnist", + partitioners={"train": partitioner}, + ) + partition = fds.load_partition(partition_id, "train") + + partition = partition.with_transform(apply_transforms) + # 20 % for on federated evaluation + partition_full = partition.train_test_split(test_size=0.2, seed=42) + # 60 % for the federated train and 20 % for the federated validation (both in fit) + partition_train_valid = partition_full["train"].train_test_split( + train_size=0.75, seed=42 ) - trainset = Subset(trainset, range(1000)) - mnist_train, mnist_val = random_split(trainset, [800, 200]) - train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=0) - val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=0) - - # Test set - testset = MNIST( - "./../data", train=False, download=True, transform=transforms.ToTensor() + trainloader = DataLoader( + partition_train_valid["train"], + batch_size=32, + shuffle=True, + num_workers=0, ) - testset = Subset(testset, range(10)) - test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=0) - - return train_loader, val_loader, test_loader + valloader = DataLoader( + partition_train_valid["test"], + batch_size=32, + num_workers=0, + ) + testloader = DataLoader(partition_full["test"], batch_size=32, num_workers=0) + return trainloader, valloader, testloader def main() -> None: """Centralized training.""" # Load data - train_loader, val_loader, test_loader = load_data() - + num_partitions = 10 + p_id = randint(0, num_partitions - 1) + train_loader, val_loader, test_loader = load_data(p_id, num_partitions) # Load model model = LitAutoEncoder() diff --git a/e2e/e2e-pytorch-lightning/pyproject.toml b/e2e/e2e-pytorch-lightning/pyproject.toml index efb0eb1bebf1..19216ec501dd 100644 --- a/e2e/e2e-pytorch-lightning/pyproject.toml +++ b/e2e/e2e-pytorch-lightning/pyproject.toml @@ -9,6 +9,7 @@ description = "Federated Learning E2E test with Flower and PyTorch Lightning" license = "Apache-2.0" dependencies = [ "flwr[simulation] @ {root:parent:parent:uri}", + "flwr-datasets[vision]>=0.5.0,<1.0.0", "pytorch-lightning==2.4.0", "torchvision>=0.20.1,<0.21.0", ] diff --git a/e2e/test_legacy.sh b/e2e/test_legacy.sh index b336ee0cb717..7e092a3eaba6 100755 --- a/e2e/test_legacy.sh +++ b/e2e/test_legacy.sh @@ -6,7 +6,7 @@ if [ "$1" = "e2e-bare-https" ]; then fi # run the first command in background and save output to a temporary file: -timeout 2m python server_app.py & +timeout 3m python server_app.py & pid=$! sleep 3