Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lightning to FL content for DLI #3208

Merged
merged 5 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from src.lit_net import LitNet

from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner

if __name__ == "__main__":
n_clients = 5
num_rounds = 2

job = BaseFedJob(
name="cifar10_lightning_fedavg",
initial_model=LitNet(),
)

controller = FedAvg(
num_clients=n_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")

# Add clients
for i in range(n_clients):
runner = ScriptRunner(
script="src/cifar10_lightning_fl.py", script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(runner, f"site-{i + 1}")

job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0", log_config="./log_config.json")
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
{
"version": 1,
"disable_existing_loggers": false,
"formatters": {
"baseFormatter": {
"()": "nvflare.fuel.utils.log_utils.BaseFormatter",
"fmt": "%(asctime)s - %(name)s - %(levelname)s - %(fl_ctx)s - %(message)s"
},
"colorFormatter": {
"()": "nvflare.fuel.utils.log_utils.ColorFormatter",
"fmt": "%(asctime)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S"
},
"jsonFormatter": {
"()": "nvflare.fuel.utils.log_utils.JsonFormatter",
"fmt": "%(asctime)s - %(identity)s - %(name)s - %(fullName)s - %(levelname)s - %(fl_ctx)s - %(message)s"
}
},
"filters": {
"FLFilter": {
"()": "nvflare.fuel.utils.log_utils.LoggerNameFilter",
"logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"]
}
},
"handlers": {
"consoleHandler": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "colorFormatter",
"filters": ["FLFilter"],
"stream": "ext://sys.stdout"
},
"logFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "baseFormatter",
"filename": "log.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"errorFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "ERROR",
"formatter": "baseFormatter",
"filename": "log_error.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"jsonFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "jsonFormatter",
"filename": "log.json",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10
},
"FLFileHandler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "baseFormatter",
"filters": ["FLFilter"],
"filename": "log_fl.txt",
"mode": "a",
"maxBytes": 20971520,
"backupCount": 10,
"delay": true
}
},
"loggers": {
"root": {
"level": "INFO",
"handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"]
}
}
}









Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
nvflare~=2.5.0rc
torch
torchvision
pytorch_lightning
tensorboard
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torchvision
import torchvision.transforms as transforms
from lit_net import LitNet
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from torch.utils.data import DataLoader, random_split

# (1) import nvflare lightning client API
import nvflare.client.lightning as flare

seed_everything(7)


DATASET_PATH = "/tmp/nvflare/data"
BATCH_SIZE = 4

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


class CIFAR10DataModule(LightningDataModule):
def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size

def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform)

def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage == "validate":
cifar_full = torchvision.datasets.CIFAR10(
root=self.data_dir, train=True, download=False, transform=transform
)
self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2])

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage == "predict":
self.cifar_test = torchvision.datasets.CIFAR10(
root=self.data_dir, train=False, download=False, transform=transform
)

def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)

def predict_dataloader(self):
return DataLoader(self.cifar_test, batch_size=self.batch_size)


def main():
model = LitNet()
cifar10_dm = CIFAR10DataModule()
trainer = Trainer(max_epochs=1, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu")

# (2) patch the lightning trainer
flare.patch(trainer)

while flare.is_running():
# (3) receives FLModel from NVFlare
# Note that we don't need to pass this input_model to trainer
# because after flare.patch the trainer.fit/validate will get the
# global model internally
input_model = flare.receive()
print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")

# (4) evaluate the current global model to allow server-side model selection
print("--- validate global model ---")
trainer.validate(model, datamodule=cifar10_dm)

# perform local training starting with the received global model
print("--- train new model ---")
trainer.fit(model, datamodule=cifar10_dm)

# test local model
print("--- test new model ---")
trainer.test(ckpt_path="best", datamodule=cifar10_dm)

# get predictions
print("--- prediction with new best model ---")
trainer.predict(ckpt_path="best", datamodule=cifar10_dm)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy

NUM_CLASSES = 10
criterion = nn.CrossEntropyLoss()


class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


class LitNet(LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.model = Net()
self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES)
# (optional) pass additional information via self.__fl_meta__
self.__fl_meta__ = {}

def forward(self, x):
out = self.model(x)
return out

def training_step(self, batch, batch_idx):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.train_acc(outputs, labels)
self.log("train_loss", loss)
self.log("train_acc", self.train_acc, on_step=True, on_epoch=False)
return loss

def evaluate(self, batch, stage=None):
x, labels = batch
outputs = self(x)
loss = criterion(outputs, labels)
self.valid_acc(outputs, labels)

if stage:
self.log(f"{stage}_loss", loss)
self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True)
return outputs

def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")

def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self.evaluate(batch)

def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
return {"optimizer": optimizer}
Loading
Loading