Skip to content

Commit

Permalink
ci(framework:skip) Update dataset sourcing for e2e-pytorch-lightning (
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 17, 2025
1 parent dc452ce commit cb091fb
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 28 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ jobs:
- directory: e2e-pytorch-lightning
e2e: e2e_pytorch_lightning
dataset: |
from torchvision.datasets import MNIST
MNIST('./data', download=True)

- directory: e2e-scikit-learn
e2e: e2e_scikit_learn
Expand Down
11 changes: 9 additions & 2 deletions e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/client_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
from random import randint

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -54,7 +55,11 @@ def _set_parameters(model, parameters):

def client_fn(context: Context):
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()
partition_id = int(context.node_config.get("partition-id", 0))
num_partitions = int(context.node_config.get("num-partitions", 10))
train_loader, val_loader, test_loader = mnist.load_data(
partition_id, num_partitions
)

# Flower client
return FlowerClient(model, train_loader, val_loader, test_loader).to_client()
Expand All @@ -68,7 +73,9 @@ def client_fn(context: Context):
def main() -> None:
# Model and data
model = mnist.LitAutoEncoder()
train_loader, val_loader, test_loader = mnist.load_data()
num_partitions = 10
p_id = randint(0, num_partitions - 1)
train_loader, val_loader, test_loader = mnist.load_data(p_id, num_partitions)

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader).to_client()
Expand Down
70 changes: 48 additions & 22 deletions e2e/e2e-pytorch-lightning/e2e_pytorch_lightning/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
Source: pytorchlightning.ai (2021/02/04)
"""

from random import randint

import pytorch_lightning as pl
import torch
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset, random_split
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


class LitAutoEncoder(pl.LightningModule):
Expand All @@ -35,7 +38,7 @@ def configure_optimizers(self):
return optimizer

def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = train_batch["image"]
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
Expand All @@ -50,7 +53,7 @@ def test_step(self, batch, batch_idx):
self._evaluate(batch, "test")

def _evaluate(self, batch, stage=None):
x, y = batch
x = batch["image"]
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
Expand All @@ -59,31 +62,54 @@ def _evaluate(self, batch, stage=None):
self.log(f"{stage}_loss", loss, prog_bar=True)


def load_data():
# Training / validation set
trainset = MNIST(
"./../data", train=True, download=True, transform=transforms.ToTensor()
def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["image"] = [transforms.functional.to_tensor(img) for img in batch["image"]]
return batch


fds = None # Cache FederatedDataset


def load_data(partition_id, num_partitions):
# Only initialize `FederatedDataset` once
global fds
if fds is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
fds = FederatedDataset(
dataset="ylecun/mnist",
partitioners={"train": partitioner},
)
partition = fds.load_partition(partition_id, "train")

partition = partition.with_transform(apply_transforms)
# 20 % for on federated evaluation
partition_full = partition.train_test_split(test_size=0.2, seed=42)
# 60 % for the federated train and 20 % for the federated validation (both in fit)
partition_train_valid = partition_full["train"].train_test_split(
train_size=0.75, seed=42
)
trainset = Subset(trainset, range(1000))
mnist_train, mnist_val = random_split(trainset, [800, 200])
train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=0)

# Test set
testset = MNIST(
"./../data", train=False, download=True, transform=transforms.ToTensor()
trainloader = DataLoader(
partition_train_valid["train"],
batch_size=32,
shuffle=True,
num_workers=0,
)
testset = Subset(testset, range(10))
test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=0)

return train_loader, val_loader, test_loader
valloader = DataLoader(
partition_train_valid["test"],
batch_size=32,
num_workers=0,
)
testloader = DataLoader(partition_full["test"], batch_size=32, num_workers=0)
return trainloader, valloader, testloader


def main() -> None:
"""Centralized training."""
# Load data
train_loader, val_loader, test_loader = load_data()

num_partitions = 10
p_id = randint(0, num_partitions - 1)
train_loader, val_loader, test_loader = load_data(p_id, num_partitions)
# Load model
model = LitAutoEncoder()

Expand Down
1 change: 1 addition & 0 deletions e2e/e2e-pytorch-lightning/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "Federated Learning E2E test with Flower and PyTorch Lightning"
license = "Apache-2.0"
dependencies = [
"flwr[simulation] @ {root:parent:parent:uri}",
"flwr-datasets[vision]>=0.5.0,<1.0.0",
"pytorch-lightning==2.4.0",
"torchvision>=0.20.1,<0.21.0",
]
Expand Down
2 changes: 1 addition & 1 deletion e2e/test_legacy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ if [ "$1" = "e2e-bare-https" ]; then
fi

# run the first command in background and save output to a temporary file:
timeout 2m python server_app.py &
timeout 3m python server_app.py &
pid=$!
sleep 3

Expand Down

0 comments on commit cb091fb

Please sign in to comment.