Skip to content

Commit

Permalink
Add lightning to FL content for DLI (#3208)
Browse files Browse the repository at this point in the history
Add PyTorch lightning to FL content for DLI.

### Description

Add PyTorch lightning to FL content for DLI.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
nvkevlu authored Feb 7, 2025
1 parent 6c29d7f commit 9bb23a3
Show file tree
Hide file tree
Showing 7 changed files with 780 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.lit_net import LitNet

from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
n_clients = 5
num_rounds = 2

job = BaseFedJob(
name="cifar10_lightning_fedavg",
initial_model=LitNet(),
)

controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")

# Add clients
for i in range(n_clients):
runner = ScriptRunner(
script="src/cifar10_lightning_fl.py", script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(runner, f"site-{i + 1}")

job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0", log_config="./log_config.json")
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
{
"version": 1,
"disable_existing_loggers": false,
"formatters": {
"baseFormatter": {
"()": "nvflare.fuel.utils.log_utils.BaseFormatter",
"fmt": "%(asctime)s - %(name)s - %(levelname)s - %(fl_ctx)s - %(message)s"
},
"colorFormatter": {
"()": "nvflare.fuel.utils.log_utils.ColorFormatter",
"fmt": "%(asctime)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S"
},
"jsonFormatter": {
"()": "nvflare.fuel.utils.log_utils.JsonFormatter",
"fmt": "%(asctime)s - %(identity)s - %(name)s - %(fullName)s - %(levelname)s - %(fl_ctx)s - %(message)s"
}
},
"filters": {
"FLFilter": {
"()": "nvflare.fuel.utils.log_utils.LoggerNameFilter",
"logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"]
}
},
"handlers": {
"consoleHandler": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "colorFormatter",
"filters": ["FLFilter"],
"stream": "ext://sys.stdout"
},
"logFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "baseFormatter",
"filename": "log.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"errorFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "ERROR",
"formatter": "baseFormatter",
"filename": "log_error.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"jsonFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "jsonFormatter",
"filename": "log.json",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"FLFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "baseFormatter",
"filters": ["FLFilter"],
"filename": "log_fl.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10,
"delay": true
}
},
"loggers": {
"root": {
"level": "INFO",
"handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"]
}
}
}









Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
nvflare~=2.5.0rc
torch
torchvision
pytorch_lightning
tensorboard
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torchvision
import torchvision.transforms as transforms
from lit_net import LitNet
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from torch.utils.data import DataLoader, random_split

# (1) import nvflare lightning client API
import nvflare.client.lightning as flare

seed_everything(7)


DATASET_PATH = "/tmp/nvflare/data"
BATCH_SIZE = 4

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


class CIFAR10DataModule(LightningDataModule):
def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size

def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)

def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage == "validate":
cifar_full = torchvision.datasets.CIFAR10(
root=self.data_dir, train=True, download=False, transform=transform
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)

def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


def main():
model = LitNet()
cifar10_dm = CIFAR10DataModule()
trainer = Trainer(max_epochs=1, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu")

# (2) patch the lightning trainer
flare.patch(trainer)

while flare.is_running():
# (3) receives FLModel from NVFlare
# Note that we don't need to pass this input_model to trainer
# because after flare.patch the trainer.fit/validate will get the
# global model internally
input_model = flare.receive()
print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")

# (4) evaluate the current global model to allow server-side model selection
print("--- validate global model ---")
trainer.validate(model, datamodule=cifar10_dm)

# perform local training starting with the received global model
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)

# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)

# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy

NUM_CLASSES = 10
criterion = nn.CrossEntropyLoss()


class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


class LitNet(LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.model = Net()
self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
# (optional) pass additional information via self.__fl_meta__
self.__fl_meta__ = {}

def forward(self, x):
out = self.model(x)
return out

def training_step(self, batch, batch_idx):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.train_acc(outputs, labels)
self.log("train_loss", loss)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
return loss

def evaluate(self, batch, stage=None):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.valid_acc(outputs, labels)

if stage:
self.log(f"{stage}_loss", loss)
self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
return outputs

def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")

def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.evaluate(batch)

def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
return {"optimizer": optimizer}
Empty file.
Loading

0 comments on commit 9bb23a3

Please sign in to comment.