From bfec6bb8fe8093f7fb1d7debb9b328725fadfd2f Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 6 Feb 2025 16:58:01 -0500 Subject: [PATCH 1/2] add lightning to FL content for DLI --- .../code/lightning_fl_job.py | 44 ++ .../code/log_config.json | 87 ++++ .../code/requirements.txt | 5 + .../code/src/cifar10_lightning_fl.py | 105 +++++ .../code/src/lit_net.py | 93 ++++ .../convert_torch_lightning_to_dl.ipynb | 0 .../convert_torch_lightning_to_fl.ipynb | 446 ++++++++++++++++++ 7 files changed, 780 insertions(+) create mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py create mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/log_config.json create mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/requirements.txt create mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/cifar10_lightning_fl.py create mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/lit_net.py delete mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_dl.ipynb create mode 100644 examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_fl.ipynb diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py new file mode 100644 index 0000000000..3dc8a2f215 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.lit_net import LitNet + +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob +from nvflare.job_config.script_runner import ScriptRunner + +if __name__ == "__main__": + n_clients = 5 + num_rounds = 2 + + job = BaseFedJob( + name="cifar10_lightning_fedavg", + initial_model=LitNet(), + ) + + controller = FedAvg( + num_clients=n_clients, + num_rounds=num_rounds, + ) + job.to(controller, "server") + + # Add clients + for i in range(n_clients): + runner = ScriptRunner( + script="src/cifar10_lightning_fl.py", script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(runner, f"site-{i+1}") + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0", log_config="./log_config.json") diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/log_config.json b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/log_config.json new file mode 100644 index 0000000000..e5732b4950 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/log_config.json @@ -0,0 +1,87 @@ +{ + "version": 1, + "disable_existing_loggers": false, + "formatters": { + "baseFormatter": { + "()": "nvflare.fuel.utils.log_utils.BaseFormatter", + "fmt": "%(asctime)s - %(name)s - %(levelname)s - %(fl_ctx)s - %(message)s" + }, + "colorFormatter": { + "()": "nvflare.fuel.utils.log_utils.ColorFormatter", + "fmt": "%(asctime)s - %(levelname)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S" + }, + "jsonFormatter": { + "()": "nvflare.fuel.utils.log_utils.JsonFormatter", + "fmt": "%(asctime)s - %(identity)s - %(name)s - %(fullName)s - %(levelname)s - %(fl_ctx)s - %(message)s" + } + }, + "filters": { + "FLFilter": { + "()": "nvflare.fuel.utils.log_utils.LoggerNameFilter", + "logger_names": ["custom", "nvflare.app_common", "nvflare.app_opt"] + } + }, + "handlers": { + "consoleHandler": { + "class": "logging.StreamHandler", + "level": "INFO", + "formatter": "colorFormatter", + "filters": ["FLFilter"], + "stream": "ext://sys.stdout" + }, + "logFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filename": "log.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "errorFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "ERROR", + "formatter": "baseFormatter", + "filename": "log_error.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "jsonFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "jsonFormatter", + "filename": "log.json", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10 + }, + "FLFileHandler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "baseFormatter", + "filters": ["FLFilter"], + "filename": "log_fl.txt", + "mode": "a", + "maxBytes": 20971520, + "backupCount": 10, + "delay": true + } + }, + "loggers": { + "root": { + "level": "INFO", + "handlers": ["consoleHandler", "logFileHandler", "errorFileHandler", "jsonFileHandler", "FLFileHandler"] + } + } +} + + + + + + + + + diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/requirements.txt b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/requirements.txt new file mode 100644 index 0000000000..5db2fd4d24 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/requirements.txt @@ -0,0 +1,5 @@ +nvflare~=2.5.0rc +torch +torchvision +pytorch_lightning +tensorboard \ No newline at end of file diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/cifar10_lightning_fl.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/cifar10_lightning_fl.py new file mode 100644 index 0000000000..18861f604f --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/cifar10_lightning_fl.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torchvision +import torchvision.transforms as transforms +from lit_net import LitNet +from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from torch.utils.data import DataLoader, random_split + +# (1) import nvflare lightning client API +import nvflare.client.lightning as flare + +seed_everything(7) + + +DATASET_PATH = "/tmp/nvflare/data" +BATCH_SIZE = 4 + +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +class CIFAR10DataModule(LightningDataModule): + def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + + def prepare_data(self): + torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform) + torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform) + + def setup(self, stage: str): + # Assign train/val datasets for use in dataloaders + if stage == "fit" or stage == "validate": + cifar_full = torchvision.datasets.CIFAR10( + root=self.data_dir, train=True, download=False, transform=transform + ) + self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2]) + + # Assign test dataset for use in dataloader(s) + if stage == "test" or stage == "predict": + self.cifar_test = torchvision.datasets.CIFAR10( + root=self.data_dir, train=False, download=False, transform=transform + ) + + def train_dataloader(self): + return DataLoader(self.cifar_train, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.cifar_val, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.cifar_test, batch_size=self.batch_size) + + def predict_dataloader(self): + return DataLoader(self.cifar_test, batch_size=self.batch_size) + + +def main(): + model = LitNet() + cifar10_dm = CIFAR10DataModule() + trainer = Trainer(max_epochs=1, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu") + + # (2) patch the lightning trainer + flare.patch(trainer) + + while flare.is_running(): + # (3) receives FLModel from NVFlare + # Note that we don't need to pass this input_model to trainer + # because after flare.patch the trainer.fit/validate will get the + # global model internally + input_model = flare.receive() + print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n") + + # (4) evaluate the current global model to allow server-side model selection + print("--- validate global model ---") + trainer.validate(model, datamodule=cifar10_dm) + + # perform local training starting with the received global model + print("--- train new model ---") + trainer.fit(model, datamodule=cifar10_dm) + + # test local model + print("--- test new model ---") + trainer.test(ckpt_path="best", datamodule=cifar10_dm) + + # get predictions + print("--- prediction with new best model ---") + trainer.predict(ckpt_path="best", datamodule=cifar10_dm) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/lit_net.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/lit_net.py new file mode 100644 index 0000000000..d70b85ca5a --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/src/lit_net.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from pytorch_lightning import LightningModule +from torchmetrics import Accuracy + +NUM_CLASSES = 10 +criterion = nn.CrossEntropyLoss() + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class LitNet(LightningModule): + def __init__(self): + super().__init__() + self.save_hyperparameters() + self.model = Net() + self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) + self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) + # (optional) pass additional information via self.__fl_meta__ + self.__fl_meta__ = {} + + def forward(self, x): + out = self.model(x) + return out + + def training_step(self, batch, batch_idx): + x, labels = batch + outputs = self(x) + loss = criterion(outputs, labels) + self.train_acc(outputs, labels) + self.log("train_loss", loss) + self.log("train_acc", self.train_acc, on_step=True, on_epoch=False) + return loss + + def evaluate(self, batch, stage=None): + x, labels = batch + outputs = self(x) + loss = criterion(outputs, labels) + self.valid_acc(outputs, labels) + + if stage: + self.log(f"{stage}_loss", loss) + self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True) + return outputs + + def validation_step(self, batch, batch_idx): + self.evaluate(batch, "val") + + def test_step(self, batch, batch_idx): + self.evaluate(batch, "test") + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.evaluate(batch) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + return {"optimizer": optimizer} diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_dl.ipynb b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_dl.ipynb deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_fl.ipynb b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_fl.ipynb new file mode 100644 index 0000000000..0693c707c7 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/convert_torch_lightning_to_fl.ipynb @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7a5c3d67-a6ea-4f59-84d2-effc3ef016e1", + "metadata": {}, + "source": [ + "# Converting PyTorch Lightning to FL\n", + "\n", + "In this notebook, we use FedAvg and CIFAR10 PyTorch Lightning for client training code to create and run a federated learning job with NVFlare." + ] + }, + { + "cell_type": "markdown", + "id": "fcf2b4a8-ed42-421d-8898-c0c93f9d8a09", + "metadata": {}, + "source": [ + "## Basic Concepts\n", + "At the heart of NVFlare lies the concept of collaboration through\n", + "\"tasks.\" An FL controller assigns tasks (e.g., training on local data) to one or more FL clients, processes returned\n", + "results (e.g., model weight updates), and may assign additional\n", + "tasks based on these results and other factors (e.g., a pre-configured\n", + "number of training rounds). The clients run executors which can listen for tasks and perform the necessary computations locally, such as model training. This task-based interaction repeats\n", + "until the experiment’s objectives are met. " + ] + }, + { + "cell_type": "markdown", + "id": "b68cb248-dc6a-48d1-880d-33c4324d9723", + "metadata": {}, + "source": [ + "## Federated Averaging with NVFlare\n", + "Given the flexible controller and executor concepts, it is easy to implement different computing & communication patterns with NVFlare, such as [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) and [cyclic weight transfer](https://academic.oup.com/jamia/article/25/8/945/4956468). \n", + "\n", + "The controller's `run()` routine is responsible for assigning tasks and processing task results from the Executors. " + ] + }, + { + "cell_type": "markdown", + "id": "b2f84fb1-9dd3-4c72-a727-c4614260f02f", + "metadata": {}, + "source": [ + "### Server Code\n", + "First, we provide a simple implementation of the [FedAvg](https://proceedings.mlr.press/v54/mcmahan17a?ref=https://githubhelp.com) algorithm with NVFlare. \n", + "The `run()` routine implements the main algorithmic logic. \n", + "Subroutines, like `sample_clients()` and `scatter_and_gather_model()` utilize the communicator object, native to each Controller to get the list of available clients,\n", + "distribute the current global model to the clients, and collect their results.\n", + "\n", + "The FedAvg controller implements these main steps:\n", + "1. FL server initializes an initial model using `self.load_model()`.\n", + "2. For each round (global iteration):\n", + " - FL server samples available clients using `self.sample_clients()`.\n", + " - FL server sends the global model to clients and waits for their updates using `self.send_model_and_wait()`.\n", + " - FL server aggregates all the `results` and produces a new global model using `self.update_model()`." + ] + }, + { + "cell_type": "markdown", + "id": "d62a13d5-1130-44e6-8818-70e30de401e6", + "metadata": {}, + "source": [ + "```python\n", + "class FedAvg(BaseFedAvg):\n", + " def run(self) -> None:\n", + " self.info(\"Start FedAvg.\")\n", + "\n", + " model = self.load_model()\n", + " model.start_round = self.start_round\n", + " model.total_rounds = self.num_rounds\n", + "\n", + " for self.current_round in range(self.start_round, self.start_round + self.num_rounds):\n", + " self.info(f\"Round {self.current_round} started.\")\n", + " model.current_round = self.current_round\n", + "\n", + " clients = self.sample_clients(self.num_clients)\n", + "\n", + " results = self.send_model_and_wait(targets=clients, data=model)\n", + "\n", + " aggregate_results = self.aggregate(results)\n", + "\n", + " model = self.update_model(model, aggregate_results)\n", + "\n", + " self.save_model(model)\n", + "\n", + " self.info(\"Finished FedAvg.\")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "d24b6476-089a-4e9d-825b-07107bd5d84a", + "metadata": {}, + "source": [ + "### Client Code \n", + "Given a CIFAR10 [PyTorch Lightning](https://lightning.ai/) code example with the network wrapped into a LightningModule [LitNet](code/src/lit_net.py) class, we wish to adapt this centralized training code to something that can run in a federated setting." + ] + }, + { + "cell_type": "markdown", + "id": "1c551053-5460-4d83-8578-796074170342", + "metadata": {}, + "source": [ + "On the client side, the training workflow is as follows:\n", + "1. Receive the model from the FL server.\n", + "2. Perform local training on the received global model\n", + "and/or evaluate the received global model for model\n", + "selection.\n", + "3. Send the new model back to the FL server." + ] + }, + { + "cell_type": "markdown", + "id": "c02bfc2a-783c-494f-9427-c38f40a2e870", + "metadata": {}, + "source": [ + "Using NVFlare's Client Lightning API, we can easily adapt machine learning code that was written for centralized training and apply it in a federated scenario.\n", + "For general use cases, we can use the Client Lightning API patch function:\n", + "- `flare.patch(trainer)`: Patch the lightning trainer. After flare.patch, functions such as `trainer.fit()` and `trainer.validate()` will get the global model internally and automatically send the result model to the FL server." + ] + }, + { + "cell_type": "markdown", + "id": "9115ee07-d848-4a7c-99ad-64e20ab7093c", + "metadata": {}, + "source": [ + "With this method, the developers can use the Client API\n", + "to change their centralized training code to an FL scenario with\n", + "these simple code changes shown below.\n", + "```python\n", + " # (1) import nvflare lightning client API\n", + " import nvflare.client.lightning as flare\n", + "\n", + " # (2) patch the lightning trainer\n", + " flare.patch(trainer)\n", + "\n", + " while flare.is_running():\n", + " \n", + " # Note that we can optionally receive the FLModel from NVFLARE.\n", + " # We don't need to pass this input_model to trainer because after flare.patch \n", + " # the trainer.fit/validate will get the global model internally\n", + " input_model = flare.receive()\n", + "\n", + " trainer.validate(...)\n", + "\n", + " trainer.fit(...)\n", + "\n", + " trainer.test(...)\n", + "\n", + " trainer.predict(...)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "67432f44-4144-4347-8d74-e7f57e065a14", + "metadata": {}, + "source": [ + "The full client training script is saved in a separate file, e.g. [code/src/cifar10_lightning_fl.py](code/src/cifar10_lightning_fl.py) doing CNN training on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset." + ] + }, + { + "cell_type": "markdown", + "id": "5da34414-bac4-4352-8077-ab7ade998eec", + "metadata": {}, + "source": [ + "## Run an NVFlare Job\n", + "Now that we have defined the FedAvg controller to run our federated compute workflow on the FL server, and our client training script to receive the global models, run local training, and send the results back to the FL server, we can put everything together using NVFlare's Job API." + ] + }, + { + "cell_type": "markdown", + "id": "17b69013", + "metadata": {}, + "source": [ + "## Install requirements\n", + "Make sure to install the required packages:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d14e8f49", + "metadata": {}, + "outputs": [], + "source": [ + "% pip install -r code/requirements.txt" + ] + }, + { + "cell_type": "markdown", + "id": "0cedaf75-3a4a-4843-8017-7716b53149a2", + "metadata": {}, + "source": [ + "#### 1. Define the initial model\n", + "First, we define the global model used to initialize the model on the FL server. See [code/src/lit_net.py](code/src/lit_net.py)." + ] + }, + { + "cell_type": "markdown", + "id": "93889e62-b725-427c-8839-2771ca81d24c", + "metadata": {}, + "source": [ + "```python\n", + "from typing import Any\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from pytorch_lightning import LightningModule\n", + "from torchmetrics import Accuracy\n", + "\n", + "NUM_CLASSES = 10\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(3, 6, 5)\n", + " self.pool = nn.MaxPool2d(2, 2)\n", + " self.conv2 = nn.Conv2d(6, 16, 5)\n", + " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", + " self.fc2 = nn.Linear(120, 84)\n", + " self.fc3 = nn.Linear(84, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.pool(F.relu(self.conv1(x)))\n", + " x = self.pool(F.relu(self.conv2(x)))\n", + " x = torch.flatten(x, 1) # flatten all dimensions except batch\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " x = self.fc3(x)\n", + " return x\n", + "\n", + "\n", + "class LitNet(LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.save_hyperparameters()\n", + " self.model = Net()\n", + " self.train_acc = Accuracy(task=\"multiclass\", num_classes=NUM_CLASSES)\n", + " self.valid_acc = Accuracy(task=\"multiclass\", num_classes=NUM_CLASSES)\n", + " # (optional) pass additional information via self.__fl_meta__\n", + " self.__fl_meta__ = {}\n", + "\n", + " def forward(self, x):\n", + " out = self.model(x)\n", + " return out\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " x, labels = batch\n", + " outputs = self(x)\n", + " loss = criterion(outputs, labels)\n", + " self.train_acc(outputs, labels)\n", + " self.log(\"train_loss\", loss)\n", + " self.log(\"train_acc\", self.train_acc, on_step=True, on_epoch=False)\n", + " return loss\n", + "\n", + " def evaluate(self, batch, stage=None):\n", + " x, labels = batch\n", + " outputs = self(x)\n", + " loss = criterion(outputs, labels)\n", + " self.valid_acc(outputs, labels)\n", + "\n", + " if stage:\n", + " self.log(f\"{stage}_loss\", loss)\n", + " self.log(f\"{stage}_acc\", self.valid_acc, on_step=True, on_epoch=True)\n", + " return outputs\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " self.evaluate(batch, \"val\")\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " self.evaluate(batch, \"test\")\n", + "\n", + " def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:\n", + " return self.evaluate(batch)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)\n", + " return {\"optimizer\": optimizer}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "1b70da5d-ba8b-4e65-b47f-44bb9bddae4d", + "metadata": {}, + "source": [ + "#### 2. Define a FedJob\n", + "The `FedJob` is used to define how controllers and executors are placed within a federated job using the `to(object, target)` routine.\n", + "\n", + "Here we use a PyTorch `BaseFedJob`, where we can define the job name and the initial global model.\n", + "The `BaseFedJob` automatically configures components for model persistence, model selection, and TensorBoard streaming for convenience." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aaa2b6f4", + "metadata": {}, + "outputs": [], + "source": [ + "from code.src.lit_net import LitNet\n", + "\n", + "from nvflare.app_common.workflows.fedavg import FedAvg\n", + "from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob\n", + "from nvflare.job_config.script_runner import ScriptRunner\n", + "\n", + "job = BaseFedJob(\n", + " name=\"cifar10_lightning_fedavg\",\n", + " initial_model=LitNet(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9361d9f8-54f3-4363-b3ba-706a7ae3a8e9", + "metadata": {}, + "source": [ + "#### 3. Define the Controller Workflow\n", + "Define the controller workflow and send to server." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6962e6cc-995e-4356-8156-3ceba2c7a249", + "metadata": {}, + "outputs": [], + "source": [ + "n_clients = 2\n", + "\n", + "controller = FedAvg(\n", + " num_clients=n_clients,\n", + " num_rounds=2,\n", + ")\n", + "job.to(controller, \"server\")" + ] + }, + { + "cell_type": "markdown", + "id": "77f5bc7f-4fb4-46e9-8f02-5e7245d95070", + "metadata": {}, + "source": [ + "That completes the components that need to be defined on the server." + ] + }, + { + "cell_type": "markdown", + "id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0", + "metadata": {}, + "source": [ + "#### 4. Add clients\n", + "Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n", + "\n", + "Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad5d36fe-9ae5-43c3-80bc-2cdc66bf7a7e", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(n_clients):\n", + " runner = ScriptRunner(\n", + " script=\"src/cifar10_lightning_fl.py\", script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n", + " )\n", + " job.to(runner, f\"site-{i+1}\")" + ] + }, + { + "cell_type": "markdown", + "id": "113fd6af-85be-4f75-8a8e-4666771252b3", + "metadata": {}, + "source": [ + "That's it!\n", + "\n", + "#### 5. Optionally export the job\n", + "Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99a270bf-c906-425b-b999-2306cb76eb62", + "metadata": {}, + "outputs": [], + "source": [ + "job.export_job(\"/tmp/nvflare/jobs/job_config\")" + ] + }, + { + "cell_type": "markdown", + "id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63", + "metadata": {}, + "source": [ + "#### 6. Run FL Simulation\n", + "Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13068ab7-35cf-49e7-91ed-10993049ef0d", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "job.simulator_run(\"/tmp/nvflare/jobs/workdir\", gpu=\"0\")" + ] + }, + { + "cell_type": "markdown", + "id": "fb2e1266", + "metadata": {}, + "source": [ + "You can see the full code for this job in [lightning_fl_job](code/lightning_fl_job.py)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 0e1334c099b32a4c98f659b8795144bd6e504dae Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 6 Feb 2025 20:27:54 -0500 Subject: [PATCH 2/2] fix ci --- .../code/lightning_fl_job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py index 3dc8a2f215..b51724a962 100644 --- a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.2_convert_torch_lightning_to_federated_learning/code/lightning_fl_job.py @@ -38,7 +38,7 @@ runner = ScriptRunner( script="src/cifar10_lightning_fl.py", script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" ) - job.to(runner, f"site-{i+1}") + job.to(runner, f"site-{i + 1}") job.export_job("/tmp/nvflare/jobs/job_config") job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0", log_config="./log_config.json")